Rasta02 commited on
Commit
468071c
·
verified ·
1 Parent(s): 4864430

Upload backend/model_downloader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. backend/model_downloader.py +88 -0
backend/model_downloader.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Automatic model downloader utility.
3
+ Downloads models from Hugging Face if they don't exist locally.
4
+ """
5
+ import os
6
+ import sys
7
+ from pathlib import Path
8
+ from huggingface_hub import snapshot_download, hf_hub_download
9
+
10
+ # Hugging Face repository
11
+ HF_REPO = "Rasta02/dataku"
12
+
13
+ def ensure_models_exist(model_paths_dict):
14
+ """
15
+ Check if models exist, download from HF if missing.
16
+
17
+ Args:
18
+ model_paths_dict: Dictionary of model paths to check
19
+ e.g., {'lama': '/path/to/lama', 'sttn': '/path/to/sttn.pth'}
20
+ """
21
+ print("[Model Checker] Checking for required models...")
22
+
23
+ for model_name, model_path in model_paths_dict.items():
24
+ # Check if model exists
25
+ if os.path.exists(model_path):
26
+ if os.path.isdir(model_path):
27
+ # Check if directory has files
28
+ if len(os.listdir(model_path)) > 0:
29
+ print(f"[✓] {model_name} model found at {model_path}")
30
+ continue
31
+ else:
32
+ # File exists
33
+ print(f"[✓] {model_name} model found at {model_path}")
34
+ continue
35
+
36
+ # Model doesn't exist, download from HF
37
+ print(f"[!] {model_name} model not found. Downloading from Hugging Face...")
38
+ try:
39
+ # Determine the relative path from backend directory
40
+ backend_dir = os.path.dirname(os.path.abspath(__file__))
41
+ rel_path = os.path.relpath(model_path, backend_dir)
42
+
43
+ if os.path.isabs(model_path) or model_path.endswith('.pth') or model_path.endswith('.pt'):
44
+ # Download specific file
45
+ filename = os.path.basename(model_path)
46
+ repo_path = rel_path.replace(filename, '')
47
+
48
+ hf_hub_download(
49
+ repo_id=HF_REPO,
50
+ filename=os.path.join(repo_path, filename).replace('\\', '/'),
51
+ local_dir=backend_dir,
52
+ local_dir_use_symlinks=False
53
+ )
54
+ print(f"[✓] Downloaded {model_name} model successfully!")
55
+ else:
56
+ # Download entire directory
57
+ os.makedirs(model_path, exist_ok=True)
58
+ snapshot_download(
59
+ repo_id=HF_REPO,
60
+ allow_patterns=f"{rel_path}/*",
61
+ local_dir=backend_dir,
62
+ local_dir_use_symlinks=False
63
+ )
64
+ print(f"[✓] Downloaded {model_name} model directory successfully!")
65
+ except Exception as e:
66
+ print(f"[✗] Failed to download {model_name} model: {e}")
67
+ print(f"[!] Please manually download models from https://huggingface.co/{HF_REPO}")
68
+ sys.exit(1)
69
+
70
+ print("[Model Checker] All models are ready!")
71
+
72
+
73
+ def download_backend_models():
74
+ """Download all backend models if they don't exist"""
75
+ import config
76
+
77
+ model_paths = {
78
+ 'LAMA': config.LAMA_MODEL_PATH,
79
+ 'STTN': config.STTN_MODEL_PATH,
80
+ 'ProPainter': config.VIDEO_INPAINT_MODEL_PATH,
81
+ 'Detection': config.DET_MODEL_PATH,
82
+ }
83
+
84
+ ensure_models_exist(model_paths)
85
+
86
+
87
+ if __name__ == '__main__':
88
+ download_backend_models()