shaocong commited on
Commit
d02e776
·
2 Parent(s): f4470c4 296d033

add dkt tools example

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. README.md +1 -1
  3. dkt/__init__.py +4 -0
  4. dkt/configs/__init__.py +0 -0
  5. dkt/configs/model_config.py +68 -0
  6. dkt/lora/__init__.py +45 -0
  7. dkt/models/__init__.py +1 -0
  8. dkt/models/attention.py +89 -0
  9. dkt/models/downloader.py +111 -0
  10. dkt/models/lora.py +197 -0
  11. dkt/models/model_manager.py +421 -0
  12. dkt/models/tiler.py +234 -0
  13. dkt/models/utils.py +182 -0
  14. dkt/models/wan_video_camera_controller.py +202 -0
  15. dkt/models/wan_video_dit.py +719 -0
  16. dkt/models/wan_video_image_encoder.py +902 -0
  17. dkt/models/wan_video_motion_controller.py +44 -0
  18. dkt/models/wan_video_text_encoder.py +269 -0
  19. dkt/models/wan_video_vace.py +113 -0
  20. dkt/models/wan_video_vae.py +1376 -0
  21. dkt/pipelines/__init__.py +0 -0
  22. dkt/pipelines/pipeline.py +1965 -0
  23. dkt/prompters/__init__.py +13 -0
  24. dkt/prompters/base_prompter.py +70 -0
  25. dkt/prompters/wan_prompter.py +109 -0
  26. dkt/schedulers/__init__.py +1 -0
  27. dkt/schedulers/flow_match.py +126 -0
  28. dkt/utils/__init__.py +261 -0
  29. dkt/vram_management/__init__.py +2 -0
  30. dkt/vram_management/gradient_checkpointing.py +34 -0
  31. dkt/vram_management/layers.py +213 -0
  32. examples/1.mp4 +3 -0
  33. examples/10.mp4 +3 -0
  34. examples/178db6e89ab682bfc612a3290fec58dd.mp4 +3 -0
  35. examples/18.mp4 +3 -0
  36. examples/1b0daeb776471c7389b36cee53049417.mp4 +3 -0
  37. examples/2.mp4 +3 -0
  38. examples/27.mp4 +3 -0
  39. examples/28.mp4 +3 -0
  40. examples/3.mp4 +3 -0
  41. examples/30.mp4 +3 -0
  42. examples/31.mp4 +3 -0
  43. examples/32.mp4 +3 -0
  44. examples/33.mp4 +3 -0
  45. examples/35.mp4 +3 -0
  46. examples/36.mp4 +3 -0
  47. examples/39.mp4 +3 -0
  48. examples/4.mp4 +3 -0
  49. examples/40.mp4 +3 -0
  50. examples/5.mp4 +3 -0
.gitattributes CHANGED
@@ -84,4 +84,5 @@ examples/28.mp4 filter=lfs diff=lfs merge=lfs -text
84
  examples/4.mp4 filter=lfs diff=lfs merge=lfs -text
85
  examples/extra_5.mp4 filter=lfs diff=lfs merge=lfs -text
86
  examples/extra_9.mp4 filter=lfs diff=lfs merge=lfs -text
 
87
  examples/IMG_5703.mp4 filter=lfs diff=lfs merge=lfs -text
 
84
  examples/4.mp4 filter=lfs diff=lfs merge=lfs -text
85
  examples/extra_5.mp4 filter=lfs diff=lfs merge=lfs -text
86
  examples/extra_9.mp4 filter=lfs diff=lfs merge=lfs -text
87
+ examples/IMG_5703.MOV filter=lfs diff=lfs merge=lfs -text
88
  examples/IMG_5703.mp4 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -11,4 +11,4 @@ license: apache-2.0
11
  short_description: DKT-Normal
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
11
  short_description: DKT-Normal
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
dkt/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .models import *
2
+ from .prompters import *
3
+ from .schedulers import *
4
+ from .pipelines import *
dkt/configs/__init__.py ADDED
File without changes
dkt/configs/model_config.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Literal, TypeAlias
2
+
3
+
4
+ from ..models.wan_video_dit import WanModel
5
+ from ..models.wan_video_text_encoder import WanTextEncoder
6
+ from ..models.wan_video_image_encoder import WanImageEncoder
7
+ from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38
8
+ from ..models.wan_video_motion_controller import WanMotionControllerModel
9
+ from ..models.wan_video_vace import VaceWanModel
10
+ model_loader_configs = [
11
+ # These configs are provided for detecting model type automatically.
12
+ # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
13
+ (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
14
+ (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
15
+ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
16
+ (None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
17
+ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
18
+ (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
19
+ (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
20
+ (None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
21
+ (None, "70ddad9d3a133785da5ea371aae09504", ["wan_video_dit"], [WanModel], "civitai"),
22
+ (None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"),
23
+ (None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"),
24
+ (None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
25
+ (None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"),
26
+ (None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"),
27
+ (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
28
+ (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
29
+ (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
30
+ (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
31
+ (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
32
+ (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
33
+ (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
34
+ (None, "e1de6c02cdac79f8b739f4d3698cd216", ["wan_video_vae"], [WanVideoVAE38], "civitai"),
35
+ (None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
36
+ ]
37
+ huggingface_model_loader_configs = [
38
+ # These configs are provided for detecting model type automatically.
39
+ # The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
40
+ ("ChatGLMModel", "dkt.models.kolors_text_encoder", "kolors_text_encoder", None),
41
+ ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
42
+ ("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
43
+ ("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
44
+ # ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
45
+ ("T5EncoderModel", "dkt.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
46
+ ("CogVideoXTransformer3DModel", "dkt.models.cog_dit", "cog_dit", "CogDiT"),
47
+ ("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
48
+ ("LlamaForCausalLM", "dkt.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
49
+ ("LlavaForConditionalGeneration", "dkt.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
50
+ ("Step1Model", "dkt.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
51
+ ("Qwen2_5_VLForConditionalGeneration", "dkt.models.qwenvl", "qwenvl", "Qwen25VL_7b_Embedder"),
52
+ ]
53
+ patch_model_loader_configs = [
54
+ # These configs are provided for detecting model type automatically.
55
+ # The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
56
+ # ("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
57
+ ]
58
+
59
+ preset_models_on_huggingface = {
60
+
61
+ }
62
+ preset_models_on_modelscope = {
63
+
64
+ }
65
+ Preset_model_id: TypeAlias = Literal[
66
+ ...
67
+
68
+ ]
dkt/lora/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+
5
+ class GeneralLoRALoader:
6
+ def __init__(self, device="cpu", torch_dtype=torch.float32):
7
+ self.device = device
8
+ self.torch_dtype = torch_dtype
9
+
10
+
11
+ def get_name_dict(self, lora_state_dict):
12
+ lora_name_dict = {}
13
+ for key in lora_state_dict:
14
+ if ".lora_B." not in key:
15
+ continue
16
+ keys = key.split(".")
17
+ if len(keys) > keys.index("lora_B") + 2:
18
+ keys.pop(keys.index("lora_B") + 1)
19
+ keys.pop(keys.index("lora_B"))
20
+ if keys[0] == "diffusion_model":
21
+ keys.pop(0)
22
+ keys.pop(-1)
23
+ target_name = ".".join(keys)
24
+ lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
25
+ return lora_name_dict
26
+
27
+
28
+ def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
29
+ updated_num = 0
30
+ lora_name_dict = self.get_name_dict(state_dict_lora)
31
+ for name, module in model.named_modules():
32
+ if name in lora_name_dict:
33
+ weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
34
+ weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
35
+ if len(weight_up.shape) == 4:
36
+ weight_up = weight_up.squeeze(3).squeeze(2)
37
+ weight_down = weight_down.squeeze(3).squeeze(2)
38
+ weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
39
+ else:
40
+ weight_lora = alpha * torch.mm(weight_up, weight_down)
41
+ state_dict = module.state_dict()
42
+ state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
43
+ module.load_state_dict(state_dict)
44
+ updated_num += 1
45
+ print(f"{updated_num} tensors are updated by LoRA.")
dkt/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model_manager import *
dkt/models/attention.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+
4
+
5
+ def low_version_attention(query, key, value, attn_bias=None):
6
+ scale = 1 / query.shape[-1] ** 0.5
7
+ query = query * scale
8
+ attn = torch.matmul(query, key.transpose(-2, -1))
9
+ if attn_bias is not None:
10
+ attn = attn + attn_bias
11
+ attn = attn.softmax(-1)
12
+ return attn @ value
13
+
14
+
15
+ class Attention(torch.nn.Module):
16
+
17
+ def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
18
+ super().__init__()
19
+ dim_inner = head_dim * num_heads
20
+ kv_dim = kv_dim if kv_dim is not None else q_dim
21
+ self.num_heads = num_heads
22
+ self.head_dim = head_dim
23
+
24
+ self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
25
+ self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
26
+ self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
27
+ self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
28
+
29
+ def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
30
+ batch_size = q.shape[0]
31
+ ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
32
+ ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
33
+ ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
34
+ hidden_states = hidden_states + scale * ip_hidden_states
35
+ return hidden_states
36
+
37
+ def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
38
+ if encoder_hidden_states is None:
39
+ encoder_hidden_states = hidden_states
40
+
41
+ batch_size = encoder_hidden_states.shape[0]
42
+
43
+ q = self.to_q(hidden_states)
44
+ k = self.to_k(encoder_hidden_states)
45
+ v = self.to_v(encoder_hidden_states)
46
+
47
+ q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
48
+ k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
49
+ v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
50
+
51
+ if qkv_preprocessor is not None:
52
+ q, k, v = qkv_preprocessor(q, k, v)
53
+
54
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
55
+ if ipadapter_kwargs is not None:
56
+ hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
57
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
58
+ hidden_states = hidden_states.to(q.dtype)
59
+
60
+ hidden_states = self.to_out(hidden_states)
61
+
62
+ return hidden_states
63
+
64
+ def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
65
+ if encoder_hidden_states is None:
66
+ encoder_hidden_states = hidden_states
67
+
68
+ q = self.to_q(hidden_states)
69
+ k = self.to_k(encoder_hidden_states)
70
+ v = self.to_v(encoder_hidden_states)
71
+
72
+ q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
73
+ k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
74
+ v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
75
+
76
+ if attn_mask is not None:
77
+ hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
78
+ else:
79
+ import xformers.ops as xops
80
+ hidden_states = xops.memory_efficient_attention(q, k, v)
81
+ hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
82
+
83
+ hidden_states = hidden_states.to(q.dtype)
84
+ hidden_states = self.to_out(hidden_states)
85
+
86
+ return hidden_states
87
+
88
+ def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
89
+ return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
dkt/models/downloader.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+ from modelscope import snapshot_download
3
+ import os, shutil
4
+ from typing_extensions import Literal, TypeAlias
5
+ from typing import List
6
+ from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
7
+
8
+
9
+ def download_from_modelscope(model_id, origin_file_path, local_dir):
10
+ os.makedirs(local_dir, exist_ok=True)
11
+ file_name = os.path.basename(origin_file_path)
12
+ if file_name in os.listdir(local_dir):
13
+ print(f" {file_name} has been already in {local_dir}.")
14
+ else:
15
+ print(f" Start downloading {os.path.join(local_dir, file_name)}")
16
+ snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
17
+ downloaded_file_path = os.path.join(local_dir, origin_file_path)
18
+ target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
19
+ if downloaded_file_path != target_file_path:
20
+ shutil.move(downloaded_file_path, target_file_path)
21
+ shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
22
+
23
+
24
+ def download_from_huggingface(model_id, origin_file_path, local_dir):
25
+ os.makedirs(local_dir, exist_ok=True)
26
+ file_name = os.path.basename(origin_file_path)
27
+ if file_name in os.listdir(local_dir):
28
+ print(f" {file_name} has been already in {local_dir}.")
29
+ else:
30
+ print(f" Start downloading {os.path.join(local_dir, file_name)}")
31
+ hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
32
+ downloaded_file_path = os.path.join(local_dir, origin_file_path)
33
+ target_file_path = os.path.join(local_dir, file_name)
34
+ if downloaded_file_path != target_file_path:
35
+ shutil.move(downloaded_file_path, target_file_path)
36
+ shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
37
+
38
+
39
+ Preset_model_website: TypeAlias = Literal[
40
+ "HuggingFace",
41
+ "ModelScope",
42
+ ]
43
+ website_to_preset_models = {
44
+ "HuggingFace": preset_models_on_huggingface,
45
+ "ModelScope": preset_models_on_modelscope,
46
+ }
47
+ website_to_download_fn = {
48
+ "HuggingFace": download_from_huggingface,
49
+ "ModelScope": download_from_modelscope,
50
+ }
51
+
52
+
53
+ def download_customized_models(
54
+ model_id,
55
+ origin_file_path,
56
+ local_dir,
57
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
58
+ ):
59
+ downloaded_files = []
60
+ for website in downloading_priority:
61
+ # Check if the file is downloaded.
62
+ file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
63
+ if file_to_download in downloaded_files:
64
+ continue
65
+ # Download
66
+ website_to_download_fn[website](model_id, origin_file_path, local_dir)
67
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
68
+ downloaded_files.append(file_to_download)
69
+ return downloaded_files
70
+
71
+
72
+ def download_models(
73
+ model_id_list: List[Preset_model_id] = [],
74
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
75
+ ):
76
+ print(f"Downloading models: {model_id_list}")
77
+ downloaded_files = []
78
+ load_files = []
79
+
80
+ for model_id in model_id_list:
81
+ for website in downloading_priority:
82
+ if model_id in website_to_preset_models[website]:
83
+
84
+ # Parse model metadata
85
+ model_metadata = website_to_preset_models[website][model_id]
86
+ if isinstance(model_metadata, list):
87
+ file_data = model_metadata
88
+ else:
89
+ file_data = model_metadata.get("file_list", [])
90
+
91
+ # Try downloading the model from this website.
92
+ model_files = []
93
+ for model_id, origin_file_path, local_dir in file_data:
94
+ # Check if the file is downloaded.
95
+ file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
96
+ if file_to_download in downloaded_files:
97
+ continue
98
+ # Download
99
+ website_to_download_fn[website](model_id, origin_file_path, local_dir)
100
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
101
+ downloaded_files.append(file_to_download)
102
+ model_files.append(file_to_download)
103
+
104
+ # If the model is successfully downloaded, break.
105
+ if len(model_files) > 0:
106
+ if isinstance(model_metadata, dict) and "load_path" in model_metadata:
107
+ model_files = model_metadata["load_path"]
108
+ load_files.extend(model_files)
109
+ break
110
+
111
+ return load_files
dkt/models/lora.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .wan_video_dit import WanModel
3
+
4
+
5
+
6
+ class LoRAFromCivitai:
7
+ def __init__(self):
8
+ self.supported_model_classes = []
9
+ self.lora_prefix = []
10
+ self.renamed_lora_prefix = {}
11
+ self.special_keys = {}
12
+
13
+
14
+ def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
15
+ for key in state_dict:
16
+ if ".lora_up" in key:
17
+ return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
18
+ return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
19
+
20
+
21
+ def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
22
+ renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
23
+ state_dict_ = {}
24
+ for key in state_dict:
25
+ if ".lora_up" not in key:
26
+ continue
27
+ if not key.startswith(lora_prefix):
28
+ continue
29
+ weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
30
+ weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
31
+ if len(weight_up.shape) == 4:
32
+ weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
33
+ weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
34
+ lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
35
+ else:
36
+ lora_weight = alpha * torch.mm(weight_up, weight_down)
37
+ target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
38
+ for special_key in self.special_keys:
39
+ target_name = target_name.replace(special_key, self.special_keys[special_key])
40
+ state_dict_[target_name] = lora_weight.cpu()
41
+ return state_dict_
42
+
43
+
44
+ def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
45
+ state_dict_ = {}
46
+ for key in state_dict:
47
+ if ".lora_B." not in key:
48
+ continue
49
+ if not key.startswith(lora_prefix):
50
+ continue
51
+ weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
52
+ weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
53
+ if len(weight_up.shape) == 4:
54
+ weight_up = weight_up.squeeze(3).squeeze(2)
55
+ weight_down = weight_down.squeeze(3).squeeze(2)
56
+ lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
57
+ else:
58
+ lora_weight = alpha * torch.mm(weight_up, weight_down)
59
+ keys = key.split(".")
60
+ keys.pop(keys.index("lora_B"))
61
+ target_name = ".".join(keys)
62
+ target_name = target_name[len(lora_prefix):]
63
+ state_dict_[target_name] = lora_weight.cpu()
64
+ return state_dict_
65
+
66
+
67
+ def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
68
+ state_dict_model = model.state_dict()
69
+ state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
70
+ if model_resource == "diffusers":
71
+ state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
72
+ elif model_resource == "civitai":
73
+ state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
74
+ if isinstance(state_dict_lora, tuple):
75
+ state_dict_lora = state_dict_lora[0]
76
+ if len(state_dict_lora) > 0:
77
+ print(f" {len(state_dict_lora)} tensors are updated.")
78
+ for name in state_dict_lora:
79
+ fp8=False
80
+ if state_dict_model[name].dtype == torch.float8_e4m3fn:
81
+ state_dict_model[name]= state_dict_model[name].to(state_dict_lora[name].dtype)
82
+ fp8=True
83
+ state_dict_model[name] += state_dict_lora[name].to(
84
+ dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
85
+ if fp8:
86
+ state_dict_model[name] = state_dict_model[name].to(torch.float8_e4m3fn)
87
+ model.load_state_dict(state_dict_model)
88
+
89
+
90
+ def match(self, model, state_dict_lora):
91
+ for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
92
+ if not isinstance(model, model_class):
93
+ continue
94
+ state_dict_model = model.state_dict()
95
+ for model_resource in ["diffusers", "civitai"]:
96
+ try:
97
+ state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
98
+ converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
99
+ else model.__class__.state_dict_converter().from_civitai
100
+ state_dict_lora_ = converter_fn(state_dict_lora_)
101
+ if isinstance(state_dict_lora_, tuple):
102
+ state_dict_lora_ = state_dict_lora_[0]
103
+ if len(state_dict_lora_) == 0:
104
+ continue
105
+ for name in state_dict_lora_:
106
+ if name not in state_dict_model:
107
+ break
108
+ else:
109
+ return lora_prefix, model_resource
110
+ except:
111
+ pass
112
+ return None
113
+
114
+
115
+ class GeneralLoRAFromPeft:
116
+ def __init__(self):
117
+ self.supported_model_classes = [ WanModel]
118
+
119
+
120
+ def get_name_dict(self, lora_state_dict):
121
+ lora_name_dict = {}
122
+ for key in lora_state_dict:
123
+ if ".lora_B." not in key:
124
+ continue
125
+ keys = key.split(".")
126
+ if len(keys) > keys.index("lora_B") + 2:
127
+ keys.pop(keys.index("lora_B") + 1)
128
+ keys.pop(keys.index("lora_B"))
129
+ if keys[0] == "diffusion_model":
130
+ keys.pop(0)
131
+ target_name = ".".join(keys)
132
+ lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
133
+ return lora_name_dict
134
+
135
+
136
+ def match(self, model: torch.nn.Module, state_dict_lora):
137
+ lora_name_dict = self.get_name_dict(state_dict_lora)
138
+ model_name_dict = {name: None for name, _ in model.named_parameters()}
139
+ matched_num = sum([i in model_name_dict for i in lora_name_dict])
140
+ if matched_num == len(lora_name_dict):
141
+ return "", ""
142
+ else:
143
+ return None
144
+
145
+
146
+ def fetch_device_and_dtype(self, state_dict):
147
+ device, dtype = None, None
148
+ for name, param in state_dict.items():
149
+ device, dtype = param.device, param.dtype
150
+ break
151
+ computation_device = device
152
+ computation_dtype = dtype
153
+ if computation_device == torch.device("cpu"):
154
+ if torch.cuda.is_available():
155
+ computation_device = torch.device("cuda")
156
+ if computation_dtype == torch.float8_e4m3fn:
157
+ computation_dtype = torch.float32
158
+ return device, dtype, computation_device, computation_dtype
159
+
160
+
161
+ def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
162
+ state_dict_model = model.state_dict()
163
+ device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
164
+ lora_name_dict = self.get_name_dict(state_dict_lora)
165
+ for name in lora_name_dict:
166
+ weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
167
+ weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
168
+ if len(weight_up.shape) == 4:
169
+ weight_up = weight_up.squeeze(3).squeeze(2)
170
+ weight_down = weight_down.squeeze(3).squeeze(2)
171
+ weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
172
+ else:
173
+ weight_lora = alpha * torch.mm(weight_up, weight_down)
174
+ weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
175
+ weight_patched = weight_model + weight_lora
176
+ state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
177
+ print(f" {len(lora_name_dict)} tensors are updated.")
178
+ model.load_state_dict(state_dict_model)
179
+
180
+
181
+
182
+ class WanLoRAConverter:
183
+ def __init__(self):
184
+ pass
185
+
186
+ @staticmethod
187
+ def align_to_opensource_format(state_dict, **kwargs):
188
+ state_dict = {"diffusion_model." + name.replace(".default.", "."): param for name, param in state_dict.items()}
189
+ return state_dict
190
+
191
+ @staticmethod
192
+ def align_to_dkt_format(state_dict, **kwargs):
193
+ state_dict = {name.replace("diffusion_model.", "").replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
194
+ return state_dict
195
+
196
+ def get_lora_loaders():
197
+ return [GeneralLoRAFromPeft()]
dkt/models/model_manager.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch, json, importlib
2
+ from typing import List
3
+
4
+ from .downloader import download_models
5
+ from .lora import get_lora_loaders
6
+
7
+ from ..configs.model_config import model_loader_configs
8
+ from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
9
+
10
+ from .downloader import Preset_model_id, Preset_model_website
11
+
12
+ def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
13
+ loaded_model_names, loaded_models = [], []
14
+ for model_name, model_class in zip(model_names, model_classes):
15
+ print(f" model_name: {model_name} model_class: {model_class.__name__}")
16
+ state_dict_converter = model_class.state_dict_converter()
17
+ if model_resource == "civitai":
18
+ state_dict_results = state_dict_converter.from_civitai(state_dict)
19
+ elif model_resource == "diffusers":
20
+ state_dict_results = state_dict_converter.from_diffusers(state_dict)
21
+ if isinstance(state_dict_results, tuple):
22
+ model_state_dict, extra_kwargs = state_dict_results
23
+ print(f" This model is initialized with extra kwargs: {extra_kwargs}")
24
+ else:
25
+ model_state_dict, extra_kwargs = state_dict_results, {}
26
+ torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
27
+ with init_weights_on_device():
28
+ model = model_class(**extra_kwargs)
29
+ if hasattr(model, "eval"):
30
+ model = model.eval()
31
+ model.load_state_dict(model_state_dict, assign=True)
32
+ model = model.to(dtype=torch_dtype, device=device)
33
+ loaded_model_names.append(model_name)
34
+ loaded_models.append(model)
35
+ return loaded_model_names, loaded_models
36
+
37
+
38
+ def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
39
+ loaded_model_names, loaded_models = [], []
40
+ for model_name, model_class in zip(model_names, model_classes):
41
+ if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
42
+ model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
43
+ else:
44
+ model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
45
+ if torch_dtype == torch.float16 and hasattr(model, "half"):
46
+ model = model.half()
47
+ try:
48
+ model = model.to(device=device)
49
+ except:
50
+ pass
51
+ loaded_model_names.append(model_name)
52
+ loaded_models.append(model)
53
+ return loaded_model_names, loaded_models
54
+
55
+
56
+ def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
57
+ print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
58
+ base_state_dict = base_model.state_dict()
59
+ base_model.to("cpu")
60
+ del base_model
61
+ model = model_class(**extra_kwargs)
62
+ model.load_state_dict(base_state_dict, strict=False)
63
+ model.load_state_dict(state_dict, strict=False)
64
+ model.to(dtype=torch_dtype, device=device)
65
+ return model
66
+
67
+
68
+ def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
69
+ loaded_model_names, loaded_models = [], []
70
+ for model_name, model_class in zip(model_names, model_classes):
71
+ while True:
72
+ for model_id in range(len(model_manager.model)):
73
+ base_model_name = model_manager.model_name[model_id]
74
+ if base_model_name == model_name:
75
+ base_model_path = model_manager.model_path[model_id]
76
+ base_model = model_manager.model[model_id]
77
+ print(f" Adding patch model to {base_model_name} ({base_model_path})")
78
+ patched_model = load_single_patch_model_from_single_file(
79
+ state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
80
+ loaded_model_names.append(base_model_name)
81
+ loaded_models.append(patched_model)
82
+ model_manager.model.pop(model_id)
83
+ model_manager.model_path.pop(model_id)
84
+ model_manager.model_name.pop(model_id)
85
+ break
86
+ else:
87
+ break
88
+ return loaded_model_names, loaded_models
89
+
90
+
91
+
92
+ class ModelDetectorTemplate:
93
+ def __init__(self):
94
+ pass
95
+
96
+ def match(self, file_path="", state_dict={}):
97
+ return False
98
+
99
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
100
+ return [], []
101
+
102
+
103
+
104
+ class ModelDetectorFromSingleFile:
105
+ def __init__(self, model_loader_configs=[]):
106
+ self.keys_hash_with_shape_dict = {}
107
+ self.keys_hash_dict = {}
108
+ for metadata in model_loader_configs:
109
+ self.add_model_metadata(*metadata)
110
+
111
+
112
+ def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
113
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
114
+ if keys_hash is not None:
115
+ self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
116
+
117
+
118
+ def match(self, file_path="", state_dict={}):
119
+ if isinstance(file_path, str) and os.path.isdir(file_path):
120
+ return False
121
+ if len(state_dict) == 0:
122
+ state_dict = load_state_dict(file_path)
123
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
124
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
125
+ return True
126
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
127
+ if keys_hash in self.keys_hash_dict:
128
+ return True
129
+ return False
130
+
131
+
132
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
133
+ if len(state_dict) == 0:
134
+ state_dict = load_state_dict(file_path)
135
+
136
+ # Load models with strict matching
137
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
138
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
139
+ model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
140
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
141
+ return loaded_model_names, loaded_models
142
+
143
+ # Load models without strict matching
144
+ # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
145
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
146
+ if keys_hash in self.keys_hash_dict:
147
+ model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
148
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
149
+ return loaded_model_names, loaded_models
150
+
151
+ return loaded_model_names, loaded_models
152
+
153
+
154
+
155
+ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
156
+ def __init__(self, model_loader_configs=[]):
157
+ super().__init__(model_loader_configs)
158
+
159
+
160
+ def match(self, file_path="", state_dict={}):
161
+ if isinstance(file_path, str) and os.path.isdir(file_path):
162
+ return False
163
+ if len(state_dict) == 0:
164
+ state_dict = load_state_dict(file_path)
165
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
166
+ for sub_state_dict in splited_state_dict:
167
+ if super().match(file_path, sub_state_dict):
168
+ return True
169
+ return False
170
+
171
+
172
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
173
+ # Split the state_dict and load from each component
174
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
175
+ valid_state_dict = {}
176
+ for sub_state_dict in splited_state_dict:
177
+ if super().match(file_path, sub_state_dict):
178
+ valid_state_dict.update(sub_state_dict)
179
+ if super().match(file_path, valid_state_dict):
180
+ loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
181
+ else:
182
+ loaded_model_names, loaded_models = [], []
183
+ for sub_state_dict in splited_state_dict:
184
+ if super().match(file_path, sub_state_dict):
185
+ loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
186
+ loaded_model_names += loaded_model_names_
187
+ loaded_models += loaded_models_
188
+ return loaded_model_names, loaded_models
189
+
190
+
191
+
192
+ class ModelDetectorFromHuggingfaceFolder:
193
+ def __init__(self, model_loader_configs=[]):
194
+ self.architecture_dict = {}
195
+ for metadata in model_loader_configs:
196
+ self.add_model_metadata(*metadata)
197
+
198
+
199
+ def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
200
+ self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
201
+
202
+
203
+ def match(self, file_path="", state_dict={}):
204
+ if not isinstance(file_path, str) or os.path.isfile(file_path):
205
+ return False
206
+ file_list = os.listdir(file_path)
207
+ if "config.json" not in file_list:
208
+ return False
209
+ with open(os.path.join(file_path, "config.json"), "r") as f:
210
+ config = json.load(f)
211
+ if "architectures" not in config and "_class_name" not in config:
212
+ return False
213
+ return True
214
+
215
+
216
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
217
+ with open(os.path.join(file_path, "config.json"), "r") as f:
218
+ config = json.load(f)
219
+ loaded_model_names, loaded_models = [], []
220
+ architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
221
+ for architecture in architectures:
222
+ huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
223
+ if redirected_architecture is not None:
224
+ architecture = redirected_architecture
225
+ model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
226
+ loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
227
+ loaded_model_names += loaded_model_names_
228
+ loaded_models += loaded_models_
229
+ return loaded_model_names, loaded_models
230
+
231
+
232
+
233
+ class ModelDetectorFromPatchedSingleFile:
234
+ def __init__(self, model_loader_configs=[]):
235
+ self.keys_hash_with_shape_dict = {}
236
+ for metadata in model_loader_configs:
237
+ self.add_model_metadata(*metadata)
238
+
239
+
240
+ def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
241
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
242
+
243
+
244
+ def match(self, file_path="", state_dict={}):
245
+ if not isinstance(file_path, str) or os.path.isdir(file_path):
246
+ return False
247
+ if len(state_dict) == 0:
248
+ state_dict = load_state_dict(file_path)
249
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
250
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
251
+ return True
252
+ return False
253
+
254
+
255
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
256
+ if len(state_dict) == 0:
257
+ state_dict = load_state_dict(file_path)
258
+
259
+ # Load models with strict matching
260
+ loaded_model_names, loaded_models = [], []
261
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
262
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
263
+ model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
264
+ loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
265
+ state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
266
+ loaded_model_names += loaded_model_names_
267
+ loaded_models += loaded_models_
268
+ return loaded_model_names, loaded_models
269
+
270
+
271
+
272
+ class ModelManager:
273
+ def __init__(
274
+ self,
275
+ torch_dtype=torch.float16,
276
+ device="cuda",
277
+ model_id_list: List[Preset_model_id] = [],
278
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
279
+ file_path_list: List[str] = [],
280
+ ):
281
+ self.torch_dtype = torch_dtype
282
+ self.device = device
283
+ self.model = []
284
+ self.model_path = []
285
+ self.model_name = []
286
+ downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
287
+ self.model_detector = [
288
+ ModelDetectorFromSingleFile(model_loader_configs),
289
+ ModelDetectorFromSplitedSingleFile(model_loader_configs),
290
+ ]
291
+ self.load_models(downloaded_files + file_path_list)
292
+
293
+
294
+ def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
295
+ print(f"Loading models from file: {file_path}")
296
+ if len(state_dict) == 0:
297
+ state_dict = load_state_dict(file_path)
298
+ model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
299
+ for model_name, model in zip(model_names, models):
300
+ self.model.append(model)
301
+ self.model_path.append(file_path)
302
+ self.model_name.append(model_name)
303
+ print(f" The following models are loaded: {model_names}.")
304
+
305
+
306
+ def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
307
+ print(f"Loading models from folder: {file_path}")
308
+ model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
309
+ for model_name, model in zip(model_names, models):
310
+ self.model.append(model)
311
+ self.model_path.append(file_path)
312
+ self.model_name.append(model_name)
313
+ print(f" The following models are loaded: {model_names}.")
314
+
315
+
316
+ def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
317
+ print(f"Loading patch models from file: {file_path}")
318
+ model_names, models = load_patch_model_from_single_file(
319
+ state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
320
+ for model_name, model in zip(model_names, models):
321
+ self.model.append(model)
322
+ self.model_path.append(file_path)
323
+ self.model_name.append(model_name)
324
+ print(f" The following patched models are loaded: {model_names}.")
325
+
326
+
327
+ def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
328
+ if isinstance(file_path, list):
329
+ for file_path_ in file_path:
330
+ self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
331
+ else:
332
+ print(f"Loading LoRA models from file: {file_path}")
333
+ is_loaded = False
334
+ if len(state_dict) == 0:
335
+ state_dict = load_state_dict(file_path)
336
+ for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
337
+ for lora in get_lora_loaders():
338
+ match_results = lora.match(model, state_dict)
339
+ if match_results is not None:
340
+ print(f" Adding LoRA to {model_name} ({model_path}).")
341
+ lora_prefix, model_resource = match_results
342
+ lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
343
+ is_loaded = True
344
+ break
345
+ if not is_loaded:
346
+ print(f" Cannot load LoRA: {file_path}")
347
+
348
+
349
+ def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
350
+ print(f"Loading models from: {file_path}")
351
+ if device is None: device = self.device
352
+ if torch_dtype is None: torch_dtype = self.torch_dtype
353
+ if isinstance(file_path, list):
354
+ state_dict = {}
355
+ for path in file_path:
356
+ state_dict.update(load_state_dict(path))
357
+ elif os.path.isfile(file_path):
358
+ state_dict = load_state_dict(file_path)
359
+ else:
360
+ state_dict = None
361
+ for model_detector in self.model_detector:
362
+ if model_detector.match(file_path, state_dict):
363
+ model_names, models = model_detector.load(
364
+ file_path, state_dict,
365
+ device=device, torch_dtype=torch_dtype,
366
+ allowed_model_names=model_names, model_manager=self
367
+ )
368
+ for model_name, model in zip(model_names, models):
369
+ self.model.append(model)
370
+ self.model_path.append(file_path)
371
+ self.model_name.append(model_name)
372
+ print(f" The following models are loaded: {model_names}.")
373
+ break
374
+ else:
375
+ print(f" We cannot detect the model type. No models are loaded.")
376
+
377
+
378
+ def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
379
+ for file_path in file_path_list:
380
+ self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
381
+
382
+
383
+ def fetch_model(self, model_name, file_path=None, require_model_path=False, index=None):
384
+ fetched_models = []
385
+ fetched_model_paths = []
386
+ for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
387
+ if file_path is not None and file_path != model_path:
388
+ continue
389
+ if model_name == model_name_:
390
+ fetched_models.append(model)
391
+ fetched_model_paths.append(model_path)
392
+ if len(fetched_models) == 0:
393
+ print(f"No {model_name} models available.")
394
+ return None
395
+ if len(fetched_models) == 1:
396
+ print(f"Using {model_name} from {fetched_model_paths[0]}.")
397
+ model = fetched_models[0]
398
+ path = fetched_model_paths[0]
399
+ else:
400
+ if index is None:
401
+ model = fetched_models[0]
402
+ path = fetched_model_paths[0]
403
+ print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
404
+ elif isinstance(index, int):
405
+ model = fetched_models[:index]
406
+ path = fetched_model_paths[:index]
407
+ print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[:index]}.")
408
+ else:
409
+ model = fetched_models
410
+ path = fetched_model_paths
411
+ print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths}.")
412
+ if require_model_path:
413
+ return model, path
414
+ else:
415
+ return model
416
+
417
+
418
+ def to(self, device):
419
+ for model in self.model:
420
+ model.to(device)
421
+
dkt/models/tiler.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange, repeat
3
+
4
+
5
+ class TileWorker:
6
+ def __init__(self):
7
+ pass
8
+
9
+
10
+ def mask(self, height, width, border_width):
11
+ # Create a mask with shape (height, width).
12
+ # The centre area is filled with 1, and the border line is filled with values in range (0, 1].
13
+ x = torch.arange(height).repeat(width, 1).T
14
+ y = torch.arange(width).repeat(height, 1)
15
+ mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
16
+ mask = (mask / border_width).clip(0, 1)
17
+ return mask
18
+
19
+
20
+ def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
21
+ # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
22
+ batch_size, channel, _, _ = model_input.shape
23
+ model_input = model_input.to(device=tile_device, dtype=tile_dtype)
24
+ unfold_operator = torch.nn.Unfold(
25
+ kernel_size=(tile_size, tile_size),
26
+ stride=(tile_stride, tile_stride)
27
+ )
28
+ model_input = unfold_operator(model_input)
29
+ model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
30
+
31
+ return model_input
32
+
33
+
34
+ def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
35
+ # Call y=forward_fn(x) for each tile
36
+ tile_num = model_input.shape[-1]
37
+ model_output_stack = []
38
+
39
+ for tile_id in range(0, tile_num, tile_batch_size):
40
+
41
+ # process input
42
+ tile_id_ = min(tile_id + tile_batch_size, tile_num)
43
+ x = model_input[:, :, :, :, tile_id: tile_id_]
44
+ x = x.to(device=inference_device, dtype=inference_dtype)
45
+ x = rearrange(x, "b c h w n -> (n b) c h w")
46
+
47
+ # process output
48
+ y = forward_fn(x)
49
+ y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
50
+ y = y.to(device=tile_device, dtype=tile_dtype)
51
+ model_output_stack.append(y)
52
+
53
+ model_output = torch.concat(model_output_stack, dim=-1)
54
+ return model_output
55
+
56
+
57
+ def io_scale(self, model_output, tile_size):
58
+ # Determine the size modification happened in forward_fn
59
+ # We only consider the same scale on height and width.
60
+ io_scale = model_output.shape[2] / tile_size
61
+ return io_scale
62
+
63
+
64
+ def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
65
+ # The reversed function of tile
66
+ mask = self.mask(tile_size, tile_size, border_width)
67
+ mask = mask.to(device=tile_device, dtype=tile_dtype)
68
+ mask = rearrange(mask, "h w -> 1 1 h w 1")
69
+ model_output = model_output * mask
70
+
71
+ fold_operator = torch.nn.Fold(
72
+ output_size=(height, width),
73
+ kernel_size=(tile_size, tile_size),
74
+ stride=(tile_stride, tile_stride)
75
+ )
76
+ mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
77
+ model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
78
+ model_output = fold_operator(model_output) / fold_operator(mask)
79
+
80
+ return model_output
81
+
82
+
83
+ def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
84
+ # Prepare
85
+ inference_device, inference_dtype = model_input.device, model_input.dtype
86
+ height, width = model_input.shape[2], model_input.shape[3]
87
+ border_width = int(tile_stride*0.5) if border_width is None else border_width
88
+
89
+ # tile
90
+ model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
91
+
92
+ # inference
93
+ model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
94
+
95
+ # resize
96
+ io_scale = self.io_scale(model_output, tile_size)
97
+ height, width = int(height*io_scale), int(width*io_scale)
98
+ tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
99
+ border_width = int(border_width*io_scale)
100
+
101
+ # untile
102
+ model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
103
+
104
+ # Done!
105
+ model_output = model_output.to(device=inference_device, dtype=inference_dtype)
106
+ return model_output
107
+
108
+
109
+
110
+ class FastTileWorker:
111
+ def __init__(self):
112
+ pass
113
+
114
+
115
+ def build_mask(self, data, is_bound):
116
+ _, _, H, W = data.shape
117
+ h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
118
+ w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
119
+ border_width = (H + W) // 4
120
+ pad = torch.ones_like(h) * border_width
121
+ mask = torch.stack([
122
+ pad if is_bound[0] else h + 1,
123
+ pad if is_bound[1] else H - h,
124
+ pad if is_bound[2] else w + 1,
125
+ pad if is_bound[3] else W - w
126
+ ]).min(dim=0).values
127
+ mask = mask.clip(1, border_width)
128
+ mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
129
+ mask = rearrange(mask, "H W -> 1 H W")
130
+ return mask
131
+
132
+
133
+ def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
134
+ # Prepare
135
+ B, C, H, W = model_input.shape
136
+ border_width = int(tile_stride*0.5) if border_width is None else border_width
137
+ weight = torch.zeros((1, 1, H, W), dtype=tile_dtype, device=tile_device)
138
+ values = torch.zeros((B, C, H, W), dtype=tile_dtype, device=tile_device)
139
+
140
+ # Split tasks
141
+ tasks = []
142
+ for h in range(0, H, tile_stride):
143
+ for w in range(0, W, tile_stride):
144
+ if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
145
+ continue
146
+ h_, w_ = h + tile_size, w + tile_size
147
+ if h_ > H: h, h_ = H - tile_size, H
148
+ if w_ > W: w, w_ = W - tile_size, W
149
+ tasks.append((h, h_, w, w_))
150
+
151
+ # Run
152
+ for hl, hr, wl, wr in tasks:
153
+ # Forward
154
+ hidden_states_batch = forward_fn(hl, hr, wl, wr).to(dtype=tile_dtype, device=tile_device)
155
+
156
+ mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
157
+ values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
158
+ weight[:, :, hl:hr, wl:wr] += mask
159
+ values /= weight
160
+ return values
161
+
162
+
163
+
164
+ class TileWorker2Dto3D:
165
+ """
166
+ Process 3D tensors, but only enable TileWorker on 2D.
167
+ """
168
+ def __init__(self):
169
+ pass
170
+
171
+
172
+ def build_mask(self, T, H, W, dtype, device, is_bound, border_width):
173
+ t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
174
+ h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
175
+ w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
176
+ border_width = (H + W) // 4 if border_width is None else border_width
177
+ pad = torch.ones_like(h) * border_width
178
+ mask = torch.stack([
179
+ pad if is_bound[0] else t + 1,
180
+ pad if is_bound[1] else T - t,
181
+ pad if is_bound[2] else h + 1,
182
+ pad if is_bound[3] else H - h,
183
+ pad if is_bound[4] else w + 1,
184
+ pad if is_bound[5] else W - w
185
+ ]).min(dim=0).values
186
+ mask = mask.clip(1, border_width)
187
+ mask = (mask / border_width).to(dtype=dtype, device=device)
188
+ mask = rearrange(mask, "T H W -> 1 1 T H W")
189
+ return mask
190
+
191
+
192
+ def tiled_forward(
193
+ self,
194
+ forward_fn,
195
+ model_input,
196
+ tile_size, tile_stride,
197
+ tile_device="cpu", tile_dtype=torch.float32,
198
+ computation_device="cuda", computation_dtype=torch.float32,
199
+ border_width=None, scales=[1, 1, 1, 1],
200
+ progress_bar=lambda x:x
201
+ ):
202
+ B, C, T, H, W = model_input.shape
203
+ scale_C, scale_T, scale_H, scale_W = scales
204
+ tile_size_H, tile_size_W = tile_size
205
+ tile_stride_H, tile_stride_W = tile_stride
206
+
207
+ value = torch.zeros((B, int(C*scale_C), int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
208
+ weight = torch.zeros((1, 1, int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
209
+
210
+ # Split tasks
211
+ tasks = []
212
+ for h in range(0, H, tile_stride_H):
213
+ for w in range(0, W, tile_stride_W):
214
+ if (h-tile_stride_H >= 0 and h-tile_stride_H+tile_size_H >= H) or (w-tile_stride_W >= 0 and w-tile_stride_W+tile_size_W >= W):
215
+ continue
216
+ h_, w_ = h + tile_size_H, w + tile_size_W
217
+ if h_ > H: h, h_ = max(H - tile_size_H, 0), H
218
+ if w_ > W: w, w_ = max(W - tile_size_W, 0), W
219
+ tasks.append((h, h_, w, w_))
220
+
221
+ # Run
222
+ for hl, hr, wl, wr in progress_bar(tasks):
223
+ mask = self.build_mask(
224
+ int(T*scale_T), int((hr-hl)*scale_H), int((wr-wl)*scale_W),
225
+ tile_dtype, tile_device,
226
+ is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W),
227
+ border_width=border_width
228
+ )
229
+ grid_input = model_input[:, :, :, hl:hr, wl:wr].to(dtype=computation_dtype, device=computation_device)
230
+ grid_output = forward_fn(grid_input).to(dtype=tile_dtype, device=tile_device)
231
+ value[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += grid_output * mask
232
+ weight[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += mask
233
+ value = value / weight
234
+ return value
dkt/models/utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ from safetensors import safe_open
3
+ from contextlib import contextmanager
4
+ import hashlib
5
+
6
+ @contextmanager
7
+ def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
8
+
9
+ old_register_parameter = torch.nn.Module.register_parameter
10
+ if include_buffers:
11
+ old_register_buffer = torch.nn.Module.register_buffer
12
+
13
+ def register_empty_parameter(module, name, param):
14
+ old_register_parameter(module, name, param)
15
+ if param is not None:
16
+ param_cls = type(module._parameters[name])
17
+ kwargs = module._parameters[name].__dict__
18
+ kwargs["requires_grad"] = param.requires_grad
19
+ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
20
+
21
+ def register_empty_buffer(module, name, buffer, persistent=True):
22
+ old_register_buffer(module, name, buffer, persistent=persistent)
23
+ if buffer is not None:
24
+ module._buffers[name] = module._buffers[name].to(device)
25
+
26
+ def patch_tensor_constructor(fn):
27
+ def wrapper(*args, **kwargs):
28
+ kwargs["device"] = device
29
+ return fn(*args, **kwargs)
30
+
31
+ return wrapper
32
+
33
+ if include_buffers:
34
+ tensor_constructors_to_patch = {
35
+ torch_function_name: getattr(torch, torch_function_name)
36
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
37
+ }
38
+ else:
39
+ tensor_constructors_to_patch = {}
40
+
41
+ try:
42
+ torch.nn.Module.register_parameter = register_empty_parameter
43
+ if include_buffers:
44
+ torch.nn.Module.register_buffer = register_empty_buffer
45
+ for torch_function_name in tensor_constructors_to_patch.keys():
46
+ setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
47
+ yield
48
+ finally:
49
+ torch.nn.Module.register_parameter = old_register_parameter
50
+ if include_buffers:
51
+ torch.nn.Module.register_buffer = old_register_buffer
52
+ for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
53
+ setattr(torch, torch_function_name, old_torch_function)
54
+
55
+ def load_state_dict_from_folder(file_path, torch_dtype=None):
56
+ state_dict = {}
57
+ for file_name in os.listdir(file_path):
58
+ if "." in file_name and file_name.split(".")[-1] in [
59
+ "safetensors", "bin", "ckpt", "pth", "pt"
60
+ ]:
61
+ state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
62
+ return state_dict
63
+
64
+
65
+ def load_state_dict(file_path, torch_dtype=None, device="cpu"):
66
+ if file_path.endswith(".safetensors"):
67
+ return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
68
+ else:
69
+ return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
70
+
71
+
72
+ def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
73
+ state_dict = {}
74
+ with safe_open(file_path, framework="pt", device=str(device)) as f:
75
+ for k in f.keys():
76
+ state_dict[k] = f.get_tensor(k)
77
+ if torch_dtype is not None:
78
+ state_dict[k] = state_dict[k].to(torch_dtype)
79
+ return state_dict
80
+
81
+
82
+ def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
83
+ state_dict = torch.load(file_path, map_location=device, weights_only=True)
84
+ if torch_dtype is not None:
85
+ for i in state_dict:
86
+ if isinstance(state_dict[i], torch.Tensor):
87
+ state_dict[i] = state_dict[i].to(torch_dtype)
88
+ return state_dict
89
+
90
+
91
+ def search_for_embeddings(state_dict):
92
+ embeddings = []
93
+ for k in state_dict:
94
+ if isinstance(state_dict[k], torch.Tensor):
95
+ embeddings.append(state_dict[k])
96
+ elif isinstance(state_dict[k], dict):
97
+ embeddings += search_for_embeddings(state_dict[k])
98
+ return embeddings
99
+
100
+
101
+ def search_parameter(param, state_dict):
102
+ for name, param_ in state_dict.items():
103
+ if param.numel() == param_.numel():
104
+ if param.shape == param_.shape:
105
+ if torch.dist(param, param_) < 1e-3:
106
+ return name
107
+ else:
108
+ if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
109
+ return name
110
+ return None
111
+
112
+
113
+ def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
114
+ matched_keys = set()
115
+ with torch.no_grad():
116
+ for name in source_state_dict:
117
+ rename = search_parameter(source_state_dict[name], target_state_dict)
118
+ if rename is not None:
119
+ print(f'"{name}": "{rename}",')
120
+ matched_keys.add(rename)
121
+ elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
122
+ length = source_state_dict[name].shape[0] // 3
123
+ rename = []
124
+ for i in range(3):
125
+ rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
126
+ if None not in rename:
127
+ print(f'"{name}": {rename},')
128
+ for rename_ in rename:
129
+ matched_keys.add(rename_)
130
+ for name in target_state_dict:
131
+ if name not in matched_keys:
132
+ print("Cannot find", name, target_state_dict[name].shape)
133
+
134
+
135
+ def search_for_files(folder, extensions):
136
+ files = []
137
+ if os.path.isdir(folder):
138
+ for file in sorted(os.listdir(folder)):
139
+ files += search_for_files(os.path.join(folder, file), extensions)
140
+ elif os.path.isfile(folder):
141
+ for extension in extensions:
142
+ if folder.endswith(extension):
143
+ files.append(folder)
144
+ break
145
+ return files
146
+
147
+
148
+ def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
149
+ keys = []
150
+ for key, value in state_dict.items():
151
+ if isinstance(key, str):
152
+ if isinstance(value, torch.Tensor):
153
+ if with_shape:
154
+ shape = "_".join(map(str, list(value.shape)))
155
+ keys.append(key + ":" + shape)
156
+ keys.append(key)
157
+ elif isinstance(value, dict):
158
+ keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
159
+ keys.sort()
160
+ keys_str = ",".join(keys)
161
+ return keys_str
162
+
163
+
164
+ def split_state_dict_with_prefix(state_dict):
165
+ keys = sorted([key for key in state_dict if isinstance(key, str)])
166
+ prefix_dict = {}
167
+ for key in keys:
168
+ prefix = key if "." not in key else key.split(".")[0]
169
+ if prefix not in prefix_dict:
170
+ prefix_dict[prefix] = []
171
+ prefix_dict[prefix].append(key)
172
+ state_dicts = []
173
+ for prefix, keys in prefix_dict.items():
174
+ sub_state_dict = {key: state_dict[key] for key in keys}
175
+ state_dicts.append(sub_state_dict)
176
+ return state_dicts
177
+
178
+
179
+ def hash_state_dict_keys(state_dict, with_shape=True):
180
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
181
+ keys_str = keys_str.encode(encoding="UTF-8")
182
+ return hashlib.md5(keys_str).hexdigest()
dkt/models/wan_video_camera_controller.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from einops import rearrange
5
+ import os
6
+ from typing_extensions import Literal
7
+
8
+ class SimpleAdapter(nn.Module):
9
+ def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1):
10
+ super(SimpleAdapter, self).__init__()
11
+
12
+ # Pixel Unshuffle: reduce spatial dimensions by a factor of 8
13
+ self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
14
+
15
+ # Convolution: reduce spatial dimensions by a factor
16
+ # of 2 (without overlap)
17
+ self.conv = nn.Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0)
18
+
19
+ # Residual blocks for feature extraction
20
+ self.residual_blocks = nn.Sequential(
21
+ *[ResidualBlock(out_dim) for _ in range(num_residual_blocks)]
22
+ )
23
+
24
+ def forward(self, x):
25
+ # Reshape to merge the frame dimension into batch
26
+ bs, c, f, h, w = x.size()
27
+ x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
28
+
29
+ # Pixel Unshuffle operation
30
+ x_unshuffled = self.pixel_unshuffle(x)
31
+
32
+ # Convolution operation
33
+ x_conv = self.conv(x_unshuffled)
34
+
35
+ # Feature extraction with residual blocks
36
+ out = self.residual_blocks(x_conv)
37
+
38
+ # Reshape to restore original bf dimension
39
+ out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
40
+
41
+ # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
42
+ out = out.permute(0, 2, 1, 3, 4)
43
+
44
+ return out
45
+
46
+ def process_camera_coordinates(
47
+ self,
48
+ direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
49
+ length: int,
50
+ height: int,
51
+ width: int,
52
+ speed: float = 1/54,
53
+ origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
54
+ ):
55
+ if origin is None:
56
+ origin = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
57
+ coordinates = generate_camera_coordinates(direction, length, speed, origin)
58
+ plucker_embedding = process_pose_file(coordinates, width, height)
59
+ return plucker_embedding
60
+
61
+
62
+
63
+ class ResidualBlock(nn.Module):
64
+ def __init__(self, dim):
65
+ super(ResidualBlock, self).__init__()
66
+ self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
67
+ self.relu = nn.ReLU(inplace=True)
68
+ self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
69
+
70
+ def forward(self, x):
71
+ residual = x
72
+ out = self.relu(self.conv1(x))
73
+ out = self.conv2(out)
74
+ out += residual
75
+ return out
76
+
77
+ class Camera(object):
78
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
79
+ """
80
+ def __init__(self, entry):
81
+ fx, fy, cx, cy = entry[1:5]
82
+ self.fx = fx
83
+ self.fy = fy
84
+ self.cx = cx
85
+ self.cy = cy
86
+ w2c_mat = np.array(entry[7:]).reshape(3, 4)
87
+ w2c_mat_4x4 = np.eye(4)
88
+ w2c_mat_4x4[:3, :] = w2c_mat
89
+ self.w2c_mat = w2c_mat_4x4
90
+ self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
91
+
92
+ def get_relative_pose(cam_params):
93
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
94
+ """
95
+ abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
96
+ abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
97
+ cam_to_origin = 0
98
+ target_cam_c2w = np.array([
99
+ [1, 0, 0, 0],
100
+ [0, 1, 0, -cam_to_origin],
101
+ [0, 0, 1, 0],
102
+ [0, 0, 0, 1]
103
+ ])
104
+ abs2rel = target_cam_c2w @ abs_w2cs[0]
105
+ ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
106
+ ret_poses = np.array(ret_poses, dtype=np.float32)
107
+ return ret_poses
108
+
109
+ def custom_meshgrid(*args):
110
+ # torch>=2.0.0 only
111
+ return torch.meshgrid(*args, indexing='ij')
112
+
113
+
114
+ def ray_condition(K, c2w, H, W, device):
115
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
116
+ """
117
+ # c2w: B, V, 4, 4
118
+ # K: B, V, 4
119
+
120
+ B = K.shape[0]
121
+
122
+ j, i = custom_meshgrid(
123
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
124
+ torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
125
+ )
126
+ i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
127
+ j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
128
+
129
+ fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
130
+
131
+ zs = torch.ones_like(i) # [B, HxW]
132
+ xs = (i - cx) / fx * zs
133
+ ys = (j - cy) / fy * zs
134
+ zs = zs.expand_as(ys)
135
+
136
+ directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
137
+ directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
138
+
139
+ rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
140
+ rays_o = c2w[..., :3, 3] # B, V, 3
141
+ rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
142
+ # c2w @ dirctions
143
+ rays_dxo = torch.linalg.cross(rays_o, rays_d)
144
+ plucker = torch.cat([rays_dxo, rays_d], dim=-1)
145
+ plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
146
+ # plucker = plucker.permute(0, 1, 4, 2, 3)
147
+ return plucker
148
+
149
+
150
+ def process_pose_file(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
151
+ if return_poses:
152
+ return cam_params
153
+ else:
154
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
155
+
156
+ sample_wh_ratio = width / height
157
+ pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
158
+
159
+ if pose_wh_ratio > sample_wh_ratio:
160
+ resized_ori_w = height * pose_wh_ratio
161
+ for cam_param in cam_params:
162
+ cam_param.fx = resized_ori_w * cam_param.fx / width
163
+ else:
164
+ resized_ori_h = width / pose_wh_ratio
165
+ for cam_param in cam_params:
166
+ cam_param.fy = resized_ori_h * cam_param.fy / height
167
+
168
+ intrinsic = np.asarray([[cam_param.fx * width,
169
+ cam_param.fy * height,
170
+ cam_param.cx * width,
171
+ cam_param.cy * height]
172
+ for cam_param in cam_params], dtype=np.float32)
173
+
174
+ K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
175
+ c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
176
+ c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
177
+ plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
178
+ plucker_embedding = plucker_embedding[None]
179
+ plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
180
+ return plucker_embedding
181
+
182
+
183
+
184
+ def generate_camera_coordinates(
185
+ direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
186
+ length: int,
187
+ speed: float = 1/54,
188
+ origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
189
+ ):
190
+ coordinates = [list(origin)]
191
+ while len(coordinates) < length:
192
+ coor = coordinates[-1].copy()
193
+ if "Left" in direction:
194
+ coor[9] += speed
195
+ if "Right" in direction:
196
+ coor[9] -= speed
197
+ if "Up" in direction:
198
+ coor[13] += speed
199
+ if "Down" in direction:
200
+ coor[13] -= speed
201
+ coordinates.append(coor)
202
+ return coordinates
dkt/models/wan_video_dit.py ADDED
@@ -0,0 +1,719 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from typing import Tuple, Optional
6
+ from einops import rearrange
7
+ from .utils import hash_state_dict_keys
8
+ from .wan_video_camera_controller import SimpleAdapter
9
+ try:
10
+ import flash_attn_interface
11
+ FLASH_ATTN_3_AVAILABLE = True
12
+ except ModuleNotFoundError:
13
+ FLASH_ATTN_3_AVAILABLE = False
14
+
15
+ try:
16
+ import flash_attn
17
+ FLASH_ATTN_2_AVAILABLE = True
18
+ except ModuleNotFoundError:
19
+ FLASH_ATTN_2_AVAILABLE = False
20
+
21
+ try:
22
+ from sageattention import sageattn
23
+ SAGE_ATTN_AVAILABLE = True
24
+ except ModuleNotFoundError:
25
+ SAGE_ATTN_AVAILABLE = False
26
+
27
+
28
+ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False):
29
+ if compatibility_mode:
30
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
31
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
32
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
33
+ x = F.scaled_dot_product_attention(q, k, v)
34
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
35
+ elif FLASH_ATTN_3_AVAILABLE:
36
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
37
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
38
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
39
+ x = flash_attn_interface.flash_attn_func(q, k, v)
40
+ if isinstance(x,tuple):
41
+ x = x[0]
42
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
43
+ elif FLASH_ATTN_2_AVAILABLE:
44
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
45
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
46
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
47
+ x = flash_attn.flash_attn_func(q, k, v)
48
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
49
+ elif SAGE_ATTN_AVAILABLE:
50
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
51
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
52
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
53
+ x = sageattn(q, k, v)
54
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
55
+ else:
56
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
57
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
58
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
59
+ x = F.scaled_dot_product_attention(q, k, v)
60
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
61
+ return x
62
+
63
+
64
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
65
+ return (x * (1 + scale) + shift)
66
+
67
+
68
+ def sinusoidal_embedding_1d(dim, position):
69
+ sinusoid = torch.outer(position.type(torch.float64), torch.pow(
70
+ 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
71
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
72
+ return x.to(position.dtype)
73
+
74
+
75
+ def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
76
+ # 3d rope precompute
77
+ f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
78
+ h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
79
+ w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
80
+ return f_freqs_cis, h_freqs_cis, w_freqs_cis
81
+
82
+
83
+ def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
84
+ # 1d rope precompute
85
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
86
+ [: (dim // 2)].double() / dim))
87
+ freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
88
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
89
+ return freqs_cis
90
+
91
+
92
+ def rope_apply(x, freqs, num_heads):
93
+ x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
94
+ x_out = torch.view_as_complex(x.to(torch.float64).reshape(
95
+ x.shape[0], x.shape[1], x.shape[2], -1, 2))
96
+ x_out = torch.view_as_real(x_out * freqs).flatten(2)
97
+ return x_out.to(x.dtype)
98
+
99
+
100
+ class RMSNorm(nn.Module):
101
+ def __init__(self, dim, eps=1e-5):
102
+ super().__init__()
103
+ self.eps = eps
104
+ self.weight = nn.Parameter(torch.ones(dim))
105
+
106
+ def norm(self, x):
107
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
108
+
109
+ def forward(self, x):
110
+ dtype = x.dtype
111
+ return self.norm(x.float()).to(dtype) * self.weight
112
+
113
+
114
+ class AttentionModule(nn.Module):
115
+ def __init__(self, num_heads):
116
+ super().__init__()
117
+ self.num_heads = num_heads
118
+
119
+ def forward(self, q, k, v):
120
+ x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
121
+ return x
122
+
123
+
124
+ class SelfAttention(nn.Module):
125
+ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
126
+ super().__init__()
127
+ self.dim = dim
128
+ self.num_heads = num_heads
129
+ self.head_dim = dim // num_heads
130
+
131
+ self.q = nn.Linear(dim, dim)
132
+ self.k = nn.Linear(dim, dim)
133
+ self.v = nn.Linear(dim, dim)
134
+ self.o = nn.Linear(dim, dim)
135
+ self.norm_q = RMSNorm(dim, eps=eps)
136
+ self.norm_k = RMSNorm(dim, eps=eps)
137
+
138
+ self.attn = AttentionModule(self.num_heads)
139
+
140
+ def forward(self, x, freqs):
141
+ q = self.norm_q(self.q(x))
142
+ k = self.norm_k(self.k(x))
143
+ v = self.v(x)
144
+ q = rope_apply(q, freqs, self.num_heads)
145
+ k = rope_apply(k, freqs, self.num_heads)
146
+ x = self.attn(q, k, v)
147
+ return self.o(x)
148
+
149
+
150
+ class CrossAttention(nn.Module):
151
+ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False):
152
+ super().__init__()
153
+ self.dim = dim
154
+ self.num_heads = num_heads
155
+ self.head_dim = dim // num_heads
156
+
157
+ self.q = nn.Linear(dim, dim)
158
+ self.k = nn.Linear(dim, dim)
159
+ self.v = nn.Linear(dim, dim)
160
+ self.o = nn.Linear(dim, dim)
161
+ self.norm_q = RMSNorm(dim, eps=eps)
162
+ self.norm_k = RMSNorm(dim, eps=eps)
163
+ self.has_image_input = has_image_input
164
+ if has_image_input:
165
+ self.k_img = nn.Linear(dim, dim)
166
+ self.v_img = nn.Linear(dim, dim)
167
+ self.norm_k_img = RMSNorm(dim, eps=eps)
168
+
169
+ self.attn = AttentionModule(self.num_heads)
170
+
171
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
172
+ if self.has_image_input:
173
+ img = y[:, :257]
174
+ ctx = y[:, 257:]
175
+ else:
176
+ ctx = y
177
+ q = self.norm_q(self.q(x))
178
+ k = self.norm_k(self.k(ctx))
179
+ v = self.v(ctx)
180
+ x = self.attn(q, k, v)
181
+ if self.has_image_input:
182
+ k_img = self.norm_k_img(self.k_img(img))
183
+ v_img = self.v_img(img)
184
+ y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
185
+ x = x + y
186
+ return self.o(x)
187
+
188
+
189
+ class GateModule(nn.Module):
190
+ def __init__(self,):
191
+ super().__init__()
192
+
193
+ def forward(self, x, gate, residual):
194
+ return x + gate * residual
195
+
196
+ class DiTBlock(nn.Module):
197
+ def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
198
+ super().__init__()
199
+ self.dim = dim
200
+ self.num_heads = num_heads
201
+ self.ffn_dim = ffn_dim
202
+
203
+ self.self_attn = SelfAttention(dim, num_heads, eps)
204
+ self.cross_attn = CrossAttention(
205
+ dim, num_heads, eps, has_image_input=has_image_input)
206
+ self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
207
+ self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
208
+ self.norm3 = nn.LayerNorm(dim, eps=eps)
209
+ self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
210
+ approximate='tanh'), nn.Linear(ffn_dim, dim))
211
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
212
+ self.gate = GateModule()
213
+
214
+ def forward(self, x, context, t_mod, freqs):
215
+ has_seq = len(t_mod.shape) == 4
216
+ chunk_dim = 2 if has_seq else 1
217
+ # msa: multi-head self-attention mlp: multi-layer perceptron
218
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
219
+ self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim)
220
+ if has_seq:
221
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
222
+ shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2),
223
+ shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2),
224
+ )
225
+ input_x = modulate(self.norm1(x), shift_msa, scale_msa)
226
+ x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
227
+ x = x + self.cross_attn(self.norm3(x), context)
228
+ input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
229
+ x = self.gate(x, gate_mlp, self.ffn(input_x))
230
+ return x
231
+
232
+
233
+ class MLP(torch.nn.Module):
234
+ def __init__(self, in_dim, out_dim, has_pos_emb=False):
235
+ super().__init__()
236
+ self.proj = torch.nn.Sequential(
237
+ nn.LayerNorm(in_dim),
238
+ nn.Linear(in_dim, in_dim),
239
+ nn.GELU(),
240
+ nn.Linear(in_dim, out_dim),
241
+ nn.LayerNorm(out_dim)
242
+ )
243
+ self.has_pos_emb = has_pos_emb
244
+ if has_pos_emb:
245
+ self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
246
+
247
+ def forward(self, x):
248
+ if self.has_pos_emb:
249
+ x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
250
+ return self.proj(x)
251
+
252
+
253
+ class Head(nn.Module):
254
+ def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
255
+ super().__init__()
256
+ self.dim = dim
257
+ self.patch_size = patch_size
258
+ self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
259
+ self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
260
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
261
+
262
+ def forward(self, x, t_mod):
263
+ if len(t_mod.shape) == 3:
264
+ shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2)
265
+ x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)))
266
+ else:
267
+ shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
268
+ x = (self.head(self.norm(x) * (1 + scale) + shift))
269
+ return x
270
+
271
+
272
+ class WanModel(torch.nn.Module):
273
+ def __init__(
274
+ self,
275
+ dim: int,
276
+ in_dim: int,
277
+ ffn_dim: int,
278
+ out_dim: int,
279
+ text_dim: int,
280
+ freq_dim: int,
281
+ eps: float,
282
+ patch_size: Tuple[int, int, int],
283
+ num_heads: int,
284
+ num_layers: int,
285
+ has_image_input: bool,
286
+ has_image_pos_emb: bool = False,
287
+ has_ref_conv: bool = False,
288
+ add_control_adapter: bool = False,
289
+ in_dim_control_adapter: int = 24,
290
+ seperated_timestep: bool = False,
291
+ require_vae_embedding: bool = True,
292
+ require_clip_embedding: bool = True,
293
+ fuse_vae_embedding_in_latents: bool = False,
294
+ ):
295
+ super().__init__()
296
+ self.dim = dim
297
+ self.freq_dim = freq_dim
298
+ self.has_image_input = has_image_input
299
+ self.patch_size = patch_size
300
+ self.seperated_timestep = seperated_timestep
301
+ self.require_vae_embedding = require_vae_embedding
302
+ self.require_clip_embedding = require_clip_embedding
303
+ self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
304
+
305
+ self.patch_embedding = nn.Conv3d(
306
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
307
+ self.text_embedding = nn.Sequential(
308
+ nn.Linear(text_dim, dim),
309
+ nn.GELU(approximate='tanh'),
310
+ nn.Linear(dim, dim)
311
+ )
312
+ self.time_embedding = nn.Sequential(
313
+ nn.Linear(freq_dim, dim),
314
+ nn.SiLU(),
315
+ nn.Linear(dim, dim)
316
+ )
317
+ self.time_projection = nn.Sequential(
318
+ nn.SiLU(), nn.Linear(dim, dim * 6))
319
+ self.blocks = nn.ModuleList([
320
+ DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
321
+ for _ in range(num_layers)
322
+ ])
323
+ self.head = Head(dim, out_dim, patch_size, eps)
324
+ head_dim = dim // num_heads
325
+ self.freqs = precompute_freqs_cis_3d(head_dim)
326
+
327
+ if has_image_input:
328
+ self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
329
+ if has_ref_conv:
330
+ self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
331
+ self.has_image_pos_emb = has_image_pos_emb
332
+ self.has_ref_conv = has_ref_conv
333
+ if add_control_adapter:
334
+ self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
335
+ else:
336
+ self.control_adapter = None
337
+
338
+ def patchify(self, x: torch.Tensor,control_camera_latents_input: torch.Tensor = None):
339
+
340
+ x = self.patch_embedding(x) #* from ([1, 48, 21, 30, 40]) to [1, 1536, 21, 15, 20]),
341
+ if self.control_adapter is not None and control_camera_latents_input is not None:
342
+ y_camera = self.control_adapter(control_camera_latents_input)
343
+ x = [u + v for u, v in zip(x, y_camera)]
344
+ x = x[0].unsqueeze(0)
345
+ grid_size = x.shape[2:]#* get the (F,H,W)
346
+ x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
347
+ return x, grid_size # x, grid_size: (f, h, w)
348
+
349
+ def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
350
+ return rearrange(
351
+ x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
352
+ f=grid_size[0], h=grid_size[1], w=grid_size[2],
353
+ x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
354
+ )
355
+
356
+ def forward(self,
357
+ x: torch.Tensor,
358
+ timestep: torch.Tensor,
359
+ context: torch.Tensor,
360
+ clip_feature: Optional[torch.Tensor] = None,
361
+ y: Optional[torch.Tensor] = None,
362
+ use_gradient_checkpointing: bool = False,
363
+ use_gradient_checkpointing_offload: bool = False,
364
+ **kwargs,
365
+ ):
366
+ t = self.time_embedding(
367
+ sinusoidal_embedding_1d(self.freq_dim, timestep))
368
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
369
+ context = self.text_embedding(context)
370
+
371
+ if self.has_image_input:
372
+ x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
373
+ clip_embdding = self.img_emb(clip_feature)
374
+ context = torch.cat([clip_embdding, context], dim=1)
375
+
376
+ x, (f, h, w) = self.patchify(x)
377
+
378
+ freqs = torch.cat([
379
+ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
380
+ self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
381
+ self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
382
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
383
+
384
+ def create_custom_forward(module):
385
+ def custom_forward(*inputs):
386
+ return module(*inputs)
387
+ return custom_forward
388
+
389
+ for block in self.blocks:
390
+ if self.training and use_gradient_checkpointing:
391
+ if use_gradient_checkpointing_offload:
392
+ with torch.autograd.graph.save_on_cpu():
393
+ x = torch.utils.checkpoint.checkpoint(
394
+ create_custom_forward(block),
395
+ x, context, t_mod, freqs,
396
+ use_reentrant=False,
397
+ )
398
+ else:
399
+ x = torch.utils.checkpoint.checkpoint(
400
+ create_custom_forward(block),
401
+ x, context, t_mod, freqs,
402
+ use_reentrant=False,
403
+ )
404
+ else:
405
+ x = block(x, context, t_mod, freqs)
406
+
407
+ x = self.head(x, t)
408
+ x = self.unpatchify(x, (f, h, w))
409
+ return x
410
+
411
+ @staticmethod
412
+ def state_dict_converter():
413
+ return WanModelStateDictConverter()
414
+
415
+
416
+ class WanModelStateDictConverter:
417
+ def __init__(self):
418
+ pass
419
+
420
+ def from_diffusers(self, state_dict):
421
+ rename_dict = {
422
+ "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
423
+ "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
424
+ "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
425
+ "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
426
+ "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
427
+ "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
428
+ "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
429
+ "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
430
+ "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
431
+ "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
432
+ "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
433
+ "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
434
+ "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
435
+ "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
436
+ "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
437
+ "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
438
+ "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
439
+ "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
440
+ "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
441
+ "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
442
+ "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
443
+ "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
444
+ "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
445
+ "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
446
+ "blocks.0.norm2.bias": "blocks.0.norm3.bias",
447
+ "blocks.0.norm2.weight": "blocks.0.norm3.weight",
448
+ "blocks.0.scale_shift_table": "blocks.0.modulation",
449
+ "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
450
+ "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
451
+ "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
452
+ "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
453
+ "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
454
+ "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
455
+ "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
456
+ "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
457
+ "condition_embedder.time_proj.bias": "time_projection.1.bias",
458
+ "condition_embedder.time_proj.weight": "time_projection.1.weight",
459
+ "patch_embedding.bias": "patch_embedding.bias",
460
+ "patch_embedding.weight": "patch_embedding.weight",
461
+ "scale_shift_table": "head.modulation",
462
+ "proj_out.bias": "head.head.bias",
463
+ "proj_out.weight": "head.head.weight",
464
+ }
465
+ state_dict_ = {}
466
+ for name, param in state_dict.items():
467
+ if name in rename_dict:
468
+ state_dict_[rename_dict[name]] = param
469
+ else:
470
+ name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
471
+ if name_ in rename_dict:
472
+ name_ = rename_dict[name_]
473
+ name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
474
+ state_dict_[name_] = param
475
+ if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
476
+ config = {
477
+ "model_type": "t2v",
478
+ "patch_size": (1, 2, 2),
479
+ "text_len": 512,
480
+ "in_dim": 16,
481
+ "dim": 5120,
482
+ "ffn_dim": 13824,
483
+ "freq_dim": 256,
484
+ "text_dim": 4096,
485
+ "out_dim": 16,
486
+ "num_heads": 40,
487
+ "num_layers": 40,
488
+ "window_size": (-1, -1),
489
+ "qk_norm": True,
490
+ "cross_attn_norm": True,
491
+ "eps": 1e-6,
492
+ }
493
+ else:
494
+ config = {}
495
+ return state_dict_, config
496
+
497
+ def from_civitai(self, state_dict):
498
+ state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
499
+ if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
500
+ config = {
501
+ "has_image_input": False,
502
+ "patch_size": [1, 2, 2],
503
+ "in_dim": 16,
504
+ "dim": 1536,
505
+ "ffn_dim": 8960,
506
+ "freq_dim": 256,
507
+ "text_dim": 4096,
508
+ "out_dim": 16,
509
+ "num_heads": 12,
510
+ "num_layers": 30,
511
+ "eps": 1e-6
512
+ }
513
+ elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
514
+ config = {
515
+ "has_image_input": False,
516
+ "patch_size": [1, 2, 2],
517
+ "in_dim": 16,
518
+ "dim": 5120,
519
+ "ffn_dim": 13824,
520
+ "freq_dim": 256,
521
+ "text_dim": 4096,
522
+ "out_dim": 16,
523
+ "num_heads": 40,
524
+ "num_layers": 40,
525
+ "eps": 1e-6
526
+ }
527
+ elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
528
+ config = {
529
+ "has_image_input": True,
530
+ "patch_size": [1, 2, 2],
531
+ "in_dim": 36,
532
+ "dim": 5120,
533
+ "ffn_dim": 13824,
534
+ "freq_dim": 256,
535
+ "text_dim": 4096,
536
+ "out_dim": 16,
537
+ "num_heads": 40,
538
+ "num_layers": 40,
539
+ "eps": 1e-6
540
+ }
541
+ elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
542
+ config = {
543
+ "has_image_input": True,
544
+ "patch_size": [1, 2, 2],
545
+ "in_dim": 36,
546
+ "dim": 1536,
547
+ "ffn_dim": 8960,
548
+ "freq_dim": 256,
549
+ "text_dim": 4096,
550
+ "out_dim": 16,
551
+ "num_heads": 12,
552
+ "num_layers": 30,
553
+ "eps": 1e-6
554
+ }
555
+ elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
556
+ config = {
557
+ "has_image_input": True,
558
+ "patch_size": [1, 2, 2],
559
+ "in_dim": 36,
560
+ "dim": 5120,
561
+ "ffn_dim": 13824,
562
+ "freq_dim": 256,
563
+ "text_dim": 4096,
564
+ "out_dim": 16,
565
+ "num_heads": 40,
566
+ "num_layers": 40,
567
+ "eps": 1e-6
568
+ }
569
+ elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
570
+ # 1.3B PAI control
571
+ config = {
572
+ "has_image_input": True,
573
+ "patch_size": [1, 2, 2],
574
+ "in_dim": 48,
575
+ "dim": 1536,
576
+ "ffn_dim": 8960,
577
+ "freq_dim": 256,
578
+ "text_dim": 4096,
579
+ "out_dim": 16,
580
+ "num_heads": 12,
581
+ "num_layers": 30,
582
+ "eps": 1e-6
583
+ }
584
+ elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
585
+ # 14B PAI control
586
+ config = {
587
+ "has_image_input": True,
588
+ "patch_size": [1, 2, 2],
589
+ "in_dim": 48,
590
+ "dim": 5120,
591
+ "ffn_dim": 13824,
592
+ "freq_dim": 256,
593
+ "text_dim": 4096,
594
+ "out_dim": 16,
595
+ "num_heads": 40,
596
+ "num_layers": 40,
597
+ "eps": 1e-6
598
+ }
599
+ elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
600
+ config = {
601
+ "has_image_input": True,
602
+ "patch_size": [1, 2, 2],
603
+ "in_dim": 36,
604
+ "dim": 5120,
605
+ "ffn_dim": 13824,
606
+ "freq_dim": 256,
607
+ "text_dim": 4096,
608
+ "out_dim": 16,
609
+ "num_heads": 40,
610
+ "num_layers": 40,
611
+ "eps": 1e-6,
612
+ "has_image_pos_emb": True
613
+ }
614
+ elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504":
615
+ # 1.3B PAI control v1.1
616
+ config = {
617
+ "has_image_input": True,
618
+ "patch_size": [1, 2, 2],
619
+ "in_dim": 48,
620
+ "dim": 1536,
621
+ "ffn_dim": 8960,
622
+ "freq_dim": 256,
623
+ "text_dim": 4096,
624
+ "out_dim": 16,
625
+ "num_heads": 12,
626
+ "num_layers": 30,
627
+ "eps": 1e-6,
628
+ "has_ref_conv": True
629
+ }
630
+ elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b":
631
+ # 14B PAI control v1.1
632
+ config = {
633
+ "has_image_input": True,
634
+ "patch_size": [1, 2, 2],
635
+ "in_dim": 48,
636
+ "dim": 5120,
637
+ "ffn_dim": 13824,
638
+ "freq_dim": 256,
639
+ "text_dim": 4096,
640
+ "out_dim": 16,
641
+ "num_heads": 40,
642
+ "num_layers": 40,
643
+ "eps": 1e-6,
644
+ "has_ref_conv": True
645
+ }
646
+ elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901":
647
+ # 1.3B PAI control-camera v1.1
648
+ config = {
649
+ "has_image_input": True,
650
+ "patch_size": [1, 2, 2],
651
+ "in_dim": 32,
652
+ "dim": 1536,
653
+ "ffn_dim": 8960,
654
+ "freq_dim": 256,
655
+ "text_dim": 4096,
656
+ "out_dim": 16,
657
+ "num_heads": 12,
658
+ "num_layers": 30,
659
+ "eps": 1e-6,
660
+ "has_ref_conv": False,
661
+ "add_control_adapter": True,
662
+ "in_dim_control_adapter": 24,
663
+ }
664
+ elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae":
665
+ # 14B PAI control-camera v1.1
666
+ config = {
667
+ "has_image_input": True,
668
+ "patch_size": [1, 2, 2],
669
+ "in_dim": 32,
670
+ "dim": 5120,
671
+ "ffn_dim": 13824,
672
+ "freq_dim": 256,
673
+ "text_dim": 4096,
674
+ "out_dim": 16,
675
+ "num_heads": 40,
676
+ "num_layers": 40,
677
+ "eps": 1e-6,
678
+ "has_ref_conv": False,
679
+ "add_control_adapter": True,
680
+ "in_dim_control_adapter": 24,
681
+ }
682
+ elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316":
683
+ # Wan-AI/Wan2.2-TI2V-5B
684
+ config = {
685
+ "has_image_input": False,
686
+ "patch_size": [1, 2, 2],
687
+ "in_dim": 48,
688
+ "dim": 3072,
689
+ "ffn_dim": 14336,
690
+ "freq_dim": 256,
691
+ "text_dim": 4096,
692
+ "out_dim": 48,
693
+ "num_heads": 24,
694
+ "num_layers": 30,
695
+ "eps": 1e-6,
696
+ "seperated_timestep": True,
697
+ "require_clip_embedding": False,
698
+ "require_vae_embedding": False,
699
+ "fuse_vae_embedding_in_latents": True,
700
+ }
701
+ elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626":
702
+ # Wan-AI/Wan2.2-I2V-A14B
703
+ config = {
704
+ "has_image_input": False,
705
+ "patch_size": [1, 2, 2],
706
+ "in_dim": 36,
707
+ "dim": 5120,
708
+ "ffn_dim": 13824,
709
+ "freq_dim": 256,
710
+ "text_dim": 4096,
711
+ "out_dim": 16,
712
+ "num_heads": 40,
713
+ "num_layers": 40,
714
+ "eps": 1e-6,
715
+ "require_clip_embedding": False,
716
+ }
717
+ else:
718
+ config = {}
719
+ return state_dict, config
dkt/models/wan_video_image_encoder.py ADDED
@@ -0,0 +1,902 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Concise re-implementation of
3
+ ``https://github.com/openai/CLIP'' and
4
+ ``https://github.com/mlfoundations/open_clip''.
5
+ """
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms as T
11
+ from .wan_video_dit import flash_attention
12
+
13
+
14
+ class SelfAttention(nn.Module):
15
+
16
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
17
+ assert dim % num_heads == 0
18
+ super().__init__()
19
+ self.dim = dim
20
+ self.num_heads = num_heads
21
+ self.head_dim = dim // num_heads
22
+ self.eps = eps
23
+
24
+ # layers
25
+ self.q = nn.Linear(dim, dim)
26
+ self.k = nn.Linear(dim, dim)
27
+ self.v = nn.Linear(dim, dim)
28
+ self.o = nn.Linear(dim, dim)
29
+ self.dropout = nn.Dropout(dropout)
30
+
31
+ def forward(self, x, mask):
32
+ """
33
+ x: [B, L, C].
34
+ """
35
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
36
+
37
+ # compute query, key, value
38
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
39
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
40
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
41
+
42
+ # compute attention
43
+ p = self.dropout.p if self.training else 0.0
44
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
45
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
46
+
47
+ # output
48
+ x = self.o(x)
49
+ x = self.dropout(x)
50
+ return x
51
+
52
+
53
+ class AttentionBlock(nn.Module):
54
+
55
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
56
+ super().__init__()
57
+ self.dim = dim
58
+ self.num_heads = num_heads
59
+ self.post_norm = post_norm
60
+ self.eps = eps
61
+
62
+ # layers
63
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
64
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
65
+ self.ffn = nn.Sequential(
66
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
67
+ nn.Dropout(dropout))
68
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
69
+
70
+ def forward(self, x, mask):
71
+ if self.post_norm:
72
+ x = self.norm1(x + self.attn(x, mask))
73
+ x = self.norm2(x + self.ffn(x))
74
+ else:
75
+ x = x + self.attn(self.norm1(x), mask)
76
+ x = x + self.ffn(self.norm2(x))
77
+ return x
78
+
79
+
80
+ class XLMRoberta(nn.Module):
81
+ """
82
+ XLMRobertaModel with no pooler and no LM head.
83
+ """
84
+
85
+ def __init__(self,
86
+ vocab_size=250002,
87
+ max_seq_len=514,
88
+ type_size=1,
89
+ pad_id=1,
90
+ dim=1024,
91
+ num_heads=16,
92
+ num_layers=24,
93
+ post_norm=True,
94
+ dropout=0.1,
95
+ eps=1e-5):
96
+ super().__init__()
97
+ self.vocab_size = vocab_size
98
+ self.max_seq_len = max_seq_len
99
+ self.type_size = type_size
100
+ self.pad_id = pad_id
101
+ self.dim = dim
102
+ self.num_heads = num_heads
103
+ self.num_layers = num_layers
104
+ self.post_norm = post_norm
105
+ self.eps = eps
106
+
107
+ # embeddings
108
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
109
+ self.type_embedding = nn.Embedding(type_size, dim)
110
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
111
+ self.dropout = nn.Dropout(dropout)
112
+
113
+ # blocks
114
+ self.blocks = nn.ModuleList([
115
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
116
+ for _ in range(num_layers)
117
+ ])
118
+
119
+ # norm layer
120
+ self.norm = nn.LayerNorm(dim, eps=eps)
121
+
122
+ def forward(self, ids):
123
+ """
124
+ ids: [B, L] of torch.LongTensor.
125
+ """
126
+ b, s = ids.shape
127
+ mask = ids.ne(self.pad_id).long()
128
+
129
+ # embeddings
130
+ x = self.token_embedding(ids) + \
131
+ self.type_embedding(torch.zeros_like(ids)) + \
132
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
133
+ if self.post_norm:
134
+ x = self.norm(x)
135
+ x = self.dropout(x)
136
+
137
+ # blocks
138
+ mask = torch.where(
139
+ mask.view(b, 1, 1, s).gt(0), 0.0,
140
+ torch.finfo(x.dtype).min)
141
+ for block in self.blocks:
142
+ x = block(x, mask)
143
+
144
+ # output
145
+ if not self.post_norm:
146
+ x = self.norm(x)
147
+ return x
148
+
149
+
150
+ def xlm_roberta_large(pretrained=False,
151
+ return_tokenizer=False,
152
+ device='cpu',
153
+ **kwargs):
154
+ """
155
+ XLMRobertaLarge adapted from Huggingface.
156
+ """
157
+ # params
158
+ cfg = dict(
159
+ vocab_size=250002,
160
+ max_seq_len=514,
161
+ type_size=1,
162
+ pad_id=1,
163
+ dim=1024,
164
+ num_heads=16,
165
+ num_layers=24,
166
+ post_norm=True,
167
+ dropout=0.1,
168
+ eps=1e-5)
169
+ cfg.update(**kwargs)
170
+
171
+ # init model
172
+ if pretrained:
173
+ from sora import DOWNLOAD_TO_CACHE
174
+
175
+ # init a meta model
176
+ with torch.device('meta'):
177
+ model = XLMRoberta(**cfg)
178
+
179
+ # load checkpoint
180
+ model.load_state_dict(
181
+ torch.load(
182
+ DOWNLOAD_TO_CACHE('models/xlm_roberta/xlm_roberta_large.pth'),
183
+ map_location=device),
184
+ assign=True)
185
+ else:
186
+ # init a model on device
187
+ with torch.device(device):
188
+ model = XLMRoberta(**cfg)
189
+
190
+ # init tokenizer
191
+ if return_tokenizer:
192
+ from sora.data import HuggingfaceTokenizer
193
+ tokenizer = HuggingfaceTokenizer(
194
+ name='xlm-roberta-large',
195
+ seq_len=model.text_len,
196
+ clean='whitespace')
197
+ return model, tokenizer
198
+ else:
199
+ return model
200
+
201
+
202
+
203
+ def pos_interpolate(pos, seq_len):
204
+ if pos.size(1) == seq_len:
205
+ return pos
206
+ else:
207
+ src_grid = int(math.sqrt(pos.size(1)))
208
+ tar_grid = int(math.sqrt(seq_len))
209
+ n = pos.size(1) - src_grid * src_grid
210
+ return torch.cat([
211
+ pos[:, :n],
212
+ F.interpolate(
213
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
214
+ 0, 3, 1, 2),
215
+ size=(tar_grid, tar_grid),
216
+ mode='bicubic',
217
+ align_corners=False).flatten(2).transpose(1, 2)
218
+ ],
219
+ dim=1)
220
+
221
+
222
+ class QuickGELU(nn.Module):
223
+
224
+ def forward(self, x):
225
+ return x * torch.sigmoid(1.702 * x)
226
+
227
+
228
+ class LayerNorm(nn.LayerNorm):
229
+
230
+ def forward(self, x):
231
+ return super().forward(x).type_as(x)
232
+
233
+
234
+ class SelfAttention(nn.Module):
235
+
236
+ def __init__(self,
237
+ dim,
238
+ num_heads,
239
+ causal=False,
240
+ attn_dropout=0.0,
241
+ proj_dropout=0.0):
242
+ assert dim % num_heads == 0
243
+ super().__init__()
244
+ self.dim = dim
245
+ self.num_heads = num_heads
246
+ self.head_dim = dim // num_heads
247
+ self.causal = causal
248
+ self.attn_dropout = attn_dropout
249
+ self.proj_dropout = proj_dropout
250
+
251
+ # layers
252
+ self.to_qkv = nn.Linear(dim, dim * 3)
253
+ self.proj = nn.Linear(dim, dim)
254
+
255
+ def forward(self, x):
256
+ """
257
+ x: [B, L, C].
258
+ """
259
+ # compute query, key, value
260
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
261
+
262
+ # compute attention
263
+ x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
264
+
265
+ # output
266
+ x = self.proj(x)
267
+ x = F.dropout(x, self.proj_dropout, self.training)
268
+ return x
269
+
270
+
271
+ class SwiGLU(nn.Module):
272
+
273
+ def __init__(self, dim, mid_dim):
274
+ super().__init__()
275
+ self.dim = dim
276
+ self.mid_dim = mid_dim
277
+
278
+ # layers
279
+ self.fc1 = nn.Linear(dim, mid_dim)
280
+ self.fc2 = nn.Linear(dim, mid_dim)
281
+ self.fc3 = nn.Linear(mid_dim, dim)
282
+
283
+ def forward(self, x):
284
+ x = F.silu(self.fc1(x)) * self.fc2(x)
285
+ x = self.fc3(x)
286
+ return x
287
+
288
+
289
+ class AttentionBlock(nn.Module):
290
+
291
+ def __init__(self,
292
+ dim,
293
+ mlp_ratio,
294
+ num_heads,
295
+ post_norm=False,
296
+ causal=False,
297
+ activation='quick_gelu',
298
+ attn_dropout=0.0,
299
+ proj_dropout=0.0,
300
+ norm_eps=1e-5):
301
+ assert activation in ['quick_gelu', 'gelu', 'swi_glu']
302
+ super().__init__()
303
+ self.dim = dim
304
+ self.mlp_ratio = mlp_ratio
305
+ self.num_heads = num_heads
306
+ self.post_norm = post_norm
307
+ self.causal = causal
308
+ self.norm_eps = norm_eps
309
+
310
+ # layers
311
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
312
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
313
+ proj_dropout)
314
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
315
+ if activation == 'swi_glu':
316
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
317
+ else:
318
+ self.mlp = nn.Sequential(
319
+ nn.Linear(dim, int(dim * mlp_ratio)),
320
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
321
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
322
+
323
+ def forward(self, x):
324
+ if self.post_norm:
325
+ x = x + self.norm1(self.attn(x))
326
+ x = x + self.norm2(self.mlp(x))
327
+ else:
328
+ x = x + self.attn(self.norm1(x))
329
+ x = x + self.mlp(self.norm2(x))
330
+ return x
331
+
332
+
333
+ class AttentionPool(nn.Module):
334
+
335
+ def __init__(self,
336
+ dim,
337
+ mlp_ratio,
338
+ num_heads,
339
+ activation='gelu',
340
+ proj_dropout=0.0,
341
+ norm_eps=1e-5):
342
+ assert dim % num_heads == 0
343
+ super().__init__()
344
+ self.dim = dim
345
+ self.mlp_ratio = mlp_ratio
346
+ self.num_heads = num_heads
347
+ self.head_dim = dim // num_heads
348
+ self.proj_dropout = proj_dropout
349
+ self.norm_eps = norm_eps
350
+
351
+ # layers
352
+ gain = 1.0 / math.sqrt(dim)
353
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
354
+ self.to_q = nn.Linear(dim, dim)
355
+ self.to_kv = nn.Linear(dim, dim * 2)
356
+ self.proj = nn.Linear(dim, dim)
357
+ self.norm = LayerNorm(dim, eps=norm_eps)
358
+ self.mlp = nn.Sequential(
359
+ nn.Linear(dim, int(dim * mlp_ratio)),
360
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
361
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
362
+
363
+ def forward(self, x):
364
+ """
365
+ x: [B, L, C].
366
+ """
367
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
368
+
369
+ # compute query, key, value
370
+ q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1)
371
+ k, v = self.to_kv(x).chunk(2, dim=-1)
372
+
373
+ # compute attention
374
+ x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
375
+ x = x.reshape(b, 1, c)
376
+
377
+ # output
378
+ x = self.proj(x)
379
+ x = F.dropout(x, self.proj_dropout, self.training)
380
+
381
+ # mlp
382
+ x = x + self.mlp(self.norm(x))
383
+ return x[:, 0]
384
+
385
+
386
+ class VisionTransformer(nn.Module):
387
+
388
+ def __init__(self,
389
+ image_size=224,
390
+ patch_size=16,
391
+ dim=768,
392
+ mlp_ratio=4,
393
+ out_dim=512,
394
+ num_heads=12,
395
+ num_layers=12,
396
+ pool_type='token',
397
+ pre_norm=True,
398
+ post_norm=False,
399
+ activation='quick_gelu',
400
+ attn_dropout=0.0,
401
+ proj_dropout=0.0,
402
+ embedding_dropout=0.0,
403
+ norm_eps=1e-5):
404
+ if image_size % patch_size != 0:
405
+ print(
406
+ '[WARNING] image_size is not divisible by patch_size',
407
+ flush=True)
408
+ assert pool_type in ('token', 'token_fc', 'attn_pool')
409
+ out_dim = out_dim or dim
410
+ super().__init__()
411
+ self.image_size = image_size
412
+ self.patch_size = patch_size
413
+ self.num_patches = (image_size // patch_size)**2
414
+ self.dim = dim
415
+ self.mlp_ratio = mlp_ratio
416
+ self.out_dim = out_dim
417
+ self.num_heads = num_heads
418
+ self.num_layers = num_layers
419
+ self.pool_type = pool_type
420
+ self.post_norm = post_norm
421
+ self.norm_eps = norm_eps
422
+
423
+ # embeddings
424
+ gain = 1.0 / math.sqrt(dim)
425
+ self.patch_embedding = nn.Conv2d(
426
+ 3,
427
+ dim,
428
+ kernel_size=patch_size,
429
+ stride=patch_size,
430
+ bias=not pre_norm)
431
+ if pool_type in ('token', 'token_fc'):
432
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
433
+ self.pos_embedding = nn.Parameter(gain * torch.randn(
434
+ 1, self.num_patches +
435
+ (1 if pool_type in ('token', 'token_fc') else 0), dim))
436
+ self.dropout = nn.Dropout(embedding_dropout)
437
+
438
+ # transformer
439
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
440
+ self.transformer = nn.Sequential(*[
441
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
442
+ activation, attn_dropout, proj_dropout, norm_eps)
443
+ for _ in range(num_layers)
444
+ ])
445
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
446
+
447
+ # head
448
+ if pool_type == 'token':
449
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
450
+ elif pool_type == 'token_fc':
451
+ self.head = nn.Linear(dim, out_dim)
452
+ elif pool_type == 'attn_pool':
453
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
454
+ proj_dropout, norm_eps)
455
+
456
+ def forward(self, x, interpolation=False, use_31_block=False):
457
+ b = x.size(0)
458
+
459
+ # embeddings
460
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
461
+ if self.pool_type in ('token', 'token_fc'):
462
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1).to(dtype=x.dtype, device=x.device), x], dim=1)
463
+ if interpolation:
464
+ e = pos_interpolate(self.pos_embedding, x.size(1))
465
+ else:
466
+ e = self.pos_embedding
467
+ e = e.to(dtype=x.dtype, device=x.device)
468
+ x = self.dropout(x + e)
469
+ if self.pre_norm is not None:
470
+ x = self.pre_norm(x)
471
+
472
+ # transformer
473
+ if use_31_block:
474
+ x = self.transformer[:-1](x)
475
+ return x
476
+ else:
477
+ x = self.transformer(x)
478
+ return x
479
+
480
+
481
+ class CLIP(nn.Module):
482
+
483
+ def __init__(self,
484
+ embed_dim=512,
485
+ image_size=224,
486
+ patch_size=16,
487
+ vision_dim=768,
488
+ vision_mlp_ratio=4,
489
+ vision_heads=12,
490
+ vision_layers=12,
491
+ vision_pool='token',
492
+ vision_pre_norm=True,
493
+ vision_post_norm=False,
494
+ vocab_size=49408,
495
+ text_len=77,
496
+ text_dim=512,
497
+ text_mlp_ratio=4,
498
+ text_heads=8,
499
+ text_layers=12,
500
+ text_causal=True,
501
+ text_pool='argmax',
502
+ text_head_bias=False,
503
+ logit_bias=None,
504
+ activation='quick_gelu',
505
+ attn_dropout=0.0,
506
+ proj_dropout=0.0,
507
+ embedding_dropout=0.0,
508
+ norm_eps=1e-5):
509
+ super().__init__()
510
+ self.embed_dim = embed_dim
511
+ self.image_size = image_size
512
+ self.patch_size = patch_size
513
+ self.vision_dim = vision_dim
514
+ self.vision_mlp_ratio = vision_mlp_ratio
515
+ self.vision_heads = vision_heads
516
+ self.vision_layers = vision_layers
517
+ self.vision_pool = vision_pool
518
+ self.vision_pre_norm = vision_pre_norm
519
+ self.vision_post_norm = vision_post_norm
520
+ self.vocab_size = vocab_size
521
+ self.text_len = text_len
522
+ self.text_dim = text_dim
523
+ self.text_mlp_ratio = text_mlp_ratio
524
+ self.text_heads = text_heads
525
+ self.text_layers = text_layers
526
+ self.text_causal = text_causal
527
+ self.text_pool = text_pool
528
+ self.text_head_bias = text_head_bias
529
+ self.norm_eps = norm_eps
530
+
531
+ # models
532
+ self.visual = VisionTransformer(
533
+ image_size=image_size,
534
+ patch_size=patch_size,
535
+ dim=vision_dim,
536
+ mlp_ratio=vision_mlp_ratio,
537
+ out_dim=embed_dim,
538
+ num_heads=vision_heads,
539
+ num_layers=vision_layers,
540
+ pool_type=vision_pool,
541
+ pre_norm=vision_pre_norm,
542
+ post_norm=vision_post_norm,
543
+ activation=activation,
544
+ attn_dropout=attn_dropout,
545
+ proj_dropout=proj_dropout,
546
+ embedding_dropout=embedding_dropout,
547
+ norm_eps=norm_eps)
548
+ self.textual = TextTransformer(
549
+ vocab_size=vocab_size,
550
+ text_len=text_len,
551
+ dim=text_dim,
552
+ mlp_ratio=text_mlp_ratio,
553
+ out_dim=embed_dim,
554
+ num_heads=text_heads,
555
+ num_layers=text_layers,
556
+ causal=text_causal,
557
+ pool_type=text_pool,
558
+ head_bias=text_head_bias,
559
+ activation=activation,
560
+ attn_dropout=attn_dropout,
561
+ proj_dropout=proj_dropout,
562
+ embedding_dropout=embedding_dropout,
563
+ norm_eps=norm_eps)
564
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
565
+ if logit_bias is not None:
566
+ self.logit_bias = nn.Parameter(logit_bias * torch.ones([]))
567
+
568
+ # initialize weights
569
+ self.init_weights()
570
+
571
+ def forward(self, imgs, txt_ids):
572
+ """
573
+ imgs: [B, 3, H, W] of torch.float32.
574
+ - mean: [0.48145466, 0.4578275, 0.40821073]
575
+ - std: [0.26862954, 0.26130258, 0.27577711]
576
+ txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer.
577
+ """
578
+ xi = self.visual(imgs)
579
+ xt = self.textual(txt_ids)
580
+ return xi, xt
581
+
582
+ def init_weights(self):
583
+ # embeddings
584
+ nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
585
+ nn.init.normal_(self.visual.patch_embedding.weight, std=0.1)
586
+
587
+ # attentions
588
+ for modality in ['visual', 'textual']:
589
+ dim = self.vision_dim if modality == 'visual' else self.text_dim
590
+ transformer = getattr(self, modality).transformer
591
+ proj_gain = (1.0 / math.sqrt(dim)) * (
592
+ 1.0 / math.sqrt(2 * len(transformer)))
593
+ attn_gain = 1.0 / math.sqrt(dim)
594
+ mlp_gain = 1.0 / math.sqrt(2.0 * dim)
595
+ for block in transformer:
596
+ nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
597
+ nn.init.normal_(block.attn.proj.weight, std=proj_gain)
598
+ nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
599
+ nn.init.normal_(block.mlp[2].weight, std=proj_gain)
600
+
601
+ def param_groups(self):
602
+ groups = [{
603
+ 'params': [
604
+ p for n, p in self.named_parameters()
605
+ if 'norm' in n or n.endswith('bias')
606
+ ],
607
+ 'weight_decay': 0.0
608
+ }, {
609
+ 'params': [
610
+ p for n, p in self.named_parameters()
611
+ if not ('norm' in n or n.endswith('bias'))
612
+ ]
613
+ }]
614
+ return groups
615
+
616
+
617
+ class XLMRobertaWithHead(XLMRoberta):
618
+
619
+ def __init__(self, **kwargs):
620
+ self.out_dim = kwargs.pop('out_dim')
621
+ super().__init__(**kwargs)
622
+
623
+ # head
624
+ mid_dim = (self.dim + self.out_dim) // 2
625
+ self.head = nn.Sequential(
626
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
627
+ nn.Linear(mid_dim, self.out_dim, bias=False))
628
+
629
+ def forward(self, ids):
630
+ # xlm-roberta
631
+ x = super().forward(ids)
632
+
633
+ # average pooling
634
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
635
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
636
+
637
+ # head
638
+ x = self.head(x)
639
+ return x
640
+
641
+
642
+ class XLMRobertaCLIP(nn.Module):
643
+
644
+ def __init__(self,
645
+ embed_dim=1024,
646
+ image_size=224,
647
+ patch_size=14,
648
+ vision_dim=1280,
649
+ vision_mlp_ratio=4,
650
+ vision_heads=16,
651
+ vision_layers=32,
652
+ vision_pool='token',
653
+ vision_pre_norm=True,
654
+ vision_post_norm=False,
655
+ activation='gelu',
656
+ vocab_size=250002,
657
+ max_text_len=514,
658
+ type_size=1,
659
+ pad_id=1,
660
+ text_dim=1024,
661
+ text_heads=16,
662
+ text_layers=24,
663
+ text_post_norm=True,
664
+ text_dropout=0.1,
665
+ attn_dropout=0.0,
666
+ proj_dropout=0.0,
667
+ embedding_dropout=0.0,
668
+ norm_eps=1e-5):
669
+ super().__init__()
670
+ self.embed_dim = embed_dim
671
+ self.image_size = image_size
672
+ self.patch_size = patch_size
673
+ self.vision_dim = vision_dim
674
+ self.vision_mlp_ratio = vision_mlp_ratio
675
+ self.vision_heads = vision_heads
676
+ self.vision_layers = vision_layers
677
+ self.vision_pre_norm = vision_pre_norm
678
+ self.vision_post_norm = vision_post_norm
679
+ self.activation = activation
680
+ self.vocab_size = vocab_size
681
+ self.max_text_len = max_text_len
682
+ self.type_size = type_size
683
+ self.pad_id = pad_id
684
+ self.text_dim = text_dim
685
+ self.text_heads = text_heads
686
+ self.text_layers = text_layers
687
+ self.text_post_norm = text_post_norm
688
+ self.norm_eps = norm_eps
689
+
690
+ # models
691
+ self.visual = VisionTransformer(
692
+ image_size=image_size,
693
+ patch_size=patch_size,
694
+ dim=vision_dim,
695
+ mlp_ratio=vision_mlp_ratio,
696
+ out_dim=embed_dim,
697
+ num_heads=vision_heads,
698
+ num_layers=vision_layers,
699
+ pool_type=vision_pool,
700
+ pre_norm=vision_pre_norm,
701
+ post_norm=vision_post_norm,
702
+ activation=activation,
703
+ attn_dropout=attn_dropout,
704
+ proj_dropout=proj_dropout,
705
+ embedding_dropout=embedding_dropout,
706
+ norm_eps=norm_eps)
707
+ self.textual = None
708
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
709
+
710
+ def forward(self, imgs, txt_ids):
711
+ """
712
+ imgs: [B, 3, H, W] of torch.float32.
713
+ - mean: [0.48145466, 0.4578275, 0.40821073]
714
+ - std: [0.26862954, 0.26130258, 0.27577711]
715
+ txt_ids: [B, L] of torch.long.
716
+ Encoded by data.CLIPTokenizer.
717
+ """
718
+ xi = self.visual(imgs)
719
+ xt = self.textual(txt_ids)
720
+ return xi, xt
721
+
722
+ def param_groups(self):
723
+ groups = [{
724
+ 'params': [
725
+ p for n, p in self.named_parameters()
726
+ if 'norm' in n or n.endswith('bias')
727
+ ],
728
+ 'weight_decay': 0.0
729
+ }, {
730
+ 'params': [
731
+ p for n, p in self.named_parameters()
732
+ if not ('norm' in n or n.endswith('bias'))
733
+ ]
734
+ }]
735
+ return groups
736
+
737
+
738
+ def _clip(pretrained=False,
739
+ pretrained_name=None,
740
+ model_cls=CLIP,
741
+ return_transforms=False,
742
+ return_tokenizer=False,
743
+ tokenizer_padding='eos',
744
+ dtype=torch.float32,
745
+ device='cpu',
746
+ **kwargs):
747
+ # init model
748
+ if pretrained and pretrained_name:
749
+ from sora import BUCKET, DOWNLOAD_TO_CACHE
750
+
751
+ # init a meta model
752
+ with torch.device('meta'):
753
+ model = model_cls(**kwargs)
754
+
755
+ # checkpoint path
756
+ checkpoint = f'models/clip/{pretrained_name}'
757
+ if dtype in (torch.float16, torch.bfloat16):
758
+ suffix = '-' + {
759
+ torch.float16: 'fp16',
760
+ torch.bfloat16: 'bf16'
761
+ }[dtype]
762
+ if object_exists(BUCKET, f'{checkpoint}{suffix}.pth'):
763
+ checkpoint = f'{checkpoint}{suffix}'
764
+ checkpoint += '.pth'
765
+
766
+ # load
767
+ model.load_state_dict(
768
+ torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device),
769
+ assign=True,
770
+ strict=False)
771
+ else:
772
+ # init a model on device
773
+ with torch.device(device):
774
+ model = model_cls(**kwargs)
775
+
776
+ # set device
777
+ output = (model,)
778
+
779
+ # init transforms
780
+ if return_transforms:
781
+ # mean and std
782
+ if 'siglip' in pretrained_name.lower():
783
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
784
+ else:
785
+ mean = [0.48145466, 0.4578275, 0.40821073]
786
+ std = [0.26862954, 0.26130258, 0.27577711]
787
+
788
+ # transforms
789
+ transforms = T.Compose([
790
+ T.Resize((model.image_size, model.image_size),
791
+ interpolation=T.InterpolationMode.BICUBIC),
792
+ T.ToTensor(),
793
+ T.Normalize(mean=mean, std=std)
794
+ ])
795
+ output += (transforms,)
796
+
797
+ # init tokenizer
798
+ if return_tokenizer:
799
+ from sora import data
800
+ if 'siglip' in pretrained_name.lower():
801
+ tokenizer = data.HuggingfaceTokenizer(
802
+ name=f'timm/{pretrained_name}',
803
+ seq_len=model.text_len,
804
+ clean='canonicalize')
805
+ elif 'xlm' in pretrained_name.lower():
806
+ tokenizer = data.HuggingfaceTokenizer(
807
+ name='xlm-roberta-large',
808
+ seq_len=model.max_text_len - 2,
809
+ clean='whitespace')
810
+ elif 'mba' in pretrained_name.lower():
811
+ tokenizer = data.HuggingfaceTokenizer(
812
+ name='facebook/xlm-roberta-xl',
813
+ seq_len=model.max_text_len - 2,
814
+ clean='whitespace')
815
+ else:
816
+ tokenizer = data.CLIPTokenizer(
817
+ seq_len=model.text_len, padding=tokenizer_padding)
818
+ output += (tokenizer,)
819
+ return output[0] if len(output) == 1 else output
820
+
821
+
822
+ def clip_xlm_roberta_vit_h_14(
823
+ pretrained=False,
824
+ pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
825
+ **kwargs):
826
+ cfg = dict(
827
+ embed_dim=1024,
828
+ image_size=224,
829
+ patch_size=14,
830
+ vision_dim=1280,
831
+ vision_mlp_ratio=4,
832
+ vision_heads=16,
833
+ vision_layers=32,
834
+ vision_pool='token',
835
+ activation='gelu',
836
+ vocab_size=250002,
837
+ max_text_len=514,
838
+ type_size=1,
839
+ pad_id=1,
840
+ text_dim=1024,
841
+ text_heads=16,
842
+ text_layers=24,
843
+ text_post_norm=True,
844
+ text_dropout=0.1,
845
+ attn_dropout=0.0,
846
+ proj_dropout=0.0,
847
+ embedding_dropout=0.0)
848
+ cfg.update(**kwargs)
849
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
850
+
851
+
852
+ class WanImageEncoder(torch.nn.Module):
853
+
854
+ def __init__(self):
855
+ super().__init__()
856
+ # init model
857
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
858
+ pretrained=False,
859
+ return_transforms=True,
860
+ return_tokenizer=False,
861
+ dtype=torch.float32,
862
+ device="cpu")
863
+
864
+ def encode_image(self, videos):
865
+ # preprocess
866
+ size = (self.model.image_size,) * 2
867
+ videos = torch.cat([
868
+ F.interpolate(
869
+ u,
870
+ size=size,
871
+ mode='bicubic',
872
+ align_corners=False) for u in videos
873
+ ])
874
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
875
+
876
+ # forward
877
+ dtype = next(iter(self.model.visual.parameters())).dtype
878
+ videos = videos.to(dtype)
879
+ out = self.model.visual(videos, use_31_block=True)
880
+ return out
881
+
882
+ @staticmethod
883
+ def state_dict_converter():
884
+ return WanImageEncoderStateDictConverter()
885
+
886
+
887
+ class WanImageEncoderStateDictConverter:
888
+ def __init__(self):
889
+ pass
890
+
891
+ def from_diffusers(self, state_dict):
892
+ return state_dict
893
+
894
+ def from_civitai(self, state_dict):
895
+ state_dict_ = {}
896
+ for name, param in state_dict.items():
897
+ if name.startswith("textual."):
898
+ continue
899
+ name = "model." + name
900
+ state_dict_[name] = param
901
+ return state_dict_
902
+
dkt/models/wan_video_motion_controller.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .wan_video_dit import sinusoidal_embedding_1d
4
+
5
+
6
+
7
+ class WanMotionControllerModel(torch.nn.Module):
8
+ def __init__(self, freq_dim=256, dim=1536):
9
+ super().__init__()
10
+ self.freq_dim = freq_dim
11
+ self.linear = nn.Sequential(
12
+ nn.Linear(freq_dim, dim),
13
+ nn.SiLU(),
14
+ nn.Linear(dim, dim),
15
+ nn.SiLU(),
16
+ nn.Linear(dim, dim * 6),
17
+ )
18
+
19
+ def forward(self, motion_bucket_id):
20
+ emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
21
+ emb = self.linear(emb)
22
+ return emb
23
+
24
+ def init(self):
25
+ state_dict = self.linear[-1].state_dict()
26
+ state_dict = {i: state_dict[i] * 0 for i in state_dict}
27
+ self.linear[-1].load_state_dict(state_dict)
28
+
29
+ @staticmethod
30
+ def state_dict_converter():
31
+ return WanMotionControllerModelDictConverter()
32
+
33
+
34
+
35
+ class WanMotionControllerModelDictConverter:
36
+ def __init__(self):
37
+ pass
38
+
39
+ def from_diffusers(self, state_dict):
40
+ return state_dict
41
+
42
+ def from_civitai(self, state_dict):
43
+ return state_dict
44
+
dkt/models/wan_video_text_encoder.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def fp16_clamp(x):
9
+ if x.dtype == torch.float16 and torch.isinf(x).any():
10
+ clamp = torch.finfo(x.dtype).max - 1000
11
+ x = torch.clamp(x, min=-clamp, max=clamp)
12
+ return x
13
+
14
+
15
+ class GELU(nn.Module):
16
+
17
+ def forward(self, x):
18
+ return 0.5 * x * (1.0 + torch.tanh(
19
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
20
+
21
+
22
+ class T5LayerNorm(nn.Module):
23
+
24
+ def __init__(self, dim, eps=1e-6):
25
+ super(T5LayerNorm, self).__init__()
26
+ self.dim = dim
27
+ self.eps = eps
28
+ self.weight = nn.Parameter(torch.ones(dim))
29
+
30
+ def forward(self, x):
31
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
32
+ self.eps)
33
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
34
+ x = x.type_as(self.weight)
35
+ return self.weight * x
36
+
37
+
38
+ class T5Attention(nn.Module):
39
+
40
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
41
+ assert dim_attn % num_heads == 0
42
+ super(T5Attention, self).__init__()
43
+ self.dim = dim
44
+ self.dim_attn = dim_attn
45
+ self.num_heads = num_heads
46
+ self.head_dim = dim_attn // num_heads
47
+
48
+ # layers
49
+ self.q = nn.Linear(dim, dim_attn, bias=False)
50
+ self.k = nn.Linear(dim, dim_attn, bias=False)
51
+ self.v = nn.Linear(dim, dim_attn, bias=False)
52
+ self.o = nn.Linear(dim_attn, dim, bias=False)
53
+ self.dropout = nn.Dropout(dropout)
54
+
55
+ def forward(self, x, context=None, mask=None, pos_bias=None):
56
+ """
57
+ x: [B, L1, C].
58
+ context: [B, L2, C] or None.
59
+ mask: [B, L2] or [B, L1, L2] or None.
60
+ """
61
+ # check inputs
62
+ context = x if context is None else context
63
+ b, n, c = x.size(0), self.num_heads, self.head_dim
64
+
65
+ # compute query, key, value
66
+ q = self.q(x).view(b, -1, n, c)
67
+ k = self.k(context).view(b, -1, n, c)
68
+ v = self.v(context).view(b, -1, n, c)
69
+
70
+ # attention bias
71
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
72
+ if pos_bias is not None:
73
+ attn_bias += pos_bias
74
+ if mask is not None:
75
+ assert mask.ndim in [2, 3]
76
+ mask = mask.view(b, 1, 1,
77
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
78
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
79
+
80
+ # compute attention (T5 does not use scaling)
81
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
82
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
83
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
84
+
85
+ # output
86
+ x = x.reshape(b, -1, n * c)
87
+ x = self.o(x)
88
+ x = self.dropout(x)
89
+ return x
90
+
91
+
92
+ class T5FeedForward(nn.Module):
93
+
94
+ def __init__(self, dim, dim_ffn, dropout=0.1):
95
+ super(T5FeedForward, self).__init__()
96
+ self.dim = dim
97
+ self.dim_ffn = dim_ffn
98
+
99
+ # layers
100
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
101
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
102
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
103
+ self.dropout = nn.Dropout(dropout)
104
+
105
+ def forward(self, x):
106
+ x = self.fc1(x) * self.gate(x)
107
+ x = self.dropout(x)
108
+ x = self.fc2(x)
109
+ x = self.dropout(x)
110
+ return x
111
+
112
+
113
+ class T5SelfAttention(nn.Module):
114
+
115
+ def __init__(self,
116
+ dim,
117
+ dim_attn,
118
+ dim_ffn,
119
+ num_heads,
120
+ num_buckets,
121
+ shared_pos=True,
122
+ dropout=0.1):
123
+ super(T5SelfAttention, self).__init__()
124
+ self.dim = dim
125
+ self.dim_attn = dim_attn
126
+ self.dim_ffn = dim_ffn
127
+ self.num_heads = num_heads
128
+ self.num_buckets = num_buckets
129
+ self.shared_pos = shared_pos
130
+
131
+ # layers
132
+ self.norm1 = T5LayerNorm(dim)
133
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
134
+ self.norm2 = T5LayerNorm(dim)
135
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
136
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
137
+ num_buckets, num_heads, bidirectional=True)
138
+
139
+ def forward(self, x, mask=None, pos_bias=None):
140
+ e = pos_bias if self.shared_pos else self.pos_embedding(
141
+ x.size(1), x.size(1))
142
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
143
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
144
+ return x
145
+
146
+
147
+ class T5RelativeEmbedding(nn.Module):
148
+
149
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
150
+ super(T5RelativeEmbedding, self).__init__()
151
+ self.num_buckets = num_buckets
152
+ self.num_heads = num_heads
153
+ self.bidirectional = bidirectional
154
+ self.max_dist = max_dist
155
+
156
+ # layers
157
+ self.embedding = nn.Embedding(num_buckets, num_heads)
158
+
159
+ def forward(self, lq, lk):
160
+ device = self.embedding.weight.device
161
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
162
+ # torch.arange(lq).unsqueeze(1).to(device)
163
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
164
+ torch.arange(lq, device=device).unsqueeze(1)
165
+ rel_pos = self._relative_position_bucket(rel_pos)
166
+ rel_pos_embeds = self.embedding(rel_pos)
167
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
168
+ 0) # [1, N, Lq, Lk]
169
+ return rel_pos_embeds.contiguous()
170
+
171
+ def _relative_position_bucket(self, rel_pos):
172
+ # preprocess
173
+ if self.bidirectional:
174
+ num_buckets = self.num_buckets // 2
175
+ rel_buckets = (rel_pos > 0).long() * num_buckets
176
+ rel_pos = torch.abs(rel_pos)
177
+ else:
178
+ num_buckets = self.num_buckets
179
+ rel_buckets = 0
180
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
181
+
182
+ # embeddings for small and large positions
183
+ max_exact = num_buckets // 2
184
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
185
+ math.log(self.max_dist / max_exact) *
186
+ (num_buckets - max_exact)).long()
187
+ rel_pos_large = torch.min(
188
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
189
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
190
+ return rel_buckets
191
+
192
+ def init_weights(m):
193
+ if isinstance(m, T5LayerNorm):
194
+ nn.init.ones_(m.weight)
195
+ elif isinstance(m, T5FeedForward):
196
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
197
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
198
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
199
+ elif isinstance(m, T5Attention):
200
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
201
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
202
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
203
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
204
+ elif isinstance(m, T5RelativeEmbedding):
205
+ nn.init.normal_(
206
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
207
+
208
+
209
+ class WanTextEncoder(torch.nn.Module):
210
+
211
+ def __init__(self,
212
+ vocab=256384,
213
+ dim=4096,
214
+ dim_attn=4096,
215
+ dim_ffn=10240,
216
+ num_heads=64,
217
+ num_layers=24,
218
+ num_buckets=32,
219
+ shared_pos=False,
220
+ dropout=0.1):
221
+ super(WanTextEncoder, self).__init__()
222
+ self.dim = dim
223
+ self.dim_attn = dim_attn
224
+ self.dim_ffn = dim_ffn
225
+ self.num_heads = num_heads
226
+ self.num_layers = num_layers
227
+ self.num_buckets = num_buckets
228
+ self.shared_pos = shared_pos
229
+
230
+ # layers
231
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
232
+ else nn.Embedding(vocab, dim)
233
+ self.pos_embedding = T5RelativeEmbedding(
234
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
235
+ self.dropout = nn.Dropout(dropout)
236
+ self.blocks = nn.ModuleList([
237
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
238
+ shared_pos, dropout) for _ in range(num_layers)
239
+ ])
240
+ self.norm = T5LayerNorm(dim)
241
+
242
+ # initialize weights
243
+ self.apply(init_weights)
244
+
245
+ def forward(self, ids, mask=None):
246
+ x = self.token_embedding(ids)
247
+ x = self.dropout(x)
248
+ e = self.pos_embedding(x.size(1),
249
+ x.size(1)) if self.shared_pos else None
250
+ for block in self.blocks:
251
+ x = block(x, mask, pos_bias=e)
252
+ x = self.norm(x)
253
+ x = self.dropout(x)
254
+ return x
255
+
256
+ @staticmethod
257
+ def state_dict_converter():
258
+ return WanTextEncoderStateDictConverter()
259
+
260
+
261
+ class WanTextEncoderStateDictConverter:
262
+ def __init__(self):
263
+ pass
264
+
265
+ def from_diffusers(self, state_dict):
266
+ return state_dict
267
+
268
+ def from_civitai(self, state_dict):
269
+ return state_dict
dkt/models/wan_video_vace.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .wan_video_dit import DiTBlock
3
+ from .utils import hash_state_dict_keys
4
+
5
+ class VaceWanAttentionBlock(DiTBlock):
6
+ def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
7
+ super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
8
+ self.block_id = block_id
9
+ if block_id == 0:
10
+ self.before_proj = torch.nn.Linear(self.dim, self.dim)
11
+ self.after_proj = torch.nn.Linear(self.dim, self.dim)
12
+
13
+ def forward(self, c, x, context, t_mod, freqs):
14
+ if self.block_id == 0:
15
+ c = self.before_proj(c) + x
16
+ all_c = []
17
+ else:
18
+ all_c = list(torch.unbind(c))
19
+ c = all_c.pop(-1)
20
+ c = super().forward(c, context, t_mod, freqs)
21
+ c_skip = self.after_proj(c)
22
+ all_c += [c_skip, c]
23
+ c = torch.stack(all_c)
24
+ return c
25
+
26
+
27
+ class VaceWanModel(torch.nn.Module):
28
+ def __init__(
29
+ self,
30
+ vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
31
+ vace_in_dim=96,
32
+ patch_size=(1, 2, 2),
33
+ has_image_input=False,
34
+ dim=1536,
35
+ num_heads=12,
36
+ ffn_dim=8960,
37
+ eps=1e-6,
38
+ ):
39
+ super().__init__()
40
+ self.vace_layers = vace_layers
41
+ self.vace_in_dim = vace_in_dim
42
+ self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
43
+
44
+ # vace blocks
45
+ self.vace_blocks = torch.nn.ModuleList([
46
+ VaceWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
47
+ for i in self.vace_layers
48
+ ])
49
+
50
+ # vace patch embeddings
51
+ self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size)
52
+
53
+ def forward(
54
+ self, x, vace_context, context, t_mod, freqs,
55
+ use_gradient_checkpointing: bool = False,
56
+ use_gradient_checkpointing_offload: bool = False,
57
+ ):
58
+ c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
59
+ c = [u.flatten(2).transpose(1, 2) for u in c]
60
+ c = torch.cat([
61
+ torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))],
62
+ dim=1) for u in c
63
+ ])
64
+
65
+ def create_custom_forward(module):
66
+ def custom_forward(*inputs):
67
+ return module(*inputs)
68
+ return custom_forward
69
+
70
+ for block in self.vace_blocks:
71
+ if use_gradient_checkpointing_offload:
72
+ with torch.autograd.graph.save_on_cpu():
73
+ c = torch.utils.checkpoint.checkpoint(
74
+ create_custom_forward(block),
75
+ c, x, context, t_mod, freqs,
76
+ use_reentrant=False,
77
+ )
78
+ elif use_gradient_checkpointing:
79
+ c = torch.utils.checkpoint.checkpoint(
80
+ create_custom_forward(block),
81
+ c, x, context, t_mod, freqs,
82
+ use_reentrant=False,
83
+ )
84
+ else:
85
+ c = block(c, x, context, t_mod, freqs)
86
+ hints = torch.unbind(c)[:-1]
87
+ return hints
88
+
89
+ @staticmethod
90
+ def state_dict_converter():
91
+ return VaceWanModelDictConverter()
92
+
93
+
94
+ class VaceWanModelDictConverter:
95
+ def __init__(self):
96
+ pass
97
+
98
+ def from_civitai(self, state_dict):
99
+ state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("vace")}
100
+ if hash_state_dict_keys(state_dict_) == '3b2726384e4f64837bdf216eea3f310d': # vace 14B
101
+ config = {
102
+ "vace_layers": (0, 5, 10, 15, 20, 25, 30, 35),
103
+ "vace_in_dim": 96,
104
+ "patch_size": (1, 2, 2),
105
+ "has_image_input": False,
106
+ "dim": 5120,
107
+ "num_heads": 40,
108
+ "ffn_dim": 13824,
109
+ "eps": 1e-06,
110
+ }
111
+ else:
112
+ config = {}
113
+ return state_dict_, config
dkt/models/wan_video_vae.py ADDED
@@ -0,0 +1,1376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange, repeat
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from tqdm import tqdm
7
+
8
+ CACHE_T = 2
9
+
10
+
11
+ def check_is_instance(model, module_class):
12
+ if isinstance(model, module_class):
13
+ return True
14
+ if hasattr(model, "module") and isinstance(model.module, module_class):
15
+ return True
16
+ return False
17
+
18
+
19
+ def block_causal_mask(x, block_size):
20
+ # params
21
+ b, n, s, _, device = *x.size(), x.device
22
+ assert s % block_size == 0
23
+ num_blocks = s // block_size
24
+
25
+ # build mask
26
+ mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
27
+ for i in range(num_blocks):
28
+ mask[:, :,
29
+ i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
30
+ return mask
31
+
32
+
33
+ class CausalConv3d(nn.Conv3d):
34
+ """
35
+ Causal 3d convolusion.
36
+ """
37
+
38
+ def __init__(self, *args, **kwargs):
39
+ super().__init__(*args, **kwargs)
40
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
41
+ self.padding[1], 2 * self.padding[0], 0)
42
+ self.padding = (0, 0, 0)
43
+
44
+ def forward(self, x, cache_x=None):
45
+ padding = list(self._padding)
46
+ if cache_x is not None and self._padding[4] > 0:
47
+ cache_x = cache_x.to(x.device)
48
+ x = torch.cat([cache_x, x], dim=2)
49
+ padding[4] -= cache_x.shape[2]
50
+ x = F.pad(x, padding)
51
+
52
+ return super().forward(x)
53
+
54
+
55
+ class RMS_norm(nn.Module):
56
+
57
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
58
+ super().__init__()
59
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
60
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
61
+
62
+ self.channel_first = channel_first
63
+ self.scale = dim**0.5
64
+ self.gamma = nn.Parameter(torch.ones(shape))
65
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
66
+
67
+ def forward(self, x):
68
+ return F.normalize(
69
+ x, dim=(1 if self.channel_first else
70
+ -1)) * self.scale * self.gamma + self.bias
71
+
72
+
73
+ class Upsample(nn.Upsample):
74
+
75
+ def forward(self, x):
76
+ """
77
+ Fix bfloat16 support for nearest neighbor interpolation.
78
+ """
79
+ return super().forward(x.float()).type_as(x)
80
+
81
+
82
+ class Resample(nn.Module):
83
+
84
+ def __init__(self, dim, mode):
85
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
86
+ 'downsample3d')
87
+ super().__init__()
88
+ self.dim = dim
89
+ self.mode = mode
90
+
91
+ # layers
92
+ if mode == 'upsample2d':
93
+ self.resample = nn.Sequential(
94
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
95
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
96
+ elif mode == 'upsample3d':
97
+ self.resample = nn.Sequential(
98
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
99
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
100
+ self.time_conv = CausalConv3d(dim,
101
+ dim * 2, (3, 1, 1),
102
+ padding=(1, 0, 0))
103
+
104
+ elif mode == 'downsample2d':
105
+ self.resample = nn.Sequential(
106
+ nn.ZeroPad2d((0, 1, 0, 1)),
107
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
108
+ elif mode == 'downsample3d':
109
+ self.resample = nn.Sequential(
110
+ nn.ZeroPad2d((0, 1, 0, 1)),
111
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
112
+ self.time_conv = CausalConv3d(dim,
113
+ dim, (3, 1, 1),
114
+ stride=(2, 1, 1),
115
+ padding=(0, 0, 0))
116
+
117
+ else:
118
+ self.resample = nn.Identity()
119
+
120
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
121
+ b, c, t, h, w = x.size()
122
+ if self.mode == 'upsample3d':
123
+ if feat_cache is not None:
124
+ idx = feat_idx[0]
125
+ if feat_cache[idx] is None:
126
+ feat_cache[idx] = 'Rep'
127
+ feat_idx[0] += 1
128
+ else:
129
+
130
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
131
+ if cache_x.shape[2] < 2 and feat_cache[
132
+ idx] is not None and feat_cache[idx] != 'Rep':
133
+ # cache last frame of last two chunk
134
+ cache_x = torch.cat([
135
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
136
+ cache_x.device), cache_x
137
+ ],
138
+ dim=2)
139
+ if cache_x.shape[2] < 2 and feat_cache[
140
+ idx] is not None and feat_cache[idx] == 'Rep':
141
+ cache_x = torch.cat([
142
+ torch.zeros_like(cache_x).to(cache_x.device),
143
+ cache_x
144
+ ],
145
+ dim=2)
146
+ if feat_cache[idx] == 'Rep':
147
+ x = self.time_conv(x)
148
+ else:
149
+ x = self.time_conv(x, feat_cache[idx])
150
+ feat_cache[idx] = cache_x
151
+ feat_idx[0] += 1
152
+
153
+ x = x.reshape(b, 2, c, t, h, w)
154
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
155
+ 3)
156
+ x = x.reshape(b, c, t * 2, h, w)
157
+ t = x.shape[2]
158
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
159
+ x = self.resample(x)
160
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
161
+
162
+ if self.mode == 'downsample3d':
163
+ if feat_cache is not None:
164
+ idx = feat_idx[0]
165
+ if feat_cache[idx] is None:
166
+ feat_cache[idx] = x.clone()
167
+ feat_idx[0] += 1
168
+ else:
169
+ cache_x = x[:, :, -1:, :, :].clone()
170
+ x = self.time_conv(
171
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
172
+ feat_cache[idx] = cache_x
173
+ feat_idx[0] += 1
174
+ return x
175
+
176
+ def init_weight(self, conv):
177
+ conv_weight = conv.weight
178
+ nn.init.zeros_(conv_weight)
179
+ c1, c2, t, h, w = conv_weight.size()
180
+ one_matrix = torch.eye(c1, c2)
181
+ init_matrix = one_matrix
182
+ nn.init.zeros_(conv_weight)
183
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix
184
+ conv.weight.data.copy_(conv_weight)
185
+ nn.init.zeros_(conv.bias.data)
186
+
187
+ def init_weight2(self, conv):
188
+ conv_weight = conv.weight.data
189
+ nn.init.zeros_(conv_weight)
190
+ c1, c2, t, h, w = conv_weight.size()
191
+ init_matrix = torch.eye(c1 // 2, c2)
192
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
193
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
194
+ conv.weight.data.copy_(conv_weight)
195
+ nn.init.zeros_(conv.bias.data)
196
+
197
+
198
+
199
+ def patchify(x, patch_size):
200
+ if patch_size == 1:
201
+ return x
202
+ if x.dim() == 4:
203
+ x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
204
+ elif x.dim() == 5:
205
+ x = rearrange(x,
206
+ "b c f (h q) (w r) -> b (c r q) f h w",
207
+ q=patch_size,
208
+ r=patch_size)
209
+ else:
210
+ raise ValueError(f"Invalid input shape: {x.shape}")
211
+ return x
212
+
213
+
214
+ def unpatchify(x, patch_size):
215
+ if patch_size == 1:
216
+ return x
217
+ if x.dim() == 4:
218
+ x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
219
+ elif x.dim() == 5:
220
+ x = rearrange(x,
221
+ "b (c r q) f h w -> b c f (h q) (w r)",
222
+ q=patch_size,
223
+ r=patch_size)
224
+ return x
225
+
226
+
227
+ class Resample38(Resample):
228
+
229
+ def __init__(self, dim, mode):
230
+ assert mode in (
231
+ "none",
232
+ "upsample2d",
233
+ "upsample3d",
234
+ "downsample2d",
235
+ "downsample3d",
236
+ )
237
+ super(Resample, self).__init__()
238
+ self.dim = dim
239
+ self.mode = mode
240
+
241
+ # layers
242
+ if mode == "upsample2d":
243
+ self.resample = nn.Sequential(
244
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
245
+ nn.Conv2d(dim, dim, 3, padding=1),
246
+ )
247
+ elif mode == "upsample3d":
248
+ self.resample = nn.Sequential(
249
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
250
+ nn.Conv2d(dim, dim, 3, padding=1),
251
+ )
252
+ self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
253
+ elif mode == "downsample2d":
254
+ self.resample = nn.Sequential(
255
+ nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
256
+ )
257
+ elif mode == "downsample3d":
258
+ self.resample = nn.Sequential(
259
+ nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
260
+ )
261
+ self.time_conv = CausalConv3d(
262
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
263
+ )
264
+ else:
265
+ self.resample = nn.Identity()
266
+
267
+ class ResidualBlock(nn.Module):
268
+
269
+ def __init__(self, in_dim, out_dim, dropout=0.0):
270
+ super().__init__()
271
+ self.in_dim = in_dim
272
+ self.out_dim = out_dim
273
+
274
+ # layers
275
+ self.residual = nn.Sequential(
276
+ RMS_norm(in_dim, images=False), nn.SiLU(),
277
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
278
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
279
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
280
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
281
+ if in_dim != out_dim else nn.Identity()
282
+
283
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
284
+ h = self.shortcut(x)
285
+ for layer in self.residual:
286
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
287
+ idx = feat_idx[0]
288
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
289
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
290
+ # cache last frame of last two chunk
291
+ cache_x = torch.cat([
292
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
293
+ cache_x.device), cache_x
294
+ ],
295
+ dim=2)
296
+ x = layer(x, feat_cache[idx])
297
+ feat_cache[idx] = cache_x
298
+ feat_idx[0] += 1
299
+ else:
300
+ x = layer(x)
301
+ return x + h
302
+
303
+
304
+ class AttentionBlock(nn.Module):
305
+ """
306
+ Causal self-attention with a single head.
307
+ """
308
+
309
+ def __init__(self, dim):
310
+ super().__init__()
311
+ self.dim = dim
312
+
313
+ # layers
314
+ self.norm = RMS_norm(dim)
315
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
316
+ self.proj = nn.Conv2d(dim, dim, 1)
317
+
318
+ # zero out the last layer params
319
+ nn.init.zeros_(self.proj.weight)
320
+
321
+ def forward(self, x):
322
+ identity = x
323
+ b, c, t, h, w = x.size()
324
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
325
+ x = self.norm(x)
326
+ # compute query, key, value
327
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
328
+ 0, 1, 3, 2).contiguous().chunk(3, dim=-1)
329
+
330
+ # apply attention
331
+ x = F.scaled_dot_product_attention(
332
+ q,
333
+ k,
334
+ v,
335
+ #attn_mask=block_causal_mask(q, block_size=h * w)
336
+ )
337
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
338
+
339
+ # output
340
+ x = self.proj(x)
341
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
342
+ return x + identity
343
+
344
+
345
+ class AvgDown3D(nn.Module):
346
+ def __init__(
347
+ self,
348
+ in_channels,
349
+ out_channels,
350
+ factor_t,
351
+ factor_s=1,
352
+ ):
353
+ super().__init__()
354
+ self.in_channels = in_channels
355
+ self.out_channels = out_channels
356
+ self.factor_t = factor_t
357
+ self.factor_s = factor_s
358
+ self.factor = self.factor_t * self.factor_s * self.factor_s
359
+
360
+ assert in_channels * self.factor % out_channels == 0
361
+ self.group_size = in_channels * self.factor // out_channels
362
+
363
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
364
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
365
+ pad = (0, 0, 0, 0, pad_t, 0)
366
+ x = F.pad(x, pad)
367
+ B, C, T, H, W = x.shape
368
+ x = x.view(
369
+ B,
370
+ C,
371
+ T // self.factor_t,
372
+ self.factor_t,
373
+ H // self.factor_s,
374
+ self.factor_s,
375
+ W // self.factor_s,
376
+ self.factor_s,
377
+ )
378
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
379
+ x = x.view(
380
+ B,
381
+ C * self.factor,
382
+ T // self.factor_t,
383
+ H // self.factor_s,
384
+ W // self.factor_s,
385
+ )
386
+ x = x.view(
387
+ B,
388
+ self.out_channels,
389
+ self.group_size,
390
+ T // self.factor_t,
391
+ H // self.factor_s,
392
+ W // self.factor_s,
393
+ )
394
+ x = x.mean(dim=2)
395
+ return x
396
+
397
+
398
+ class DupUp3D(nn.Module):
399
+ def __init__(
400
+ self,
401
+ in_channels: int,
402
+ out_channels: int,
403
+ factor_t,
404
+ factor_s=1,
405
+ ):
406
+ super().__init__()
407
+ self.in_channels = in_channels
408
+ self.out_channels = out_channels
409
+
410
+ self.factor_t = factor_t
411
+ self.factor_s = factor_s
412
+ self.factor = self.factor_t * self.factor_s * self.factor_s
413
+
414
+ assert out_channels * self.factor % in_channels == 0
415
+ self.repeats = out_channels * self.factor // in_channels
416
+
417
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
418
+ x = x.repeat_interleave(self.repeats, dim=1)
419
+ x = x.view(
420
+ x.size(0),
421
+ self.out_channels,
422
+ self.factor_t,
423
+ self.factor_s,
424
+ self.factor_s,
425
+ x.size(2),
426
+ x.size(3),
427
+ x.size(4),
428
+ )
429
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
430
+ x = x.view(
431
+ x.size(0),
432
+ self.out_channels,
433
+ x.size(2) * self.factor_t,
434
+ x.size(4) * self.factor_s,
435
+ x.size(6) * self.factor_s,
436
+ )
437
+ if first_chunk:
438
+ x = x[:, :, self.factor_t - 1 :, :, :]
439
+ return x
440
+
441
+
442
+ class Down_ResidualBlock(nn.Module):
443
+ def __init__(
444
+ self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False
445
+ ):
446
+ super().__init__()
447
+
448
+ # Shortcut path with downsample
449
+ self.avg_shortcut = AvgDown3D(
450
+ in_dim,
451
+ out_dim,
452
+ factor_t=2 if temperal_downsample else 1,
453
+ factor_s=2 if down_flag else 1,
454
+ )
455
+
456
+ # Main path with residual blocks and downsample
457
+ downsamples = []
458
+ for _ in range(mult):
459
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
460
+ in_dim = out_dim
461
+
462
+ # Add the final downsample block
463
+ if down_flag:
464
+ mode = "downsample3d" if temperal_downsample else "downsample2d"
465
+ downsamples.append(Resample38(out_dim, mode=mode))
466
+
467
+ self.downsamples = nn.Sequential(*downsamples)
468
+
469
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
470
+ x_copy = x.clone()
471
+ for module in self.downsamples:
472
+ x = module(x, feat_cache, feat_idx)
473
+
474
+ return x + self.avg_shortcut(x_copy)
475
+
476
+
477
+ class Up_ResidualBlock(nn.Module):
478
+ def __init__(
479
+ self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False
480
+ ):
481
+ super().__init__()
482
+ # Shortcut path with upsample
483
+ if up_flag:
484
+ self.avg_shortcut = DupUp3D(
485
+ in_dim,
486
+ out_dim,
487
+ factor_t=2 if temperal_upsample else 1,
488
+ factor_s=2 if up_flag else 1,
489
+ )
490
+ else:
491
+ self.avg_shortcut = None
492
+
493
+ # Main path with residual blocks and upsample
494
+ upsamples = []
495
+ for _ in range(mult):
496
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
497
+ in_dim = out_dim
498
+
499
+ # Add the final upsample block
500
+ if up_flag:
501
+ mode = "upsample3d" if temperal_upsample else "upsample2d"
502
+ upsamples.append(Resample38(out_dim, mode=mode))
503
+
504
+ self.upsamples = nn.Sequential(*upsamples)
505
+
506
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
507
+ x_main = x.clone()
508
+ for module in self.upsamples:
509
+ x_main = module(x_main, feat_cache, feat_idx)
510
+ if self.avg_shortcut is not None:
511
+ x_shortcut = self.avg_shortcut(x, first_chunk)
512
+ return x_main + x_shortcut
513
+ else:
514
+ return x_main
515
+
516
+
517
+ class Encoder3d(nn.Module):
518
+
519
+ def __init__(self,
520
+ dim=128,
521
+ z_dim=4,
522
+ dim_mult=[1, 2, 4, 4],
523
+ num_res_blocks=2,
524
+ attn_scales=[],
525
+ temperal_downsample=[True, True, False],
526
+ dropout=0.0):
527
+ super().__init__()
528
+ self.dim = dim
529
+ self.z_dim = z_dim
530
+ self.dim_mult = dim_mult
531
+ self.num_res_blocks = num_res_blocks
532
+ self.attn_scales = attn_scales
533
+ self.temperal_downsample = temperal_downsample
534
+
535
+ # dimensions
536
+ dims = [dim * u for u in [1] + dim_mult]
537
+ scale = 1.0
538
+
539
+ # init block
540
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
541
+
542
+ # downsample blocks
543
+ downsamples = []
544
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
545
+ # residual (+attention) blocks
546
+ for _ in range(num_res_blocks):
547
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
548
+ if scale in attn_scales:
549
+ downsamples.append(AttentionBlock(out_dim))
550
+ in_dim = out_dim
551
+
552
+ # downsample block
553
+ if i != len(dim_mult) - 1:
554
+ mode = 'downsample3d' if temperal_downsample[
555
+ i] else 'downsample2d'
556
+ downsamples.append(Resample(out_dim, mode=mode))
557
+ scale /= 2.0
558
+ self.downsamples = nn.Sequential(*downsamples)
559
+
560
+ # middle blocks
561
+ self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
562
+ AttentionBlock(out_dim),
563
+ ResidualBlock(out_dim, out_dim, dropout))
564
+
565
+ # output blocks
566
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
567
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
568
+
569
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
570
+ if feat_cache is not None:
571
+ idx = feat_idx[0]
572
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
573
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
574
+ # cache last frame of last two chunk
575
+ cache_x = torch.cat([
576
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
577
+ cache_x.device), cache_x
578
+ ],
579
+ dim=2)
580
+ x = self.conv1(x, feat_cache[idx])
581
+ feat_cache[idx] = cache_x
582
+ feat_idx[0] += 1
583
+ else:
584
+ x = self.conv1(x)
585
+
586
+ ## downsamples
587
+ for layer in self.downsamples:
588
+ if feat_cache is not None:
589
+ x = layer(x, feat_cache, feat_idx)
590
+ else:
591
+ x = layer(x)
592
+
593
+ ## middle
594
+ for layer in self.middle:
595
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
596
+ x = layer(x, feat_cache, feat_idx)
597
+ else:
598
+ x = layer(x)
599
+
600
+ ## head
601
+ for layer in self.head:
602
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
603
+ idx = feat_idx[0]
604
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
605
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
606
+ # cache last frame of last two chunk
607
+ cache_x = torch.cat([
608
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
609
+ cache_x.device), cache_x
610
+ ],
611
+ dim=2)
612
+ x = layer(x, feat_cache[idx])
613
+ feat_cache[idx] = cache_x
614
+ feat_idx[0] += 1
615
+ else:
616
+ x = layer(x)
617
+ return x
618
+
619
+
620
+ class Encoder3d_38(nn.Module):
621
+
622
+ def __init__(self,
623
+ dim=128,
624
+ z_dim=4,
625
+ dim_mult=[1, 2, 4, 4],
626
+ num_res_blocks=2,
627
+ attn_scales=[],
628
+ temperal_downsample=[False, True, True],
629
+ dropout=0.0):
630
+ super().__init__()
631
+ self.dim = dim
632
+ self.z_dim = z_dim
633
+ self.dim_mult = dim_mult
634
+ self.num_res_blocks = num_res_blocks
635
+ self.attn_scales = attn_scales
636
+ self.temperal_downsample = temperal_downsample
637
+
638
+ # dimensions
639
+ dims = [dim * u for u in [1] + dim_mult]
640
+ scale = 1.0
641
+
642
+ # init block
643
+ self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
644
+
645
+ # downsample blocks
646
+ downsamples = []
647
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
648
+ t_down_flag = (
649
+ temperal_downsample[i] if i < len(temperal_downsample) else False
650
+ )
651
+ downsamples.append(
652
+ Down_ResidualBlock(
653
+ in_dim=in_dim,
654
+ out_dim=out_dim,
655
+ dropout=dropout,
656
+ mult=num_res_blocks,
657
+ temperal_downsample=t_down_flag,
658
+ down_flag=i != len(dim_mult) - 1,
659
+ )
660
+ )
661
+ scale /= 2.0
662
+ self.downsamples = nn.Sequential(*downsamples)
663
+
664
+ # middle blocks
665
+ self.middle = nn.Sequential(
666
+ ResidualBlock(out_dim, out_dim, dropout),
667
+ AttentionBlock(out_dim),
668
+ ResidualBlock(out_dim, out_dim, dropout),
669
+ )
670
+
671
+ # # output blocks
672
+ self.head = nn.Sequential(
673
+ RMS_norm(out_dim, images=False),
674
+ nn.SiLU(),
675
+ CausalConv3d(out_dim, z_dim, 3, padding=1),
676
+ )
677
+
678
+
679
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
680
+
681
+ if feat_cache is not None:
682
+ idx = feat_idx[0]
683
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
684
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
685
+ cache_x = torch.cat(
686
+ [
687
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
688
+ cache_x,
689
+ ],
690
+ dim=2,
691
+ )
692
+ x = self.conv1(x, feat_cache[idx])
693
+ feat_cache[idx] = cache_x
694
+ feat_idx[0] += 1
695
+ else:
696
+ x = self.conv1(x)
697
+
698
+ ## downsamples
699
+ for layer in self.downsamples:
700
+ if feat_cache is not None:
701
+ x = layer(x, feat_cache, feat_idx)
702
+ else:
703
+ x = layer(x)
704
+
705
+ ## middle
706
+ for layer in self.middle:
707
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
708
+ x = layer(x, feat_cache, feat_idx)
709
+ else:
710
+ x = layer(x)
711
+
712
+ ## head
713
+ for layer in self.head:
714
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
715
+ idx = feat_idx[0]
716
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
717
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
718
+ cache_x = torch.cat(
719
+ [
720
+ feat_cache[idx][:, :, -1, :, :]
721
+ .unsqueeze(2)
722
+ .to(cache_x.device),
723
+ cache_x,
724
+ ],
725
+ dim=2,
726
+ )
727
+ x = layer(x, feat_cache[idx])
728
+ feat_cache[idx] = cache_x
729
+ feat_idx[0] += 1
730
+ else:
731
+ x = layer(x)
732
+
733
+ return x
734
+
735
+
736
+ class Decoder3d(nn.Module):
737
+
738
+ def __init__(self,
739
+ dim=128,
740
+ z_dim=4,
741
+ dim_mult=[1, 2, 4, 4],
742
+ num_res_blocks=2,
743
+ attn_scales=[],
744
+ temperal_upsample=[False, True, True],
745
+ dropout=0.0):
746
+ super().__init__()
747
+ self.dim = dim
748
+ self.z_dim = z_dim
749
+ self.dim_mult = dim_mult
750
+ self.num_res_blocks = num_res_blocks
751
+ self.attn_scales = attn_scales
752
+ self.temperal_upsample = temperal_upsample
753
+
754
+ # dimensions
755
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
756
+ scale = 1.0 / 2**(len(dim_mult) - 2)
757
+
758
+ # init block
759
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
760
+
761
+ # middle blocks
762
+ self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
763
+ AttentionBlock(dims[0]),
764
+ ResidualBlock(dims[0], dims[0], dropout))
765
+
766
+ # upsample blocks
767
+ upsamples = []
768
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
769
+ # residual (+attention) blocks
770
+ if i == 1 or i == 2 or i == 3:
771
+ in_dim = in_dim // 2
772
+ for _ in range(num_res_blocks + 1):
773
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
774
+ if scale in attn_scales:
775
+ upsamples.append(AttentionBlock(out_dim))
776
+ in_dim = out_dim
777
+
778
+ # upsample block
779
+ if i != len(dim_mult) - 1:
780
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
781
+ upsamples.append(Resample(out_dim, mode=mode))
782
+ scale *= 2.0
783
+ self.upsamples = nn.Sequential(*upsamples)
784
+
785
+ # output blocks
786
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
787
+ CausalConv3d(out_dim, 3, 3, padding=1))
788
+
789
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
790
+ ## conv1
791
+ if feat_cache is not None:
792
+ idx = feat_idx[0]
793
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
794
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
795
+ # cache last frame of last two chunk
796
+ cache_x = torch.cat([
797
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
798
+ cache_x.device), cache_x
799
+ ],
800
+ dim=2)
801
+ x = self.conv1(x, feat_cache[idx])
802
+ feat_cache[idx] = cache_x
803
+ feat_idx[0] += 1
804
+ else:
805
+ x = self.conv1(x)
806
+
807
+ ## middle
808
+ for layer in self.middle:
809
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
810
+ x = layer(x, feat_cache, feat_idx)
811
+ else:
812
+ x = layer(x)
813
+
814
+ ## upsamples
815
+ for layer in self.upsamples:
816
+ if feat_cache is not None:
817
+ x = layer(x, feat_cache, feat_idx)
818
+ else:
819
+ x = layer(x)
820
+
821
+ ## head
822
+ for layer in self.head:
823
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
824
+ idx = feat_idx[0]
825
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
826
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
827
+ # cache last frame of last two chunk
828
+ cache_x = torch.cat([
829
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
830
+ cache_x.device), cache_x
831
+ ],
832
+ dim=2)
833
+ x = layer(x, feat_cache[idx])
834
+ feat_cache[idx] = cache_x
835
+ feat_idx[0] += 1
836
+ else:
837
+ x = layer(x)
838
+ return x
839
+
840
+
841
+
842
+ class Decoder3d_38(nn.Module):
843
+
844
+ def __init__(self,
845
+ dim=128,
846
+ z_dim=4,
847
+ dim_mult=[1, 2, 4, 4],
848
+ num_res_blocks=2,
849
+ attn_scales=[],
850
+ temperal_upsample=[False, True, True],
851
+ dropout=0.0):
852
+ super().__init__()
853
+ self.dim = dim
854
+ self.z_dim = z_dim
855
+ self.dim_mult = dim_mult
856
+ self.num_res_blocks = num_res_blocks
857
+ self.attn_scales = attn_scales
858
+ self.temperal_upsample = temperal_upsample
859
+
860
+ # dimensions
861
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
862
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
863
+ # init block
864
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
865
+
866
+ # middle blocks
867
+ self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
868
+ AttentionBlock(dims[0]),
869
+ ResidualBlock(dims[0], dims[0], dropout))
870
+
871
+ # upsample blocks
872
+ upsamples = []
873
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
874
+ t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
875
+ upsamples.append(
876
+ Up_ResidualBlock(in_dim=in_dim,
877
+ out_dim=out_dim,
878
+ dropout=dropout,
879
+ mult=num_res_blocks + 1,
880
+ temperal_upsample=t_up_flag,
881
+ up_flag=i != len(dim_mult) - 1))
882
+ self.upsamples = nn.Sequential(*upsamples)
883
+
884
+ # output blocks
885
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
886
+ CausalConv3d(out_dim, 12, 3, padding=1))
887
+
888
+
889
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
890
+ if feat_cache is not None:
891
+ idx = feat_idx[0]
892
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
893
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
894
+ cache_x = torch.cat(
895
+ [
896
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
897
+ cache_x,
898
+ ],
899
+ dim=2,
900
+ )
901
+ x = self.conv1(x, feat_cache[idx])
902
+ feat_cache[idx] = cache_x
903
+ feat_idx[0] += 1
904
+ else:
905
+ x = self.conv1(x)
906
+
907
+ for layer in self.middle:
908
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
909
+ x = layer(x, feat_cache, feat_idx)
910
+ else:
911
+ x = layer(x)
912
+
913
+ ## upsamples
914
+ for layer in self.upsamples:
915
+ if feat_cache is not None:
916
+ x = layer(x, feat_cache, feat_idx, first_chunk)
917
+ else:
918
+ x = layer(x)
919
+
920
+ ## head
921
+ for layer in self.head:
922
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
923
+ idx = feat_idx[0]
924
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
925
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
926
+ cache_x = torch.cat(
927
+ [
928
+ feat_cache[idx][:, :, -1, :, :]
929
+ .unsqueeze(2)
930
+ .to(cache_x.device),
931
+ cache_x,
932
+ ],
933
+ dim=2,
934
+ )
935
+ x = layer(x, feat_cache[idx])
936
+ feat_cache[idx] = cache_x
937
+ feat_idx[0] += 1
938
+ else:
939
+ x = layer(x)
940
+ return x
941
+
942
+
943
+ def count_conv3d(model):
944
+ count = 0
945
+ for m in model.modules():
946
+ if isinstance(m, CausalConv3d):
947
+ count += 1
948
+ return count
949
+
950
+
951
+ class VideoVAE_(nn.Module):
952
+
953
+ def __init__(self,
954
+ dim=96,
955
+ z_dim=16,
956
+ dim_mult=[1, 2, 4, 4],
957
+ num_res_blocks=2,
958
+ attn_scales=[],
959
+ temperal_downsample=[False, True, True],
960
+ dropout=0.0):
961
+ super().__init__()
962
+ self.dim = dim
963
+ self.z_dim = z_dim
964
+ self.dim_mult = dim_mult
965
+ self.num_res_blocks = num_res_blocks
966
+ self.attn_scales = attn_scales
967
+ self.temperal_downsample = temperal_downsample
968
+ self.temperal_upsample = temperal_downsample[::-1]
969
+
970
+ # modules
971
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
972
+ attn_scales, self.temperal_downsample, dropout)
973
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
974
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
975
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
976
+ attn_scales, self.temperal_upsample, dropout)
977
+
978
+ def forward(self, x):
979
+ mu, log_var = self.encode(x)
980
+ z = self.reparameterize(mu, log_var)
981
+ x_recon = self.decode(z)
982
+ return x_recon, mu, log_var
983
+
984
+ def encode(self, x, scale):
985
+ self.clear_cache()
986
+ ## cache
987
+ t = x.shape[2]
988
+ iter_ = 1 + (t - 1) // 4
989
+
990
+ for i in range(iter_):
991
+ self._enc_conv_idx = [0]
992
+ if i == 0:
993
+ out = self.encoder(x[:, :, :1, :, :],
994
+ feat_cache=self._enc_feat_map,
995
+ feat_idx=self._enc_conv_idx)
996
+ else:
997
+ out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
998
+ feat_cache=self._enc_feat_map,
999
+ feat_idx=self._enc_conv_idx)
1000
+ out = torch.cat([out, out_], 2)
1001
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
1002
+ if isinstance(scale[0], torch.Tensor):
1003
+ scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
1004
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1005
+ 1, self.z_dim, 1, 1, 1)
1006
+ else:
1007
+ scale = scale.to(dtype=mu.dtype, device=mu.device)
1008
+ mu = (mu - scale[0]) * scale[1]
1009
+ return mu
1010
+
1011
+ def decode(self, z, scale):
1012
+ self.clear_cache()
1013
+ # z: [b,c,t,h,w]
1014
+ if isinstance(scale[0], torch.Tensor):
1015
+ scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
1016
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1017
+ 1, self.z_dim, 1, 1, 1)
1018
+ else:
1019
+ scale = scale.to(dtype=z.dtype, device=z.device)
1020
+ z = z / scale[1] + scale[0]
1021
+ iter_ = z.shape[2]
1022
+ x = self.conv2(z)
1023
+ for i in range(iter_):
1024
+ self._conv_idx = [0]
1025
+ if i == 0:
1026
+ out = self.decoder(x[:, :, i:i + 1, :, :],
1027
+ feat_cache=self._feat_map,
1028
+ feat_idx=self._conv_idx)
1029
+ else:
1030
+ out_ = self.decoder(x[:, :, i:i + 1, :, :],
1031
+ feat_cache=self._feat_map,
1032
+ feat_idx=self._conv_idx)
1033
+ out = torch.cat([out, out_], 2) # may add tensor offload
1034
+ return out
1035
+
1036
+ def reparameterize(self, mu, log_var):
1037
+ std = torch.exp(0.5 * log_var)
1038
+ eps = torch.randn_like(std)
1039
+ return eps * std + mu
1040
+
1041
+ def sample(self, imgs, deterministic=False):
1042
+ mu, log_var = self.encode(imgs)
1043
+ if deterministic:
1044
+ return mu
1045
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
1046
+ return mu + std * torch.randn_like(std)
1047
+
1048
+ def clear_cache(self):
1049
+ self._conv_num = count_conv3d(self.decoder)
1050
+ self._conv_idx = [0]
1051
+ self._feat_map = [None] * self._conv_num
1052
+ # cache encode
1053
+ self._enc_conv_num = count_conv3d(self.encoder)
1054
+ self._enc_conv_idx = [0]
1055
+ self._enc_feat_map = [None] * self._enc_conv_num
1056
+
1057
+
1058
+ class WanVideoVAE(nn.Module):
1059
+
1060
+ def __init__(self, z_dim=16):
1061
+ super().__init__()
1062
+
1063
+ mean = [
1064
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
1065
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
1066
+ ]
1067
+ std = [
1068
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
1069
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
1070
+ ]
1071
+ self.mean = torch.tensor(mean)
1072
+ self.std = torch.tensor(std)
1073
+ self.scale = [self.mean, 1.0 / self.std]
1074
+
1075
+ # init model
1076
+ self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
1077
+ self.upsampling_factor = 8
1078
+ self.z_dim = z_dim
1079
+
1080
+
1081
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
1082
+ x = torch.ones((length,))
1083
+ if not left_bound:
1084
+ x[:border_width] = (torch.arange(border_width) + 1) / border_width
1085
+ if not right_bound:
1086
+ x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
1087
+ return x
1088
+
1089
+
1090
+ def build_mask(self, data, is_bound, border_width):
1091
+ _, _, _, H, W = data.shape
1092
+ h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
1093
+ w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
1094
+
1095
+ h = repeat(h, "H -> H W", H=H, W=W)
1096
+ w = repeat(w, "W -> H W", H=H, W=W)
1097
+
1098
+ mask = torch.stack([h, w]).min(dim=0).values
1099
+ mask = rearrange(mask, "H W -> 1 1 1 H W")
1100
+ return mask
1101
+
1102
+
1103
+ def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
1104
+ _, _, T, H, W = hidden_states.shape
1105
+ size_h, size_w = tile_size
1106
+ stride_h, stride_w = tile_stride
1107
+
1108
+ # Split tasks
1109
+ tasks = []
1110
+ for h in range(0, H, stride_h):
1111
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
1112
+ for w in range(0, W, stride_w):
1113
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
1114
+ h_, w_ = h + size_h, w + size_w
1115
+ tasks.append((h, h_, w, w_))
1116
+
1117
+ data_device = "cpu"
1118
+ computation_device = device
1119
+
1120
+ out_T = T * 4 - 3
1121
+ weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
1122
+ values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
1123
+
1124
+ for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
1125
+ hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
1126
+ hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
1127
+
1128
+ mask = self.build_mask(
1129
+ hidden_states_batch,
1130
+ is_bound=(h==0, h_>=H, w==0, w_>=W),
1131
+ border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
1132
+ ).to(dtype=hidden_states.dtype, device=data_device)
1133
+
1134
+ target_h = h * self.upsampling_factor
1135
+ target_w = w * self.upsampling_factor
1136
+ values[
1137
+ :,
1138
+ :,
1139
+ :,
1140
+ target_h:target_h + hidden_states_batch.shape[3],
1141
+ target_w:target_w + hidden_states_batch.shape[4],
1142
+ ] += hidden_states_batch * mask
1143
+ weight[
1144
+ :,
1145
+ :,
1146
+ :,
1147
+ target_h: target_h + hidden_states_batch.shape[3],
1148
+ target_w: target_w + hidden_states_batch.shape[4],
1149
+ ] += mask
1150
+ values = values / weight
1151
+ values = values.clamp_(-1, 1)
1152
+ return values
1153
+
1154
+
1155
+ def tiled_encode(self, video, device, tile_size, tile_stride):
1156
+ _, _, T, H, W = video.shape
1157
+ size_h, size_w = tile_size
1158
+ stride_h, stride_w = tile_stride
1159
+
1160
+ # Split tasks
1161
+ tasks = []
1162
+ for h in range(0, H, stride_h):
1163
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
1164
+ for w in range(0, W, stride_w):
1165
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
1166
+ h_, w_ = h + size_h, w + size_w
1167
+ tasks.append((h, h_, w, w_))
1168
+
1169
+ data_device = "cpu"
1170
+ computation_device = device
1171
+
1172
+ out_T = (T + 3) // 4
1173
+ weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
1174
+ values = torch.zeros((1, self.z_dim, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
1175
+
1176
+ for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
1177
+ hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
1178
+ hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
1179
+
1180
+ mask = self.build_mask(
1181
+ hidden_states_batch,
1182
+ is_bound=(h==0, h_>=H, w==0, w_>=W),
1183
+ border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
1184
+ ).to(dtype=video.dtype, device=data_device)
1185
+
1186
+ target_h = h // self.upsampling_factor
1187
+ target_w = w // self.upsampling_factor
1188
+ values[
1189
+ :,
1190
+ :,
1191
+ :,
1192
+ target_h:target_h + hidden_states_batch.shape[3],
1193
+ target_w:target_w + hidden_states_batch.shape[4],
1194
+ ] += hidden_states_batch * mask
1195
+ weight[
1196
+ :,
1197
+ :,
1198
+ :,
1199
+ target_h: target_h + hidden_states_batch.shape[3],
1200
+ target_w: target_w + hidden_states_batch.shape[4],
1201
+ ] += mask
1202
+ values = values / weight
1203
+ return values
1204
+
1205
+
1206
+ def single_encode(self, video, device):
1207
+ video = video.to(device)
1208
+ x = self.model.encode(video, self.scale)
1209
+ return x
1210
+
1211
+
1212
+ def single_decode(self, hidden_state, device):
1213
+ hidden_state = hidden_state.to(device)
1214
+ video = self.model.decode(hidden_state, self.scale)
1215
+ return video.clamp_(-1, 1)
1216
+
1217
+
1218
+ def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
1219
+
1220
+ videos = [video.to("cpu") for video in videos]
1221
+ hidden_states = []
1222
+ for video in videos:
1223
+ video = video.unsqueeze(0)
1224
+ if tiled:
1225
+ tile_size = (tile_size[0] * self.upsampling_factor, tile_size[1] * self.upsampling_factor)
1226
+ tile_stride = (tile_stride[0] * self.upsampling_factor, tile_stride[1] * self.upsampling_factor)
1227
+ hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
1228
+ else:
1229
+ hidden_state = self.single_encode(video, device)
1230
+ hidden_state = hidden_state.squeeze(0)
1231
+ hidden_states.append(hidden_state)
1232
+ hidden_states = torch.stack(hidden_states)
1233
+ return hidden_states
1234
+
1235
+
1236
+ def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
1237
+ if tiled:
1238
+ video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
1239
+ else:
1240
+ video = self.single_decode(hidden_states, device)
1241
+ return video
1242
+
1243
+
1244
+ @staticmethod
1245
+ def state_dict_converter():
1246
+ return WanVideoVAEStateDictConverter()
1247
+
1248
+
1249
+ class WanVideoVAEStateDictConverter:
1250
+
1251
+ def __init__(self):
1252
+ pass
1253
+
1254
+ def from_civitai(self, state_dict):
1255
+ state_dict_ = {}
1256
+ if 'model_state' in state_dict:
1257
+ state_dict = state_dict['model_state']
1258
+ for name in state_dict:
1259
+ state_dict_['model.' + name] = state_dict[name]
1260
+ return state_dict_
1261
+
1262
+
1263
+ class VideoVAE38_(VideoVAE_):
1264
+
1265
+ def __init__(self,
1266
+ dim=160,
1267
+ z_dim=48,
1268
+ dec_dim=256,
1269
+ dim_mult=[1, 2, 4, 4],
1270
+ num_res_blocks=2,
1271
+ attn_scales=[],
1272
+ temperal_downsample=[False, True, True],
1273
+ dropout=0.0):
1274
+ super(VideoVAE_, self).__init__()
1275
+ self.dim = dim
1276
+ self.z_dim = z_dim
1277
+ self.dim_mult = dim_mult
1278
+ self.num_res_blocks = num_res_blocks
1279
+ self.attn_scales = attn_scales
1280
+ self.temperal_downsample = temperal_downsample
1281
+ self.temperal_upsample = temperal_downsample[::-1]
1282
+
1283
+ # modules
1284
+ self.encoder = Encoder3d_38(dim, z_dim * 2, dim_mult, num_res_blocks,
1285
+ attn_scales, self.temperal_downsample, dropout)
1286
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
1287
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
1288
+ self.decoder = Decoder3d_38(dec_dim, z_dim, dim_mult, num_res_blocks,
1289
+ attn_scales, self.temperal_upsample, dropout)
1290
+
1291
+
1292
+ def encode(self, x, scale):
1293
+ self.clear_cache()
1294
+ x = patchify(x, patch_size=2)
1295
+ t = x.shape[2]
1296
+ iter_ = 1 + (t - 1) // 4
1297
+ for i in range(iter_):
1298
+ self._enc_conv_idx = [0]
1299
+ if i == 0:
1300
+ out = self.encoder(x[:, :, :1, :, :],
1301
+ feat_cache=self._enc_feat_map,
1302
+ feat_idx=self._enc_conv_idx)
1303
+ else:
1304
+ out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
1305
+ feat_cache=self._enc_feat_map,
1306
+ feat_idx=self._enc_conv_idx)
1307
+ out = torch.cat([out, out_], 2)
1308
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
1309
+ if isinstance(scale[0], torch.Tensor):
1310
+ scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
1311
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1312
+ 1, self.z_dim, 1, 1, 1)
1313
+ else:
1314
+ scale = scale.to(dtype=mu.dtype, device=mu.device)
1315
+ mu = (mu - scale[0]) * scale[1]
1316
+ self.clear_cache()
1317
+ return mu
1318
+
1319
+
1320
+ def decode(self, z, scale):
1321
+ self.clear_cache()
1322
+ if isinstance(scale[0], torch.Tensor):
1323
+ scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
1324
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1325
+ 1, self.z_dim, 1, 1, 1)
1326
+ else:
1327
+ scale = scale.to(dtype=z.dtype, device=z.device)
1328
+ z = z / scale[1] + scale[0]
1329
+ iter_ = z.shape[2]
1330
+ x = self.conv2(z)
1331
+ for i in range(iter_):
1332
+ self._conv_idx = [0]
1333
+ if i == 0:
1334
+ out = self.decoder(x[:, :, i:i + 1, :, :],
1335
+ feat_cache=self._feat_map,
1336
+ feat_idx=self._conv_idx,
1337
+ first_chunk=True)
1338
+ else:
1339
+ out_ = self.decoder(x[:, :, i:i + 1, :, :],
1340
+ feat_cache=self._feat_map,
1341
+ feat_idx=self._conv_idx)
1342
+ out = torch.cat([out, out_], 2)
1343
+ out = unpatchify(out, patch_size=2)
1344
+ self.clear_cache()
1345
+ return out
1346
+
1347
+
1348
+ class WanVideoVAE38(WanVideoVAE):
1349
+
1350
+ def __init__(self, z_dim=48, dim=160):
1351
+ super(WanVideoVAE, self).__init__()
1352
+
1353
+ mean = [
1354
+ -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
1355
+ -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
1356
+ -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
1357
+ -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
1358
+ -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
1359
+ 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667
1360
+ ]
1361
+ std = [
1362
+ 0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
1363
+ 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
1364
+ 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
1365
+ 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
1366
+ 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
1367
+ 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
1368
+ ]
1369
+ self.mean = torch.tensor(mean)
1370
+ self.std = torch.tensor(std)
1371
+ self.scale = [self.mean, 1.0 / self.std]
1372
+
1373
+ # init model
1374
+ self.model = VideoVAE38_(z_dim=z_dim, dim=dim).eval().requires_grad_(False)
1375
+ self.upsampling_factor = 16
1376
+ self.z_dim = z_dim
dkt/pipelines/__init__.py ADDED
File without changes
dkt/pipelines/pipeline.py ADDED
@@ -0,0 +1,1965 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, warnings, glob, os, types
2
+ import numpy as np
3
+ from PIL import Image
4
+ from einops import repeat, reduce
5
+ from typing import Optional, Union
6
+ from dataclasses import dataclass
7
+ from modelscope import snapshot_download as ms_snap_download
8
+ from huggingface_hub import snapshot_download as hf_snap_download
9
+
10
+ from einops import rearrange
11
+ import numpy as np
12
+ from PIL import Image
13
+ from tqdm import tqdm
14
+ from typing import Optional
15
+ from typing_extensions import Literal
16
+
17
+ from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner
18
+ from ..models import ModelManager, load_state_dict
19
+ from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
20
+ from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
21
+ from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
22
+ from ..models.wan_video_image_encoder import WanImageEncoder
23
+ from ..models.wan_video_vace import VaceWanModel
24
+ from ..models.wan_video_motion_controller import WanMotionControllerModel
25
+ from ..schedulers.flow_match import FlowMatchScheduler
26
+ from ..prompters import WanPrompter
27
+ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
28
+ from ..lora import GeneralLoRALoader
29
+
30
+ from loguru import logger
31
+
32
+
33
+
34
+ import spaces
35
+
36
+
37
+ class BasePipeline(torch.nn.Module):
38
+
39
+ def __init__(
40
+ self,
41
+ device="cuda", torch_dtype=torch.float16,
42
+ height_division_factor=64, width_division_factor=64,
43
+ time_division_factor=None, time_division_remainder=None,
44
+ ):
45
+ super().__init__()
46
+ # The device and torch_dtype is used for the storage of intermediate variables, not models.
47
+ self.device = device
48
+ self.torch_dtype = torch_dtype
49
+ # The following parameters are used for shape check.
50
+ self.height_division_factor = height_division_factor
51
+ self.width_division_factor = width_division_factor
52
+ self.time_division_factor = time_division_factor
53
+ self.time_division_remainder = time_division_remainder
54
+ self.vram_management_enabled = False
55
+
56
+
57
+ def to(self, *args, **kwargs):
58
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
59
+ if device is not None:
60
+ self.device = device
61
+ if dtype is not None:
62
+ self.torch_dtype = dtype
63
+ super().to(*args, **kwargs)
64
+ return self
65
+
66
+
67
+ def check_resize_height_width(self, height, width, num_frames=None):
68
+ # Shape check
69
+ if height % self.height_division_factor != 0:
70
+ height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
71
+ print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
72
+ if width % self.width_division_factor != 0:
73
+ width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
74
+ print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
75
+ if num_frames is None:
76
+ return height, width
77
+ else:
78
+ if num_frames % self.time_division_factor != self.time_division_remainder:
79
+ num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
80
+ print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
81
+ return height, width, num_frames
82
+
83
+
84
+ def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
85
+ # Transform a PIL.Image to torch.Tensor
86
+ image = torch.Tensor(np.array(image, dtype=np.float32))
87
+ image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
88
+ image = image * ((max_value - min_value) / 255) + min_value
89
+ image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
90
+ return image
91
+
92
+
93
+ def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
94
+ # Transform a list of PIL.Image to torch.Tensor
95
+
96
+
97
+ if hasattr(video, 'length') and video.length is not None:
98
+ video = [self.preprocess_image(video[idx], torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for idx in range(video.length)]
99
+ else:
100
+ video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
101
+ video = torch.stack(video, dim=pattern.index("T") // 2)
102
+ return video
103
+
104
+
105
+ def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
106
+ # Transform a torch.Tensor to PIL.Image
107
+ if pattern != "H W C":
108
+ vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
109
+ image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
110
+ image = image.to(device="cpu", dtype=torch.uint8)
111
+ image = Image.fromarray(image.numpy())
112
+ return image
113
+
114
+
115
+ def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
116
+ # Transform a torch.Tensor to list of PIL.Image
117
+ if pattern != "T H W C":
118
+ vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
119
+ video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
120
+ return video
121
+
122
+
123
+ def load_models_to_device(self, model_names=[]):
124
+ if self.vram_management_enabled:
125
+ # offload models
126
+ for name, model in self.named_children():
127
+ if name not in model_names:
128
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
129
+ for module in model.modules():
130
+ if hasattr(module, "offload"):
131
+ module.offload()
132
+ else:
133
+ model.cpu()
134
+ torch.cuda.empty_cache()
135
+ # onload models
136
+ for name, model in self.named_children():
137
+ if name in model_names:
138
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
139
+ for module in model.modules():
140
+ if hasattr(module, "onload"):
141
+ module.onload()
142
+ else:
143
+ model.to(self.device)
144
+
145
+
146
+ def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
147
+ # Initialize Gaussian noise
148
+ generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
149
+ noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
150
+ noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
151
+ return noise
152
+
153
+
154
+ def enable_cpu_offload(self):
155
+ warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.")
156
+ self.vram_management_enabled = True
157
+
158
+
159
+ def get_vram(self):
160
+ return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
161
+
162
+
163
+ def freeze_except(self, model_names):
164
+ for name, model in self.named_children():
165
+ if name in model_names:
166
+ model.train()
167
+ model.requires_grad_(True)
168
+ else:
169
+ model.eval()
170
+ model.requires_grad_(False)
171
+
172
+
173
+ @dataclass
174
+ class ModelConfig:
175
+ path: Union[str, list[str]] = None
176
+ model_id: str = None
177
+ origin_file_pattern: Union[str, list[str]] = None
178
+ download_resource: str = "ModelScope"
179
+ offload_device: Optional[Union[str, torch.device]] = None
180
+ offload_dtype: Optional[torch.dtype] = None
181
+
182
+ def download_if_necessary(self, local_model_path="./checkpoints", skip_download=False, use_usp=False):
183
+ if self.path is None:
184
+ # Check model_id and origin_file_pattern
185
+ if self.model_id is None:
186
+ raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
187
+
188
+ # Skip if not in rank 0
189
+ if use_usp:
190
+ import torch.distributed as dist
191
+ skip_download = dist.get_rank() != 0
192
+
193
+ # Check whether the origin path is a folder
194
+ if self.origin_file_pattern is None or self.origin_file_pattern == "":
195
+ self.origin_file_pattern = ""
196
+ allow_file_pattern = None
197
+ is_folder = True
198
+ elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"):
199
+ allow_file_pattern = self.origin_file_pattern + "*"
200
+ is_folder = True
201
+ else:
202
+ allow_file_pattern = self.origin_file_pattern
203
+ is_folder = False
204
+
205
+ # Download
206
+ if not skip_download:
207
+
208
+ # downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
209
+ #!========================================================================================================================
210
+ downloaded_files = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
211
+ #!========================================================================================================================
212
+
213
+ if downloaded_files is None or len(downloaded_files) == 0 or not os.path.exists(downloaded_files[0]) :
214
+ #todo
215
+ # if 'Wan2' in self.model_id:
216
+ # ms_snap_download(
217
+ # self.model_id,
218
+ # local_dir=os.path.join(local_model_path, self.model_id),
219
+ # allow_file_pattern=allow_file_pattern,
220
+ # ignore_file_pattern=downloaded_files,
221
+ # )
222
+ # else:
223
+ hf_snap_download(
224
+ repo_id=self.model_id,
225
+ local_dir=os.path.join(local_model_path, self.model_id),
226
+ allow_patterns=allow_file_pattern,
227
+ ignore_patterns=downloaded_files if downloaded_files else None
228
+ )
229
+
230
+ # Let rank 1, 2, ... wait for rank 0
231
+ if use_usp:
232
+ import torch.distributed as dist
233
+ dist.barrier(device_ids=[dist.get_rank()])
234
+
235
+ # Return downloaded files
236
+ if is_folder:
237
+ self.path = os.path.join(local_model_path, self.model_id, self.origin_file_pattern)
238
+ else:
239
+ self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
240
+ if isinstance(self.path, list) and len(self.path) == 1:
241
+ self.path = self.path[0]
242
+
243
+
244
+
245
+
246
+ class WanVideoPipeline(BasePipeline):
247
+
248
+ def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
249
+ super().__init__(
250
+ device=device, torch_dtype=torch_dtype,
251
+ height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
252
+ )
253
+ self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
254
+
255
+ self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
256
+ self.text_encoder: WanTextEncoder = None
257
+ self.image_encoder: WanImageEncoder = None
258
+ self.dit: WanModel = None
259
+ self.dit2: WanModel = None
260
+ self.vae: WanVideoVAE = None
261
+ self.motion_controller: WanMotionControllerModel = None
262
+ self.vace: VaceWanModel = None
263
+ self.in_iteration_models = ("dit", "motion_controller", "vace")
264
+ self.in_iteration_models_2 = ("dit2", "motion_controller", "vace")
265
+ self.unit_runner = PipelineUnitRunner()
266
+ self.units = [
267
+ WanVideoUnit_ShapeChecker(),
268
+ WanVideoUnit_NoiseInitializer(),
269
+ WanVideoUnit_InputVideoEmbedder(),
270
+ WanVideoUnit_PromptEmbedder(),
271
+ # WanVideoUnit_ImageEmbedderVAE(),
272
+ # WanVideoUnit_ImageEmbedderCLIP(),
273
+ # WanVideoUnit_ImageEmbedderFused(),
274
+ # WanVideoUnit_FunControl(),
275
+ WanVideoUnit_FunControl_Mask(),
276
+ # WanVideoUnit_FunReference(),
277
+ # WanVideoUnit_FunCameraControl(),
278
+ # WanVideoUnit_SpeedControl(),
279
+ # WanVideoUnit_VACE(),
280
+ # WanVideoUnit_UnifiedSequenceParallel(),
281
+ # WanVideoUnit_TeaCache(),
282
+ # WanVideoUnit_CfgMerger(),
283
+ ]
284
+ self.model_fn = model_fn_wan_video
285
+
286
+
287
+ def load_lora(self, module, path, alpha=1):
288
+ loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
289
+ lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
290
+ loader.load(module, lora, alpha=alpha)
291
+
292
+
293
+ def training_loss(self, **inputs):
294
+ max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
295
+ min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps)
296
+ timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
297
+ timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
298
+ #* 单步去噪的时候,每次返回的都是纯噪声
299
+ #? 指的是input_latents 吧?
300
+ #* 本来就有inputs["latents"], 只不过是完全等于inputs["noise"], 这里做了更新然后覆盖
301
+ inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
302
+ training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
303
+
304
+ noise_pred = self.model_fn(**inputs, timestep=timestep)#* timestep === 1
305
+
306
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
307
+ loss = loss * self.scheduler.training_weight(timestep)
308
+ return loss
309
+
310
+
311
+ def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
312
+ self.vram_management_enabled = True
313
+ if num_persistent_param_in_dit is not None:
314
+ vram_limit = None
315
+ else:
316
+ if vram_limit is None:
317
+ vram_limit = self.get_vram()
318
+ vram_limit = vram_limit - vram_buffer
319
+ if self.text_encoder is not None:
320
+ dtype = next(iter(self.text_encoder.parameters())).dtype
321
+ enable_vram_management(
322
+ self.text_encoder,
323
+ module_map = {
324
+ torch.nn.Linear: AutoWrappedLinear,
325
+ torch.nn.Embedding: AutoWrappedModule,
326
+ T5RelativeEmbedding: AutoWrappedModule,
327
+ T5LayerNorm: AutoWrappedModule,
328
+ },
329
+ module_config = dict(
330
+ offload_dtype=dtype,
331
+ offload_device="cpu",
332
+ onload_dtype=dtype,
333
+ onload_device="cpu",
334
+ computation_dtype=self.torch_dtype,
335
+ computation_device=self.device,
336
+ ),
337
+ vram_limit=vram_limit,
338
+ )
339
+ if self.dit is not None:
340
+ dtype = next(iter(self.dit.parameters())).dtype
341
+ device = "cpu" if vram_limit is not None else self.device
342
+ enable_vram_management(
343
+ self.dit,
344
+ module_map = {
345
+ torch.nn.Linear: AutoWrappedLinear,
346
+ torch.nn.Conv3d: AutoWrappedModule,
347
+ torch.nn.LayerNorm: WanAutoCastLayerNorm,
348
+ RMSNorm: AutoWrappedModule,
349
+ torch.nn.Conv2d: AutoWrappedModule,
350
+ },
351
+ module_config = dict(
352
+ offload_dtype=dtype,
353
+ offload_device="cpu",
354
+ onload_dtype=dtype,
355
+ onload_device=device,
356
+ computation_dtype=self.torch_dtype,
357
+ computation_device=self.device,
358
+ ),
359
+ max_num_param=num_persistent_param_in_dit,
360
+ overflow_module_config = dict(
361
+ offload_dtype=dtype,
362
+ offload_device="cpu",
363
+ onload_dtype=dtype,
364
+ onload_device="cpu",
365
+ computation_dtype=self.torch_dtype,
366
+ computation_device=self.device,
367
+ ),
368
+ vram_limit=vram_limit,
369
+ )
370
+ if self.dit2 is not None:
371
+ dtype = next(iter(self.dit2.parameters())).dtype
372
+ device = "cpu" if vram_limit is not None else self.device
373
+ enable_vram_management(
374
+ self.dit2,
375
+ module_map = {
376
+ torch.nn.Linear: AutoWrappedLinear,
377
+ torch.nn.Conv3d: AutoWrappedModule,
378
+ torch.nn.LayerNorm: WanAutoCastLayerNorm,
379
+ RMSNorm: AutoWrappedModule,
380
+ torch.nn.Conv2d: AutoWrappedModule,
381
+ },
382
+ module_config = dict(
383
+ offload_dtype=dtype,
384
+ offload_device="cpu",
385
+ onload_dtype=dtype,
386
+ onload_device=device,
387
+ computation_dtype=self.torch_dtype,
388
+ computation_device=self.device,
389
+ ),
390
+ max_num_param=num_persistent_param_in_dit,
391
+ overflow_module_config = dict(
392
+ offload_dtype=dtype,
393
+ offload_device="cpu",
394
+ onload_dtype=dtype,
395
+ onload_device="cpu",
396
+ computation_dtype=self.torch_dtype,
397
+ computation_device=self.device,
398
+ ),
399
+ vram_limit=vram_limit,
400
+ )
401
+ if self.vae is not None:
402
+ dtype = next(iter(self.vae.parameters())).dtype
403
+ enable_vram_management(
404
+ self.vae,
405
+ module_map = {
406
+ torch.nn.Linear: AutoWrappedLinear,
407
+ torch.nn.Conv2d: AutoWrappedModule,
408
+ RMS_norm: AutoWrappedModule,
409
+ CausalConv3d: AutoWrappedModule,
410
+ Upsample: AutoWrappedModule,
411
+ torch.nn.SiLU: AutoWrappedModule,
412
+ torch.nn.Dropout: AutoWrappedModule,
413
+ },
414
+ module_config = dict(
415
+ offload_dtype=dtype,
416
+ offload_device="cpu",
417
+ onload_dtype=dtype,
418
+ onload_device=self.device,
419
+ computation_dtype=self.torch_dtype,
420
+ computation_device=self.device,
421
+ ),
422
+ )
423
+ if self.image_encoder is not None:
424
+ dtype = next(iter(self.image_encoder.parameters())).dtype
425
+ enable_vram_management(
426
+ self.image_encoder,
427
+ module_map = {
428
+ torch.nn.Linear: AutoWrappedLinear,
429
+ torch.nn.Conv2d: AutoWrappedModule,
430
+ torch.nn.LayerNorm: AutoWrappedModule,
431
+ },
432
+ module_config = dict(
433
+ offload_dtype=dtype,
434
+ offload_device="cpu",
435
+ onload_dtype=dtype,
436
+ onload_device="cpu",
437
+ computation_dtype=dtype,
438
+ computation_device=self.device,
439
+ ),
440
+ )
441
+ if self.motion_controller is not None:
442
+ dtype = next(iter(self.motion_controller.parameters())).dtype
443
+ enable_vram_management(
444
+ self.motion_controller,
445
+ module_map = {
446
+ torch.nn.Linear: AutoWrappedLinear,
447
+ },
448
+ module_config = dict(
449
+ offload_dtype=dtype,
450
+ offload_device="cpu",
451
+ onload_dtype=dtype,
452
+ onload_device="cpu",
453
+ computation_dtype=dtype,
454
+ computation_device=self.device,
455
+ ),
456
+ )
457
+ if self.vace is not None:
458
+ device = "cpu" if vram_limit is not None else self.device
459
+ enable_vram_management(
460
+ self.vace,
461
+ module_map = {
462
+ torch.nn.Linear: AutoWrappedLinear,
463
+ torch.nn.Conv3d: AutoWrappedModule,
464
+ torch.nn.LayerNorm: AutoWrappedModule,
465
+ RMSNorm: AutoWrappedModule,
466
+ },
467
+ module_config = dict(
468
+ offload_dtype=dtype,
469
+ offload_device="cpu",
470
+ onload_dtype=dtype,
471
+ onload_device=device,
472
+ computation_dtype=self.torch_dtype,
473
+ computation_device=self.device,
474
+ ),
475
+ vram_limit=vram_limit,
476
+ )
477
+
478
+
479
+ def initialize_usp(self):
480
+ import torch.distributed as dist
481
+ from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
482
+ dist.init_process_group(backend="nccl", init_method="env://")
483
+ init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
484
+ initialize_model_parallel(
485
+ sequence_parallel_degree=dist.get_world_size(),
486
+ ring_degree=1,
487
+ ulysses_degree=dist.get_world_size(),
488
+ )
489
+ torch.cuda.set_device(dist.get_rank())
490
+
491
+
492
+ def enable_usp(self):
493
+ from xfuser.core.distributed import get_sequence_parallel_world_size
494
+ from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
495
+
496
+ for block in self.dit.blocks:
497
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
498
+ self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
499
+ if self.dit2 is not None:
500
+ for block in self.dit2.blocks:
501
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
502
+ self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
503
+ self.sp_size = get_sequence_parallel_world_size()
504
+ self.use_unified_sequence_parallel = True
505
+
506
+
507
+ @staticmethod
508
+ def from_pretrained(
509
+ torch_dtype: torch.dtype = torch.bfloat16,
510
+ device: Union[str, torch.device] = "cuda",
511
+ model_configs: list[ModelConfig] = [],
512
+ tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
513
+ local_model_path: str = "./checkpoints",
514
+ skip_download: bool = False,
515
+ redirect_common_files: bool = True,
516
+ use_usp=False,
517
+ training_strategy='origin',
518
+ ):
519
+
520
+ # Redirect model path
521
+
522
+ if redirect_common_files:
523
+
524
+ redirect_dict = {
525
+ "models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
526
+ "Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
527
+ "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
528
+ }
529
+ for model_config in model_configs:
530
+ if model_config.origin_file_pattern is None or model_config.model_id is None:
531
+ continue
532
+ if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]:
533
+ print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.")
534
+ model_config.model_id = redirect_dict[model_config.origin_file_pattern]
535
+
536
+ # Initialize pipeline
537
+
538
+ if training_strategy == 'origin':
539
+ pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
540
+ logger.warning("Using origin generative model training")
541
+ else:
542
+ raise ValueError(f"Invalid training strategy: {training_strategy}")
543
+
544
+ if use_usp: pipe.initialize_usp()
545
+
546
+ # Download and load models
547
+ model_manager = ModelManager()
548
+
549
+ for model_config in model_configs:
550
+ model_config.download_if_necessary(use_usp=use_usp)
551
+ model_manager.load_model(
552
+ model_config.path,
553
+ device=model_config.offload_device or device,
554
+ torch_dtype=model_config.offload_dtype or torch_dtype
555
+ )
556
+
557
+ # Load models
558
+ pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
559
+ dit = model_manager.fetch_model("wan_video_dit", index=2)
560
+ if isinstance(dit, list):
561
+ pipe.dit, pipe.dit2 = dit
562
+ else:
563
+ pipe.dit = dit
564
+ pipe.vae = model_manager.fetch_model("wan_video_vae")
565
+ pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
566
+ pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
567
+ pipe.vace = model_manager.fetch_model("wan_video_vace")
568
+
569
+ # Size division factor
570
+ if pipe.vae is not None:
571
+ pipe.height_division_factor = pipe.vae.upsampling_factor * 2
572
+ pipe.width_division_factor = pipe.vae.upsampling_factor * 2
573
+
574
+ # Initialize tokenizer
575
+ tokenizer_config.download_if_necessary(use_usp=use_usp)
576
+ pipe.prompter.fetch_models(pipe.text_encoder)
577
+ pipe.prompter.fetch_tokenizer(tokenizer_config.path)
578
+
579
+ # Unified Sequence Parallel
580
+ if use_usp: pipe.enable_usp()
581
+ return pipe
582
+
583
+
584
+
585
+
586
+ @torch.no_grad()
587
+ def __call__(
588
+ self,
589
+ # Prompt
590
+ prompt: str,
591
+ negative_prompt: Optional[str] = "",
592
+ # Image-to-video
593
+ input_image: Optional[Image.Image] = None,
594
+ # First-last-frame-to-video
595
+ end_image: Optional[Image.Image] = None,
596
+ # Video-to-video
597
+ input_video: Optional[list[Image.Image]] = None,
598
+ denoising_strength: Optional[float] = 1.0,
599
+ # ControlNet
600
+ control_video: Optional[list[Image.Image]] = None,
601
+ reference_image: Optional[Image.Image] = None,
602
+ # Camera control
603
+ camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None,
604
+ camera_control_speed: Optional[float] = 1/54,
605
+ camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0),
606
+ # VACE
607
+ vace_video: Optional[list[Image.Image]] = None,
608
+ vace_video_mask: Optional[Image.Image] = None,
609
+ vace_reference_image: Optional[Image.Image] = None,
610
+ vace_scale: Optional[float] = 1.0,
611
+ # Randomness
612
+ seed: Optional[int] = None,
613
+ rand_device: Optional[str] = "cpu",
614
+ # Shape
615
+ height: Optional[int] = 480,
616
+ width: Optional[int] = 832,
617
+ num_frames=81,
618
+ # Classifier-free guidance
619
+ cfg_scale: Optional[float] = 5.0,
620
+ cfg_merge: Optional[bool] = False,
621
+ # Boundary
622
+ switch_DiT_boundary: Optional[float] = 0.875,
623
+ # Scheduler
624
+ num_inference_steps: Optional[int] = 50,
625
+ sigma_shift: Optional[float] = 5.0,
626
+ # Speed control
627
+ motion_bucket_id: Optional[int] = None,
628
+ # VAE tiling
629
+ tiled: Optional[bool] = True,
630
+ tile_size: Optional[tuple[int, int]] = (30, 52),
631
+ tile_stride: Optional[tuple[int, int]] = (15, 26),
632
+ # Sliding window
633
+ sliding_window_size: Optional[int] = None,
634
+ sliding_window_stride: Optional[int] = None,
635
+ # Teacache
636
+ tea_cache_l1_thresh: Optional[float] = None,
637
+ tea_cache_model_id: Optional[str] = "",
638
+ # progress_bar
639
+ progress_bar_cmd=tqdm,
640
+ mask: Optional[Image.Image] = None,
641
+ ):
642
+
643
+
644
+
645
+ # Scheduler
646
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
647
+
648
+ # Inputs
649
+ inputs_posi = {
650
+ "prompt": prompt,
651
+ "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
652
+ }
653
+ inputs_nega = {
654
+ "negative_prompt": negative_prompt,
655
+ "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
656
+ }
657
+ inputs_shared = {
658
+ "input_image": input_image,
659
+ "end_image": end_image,
660
+ "input_video": input_video, "denoising_strength": denoising_strength,
661
+ "control_video": control_video, "reference_image": reference_image,
662
+ "camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin,
663
+ "vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale,
664
+ "seed": seed, "rand_device": rand_device,
665
+ "height": height, "width": width, "num_frames": num_frames,
666
+ "cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
667
+ "sigma_shift": sigma_shift,
668
+ "motion_bucket_id": motion_bucket_id,
669
+ "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
670
+ "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
671
+ "mask":mask,
672
+ }
673
+
674
+ for unit in self.units:
675
+ inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
676
+
677
+ # Denoise
678
+ self.load_models_to_device(self.in_iteration_models)
679
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
680
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
681
+ # Switch DiT if necessary
682
+ if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2:
683
+ self.load_models_to_device(self.in_iteration_models_2)
684
+ models["dit"] = self.dit2
685
+
686
+ # Timestep
687
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
688
+
689
+ # Inference
690
+ noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
691
+ if cfg_scale != 1.0:
692
+ if cfg_merge:
693
+ noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0)
694
+ else:
695
+ noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep)
696
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
697
+ else:
698
+ noise_pred = noise_pred_posi
699
+
700
+ # Scheduler
701
+ inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
702
+ if "first_frame_latents" in inputs_shared:
703
+ inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"]
704
+
705
+ # VACE (TODO: remove it)
706
+ if vace_reference_image is not None:
707
+ inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
708
+
709
+ # Decode
710
+ self.load_models_to_device(['vae'])
711
+ vae_outs = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
712
+ # from einops import reduce
713
+ # video = reduce(vae_outs, 'b c t h w -> b c t', 'mean')
714
+
715
+ video = self.vae_output_to_video(vae_outs)
716
+ self.load_models_to_device([])
717
+
718
+ return video,vae_outs
719
+
720
+
721
+
722
+
723
+ def extract_frames_from_video_file(video_path):
724
+ try:
725
+ cap = cv2.VideoCapture(video_path)
726
+ frames = []
727
+
728
+ fps = cap.get(cv2.CAP_PROP_FPS)
729
+ if fps <= 0:
730
+ fps = 15.0
731
+
732
+ while True:
733
+ ret, frame = cap.read()
734
+ if not ret:
735
+ break
736
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
737
+ frame_rgb = Image.fromarray(frame_rgb)
738
+ frames.append(frame_rgb)
739
+
740
+ cap.release()
741
+ return frames, fps
742
+ except Exception as e:
743
+ logger.error(f"Error extracting frames from {video_path}: {str(e)}")
744
+ return [], 15.0
745
+
746
+
747
+ def resize_frame(frame, height, width):
748
+ frame = np.array(frame)
749
+ frame = torch.from_numpy(frame).permute(2, 0, 1).unsqueeze(0).float() / 255.0
750
+ frame = torch.nn.functional.interpolate(frame, (height, width), mode="bicubic", align_corners=False, antialias=True)
751
+ frame = (frame.squeeze(0).permute(1, 2, 0).clamp(0, 1) * 255).byte().numpy()
752
+ frame = Image.fromarray(frame)
753
+ return frame
754
+
755
+
756
+
757
+ from moge.model.v2 import MoGeModel
758
+ from tools.eval_utils import transfer_pred_disp2depth, transfer_pred_disp2depth_v2, colorize_depth_map
759
+ from tools.depth2pcd import depth2pcd
760
+ import cv2, copy
761
+
762
+
763
+ class DKTPipeline:
764
+ def __init__(self, model_path = None, is14B = False, is_depth = True):
765
+
766
+ if is14B:
767
+ if model_path is None:
768
+ if is_depth: #* 14B depth model
769
+ model_path = 'Daniellesry/DKT-Depth-14B'
770
+ else:#* 14B normal Model
771
+ model_path = 'Daniellesry/DKT-Normal-14B'
772
+ self.main_pipe = self.init_model_14B(model_path)
773
+ else:
774
+ if model_path is None:
775
+ if is_depth: #* 1.3B depth model
776
+ model_path = 'Daniellesry/DKT-Depth-1-3B'
777
+ else:#* 1.3B normal Model
778
+ raise ValueError("1.3B normal model is not available")
779
+ model_path = ...
780
+
781
+ self.main_pipe = self.init_model(model_path)
782
+
783
+
784
+ if is_depth:
785
+ self.prompt = 'depth'
786
+ # self.moge_pipe = self.load_moge_model()
787
+ else:
788
+ self.prompt = 'normal'
789
+
790
+ logger.info(f'DKT_PIPELINE init success, model_path: {model_path}, is14B: {is14B}, is_depth: {is_depth},prompt: {self.prompt}')
791
+
792
+
793
+
794
+ def init_model_14B(self,model_id):
795
+ """加载14B模型到指定GPU"""
796
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
797
+
798
+ pipe = WanVideoPipeline.from_pretrained(
799
+ torch_dtype=torch.bfloat16,
800
+ device=device,
801
+ # model_configs=[
802
+ # ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
803
+ # ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
804
+ # ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
805
+ # ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
806
+ # ],
807
+
808
+ model_configs=[
809
+ ModelConfig(model_id="alibaba-pai/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
810
+ ModelConfig(model_id="alibaba-pai/Wan2.1-Fun-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
811
+ ModelConfig(model_id="alibaba-pai/Wan2.1-Fun-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
812
+ ModelConfig(model_id="alibaba-pai/Wan2.1-Fun-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
813
+ ],
814
+ redirect_common_files=False,
815
+ training_strategy="origin",
816
+ )
817
+
818
+
819
+ lora_config = ModelConfig(
820
+ model_id=model_id,
821
+ origin_file_pattern="*.safetensors",
822
+ offload_device="cpu",
823
+ )
824
+ lora_config.download_if_necessary(use_usp=False)
825
+
826
+ pipe.load_lora(pipe.dit, lora_config.path, alpha=1.0)
827
+
828
+
829
+ pipe.enable_vram_management()
830
+
831
+ return pipe
832
+
833
+
834
+
835
+ def init_model(self, model_id):
836
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
837
+
838
+ pipe = WanVideoPipeline.from_pretrained(
839
+ torch_dtype=torch.bfloat16,
840
+ device=device,
841
+ model_configs=[
842
+ ModelConfig(
843
+ model_id="alibaba-pai/Wan2.1-Fun-1.3B-Control",
844
+ origin_file_pattern="diffusion_pytorch_model*.safetensors",
845
+ offload_device="cpu",
846
+ ),
847
+ ModelConfig(
848
+ model_id="alibaba-pai/Wan2.1-Fun-1.3B-Control",
849
+ origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth",
850
+ offload_device="cpu",
851
+ ),
852
+ ModelConfig(
853
+ model_id="alibaba-pai/Wan2.1-Fun-1.3B-Control",
854
+ origin_file_pattern="Wan2.1_VAE.pth",
855
+ offload_device="cpu",
856
+ ),
857
+ ModelConfig(
858
+ model_id="alibaba-pai/Wan2.1-Fun-1.3B-Control",
859
+ origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
860
+ offload_device="cpu",
861
+ ),
862
+ ],
863
+ training_strategy="origin",
864
+ )
865
+
866
+
867
+
868
+
869
+ lora_config = ModelConfig(
870
+ model_id=model_id ,
871
+ origin_file_pattern="*.safetensors",
872
+ offload_device="cpu",
873
+ )
874
+ lora_config.download_if_necessary(use_usp=False)
875
+
876
+ pipe.load_lora(pipe.dit, lora_config.path, alpha=1.0)#todo is it work?
877
+ pipe.enable_vram_management()
878
+ return pipe
879
+
880
+
881
+ def load_moge_model(self):
882
+ device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
883
+
884
+
885
+ print(f'device for loading MoGe: {device}')
886
+
887
+
888
+ cached_dir = 'checkpoints/moge_ckpt/moge-2-vitl-normal'
889
+ os.makedirs(cached_dir, exist_ok=True)
890
+ cached_model_path = f'{cached_dir}/model.pt'
891
+
892
+ if os.path.exists(cached_model_path):
893
+ logger.info(f"Found cached model at {cached_model_path}, loading from cache...")
894
+ moge_pipe = MoGeModel.from_pretrained(cached_model_path).to(device)
895
+ else:
896
+ logger.info(f"Cache not found at {cached_model_path}, downloading from HuggingFace...")
897
+ moge_pipe = MoGeModel.from_pretrained('Ruicheng/moge-2-vitl-normal').to(device)
898
+
899
+ if moge_pipe is not None:
900
+ logger.info(f'moge init success')
901
+ print(f'moge init success')
902
+ else:
903
+ logger.error(f'moge init failed')
904
+ print(f'moge init failed')
905
+ raise Exception('moge init failed')
906
+
907
+ return moge_pipe
908
+
909
+
910
+
911
+
912
+ @spaces.GPU(duration=240)
913
+ @torch.inference_mode()
914
+ def __call__(self, video_file,\
915
+ negative_prompt='', height=480, width=832, \
916
+ num_inference_steps=5, window_size=21, \
917
+ overlap=3, vis_pc = False, return_rgb = False, get_moge_intrinsics = False ):
918
+
919
+ origin_frames, input_fps = extract_frames_from_video_file(video_file)
920
+
921
+ frame_length = len(origin_frames)
922
+
923
+ original_width, original_height = origin_frames[0].size
924
+
925
+ ROTATE = False
926
+ if original_width < original_height:#* ensure the width is the longer side
927
+ ROTATE = True
928
+ origin_frames = [x.transpose(Image.ROTATE_90) for x in origin_frames]
929
+ tmp = original_width
930
+ original_width = original_height
931
+ original_height = tmp
932
+
933
+
934
+ frames = [resize_frame(frame, height, width) for frame in origin_frames]
935
+
936
+
937
+ if (frame_length - 1) % 4 != 0:
938
+ new_len = ((frame_length - 1) // 4 + 1) * 4 + 1
939
+ frames = frames + [copy.deepcopy(frames[-1]) for _ in range(new_len - frame_length)]
940
+
941
+
942
+
943
+
944
+ video, vae_outs = self.main_pipe(
945
+ prompt=self.prompt,
946
+ negative_prompt=negative_prompt,
947
+ control_video=frames,
948
+ height=height,
949
+ width=width,
950
+ num_frames=len(frames),
951
+ seed=1,
952
+ tiled=False,
953
+ num_inference_steps=num_inference_steps,
954
+ sliding_window_size=window_size,
955
+ sliding_window_stride=window_size - overlap,
956
+ cfg_scale=1.0,
957
+ )
958
+ torch.cuda.empty_cache()
959
+
960
+ processed_video = video[:frame_length]
961
+ processed_video = [resize_frame(frame, original_height, original_width) for frame in processed_video]
962
+
963
+ if ROTATE:
964
+ processed_video = [x.transpose(Image.ROTATE_270) for x in processed_video]
965
+ origin_frames = [x.transpose(Image.ROTATE_270) for x in origin_frames]
966
+
967
+ color_predictions = []
968
+ if self.prompt == 'depth':
969
+ prediced_depth_map_np = [np.array(item).astype(np.float32).mean(-1) for item in processed_video]
970
+ prediced_depth_map_np = np.stack(prediced_depth_map_np)
971
+ prediced_depth_map_np = prediced_depth_map_np / 255.0
972
+
973
+ __min = prediced_depth_map_np.min()
974
+ __max = prediced_depth_map_np.max()
975
+
976
+ prediced_depth_map_np_normalized = (prediced_depth_map_np - __min) / (__max - __min)
977
+ color_predictions = [colorize_depth_map(item) for item in prediced_depth_map_np_normalized]
978
+ else:
979
+ color_predictions = processed_video
980
+ prediced_depth_map_np = None
981
+
982
+ return_dict = {}
983
+
984
+ return_dict['depth_map'] = prediced_depth_map_np
985
+ return_dict['colored_depth_map'] = color_predictions
986
+
987
+
988
+
989
+ if vis_pc and self.prompt == 'depth':
990
+ vis_pc_num = 4
991
+ indices = np.linspace(0, frame_length-1, vis_pc_num)
992
+ indices = np.round(indices).astype(np.int32)
993
+ return_dict['point_clouds'] = self.prediction2pc(prediced_depth_map_np, origin_frames, indices)
994
+
995
+ if return_rgb:
996
+ return_dict['rgb_frames'] = origin_frames
997
+
998
+
999
+ if get_moge_intrinsics:
1000
+ demo_idx = 0
1001
+ origin_frames_demo = origin_frames[demo_idx]
1002
+ prediced_depth_map_np_demo = prediced_depth_map_np[demo_idx]
1003
+ input_image_np = np.array(origin_frames_demo)
1004
+ input_image = torch.tensor(input_image_np / 255, dtype=torch.float32, device=torch.device("cuda")).permute(2, 0, 1)
1005
+
1006
+ output = self.moge_pipe.infer(input_image)
1007
+ moge_intrinsics = output['intrinsics'].cpu().numpy()
1008
+ moge_mask = output['mask'].cpu().numpy().astype(bool)
1009
+ moge_depth = output['depth'].cpu().numpy()
1010
+
1011
+ metric_depth, scale, shift = transfer_pred_disp2depth(prediced_depth_map_np_demo, moge_depth, moge_mask, return_scale_shift=True)
1012
+
1013
+
1014
+
1015
+ moge_intrinsics[0, 0] *= original_width
1016
+ moge_intrinsics[1, 1] *= original_height
1017
+ moge_intrinsics[0, 2] *= original_width
1018
+ moge_intrinsics[1, 2] *= original_height
1019
+
1020
+
1021
+ return_dict['moge_intrinsics'] = moge_intrinsics
1022
+ return_dict['moge_mask'] = moge_mask
1023
+ return_dict['scale'] = scale
1024
+ return_dict['shift'] = shift
1025
+ # return_dict['moge_depth'] = moge_depth
1026
+ # return_dict['metric_depth'] = metric_depth
1027
+
1028
+ return return_dict
1029
+
1030
+
1031
+
1032
+ def prediction2pc(self, prediction_depth_map, RGB_frames, indices, return_pcd = True,nb_neighbors = 20, std_ratio = 3.0):
1033
+ resize_W,resize_H = RGB_frames[0].size
1034
+ pcds = []
1035
+ moge_device = self.moge_pipe.device if self.moge_pipe is not None else torch.device("cuda:0")
1036
+
1037
+ for idx in tqdm(indices):
1038
+ orgin_rgb_frame = RGB_frames[idx]
1039
+ predicted_depth = prediction_depth_map[idx]
1040
+
1041
+ # Read the input image and convert to tensor (3, H, W) with RGB values normalized to [0, 1]
1042
+ input_image_np = np.array(orgin_rgb_frame) # Convert PIL Image to numpy array
1043
+ input_image = torch.tensor(input_image_np / 255, dtype=torch.float32, device=moge_device).permute(2, 0, 1)
1044
+ output = self.moge_pipe.infer(input_image)
1045
+
1046
+ #* "dict_keys(['points', 'intrinsics', 'depth', 'mask', 'normal'])"
1047
+ moge_intrinsics = output['intrinsics'].cpu().numpy()
1048
+ moge_mask = output['mask'].cpu().numpy()
1049
+ moge_depth = output['depth'].cpu().numpy()
1050
+
1051
+
1052
+ metric_depth = transfer_pred_disp2depth(predicted_depth, moge_depth, moge_mask)
1053
+
1054
+ moge_intrinsics[0, 0] *= resize_W
1055
+ moge_intrinsics[1, 1] *= resize_H
1056
+ moge_intrinsics[0, 2] *= resize_W
1057
+ moge_intrinsics[1, 2] *= resize_H
1058
+
1059
+ pcd = depth2pcd(metric_depth, moge_intrinsics, color=input_image_np, input_mask=moge_mask, ret_pcd=return_pcd)
1060
+
1061
+ if return_pcd:
1062
+ #* [15,50], [2,3]
1063
+ cl, ind = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio)
1064
+ pcd = pcd.select_by_index(ind)
1065
+ #todo downsample
1066
+
1067
+ pcds.append(pcd)
1068
+
1069
+ return pcds
1070
+
1071
+
1072
+
1073
+
1074
+
1075
+
1076
+
1077
+ @spaces.GPU()
1078
+ @torch.inference_mode()
1079
+ def prediction2pc_v2(self, prediction_depth_map, RGB_frames, indices, return_pcd = True,nb_neighbors = 20, std_ratio = 3.0):
1080
+ """
1081
+ call MoGe once
1082
+ """
1083
+ resize_W,resize_H = RGB_frames[0].size
1084
+ pcds = []
1085
+ moge_device = self.moge_pipe.device if self.moge_pipe is not None else torch.device("cuda:0")
1086
+
1087
+ for iidx, idx in enumerate(tqdm(indices)):
1088
+
1089
+ orgin_rgb_frame = RGB_frames[idx]
1090
+ predicted_depth = prediction_depth_map[idx]
1091
+ input_image_np = np.array(orgin_rgb_frame) # Convert PIL Image to numpy array
1092
+
1093
+
1094
+ if iidx == 0:
1095
+ # Read the input image and convert to tensor (3, H, W) with RGB values normalized to [0, 1]
1096
+ if input_image_np.max() > 1:
1097
+ input_image = torch.tensor(input_image_np / 255, dtype=torch.float32, device=moge_device).permute(2, 0, 1)
1098
+ else:
1099
+ input_image = torch.tensor(input_image_np, dtype=torch.float32, device=moge_device).permute(2, 0, 1)
1100
+
1101
+ print(f'moge devices: {moge_device}') #* why cpu?
1102
+
1103
+ output = self.moge_pipe.infer(input_image)
1104
+
1105
+ #* "dict_keys(['points', 'intrinsics', 'depth', 'mask', 'normal'])"
1106
+ moge_intrinsics = output['intrinsics'].cpu().numpy()
1107
+ moge_mask = output['mask'].cpu().numpy()
1108
+ moge_depth = output['depth'].cpu().numpy()
1109
+
1110
+
1111
+ print(f'moge_mask dtype: {moge_mask.dtype}, shape: {moge_mask.shape}, valid_count: {moge_mask.sum()}, total: {moge_mask.size}')
1112
+
1113
+ if moge_mask.sum() == 0:
1114
+ print('stop pc gen due to the error happend in moge inference process')
1115
+ return pcds
1116
+
1117
+ # Ensure moge_mask is boolean type
1118
+ moge_mask = moge_mask.astype(bool)
1119
+
1120
+ metric_depth, scale, shift = transfer_pred_disp2depth(predicted_depth, moge_depth, moge_mask, return_scale_shift=True)
1121
+
1122
+ # Debug: check metric_depth
1123
+
1124
+ print(f'metric_depth shape: {metric_depth.shape}, min: {metric_depth.min():.4f}, max: {metric_depth.max():.4f}')
1125
+
1126
+ moge_intrinsics[0, 0] *= resize_W
1127
+ moge_intrinsics[1, 1] *= resize_H
1128
+ moge_intrinsics[0, 2] *= resize_W
1129
+ moge_intrinsics[1, 2] *= resize_H
1130
+ else:
1131
+ metric_depth = transfer_pred_disp2depth_v2(predicted_depth, scale, shift)
1132
+
1133
+
1134
+ pcd = depth2pcd(metric_depth, moge_intrinsics, color=input_image_np, input_mask=moge_mask, ret_pcd=return_pcd)
1135
+
1136
+ if return_pcd:
1137
+ #* [15,50], [2,3]
1138
+ cl, ind = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio)
1139
+ pcd = pcd.select_by_index(ind)
1140
+
1141
+
1142
+ pcds.append({'point':np.asarray(pcd.points),
1143
+ 'color': np.asarray(pcd.colors) if pcd.has_colors() else None} )
1144
+
1145
+ return pcds
1146
+
1147
+
1148
+
1149
+
1150
+
1151
+
1152
+
1153
+
1154
+
1155
+ def prediction2pc_v3(self, prediction_depth_map, RGB_frames, indices, scale, shift,moge_intrinsics, moge_mask, return_pcd = True,nb_neighbors = 20, std_ratio = 3.0):
1156
+ """
1157
+ call MoGe once
1158
+ """
1159
+ resize_W,resize_H = RGB_frames[0].size
1160
+ pcds = []
1161
+
1162
+ for idx in tqdm(indices):
1163
+
1164
+ orgin_rgb_frame = RGB_frames[idx]
1165
+ predicted_depth = prediction_depth_map[idx]
1166
+
1167
+ input_image_np = np.array(orgin_rgb_frame) # Convert PIL Image to numpy array
1168
+ metric_depth = transfer_pred_disp2depth_v2(predicted_depth, scale, shift)
1169
+
1170
+ pcd = depth2pcd(metric_depth, moge_intrinsics, color=input_image_np, input_mask=moge_mask, ret_pcd=return_pcd)
1171
+
1172
+ if return_pcd:
1173
+ #* [15,50], [2,3]
1174
+ cl, ind = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio)
1175
+ pcd = pcd.select_by_index(ind)
1176
+
1177
+
1178
+ pcds.append({'point':np.asarray(pcd.points),
1179
+ 'color': np.asarray(pcd.colors) if pcd.has_colors() else None} )
1180
+
1181
+ return pcds
1182
+
1183
+
1184
+
1185
+
1186
+
1187
+
1188
+
1189
+
1190
+
1191
+
1192
+
1193
+
1194
+
1195
+
1196
+
1197
+
1198
+
1199
+
1200
+
1201
+
1202
+
1203
+
1204
+
1205
+
1206
+ class WanVideoUnit_ShapeChecker(PipelineUnit):
1207
+ def __init__(self):
1208
+ super().__init__(input_params=("height", "width", "num_frames"))
1209
+
1210
+ def process(self, pipe: WanVideoPipeline, height, width, num_frames):
1211
+ height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
1212
+ return {"height": height, "width": width, "num_frames": num_frames}
1213
+
1214
+
1215
+
1216
+ class WanVideoUnit_NoiseInitializer(PipelineUnit):
1217
+ def __init__(self):
1218
+ super().__init__(input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image"))
1219
+
1220
+ def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image):
1221
+ length = (num_frames - 1) // 4 + 1
1222
+ if vace_reference_image is not None:
1223
+ length += 1
1224
+ shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)
1225
+ noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
1226
+ if vace_reference_image is not None:
1227
+ noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2)
1228
+ return {"noise": noise}
1229
+
1230
+
1231
+
1232
+ class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
1233
+ def __init__(self):
1234
+ super().__init__(
1235
+ input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"),
1236
+ onload_model_names=("vae",)
1237
+ )
1238
+
1239
+ def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image):
1240
+ if input_video is None:
1241
+ return {"latents": noise}
1242
+
1243
+ pipe.load_models_to_device(["vae"])#* input_video is the GT
1244
+ input_video = pipe.preprocess_video(input_video) #* [B,3,F,W,H]
1245
+ #* [B,3,(F/4) + 1 ,W/8,H/8]
1246
+ input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
1247
+ if vace_reference_image is not None:
1248
+ vace_reference_image = pipe.preprocess_video([vace_reference_image])
1249
+ vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)
1250
+ input_latents = torch.concat([vace_reference_latents, input_latents], dim=2)
1251
+ #? during training, the input_latents have nothing to do with the noise,
1252
+ #? but during inference, the input_latents is used to generate the noise
1253
+ if pipe.scheduler.training:
1254
+ return {"latents": noise, "input_latents": input_latents}
1255
+ else:
1256
+ latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
1257
+ return {"latents": latents}
1258
+
1259
+
1260
+
1261
+ class WanVideoUnit_PromptEmbedder(PipelineUnit):
1262
+ def __init__(self):
1263
+ super().__init__(
1264
+ seperate_cfg=True,
1265
+ input_params_posi={"prompt": "prompt", "positive": "positive"},
1266
+ input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
1267
+ onload_model_names=("text_encoder",)
1268
+ )
1269
+
1270
+ def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict:
1271
+ pipe.load_models_to_device(self.onload_model_names)
1272
+ prompt_emb = pipe.prompter.encode_prompt(prompt, positive=positive, device=pipe.device)
1273
+ return {"context": prompt_emb}
1274
+
1275
+
1276
+
1277
+ class WanVideoUnit_ImageEmbedder(PipelineUnit):
1278
+ """
1279
+ Deprecated
1280
+ """
1281
+ def __init__(self):
1282
+ super().__init__(
1283
+ input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
1284
+ onload_model_names=("image_encoder", "vae")
1285
+ )
1286
+
1287
+ def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
1288
+ if input_image is None or pipe.image_encoder is None:
1289
+ return {}
1290
+ pipe.load_models_to_device(self.onload_model_names)
1291
+ image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
1292
+ clip_context = pipe.image_encoder.encode_image([image])
1293
+ msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) #* indicate which image is reference image
1294
+ msk[:, 1:] = 0
1295
+ if end_image is not None:
1296
+ end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
1297
+ vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
1298
+ if pipe.dit.has_image_pos_emb:
1299
+ clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1)
1300
+ msk[:, -1:] = 1
1301
+ else:
1302
+ vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
1303
+
1304
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
1305
+ msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
1306
+ msk = msk.transpose(1, 2)[0]
1307
+
1308
+ y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
1309
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
1310
+ y = torch.concat([msk, y])
1311
+ y = y.unsqueeze(0)
1312
+ clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
1313
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
1314
+ return {"clip_feature": clip_context, "y": y}
1315
+
1316
+
1317
+
1318
+ class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit):
1319
+ def __init__(self):
1320
+ super().__init__(
1321
+ input_params=("input_image", "end_image", "height", "width"),
1322
+ onload_model_names=("image_encoder",)
1323
+ )
1324
+
1325
+ def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width):
1326
+ if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding:
1327
+ return {}
1328
+ pipe.load_models_to_device(self.onload_model_names)
1329
+ image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
1330
+ clip_context = pipe.image_encoder.encode_image([image])
1331
+ if end_image is not None:
1332
+ end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
1333
+ if pipe.dit.has_image_pos_emb:
1334
+ clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1)
1335
+ clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
1336
+ return {"clip_feature": clip_context}
1337
+
1338
+
1339
+
1340
+ class WanVideoUnit_ImageEmbedderVAE(PipelineUnit):
1341
+ def __init__(self):
1342
+ super().__init__(
1343
+ input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
1344
+ onload_model_names=("vae",)
1345
+ )
1346
+
1347
+ def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
1348
+ if input_image is None or not pipe.dit.require_vae_embedding:
1349
+ return {}
1350
+ pipe.load_models_to_device(self.onload_model_names)
1351
+ image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
1352
+ msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
1353
+ msk[:, 1:] = 0
1354
+ if end_image is not None:
1355
+ end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
1356
+ vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
1357
+ msk[:, -1:] = 1
1358
+ else:
1359
+ vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
1360
+
1361
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
1362
+ msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
1363
+ msk = msk.transpose(1, 2)[0]
1364
+
1365
+ y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
1366
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
1367
+ y = torch.concat([msk, y])
1368
+ y = y.unsqueeze(0)
1369
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
1370
+ return {"y": y}
1371
+
1372
+
1373
+
1374
+ class WanVideoUnit_ImageEmbedderFused(PipelineUnit):
1375
+ """
1376
+ Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.
1377
+ """
1378
+ def __init__(self):
1379
+ super().__init__(
1380
+ input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"),
1381
+ onload_model_names=("vae",)
1382
+ )
1383
+
1384
+ def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride):
1385
+ if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents:
1386
+ return {}
1387
+ pipe.load_models_to_device(self.onload_model_names)
1388
+ image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1)
1389
+ z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
1390
+ latents[:, :, 0: 1] = z
1391
+ return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z}
1392
+
1393
+
1394
+
1395
+ class WanVideoUnit_FunControl(PipelineUnit):
1396
+ def __init__(self):
1397
+ super().__init__(
1398
+ input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"),
1399
+ onload_model_names=("vae",)
1400
+ )
1401
+
1402
+ def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y):
1403
+
1404
+ if control_video is None:
1405
+ return {}
1406
+ pipe.load_models_to_device(self.onload_model_names)
1407
+ #* transfer to torch.tensor from PIL.Image
1408
+ #* result size: [1, 3, F, H, W]
1409
+ control_video = pipe.preprocess_video(control_video)
1410
+
1411
+ control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
1412
+ #* size of control_latents: [1, 3, (F/4) + 1 , H/8, W/8]
1413
+ control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
1414
+
1415
+ if clip_feature is None or y is None:
1416
+ #* this branch is used during training
1417
+ clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device)
1418
+
1419
+ # y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
1420
+
1421
+ #* [1, 16, (F/4) + 1 , H/8, W/8]
1422
+ y = torch.zeros((1, 16, control_latents.shape[-3], height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
1423
+ else:
1424
+ y = y[:, -16:]
1425
+ #* control_latents: [1, 16, 21, 60, 80]; y: [1, 16, 21, 60, 80])
1426
+
1427
+ #* [1, 32, (F/4) + 1 , H/8, W/8], 前16个通道是control_latents, 后16个通道是y(或者说0 vector)
1428
+ y = torch.concat([control_latents, y], dim=1)
1429
+ return {"clip_feature": clip_feature, "y": y}
1430
+
1431
+
1432
+
1433
+ class WanVideoUnit_FunControl_Mask(PipelineUnit):
1434
+ def __init__(self):
1435
+ super().__init__(
1436
+ input_params=("control_video", "mask","num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"),
1437
+ onload_model_names=("vae",)
1438
+ )
1439
+
1440
+ def process(self, pipe: WanVideoPipeline, control_video, mask, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y):
1441
+
1442
+ if control_video is None:
1443
+ return {}
1444
+ pipe.load_models_to_device(self.onload_model_names)
1445
+ #* transfer to torch.tensor from PIL.Image
1446
+ #* result size: [1, 3, F, H, W]
1447
+
1448
+ control_video = pipe.preprocess_video(control_video)
1449
+
1450
+
1451
+ control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
1452
+
1453
+ #* size of control_latents: [1, 3, (F/4) + 1 , H/8, W/8]
1454
+ control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
1455
+
1456
+ if mask is not None:
1457
+ mask = pipe.preprocess_video(mask)
1458
+ mask_latents = pipe.vae.encode(mask, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
1459
+ mask_latents = mask_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
1460
+
1461
+
1462
+ if clip_feature is None or y is None:
1463
+ #* this branch is used during training
1464
+ clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device)
1465
+
1466
+ # y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
1467
+
1468
+ #* [1, 16, (F/4) + 1 , H/8, W/8]
1469
+ y = torch.zeros((1, 16, control_latents.shape[-3], height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
1470
+ else:
1471
+ y = y[:, -16:]
1472
+
1473
+ #* control_latents: [1, 16, 21, 60, 80]; y: [1, 16, 21, 60, 80])
1474
+
1475
+ #* [1, 32, (F/4) + 1 , H/8, W/8], 前16个通道是control_latents, 后16个通道是y(或者说0 vector)
1476
+
1477
+ if mask is not None:
1478
+ y = torch.concat([control_latents, mask_latents], dim=1)
1479
+ # logger.warning(f"mask is provided, using mask_latents instead of y")
1480
+ else:
1481
+ y = torch.concat([control_latents, y], dim=1)
1482
+ # logger.warning(f"mask is not provided, using y")
1483
+
1484
+ return {"clip_feature": clip_feature, "y": y}
1485
+
1486
+
1487
+
1488
+ class WanVideoUnit_FunReference(PipelineUnit):
1489
+ def __init__(self):
1490
+ super().__init__(
1491
+ input_params=("reference_image", "height", "width", "reference_image"),
1492
+ onload_model_names=("vae",)
1493
+ )
1494
+
1495
+ def process(self, pipe: WanVideoPipeline, reference_image, height, width):
1496
+ if reference_image is None:
1497
+ return {}
1498
+ pipe.load_models_to_device(["vae"])
1499
+ reference_image = reference_image.resize((width, height))
1500
+ reference_latents = pipe.preprocess_video([reference_image])
1501
+ reference_latents = pipe.vae.encode(reference_latents, device=pipe.device)
1502
+ clip_feature = pipe.preprocess_image(reference_image)
1503
+ clip_feature = pipe.image_encoder.encode_image([clip_feature])
1504
+ return {"reference_latents": reference_latents, "clip_feature": clip_feature}
1505
+
1506
+
1507
+
1508
+
1509
+ class WanVideoUnit_FunCameraControl(PipelineUnit):
1510
+ def __init__(self):
1511
+ super().__init__(
1512
+ input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image"),
1513
+ onload_model_names=("vae",)
1514
+ )
1515
+
1516
+ def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image):
1517
+ if camera_control_direction is None:
1518
+ return {}
1519
+ camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates(
1520
+ camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin)
1521
+
1522
+ control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0)
1523
+ control_camera_latents = torch.concat(
1524
+ [
1525
+ torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
1526
+ control_camera_video[:, :, 1:]
1527
+ ], dim=2
1528
+ ).transpose(1, 2)
1529
+ b, f, c, h, w = control_camera_latents.shape
1530
+ control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
1531
+ control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
1532
+ control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype)
1533
+
1534
+ input_image = input_image.resize((width, height))
1535
+ input_latents = pipe.preprocess_video([input_image])
1536
+ pipe.load_models_to_device(self.onload_model_names)
1537
+ input_latents = pipe.vae.encode(input_latents, device=pipe.device)
1538
+ y = torch.zeros_like(latents).to(pipe.device)
1539
+ y[:, :, :1] = input_latents
1540
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
1541
+ return {"control_camera_latents_input": control_camera_latents_input, "y": y}
1542
+
1543
+
1544
+
1545
+ class WanVideoUnit_SpeedControl(PipelineUnit):
1546
+ def __init__(self):
1547
+ super().__init__(input_params=("motion_bucket_id",))
1548
+
1549
+ def process(self, pipe: WanVideoPipeline, motion_bucket_id):
1550
+ if motion_bucket_id is None:
1551
+ return {}
1552
+ motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device)
1553
+ return {"motion_bucket_id": motion_bucket_id}
1554
+
1555
+
1556
+
1557
+ class WanVideoUnit_VACE(PipelineUnit):
1558
+ def __init__(self):
1559
+ super().__init__(
1560
+ input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"),
1561
+ onload_model_names=("vae",)
1562
+ )
1563
+
1564
+ def process(
1565
+ self,
1566
+ pipe: WanVideoPipeline,
1567
+ vace_video, vace_video_mask, vace_reference_image, vace_scale,
1568
+ height, width, num_frames,
1569
+ tiled, tile_size, tile_stride
1570
+ ):
1571
+ if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None:
1572
+ pipe.load_models_to_device(["vae"])
1573
+ if vace_video is None:
1574
+ vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device)
1575
+ else:
1576
+ vace_video = pipe.preprocess_video(vace_video)
1577
+
1578
+ if vace_video_mask is None:
1579
+ vace_video_mask = torch.ones_like(vace_video)
1580
+ else:
1581
+ vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1)
1582
+
1583
+ inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask
1584
+ reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask)
1585
+ inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
1586
+ reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
1587
+ vace_video_latents = torch.concat((inactive, reactive), dim=1)
1588
+
1589
+ vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
1590
+ vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
1591
+
1592
+ if vace_reference_image is None:
1593
+ pass
1594
+ else:
1595
+ vace_reference_image = pipe.preprocess_video([vace_reference_image])
1596
+ vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
1597
+ vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
1598
+ vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
1599
+ vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
1600
+
1601
+ vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
1602
+ return {"vace_context": vace_context, "vace_scale": vace_scale}
1603
+ else:
1604
+ return {"vace_context": None, "vace_scale": vace_scale}
1605
+
1606
+
1607
+
1608
+ class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
1609
+ def __init__(self):
1610
+ super().__init__(input_params=())
1611
+
1612
+ def process(self, pipe: WanVideoPipeline):
1613
+ if hasattr(pipe, "use_unified_sequence_parallel"):
1614
+ if pipe.use_unified_sequence_parallel:
1615
+ return {"use_unified_sequence_parallel": True}
1616
+ return {}
1617
+
1618
+
1619
+
1620
+ class WanVideoUnit_TeaCache(PipelineUnit):
1621
+ def __init__(self):
1622
+ super().__init__(
1623
+ seperate_cfg=True,
1624
+ input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"},
1625
+ input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"},
1626
+ )
1627
+
1628
+ def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id):
1629
+ if tea_cache_l1_thresh is None:
1630
+ return {}
1631
+ return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)}
1632
+
1633
+
1634
+
1635
+ class WanVideoUnit_CfgMerger(PipelineUnit):
1636
+ def __init__(self):
1637
+ super().__init__(take_over=True)
1638
+ self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"]
1639
+
1640
+ def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
1641
+ if not inputs_shared["cfg_merge"]:
1642
+ return inputs_shared, inputs_posi, inputs_nega
1643
+ for name in self.concat_tensor_names:
1644
+ tensor_posi = inputs_posi.get(name)
1645
+ tensor_nega = inputs_nega.get(name)
1646
+ tensor_shared = inputs_shared.get(name)
1647
+ if tensor_posi is not None and tensor_nega is not None:
1648
+ inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0)
1649
+ elif tensor_shared is not None:
1650
+ inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0)
1651
+ inputs_posi.clear()
1652
+ inputs_nega.clear()
1653
+ return inputs_shared, inputs_posi, inputs_nega
1654
+
1655
+
1656
+
1657
+ class TeaCache:
1658
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
1659
+ self.num_inference_steps = num_inference_steps
1660
+ self.step = 0
1661
+ self.accumulated_rel_l1_distance = 0
1662
+ self.previous_modulated_input = None
1663
+ self.rel_l1_thresh = rel_l1_thresh
1664
+ self.previous_residual = None
1665
+ self.previous_hidden_states = None
1666
+
1667
+ self.coefficients_dict = {
1668
+ "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
1669
+ "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
1670
+ "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
1671
+ "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
1672
+ }
1673
+ if model_id not in self.coefficients_dict:
1674
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
1675
+ raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
1676
+ self.coefficients = self.coefficients_dict[model_id]
1677
+
1678
+ def check(self, dit: WanModel, x, t_mod):
1679
+ modulated_inp = t_mod.clone()
1680
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
1681
+ should_calc = True
1682
+ self.accumulated_rel_l1_distance = 0
1683
+ else:
1684
+ coefficients = self.coefficients
1685
+ rescale_func = np.poly1d(coefficients)
1686
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
1687
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
1688
+ should_calc = False
1689
+ else:
1690
+ should_calc = True
1691
+ self.accumulated_rel_l1_distance = 0
1692
+ self.previous_modulated_input = modulated_inp
1693
+ self.step += 1
1694
+ if self.step == self.num_inference_steps:
1695
+ self.step = 0
1696
+ if should_calc:
1697
+ self.previous_hidden_states = x.clone()
1698
+ return not should_calc
1699
+
1700
+ def store(self, hidden_states):
1701
+ self.previous_residual = hidden_states - self.previous_hidden_states
1702
+ self.previous_hidden_states = None
1703
+
1704
+ def update(self, hidden_states):
1705
+ hidden_states = hidden_states + self.previous_residual
1706
+ return hidden_states
1707
+
1708
+
1709
+
1710
+ class TemporalTiler_BCTHW:
1711
+ def __init__(self):
1712
+ pass
1713
+
1714
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
1715
+ x = torch.ones((length,))
1716
+ if border_width == 0:
1717
+ return x
1718
+
1719
+ shift = 0.5
1720
+ if not left_bound:
1721
+ x[:border_width] = (torch.arange(border_width) + shift) / border_width
1722
+ if not right_bound:
1723
+ x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,))
1724
+ return x
1725
+
1726
+ def build_mask(self, data, is_bound, border_width):
1727
+ _, _, T, _, _ = data.shape
1728
+ t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
1729
+ mask = repeat(t, "T -> 1 1 T 1 1")
1730
+ return mask
1731
+
1732
+ def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None):
1733
+ tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None]
1734
+ tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names}
1735
+
1736
+ B, C, T, H, W = tensor_dict[tensor_names[0]].shape
1737
+ if batch_size is not None:
1738
+ B *= batch_size
1739
+ data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype
1740
+ value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype)
1741
+ weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype)
1742
+ for t in range(0, T, sliding_window_stride):
1743
+ if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T: #* 如果上一个窗口已经走到最后一帧了, 那么就continue/break
1744
+ continue
1745
+ t_ = min(t + sliding_window_size, T)
1746
+
1747
+ model_kwargs.update({
1748
+ tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \
1749
+ for tensor_name in tensor_names
1750
+ })
1751
+ model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype)
1752
+
1753
+ mask = self.build_mask(
1754
+ model_output,
1755
+ is_bound=(t == 0, t_ == T),
1756
+ border_width=(sliding_window_size - sliding_window_stride,)
1757
+ ).to(device=data_device, dtype=data_dtype)
1758
+
1759
+ # logger.info(f"t: {t}, t_: {t_}, sliding_window_size: {sliding_window_size}, sliding_window_stride: {sliding_window_stride}")
1760
+
1761
+ value[:, :, t: t_, :, :] += model_output * mask
1762
+ weight[:, :, t: t_, :, :] += mask
1763
+ value /= weight
1764
+ model_kwargs.update(tensor_dict)
1765
+ return value
1766
+
1767
+
1768
+
1769
+ def model_fn_wan_video(
1770
+ dit: WanModel,
1771
+ motion_controller: WanMotionControllerModel = None,
1772
+ vace: VaceWanModel = None,
1773
+ latents: torch.Tensor = None,
1774
+ timestep: torch.Tensor = None,
1775
+ context: torch.Tensor = None,
1776
+ clip_feature: Optional[torch.Tensor] = None,
1777
+ y: Optional[torch.Tensor] = None,
1778
+ reference_latents = None,
1779
+ vace_context = None,
1780
+ vace_scale = 1.0,
1781
+ tea_cache: TeaCache = None,
1782
+ use_unified_sequence_parallel: bool = False,
1783
+ motion_bucket_id: Optional[torch.Tensor] = None,
1784
+ sliding_window_size: Optional[int] = None,
1785
+ sliding_window_stride: Optional[int] = None,
1786
+ cfg_merge: bool = False,
1787
+ use_gradient_checkpointing: bool = False,
1788
+ use_gradient_checkpointing_offload: bool = False,
1789
+ control_camera_latents_input = None,
1790
+ fuse_vae_embedding_in_latents: bool = False,
1791
+ **kwargs,
1792
+ ):
1793
+
1794
+
1795
+ if sliding_window_size is not None and sliding_window_stride is not None: #* skip for training,
1796
+ model_kwargs = dict(
1797
+ dit=dit,
1798
+ motion_controller=motion_controller,
1799
+ vace=vace,
1800
+ latents=latents,
1801
+ timestep=timestep,
1802
+ context=context,
1803
+ clip_feature=clip_feature,
1804
+ y=y,
1805
+ reference_latents=reference_latents,
1806
+ vace_context=vace_context,
1807
+ vace_scale=vace_scale,
1808
+ tea_cache=tea_cache,
1809
+ use_unified_sequence_parallel=use_unified_sequence_parallel,
1810
+ motion_bucket_id=motion_bucket_id,
1811
+ )
1812
+ return TemporalTiler_BCTHW().run(
1813
+ model_fn_wan_video,
1814
+ sliding_window_size, sliding_window_stride,
1815
+ latents.device, latents.dtype,
1816
+ model_kwargs=model_kwargs,
1817
+ tensor_names=["latents", "y"],
1818
+ batch_size=2 if cfg_merge else 1
1819
+ )
1820
+
1821
+ if use_unified_sequence_parallel:#* skip
1822
+ import torch.distributed as dist
1823
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
1824
+ get_sequence_parallel_world_size,
1825
+ get_sp_group)
1826
+
1827
+
1828
+ # Timestep
1829
+ if dit.seperated_timestep and fuse_vae_embedding_in_latents:
1830
+ timestep = torch.concat([
1831
+ torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device),
1832
+ torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep
1833
+ ]).flatten()
1834
+ t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0))
1835
+ if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
1836
+ t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1)
1837
+ t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks]
1838
+ t = t_chunks[get_sequence_parallel_rank()]
1839
+ t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
1840
+ else:#* this branch
1841
+ t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) #* out: torch.Size([1, 1536])
1842
+ t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) #* out: torch.Size([1, 6, 1536]); dit.dim: 1536
1843
+
1844
+
1845
+
1846
+
1847
+ if motion_bucket_id is not None and motion_controller is not None: #* skip
1848
+ t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
1849
+
1850
+ context = dit.text_embedding(context)#* text prompt, 比如“depth”, : from torch.Size([1, 512, 4096]) to torch.Size([1, 512, 1536])
1851
+ #todo double check 这个x
1852
+ #* [1, 16, (F-1)/4, H/8, W/8], 纯高斯噪声 或者 加噪后的gt
1853
+ x = latents
1854
+
1855
+ # Merged cfg
1856
+ #* batch 这个维度必须一致, 跟
1857
+ if x.shape[0] != context.shape[0]:
1858
+ x = torch.concat([x] * context.shape[0], dim=0)
1859
+ if timestep.shape[0] != context.shape[0]:
1860
+ timestep = torch.concat([timestep] * context.shape[0], dim=0)
1861
+
1862
+ # Image Embedding
1863
+ """
1864
+ new parameters:
1865
+ #* require_vae_embedding
1866
+ #* require_clip_embedding
1867
+ """
1868
+
1869
+ # todo: x 是target video(也就是depth/normal video) 通过噪声调整的结果 / 纯高斯噪声; y是输入的rgb video
1870
+ #todo , double check 这个y, [1, 32, (F-1)/4, H/8, W/8]
1871
+ if y is not None and dit.require_vae_embedding:
1872
+ x = torch.cat([x, y], dim=1)# (b, c_x + c_y, f, h, w) #* [1, 48, (F-1)/4, H/8, W/8]
1873
+ if clip_feature is not None and dit.require_clip_embedding:
1874
+ #* clip_feature is initialized by zero, from torch.Size([1, 257, 1280]) to torch.Size([1, 257, 1536])
1875
+ clip_embdding = dit.img_emb(clip_feature)
1876
+ #* concat 257 and 512 to form torch.Size([1, 769, 1536])
1877
+ context = torch.cat([clip_embdding, context], dim=1)
1878
+
1879
+ # Add camera control
1880
+ #* from torch.Size([1, 48, (F-1)/4, H/8, W/8]),
1881
+ #* to [1, 1536, (F-1)/4, H/16, W/16] (函数内的mlp)
1882
+ #* to [1, 1536, ( (F-1)/4 * H/16 * W/16)]
1883
+ #* x_out: [1, 1536, ( (F-1)/4 * H/16 * W/16)]
1884
+ x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
1885
+
1886
+ # Reference image
1887
+ if reference_latents is not None: #* skip
1888
+ if len(reference_latents.shape) == 5:
1889
+ reference_latents = reference_latents[:, :, 0]
1890
+ reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2)
1891
+ x = torch.concat([reference_latents, x], dim=1)
1892
+ f += 1
1893
+
1894
+ #* RoPE position embedding for 3D video, [ ( (F-1)/4 * H/16 * W/16), 1, 64]
1895
+ freqs = torch.cat([
1896
+ dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
1897
+ dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
1898
+ dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
1899
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
1900
+
1901
+ # TeaCache
1902
+ if tea_cache is not None:#*skip
1903
+ tea_cache_update = tea_cache.check(dit, x, t_mod)
1904
+ else:
1905
+ tea_cache_update = False
1906
+
1907
+ if vace_context is not None:#*skip
1908
+ vace_hints = vace(x, vace_context, context, t_mod, freqs)
1909
+
1910
+ # blocks
1911
+ if use_unified_sequence_parallel:#* skip
1912
+ if dist.is_initialized() and dist.get_world_size() > 1:
1913
+ chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
1914
+ pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
1915
+ chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
1916
+ x = chunks[get_sequence_parallel_rank()]
1917
+ if tea_cache_update:
1918
+ x = tea_cache.update(x)
1919
+ else:
1920
+ def create_custom_forward(module):
1921
+ def custom_forward(*inputs):
1922
+ return module(*inputs)
1923
+ return custom_forward
1924
+ #* pass through dit blocks 30 times
1925
+ for block_id, block in enumerate(dit.blocks):
1926
+ if use_gradient_checkpointing_offload:
1927
+ with torch.autograd.graph.save_on_cpu():
1928
+ x = torch.utils.checkpoint.checkpoint(
1929
+ create_custom_forward(block),
1930
+ x, context, t_mod, freqs,
1931
+ use_reentrant=False,
1932
+ )
1933
+ elif use_gradient_checkpointing:
1934
+ x = torch.utils.checkpoint.checkpoint(
1935
+ create_custom_forward(block),
1936
+ x, context, t_mod, freqs,
1937
+ use_reentrant=False,
1938
+ )
1939
+ else:
1940
+ x = block(x, context, t_mod, freqs)#* x_in: [1, ( (F-1)/4 * H/16 * W/16), 1536], context_in: [1, 769, 1536], t_mod_in: [1, 6, 1536], freqs_in: [ ( (F-1)/4 * H/16 * W/16), 1, 64], x_out: [1, ( (F-1)/4 * H/16 * W/16), 1536]
1941
+ if vace_context is not None and block_id in vace.vace_layers_mapping:#* skip
1942
+ current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
1943
+ if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
1944
+ current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
1945
+ current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
1946
+ x = x + current_vace_hint * vace_scale
1947
+ if tea_cache is not None:#* skip
1948
+ tea_cache.store(x)
1949
+
1950
+ #* x_in: [1, ( (F-1)/4 * H/16 * W/16), 1536], t_in: [1, 1536],
1951
+ #* x_out: [1, ( (F-1)/4 * H/16 * W/16), 64]
1952
+ x = dit.head(x, t)
1953
+ if use_unified_sequence_parallel:#* skip
1954
+ if dist.is_initialized() and dist.get_world_size() > 1:
1955
+ x = get_sp_group().all_gather(x, dim=1)
1956
+ x = x[:, :-pad_shape] if pad_shape > 0 else x
1957
+
1958
+ # Remove reference latents
1959
+ if reference_latents is not None:#* skip
1960
+ x = x[:, reference_latents.shape[1]:]
1961
+ f -= 1
1962
+
1963
+ #* unpatchify, from [1, ( (F-1)/4 * H/16 * W/16), 64] to [1, 16, (F-1)/4, H/8, W/8]
1964
+ x = dit.unpatchify(x, (f, h, w))
1965
+ return x
dkt/prompters/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Only wan_prompter is used by entry files
2
+ # from .prompt_refiners import Translator, BeautifulPrompt, QwenPrompt
3
+ # from .sd_prompter import SDPrompter
4
+ # from .sdxl_prompter import SDXLPrompter
5
+ # from .sd3_prompter import SD3Prompter
6
+ # from .hunyuan_dit_prompter import HunyuanDiTPrompter
7
+ # from .kolors_prompter import KolorsPrompter
8
+ # from .flux_prompter import FluxPrompter
9
+ # from .omost import OmostPromter
10
+ # from .cog_prompter import CogPrompter
11
+ # from .hunyuan_video_prompter import HunyuanVideoPrompter
12
+ # from .stepvideo_prompter import StepVideoPrompter
13
+ from .wan_prompter import WanPrompter
dkt/prompters/base_prompter.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..models.model_manager import ModelManager
2
+ import torch
3
+
4
+
5
+
6
+ def tokenize_long_prompt(tokenizer, prompt, max_length=None):
7
+ # Get model_max_length from self.tokenizer
8
+ length = tokenizer.model_max_length if max_length is None else max_length
9
+
10
+ # To avoid the warning. set self.tokenizer.model_max_length to +oo.
11
+ tokenizer.model_max_length = 99999999
12
+
13
+ # Tokenize it!
14
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
15
+
16
+ # Determine the real length.
17
+ max_length = (input_ids.shape[1] + length - 1) // length * length
18
+
19
+ # Restore tokenizer.model_max_length
20
+ tokenizer.model_max_length = length
21
+
22
+ # Tokenize it again with fixed length.
23
+ input_ids = tokenizer(
24
+ prompt,
25
+ return_tensors="pt",
26
+ padding="max_length",
27
+ max_length=max_length,
28
+ truncation=True
29
+ ).input_ids
30
+
31
+ # Reshape input_ids to fit the text encoder.
32
+ num_sentence = input_ids.shape[1] // length
33
+ input_ids = input_ids.reshape((num_sentence, length))
34
+
35
+ return input_ids
36
+
37
+
38
+
39
+ class BasePrompter:
40
+ def __init__(self):
41
+ self.refiners = []
42
+ self.extenders = []
43
+
44
+
45
+ def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
46
+ for refiner_class in refiner_classes:
47
+ refiner = refiner_class.from_model_manager(model_manager)
48
+ self.refiners.append(refiner)
49
+
50
+ def load_prompt_extenders(self,model_manager:ModelManager,extender_classes=[]):
51
+ for extender_class in extender_classes:
52
+ extender = extender_class.from_model_manager(model_manager)
53
+ self.extenders.append(extender)
54
+
55
+
56
+ @torch.no_grad()
57
+ def process_prompt(self, prompt, positive=True):
58
+ if isinstance(prompt, list):
59
+ prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt]
60
+ else:
61
+ for refiner in self.refiners:
62
+ prompt = refiner(prompt, positive=positive)
63
+ return prompt
64
+
65
+ @torch.no_grad()
66
+ def extend_prompt(self, prompt:str, positive=True):
67
+ extended_prompt = dict(prompt=prompt)
68
+ for extender in self.extenders:
69
+ extended_prompt = extender(extended_prompt)
70
+ return extended_prompt
dkt/prompters/wan_prompter.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_prompter import BasePrompter
2
+ from ..models.wan_video_text_encoder import WanTextEncoder
3
+ from transformers import AutoTokenizer
4
+ import os, torch
5
+ import ftfy
6
+ import html
7
+ import string
8
+ import regex as re
9
+
10
+
11
+ def basic_clean(text):
12
+ text = ftfy.fix_text(text)
13
+ text = html.unescape(html.unescape(text))
14
+ return text.strip()
15
+
16
+
17
+ def whitespace_clean(text):
18
+ text = re.sub(r'\s+', ' ', text)
19
+ text = text.strip()
20
+ return text
21
+
22
+
23
+ def canonicalize(text, keep_punctuation_exact_string=None):
24
+ text = text.replace('_', ' ')
25
+ if keep_punctuation_exact_string:
26
+ text = keep_punctuation_exact_string.join(
27
+ part.translate(str.maketrans('', '', string.punctuation))
28
+ for part in text.split(keep_punctuation_exact_string))
29
+ else:
30
+ text = text.translate(str.maketrans('', '', string.punctuation))
31
+ text = text.lower()
32
+ text = re.sub(r'\s+', ' ', text)
33
+ return text.strip()
34
+
35
+
36
+ class HuggingfaceTokenizer:
37
+
38
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
39
+ assert clean in (None, 'whitespace', 'lower', 'canonicalize')
40
+ self.name = name
41
+ self.seq_len = seq_len
42
+ self.clean = clean
43
+
44
+ # init tokenizer
45
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
46
+ self.vocab_size = self.tokenizer.vocab_size
47
+
48
+ def __call__(self, sequence, **kwargs):
49
+ return_mask = kwargs.pop('return_mask', False)
50
+
51
+ # arguments
52
+ _kwargs = {'return_tensors': 'pt'}
53
+ if self.seq_len is not None:
54
+ _kwargs.update({
55
+ 'padding': 'max_length',
56
+ 'truncation': True,
57
+ 'max_length': self.seq_len
58
+ })
59
+ _kwargs.update(**kwargs)
60
+
61
+ # tokenization
62
+ if isinstance(sequence, str):
63
+ sequence = [sequence]
64
+ if self.clean:
65
+ sequence = [self._clean(u) for u in sequence]
66
+ ids = self.tokenizer(sequence, **_kwargs)
67
+
68
+ # output
69
+ if return_mask:
70
+ return ids.input_ids, ids.attention_mask
71
+ else:
72
+ return ids.input_ids
73
+
74
+ def _clean(self, text):
75
+ if self.clean == 'whitespace':
76
+ text = whitespace_clean(basic_clean(text))
77
+ elif self.clean == 'lower':
78
+ text = whitespace_clean(basic_clean(text)).lower()
79
+ elif self.clean == 'canonicalize':
80
+ text = canonicalize(basic_clean(text))
81
+ return text
82
+
83
+
84
+ class WanPrompter(BasePrompter):
85
+
86
+ def __init__(self, tokenizer_path=None, text_len=512):
87
+ super().__init__()
88
+ self.text_len = text_len
89
+ self.text_encoder = None
90
+ self.fetch_tokenizer(tokenizer_path)
91
+
92
+ def fetch_tokenizer(self, tokenizer_path=None):
93
+ if tokenizer_path is not None:
94
+ self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace')
95
+
96
+ def fetch_models(self, text_encoder: WanTextEncoder = None):
97
+ self.text_encoder = text_encoder
98
+
99
+ def encode_prompt(self, prompt, positive=True, device="cuda"):
100
+ prompt = self.process_prompt(prompt, positive=positive)
101
+
102
+ ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
103
+ ids = ids.to(device)
104
+ mask = mask.to(device)
105
+ seq_lens = mask.gt(0).sum(dim=1).long()
106
+ prompt_emb = self.text_encoder(ids, mask)
107
+ for i, v in enumerate(seq_lens):
108
+ prompt_emb[:, v:] = 0
109
+ return prompt_emb
dkt/schedulers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flow_match import FlowMatchScheduler
dkt/schedulers/flow_match.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+
3
+
4
+
5
+ class FlowMatchScheduler():
6
+
7
+ def __init__(
8
+ self,
9
+ num_inference_steps=100,
10
+ num_train_timesteps=1000,
11
+ shift=3.0,
12
+ sigma_max=1.0,
13
+ sigma_min=0.003/1.002,
14
+ inverse_timesteps=False,
15
+ extra_one_step=False,
16
+ reverse_sigmas=False,
17
+ exponential_shift=False,
18
+ exponential_shift_mu=None,
19
+ shift_terminal=None,
20
+ ):
21
+ self.num_train_timesteps = num_train_timesteps
22
+ self.shift = shift
23
+ self.sigma_max = sigma_max
24
+ self.sigma_min = sigma_min
25
+ self.inverse_timesteps = inverse_timesteps
26
+ self.extra_one_step = extra_one_step
27
+ self.reverse_sigmas = reverse_sigmas
28
+ self.exponential_shift = exponential_shift
29
+ self.exponential_shift_mu = exponential_shift_mu
30
+ self.shift_terminal = shift_terminal
31
+ self.set_timesteps(num_inference_steps)
32
+
33
+
34
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None):
35
+ if shift is not None:
36
+ self.shift = shift
37
+ sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
38
+ if self.extra_one_step:
39
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
40
+ else:
41
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
42
+
43
+ if self.inverse_timesteps:
44
+ self.sigmas = torch.flip(self.sigmas, dims=[0])
45
+ if self.exponential_shift:
46
+ mu = self.calculate_shift(dynamic_shift_len) if dynamic_shift_len is not None else self.exponential_shift_mu
47
+ self.sigmas = math.exp(mu) / (math.exp(mu) + (1 / self.sigmas - 1))
48
+ else:
49
+ self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
50
+ if self.shift_terminal is not None:
51
+ one_minus_z = 1 - self.sigmas
52
+ scale_factor = one_minus_z[-1] / (1 - self.shift_terminal)
53
+ self.sigmas = 1 - (one_minus_z / scale_factor)
54
+ if self.reverse_sigmas:
55
+ self.sigmas = 1 - self.sigmas
56
+ self.timesteps = self.sigmas * self.num_train_timesteps
57
+
58
+ if training:
59
+ x = self.timesteps
60
+ y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
61
+ y_shifted = y - y.min()
62
+ bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
63
+ self.linear_timesteps_weights = bsmntw_weighing
64
+ self.training = True
65
+ else:
66
+ self.training = False
67
+
68
+
69
+ def step(self, model_output, timestep, sample, to_final=False, **kwargs):
70
+ if isinstance(timestep, torch.Tensor):
71
+ timestep = timestep.cpu()
72
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
73
+
74
+ sigma = self.sigmas[timestep_id] #* 当前的sigma值, 也就是加了多少比例的噪声
75
+ if to_final or timestep_id + 1 >= len(self.timesteps): #* 下一步的噪声比例,
76
+ sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
77
+ else:
78
+ sigma_ = self.sigmas[timestep_id + 1]
79
+ prev_sample = sample + model_output * (sigma_ - sigma)
80
+ return prev_sample
81
+
82
+
83
+ def return_to_timestep(self, timestep, sample, sample_stablized):
84
+ if isinstance(timestep, torch.Tensor):
85
+ timestep = timestep.cpu()
86
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
87
+ sigma = self.sigmas[timestep_id]
88
+ model_output = (sample - sample_stablized) / sigma
89
+ return model_output
90
+
91
+
92
+ def add_noise(self, original_samples, noise, timestep):
93
+ if isinstance(timestep, torch.Tensor):
94
+ timestep = timestep.cpu()
95
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
96
+ sigma = self.sigmas[timestep_id]
97
+ sample = (1 - sigma) * original_samples + sigma * noise
98
+ return sample
99
+
100
+
101
+ def training_target(self, sample, noise, timestep):
102
+ #* so: noise - target = sample
103
+ #* sample + target = noise
104
+ target = noise - sample
105
+ #* noise: is rgb images
106
+ return target
107
+
108
+
109
+ def training_weight(self, timestep):
110
+ timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
111
+ weights = self.linear_timesteps_weights[timestep_id]
112
+ return weights
113
+
114
+
115
+ def calculate_shift(
116
+ self,
117
+ image_seq_len,
118
+ base_seq_len: int = 256,
119
+ max_seq_len: int = 8192,
120
+ base_shift: float = 0.5,
121
+ max_shift: float = 0.9,
122
+ ):
123
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
124
+ b = base_shift - m * base_seq_len
125
+ mu = image_seq_len * m + b
126
+ return mu
dkt/utils/__init__.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, warnings, glob, os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from einops import repeat, reduce
5
+ from typing import Optional, Union
6
+ from dataclasses import dataclass
7
+ from modelscope import snapshot_download
8
+ import numpy as np
9
+ from PIL import Image
10
+ from typing import Optional
11
+
12
+
13
+ class BasePipeline(torch.nn.Module):
14
+
15
+ def __init__(
16
+ self,
17
+ device="cuda", torch_dtype=torch.float16,
18
+ height_division_factor=64, width_division_factor=64,
19
+ time_division_factor=None, time_division_remainder=None,
20
+ ):
21
+ super().__init__()
22
+ # The device and torch_dtype is used for the storage of intermediate variables, not models.
23
+ self.device = device
24
+ self.torch_dtype = torch_dtype
25
+ # The following parameters are used for shape check.
26
+ self.height_division_factor = height_division_factor
27
+ self.width_division_factor = width_division_factor
28
+ self.time_division_factor = time_division_factor
29
+ self.time_division_remainder = time_division_remainder
30
+ self.vram_management_enabled = False
31
+
32
+
33
+ def to(self, *args, **kwargs):
34
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
35
+ if device is not None:
36
+ self.device = device
37
+ if dtype is not None:
38
+ self.torch_dtype = dtype
39
+ super().to(*args, **kwargs)
40
+ return self
41
+
42
+
43
+ def check_resize_height_width(self, height, width, num_frames=None):
44
+ # Shape check
45
+ if height % self.height_division_factor != 0:
46
+ height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
47
+ print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
48
+ if width % self.width_division_factor != 0:
49
+ width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
50
+ print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
51
+ if num_frames is None:
52
+ return height, width
53
+ else:
54
+ if num_frames % self.time_division_factor != self.time_division_remainder:
55
+ num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
56
+ print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
57
+ return height, width, num_frames
58
+
59
+
60
+ def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
61
+ # Transform a PIL.Image to torch.Tensor
62
+ image = torch.Tensor(np.array(image, dtype=np.float32))
63
+ image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
64
+ image = image * ((max_value - min_value) / 255) + min_value
65
+ image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
66
+ return image
67
+
68
+
69
+ def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
70
+ # Transform a list of PIL.Image to torch.Tensor
71
+ video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
72
+ video = torch.stack(video, dim=pattern.index("T") // 2)
73
+ return video
74
+
75
+
76
+ def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
77
+ # Transform a torch.Tensor to PIL.Image
78
+ if pattern != "H W C":
79
+ vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
80
+ image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
81
+ image = image.to(device="cpu", dtype=torch.uint8)
82
+ image = Image.fromarray(image.numpy())
83
+ return image
84
+
85
+
86
+ def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
87
+ # Transform a torch.Tensor to list of PIL.Image
88
+ if pattern != "T H W C":
89
+ vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
90
+ video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
91
+ return video
92
+
93
+
94
+ def load_models_to_device(self, model_names=[]):
95
+ if self.vram_management_enabled:
96
+ # offload models
97
+ for name, model in self.named_children():
98
+ if name not in model_names:
99
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
100
+ for module in model.modules():
101
+ if hasattr(module, "offload"):
102
+ module.offload()
103
+ else:
104
+ model.cpu()
105
+ torch.cuda.empty_cache()
106
+ # onload models
107
+ for name, model in self.named_children():
108
+ if name in model_names:
109
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
110
+ for module in model.modules():
111
+ if hasattr(module, "onload"):
112
+ module.onload()
113
+ else:
114
+ model.to(self.device)
115
+
116
+
117
+ def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
118
+ # Initialize Gaussian noise
119
+ generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
120
+ noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
121
+ noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
122
+ return noise
123
+
124
+
125
+ def enable_cpu_offload(self):
126
+ warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.")
127
+ self.vram_management_enabled = True
128
+
129
+
130
+ def get_vram(self):
131
+ return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
132
+
133
+
134
+ def freeze_except(self, model_names):
135
+ for name, model in self.named_children():
136
+ if name in model_names:
137
+ model.train()
138
+ model.requires_grad_(True)
139
+ else:
140
+ model.eval()
141
+ model.requires_grad_(False)
142
+
143
+
144
+ @dataclass
145
+ class ModelConfig:
146
+ path: Union[str, list[str]] = None
147
+ model_id: str = None
148
+ origin_file_pattern: Union[str, list[str]] = None
149
+ download_resource: str = "ModelScope"
150
+ offload_device: Optional[Union[str, torch.device]] = None
151
+ offload_dtype: Optional[torch.dtype] = None
152
+ local_model_path: str = None
153
+ skip_download: bool = False
154
+
155
+ def download_if_necessary(self, use_usp=False):
156
+ if self.path is None:
157
+ # Check model_id and origin_file_pattern
158
+ if self.model_id is None:
159
+ raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
160
+
161
+ # Skip if not in rank 0
162
+ if use_usp:
163
+ import torch.distributed as dist
164
+ skip_download = self.skip_download or dist.get_rank() != 0
165
+ else:
166
+ skip_download = self.skip_download
167
+
168
+ # Check whether the origin path is a folder
169
+ if self.origin_file_pattern is None or self.origin_file_pattern == "":
170
+ self.origin_file_pattern = ""
171
+ allow_file_pattern = None
172
+ is_folder = True
173
+ elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"):
174
+ allow_file_pattern = self.origin_file_pattern + "*"
175
+ is_folder = True
176
+ else:
177
+ allow_file_pattern = self.origin_file_pattern
178
+ is_folder = False
179
+
180
+ # Download
181
+ if self.local_model_path is None:
182
+ self.local_model_path = "./models"
183
+ if not skip_download:
184
+ downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
185
+ snapshot_download(
186
+ self.model_id,
187
+ local_dir=os.path.join(self.local_model_path, self.model_id),
188
+ allow_file_pattern=allow_file_pattern,
189
+ ignore_file_pattern=downloaded_files,
190
+ local_files_only=False
191
+ )
192
+
193
+ # Let rank 1, 2, ... wait for rank 0
194
+ if use_usp:
195
+ import torch.distributed as dist
196
+ dist.barrier(device_ids=[dist.get_rank()])
197
+
198
+ # Return downloaded files
199
+ if is_folder:
200
+ self.path = os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern)
201
+ else:
202
+ self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
203
+ if isinstance(self.path, list) and len(self.path) == 1:
204
+ self.path = self.path[0]
205
+
206
+
207
+
208
+ class PipelineUnit:
209
+ def __init__(
210
+ self,
211
+ seperate_cfg: bool = False,
212
+ take_over: bool = False,
213
+ input_params: tuple[str] = None,
214
+ input_params_posi: dict[str, str] = None,
215
+ input_params_nega: dict[str, str] = None,
216
+ onload_model_names: tuple[str] = None
217
+ ):
218
+ self.seperate_cfg = seperate_cfg
219
+ self.take_over = take_over
220
+ self.input_params = input_params
221
+ self.input_params_posi = input_params_posi
222
+ self.input_params_nega = input_params_nega
223
+ self.onload_model_names = onload_model_names
224
+
225
+
226
+ def process(self, pipe: BasePipeline, inputs: dict, positive=True, **kwargs) -> dict:
227
+ raise NotImplementedError("`process` is not implemented.")
228
+
229
+
230
+
231
+ class PipelineUnitRunner:
232
+ def __init__(self):
233
+ pass
234
+
235
+ def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
236
+ if unit.take_over:
237
+ # Let the pipeline unit take over this function.
238
+ inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
239
+ elif unit.seperate_cfg:
240
+ # Positive side
241
+ processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
242
+ if unit.input_params is not None:
243
+ for name in unit.input_params:
244
+ processor_inputs[name] = inputs_shared.get(name)
245
+ processor_outputs = unit.process(pipe, **processor_inputs)
246
+ inputs_posi.update(processor_outputs)
247
+ # Negative side
248
+ if inputs_shared["cfg_scale"] != 1:
249
+ processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
250
+ if unit.input_params is not None:
251
+ for name in unit.input_params:
252
+ processor_inputs[name] = inputs_shared.get(name)
253
+ processor_outputs = unit.process(pipe, **processor_inputs)
254
+ inputs_nega.update(processor_outputs)
255
+ else:
256
+ inputs_nega.update(processor_outputs)
257
+ else:
258
+ processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
259
+ processor_outputs = unit.process(pipe, **processor_inputs)
260
+ inputs_shared.update(processor_outputs)
261
+ return inputs_shared, inputs_posi, inputs_nega
dkt/vram_management/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .layers import *
2
+ from .gradient_checkpointing import *
dkt/vram_management/gradient_checkpointing.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def create_custom_forward(module):
5
+ def custom_forward(*inputs, **kwargs):
6
+ return module(*inputs, **kwargs)
7
+ return custom_forward
8
+
9
+
10
+ def gradient_checkpoint_forward(
11
+ model,
12
+ use_gradient_checkpointing,
13
+ use_gradient_checkpointing_offload,
14
+ *args,
15
+ **kwargs,
16
+ ):
17
+ if use_gradient_checkpointing_offload:
18
+ with torch.autograd.graph.save_on_cpu():
19
+ model_output = torch.utils.checkpoint.checkpoint(
20
+ create_custom_forward(model),
21
+ *args,
22
+ **kwargs,
23
+ use_reentrant=False,
24
+ )
25
+ elif use_gradient_checkpointing:
26
+ model_output = torch.utils.checkpoint.checkpoint(
27
+ create_custom_forward(model),
28
+ *args,
29
+ **kwargs,
30
+ use_reentrant=False,
31
+ )
32
+ else:
33
+ model_output = model(*args, **kwargs)
34
+ return model_output
dkt/vram_management/layers.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, copy
2
+ from typing import Union
3
+ from ..models.utils import init_weights_on_device
4
+
5
+
6
+ def cast_to(weight, dtype, device):
7
+ r = torch.empty_like(weight, dtype=dtype, device=device)
8
+ r.copy_(weight)
9
+ return r
10
+
11
+
12
+ class AutoTorchModule(torch.nn.Module):
13
+ def __init__(self):
14
+ super().__init__()
15
+
16
+ def check_free_vram(self):
17
+ gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
18
+ used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024 ** 3)
19
+ return used_memory < self.vram_limit
20
+
21
+ def offload(self):
22
+ if self.state != 0:
23
+ self.to(dtype=self.offload_dtype, device=self.offload_device)
24
+ self.state = 0
25
+
26
+ def onload(self):
27
+ if self.state != 1:
28
+ self.to(dtype=self.onload_dtype, device=self.onload_device)
29
+ self.state = 1
30
+
31
+ def keep(self):
32
+ if self.state != 2:
33
+ self.to(dtype=self.computation_dtype, device=self.computation_device)
34
+ self.state = 2
35
+
36
+
37
+ class AutoWrappedModule(AutoTorchModule):
38
+ def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs):
39
+ super().__init__()
40
+ self.module = module.to(dtype=offload_dtype, device=offload_device)
41
+ self.offload_dtype = offload_dtype
42
+ self.offload_device = offload_device
43
+ self.onload_dtype = onload_dtype
44
+ self.onload_device = onload_device
45
+ self.computation_dtype = computation_dtype
46
+ self.computation_device = computation_device
47
+ self.vram_limit = vram_limit
48
+ self.state = 0
49
+
50
+ def forward(self, *args, **kwargs):
51
+ if self.state == 2:
52
+ module = self.module
53
+ else:
54
+ if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
55
+ module = self.module
56
+ elif self.vram_limit is not None and self.check_free_vram():
57
+ self.keep()
58
+ module = self.module
59
+ else:
60
+ module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
61
+ return module(*args, **kwargs)
62
+
63
+
64
+ class WanAutoCastLayerNorm(torch.nn.LayerNorm, AutoTorchModule):
65
+ def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs):
66
+ with init_weights_on_device(device=torch.device("meta")):
67
+ super().__init__(module.normalized_shape, eps=module.eps, elementwise_affine=module.elementwise_affine, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
68
+ self.weight = module.weight
69
+ self.bias = module.bias
70
+ self.offload_dtype = offload_dtype
71
+ self.offload_device = offload_device
72
+ self.onload_dtype = onload_dtype
73
+ self.onload_device = onload_device
74
+ self.computation_dtype = computation_dtype
75
+ self.computation_device = computation_device
76
+ self.vram_limit = vram_limit
77
+ self.state = 0
78
+
79
+ def forward(self, x, *args, **kwargs):
80
+ if self.state == 2:
81
+ weight, bias = self.weight, self.bias
82
+ else:
83
+ if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
84
+ weight, bias = self.weight, self.bias
85
+ elif self.vram_limit is not None and self.check_free_vram():
86
+ self.keep()
87
+ weight, bias = self.weight, self.bias
88
+ else:
89
+ weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device)
90
+ bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
91
+ with torch.amp.autocast(device_type=x.device.type):
92
+ x = torch.nn.functional.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).type_as(x)
93
+ return x
94
+
95
+
96
+ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
97
+ def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, name="", **kwargs):
98
+ with init_weights_on_device(device=torch.device("meta")):
99
+ super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
100
+ self.weight = module.weight
101
+ self.bias = module.bias
102
+ self.offload_dtype = offload_dtype
103
+ self.offload_device = offload_device
104
+ self.onload_dtype = onload_dtype
105
+ self.onload_device = onload_device
106
+ self.computation_dtype = computation_dtype
107
+ self.computation_device = computation_device
108
+ self.vram_limit = vram_limit
109
+ self.state = 0
110
+ self.name = name
111
+ self.lora_A_weights = []
112
+ self.lora_B_weights = []
113
+ self.lora_merger = None
114
+ self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
115
+
116
+ def fp8_linear(
117
+ self,
118
+ input: torch.Tensor,
119
+ weight: torch.Tensor,
120
+ bias: Union[torch.Tensor, None] = None):
121
+ device = input.device
122
+ origin_dtype = input.dtype
123
+ origin_shape = input.shape
124
+ input = input.reshape(-1, origin_shape[-1])
125
+
126
+ x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
127
+ fp8_max = 448.0
128
+ # For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
129
+ # To avoid overflow and ensure numerical compatibility during FP8 computation,
130
+ # we scale down the input by 2.0 in advance.
131
+ # This scaling will be compensated later during the final result scaling.
132
+ if self.computation_dtype == torch.float8_e4m3fnuz:
133
+ fp8_max = fp8_max / 2.0
134
+ scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
135
+ scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
136
+ input = input / (scale_a + 1e-8)
137
+ input = input.to(self.computation_dtype)
138
+ weight = weight.to(self.computation_dtype)
139
+ bias = bias.to(torch.bfloat16)
140
+
141
+ result = torch._scaled_mm(
142
+ input,
143
+ weight.T,
144
+ scale_a=scale_a,
145
+ scale_b=scale_b.T,
146
+ bias=bias,
147
+ out_dtype=origin_dtype,
148
+ )
149
+ new_shape = origin_shape[:-1] + result.shape[-1:]
150
+ result = result.reshape(new_shape)
151
+ return result
152
+
153
+ def forward(self, x, *args, **kwargs):
154
+ # VRAM management
155
+ if self.state == 2:
156
+ weight, bias = self.weight, self.bias
157
+ else:
158
+ if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
159
+ weight, bias = self.weight, self.bias
160
+ elif self.vram_limit is not None and self.check_free_vram():
161
+ self.keep()
162
+ weight, bias = self.weight, self.bias
163
+ else:
164
+ weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
165
+ bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
166
+
167
+ # Linear forward
168
+ if self.enable_fp8:
169
+ out = self.fp8_linear(x, weight, bias)
170
+ else:
171
+ out = torch.nn.functional.linear(x, weight, bias)
172
+
173
+ # LoRA
174
+ if len(self.lora_A_weights) == 0:
175
+ # No LoRA
176
+ return out
177
+ elif self.lora_merger is None:
178
+ # Native LoRA inference
179
+ for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
180
+ out = out + x @ lora_A.T @ lora_B.T
181
+ else:
182
+ # LoRA fusion
183
+ lora_output = []
184
+ for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
185
+ lora_output.append(x @ lora_A.T @ lora_B.T)
186
+ lora_output = torch.stack(lora_output)
187
+ out = self.lora_merger(out, lora_output)
188
+ return out
189
+
190
+
191
+ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None, name_prefix=""):
192
+ for name, module in model.named_children():
193
+ layer_name = name if name_prefix == "" else name_prefix + "." + name
194
+ for source_module, target_module in module_map.items():
195
+ if isinstance(module, source_module):
196
+ num_param = sum(p.numel() for p in module.parameters())
197
+ if max_num_param is not None and total_num_param + num_param > max_num_param:
198
+ module_config_ = overflow_module_config
199
+ else:
200
+ module_config_ = module_config
201
+ module_ = target_module(module, **module_config_, vram_limit=vram_limit, name=layer_name)
202
+ setattr(model, name, module_)
203
+ total_num_param += num_param
204
+ break
205
+ else:
206
+ total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param, vram_limit=vram_limit, name_prefix=layer_name)
207
+ return total_num_param
208
+
209
+
210
+ def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, vram_limit=None):
211
+ enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0, vram_limit=vram_limit)
212
+ model.vram_management_enabled = True
213
+
examples/1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68ae88d1729f13d2eba6b9c3ae265af24ba563c725cd6ec7430fa9cf2a8f3584
3
+ size 695111
examples/10.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33faeefd4d21ca9ddc6386f0c7a83523632901d7f791fb5bd307b43739235c3d
3
+ size 3886742
examples/178db6e89ab682bfc612a3290fec58dd.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:803bea53274f55f02463af8855585a0e4950ec1ae498ed1e6ef261d83d38b371
3
+ size 1552729
examples/18.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38a34cc7a7bd060e1cb891d35457b37e2ad91e8fac457273367500b65a8e1eb8
3
+ size 1091805
examples/1b0daeb776471c7389b36cee53049417.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c89828ebf754c762daa916ba764293193986fde8e2ff0e70be66926fbc9a8d07
3
+ size 1447735
examples/2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6dd9a4527924fa2c26fd5bcc85237951180dcc136bac88d08bcd78635de58848
3
+ size 883548
examples/27.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6959f0dc0f5fde449cab87a0db04f2a14c23bc3df414b9e0474bd45bf901fbb
3
+ size 893079
examples/28.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd389e1fa12d73f8ded9181f9aa2b2586f8e150e855e96bf1df3c6420062bd12
3
+ size 605351
examples/3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da710ac25e12ef740d0bac04a452d1da6078a7acfe35bd456ab8a659a81401ff
3
+ size 628311
examples/30.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0aee9bbb44a52ffd26424b4a0d804ee236e56edddadbcec5b7b6d79b8b2464e
3
+ size 2677102
examples/31.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d0060e4564595ecc6397b7516ffc906beab7ff2a11971be2c9cd0e7807e6772
3
+ size 569935
examples/32.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:209d7f97bd9e5c881789d35859c06840a7321db7a1f5cfe285dadf18d84e847c
3
+ size 1593158
examples/33.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54c8de91a05ceb1c01c3bcc267e527a55adffcf6dee1714629bb5b09850a38fd
3
+ size 918682
examples/35.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9fba530d137d1ebeae598f502c4b59f7693d87738038d6f726512d00b54a6e4
3
+ size 952985
examples/36.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eff2bff79f77dadc26bb74623b50ef6f9e6b37b68247a0779f48b57822e5f74d
3
+ size 1008476
examples/39.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df7ff152c9835a20d95d3ccf8bd0386ed8fd25ed393b7eb2b0f669f7866124e3
3
+ size 740198
examples/4.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1eb80b99247cc6ca122bc51b7f255c65fcafe4d35accae84470f8ebd8be58993
3
+ size 244113
examples/40.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:deb684b15a71152cc7377957946042e9e8fa8efce619f7801de8cb91eb3c1e82
3
+ size 1021691
examples/5.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:631fc39bbde6f098c4f539f87c7c43de16a0d117ec51d16976d4b4e7e7279bc0
3
+ size 7130292