File size: 2,026 Bytes
9b07278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
Lightweight model setup utility for dwpose-editor.

Provides a CLI to pre-download required ONNX models into the local cache
directory used by the app (./models), without initializing ONNX sessions.
"""

from typing import List
import os
import sys

try:
    from huggingface_hub import hf_hub_download
except Exception as e:  # pragma: no cover
    hf_hub_download = None


DEFAULT_REPO_ID = "yzd-v/DWPose"
DEFAULT_CACHE_DIR = "./models"
REQUIRED_FILES: List[str] = [
    "yolox_l.onnx",
    "dw-ll_ucoco_384.onnx",
]


def ensure_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def download_models(repo_id: str = DEFAULT_REPO_ID, cache_dir: str = DEFAULT_CACHE_DIR) -> int:
    """Download required model files into cache_dir. Returns 0 on success, non-zero on failure."""
    if hf_hub_download is None:
        print("[ERROR] huggingface-hub is not installed. Run: pip install huggingface-hub")
        return 2

    ensure_dir(cache_dir)

    ok = True
    for filename in REQUIRED_FILES:
        try:
            print(f"[SETUP] Downloading {filename} from {repo_id}{cache_dir} ...")
            local_path = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
            print(f"[SETUP] OK: {local_path}")
        except Exception as e:
            ok = False
            print(f"[SETUP] ERROR: failed to download {filename}: {e}")

    return 0 if ok else 1


def main(argv: List[str] | None = None) -> int:
    """CLI entry point: python -m utils.model_setup [--repo REPO] [--cache-dir DIR]"""
    import argparse

    parser = argparse.ArgumentParser(description="Pre-download DWPose models into ./models")
    parser.add_argument("--repo", default=DEFAULT_REPO_ID, help="Hugging Face repo id")
    parser.add_argument("--cache-dir", default=DEFAULT_CACHE_DIR, help="Local cache directory")
    args = parser.parse_args(argv)

    return download_models(repo_id=args.repo, cache_dir=args.cache_dir)


if __name__ == "__main__":  # pragma: no cover
    sys.exit(main())