Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
31189f6
1
Parent(s):
3ea6bf9
down wan 14b
Browse files- app.py +3 -1
- videox_fun/models/wan_image_encoder.py +16 -2
- videox_fun/models/wan_text_encoder.py +20 -0
- videox_fun/models/wan_vae.py +18 -2
app.py
CHANGED
|
@@ -276,10 +276,12 @@ def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype):
|
|
| 276 |
# Use snapshot download for the VideoCoF repo to get all weights (including safetensors)
|
| 277 |
try:
|
| 278 |
from huggingface_hub import snapshot_download
|
|
|
|
|
|
|
| 279 |
print("Downloading VideoCoF weights...")
|
| 280 |
snapshot_download(repo_id="XiangpengYang/VideoCoF")
|
| 281 |
except Exception as e:
|
| 282 |
-
print(f"Warning: Failed to pre-download
|
| 283 |
|
| 284 |
base_model_dropdown, lora_model_dropdown, lora_alpha_slider, _ = create_finetune_models_checkpoints(controller, visible=False, default_lora="XiangpengYang/VideoCoF")
|
| 285 |
|
|
|
|
| 276 |
# Use snapshot download for the VideoCoF repo to get all weights (including safetensors)
|
| 277 |
try:
|
| 278 |
from huggingface_hub import snapshot_download
|
| 279 |
+
print("Downloading Wan2.1-T2V-14B weights...")
|
| 280 |
+
snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir="Wan-AI/Wan2.1-T2V-14B")
|
| 281 |
print("Downloading VideoCoF weights...")
|
| 282 |
snapshot_download(repo_id="XiangpengYang/VideoCoF")
|
| 283 |
except Exception as e:
|
| 284 |
+
print(f"Warning: Failed to pre-download weights: {e}")
|
| 285 |
|
| 286 |
base_model_dropdown, lora_model_dropdown, lora_alpha_slider, _ = create_finetune_models_checkpoints(controller, visible=False, default_lora="XiangpengYang/VideoCoF")
|
| 287 |
|
videox_fun/models/wan_image_encoder.py
CHANGED
|
@@ -537,8 +537,22 @@ class CLIPModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
| 537 |
return filtered_kwargs
|
| 538 |
|
| 539 |
model = cls(**filter_kwargs(cls, transformer_additional_kwargs))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 540 |
if pretrained_model_path.endswith(".safetensors"):
|
| 541 |
-
from safetensors.torch import load_file
|
| 542 |
state_dict = load_file(pretrained_model_path)
|
| 543 |
else:
|
| 544 |
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
|
@@ -549,5 +563,5 @@ class CLIPModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
| 549 |
m, u = model.load_state_dict(state_dict)
|
| 550 |
|
| 551 |
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 552 |
-
print(m, u)
|
| 553 |
return model
|
|
|
|
| 537 |
return filtered_kwargs
|
| 538 |
|
| 539 |
model = cls(**filter_kwargs(cls, transformer_additional_kwargs))
|
| 540 |
+
|
| 541 |
+
# Handle HF hub download
|
| 542 |
+
import os
|
| 543 |
+
from huggingface_hub import hf_hub_download
|
| 544 |
+
|
| 545 |
+
# If path doesn't exist locally, assume it's a repo ID and try to download
|
| 546 |
+
if not os.path.exists(pretrained_model_path):
|
| 547 |
+
try:
|
| 548 |
+
# Try to download models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth from the repo
|
| 549 |
+
print(f"Downloading CLIP model from {pretrained_model_path}...")
|
| 550 |
+
pretrained_model_path = hf_hub_download(repo_id=pretrained_model_path, filename="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
|
| 551 |
+
except Exception as e:
|
| 552 |
+
print(f"Failed to download CLIP model from HF: {e}")
|
| 553 |
+
|
| 554 |
if pretrained_model_path.endswith(".safetensors"):
|
| 555 |
+
from safetensors.torch import load_file
|
| 556 |
state_dict = load_file(pretrained_model_path)
|
| 557 |
else:
|
| 558 |
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
|
|
|
| 563 |
m, u = model.load_state_dict(state_dict)
|
| 564 |
|
| 565 |
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 566 |
+
# print(m, u)
|
| 567 |
return model
|
videox_fun/models/wan_text_encoder.py
CHANGED
|
@@ -311,6 +311,26 @@ class WanT5EncoderModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
| 311 |
valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
|
| 312 |
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
|
| 313 |
return filtered_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
if low_cpu_mem_usage:
|
| 316 |
try:
|
|
|
|
| 311 |
valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
|
| 312 |
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
|
| 313 |
return filtered_kwargs
|
| 314 |
+
|
| 315 |
+
# Check if pretrained_model_path is a Hugging Face repo ID and download if needed
|
| 316 |
+
import os
|
| 317 |
+
from huggingface_hub import hf_hub_download
|
| 318 |
+
|
| 319 |
+
# If path doesn't exist locally, assume it's a repo ID and try to download the weights
|
| 320 |
+
if not os.path.exists(pretrained_model_path):
|
| 321 |
+
try:
|
| 322 |
+
# Try to download models_t5_umt5-xxl-enc-bf16.pth from the repo
|
| 323 |
+
# Note: The user mentioned `models_t5_umt5-xxl-enc-bf16.pth` in previous context or similar.
|
| 324 |
+
# But here we should check the file name. Wan repo usually has `models_t5_umt5-xxl-enc-bf16.pth` inside a folder or root.
|
| 325 |
+
# However, usually we download the file that corresponds to this class.
|
| 326 |
+
# Let's assume the user passes the full path or we default to a standard name if it's a directory/repo.
|
| 327 |
+
# If pretrained_model_path is a repo ID, we need the filename.
|
| 328 |
+
# Based on `wan_civitai.yaml`: text_encoder_subpath: models_t5_umt5-xxl-enc-bf16.pth
|
| 329 |
+
print(f"Downloading text encoder from {pretrained_model_path}...")
|
| 330 |
+
pretrained_model_path = hf_hub_download(repo_id=pretrained_model_path, filename="models_t5_umt5-xxl-enc-bf16.pth")
|
| 331 |
+
except Exception as e:
|
| 332 |
+
print(f"Failed to download Text Encoder from HF: {e}")
|
| 333 |
+
# Fallback to original path logic which might fail later if file missing
|
| 334 |
|
| 335 |
if low_cpu_mem_usage:
|
| 336 |
try:
|
videox_fun/models/wan_vae.py
CHANGED
|
@@ -691,16 +691,32 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
| 691 |
return filtered_kwargs
|
| 692 |
|
| 693 |
model = cls(**filter_kwargs(cls, additional_kwargs))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 694 |
if pretrained_model_path.endswith(".safetensors"):
|
| 695 |
-
from safetensors.torch import load_file
|
| 696 |
state_dict = load_file(pretrained_model_path)
|
| 697 |
else:
|
| 698 |
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
|
|
|
| 699 |
tmp_state_dict = {}
|
| 700 |
for key in state_dict:
|
| 701 |
tmp_state_dict["model." + key] = state_dict[key]
|
| 702 |
state_dict = tmp_state_dict
|
| 703 |
m, u = model.load_state_dict(state_dict, strict=False)
|
| 704 |
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 705 |
-
print(m, u)
|
| 706 |
return model
|
|
|
|
| 691 |
return filtered_kwargs
|
| 692 |
|
| 693 |
model = cls(**filter_kwargs(cls, additional_kwargs))
|
| 694 |
+
|
| 695 |
+
# Check if pretrained_model_path is a Hugging Face repo ID (e.g., "Wan-AI/Wan2.1-T2V-14B")
|
| 696 |
+
import os
|
| 697 |
+
from huggingface_hub import hf_hub_download
|
| 698 |
+
|
| 699 |
+
# If path doesn't exist locally, assume it's a repo ID and try to download VAE file
|
| 700 |
+
if not os.path.exists(pretrained_model_path):
|
| 701 |
+
try:
|
| 702 |
+
# Try to download Wan2.1_VAE.pth from the repo
|
| 703 |
+
print(f"Downloading Wan2.1_VAE.pth from {pretrained_model_path}...")
|
| 704 |
+
pretrained_model_path = hf_hub_download(repo_id=pretrained_model_path, filename="Wan2.1_VAE.pth")
|
| 705 |
+
except Exception as e:
|
| 706 |
+
print(f"Failed to download VAE from HF: {e}")
|
| 707 |
+
# Fallback or re-raise if needed, but torch.load will fail anyway if path is invalid
|
| 708 |
+
|
| 709 |
if pretrained_model_path.endswith(".safetensors"):
|
| 710 |
+
from safetensors.torch import load_file
|
| 711 |
state_dict = load_file(pretrained_model_path)
|
| 712 |
else:
|
| 713 |
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
| 714 |
+
|
| 715 |
tmp_state_dict = {}
|
| 716 |
for key in state_dict:
|
| 717 |
tmp_state_dict["model." + key] = state_dict[key]
|
| 718 |
state_dict = tmp_state_dict
|
| 719 |
m, u = model.load_state_dict(state_dict, strict=False)
|
| 720 |
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 721 |
+
# print(m, u)
|
| 722 |
return model
|