XiangpengYang commited on
Commit
31189f6
·
1 Parent(s): 3ea6bf9

down wan 14b

Browse files
app.py CHANGED
@@ -276,10 +276,12 @@ def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype):
276
  # Use snapshot download for the VideoCoF repo to get all weights (including safetensors)
277
  try:
278
  from huggingface_hub import snapshot_download
 
 
279
  print("Downloading VideoCoF weights...")
280
  snapshot_download(repo_id="XiangpengYang/VideoCoF")
281
  except Exception as e:
282
- print(f"Warning: Failed to pre-download VideoCoF weights: {e}")
283
 
284
  base_model_dropdown, lora_model_dropdown, lora_alpha_slider, _ = create_finetune_models_checkpoints(controller, visible=False, default_lora="XiangpengYang/VideoCoF")
285
 
 
276
  # Use snapshot download for the VideoCoF repo to get all weights (including safetensors)
277
  try:
278
  from huggingface_hub import snapshot_download
279
+ print("Downloading Wan2.1-T2V-14B weights...")
280
+ snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir="Wan-AI/Wan2.1-T2V-14B")
281
  print("Downloading VideoCoF weights...")
282
  snapshot_download(repo_id="XiangpengYang/VideoCoF")
283
  except Exception as e:
284
+ print(f"Warning: Failed to pre-download weights: {e}")
285
 
286
  base_model_dropdown, lora_model_dropdown, lora_alpha_slider, _ = create_finetune_models_checkpoints(controller, visible=False, default_lora="XiangpengYang/VideoCoF")
287
 
videox_fun/models/wan_image_encoder.py CHANGED
@@ -537,8 +537,22 @@ class CLIPModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
537
  return filtered_kwargs
538
 
539
  model = cls(**filter_kwargs(cls, transformer_additional_kwargs))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
  if pretrained_model_path.endswith(".safetensors"):
541
- from safetensors.torch import load_file, safe_open
542
  state_dict = load_file(pretrained_model_path)
543
  else:
544
  state_dict = torch.load(pretrained_model_path, map_location="cpu")
@@ -549,5 +563,5 @@ class CLIPModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
549
  m, u = model.load_state_dict(state_dict)
550
 
551
  print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
552
- print(m, u)
553
  return model
 
537
  return filtered_kwargs
538
 
539
  model = cls(**filter_kwargs(cls, transformer_additional_kwargs))
540
+
541
+ # Handle HF hub download
542
+ import os
543
+ from huggingface_hub import hf_hub_download
544
+
545
+ # If path doesn't exist locally, assume it's a repo ID and try to download
546
+ if not os.path.exists(pretrained_model_path):
547
+ try:
548
+ # Try to download models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth from the repo
549
+ print(f"Downloading CLIP model from {pretrained_model_path}...")
550
+ pretrained_model_path = hf_hub_download(repo_id=pretrained_model_path, filename="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
551
+ except Exception as e:
552
+ print(f"Failed to download CLIP model from HF: {e}")
553
+
554
  if pretrained_model_path.endswith(".safetensors"):
555
+ from safetensors.torch import load_file
556
  state_dict = load_file(pretrained_model_path)
557
  else:
558
  state_dict = torch.load(pretrained_model_path, map_location="cpu")
 
563
  m, u = model.load_state_dict(state_dict)
564
 
565
  print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
566
+ # print(m, u)
567
  return model
videox_fun/models/wan_text_encoder.py CHANGED
@@ -311,6 +311,26 @@ class WanT5EncoderModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
311
  valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
312
  filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
313
  return filtered_kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
  if low_cpu_mem_usage:
316
  try:
 
311
  valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
312
  filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
313
  return filtered_kwargs
314
+
315
+ # Check if pretrained_model_path is a Hugging Face repo ID and download if needed
316
+ import os
317
+ from huggingface_hub import hf_hub_download
318
+
319
+ # If path doesn't exist locally, assume it's a repo ID and try to download the weights
320
+ if not os.path.exists(pretrained_model_path):
321
+ try:
322
+ # Try to download models_t5_umt5-xxl-enc-bf16.pth from the repo
323
+ # Note: The user mentioned `models_t5_umt5-xxl-enc-bf16.pth` in previous context or similar.
324
+ # But here we should check the file name. Wan repo usually has `models_t5_umt5-xxl-enc-bf16.pth` inside a folder or root.
325
+ # However, usually we download the file that corresponds to this class.
326
+ # Let's assume the user passes the full path or we default to a standard name if it's a directory/repo.
327
+ # If pretrained_model_path is a repo ID, we need the filename.
328
+ # Based on `wan_civitai.yaml`: text_encoder_subpath: models_t5_umt5-xxl-enc-bf16.pth
329
+ print(f"Downloading text encoder from {pretrained_model_path}...")
330
+ pretrained_model_path = hf_hub_download(repo_id=pretrained_model_path, filename="models_t5_umt5-xxl-enc-bf16.pth")
331
+ except Exception as e:
332
+ print(f"Failed to download Text Encoder from HF: {e}")
333
+ # Fallback to original path logic which might fail later if file missing
334
 
335
  if low_cpu_mem_usage:
336
  try:
videox_fun/models/wan_vae.py CHANGED
@@ -691,16 +691,32 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
691
  return filtered_kwargs
692
 
693
  model = cls(**filter_kwargs(cls, additional_kwargs))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
694
  if pretrained_model_path.endswith(".safetensors"):
695
- from safetensors.torch import load_file, safe_open
696
  state_dict = load_file(pretrained_model_path)
697
  else:
698
  state_dict = torch.load(pretrained_model_path, map_location="cpu")
 
699
  tmp_state_dict = {}
700
  for key in state_dict:
701
  tmp_state_dict["model." + key] = state_dict[key]
702
  state_dict = tmp_state_dict
703
  m, u = model.load_state_dict(state_dict, strict=False)
704
  print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
705
- print(m, u)
706
  return model
 
691
  return filtered_kwargs
692
 
693
  model = cls(**filter_kwargs(cls, additional_kwargs))
694
+
695
+ # Check if pretrained_model_path is a Hugging Face repo ID (e.g., "Wan-AI/Wan2.1-T2V-14B")
696
+ import os
697
+ from huggingface_hub import hf_hub_download
698
+
699
+ # If path doesn't exist locally, assume it's a repo ID and try to download VAE file
700
+ if not os.path.exists(pretrained_model_path):
701
+ try:
702
+ # Try to download Wan2.1_VAE.pth from the repo
703
+ print(f"Downloading Wan2.1_VAE.pth from {pretrained_model_path}...")
704
+ pretrained_model_path = hf_hub_download(repo_id=pretrained_model_path, filename="Wan2.1_VAE.pth")
705
+ except Exception as e:
706
+ print(f"Failed to download VAE from HF: {e}")
707
+ # Fallback or re-raise if needed, but torch.load will fail anyway if path is invalid
708
+
709
  if pretrained_model_path.endswith(".safetensors"):
710
+ from safetensors.torch import load_file
711
  state_dict = load_file(pretrained_model_path)
712
  else:
713
  state_dict = torch.load(pretrained_model_path, map_location="cpu")
714
+
715
  tmp_state_dict = {}
716
  for key in state_dict:
717
  tmp_state_dict["model." + key] = state_dict[key]
718
  state_dict = tmp_state_dict
719
  m, u = model.load_state_dict(state_dict, strict=False)
720
  print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
721
+ # print(m, u)
722
  return model