Spaces:
Running on Zero
Running on Zero
Update hf_model.py
Browse files- hf_model.py +11 -11
hf_model.py
CHANGED
|
@@ -131,32 +131,32 @@ class UniBioTransferModel(LatentDiffusion, PyTorchModelHubMixin):
|
|
| 131 |
arcface_path = cache_dir / "Other_dependencies" / "arcface" / "model_ir_se50.pth"
|
| 132 |
face_parsing_path = cache_dir / "Other_dependencies" / "face_parsing" / "79999_iter.pth"
|
| 133 |
|
| 134 |
-
def _download_file(repo, filename,
|
| 135 |
-
|
| 136 |
-
|
| 137 |
print(f"Downloading {filename} from {repo}...")
|
| 138 |
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
|
| 139 |
hf_hub_download(
|
| 140 |
repo_id=repo,
|
| 141 |
filename=filename,
|
| 142 |
-
local_dir=str(
|
| 143 |
local_dir_use_symlinks=False,
|
| 144 |
token=token,
|
| 145 |
)
|
| 146 |
-
|
| 147 |
if not ckpt_path.exists():
|
| 148 |
-
_download_file(repo_id, "checkpoints/pretrained.ckpt",
|
| 149 |
if not json_path.exists():
|
| 150 |
-
_download_file(repo_id, "checkpoints/pretrained.json",
|
| 151 |
-
|
| 152 |
if download_sd14 and not sd14_path.exists():
|
| 153 |
-
_download_file(SD14_REPO, SD14_FILENAME,
|
| 154 |
|
| 155 |
if download_deps:
|
| 156 |
if not arcface_path.exists():
|
| 157 |
-
_download_file(repo_id, "Other_dependencies/arcface/model_ir_se50.pth",
|
| 158 |
if not face_parsing_path.exists():
|
| 159 |
-
_download_file(repo_id, "Other_dependencies/face_parsing/79999_iter.pth",
|
| 160 |
|
| 161 |
seed_everything(42)
|
| 162 |
|
|
|
|
| 131 |
arcface_path = cache_dir / "Other_dependencies" / "arcface" / "model_ir_se50.pth"
|
| 132 |
face_parsing_path = cache_dir / "Other_dependencies" / "face_parsing" / "79999_iter.pth"
|
| 133 |
|
| 134 |
+
def _download_file(repo, filename, target_root):
|
| 135 |
+
target_root = Path(target_root)
|
| 136 |
+
target_root.mkdir(parents=True, exist_ok=True)
|
| 137 |
print(f"Downloading {filename} from {repo}...")
|
| 138 |
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
|
| 139 |
hf_hub_download(
|
| 140 |
repo_id=repo,
|
| 141 |
filename=filename,
|
| 142 |
+
local_dir=str(target_root),
|
| 143 |
local_dir_use_symlinks=False,
|
| 144 |
token=token,
|
| 145 |
)
|
| 146 |
+
|
| 147 |
if not ckpt_path.exists():
|
| 148 |
+
_download_file(repo_id, "checkpoints/pretrained.ckpt", cache_dir)
|
| 149 |
if not json_path.exists():
|
| 150 |
+
_download_file(repo_id, "checkpoints/pretrained.json", cache_dir)
|
| 151 |
+
|
| 152 |
if download_sd14 and not sd14_path.exists():
|
| 153 |
+
_download_file(SD14_REPO, SD14_FILENAME, cache_dir)
|
| 154 |
|
| 155 |
if download_deps:
|
| 156 |
if not arcface_path.exists():
|
| 157 |
+
_download_file(repo_id, "Other_dependencies/arcface/model_ir_se50.pth", cache_dir)
|
| 158 |
if not face_parsing_path.exists():
|
| 159 |
+
_download_file(repo_id, "Other_dependencies/face_parsing/79999_iter.pth", cache_dir)
|
| 160 |
|
| 161 |
seed_everything(42)
|
| 162 |
|