scy639 commited on
Commit
a02fb90
·
verified ·
1 Parent(s): 62e20ca

Update hf_model.py

Browse files
Files changed (1) hide show
  1. hf_model.py +11 -11
hf_model.py CHANGED
@@ -131,32 +131,32 @@ class UniBioTransferModel(LatentDiffusion, PyTorchModelHubMixin):
131
  arcface_path = cache_dir / "Other_dependencies" / "arcface" / "model_ir_se50.pth"
132
  face_parsing_path = cache_dir / "Other_dependencies" / "face_parsing" / "79999_iter.pth"
133
 
134
- def _download_file(repo, filename, local_path):
135
- local_path = Path(local_path)
136
- local_path.parent.mkdir(parents=True, exist_ok=True)
137
  print(f"Downloading {filename} from {repo}...")
138
  token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
139
  hf_hub_download(
140
  repo_id=repo,
141
  filename=filename,
142
- local_dir=str(local_path.parent),
143
  local_dir_use_symlinks=False,
144
  token=token,
145
  )
146
-
147
  if not ckpt_path.exists():
148
- _download_file(repo_id, "checkpoints/pretrained.ckpt", ckpt_path)
149
  if not json_path.exists():
150
- _download_file(repo_id, "checkpoints/pretrained.json", json_path)
151
-
152
  if download_sd14 and not sd14_path.exists():
153
- _download_file(SD14_REPO, SD14_FILENAME, sd14_path)
154
 
155
  if download_deps:
156
  if not arcface_path.exists():
157
- _download_file(repo_id, "Other_dependencies/arcface/model_ir_se50.pth", arcface_path)
158
  if not face_parsing_path.exists():
159
- _download_file(repo_id, "Other_dependencies/face_parsing/79999_iter.pth", face_parsing_path)
160
 
161
  seed_everything(42)
162
 
 
131
  arcface_path = cache_dir / "Other_dependencies" / "arcface" / "model_ir_se50.pth"
132
  face_parsing_path = cache_dir / "Other_dependencies" / "face_parsing" / "79999_iter.pth"
133
 
134
+ def _download_file(repo, filename, target_root):
135
+ target_root = Path(target_root)
136
+ target_root.mkdir(parents=True, exist_ok=True)
137
  print(f"Downloading {filename} from {repo}...")
138
  token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
139
  hf_hub_download(
140
  repo_id=repo,
141
  filename=filename,
142
+ local_dir=str(target_root),
143
  local_dir_use_symlinks=False,
144
  token=token,
145
  )
146
+
147
  if not ckpt_path.exists():
148
+ _download_file(repo_id, "checkpoints/pretrained.ckpt", cache_dir)
149
  if not json_path.exists():
150
+ _download_file(repo_id, "checkpoints/pretrained.json", cache_dir)
151
+
152
  if download_sd14 and not sd14_path.exists():
153
+ _download_file(SD14_REPO, SD14_FILENAME, cache_dir)
154
 
155
  if download_deps:
156
  if not arcface_path.exists():
157
+ _download_file(repo_id, "Other_dependencies/arcface/model_ir_se50.pth", cache_dir)
158
  if not face_parsing_path.exists():
159
+ _download_file(repo_id, "Other_dependencies/face_parsing/79999_iter.pth", cache_dir)
160
 
161
  seed_everything(42)
162