XiangpengYang commited on
Commit
3ea6bf9
·
1 Parent(s): 3b98a6f

Enhance model download from HF Hub

Browse files
Files changed (2) hide show
  1. app.py +9 -0
  2. videox_fun/models/wan_transformer3d.py +29 -0
app.py CHANGED
@@ -272,6 +272,15 @@ def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype):
272
  with gr.Column(variant="panel"):
273
  # Hide model selection
274
  diffusion_transformer_dropdown, _ = create_model_checkpoints(controller, visible=False, default_model="Wan-AI/Wan2.1-T2V-14B")
 
 
 
 
 
 
 
 
 
275
  base_model_dropdown, lora_model_dropdown, lora_alpha_slider, _ = create_finetune_models_checkpoints(controller, visible=False, default_lora="XiangpengYang/VideoCoF")
276
 
277
  # Set default LoRA alpha to 1.0 (matching inference.py)
 
272
  with gr.Column(variant="panel"):
273
  # Hide model selection
274
  diffusion_transformer_dropdown, _ = create_model_checkpoints(controller, visible=False, default_model="Wan-AI/Wan2.1-T2V-14B")
275
+
276
+ # Use snapshot download for the VideoCoF repo to get all weights (including safetensors)
277
+ try:
278
+ from huggingface_hub import snapshot_download
279
+ print("Downloading VideoCoF weights...")
280
+ snapshot_download(repo_id="XiangpengYang/VideoCoF")
281
+ except Exception as e:
282
+ print(f"Warning: Failed to pre-download VideoCoF weights: {e}")
283
+
284
  base_model_dropdown, lora_model_dropdown, lora_alpha_slider, _ = create_finetune_models_checkpoints(controller, visible=False, default_lora="XiangpengYang/VideoCoF")
285
 
286
  # Set default LoRA alpha to 1.0 (matching inference.py)
videox_fun/models/wan_transformer3d.py CHANGED
@@ -1161,6 +1161,35 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1161
  ):
1162
  if subfolder is not None:
1163
  pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1164
  print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
1165
 
1166
  config_file = os.path.join(pretrained_model_path, 'config.json')
 
1161
  ):
1162
  if subfolder is not None:
1163
  pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1164
+
1165
+ # Handle HF hub download for model weights
1166
+ import os
1167
+ from huggingface_hub import hf_hub_download, snapshot_download
1168
+
1169
+ # Check if pretrained_model_path is a Hugging Face repo ID
1170
+ if not os.path.exists(pretrained_model_path):
1171
+ try:
1172
+ print(f"Downloading model from HF repo: {pretrained_model_path}...")
1173
+ # Download config.json
1174
+ config_path = hf_hub_download(repo_id=pretrained_model_path, filename="config.json")
1175
+ pretrained_model_path = os.path.dirname(config_path)
1176
+
1177
+ # We also need model weights.
1178
+ # Try to download safetensors first
1179
+ try:
1180
+ hf_hub_download(repo_id=pretrained_model_path, filename="diffusion_pytorch_model.safetensors")
1181
+ except:
1182
+ pass
1183
+
1184
+ # Try bin
1185
+ try:
1186
+ hf_hub_download(repo_id=pretrained_model_path, filename="diffusion_pytorch_model.bin")
1187
+ except:
1188
+ pass
1189
+
1190
+ except Exception as e:
1191
+ print(f"Failed to download model from HF: {e}")
1192
+
1193
  print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
1194
 
1195
  config_file = os.path.join(pretrained_model_path, 'config.json')