rookie9 commited on
Commit
79f3e78
·
verified ·
1 Parent(s): cb9f65f

Upload 77 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. config.json +103 -0
  2. ezaudio_vae/1m.pt +3 -0
  3. model.py +151 -0
  4. model_index.json +6 -0
  5. models/__pycache__/common.cpython-310.pyc +0 -0
  6. models/__pycache__/content_adapter.cpython-310.pyc +0 -0
  7. models/__pycache__/diffusion.cpython-310.pyc +0 -0
  8. models/__pycache__/diffusion_cfg.cpython-310.pyc +0 -0
  9. models/__pycache__/diffusion_cfg_new.cpython-310.pyc +0 -0
  10. models/__pycache__/diffusion_content_cfg.cpython-310.pyc +0 -0
  11. models/autoencoder/__pycache__/autoencoder_base.cpython-310.pyc +0 -0
  12. models/autoencoder/autoencoder_base.py +22 -0
  13. models/autoencoder/waveform/__pycache__/stable_vae.cpython-310.pyc +0 -0
  14. models/autoencoder/waveform/stable_vae.py +537 -0
  15. models/common.py +69 -0
  16. models/content_encoder/__pycache__/caption_encoder.cpython-310.pyc +0 -0
  17. models/content_encoder/__pycache__/content_encoder.cpython-310.pyc +0 -0
  18. models/content_encoder/__pycache__/content_encoder_add_1024.cpython-310.pyc +0 -0
  19. models/content_encoder/__pycache__/content_encoder_clap.cpython-310.pyc +0 -0
  20. models/content_encoder/__pycache__/content_encoder_clap_test.cpython-310.pyc +0 -0
  21. models/content_encoder/__pycache__/content_encoder_concat.cpython-310.pyc +0 -0
  22. models/content_encoder/__pycache__/content_encoder_concat_4096.cpython-310.pyc +0 -0
  23. models/content_encoder/__pycache__/content_encoder_concat_4096_random.cpython-310.pyc +0 -0
  24. models/content_encoder/__pycache__/content_encoder_full.cpython-310.pyc +0 -0
  25. models/content_encoder/__pycache__/content_encoder_full_non.cpython-310.pyc +0 -0
  26. models/content_encoder/__pycache__/content_encoder_full_non_test.cpython-310.pyc +0 -0
  27. models/content_encoder/__pycache__/content_encoder_full_test.cpython-310.pyc +0 -0
  28. models/content_encoder/__pycache__/content_encoder_full_woonset.cpython-310.pyc +0 -0
  29. models/content_encoder/__pycache__/content_encoder_merge.cpython-310.pyc +0 -0
  30. models/content_encoder/__pycache__/content_encoder_merge_test.cpython-310.pyc +0 -0
  31. models/content_encoder/__pycache__/content_encoder_replace.cpython-310.pyc +0 -0
  32. models/content_encoder/__pycache__/content_encoder_replace_merge.cpython-310.pyc +0 -0
  33. models/content_encoder/__pycache__/content_encoder_replace_new.cpython-310.pyc +0 -0
  34. models/content_encoder/__pycache__/content_encoder_test.cpython-310.pyc +0 -0
  35. models/content_encoder/__pycache__/content_test.cpython-310.pyc +0 -0
  36. models/content_encoder/__pycache__/new_content_encoder.cpython-310.pyc +0 -0
  37. models/content_encoder/__pycache__/text_encoder.cpython-310.pyc +0 -0
  38. models/content_encoder/caption_encoder.py +116 -0
  39. models/content_encoder/text_encoder.py +76 -0
  40. models/diffusion.py +398 -0
  41. models/dit/__pycache__/attention.cpython-310.pyc +0 -0
  42. models/dit/__pycache__/audio_dit.cpython-310.pyc +0 -0
  43. models/dit/__pycache__/mask_dit.cpython-310.pyc +0 -0
  44. models/dit/__pycache__/modules.cpython-310.pyc +0 -0
  45. models/dit/__pycache__/rotary.cpython-310.pyc +0 -0
  46. models/dit/__pycache__/span_mask.cpython-310.pyc +0 -0
  47. models/dit/attention.py +350 -0
  48. models/dit/audio_diffsingernet_dit.py +520 -0
  49. models/dit/audio_dit.py +549 -0
  50. models/dit/mask_dit.py +823 -0
config.json ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "PicoAudio2",
3
+ "autoencoder": {
4
+ "_target_": "models.autoencoder.waveform.stable_vae.StableVAE",
5
+ "encoder": {
6
+ "_target_": "models.autoencoder.waveform.stable_vae.OobleckEncoder",
7
+ "in_channels": 1,
8
+ "channels": 128,
9
+ "c_mults": [
10
+ 1,
11
+ 2,
12
+ 4,
13
+ 8
14
+ ],
15
+ "strides": [
16
+ 2,
17
+ 4,
18
+ 6,
19
+ 10
20
+ ],
21
+ "latent_dim": 256,
22
+ "use_snake": true
23
+ },
24
+ "decoder": {
25
+ "_target_": "models.autoencoder.waveform.stable_vae.OobleckDecoder",
26
+ "out_channels": 1,
27
+ "channels": 128,
28
+ "c_mults": [
29
+ 1,
30
+ 2,
31
+ 4,
32
+ 8
33
+ ],
34
+ "strides": [
35
+ 2,
36
+ 4,
37
+ 6,
38
+ 10
39
+ ],
40
+ "latent_dim": 128,
41
+ "use_snake": true,
42
+ "final_tanh": false
43
+ },
44
+ "io_channels": 1,
45
+ "latent_dim": 128,
46
+ "downsampling_ratio": 480,
47
+ "sample_rate": 24000,
48
+ "pretrained_ckpt": "/mnt/petrelfs/zhengzihao/cache/ezaudio/ckpts/vae/1m.pt",
49
+ "bottleneck": {
50
+ "_target_": "models.autoencoder.waveform.stable_vae.VAEBottleneck"
51
+ }
52
+ },
53
+ "backbone": {
54
+ "_target_": "models.dit.audio_dit.AudioUDiT",
55
+ "img_size": 1000,
56
+ "patch_size": 1,
57
+ "in_chans": 128,
58
+ "out_chans": 128,
59
+ "input_type": "1d",
60
+ "embed_dim": 1024,
61
+ "depth": 24,
62
+ "num_heads": 16,
63
+ "mlp_ratio": 4.0,
64
+ "qkv_bias": false,
65
+ "qk_scale": null,
66
+ "qk_norm": "layernorm",
67
+ "norm_layer": "layernorm",
68
+ "act_layer": "geglu",
69
+ "context_norm": true,
70
+ "use_checkpoint": true,
71
+ "time_fusion": "ada_sola_bias",
72
+ "ada_sola_rank": 32,
73
+ "ada_sola_alpha": 32,
74
+ "cls_dim": null,
75
+ "ta_context_dim": 1024,
76
+ "ta_context_fusion": "concat",
77
+ "ta_context_norm": true,
78
+ "context_dim": 1024,
79
+ "context_fusion": "cross",
80
+ "context_max_length": null,
81
+ "context_pe_method": "none",
82
+ "pe_method": "none",
83
+ "rope_mode": "shared",
84
+ "use_conv": true,
85
+ "skip": true,
86
+ "skip_norm": true
87
+ },
88
+ "_target_": "models.diffusion.AudioDiffusion",
89
+ "content_encoder": {
90
+ "_target_": "models.content_encoder.caption_encoder.ContentEncoder",
91
+ "text_encoder": {
92
+ "_target_": "models.content_encoder.text_encoder.T5TextEncoder",
93
+ "model_name": "/mnt/petrelfs/zhengzihao/cache/google-flan-t5-large"
94
+ }
95
+ },
96
+ "frame_resolution": 0.005,
97
+ "noise_scheduler_name": "/mnt/petrelfs/zhengzihao/cache/stabilityai-stable-diffusion-2-1",
98
+ "snr_gamma": 5.0,
99
+ "classifier_free_guidance": true,
100
+ "cfg_drop_ratio": 0.2,
101
+ "num_steps": 50,
102
+ "guidance_scale": 7.5
103
+ }
ezaudio_vae/1m.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3cb13e2699fa922ce6a2b3b4f53c270ec64156e0cc3f3e3645e10cdf98b740dc
3
+ size 183037614
model.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel, PretrainedConfig
3
+ import inspect, importlib
4
+ from safetensors.torch import load_file
5
+ from models.diffusion import AudioDiffusion
6
+
7
+ class PicoAudio2Config(PretrainedConfig):
8
+ model_type = "PicoAudio2"
9
+ def __init__(
10
+ self,
11
+ autoencoder=None,
12
+ content_encoder=None,
13
+ backbone=None,
14
+ frame_resolution: float = 0.005,
15
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
16
+ snr_gamma: float = 5.0,
17
+ classifier_free_guidance: bool = True,
18
+ cfg_drop_ratio: float = 0.2,
19
+ num_steps: int = 50,
20
+ guidance_scale: float = 7.5,
21
+ **kwargs
22
+ ):
23
+ super().__init__(**kwargs)
24
+ self.autoencoder = autoencoder
25
+ self.content_encoder = content_encoder
26
+ self.backbone = backbone
27
+ self.frame_resolution = frame_resolution
28
+ self.noise_scheduler_name = noise_scheduler_name
29
+ self.snr_gamma = snr_gamma
30
+ self.classifier_free_guidance = classifier_free_guidance
31
+ self.cfg_drop_ratio = cfg_drop_ratio
32
+ self.num_steps = num_steps
33
+ self.guidance_scale = guidance_scale
34
+
35
+
36
+ class PicoAudio2HF(PreTrainedModel):
37
+ config_class = PicoAudio2Config
38
+
39
+ def __init__(self, config: PicoAudio2Config):
40
+ super().__init__(config)
41
+
42
+ autoencoder = self._build_submodule(config.autoencoder)
43
+ content_encoder = self.build_content_encoder_from_config(config.content_encoder)
44
+ backbone = self._build_submodule(config.backbone)
45
+
46
+ state_dict = load_file("model.safetensors")
47
+ new_state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
48
+ backbone.load_state_dict(new_state_dict, strict=False, assign=True)
49
+
50
+ self.inner_model = AudioDiffusion(
51
+ autoencoder=autoencoder,
52
+ content_encoder=content_encoder,
53
+ backbone=backbone,
54
+ frame_resolution=config.frame_resolution,
55
+ noise_scheduler_name=config.noise_scheduler_name,
56
+ snr_gamma=config.snr_gamma,
57
+ classifier_free_guidance=config.classifier_free_guidance,
58
+ cfg_drop_ratio=config.cfg_drop_ratio,
59
+ )
60
+ def build_content_encoder_from_config(self, content_encoder_cfg):
61
+ te_cfg = content_encoder_cfg['text_encoder']
62
+ te_mod_path, te_cls_name = te_cfg['_target_'].rsplit('.', 1)
63
+ te_mod = importlib.import_module(te_mod_path)
64
+ TextEncoderClass = getattr(te_mod, te_cls_name)
65
+ text_encoder = TextEncoderClass(model_name=te_cfg['model_name'])
66
+
67
+ ce_mod_path, ce_cls_name = content_encoder_cfg['_target_'].rsplit('.', 1)
68
+ ce_mod = importlib.import_module(ce_mod_path)
69
+ ContentEncoderClass = getattr(ce_mod, ce_cls_name)
70
+ content_encoder = ContentEncoderClass(text_encoder=text_encoder)
71
+
72
+ return content_encoder
73
+
74
+ def _build_submodule(self, sub_config):
75
+ import inspect
76
+ if sub_config is None:
77
+ return None
78
+ if isinstance(sub_config, dict) and "_target_" in sub_config:
79
+ kwargs = {}
80
+ for k, v in sub_config.items():
81
+ if k == "_target_":
82
+ continue
83
+ if isinstance(v, dict) and "_target_" in v:
84
+ kwargs[k] = self._build_submodule(v)
85
+ else:
86
+ kwargs[k] = v
87
+ module_path, class_name = sub_config["_target_"].rsplit(".", 1)
88
+ module = __import__(module_path, fromlist=[class_name])
89
+ cls = getattr(module, class_name)
90
+ obj = cls(**kwargs)
91
+ if "pretrained_ckpt" in sub_config:
92
+ state_dict = torch.load(sub_config["pretrained_ckpt"])
93
+ if "state_dict" in state_dict:
94
+ new_state_dict = state_dict["state_dict"]
95
+ state_dict = {k.replace("autoencoder.", ""): v for k, v in new_state_dict.items()}
96
+
97
+ sig = inspect.signature(obj.load_state_dict)
98
+ if "assign" in sig.parameters:
99
+ result = obj.load_state_dict(state_dict, strict=False, assign=True)
100
+ else:
101
+ result = obj.load_state_dict(state_dict, strict=False)
102
+
103
+ self._check_param_stats(obj, class_name)
104
+ return obj
105
+ else:
106
+ return sub_config
107
+ def _check_weights(self, module, name):
108
+ if hasattr(module, "load_state_dict") and hasattr(module, "state_dict"):
109
+ print(f"[{name}] parameter keys:", list(module.state_dict().keys())[:5], "...")
110
+ for idx, (k, v) in enumerate(module.state_dict().items()):
111
+ print(f"[{name}] {k}: mean={v.float().mean():.5f}, std={v.float().std():.5f}")
112
+ if idx >= 2:
113
+ break
114
+
115
+ def _check_param_stats(self, module, name):
116
+ if hasattr(module, "named_parameters"):
117
+ for idx, (k, v) in enumerate(module.named_parameters()):
118
+ print(f"[{name}] {k}: mean={v.data.float().mean():.5f}, std={v.data.float().std():.5f}")
119
+ if idx >= 2:
120
+ break
121
+
122
+ def forward(
123
+ self,
124
+ content,
125
+ num_steps=None,
126
+ guidance_scale=None,
127
+ guidance_rescale=0.0,
128
+ disable_progress=True,
129
+ num_samples_per_content=1,
130
+ **kwargs
131
+ ):
132
+ num_steps = num_steps if num_steps is not None else self.config.num_steps
133
+ guidance_scale = guidance_scale if guidance_scale is not None else self.config.guidance_scale
134
+ return self.inner_model.inference(
135
+ content=[content],
136
+ num_steps=num_steps,
137
+ guidance_scale=guidance_scale,
138
+ guidance_rescale=guidance_rescale,
139
+ disable_progress=disable_progress,
140
+ num_samples_per_content=num_samples_per_content,
141
+ **kwargs
142
+ )
143
+
144
+ @classmethod
145
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
146
+ config = PicoAudio2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
147
+ model = cls(config)
148
+ return model
149
+
150
+ def load_state_dict(self, state_dict, *args, **kwargs):
151
+ pass
model_index.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoConfig": "model.PicoAudio2Config",
4
+ "AutoModel": "model.PicoAudio2HF"
5
+ }
6
+ }
models/__pycache__/common.cpython-310.pyc ADDED
Binary file (3.1 kB). View file
 
models/__pycache__/content_adapter.cpython-310.pyc ADDED
Binary file (3.87 kB). View file
 
models/__pycache__/diffusion.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
models/__pycache__/diffusion_cfg.cpython-310.pyc ADDED
Binary file (18.9 kB). View file
 
models/__pycache__/diffusion_cfg_new.cpython-310.pyc ADDED
Binary file (18.8 kB). View file
 
models/__pycache__/diffusion_content_cfg.cpython-310.pyc ADDED
Binary file (18.5 kB). View file
 
models/autoencoder/__pycache__/autoencoder_base.cpython-310.pyc ADDED
Binary file (1.06 kB). View file
 
models/autoencoder/autoencoder_base.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod, ABC
2
+ from typing import Sequence
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class AutoEncoderBase(ABC):
8
+ def __init__(
9
+ self, downsampling_ratio: int, sample_rate: int,
10
+ latent_shape: Sequence[int | None]
11
+ ):
12
+ self.downsampling_ratio = downsampling_ratio
13
+ self.sample_rate = sample_rate
14
+ self.latent_token_rate = sample_rate // downsampling_ratio
15
+ self.latent_shape = latent_shape
16
+ self.time_dim = latent_shape.index(None) + 1 # the first dim is batch
17
+
18
+ @abstractmethod
19
+ def encode(
20
+ self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
21
+ ) -> tuple[torch.Tensor, torch.Tensor]:
22
+ ...
models/autoencoder/waveform/__pycache__/stable_vae.cpython-310.pyc ADDED
Binary file (12 kB). View file
 
models/autoencoder/waveform/stable_vae.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal, Callable
2
+ import math
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn.utils import weight_norm
8
+ import torchaudio
9
+ from alias_free_torch import Activation1d
10
+
11
+ from models.common import LoadPretrainedBase
12
+ from models.autoencoder.autoencoder_base import AutoEncoderBase
13
+ from utils.torch_utilities import remove_key_prefix_factory, create_mask_from_length
14
+
15
+
16
+ # jit script make it 1.4x faster and save GPU memory
17
+ @torch.jit.script
18
+ def snake_beta(x, alpha, beta):
19
+ return x + (1.0 / (beta+0.000000001)) * pow(torch.sin(x * alpha), 2)
20
+
21
+
22
+ class SnakeBeta(nn.Module):
23
+ def __init__(
24
+ self,
25
+ in_features,
26
+ alpha=1.0,
27
+ alpha_trainable=True,
28
+ alpha_logscale=True
29
+ ):
30
+ super(SnakeBeta, self).__init__()
31
+ self.in_features = in_features
32
+
33
+ # initialize alpha
34
+ self.alpha_logscale = alpha_logscale
35
+ if self.alpha_logscale:
36
+ # log scale alphas initialized to zeros
37
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
38
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
39
+ else:
40
+ # linear scale alphas initialized to ones
41
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
42
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
43
+
44
+ self.alpha.requires_grad = alpha_trainable
45
+ self.beta.requires_grad = alpha_trainable
46
+
47
+ # self.no_div_by_zero = 0.000000001
48
+
49
+ def forward(self, x):
50
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
51
+ # line up with x to [B, C, T]
52
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
53
+ if self.alpha_logscale:
54
+ alpha = torch.exp(alpha)
55
+ beta = torch.exp(beta)
56
+ x = snake_beta(x, alpha, beta)
57
+
58
+ return x
59
+
60
+
61
+ def WNConv1d(*args, **kwargs):
62
+ return weight_norm(nn.Conv1d(*args, **kwargs))
63
+
64
+
65
+ def WNConvTranspose1d(*args, **kwargs):
66
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
67
+
68
+
69
+ def get_activation(
70
+ activation: Literal["elu", "snake", "none"],
71
+ antialias=False,
72
+ channels=None
73
+ ) -> nn.Module:
74
+ if activation == "elu":
75
+ act = nn.ELU()
76
+ elif activation == "snake":
77
+ act = SnakeBeta(channels)
78
+ elif activation == "none":
79
+ act = nn.Identity()
80
+ else:
81
+ raise ValueError(f"Unknown activation {activation}")
82
+
83
+ if antialias:
84
+ act = Activation1d(act)
85
+
86
+ return act
87
+
88
+
89
+ class ResidualUnit(nn.Module):
90
+ def __init__(
91
+ self,
92
+ in_channels,
93
+ out_channels,
94
+ dilation,
95
+ use_snake=False,
96
+ antialias_activation=False
97
+ ):
98
+ super().__init__()
99
+
100
+ self.dilation = dilation
101
+
102
+ padding = (dilation * (7-1)) // 2
103
+
104
+ self.layers = nn.Sequential(
105
+ get_activation(
106
+ "snake" if use_snake else "elu",
107
+ antialias=antialias_activation,
108
+ channels=out_channels
109
+ ),
110
+ WNConv1d(
111
+ in_channels=in_channels,
112
+ out_channels=out_channels,
113
+ kernel_size=7,
114
+ dilation=dilation,
115
+ padding=padding
116
+ ),
117
+ get_activation(
118
+ "snake" if use_snake else "elu",
119
+ antialias=antialias_activation,
120
+ channels=out_channels
121
+ ),
122
+ WNConv1d(
123
+ in_channels=out_channels,
124
+ out_channels=out_channels,
125
+ kernel_size=1
126
+ )
127
+ )
128
+
129
+ def forward(self, x):
130
+ res = x
131
+
132
+ #x = checkpoint(self.layers, x)
133
+ x = self.layers(x)
134
+
135
+ return x + res
136
+
137
+
138
+ class EncoderBlock(nn.Module):
139
+ def __init__(
140
+ self,
141
+ in_channels,
142
+ out_channels,
143
+ stride,
144
+ use_snake=False,
145
+ antialias_activation=False
146
+ ):
147
+ super().__init__()
148
+
149
+ self.layers = nn.Sequential(
150
+ ResidualUnit(
151
+ in_channels=in_channels,
152
+ out_channels=in_channels,
153
+ dilation=1,
154
+ use_snake=use_snake
155
+ ),
156
+ ResidualUnit(
157
+ in_channels=in_channels,
158
+ out_channels=in_channels,
159
+ dilation=3,
160
+ use_snake=use_snake
161
+ ),
162
+ ResidualUnit(
163
+ in_channels=in_channels,
164
+ out_channels=in_channels,
165
+ dilation=9,
166
+ use_snake=use_snake
167
+ ),
168
+ get_activation(
169
+ "snake" if use_snake else "elu",
170
+ antialias=antialias_activation,
171
+ channels=in_channels
172
+ ),
173
+ WNConv1d(
174
+ in_channels=in_channels,
175
+ out_channels=out_channels,
176
+ kernel_size=2 * stride,
177
+ stride=stride,
178
+ padding=math.ceil(stride / 2)
179
+ ),
180
+ )
181
+
182
+ def forward(self, x):
183
+ return self.layers(x)
184
+
185
+
186
+ class DecoderBlock(nn.Module):
187
+ def __init__(
188
+ self,
189
+ in_channels,
190
+ out_channels,
191
+ stride,
192
+ use_snake=False,
193
+ antialias_activation=False,
194
+ use_nearest_upsample=False
195
+ ):
196
+ super().__init__()
197
+
198
+ if use_nearest_upsample:
199
+ upsample_layer = nn.Sequential(
200
+ nn.Upsample(scale_factor=stride, mode="nearest"),
201
+ WNConv1d(
202
+ in_channels=in_channels,
203
+ out_channels=out_channels,
204
+ kernel_size=2 * stride,
205
+ stride=1,
206
+ bias=False,
207
+ padding='same'
208
+ )
209
+ )
210
+ else:
211
+ upsample_layer = WNConvTranspose1d(
212
+ in_channels=in_channels,
213
+ out_channels=out_channels,
214
+ kernel_size=2 * stride,
215
+ stride=stride,
216
+ padding=math.ceil(stride / 2)
217
+ )
218
+
219
+ self.layers = nn.Sequential(
220
+ get_activation(
221
+ "snake" if use_snake else "elu",
222
+ antialias=antialias_activation,
223
+ channels=in_channels
224
+ ),
225
+ upsample_layer,
226
+ ResidualUnit(
227
+ in_channels=out_channels,
228
+ out_channels=out_channels,
229
+ dilation=1,
230
+ use_snake=use_snake
231
+ ),
232
+ ResidualUnit(
233
+ in_channels=out_channels,
234
+ out_channels=out_channels,
235
+ dilation=3,
236
+ use_snake=use_snake
237
+ ),
238
+ ResidualUnit(
239
+ in_channels=out_channels,
240
+ out_channels=out_channels,
241
+ dilation=9,
242
+ use_snake=use_snake
243
+ ),
244
+ )
245
+
246
+ def forward(self, x):
247
+ return self.layers(x)
248
+
249
+
250
+ class OobleckEncoder(nn.Module):
251
+ def __init__(
252
+ self,
253
+ in_channels=2,
254
+ channels=128,
255
+ latent_dim=32,
256
+ c_mults=[1, 2, 4, 8],
257
+ strides=[2, 4, 8, 8],
258
+ use_snake=False,
259
+ antialias_activation=False
260
+ ):
261
+ super().__init__()
262
+
263
+ c_mults = [1] + c_mults
264
+
265
+ self.depth = len(c_mults)
266
+
267
+ layers = [
268
+ WNConv1d(
269
+ in_channels=in_channels,
270
+ out_channels=c_mults[0] * channels,
271
+ kernel_size=7,
272
+ padding=3
273
+ )
274
+ ]
275
+
276
+ for i in range(self.depth - 1):
277
+ layers += [
278
+ EncoderBlock(
279
+ in_channels=c_mults[i] * channels,
280
+ out_channels=c_mults[i + 1] * channels,
281
+ stride=strides[i],
282
+ use_snake=use_snake
283
+ )
284
+ ]
285
+
286
+ layers += [
287
+ get_activation(
288
+ "snake" if use_snake else "elu",
289
+ antialias=antialias_activation,
290
+ channels=c_mults[-1] * channels
291
+ ),
292
+ WNConv1d(
293
+ in_channels=c_mults[-1] * channels,
294
+ out_channels=latent_dim,
295
+ kernel_size=3,
296
+ padding=1
297
+ )
298
+ ]
299
+
300
+ self.layers = nn.Sequential(*layers)
301
+
302
+ def forward(self, x):
303
+ return self.layers(x)
304
+
305
+
306
+ class OobleckDecoder(nn.Module):
307
+ def __init__(
308
+ self,
309
+ out_channels=2,
310
+ channels=128,
311
+ latent_dim=32,
312
+ c_mults=[1, 2, 4, 8],
313
+ strides=[2, 4, 8, 8],
314
+ use_snake=False,
315
+ antialias_activation=False,
316
+ use_nearest_upsample=False,
317
+ final_tanh=True
318
+ ):
319
+ super().__init__()
320
+
321
+ c_mults = [1] + c_mults
322
+
323
+ self.depth = len(c_mults)
324
+
325
+ layers = [
326
+ WNConv1d(
327
+ in_channels=latent_dim,
328
+ out_channels=c_mults[-1] * channels,
329
+ kernel_size=7,
330
+ padding=3
331
+ ),
332
+ ]
333
+
334
+ for i in range(self.depth - 1, 0, -1):
335
+ layers += [
336
+ DecoderBlock(
337
+ in_channels=c_mults[i] * channels,
338
+ out_channels=c_mults[i - 1] * channels,
339
+ stride=strides[i - 1],
340
+ use_snake=use_snake,
341
+ antialias_activation=antialias_activation,
342
+ use_nearest_upsample=use_nearest_upsample
343
+ )
344
+ ]
345
+
346
+ layers += [
347
+ get_activation(
348
+ "snake" if use_snake else "elu",
349
+ antialias=antialias_activation,
350
+ channels=c_mults[0] * channels
351
+ ),
352
+ WNConv1d(
353
+ in_channels=c_mults[0] * channels,
354
+ out_channels=out_channels,
355
+ kernel_size=7,
356
+ padding=3,
357
+ bias=False
358
+ ),
359
+ nn.Tanh() if final_tanh else nn.Identity()
360
+ ]
361
+
362
+ self.layers = nn.Sequential(*layers)
363
+
364
+ def forward(self, x):
365
+ return self.layers(x)
366
+
367
+
368
+ class Bottleneck(nn.Module):
369
+ def __init__(self, is_discrete: bool = False):
370
+ super().__init__()
371
+
372
+ self.is_discrete = is_discrete
373
+
374
+ def encode(self, x, return_info=False, **kwargs):
375
+ raise NotImplementedError
376
+
377
+ def decode(self, x):
378
+ raise NotImplementedError
379
+
380
+
381
+ @torch.jit.script
382
+ def vae_sample(mean, scale) -> dict[str, torch.Tensor]:
383
+ stdev = nn.functional.softplus(scale) + 1e-4
384
+ var = stdev * stdev
385
+ logvar = torch.log(var)
386
+ latents = torch.randn_like(mean) * stdev + mean
387
+
388
+ kl = (mean*mean + var - logvar - 1).sum(1).mean()
389
+ return {"latents": latents, "kl": kl}
390
+
391
+
392
+ class VAEBottleneck(Bottleneck):
393
+ def __init__(self):
394
+ super().__init__(is_discrete=False)
395
+
396
+ def encode(self,
397
+ x,
398
+ return_info=False,
399
+ **kwargs) -> dict[str, torch.Tensor] | torch.Tensor:
400
+ mean, scale = x.chunk(2, dim=1)
401
+ sampled = vae_sample(mean, scale)
402
+
403
+ if return_info:
404
+ return sampled["latents"], {"kl": sampled["kl"]}
405
+ else:
406
+ return sampled["latents"]
407
+
408
+ def decode(self, x):
409
+ return x
410
+
411
+
412
+ def compute_mean_kernel(x, y):
413
+ kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
414
+ return torch.exp(-kernel_input).mean()
415
+
416
+
417
+ class Pretransform(nn.Module):
418
+ def __init__(self, enable_grad, io_channels, is_discrete):
419
+ super().__init__()
420
+
421
+ self.is_discrete = is_discrete
422
+ self.io_channels = io_channels
423
+ self.encoded_channels = None
424
+ self.downsampling_ratio = None
425
+
426
+ self.enable_grad = enable_grad
427
+
428
+ def encode(self, x):
429
+ raise NotImplementedError
430
+
431
+ def decode(self, z):
432
+ raise NotImplementedError
433
+
434
+ def tokenize(self, x):
435
+ raise NotImplementedError
436
+
437
+ def decode_tokens(self, tokens):
438
+ raise NotImplementedError
439
+
440
+
441
+ class StableVAE(LoadPretrainedBase, AutoEncoderBase):
442
+ def __init__(
443
+ self,
444
+ encoder,
445
+ decoder,
446
+ latent_dim,
447
+ downsampling_ratio,
448
+ sample_rate,
449
+ io_channels=2,
450
+ bottleneck: Bottleneck = None,
451
+ pretransform: Pretransform = None,
452
+ in_channels=None,
453
+ out_channels=None,
454
+ soft_clip=False,
455
+ pretrained_ckpt: str | Path = None
456
+ ):
457
+ LoadPretrainedBase.__init__(self)
458
+ AutoEncoderBase.__init__(
459
+ self,
460
+ downsampling_ratio=downsampling_ratio,
461
+ sample_rate=sample_rate,
462
+ latent_shape=(latent_dim, None)
463
+ )
464
+
465
+ self.latent_dim = latent_dim
466
+ self.io_channels = io_channels
467
+ self.in_channels = io_channels
468
+ self.out_channels = io_channels
469
+ self.min_length = self.downsampling_ratio
470
+
471
+ if in_channels is not None:
472
+ self.in_channels = in_channels
473
+
474
+ if out_channels is not None:
475
+ self.out_channels = out_channels
476
+
477
+ self.bottleneck = bottleneck
478
+ self.encoder = encoder
479
+ self.decoder = decoder
480
+ self.pretransform = pretransform
481
+ self.soft_clip = soft_clip
482
+ self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
483
+
484
+ self.remove_autoencoder_prefix_fn: Callable = remove_key_prefix_factory(
485
+ "autoencoder."
486
+ )
487
+ if pretrained_ckpt is not None:
488
+ self.load_pretrained(pretrained_ckpt)
489
+
490
+ def process_state_dict(self, model_dict, state_dict):
491
+ state_dict = state_dict["state_dict"]
492
+ state_dict = self.remove_autoencoder_prefix_fn(model_dict, state_dict)
493
+ return state_dict
494
+
495
+ def encode(
496
+ self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
497
+ ) -> tuple[torch.Tensor, torch.Tensor]:
498
+ z = self.encoder(waveform)
499
+ z = self.bottleneck.encode(z)
500
+ z_length = waveform_lengths // self.downsampling_ratio
501
+ z_mask = create_mask_from_length(z_length)
502
+ return z, z_mask
503
+
504
+ def decode(self, latents: torch.Tensor) -> torch.Tensor:
505
+ waveform = self.decoder(latents)
506
+ return waveform
507
+
508
+
509
+ if __name__ == '__main__':
510
+ import hydra
511
+ from utils.config import generate_config_from_command_line_overrides
512
+ model_config = generate_config_from_command_line_overrides(
513
+ "configs/model/autoencoder/stable_vae.yaml"
514
+ )
515
+ autoencoder: StableVAE = hydra.utils.instantiate(model_config)
516
+ autoencoder.eval()
517
+
518
+ waveform, sr = torchaudio.load(
519
+ "/hpc_stor03/sjtu_home/xuenan.xu/workspace/singing_voice_synthesis/diffsinger/data/raw/opencpop/segments/wavs/2007000230.wav"
520
+ )
521
+ waveform = torchaudio.functional.resample(
522
+ waveform, sr, model_config["sample_rate"]
523
+ )
524
+ print("waveform: ", waveform.shape)
525
+ with torch.no_grad():
526
+ latent, latent_length = autoencoder.encode(
527
+ waveform, torch.as_tensor([waveform.shape[-1]])
528
+ )
529
+ print("latent: ", latent.shape)
530
+ reconstructed = autoencoder.decode(latent)
531
+ print("reconstructed: ", reconstructed.shape)
532
+ import soundfile as sf
533
+ sf.write(
534
+ "./reconstructed.wav",
535
+ reconstructed[0, 0].numpy(),
536
+ samplerate=model_config["sample_rate"]
537
+ )
models/common.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch
3
+ import torch.nn as nn
4
+ from utils.torch_utilities import load_pretrained_model, merge_matched_keys
5
+ import warnings
6
+
7
+ class LoadPretrainedBase(nn.Module):
8
+ def process_state_dict(
9
+ self, model_dict: dict[str, torch.Tensor],
10
+ state_dict: dict[str, torch.Tensor]
11
+ ):
12
+ """
13
+ Custom processing functions of each model that transforms `state_dict` loaded from
14
+ checkpoints to the state that can be used in `load_state_dict`.
15
+ Use `merge_mathced_keys` to update parameters with matched names and shapes by
16
+ default.
17
+
18
+ Args
19
+ model_dict:
20
+ The state dict of the current model, which is going to load pretrained parameters
21
+ state_dict:
22
+ A dictionary of parameters from a pre-trained model.
23
+
24
+ Returns:
25
+ dict[str, torch.Tensor]:
26
+ The updated state dict, where parameters with matched keys and shape are
27
+ updated with values in `state_dict`.
28
+ """
29
+ state_dict = merge_matched_keys(model_dict, state_dict)
30
+ return state_dict
31
+
32
+ def load_pretrained(self, ckpt_path: str | Path):
33
+ load_pretrained_model(
34
+ self, ckpt_path, state_dict_process_fn=self.process_state_dict
35
+ )
36
+
37
+
38
+ class CountParamsBase(nn.Module):
39
+ def count_params(self):
40
+ num_params = 0
41
+ trainable_params = 0
42
+ for param in self.parameters():
43
+ num_params += param.numel()
44
+ if param.requires_grad:
45
+ trainable_params += param.numel()
46
+ return num_params, trainable_params
47
+
48
+
49
+ class SaveTrainableParamsBase(nn.Module):
50
+ @property
51
+ def param_names_to_save(self):
52
+ names = []
53
+ for name, param in self.named_parameters():
54
+ if param.requires_grad:
55
+ names.append(name)
56
+ for name, _ in self.named_buffers():
57
+ names.append(name)
58
+ return names
59
+
60
+ def load_state_dict(self, state_dict, strict=True, assign=True):
61
+ print("State dict keys:", list(state_dict.keys()))
62
+ for key in self.param_names_to_save:
63
+ if key not in state_dict:
64
+ raise Exception(
65
+ f"{key} not found in either pre-trained models (e.g. BERT)"
66
+ " or resumed checkpoints (e.g. epoch_40/model.pt)"
67
+ )
68
+ # 兼容 PyTorch/transformers 的 assign 参数
69
+ return super().load_state_dict(state_dict, strict=strict, assign=assign)
models/content_encoder/__pycache__/caption_encoder.cpython-310.pyc ADDED
Binary file (3.51 kB). View file
 
models/content_encoder/__pycache__/content_encoder.cpython-310.pyc ADDED
Binary file (4.72 kB). View file
 
models/content_encoder/__pycache__/content_encoder_add_1024.cpython-310.pyc ADDED
Binary file (4.62 kB). View file
 
models/content_encoder/__pycache__/content_encoder_clap.cpython-310.pyc ADDED
Binary file (6.11 kB). View file
 
models/content_encoder/__pycache__/content_encoder_clap_test.cpython-310.pyc ADDED
Binary file (6.12 kB). View file
 
models/content_encoder/__pycache__/content_encoder_concat.cpython-310.pyc ADDED
Binary file (4.74 kB). View file
 
models/content_encoder/__pycache__/content_encoder_concat_4096.cpython-310.pyc ADDED
Binary file (4.69 kB). View file
 
models/content_encoder/__pycache__/content_encoder_concat_4096_random.cpython-310.pyc ADDED
Binary file (4.73 kB). View file
 
models/content_encoder/__pycache__/content_encoder_full.cpython-310.pyc ADDED
Binary file (5.01 kB). View file
 
models/content_encoder/__pycache__/content_encoder_full_non.cpython-310.pyc ADDED
Binary file (5 kB). View file
 
models/content_encoder/__pycache__/content_encoder_full_non_test.cpython-310.pyc ADDED
Binary file (4.87 kB). View file
 
models/content_encoder/__pycache__/content_encoder_full_test.cpython-310.pyc ADDED
Binary file (4.48 kB). View file
 
models/content_encoder/__pycache__/content_encoder_full_woonset.cpython-310.pyc ADDED
Binary file (4.59 kB). View file
 
models/content_encoder/__pycache__/content_encoder_merge.cpython-310.pyc ADDED
Binary file (4.78 kB). View file
 
models/content_encoder/__pycache__/content_encoder_merge_test.cpython-310.pyc ADDED
Binary file (4.82 kB). View file
 
models/content_encoder/__pycache__/content_encoder_replace.cpython-310.pyc ADDED
Binary file (4.71 kB). View file
 
models/content_encoder/__pycache__/content_encoder_replace_merge.cpython-310.pyc ADDED
Binary file (4.72 kB). View file
 
models/content_encoder/__pycache__/content_encoder_replace_new.cpython-310.pyc ADDED
Binary file (4.71 kB). View file
 
models/content_encoder/__pycache__/content_encoder_test.cpython-310.pyc ADDED
Binary file (4.58 kB). View file
 
models/content_encoder/__pycache__/content_test.cpython-310.pyc ADDED
Binary file (4.71 kB). View file
 
models/content_encoder/__pycache__/new_content_encoder.cpython-310.pyc ADDED
Binary file (4.73 kB). View file
 
models/content_encoder/__pycache__/text_encoder.cpython-310.pyc ADDED
Binary file (2.71 kB). View file
 
models/content_encoder/caption_encoder.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import torch
3
+ import torch.nn as nn
4
+ import random
5
+ from utils.audiotime_event_merge import replace_event_synonyms
6
+
7
+ def decode_data(line_onset_str, latent_length):
8
+ """
9
+ Extracts a timestamp matrix (event onset indices) from a formatted onset string.
10
+
11
+ Args:
12
+ line_onset_str (str): String containing event names and onset intervals,
13
+ formatted like "event1__start1-end1_start2-end2--event2__start1-end1".
14
+ latent_length (int): Length of the output matrix.
15
+
16
+ Returns:
17
+ line_onset_index (torch.Tensor): Matrix of shape [4, latent_length],
18
+ line_event (list): List of event names extracted from the onset string.
19
+
20
+ Notes:
21
+ - 24000 is the audio sample rate.
22
+ - 480 is the downsample ratio to align with VAE.
23
+ - Each onset interval "start-end" (in seconds) is converted to embedding indices via (time * 24000 / 480).
24
+ """
25
+ line_onset_index = torch.zeros((4, latent_length)) # max for 4 events
26
+ line_event = []
27
+ event_idx = 0
28
+ for event_onset in line_onset_str.split('--'):
29
+ #print(event_onset)
30
+ (event, instance) = event_onset.split('__')
31
+ #print(instance)
32
+ line_event.append(event)
33
+ for start_end in instance.split('_'):
34
+ (start, end) = start_end.split('-')
35
+ start, end = int(float(start)*24000/480), int(float(end)*24000/480)
36
+ if end > (latent_length - 1): break
37
+ line_onset_index[event_idx, start: end] = 1
38
+ event_idx = event_idx + 1
39
+ return line_onset_index, line_event
40
+
41
+
42
+ class ContentEncoder(nn.Module):
43
+ """
44
+ ContentEncoder encodes TCC and TDC information.
45
+ """
46
+ def __init__(
47
+ self,
48
+ text_encoder: nn.Module= None,
49
+ ):
50
+ super().__init__()
51
+ self.text_encoder = text_encoder
52
+ self.pool = nn.AdaptiveAvgPool1d(1)
53
+
54
+ def encode_content(
55
+ self, batch_content: list[Any], device: str | torch.device
56
+ ):
57
+ batch_output = []
58
+ batch_mask = []
59
+ batch_onset = []
60
+ length_list = []
61
+ print(batch_content)
62
+ for content in batch_content:
63
+
64
+ caption = content["caption"]
65
+ onset = content["onset"]
66
+ length = int(float(content["length"]) *24000/480)
67
+ # Replacement for AudioTime
68
+ print(onset)
69
+ replace_label = content.get("replace_label", "False")
70
+ if replace_label == "True":
71
+ caption, onset = replace_event_synonyms(caption, onset)
72
+
73
+ # Handle random onset case for read data without timestamp
74
+ if content["onset"] == "random":
75
+ length_list.append(length)
76
+ """
77
+ fixed embedding. Actually it's a sick sentence, a error during training, kept to match the checkpoint.
78
+ You can change it to sentence that difference to captions in datasets.
79
+ The use of fixed text to obtain encoding is for numerical stability.
80
+ We attempted to use learnable unified encoding during training, but the results were not satisfactory.
81
+ """
82
+ event = "There is no event here"
83
+ event_embed = self.text_encoder([event.replace("_", " ")])["output"]
84
+ event_embed = self.pool(event_embed.permute(0, 2, 1)) # (B, 1024, 1)
85
+ event_embed = event_embed.flatten().unsqueeze(0)
86
+ new_onset = event_embed.repeat(length, 1).T
87
+ else:
88
+ onset_matrix, events = decode_data(onset, length)
89
+ length_list.append(length)
90
+ new_onset = torch.zeros((1024, length), device=device) # 1024 for T5
91
+ # TDC
92
+ for (idx, event) in enumerate(events):
93
+ with torch.no_grad():
94
+ event_embed = self.text_encoder([event.replace("_", " ")])["output"]
95
+ event_embed = self.pool(event_embed.permute(0, 2, 1)) # (B, 1024, 1)
96
+ event_embed = event_embed.flatten().unsqueeze(0)
97
+ mask = (onset_matrix[idx, :] == 0)
98
+ cols = mask.nonzero(as_tuple=True)[0]
99
+ new_onset[:, cols] += event_embed.T.float()
100
+ # TCC
101
+ output_dict = self.text_encoder([caption])
102
+ batch_output.append(output_dict["output"][0])
103
+ batch_mask.append(output_dict["mask"][0])
104
+ batch_onset.append(new_onset)
105
+
106
+ # Pad all sequences in the batch to the same length for batching
107
+ batch_output = nn.utils.rnn.pad_sequence(
108
+ batch_output, batch_first=True, padding_value=0
109
+ )
110
+ batch_mask = nn.utils.rnn.pad_sequence(
111
+ batch_mask, batch_first=True, padding_value=False
112
+ )
113
+ batch_onset = nn.utils.rnn.pad_sequence(
114
+ batch_onset, batch_first=True, padding_value=0
115
+ )
116
+ return batch_output, batch_mask, batch_onset, length_list
models/content_encoder/text_encoder.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5EncoderModel
4
+ from transformers.modeling_outputs import BaseModelOutput
5
+
6
+ try:
7
+ import torch_npu
8
+ from torch_npu.contrib import transfer_to_npu
9
+ DEVICE_TYPE = "npu"
10
+ except ModuleNotFoundError:
11
+ DEVICE_TYPE = "cuda"
12
+
13
+
14
+ class TransformersTextEncoderBase(nn.Module):
15
+ """
16
+ Base class for text encoding using HuggingFace Transformers models.
17
+
18
+ """
19
+ def __init__(self, model_name: str):
20
+ super().__init__()
21
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ self.model = AutoModel.from_pretrained(model_name)
23
+
24
+ def forward(
25
+ self,
26
+ text: list[str],
27
+ ):
28
+ device = self.model.device
29
+ batch = self.tokenizer(
30
+ text,
31
+ max_length=self.tokenizer.model_max_length,
32
+ padding=True,
33
+ truncation=True,
34
+ return_tensors="pt"
35
+ )
36
+ input_ids = batch.input_ids.to(device)
37
+ attention_mask = batch.attention_mask.to(device)
38
+ output: BaseModelOutput = self.model(
39
+ input_ids=input_ids, attention_mask=attention_mask
40
+ )
41
+ output = output.last_hidden_state
42
+ mask = (attention_mask == 1).to(device)
43
+
44
+ return {"output": output, "mask": mask}
45
+
46
+
47
+ class T5TextEncoder(TransformersTextEncoderBase):
48
+ """
49
+ Text encoder using T5 encoder model.
50
+ """
51
+ def __init__(self, model_name: str = "/mnt/petrelfs/zhengzihao/cache/google-flan-t5-large"):
52
+ nn.Module.__init__(self)
53
+ self.tokenizer = T5Tokenizer.from_pretrained(model_name)
54
+ self.model = T5EncoderModel.from_pretrained(model_name)
55
+ for param in self.model.parameters():
56
+ param.requires_grad = False
57
+ self.eval()
58
+
59
+ def forward(
60
+ self,
61
+ text: list[str],
62
+ ):
63
+ with torch.no_grad(), torch.amp.autocast(
64
+ device_type=DEVICE_TYPE, enabled=False
65
+ ):
66
+ return super().forward(text)
67
+
68
+
69
+ if __name__ == '__main__':
70
+ text_encoder = T5TextEncoder()
71
+ text = ["dog barking and cat moving"]
72
+ text_encoder.eval()
73
+ with torch.no_grad():
74
+ output = text_encoder(text)
75
+ print(output["output"].shape)
76
+ #print(output)
models/diffusion.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+ import random
3
+ from typing import Any
4
+
5
+ from tqdm import tqdm
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import diffusers.schedulers as noise_schedulers
10
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
11
+ from diffusers.utils.torch_utils import randn_tensor
12
+
13
+ import numpy as np
14
+ from models.autoencoder.autoencoder_base import AutoEncoderBase
15
+ from models.content_encoder.caption_encoder import ContentEncoder
16
+ from models.common import LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase
17
+ from utils.torch_utilities import (
18
+ create_alignment_path, create_mask_from_length, loss_with_mask,
19
+ trim_or_pad_length
20
+ )
21
+
22
+
23
+ class DiffusionMixin:
24
+ def __init__(
25
+ self,
26
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
27
+ snr_gamma: float = None,
28
+ classifier_free_guidance: bool = True,
29
+ cfg_drop_ratio: float = 0.2,
30
+
31
+ ) -> None:
32
+ self.noise_scheduler_name = noise_scheduler_name
33
+ self.snr_gamma = snr_gamma
34
+ self.classifier_free_guidance = classifier_free_guidance
35
+ self.cfg_drop_ratio = cfg_drop_ratio
36
+ self.noise_scheduler = noise_schedulers.DDIMScheduler.from_pretrained(
37
+ self.noise_scheduler_name, subfolder="scheduler"
38
+ )
39
+
40
+ def compute_snr(self, timesteps) -> torch.Tensor:
41
+ """
42
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
43
+ """
44
+ alphas_cumprod = self.noise_scheduler.alphas_cumprod
45
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
46
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod)**0.5
47
+
48
+ # Expand the tensors.
49
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
50
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device
51
+ )[timesteps].float()
52
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
53
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
54
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
55
+
56
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
57
+ device=timesteps.device
58
+ )[timesteps].float()
59
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
60
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[...,
61
+ None]
62
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
63
+
64
+ # Compute SNR.
65
+ snr = (alpha / sigma)**2
66
+ return snr
67
+
68
+ def get_timesteps(
69
+ self,
70
+ batch_size: int,
71
+ device: torch.device,
72
+ training: bool = True
73
+ ) -> torch.Tensor:
74
+ if training:
75
+ timesteps = torch.randint(
76
+ 0,
77
+ self.noise_scheduler.config.num_train_timesteps,
78
+ (batch_size, ),
79
+ device=device
80
+ )
81
+ else:
82
+ # validation on half of the total timesteps
83
+ timesteps = (self.noise_scheduler.config.num_train_timesteps //
84
+ 2) * torch.ones((batch_size, ),
85
+ dtype=torch.int64,
86
+ device=device)
87
+
88
+ timesteps = timesteps.long()
89
+ return timesteps
90
+
91
+ def get_target(
92
+ self, latent: torch.Tensor, noise: torch.Tensor,
93
+ timesteps: torch.Tensor
94
+ ) -> torch.Tensor:
95
+ """
96
+ Get the target for loss depending on the prediction type
97
+ """
98
+ if self.noise_scheduler.config.prediction_type == "epsilon":
99
+ target = noise
100
+ elif self.noise_scheduler.config.prediction_type == "v_prediction":
101
+ target = self.noise_scheduler.get_velocity(
102
+ latent, noise, timesteps
103
+ )
104
+ else:
105
+ raise ValueError(
106
+ f"Unknown prediction type {self.noise_scheduler.config.prediction_type}"
107
+ )
108
+ return target
109
+
110
+ def loss_with_snr(
111
+ self, pred: torch.Tensor, target: torch.Tensor,
112
+ timesteps: torch.Tensor, mask: torch.Tensor
113
+ ) -> torch.Tensor:
114
+ if self.snr_gamma is None:
115
+ loss = F.mse_loss(pred.float(), target.float(), reduction="none")
116
+ loss = loss_with_mask(loss, mask)
117
+ else:
118
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
119
+ # Adaptef from huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
120
+ snr = self.compute_snr(timesteps)
121
+ mse_loss_weights = (
122
+ torch.stack([snr, self.snr_gamma * torch.ones_like(timesteps)],
123
+ dim=1).min(dim=1)[0] / snr
124
+ )
125
+ loss = F.mse_loss(pred.float(), target.float(), reduction="none")
126
+ loss = loss_with_mask(loss, mask, reduce=False) * mse_loss_weights
127
+ loss = loss.mean()
128
+ return loss
129
+
130
+
131
+ class AudioDiffusion(
132
+ LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase,
133
+ DiffusionMixin
134
+ ):
135
+ """
136
+ Args:
137
+ autoencoder (AutoEncoderBase): Pretrained autoencoder module VAE(frozen).
138
+ content_encoder (ContentEncoder): Encodes TCC and TDC information.
139
+ backbone (nn.Module): Main denoising network.
140
+ frame_resolution (float): Resolution for audio frames.
141
+ noise_scheduler_name (str): Noise scheduler identifier.
142
+ snr_gamma (float, optional): SNR gamma for noise scheduler.
143
+ classifier_free_guidance (bool): Enable classifier-free guidance.
144
+ cfg_drop_ratio (float): Ratio for randomly dropping context for classifier-free guidance.
145
+ """
146
+ def __init__(
147
+ self,
148
+ autoencoder: AutoEncoderBase,
149
+ content_encoder: ContentEncoder,
150
+ backbone: nn.Module,
151
+ frame_resolution:float,
152
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
153
+ snr_gamma: float = None,
154
+ classifier_free_guidance: bool = True,
155
+ cfg_drop_ratio: float = 0.2,
156
+ ):
157
+ nn.Module.__init__(self)
158
+ DiffusionMixin.__init__(
159
+ self, noise_scheduler_name, snr_gamma, classifier_free_guidance, cfg_drop_ratio
160
+ )
161
+
162
+ self.autoencoder = autoencoder
163
+ # Freeze autoencoder parameters
164
+ for param in self.autoencoder.parameters():
165
+ param.requires_grad = False
166
+
167
+ self.content_encoder = content_encoder
168
+ self.backbone = backbone
169
+ self.frame_resolution = frame_resolution
170
+ self.dummy_param = nn.Parameter(torch.empty(0))
171
+
172
+ def forward(
173
+ self, content: list[Any], condition: list[Any], task: list[str],
174
+ waveform: torch.Tensor, waveform_lengths: torch.Tensor, **kwargs
175
+ ):
176
+ """
177
+ Training forward pass.
178
+
179
+ Args:
180
+ content (list[Any]): List of content dicts for each sample.
181
+ condition (list[Any]): Conditioning information (unused here).
182
+ task (list[str]): List of task types.
183
+ waveform (Tensor): Batch of waveform tensors.
184
+ waveform_lengths (Tensor): Lengths for each waveform sample.
185
+
186
+ Returns:
187
+ dict: Dictionary containing the diffusion loss.
188
+ """
189
+ device = self.dummy_param.device
190
+ num_train_timesteps = self.noise_scheduler.config.num_train_timesteps
191
+ self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
192
+
193
+ self.autoencoder.eval()
194
+ with torch.no_grad():
195
+ latent, latent_mask = self.autoencoder.encode(
196
+ waveform.unsqueeze(1), waveform_lengths
197
+ )
198
+ # content(non_time_aligned_content) for TCC and time_aligned_content for TDC
199
+ content, content_mask, onset, _= self.content_encoder.encode_content(
200
+ content, device=device
201
+ )
202
+
203
+ # prepare latent and diffusion-related noise
204
+ time_aligned_content = onset.permute(0,2,1)
205
+ if self.training and self.classifier_free_guidance:
206
+ mask_indices = [
207
+ k for k in range(len(waveform)) if random.random() < self.cfg_drop_ratio
208
+ ]
209
+ if len(mask_indices) > 0:
210
+ content[mask_indices] = 0
211
+ time_aligned_content[mask_indices] = 0
212
+
213
+ batch_size = latent.shape[0]
214
+ timesteps = self.get_timesteps(batch_size, device, self.training)
215
+ noise = torch.randn_like(latent)
216
+ noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps)
217
+ target = self.get_target(latent, noise, timesteps)
218
+
219
+ # Denoising prediction
220
+ pred: torch.Tensor = self.backbone(
221
+ x=noisy_latent,
222
+ timesteps=timesteps,
223
+ time_aligned_context=time_aligned_content,
224
+ context=content,
225
+ x_mask=latent_mask,
226
+ context_mask=content_mask
227
+ )
228
+ pred = pred.transpose(1, self.autoencoder.time_dim)
229
+ target = target.transpose(1, self.autoencoder.time_dim)
230
+ diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask)
231
+ return {
232
+ "diff_loss": diff_loss,
233
+ }
234
+
235
+ @torch.no_grad()
236
+ def inference(
237
+ self,
238
+ content: list[Any],
239
+ num_steps: int = 20,
240
+ guidance_scale: float = 3.0,
241
+ guidance_rescale: float = 0.0,
242
+ disable_progress: bool = True,
243
+ num_samples_per_content: int = 1,
244
+ **kwargs
245
+ ):
246
+ """
247
+ Inference/generation method for audio diffusion.
248
+
249
+ Args:
250
+ content (list[Any]): List of content dicts.
251
+ scheduler (SchedulerMixin): Scheduler for timesteps and noise.
252
+ num_steps (int): Number of denoising steps.
253
+ guidance_scale (float): Classifier-free guidance scale.
254
+ guidance_rescale (float): Rescale factor for guidance.
255
+ disable_progress (bool): Disable progress bar.
256
+ num_samples_per_content (int): How many samples to generate per content.
257
+
258
+ Returns:
259
+ waveform (Tensor): Generated waveform.
260
+ """
261
+ device = self.dummy_param.device
262
+ classifier_free_guidance = guidance_scale > 1.0
263
+ batch_size = len(content) * num_samples_per_content
264
+ print(content)
265
+ if classifier_free_guidance:
266
+ content, content_mask, onset, length_list = self.encode_content_classifier_free(
267
+ content, num_samples_per_content
268
+ )
269
+ else:
270
+ content, content_mask, onset, length_list = self.content_encoder.encode_content(
271
+ content, device=device
272
+ )
273
+ content = content.repeat_interleave(num_samples_per_content, 0)
274
+ content_mask = content_mask.repeat_interleave(
275
+ num_samples_per_content, 0
276
+ )
277
+
278
+ self.noise_scheduler.set_timesteps(num_steps, device=device)
279
+ timesteps = self.noise_scheduler.timesteps
280
+
281
+
282
+ # prepare input latent and context for the backbone
283
+ shape = (batch_size, 128, onset.shape[2]) # 128 for StableVAE channels
284
+ time_aligned_content = onset.permute(0,2,1)
285
+ latent = randn_tensor(
286
+ shape, generator=None, device=device, dtype=content.dtype
287
+ )
288
+
289
+ # scale the initial noise by the standard deviation required by the scheduler
290
+ latent = latent * self.noise_scheduler.init_noise_sigma
291
+ latent_mask = torch.full((batch_size, onset.shape[2]), False, device=device)
292
+
293
+ for i, length in enumerate(length_list):
294
+ # Set latent mask True for valid time steps for each sample
295
+ latent_mask[i, :length] = True
296
+ num_warmup_steps = len(timesteps) - num_steps * self.noise_scheduler.order
297
+ progress_bar = tqdm(range(num_steps), disable=disable_progress)
298
+
299
+ if classifier_free_guidance:
300
+ uncond_time_aligned_content = torch.zeros_like(
301
+ time_aligned_content
302
+ )
303
+ time_aligned_content = torch.cat(
304
+ [uncond_time_aligned_content, time_aligned_content]
305
+ )
306
+ latent_mask = torch.cat(
307
+ [latent_mask, latent_mask.detach().clone()]
308
+ )
309
+
310
+ # iteratively denoising
311
+
312
+ for i, timestep in enumerate(timesteps):
313
+
314
+ latent_input = torch.cat(
315
+ [latent, latent]
316
+ ) if classifier_free_guidance else latent
317
+ latent_input = self.noise_scheduler.scale_model_input(latent_input, timestep)
318
+
319
+ noise_pred = self.backbone(
320
+ x=latent_input,
321
+ x_mask=latent_mask,
322
+ timesteps=timestep,
323
+ time_aligned_context=time_aligned_content,
324
+ context=content,
325
+ context_mask=content_mask,
326
+ )
327
+
328
+ if classifier_free_guidance:
329
+ noise_pred_uncond, noise_pred_content = noise_pred.chunk(2)
330
+ noise_pred = noise_pred_uncond + guidance_scale * (
331
+ noise_pred_content - noise_pred_uncond
332
+ )
333
+ if guidance_rescale != 0.0:
334
+ noise_pred = self.rescale_cfg(
335
+ noise_pred_content, noise_pred, guidance_rescale
336
+ )
337
+ # compute the previous noisy sample x_t -> x_t-1
338
+ latent = self.noise_scheduler.step(noise_pred, timestep, latent).prev_sample
339
+
340
+ # call the callback, if provided
341
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
342
+ (i+1) % self.noise_scheduler.order == 0):
343
+ progress_bar.update(1)
344
+ #latent = latent.to(next(self.autoencoder.parameters()).device)
345
+ waveform = self.autoencoder.decode(latent)
346
+ return waveform
347
+
348
+ def encode_content_classifier_free(
349
+ self,
350
+ content: list[Any],
351
+ task: list[str],
352
+ num_samples_per_content: int = 1
353
+ ):
354
+ device = self.dummy_param.device
355
+
356
+ content, content_mask, onset, length_list = self.content_encoder.encode_content(
357
+ content, device=device
358
+ )
359
+ content = content.repeat_interleave(num_samples_per_content, 0)
360
+ content_mask = content_mask.repeat_interleave(
361
+ num_samples_per_content, 0
362
+ )
363
+
364
+ # get unconditional embeddings for classifier free guidance
365
+ uncond_content = torch.zeros_like(content)
366
+ uncond_content_mask = content_mask.detach().clone()
367
+
368
+ uncond_content = uncond_content.repeat_interleave(
369
+ num_samples_per_content, 0
370
+ )
371
+ uncond_content_mask = uncond_content_mask.repeat_interleave(
372
+ num_samples_per_content, 0
373
+ )
374
+
375
+ # For classifier free guidance, we need to do two forward passes.
376
+ # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
377
+ content = torch.cat([uncond_content, content])
378
+ content_mask = torch.cat([uncond_content_mask, content_mask])
379
+
380
+ return content, content_mask, onset, length_list
381
+
382
+ def rescale_cfg(
383
+ self, pred_cond: torch.Tensor, pred_cfg: torch.Tensor,
384
+ guidance_rescale: float
385
+ ):
386
+ """
387
+ Rescale `pred_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
388
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
389
+ """
390
+ std_cond = pred_cond.std(
391
+ dim=list(range(1, pred_cond.ndim)), keepdim=True
392
+ )
393
+ std_cfg = pred_cfg.std(dim=list(range(1, pred_cfg.ndim)), keepdim=True)
394
+
395
+ pred_rescaled = pred_cfg * (std_cond / std_cfg)
396
+ pred_cfg = guidance_rescale * pred_rescaled + (
397
+ 1 - guidance_rescale
398
+ ) * pred_cfg
models/dit/__pycache__/attention.cpython-310.pyc ADDED
Binary file (7.7 kB). View file
 
models/dit/__pycache__/audio_dit.cpython-310.pyc ADDED
Binary file (8.31 kB). View file
 
models/dit/__pycache__/mask_dit.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
models/dit/__pycache__/modules.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
models/dit/__pycache__/rotary.cpython-310.pyc ADDED
Binary file (2.79 kB). View file
 
models/dit/__pycache__/span_mask.cpython-310.pyc ADDED
Binary file (4.74 kB). View file
 
models/dit/attention.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint
5
+ import einops
6
+ from einops import rearrange, repeat
7
+ from inspect import isfunction
8
+ from .rotary import RotaryEmbedding
9
+ from .modules import RMSNorm
10
+
11
+ if hasattr(nn.functional, 'scaled_dot_product_attention'):
12
+ ATTENTION_MODE = 'flash'
13
+ else:
14
+ ATTENTION_MODE = 'math'
15
+ print(f'attention mode is {ATTENTION_MODE}')
16
+
17
+
18
+ def add_mask(sim, mask):
19
+ b, ndim = sim.shape[0], mask.ndim
20
+ if ndim == 3:
21
+ mask = rearrange(mask, "b n m -> b 1 n m")
22
+ if ndim == 2:
23
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
24
+ max_neg_value = -torch.finfo(sim.dtype).max
25
+ sim = sim.masked_fill(~mask, max_neg_value)
26
+ return sim
27
+
28
+
29
+ def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
30
+ def default(val, d):
31
+ return val if val is not None else (d() if isfunction(d) else d)
32
+
33
+ b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
34
+ #print(q_mask)
35
+ q_mask = default(
36
+ q_mask, torch.ones((b, i), device=device, dtype=torch.bool)
37
+ )
38
+ k_mask = default(
39
+ k_mask, torch.ones((b, j), device=device, dtype=torch.bool)
40
+ )
41
+ attn_mask = rearrange(q_mask, 'b i -> b 1 i 1'
42
+ ) * rearrange(k_mask, 'b j -> b 1 1 j')
43
+ return attn_mask
44
+
45
+
46
+ class Attention(nn.Module):
47
+ def __init__(
48
+ self,
49
+ dim,
50
+ context_dim=None,
51
+ num_heads=8,
52
+ qkv_bias=False,
53
+ qk_scale=None,
54
+ qk_norm=None,
55
+ attn_drop=0.,
56
+ proj_drop=0.,
57
+ rope_mode='none'
58
+ ):
59
+ super().__init__()
60
+ self.num_heads = num_heads
61
+ head_dim = dim // num_heads
62
+ self.scale = qk_scale or head_dim**-0.5
63
+
64
+ if context_dim is None:
65
+ self.cross_attn = False
66
+ else:
67
+ self.cross_attn = True
68
+
69
+ context_dim = dim if context_dim is None else context_dim
70
+
71
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
72
+ self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
73
+ self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
74
+
75
+ if qk_norm is None:
76
+ self.norm_q = nn.Identity()
77
+ self.norm_k = nn.Identity()
78
+ elif qk_norm == 'layernorm':
79
+ self.norm_q = nn.LayerNorm(head_dim)
80
+ self.norm_k = nn.LayerNorm(head_dim)
81
+ elif qk_norm == 'rmsnorm':
82
+ self.norm_q = RMSNorm(head_dim)
83
+ self.norm_k = RMSNorm(head_dim)
84
+ else:
85
+ raise NotImplementedError
86
+
87
+ self.attn_drop_p = attn_drop
88
+ self.attn_drop = nn.Dropout(attn_drop)
89
+ self.proj = nn.Linear(dim, dim)
90
+ self.proj_drop = nn.Dropout(proj_drop)
91
+
92
+ if self.cross_attn:
93
+ assert rope_mode == 'none'
94
+ self.rope_mode = rope_mode
95
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
96
+ self.rotary = RotaryEmbedding(dim=head_dim)
97
+ elif self.rope_mode == 'dual':
98
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
99
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
100
+
101
+ def _rotary(self, q, k, extras):
102
+ if self.rope_mode == 'shared':
103
+ q, k = self.rotary(q=q, k=k)
104
+ elif self.rope_mode == 'x_only':
105
+ q_x, k_x = self.rotary(
106
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
107
+ )
108
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
109
+ q = torch.cat((q_c, q_x), dim=2)
110
+ k = torch.cat((k_c, k_x), dim=2)
111
+ elif self.rope_mode == 'dual':
112
+ q_x, k_x = self.rotary_x(
113
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
114
+ )
115
+ q_c, k_c = self.rotary_c(
116
+ q=q[:, :, :extras, :], k=k[:, :, :extras, :]
117
+ )
118
+ q = torch.cat((q_c, q_x), dim=2)
119
+ k = torch.cat((k_c, k_x), dim=2)
120
+ elif self.rope_mode == 'none':
121
+ pass
122
+ else:
123
+ raise NotImplementedError
124
+ return q, k
125
+
126
+ def _attn(self, q, k, v, mask_binary):
127
+ if ATTENTION_MODE == 'flash':
128
+ x = F.scaled_dot_product_attention(
129
+ q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary
130
+ )
131
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
132
+ elif ATTENTION_MODE == 'math':
133
+ attn = (q @ k.transpose(-2, -1)) * self.scale
134
+ attn = add_mask(
135
+ attn, mask_binary
136
+ ) if mask_binary is not None else attn
137
+ attn = attn.softmax(dim=-1)
138
+ attn = self.attn_drop(attn)
139
+ x = (attn @ v).transpose(1, 2)
140
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
141
+ else:
142
+ raise NotImplementedError
143
+ return x
144
+
145
+ def forward(self, x, context=None, context_mask=None, extras=0):
146
+ B, L, C = x.shape
147
+ if context is None:
148
+ context = x
149
+
150
+ q = self.to_q(x)
151
+ k = self.to_k(context)
152
+ v = self.to_v(context)
153
+
154
+ if context_mask is not None:
155
+ mask_binary = create_mask(
156
+ x.shape, context.shape, x.device, None, context_mask
157
+ )
158
+ else:
159
+ mask_binary = None
160
+
161
+ q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads)
162
+ k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads)
163
+ v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads)
164
+
165
+ q = self.norm_q(q)
166
+ k = self.norm_k(k)
167
+
168
+ q, k = self._rotary(q, k, extras)
169
+
170
+ x = self._attn(q, k, v, mask_binary)
171
+
172
+ x = self.proj(x)
173
+ x = self.proj_drop(x)
174
+ return x
175
+
176
+
177
+ class JointAttention(nn.Module):
178
+ def __init__(
179
+ self,
180
+ dim,
181
+ num_heads=8,
182
+ qkv_bias=False,
183
+ qk_scale=None,
184
+ qk_norm=None,
185
+ attn_drop=0.,
186
+ proj_drop=0.,
187
+ rope_mode='none'
188
+ ):
189
+ super().__init__()
190
+ self.num_heads = num_heads
191
+ head_dim = dim // num_heads
192
+ self.scale = qk_scale or head_dim**-0.5
193
+
194
+ self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(
195
+ dim, qkv_bias
196
+ )
197
+ self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(
198
+ dim, qkv_bias
199
+ )
200
+
201
+ self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim)
202
+ self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim)
203
+
204
+ self.attn_drop_p = attn_drop
205
+ self.attn_drop = nn.Dropout(attn_drop)
206
+
207
+ self.proj_x = nn.Linear(dim, dim)
208
+ self.proj_drop_x = nn.Dropout(proj_drop)
209
+
210
+ self.proj_c = nn.Linear(dim, dim)
211
+ self.proj_drop_c = nn.Dropout(proj_drop)
212
+
213
+ self.rope_mode = rope_mode
214
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
215
+ self.rotary = RotaryEmbedding(dim=head_dim)
216
+ elif self.rope_mode == 'dual':
217
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
218
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
219
+
220
+ def _make_qkv_layers(self, dim, qkv_bias):
221
+ return (
222
+ nn.Linear(dim, dim,
223
+ bias=qkv_bias), nn.Linear(dim, dim, bias=qkv_bias),
224
+ nn.Linear(dim, dim, bias=qkv_bias)
225
+ )
226
+
227
+ def _make_norm_layers(self, qk_norm, head_dim):
228
+ if qk_norm is None:
229
+ norm_q = nn.Identity()
230
+ norm_k = nn.Identity()
231
+ elif qk_norm == 'layernorm':
232
+ norm_q = nn.LayerNorm(head_dim)
233
+ norm_k = nn.LayerNorm(head_dim)
234
+ elif qk_norm == 'rmsnorm':
235
+ norm_q = RMSNorm(head_dim)
236
+ norm_k = RMSNorm(head_dim)
237
+ else:
238
+ raise NotImplementedError
239
+ return norm_q, norm_k
240
+
241
+ def _rotary(self, q, k, extras):
242
+ if self.rope_mode == 'shared':
243
+ q, k = self.rotary(q=q, k=k)
244
+ elif self.rope_mode == 'x_only':
245
+ q_x, k_x = self.rotary(
246
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
247
+ )
248
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
249
+ q = torch.cat((q_c, q_x), dim=2)
250
+ k = torch.cat((k_c, k_x), dim=2)
251
+ elif self.rope_mode == 'dual':
252
+ q_x, k_x = self.rotary_x(
253
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
254
+ )
255
+ q_c, k_c = self.rotary_c(
256
+ q=q[:, :, :extras, :], k=k[:, :, :extras, :]
257
+ )
258
+ q = torch.cat((q_c, q_x), dim=2)
259
+ k = torch.cat((k_c, k_x), dim=2)
260
+ elif self.rope_mode == 'none':
261
+ pass
262
+ else:
263
+ raise NotImplementedError
264
+ return q, k
265
+
266
+ def _attn(self, q, k, v, mask_binary):
267
+ if ATTENTION_MODE == 'flash':
268
+ x = F.scaled_dot_product_attention(
269
+ q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary
270
+ )
271
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
272
+ elif ATTENTION_MODE == 'math':
273
+ attn = (q @ k.transpose(-2, -1)) * self.scale
274
+ attn = add_mask(
275
+ attn, mask_binary
276
+ ) if mask_binary is not None else attn
277
+ attn = attn.softmax(dim=-1)
278
+ attn = self.attn_drop(attn)
279
+ x = (attn @ v).transpose(1, 2)
280
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
281
+ else:
282
+ raise NotImplementedError
283
+ return x
284
+
285
+ def _cat_mask(self, x, context, x_mask=None, context_mask=None):
286
+ B = x.shape[0]
287
+ if x_mask is None:
288
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
289
+ if context_mask is None:
290
+ context_mask = torch.ones(
291
+ B, context.shape[-2], device=context.device
292
+ ).bool()
293
+ mask = torch.cat([context_mask, x_mask], dim=1)
294
+ return mask
295
+
296
+ def forward(self, x, context, x_mask=None, context_mask=None, extras=0):
297
+ B, Lx, C = x.shape
298
+ _, Lc, _ = context.shape
299
+ if x_mask is not None or context_mask is not None:
300
+ mask = self._cat_mask(
301
+ x, context, x_mask=x_mask, context_mask=context_mask
302
+ )
303
+ shape = [B, Lx + Lc, C]
304
+ mask_binary = create_mask(
305
+ q_shape=shape,
306
+ k_shape=shape,
307
+ device=x.device,
308
+ q_mask=None,
309
+ k_mask=mask
310
+ )
311
+ else:
312
+ mask_binary = None
313
+
314
+ qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x)
315
+ qc, kc, vc = self.to_qc(context), self.to_kc(context
316
+ ), self.to_vc(context)
317
+
318
+ qx, kx, vx = map(
319
+ lambda t: einops.
320
+ rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads),
321
+ [qx, kx, vx]
322
+ )
323
+ qc, kc, vc = map(
324
+ lambda t: einops.
325
+ rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads),
326
+ [qc, kc, vc]
327
+ )
328
+
329
+ qx, kx = self.norm_qx(qx), self.norm_kx(kx)
330
+ qc, kc = self.norm_qc(qc), self.norm_kc(kc)
331
+
332
+ q, k, v = (
333
+ torch.cat([qc, qx],
334
+ dim=2), torch.cat([kc, kx],
335
+ dim=2), torch.cat([vc, vx], dim=2)
336
+ )
337
+
338
+ q, k = self._rotary(q, k, extras)
339
+
340
+ x = self._attn(q, k, v, mask_binary)
341
+
342
+ context, x = x[:, :Lc, :], x[:, Lc:, :]
343
+
344
+ x = self.proj_x(x)
345
+ x = self.proj_drop_x(x)
346
+
347
+ context = self.proj_c(context)
348
+ context = self.proj_drop_c(context)
349
+
350
+ return x, context
models/dit/audio_diffsingernet_dit.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from .mask_dit import DiTBlock, FinalBlock, UDiT
6
+ from .modules import (
7
+ film_modulate,
8
+ PatchEmbed,
9
+ PE_wrapper,
10
+ TimestepEmbedder,
11
+ RMSNorm,
12
+ )
13
+
14
+
15
+ class AudioDiTBlock(DiTBlock):
16
+ """
17
+ A modified DiT block with time_aligned_context add to latent.
18
+ """
19
+ def __init__(
20
+ self,
21
+ dim,
22
+ time_aligned_context_dim,
23
+ dilation,
24
+ context_dim=None,
25
+ num_heads=8,
26
+ mlp_ratio=4.,
27
+ qkv_bias=False,
28
+ qk_scale=None,
29
+ qk_norm=None,
30
+ act_layer='gelu',
31
+ norm_layer=nn.LayerNorm,
32
+ time_fusion='none',
33
+ ada_sola_rank=None,
34
+ ada_sola_alpha=None,
35
+ skip=False,
36
+ skip_norm=False,
37
+ rope_mode='none',
38
+ context_norm=False,
39
+ use_checkpoint=False
40
+ ):
41
+ super().__init__(
42
+ dim=dim,
43
+ context_dim=context_dim,
44
+ num_heads=num_heads,
45
+ mlp_ratio=mlp_ratio,
46
+ qkv_bias=qkv_bias,
47
+ qk_scale=qk_scale,
48
+ qk_norm=qk_norm,
49
+ act_layer=act_layer,
50
+ norm_layer=norm_layer,
51
+ time_fusion=time_fusion,
52
+ ada_sola_rank=ada_sola_rank,
53
+ ada_sola_alpha=ada_sola_alpha,
54
+ skip=skip,
55
+ skip_norm=skip_norm,
56
+ rope_mode=rope_mode,
57
+ context_norm=context_norm,
58
+ use_checkpoint=use_checkpoint
59
+ )
60
+ # time-aligned context projection
61
+ self.ta_context_projection = nn.Linear(
62
+ time_aligned_context_dim, 2 * dim
63
+ )
64
+ self.dilated_conv = nn.Conv1d(
65
+ dim, 2 * dim, kernel_size=3, padding=dilation, dilation=dilation
66
+ )
67
+
68
+ def forward(
69
+ self,
70
+ x,
71
+ time_aligned_context,
72
+ time_token=None,
73
+ time_ada=None,
74
+ skip=None,
75
+ context=None,
76
+ x_mask=None,
77
+ context_mask=None,
78
+ extras=None
79
+ ):
80
+ if self.use_checkpoint:
81
+ return checkpoint(
82
+ self._forward,
83
+ x,
84
+ time_aligned_context,
85
+ time_token,
86
+ time_ada,
87
+ skip,
88
+ context,
89
+ x_mask,
90
+ context_mask,
91
+ extras,
92
+ use_reentrant=False
93
+ )
94
+ else:
95
+ return self._forward(
96
+ x,
97
+ time_aligned_context,
98
+ time_token,
99
+ time_ada,
100
+ skip,
101
+ context,
102
+ x_mask,
103
+ context_mask,
104
+ extras,
105
+ )
106
+
107
+ def _forward(
108
+ self,
109
+ x,
110
+ time_aligned_context,
111
+ time_token=None,
112
+ time_ada=None,
113
+ skip=None,
114
+ context=None,
115
+ x_mask=None,
116
+ context_mask=None,
117
+ extras=None
118
+ ):
119
+ B, T, C = x.shape
120
+ if self.skip_linear is not None:
121
+ assert skip is not None
122
+ cat = torch.cat([x, skip], dim=-1)
123
+ cat = self.skip_norm(cat)
124
+ x = self.skip_linear(cat)
125
+
126
+ if self.use_adanorm:
127
+ time_ada = self.adaln(time_token, time_ada)
128
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
129
+ gate_mlp) = time_ada.chunk(6, dim=1)
130
+
131
+ # self attention
132
+ if self.use_adanorm:
133
+ x_norm = film_modulate(
134
+ self.norm1(x), shift=shift_msa, scale=scale_msa
135
+ )
136
+ x = x + (1-gate_msa) * self.attn(
137
+ x_norm, context=None, context_mask=x_mask, extras=extras
138
+ )
139
+ else:
140
+ # TODO diffusion timestep input is not fused here
141
+ x = x + self.attn(
142
+ self.norm1(x),
143
+ context=None,
144
+ context_mask=x_mask,
145
+ extras=extras
146
+ )
147
+
148
+ # time-aligned context
149
+ time_aligned_context = self.ta_context_projection(time_aligned_context)
150
+ x = self.dilated_conv(x.transpose(1, 2)
151
+ ).transpose(1, 2) + time_aligned_context
152
+
153
+ gate, filter = torch.chunk(x, 2, dim=-1)
154
+ x = torch.sigmoid(gate) * torch.tanh(filter)
155
+
156
+ # cross attention
157
+ if self.use_context:
158
+ assert context is not None
159
+ x = x + self.cross_attn(
160
+ x=self.norm2(x),
161
+ context=self.norm_context(context),
162
+ context_mask=context_mask,
163
+ extras=extras
164
+ )
165
+
166
+ # mlp
167
+ if self.use_adanorm:
168
+ x_norm = film_modulate(
169
+ self.norm3(x), shift=shift_mlp, scale=scale_mlp
170
+ )
171
+ x = x + (1-gate_mlp) * self.mlp(x_norm)
172
+ else:
173
+ x = x + self.mlp(self.norm3(x))
174
+
175
+ return x
176
+
177
+
178
+ class AudioUDiT(UDiT):
179
+ def __init__(
180
+ self,
181
+ img_size=224,
182
+ patch_size=16,
183
+ in_chans=3,
184
+ input_type='2d',
185
+ out_chans=None,
186
+ embed_dim=768,
187
+ depth=12,
188
+ dilation_cycle_length=4,
189
+ num_heads=12,
190
+ mlp_ratio=4,
191
+ qkv_bias=False,
192
+ qk_scale=None,
193
+ qk_norm=None,
194
+ act_layer='gelu',
195
+ norm_layer='layernorm',
196
+ context_norm=False,
197
+ use_checkpoint=False,
198
+ time_fusion='token',
199
+ ada_sola_rank=None,
200
+ ada_sola_alpha=None,
201
+ cls_dim=None,
202
+ time_aligned_context_dim=768,
203
+ context_dim=768,
204
+ context_fusion='concat',
205
+ context_max_length=128,
206
+ context_pe_method='sinu',
207
+ pe_method='abs',
208
+ rope_mode='none',
209
+ use_conv=True,
210
+ skip=True,
211
+ skip_norm=True
212
+ ):
213
+ nn.Module.__init__(self)
214
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
215
+
216
+ # input
217
+ self.in_chans = in_chans
218
+ self.input_type = input_type
219
+ if self.input_type == '2d':
220
+ num_patches = (img_size[0] //
221
+ patch_size) * (img_size[1] // patch_size)
222
+ elif self.input_type == '1d':
223
+ num_patches = img_size // patch_size
224
+ self.patch_embed = PatchEmbed(
225
+ patch_size=patch_size,
226
+ in_chans=in_chans,
227
+ embed_dim=embed_dim,
228
+ input_type=input_type
229
+ )
230
+ out_chans = in_chans if out_chans is None else out_chans
231
+ self.out_chans = out_chans
232
+
233
+ # position embedding
234
+ self.rope = rope_mode
235
+ self.x_pe = PE_wrapper(
236
+ dim=embed_dim, method=pe_method, length=num_patches
237
+ )
238
+
239
+ # time embed
240
+ self.time_embed = TimestepEmbedder(embed_dim)
241
+ self.time_fusion = time_fusion
242
+ self.use_adanorm = False
243
+
244
+ # cls embed
245
+ if cls_dim is not None:
246
+ self.cls_embed = nn.Sequential(
247
+ nn.Linear(cls_dim, embed_dim, bias=True),
248
+ nn.SiLU(),
249
+ nn.Linear(embed_dim, embed_dim, bias=True),
250
+ )
251
+ else:
252
+ self.cls_embed = None
253
+
254
+ # time fusion
255
+ if time_fusion == 'token':
256
+ # put token at the beginning of sequence
257
+ self.extras = 2 if self.cls_embed else 1
258
+ self.time_pe = PE_wrapper(
259
+ dim=embed_dim, method='abs', length=self.extras
260
+ )
261
+ elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
262
+ self.use_adanorm = True
263
+ # aviod repetitive silu for each adaln block
264
+ self.time_act = nn.SiLU()
265
+ self.extras = 0
266
+ self.time_ada_final = nn.Linear(
267
+ embed_dim, 2 * embed_dim, bias=True
268
+ )
269
+ if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
270
+ # shared adaln
271
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
272
+ else:
273
+ self.time_ada = None
274
+ else:
275
+ raise NotImplementedError
276
+
277
+ # context
278
+ # use a simple projection
279
+ self.use_context = False
280
+ self.context_cross = False
281
+ self.context_max_length = context_max_length
282
+ self.context_fusion = 'none'
283
+ if context_dim is not None:
284
+ self.use_context = True
285
+ self.context_embed = nn.Sequential(
286
+ nn.Linear(context_dim, embed_dim, bias=True),
287
+ nn.SiLU(),
288
+ nn.Linear(embed_dim, embed_dim, bias=True),
289
+ )
290
+ self.context_fusion = context_fusion
291
+ if context_fusion == 'concat' or context_fusion == 'joint':
292
+ self.extras += context_max_length
293
+ self.context_pe = PE_wrapper(
294
+ dim=embed_dim,
295
+ method=context_pe_method,
296
+ length=context_max_length
297
+ )
298
+ # no cross attention layers
299
+ context_dim = None
300
+ elif context_fusion == 'cross':
301
+ self.context_pe = PE_wrapper(
302
+ dim=embed_dim,
303
+ method=context_pe_method,
304
+ length=context_max_length
305
+ )
306
+ self.context_cross = True
307
+ context_dim = embed_dim
308
+ else:
309
+ raise NotImplementedError
310
+
311
+ self.use_skip = skip
312
+
313
+ # norm layers
314
+ if norm_layer == 'layernorm':
315
+ norm_layer = nn.LayerNorm
316
+ elif norm_layer == 'rmsnorm':
317
+ norm_layer = RMSNorm
318
+ else:
319
+ raise NotImplementedError
320
+
321
+ self.in_blocks = nn.ModuleList([
322
+ AudioDiTBlock(
323
+ dim=embed_dim,
324
+ time_aligned_context_dim=time_aligned_context_dim,
325
+ dilation=2**(i % dilation_cycle_length),
326
+ context_dim=context_dim,
327
+ num_heads=num_heads,
328
+ mlp_ratio=mlp_ratio,
329
+ qkv_bias=qkv_bias,
330
+ qk_scale=qk_scale,
331
+ qk_norm=qk_norm,
332
+ act_layer=act_layer,
333
+ norm_layer=norm_layer,
334
+ time_fusion=time_fusion,
335
+ ada_sola_rank=ada_sola_rank,
336
+ ada_sola_alpha=ada_sola_alpha,
337
+ skip=False,
338
+ skip_norm=False,
339
+ rope_mode=self.rope,
340
+ context_norm=context_norm,
341
+ use_checkpoint=use_checkpoint
342
+ ) for i in range(depth // 2)
343
+ ])
344
+
345
+ self.mid_block = AudioDiTBlock(
346
+ dim=embed_dim,
347
+ time_aligned_context_dim=time_aligned_context_dim,
348
+ dilation=1,
349
+ context_dim=context_dim,
350
+ num_heads=num_heads,
351
+ mlp_ratio=mlp_ratio,
352
+ qkv_bias=qkv_bias,
353
+ qk_scale=qk_scale,
354
+ qk_norm=qk_norm,
355
+ act_layer=act_layer,
356
+ norm_layer=norm_layer,
357
+ time_fusion=time_fusion,
358
+ ada_sola_rank=ada_sola_rank,
359
+ ada_sola_alpha=ada_sola_alpha,
360
+ skip=False,
361
+ skip_norm=False,
362
+ rope_mode=self.rope,
363
+ context_norm=context_norm,
364
+ use_checkpoint=use_checkpoint
365
+ )
366
+
367
+ self.out_blocks = nn.ModuleList([
368
+ AudioDiTBlock(
369
+ dim=embed_dim,
370
+ time_aligned_context_dim=time_aligned_context_dim,
371
+ dilation=2**(i % dilation_cycle_length),
372
+ context_dim=context_dim,
373
+ num_heads=num_heads,
374
+ mlp_ratio=mlp_ratio,
375
+ qkv_bias=qkv_bias,
376
+ qk_scale=qk_scale,
377
+ qk_norm=qk_norm,
378
+ act_layer=act_layer,
379
+ norm_layer=norm_layer,
380
+ time_fusion=time_fusion,
381
+ ada_sola_rank=ada_sola_rank,
382
+ ada_sola_alpha=ada_sola_alpha,
383
+ skip=skip,
384
+ skip_norm=skip_norm,
385
+ rope_mode=self.rope,
386
+ context_norm=context_norm,
387
+ use_checkpoint=use_checkpoint
388
+ ) for i in range(depth // 2)
389
+ ])
390
+
391
+ # FinalLayer block
392
+ self.use_conv = use_conv
393
+ self.final_block = FinalBlock(
394
+ embed_dim=embed_dim,
395
+ patch_size=patch_size,
396
+ img_size=img_size,
397
+ in_chans=out_chans,
398
+ input_type=input_type,
399
+ norm_layer=norm_layer,
400
+ use_conv=use_conv,
401
+ use_adanorm=self.use_adanorm
402
+ )
403
+ self.initialize_weights()
404
+
405
+ def forward(
406
+ self,
407
+ x,
408
+ timesteps,
409
+ time_aligned_context,
410
+ context,
411
+ x_mask=None,
412
+ context_mask=None,
413
+ cls_token=None,
414
+ controlnet_skips=None,
415
+ ):
416
+ # make it compatible with int time step during inference
417
+ if timesteps.dim() == 0:
418
+ timesteps = timesteps.expand(x.shape[0]
419
+ ).to(x.device, dtype=torch.long)
420
+
421
+ x = self.patch_embed(x)
422
+ x = self.x_pe(x)
423
+
424
+ B, L, D = x.shape
425
+
426
+ if self.use_context:
427
+ context_token = self.context_embed(context)
428
+ context_token = self.context_pe(context_token)
429
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
430
+ x, x_mask = self._concat_x_context(
431
+ x=x,
432
+ context=context_token,
433
+ x_mask=x_mask,
434
+ context_mask=context_mask
435
+ )
436
+ context_token, context_mask = None, None
437
+ else:
438
+ context_token, context_mask = None, None
439
+
440
+ time_token = self.time_embed(timesteps)
441
+ if self.cls_embed:
442
+ cls_token = self.cls_embed(cls_token)
443
+ time_ada = None
444
+ time_ada_final = None
445
+ if self.use_adanorm:
446
+ if self.cls_embed:
447
+ time_token = time_token + cls_token
448
+ time_token = self.time_act(time_token)
449
+ time_ada_final = self.time_ada_final(time_token)
450
+ if self.time_ada is not None:
451
+ time_ada = self.time_ada(time_token)
452
+ else:
453
+ time_token = time_token.unsqueeze(dim=1)
454
+ if self.cls_embed:
455
+ cls_token = cls_token.unsqueeze(dim=1)
456
+ time_token = torch.cat([time_token, cls_token], dim=1)
457
+ time_token = self.time_pe(time_token)
458
+ x = torch.cat((time_token, x), dim=1)
459
+ if x_mask is not None:
460
+ x_mask = torch.cat([
461
+ torch.ones(B, time_token.shape[1],
462
+ device=x_mask.device).bool(), x_mask
463
+ ],
464
+ dim=1)
465
+ time_token = None
466
+
467
+ skips = []
468
+ for blk in self.in_blocks:
469
+ x = blk(
470
+ x=x,
471
+ time_aligned_context=time_aligned_context,
472
+ time_token=time_token,
473
+ time_ada=time_ada,
474
+ skip=None,
475
+ context=context_token,
476
+ x_mask=x_mask,
477
+ context_mask=context_mask,
478
+ extras=self.extras
479
+ )
480
+ if self.use_skip:
481
+ skips.append(x)
482
+
483
+ x = self.mid_block(
484
+ x=x,
485
+ time_aligned_context=time_aligned_context,
486
+ time_token=time_token,
487
+ time_ada=time_ada,
488
+ skip=None,
489
+ context=context_token,
490
+ x_mask=x_mask,
491
+ context_mask=context_mask,
492
+ extras=self.extras
493
+ )
494
+ for blk in self.out_blocks:
495
+ if self.use_skip:
496
+ skip = skips.pop()
497
+ if controlnet_skips:
498
+ # add to skip like u-net controlnet
499
+ skip = skip + controlnet_skips.pop()
500
+ else:
501
+ skip = None
502
+ if controlnet_skips:
503
+ # directly add to x
504
+ x = x + controlnet_skips.pop()
505
+
506
+ x = blk(
507
+ x=x,
508
+ time_aligned_context=time_aligned_context,
509
+ time_token=time_token,
510
+ time_ada=time_ada,
511
+ skip=skip,
512
+ context=context_token,
513
+ x_mask=x_mask,
514
+ context_mask=context_mask,
515
+ extras=self.extras
516
+ )
517
+
518
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
519
+
520
+ return x
models/dit/audio_dit.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from .mask_dit import DiTBlock, FinalBlock, UDiT
6
+ from .modules import (
7
+ film_modulate,
8
+ PatchEmbed,
9
+ PE_wrapper,
10
+ TimestepEmbedder,
11
+ RMSNorm,
12
+ )
13
+
14
+
15
+ class AudioDiTBlock(DiTBlock):
16
+ """
17
+ A modified DiT block with time aligned context add to latent.
18
+ """
19
+ def __init__(
20
+ self,
21
+ dim,
22
+ ta_context_dim,
23
+ ta_context_norm=False,
24
+ context_dim=None,
25
+ num_heads=8,
26
+ mlp_ratio=4.,
27
+ qkv_bias=False,
28
+ qk_scale=None,
29
+ qk_norm=None,
30
+ act_layer='gelu',
31
+ norm_layer=nn.LayerNorm,
32
+ ta_context_fusion='add',
33
+ time_fusion='none',
34
+ ada_sola_rank=None,
35
+ ada_sola_alpha=None,
36
+ skip=False,
37
+ skip_norm=False,
38
+ rope_mode='none',
39
+ context_norm=False,
40
+ use_checkpoint=False
41
+ ):
42
+ super().__init__(
43
+ dim=dim,
44
+ context_dim=context_dim,
45
+ num_heads=num_heads,
46
+ mlp_ratio=mlp_ratio,
47
+ qkv_bias=qkv_bias,
48
+ qk_scale=qk_scale,
49
+ qk_norm=qk_norm,
50
+ act_layer=act_layer,
51
+ norm_layer=norm_layer,
52
+ time_fusion=time_fusion,
53
+ ada_sola_rank=ada_sola_rank,
54
+ ada_sola_alpha=ada_sola_alpha,
55
+ skip=skip,
56
+ skip_norm=skip_norm,
57
+ rope_mode=rope_mode,
58
+ context_norm=context_norm,
59
+ use_checkpoint=use_checkpoint
60
+ )
61
+ self.ta_context_fusion = ta_context_fusion
62
+ self.ta_context_norm = ta_context_norm
63
+ if self.ta_context_fusion == "add":
64
+ self.ta_context_projection = nn.Linear(ta_context_dim, dim)
65
+ self.ta_context_norm = norm_layer(
66
+ ta_context_dim
67
+ ) if self.ta_context_norm else nn.Identity()
68
+ elif self.ta_context_fusion == "concat":
69
+ self.ta_context_projection = nn.Linear(ta_context_dim + dim, dim)
70
+ self.ta_context_norm = norm_layer(
71
+ ta_context_dim + dim
72
+ ) if self.ta_context_norm else nn.Identity()
73
+
74
+ def forward(
75
+ self,
76
+ x,
77
+ time_aligned_context,
78
+ time_token=None,
79
+ time_ada=None,
80
+ skip=None,
81
+ context=None,
82
+ x_mask=None,
83
+ context_mask=None,
84
+ extras=None
85
+ ):
86
+ if self.use_checkpoint:
87
+ return checkpoint(
88
+ self._forward,
89
+ x,
90
+ time_aligned_context,
91
+ time_token,
92
+ time_ada,
93
+ skip,
94
+ context,
95
+ x_mask,
96
+ context_mask,
97
+ extras,
98
+ use_reentrant=False
99
+ )
100
+ else:
101
+ return self._forward(
102
+ x,
103
+ time_aligned_context,
104
+ time_token,
105
+ time_ada,
106
+ skip,
107
+ context,
108
+ x_mask,
109
+ context_mask,
110
+ extras,
111
+ )
112
+
113
+ def _forward(
114
+ self,
115
+ x,
116
+ time_aligned_context,
117
+ time_token=None,
118
+ time_ada=None,
119
+ skip=None,
120
+ context=None,
121
+ x_mask=None,
122
+ context_mask=None,
123
+ extras=None
124
+ ):
125
+ B, T, C = x.shape
126
+
127
+ # # time aligned context
128
+ # if self.ta_context_fusion == "add":
129
+ # time_aligned_context = self.ta_context_projection(
130
+ # self.ta_context_norm(time_aligned_context)
131
+ # )
132
+ # x = x + time_aligned_context
133
+ # elif self.ta_context_fusion == "concat":
134
+ # cat = torch.cat([x, time_aligned_context], dim=-1)
135
+ # cat = self.ta_context_norm(cat)
136
+ # x = self.ta_context_projection(cat)
137
+
138
+ # skip connection
139
+ if self.skip_linear is not None:
140
+ assert skip is not None
141
+ cat = torch.cat([x, skip], dim=-1)
142
+ cat = self.skip_norm(cat)
143
+ x = self.skip_linear(cat)
144
+ #print('skip')
145
+ #print(x)
146
+ if self.use_adanorm:
147
+ time_ada = self.adaln(time_token, time_ada)
148
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
149
+ gate_mlp) = time_ada.chunk(6, dim=1)
150
+
151
+ # self attention
152
+ if self.use_adanorm:
153
+ x_norm = film_modulate(
154
+ self.norm1(x), shift=shift_msa, scale=scale_msa
155
+ )
156
+ x = x + (1-gate_msa) * self.attn(
157
+ x_norm, context=None, context_mask=x_mask, extras=extras
158
+ )
159
+ else:
160
+ # TODO diffusion timestep input is not fused here
161
+ x = x + self.attn(
162
+ self.norm1(x),
163
+ context=None,
164
+ context_mask=x_mask,
165
+ extras=extras
166
+ )
167
+
168
+ # time aligned context fusion
169
+ if self.ta_context_fusion == "add":
170
+ time_aligned_context = self.ta_context_projection(
171
+ self.ta_context_norm(time_aligned_context)
172
+ )
173
+ x = x + time_aligned_context
174
+ elif self.ta_context_fusion == "concat":
175
+ cat = torch.cat([x, time_aligned_context], dim=-1)
176
+ cat = self.ta_context_norm(cat)
177
+ x = self.ta_context_projection(cat)
178
+
179
+ # cross attention
180
+ if self.use_context:
181
+ assert context is not None
182
+ x = x + self.cross_attn(
183
+ x=self.norm2(x),
184
+ context=self.norm_context(context),
185
+ context_mask=context_mask,
186
+ extras=extras
187
+ )
188
+
189
+ # mlp
190
+ if self.use_adanorm:
191
+ x_norm = film_modulate(
192
+ self.norm3(x), shift=shift_mlp, scale=scale_mlp
193
+ )
194
+ x = x + (1-gate_mlp) * self.mlp(x_norm)
195
+ else:
196
+ x = x + self.mlp(self.norm3(x))
197
+
198
+ return x
199
+
200
+
201
+ class AudioUDiT(UDiT):
202
+ def __init__(
203
+ self,
204
+ img_size=224,
205
+ patch_size=16,
206
+ in_chans=3,
207
+ input_type='2d',
208
+ out_chans=None,
209
+ embed_dim=768,
210
+ depth=12,
211
+ num_heads=12,
212
+ mlp_ratio=4,
213
+ qkv_bias=False,
214
+ qk_scale=None,
215
+ qk_norm=None,
216
+ act_layer='gelu',
217
+ norm_layer='layernorm',
218
+ context_norm=False,
219
+ use_checkpoint=False,
220
+ time_fusion='token',
221
+ ada_sola_rank=None,
222
+ ada_sola_alpha=None,
223
+ cls_dim=None,
224
+ ta_context_dim=768,
225
+ ta_context_fusion='concat',
226
+ ta_context_norm=True,
227
+ context_dim=768,
228
+ context_fusion='concat',
229
+ context_max_length=128,
230
+ context_pe_method='sinu',
231
+ pe_method='abs',
232
+ rope_mode='none',
233
+ use_conv=True,
234
+ skip=True,
235
+ skip_norm=True
236
+ ):
237
+ nn.Module.__init__(self)
238
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
239
+
240
+ # input
241
+ self.in_chans = in_chans
242
+ self.input_type = input_type
243
+ if self.input_type == '2d':
244
+ num_patches = (img_size[0] //
245
+ patch_size) * (img_size[1] // patch_size)
246
+ elif self.input_type == '1d':
247
+ num_patches = img_size // patch_size
248
+ self.patch_embed = PatchEmbed(
249
+ patch_size=patch_size,
250
+ in_chans=in_chans,
251
+ embed_dim=embed_dim,
252
+ input_type=input_type
253
+ )
254
+ out_chans = in_chans if out_chans is None else out_chans
255
+ self.out_chans = out_chans
256
+
257
+ # position embedding
258
+ self.rope = rope_mode
259
+ self.x_pe = PE_wrapper(
260
+ dim=embed_dim, method=pe_method, length=num_patches
261
+ )
262
+
263
+ # time embed
264
+ self.time_embed = TimestepEmbedder(embed_dim)
265
+ self.time_fusion = time_fusion
266
+ self.use_adanorm = False
267
+
268
+ # cls embed
269
+ if cls_dim is not None:
270
+ self.cls_embed = nn.Sequential(
271
+ nn.Linear(cls_dim, embed_dim, bias=True),
272
+ nn.SiLU(),
273
+ nn.Linear(embed_dim, embed_dim, bias=True),
274
+ )
275
+ else:
276
+ self.cls_embed = None
277
+
278
+ # time fusion
279
+ if time_fusion == 'token':
280
+ # put token at the beginning of sequence
281
+ self.extras = 2 if self.cls_embed else 1
282
+ self.time_pe = PE_wrapper(
283
+ dim=embed_dim, method='abs', length=self.extras
284
+ )
285
+ elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
286
+ self.use_adanorm = True
287
+ # aviod repetitive silu for each adaln block
288
+ self.time_act = nn.SiLU()
289
+ self.extras = 0
290
+ self.time_ada_final = nn.Linear(
291
+ embed_dim, 2 * embed_dim, bias=True
292
+ )
293
+ if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
294
+ # shared adaln
295
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
296
+ else:
297
+ self.time_ada = None
298
+ else:
299
+ raise NotImplementedError
300
+
301
+ # context
302
+ # use a simple projection
303
+ self.use_context = False
304
+ self.context_cross = False
305
+ self.context_max_length = context_max_length
306
+ self.context_fusion = 'none'
307
+ if context_dim is not None:
308
+ self.use_context = True
309
+ self.context_embed = nn.Sequential(
310
+ nn.Linear(context_dim, embed_dim, bias=True),
311
+ nn.SiLU(),
312
+ nn.Linear(embed_dim, embed_dim, bias=True),
313
+ )
314
+ self.context_fusion = context_fusion
315
+ if context_fusion == 'concat' or context_fusion == 'joint':
316
+ self.extras += context_max_length
317
+ self.context_pe = PE_wrapper(
318
+ dim=embed_dim,
319
+ method=context_pe_method,
320
+ length=context_max_length
321
+ )
322
+ # no cross attention layers
323
+ context_dim = None
324
+ elif context_fusion == 'cross':
325
+ self.context_pe = PE_wrapper(
326
+ dim=embed_dim,
327
+ method=context_pe_method,
328
+ length=context_max_length
329
+ )
330
+ self.context_cross = True
331
+ context_dim = embed_dim
332
+ else:
333
+ raise NotImplementedError
334
+
335
+ self.use_skip = skip
336
+
337
+ # norm layers
338
+ if norm_layer == 'layernorm':
339
+ norm_layer = nn.LayerNorm
340
+ elif norm_layer == 'rmsnorm':
341
+ norm_layer = RMSNorm
342
+ else:
343
+ raise NotImplementedError
344
+
345
+ self.in_blocks = nn.ModuleList([
346
+ AudioDiTBlock(
347
+ dim=embed_dim,
348
+ ta_context_dim=ta_context_dim,
349
+ ta_context_fusion=ta_context_fusion,
350
+ ta_context_norm=ta_context_norm,
351
+ context_dim=context_dim,
352
+ num_heads=num_heads,
353
+ mlp_ratio=mlp_ratio,
354
+ qkv_bias=qkv_bias,
355
+ qk_scale=qk_scale,
356
+ qk_norm=qk_norm,
357
+ act_layer=act_layer,
358
+ norm_layer=norm_layer,
359
+ time_fusion=time_fusion,
360
+ ada_sola_rank=ada_sola_rank,
361
+ ada_sola_alpha=ada_sola_alpha,
362
+ skip=False,
363
+ skip_norm=False,
364
+ rope_mode=self.rope,
365
+ context_norm=context_norm,
366
+ use_checkpoint=use_checkpoint
367
+ ) for i in range(depth // 2)
368
+ ])
369
+
370
+ self.mid_block = AudioDiTBlock(
371
+ dim=embed_dim,
372
+ ta_context_dim=ta_context_dim,
373
+ context_dim=context_dim,
374
+ num_heads=num_heads,
375
+ mlp_ratio=mlp_ratio,
376
+ qkv_bias=qkv_bias,
377
+ qk_scale=qk_scale,
378
+ qk_norm=qk_norm,
379
+ act_layer=act_layer,
380
+ norm_layer=norm_layer,
381
+ time_fusion=time_fusion,
382
+ ada_sola_rank=ada_sola_rank,
383
+ ada_sola_alpha=ada_sola_alpha,
384
+ ta_context_fusion=ta_context_fusion,
385
+ ta_context_norm=ta_context_norm,
386
+ skip=False,
387
+ skip_norm=False,
388
+ rope_mode=self.rope,
389
+ context_norm=context_norm,
390
+ use_checkpoint=use_checkpoint
391
+ )
392
+
393
+ self.out_blocks = nn.ModuleList([
394
+ AudioDiTBlock(
395
+ dim=embed_dim,
396
+ ta_context_dim=ta_context_dim,
397
+ context_dim=context_dim,
398
+ num_heads=num_heads,
399
+ mlp_ratio=mlp_ratio,
400
+ qkv_bias=qkv_bias,
401
+ qk_scale=qk_scale,
402
+ qk_norm=qk_norm,
403
+ act_layer=act_layer,
404
+ norm_layer=norm_layer,
405
+ time_fusion=time_fusion,
406
+ ada_sola_rank=ada_sola_rank,
407
+ ada_sola_alpha=ada_sola_alpha,
408
+ ta_context_fusion=ta_context_fusion,
409
+ ta_context_norm=ta_context_norm,
410
+ skip=skip,
411
+ skip_norm=skip_norm,
412
+ rope_mode=self.rope,
413
+ context_norm=context_norm,
414
+ use_checkpoint=use_checkpoint
415
+ ) for i in range(depth // 2)
416
+ ])
417
+
418
+ # FinalLayer block
419
+ self.use_conv = use_conv
420
+ self.final_block = FinalBlock(
421
+ embed_dim=embed_dim,
422
+ patch_size=patch_size,
423
+ img_size=img_size,
424
+ in_chans=out_chans,
425
+ input_type=input_type,
426
+ norm_layer=norm_layer,
427
+ use_conv=use_conv,
428
+ use_adanorm=self.use_adanorm
429
+ )
430
+ self.initialize_weights()
431
+
432
+ def forward(
433
+ self,
434
+ x,
435
+ timesteps,
436
+ time_aligned_context,
437
+ context,
438
+ x_mask=None,
439
+ context_mask=None,
440
+ cls_token=None,
441
+ controlnet_skips=None,
442
+ ):
443
+ # make it compatible with int time step during inference
444
+ if timesteps.dim() == 0:
445
+ timesteps = timesteps.expand(x.shape[0]
446
+ ).to(x.device, dtype=torch.long)
447
+
448
+ x = self.patch_embed(x)
449
+ x = self.x_pe(x)
450
+
451
+ B, L, D = x.shape
452
+
453
+ if self.use_context:
454
+ context_token = self.context_embed(context)
455
+ context_token = self.context_pe(context_token)
456
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
457
+ x, x_mask = self._concat_x_context(
458
+ x=x,
459
+ context=context_token,
460
+ x_mask=x_mask,
461
+ context_mask=context_mask
462
+ )
463
+ context_token, context_mask = None, None
464
+ else:
465
+ context_token, context_mask = None, None
466
+
467
+ time_token = self.time_embed(timesteps)
468
+ if self.cls_embed:
469
+ cls_token = self.cls_embed(cls_token)
470
+ time_ada = None
471
+ time_ada_final = None
472
+ if self.use_adanorm:
473
+ if self.cls_embed:
474
+ time_token = time_token + cls_token
475
+ time_token = self.time_act(time_token)
476
+ time_ada_final = self.time_ada_final(time_token)
477
+ if self.time_ada is not None:
478
+ time_ada = self.time_ada(time_token)
479
+ else:
480
+ time_token = time_token.unsqueeze(dim=1)
481
+ if self.cls_embed:
482
+ cls_token = cls_token.unsqueeze(dim=1)
483
+ time_token = torch.cat([time_token, cls_token], dim=1)
484
+ time_token = self.time_pe(time_token)
485
+ x = torch.cat((time_token, x), dim=1)
486
+ if x_mask is not None:
487
+ x_mask = torch.cat([
488
+ torch.ones(B, time_token.shape[1],
489
+ device=x_mask.device).bool(), x_mask
490
+ ],
491
+ dim=1)
492
+ time_token = None
493
+
494
+ skips = []
495
+ for blk in self.in_blocks:
496
+ x = blk(
497
+ x=x,
498
+ time_aligned_context=time_aligned_context,
499
+ time_token=time_token,
500
+ time_ada=time_ada,
501
+ skip=None,
502
+ context=context_token,
503
+ x_mask=x_mask,
504
+ context_mask=context_mask,
505
+ extras=self.extras
506
+ )
507
+
508
+ if self.use_skip:
509
+ skips.append(x)
510
+
511
+ x = self.mid_block(
512
+ x=x,
513
+ time_aligned_context=time_aligned_context,
514
+ time_token=time_token,
515
+ time_ada=time_ada,
516
+ skip=None,
517
+ context=context_token,
518
+ x_mask=x_mask,
519
+ context_mask=context_mask,
520
+ extras=self.extras
521
+ )
522
+
523
+ for blk in self.out_blocks:
524
+ if self.use_skip:
525
+ skip = skips.pop()
526
+ if controlnet_skips:
527
+ # add to skip like u-net controlnet
528
+ skip = skip + controlnet_skips.pop()
529
+ else:
530
+ skip = None
531
+ if controlnet_skips:
532
+ # directly add to x
533
+ x = x + controlnet_skips.pop()
534
+
535
+ x = blk(
536
+ x=x,
537
+ time_aligned_context=time_aligned_context,
538
+ time_token=time_token,
539
+ time_ada=time_ada,
540
+ skip=skip,
541
+ context=context_token,
542
+ x_mask=x_mask,
543
+ context_mask=context_mask,
544
+ extras=self.extras
545
+ )
546
+
547
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
548
+
549
+ return x
models/dit/mask_dit.py ADDED
@@ -0,0 +1,823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.checkpoint import checkpoint
6
+
7
+ from .modules import (
8
+ film_modulate,
9
+ unpatchify,
10
+ PatchEmbed,
11
+ PE_wrapper,
12
+ TimestepEmbedder,
13
+ FeedForward,
14
+ RMSNorm,
15
+ )
16
+ from .span_mask import compute_mask_indices
17
+ from .attention import Attention
18
+
19
+ logger = logging.Logger(__file__)
20
+
21
+
22
+ class AdaLN(nn.Module):
23
+ def __init__(self, dim, ada_mode='ada', r=None, alpha=None):
24
+ super().__init__()
25
+ self.ada_mode = ada_mode
26
+ self.scale_shift_table = None
27
+ if ada_mode == 'ada':
28
+ # move nn.silu outside
29
+ self.time_ada = nn.Linear(dim, 6 * dim, bias=True)
30
+ elif ada_mode == 'ada_single':
31
+ # adaln used in pixel-art alpha
32
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
33
+ elif ada_mode in ['ada_solo', 'ada_sola_bias']:
34
+ self.lora_a = nn.Linear(dim, r * 6, bias=False)
35
+ self.lora_b = nn.Linear(r * 6, dim * 6, bias=False)
36
+ self.scaling = alpha / r
37
+ if ada_mode == 'ada_sola_bias':
38
+ # take bias out for consistency
39
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
40
+ else:
41
+ raise NotImplementedError
42
+
43
+ def forward(self, time_token=None, time_ada=None):
44
+ if self.ada_mode == 'ada':
45
+ assert time_ada is None
46
+ B = time_token.shape[0]
47
+ time_ada = self.time_ada(time_token).reshape(B, 6, -1)
48
+ elif self.ada_mode == 'ada_single':
49
+ B = time_ada.shape[0]
50
+ time_ada = time_ada.reshape(B, 6, -1)
51
+ time_ada = self.scale_shift_table[None] + time_ada
52
+ elif self.ada_mode in ['ada_sola', 'ada_sola_bias']:
53
+ B = time_ada.shape[0]
54
+ time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling
55
+ time_ada = time_ada + time_ada_lora
56
+ time_ada = time_ada.reshape(B, 6, -1)
57
+ if self.scale_shift_table is not None:
58
+ time_ada = self.scale_shift_table[None] + time_ada
59
+ else:
60
+ raise NotImplementedError
61
+ return time_ada
62
+
63
+
64
+ class DiTBlock(nn.Module):
65
+ """
66
+ A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
67
+ """
68
+ def __init__(
69
+ self,
70
+ dim,
71
+ context_dim=None,
72
+ num_heads=8,
73
+ mlp_ratio=4.,
74
+ qkv_bias=False,
75
+ qk_scale=None,
76
+ qk_norm=None,
77
+ act_layer='gelu',
78
+ norm_layer=nn.LayerNorm,
79
+ time_fusion='none',
80
+ ada_sola_rank=None,
81
+ ada_sola_alpha=None,
82
+ skip=False,
83
+ skip_norm=False,
84
+ rope_mode='none',
85
+ context_norm=False,
86
+ use_checkpoint=False
87
+ ):
88
+
89
+ super().__init__()
90
+ self.norm1 = norm_layer(dim)
91
+ self.attn = Attention(
92
+ dim=dim,
93
+ num_heads=num_heads,
94
+ qkv_bias=qkv_bias,
95
+ qk_scale=qk_scale,
96
+ qk_norm=qk_norm,
97
+ rope_mode=rope_mode
98
+ )
99
+
100
+ if context_dim is not None:
101
+ self.use_context = True
102
+ self.cross_attn = Attention(
103
+ dim=dim,
104
+ num_heads=num_heads,
105
+ context_dim=context_dim,
106
+ qkv_bias=qkv_bias,
107
+ qk_scale=qk_scale,
108
+ qk_norm=qk_norm,
109
+ rope_mode='none'
110
+ )
111
+ self.norm2 = norm_layer(dim)
112
+ if context_norm:
113
+ self.norm_context = norm_layer(context_dim)
114
+ else:
115
+ self.norm_context = nn.Identity()
116
+ else:
117
+ self.use_context = False
118
+
119
+ self.norm3 = norm_layer(dim)
120
+ self.mlp = FeedForward(
121
+ dim=dim, mult=mlp_ratio, activation_fn=act_layer, dropout=0
122
+ )
123
+
124
+ self.use_adanorm = True if time_fusion != 'token' else False
125
+ if self.use_adanorm:
126
+ self.adaln = AdaLN(
127
+ dim,
128
+ ada_mode=time_fusion,
129
+ r=ada_sola_rank,
130
+ alpha=ada_sola_alpha
131
+ )
132
+ if skip:
133
+ self.skip_norm = norm_layer(2 *
134
+ dim) if skip_norm else nn.Identity()
135
+ self.skip_linear = nn.Linear(2 * dim, dim)
136
+ else:
137
+ self.skip_linear = None
138
+
139
+ self.use_checkpoint = use_checkpoint
140
+
141
+ def forward(
142
+ self,
143
+ x,
144
+ time_token=None,
145
+ time_ada=None,
146
+ skip=None,
147
+ context=None,
148
+ x_mask=None,
149
+ context_mask=None,
150
+ extras=None
151
+ ):
152
+ if self.use_checkpoint:
153
+ return checkpoint(
154
+ self._forward,
155
+ x,
156
+ time_token,
157
+ time_ada,
158
+ skip,
159
+ context,
160
+ x_mask,
161
+ context_mask,
162
+ extras,
163
+ use_reentrant=False
164
+ )
165
+ else:
166
+ return self._forward(
167
+ x, time_token, time_ada, skip, context, x_mask, context_mask,
168
+ extras
169
+ )
170
+
171
+ def _forward(
172
+ self,
173
+ x,
174
+ time_token=None,
175
+ time_ada=None,
176
+ skip=None,
177
+ context=None,
178
+ x_mask=None,
179
+ context_mask=None,
180
+ extras=None
181
+ ):
182
+ B, T, C = x.shape
183
+ if self.skip_linear is not None:
184
+ assert skip is not None
185
+ cat = torch.cat([x, skip], dim=-1)
186
+ cat = self.skip_norm(cat)
187
+ x = self.skip_linear(cat)
188
+
189
+ if self.use_adanorm:
190
+ time_ada = self.adaln(time_token, time_ada)
191
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
192
+ gate_mlp) = time_ada.chunk(6, dim=1)
193
+
194
+ # self attention
195
+ if self.use_adanorm:
196
+ x_norm = film_modulate(
197
+ self.norm1(x), shift=shift_msa, scale=scale_msa
198
+ )
199
+ x = x + (1-gate_msa) * self.attn(
200
+ x_norm, context=None, context_mask=x_mask, extras=extras
201
+ )
202
+ else:
203
+ x = x + self.attn(
204
+ self.norm1(x),
205
+ context=None,
206
+ context_mask=x_mask,
207
+ extras=extras
208
+ )
209
+
210
+ # cross attention
211
+ if self.use_context:
212
+ assert context is not None
213
+ x = x + self.cross_attn(
214
+ x=self.norm2(x),
215
+ context=self.norm_context(context),
216
+ context_mask=context_mask,
217
+ extras=extras
218
+ )
219
+
220
+ # mlp
221
+ if self.use_adanorm:
222
+ x_norm = film_modulate(
223
+ self.norm3(x), shift=shift_mlp, scale=scale_mlp
224
+ )
225
+ x = x + (1-gate_mlp) * self.mlp(x_norm)
226
+ else:
227
+ x = x + self.mlp(self.norm3(x))
228
+
229
+ return x
230
+
231
+
232
+ class FinalBlock(nn.Module):
233
+ def __init__(
234
+ self,
235
+ embed_dim,
236
+ patch_size,
237
+ in_chans,
238
+ img_size,
239
+ input_type='2d',
240
+ norm_layer=nn.LayerNorm,
241
+ use_conv=True,
242
+ use_adanorm=True
243
+ ):
244
+ super().__init__()
245
+ self.in_chans = in_chans
246
+ self.img_size = img_size
247
+ self.input_type = input_type
248
+
249
+ self.norm = norm_layer(embed_dim)
250
+ if use_adanorm:
251
+ self.use_adanorm = True
252
+ else:
253
+ self.use_adanorm = False
254
+
255
+ if input_type == '2d':
256
+ self.patch_dim = patch_size**2 * in_chans
257
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
258
+ if use_conv:
259
+ self.final_layer = nn.Conv2d(
260
+ self.in_chans, self.in_chans, 3, padding=1
261
+ )
262
+ else:
263
+ self.final_layer = nn.Identity()
264
+
265
+ elif input_type == '1d':
266
+ self.patch_dim = patch_size * in_chans
267
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
268
+ if use_conv:
269
+ self.final_layer = nn.Conv1d(
270
+ self.in_chans, self.in_chans, 3, padding=1
271
+ )
272
+ else:
273
+ self.final_layer = nn.Identity()
274
+
275
+ def forward(self, x, time_ada=None, extras=0):
276
+ B, T, C = x.shape
277
+ x = x[:, extras:, :]
278
+ # only handle generation target
279
+ if self.use_adanorm:
280
+ shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1)
281
+ x = film_modulate(self.norm(x), shift, scale)
282
+ else:
283
+ x = self.norm(x)
284
+ x = self.linear(x)
285
+ x = unpatchify(x, self.in_chans, self.input_type, self.img_size)
286
+ x = self.final_layer(x)
287
+ return x
288
+
289
+
290
+ class UDiT(nn.Module):
291
+ def __init__(
292
+ self,
293
+ img_size=224,
294
+ patch_size=16,
295
+ in_chans=3,
296
+ input_type='2d',
297
+ out_chans=None,
298
+ embed_dim=768,
299
+ depth=12,
300
+ num_heads=12,
301
+ mlp_ratio=4.,
302
+ qkv_bias=False,
303
+ qk_scale=None,
304
+ qk_norm=None,
305
+ act_layer='gelu',
306
+ norm_layer='layernorm',
307
+ context_norm=False,
308
+ use_checkpoint=False,
309
+ # time fusion ada or token
310
+ time_fusion='token',
311
+ ada_sola_rank=None,
312
+ ada_sola_alpha=None,
313
+ cls_dim=None,
314
+ # max length is only used for concat
315
+ context_dim=768,
316
+ context_fusion='concat',
317
+ context_max_length=128,
318
+ context_pe_method='sinu',
319
+ pe_method='abs',
320
+ rope_mode='none',
321
+ use_conv=True,
322
+ skip=True,
323
+ skip_norm=True
324
+ ):
325
+ super().__init__()
326
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
327
+
328
+ # input
329
+ self.in_chans = in_chans
330
+ self.input_type = input_type
331
+ if self.input_type == '2d':
332
+ num_patches = (img_size[0] //
333
+ patch_size) * (img_size[1] // patch_size)
334
+ elif self.input_type == '1d':
335
+ num_patches = img_size // patch_size
336
+ self.patch_embed = PatchEmbed(
337
+ patch_size=patch_size,
338
+ in_chans=in_chans,
339
+ embed_dim=embed_dim,
340
+ input_type=input_type
341
+ )
342
+ out_chans = in_chans if out_chans is None else out_chans
343
+ self.out_chans = out_chans
344
+
345
+ # position embedding
346
+ self.rope = rope_mode
347
+ self.x_pe = PE_wrapper(
348
+ dim=embed_dim, method=pe_method, length=num_patches
349
+ )
350
+
351
+ logger.info(f'x position embedding: {pe_method}')
352
+ logger.info(f'rope mode: {self.rope}')
353
+
354
+ # time embed
355
+ self.time_embed = TimestepEmbedder(embed_dim)
356
+ self.time_fusion = time_fusion
357
+ self.use_adanorm = False
358
+
359
+ # cls embed
360
+ if cls_dim is not None:
361
+ self.cls_embed = nn.Sequential(
362
+ nn.Linear(cls_dim, embed_dim, bias=True),
363
+ nn.SiLU(),
364
+ nn.Linear(embed_dim, embed_dim, bias=True),
365
+ )
366
+ else:
367
+ self.cls_embed = None
368
+
369
+ # time fusion
370
+ if time_fusion == 'token':
371
+ # put token at the beginning of sequence
372
+ self.extras = 2 if self.cls_embed else 1
373
+ self.time_pe = PE_wrapper(
374
+ dim=embed_dim, method='abs', length=self.extras
375
+ )
376
+ elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
377
+ self.use_adanorm = True
378
+ # aviod repetitive silu for each adaln block
379
+ self.time_act = nn.SiLU()
380
+ self.extras = 0
381
+ self.time_ada_final = nn.Linear(
382
+ embed_dim, 2 * embed_dim, bias=True
383
+ )
384
+ if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
385
+ # shared adaln
386
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
387
+ else:
388
+ self.time_ada = None
389
+ else:
390
+ raise NotImplementedError
391
+ logger.info(f'time fusion mode: {self.time_fusion}')
392
+
393
+ # context
394
+ # use a simple projection
395
+ self.use_context = False
396
+ self.context_cross = False
397
+ self.context_max_length = context_max_length
398
+ self.context_fusion = 'none'
399
+ if context_dim is not None:
400
+ self.use_context = True
401
+ self.context_embed = nn.Sequential(
402
+ nn.Linear(context_dim, embed_dim, bias=True),
403
+ nn.SiLU(),
404
+ nn.Linear(embed_dim, embed_dim, bias=True),
405
+ )
406
+ self.context_fusion = context_fusion
407
+ if context_fusion == 'concat' or context_fusion == 'joint':
408
+ self.extras += context_max_length
409
+ self.context_pe = PE_wrapper(
410
+ dim=embed_dim,
411
+ method=context_pe_method,
412
+ length=context_max_length
413
+ )
414
+ # no cross attention layers
415
+ context_dim = None
416
+ elif context_fusion == 'cross':
417
+ self.context_pe = PE_wrapper(
418
+ dim=embed_dim,
419
+ method=context_pe_method,
420
+ length=context_max_length
421
+ )
422
+ self.context_cross = True
423
+ context_dim = embed_dim
424
+ else:
425
+ raise NotImplementedError
426
+ logger.info(f'context fusion mode: {context_fusion}')
427
+ logger.info(f'context position embedding: {context_pe_method}')
428
+
429
+ self.use_skip = skip
430
+
431
+ # norm layers
432
+ if norm_layer == 'layernorm':
433
+ norm_layer = nn.LayerNorm
434
+ elif norm_layer == 'rmsnorm':
435
+ norm_layer = RMSNorm
436
+ else:
437
+ raise NotImplementedError
438
+
439
+ logger.info(f'use long skip connection: {skip}')
440
+ self.in_blocks = nn.ModuleList([
441
+ DiTBlock(
442
+ dim=embed_dim,
443
+ context_dim=context_dim,
444
+ num_heads=num_heads,
445
+ mlp_ratio=mlp_ratio,
446
+ qkv_bias=qkv_bias,
447
+ qk_scale=qk_scale,
448
+ qk_norm=qk_norm,
449
+ act_layer=act_layer,
450
+ norm_layer=norm_layer,
451
+ time_fusion=time_fusion,
452
+ ada_sola_rank=ada_sola_rank,
453
+ ada_sola_alpha=ada_sola_alpha,
454
+ skip=False,
455
+ skip_norm=False,
456
+ rope_mode=self.rope,
457
+ context_norm=context_norm,
458
+ use_checkpoint=use_checkpoint
459
+ ) for _ in range(depth // 2)
460
+ ])
461
+
462
+ self.mid_block = DiTBlock(
463
+ dim=embed_dim,
464
+ context_dim=context_dim,
465
+ num_heads=num_heads,
466
+ mlp_ratio=mlp_ratio,
467
+ qkv_bias=qkv_bias,
468
+ qk_scale=qk_scale,
469
+ qk_norm=qk_norm,
470
+ act_layer=act_layer,
471
+ norm_layer=norm_layer,
472
+ time_fusion=time_fusion,
473
+ ada_sola_rank=ada_sola_rank,
474
+ ada_sola_alpha=ada_sola_alpha,
475
+ skip=False,
476
+ skip_norm=False,
477
+ rope_mode=self.rope,
478
+ context_norm=context_norm,
479
+ use_checkpoint=use_checkpoint
480
+ )
481
+
482
+ self.out_blocks = nn.ModuleList([
483
+ DiTBlock(
484
+ dim=embed_dim,
485
+ context_dim=context_dim,
486
+ num_heads=num_heads,
487
+ mlp_ratio=mlp_ratio,
488
+ qkv_bias=qkv_bias,
489
+ qk_scale=qk_scale,
490
+ qk_norm=qk_norm,
491
+ act_layer=act_layer,
492
+ norm_layer=norm_layer,
493
+ time_fusion=time_fusion,
494
+ ada_sola_rank=ada_sola_rank,
495
+ ada_sola_alpha=ada_sola_alpha,
496
+ skip=skip,
497
+ skip_norm=skip_norm,
498
+ rope_mode=self.rope,
499
+ context_norm=context_norm,
500
+ use_checkpoint=use_checkpoint
501
+ ) for _ in range(depth // 2)
502
+ ])
503
+
504
+ # FinalLayer block
505
+ self.use_conv = use_conv
506
+ self.final_block = FinalBlock(
507
+ embed_dim=embed_dim,
508
+ patch_size=patch_size,
509
+ img_size=img_size,
510
+ in_chans=out_chans,
511
+ input_type=input_type,
512
+ norm_layer=norm_layer,
513
+ use_conv=use_conv,
514
+ use_adanorm=self.use_adanorm
515
+ )
516
+ self.initialize_weights()
517
+
518
+ def _init_ada(self):
519
+ if self.time_fusion == 'ada':
520
+ nn.init.constant_(self.time_ada_final.weight, 0)
521
+ nn.init.constant_(self.time_ada_final.bias, 0)
522
+ for block in self.in_blocks:
523
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
524
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
525
+ nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0)
526
+ nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0)
527
+ for block in self.out_blocks:
528
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
529
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
530
+ elif self.time_fusion == 'ada_single':
531
+ nn.init.constant_(self.time_ada.weight, 0)
532
+ nn.init.constant_(self.time_ada.bias, 0)
533
+ nn.init.constant_(self.time_ada_final.weight, 0)
534
+ nn.init.constant_(self.time_ada_final.bias, 0)
535
+ elif self.time_fusion in ['ada_sola', 'ada_sola_bias']:
536
+ nn.init.constant_(self.time_ada.weight, 0)
537
+ nn.init.constant_(self.time_ada.bias, 0)
538
+ nn.init.constant_(self.time_ada_final.weight, 0)
539
+ nn.init.constant_(self.time_ada_final.bias, 0)
540
+ for block in self.in_blocks:
541
+ nn.init.kaiming_uniform_(
542
+ block.adaln.lora_a.weight, a=math.sqrt(5)
543
+ )
544
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
545
+ nn.init.kaiming_uniform_(
546
+ self.mid_block.adaln.lora_a.weight, a=math.sqrt(5)
547
+ )
548
+ nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0)
549
+ for block in self.out_blocks:
550
+ nn.init.kaiming_uniform_(
551
+ block.adaln.lora_a.weight, a=math.sqrt(5)
552
+ )
553
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
554
+
555
+ def initialize_weights(self):
556
+ # Basic init for all layers
557
+ def _basic_init(module):
558
+ if isinstance(module, nn.Linear):
559
+ torch.nn.init.xavier_uniform_(module.weight)
560
+ if module.bias is not None:
561
+ nn.init.constant_(module.bias, 0)
562
+
563
+ self.apply(_basic_init)
564
+
565
+ # init patch Conv like Linear
566
+ w = self.patch_embed.proj.weight.data
567
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
568
+ nn.init.constant_(self.patch_embed.proj.bias, 0)
569
+
570
+ # Zero-out AdaLN
571
+ if self.use_adanorm:
572
+ self._init_ada()
573
+
574
+ # Zero-out Cross Attention
575
+ if self.context_cross:
576
+ for block in self.in_blocks:
577
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
578
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
579
+ nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0)
580
+ nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0)
581
+ for block in self.out_blocks:
582
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
583
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
584
+
585
+ # Zero-out cls embedding
586
+ if self.cls_embed:
587
+ if self.use_adanorm:
588
+ nn.init.constant_(self.cls_embed[-1].weight, 0)
589
+ nn.init.constant_(self.cls_embed[-1].bias, 0)
590
+
591
+ # Zero-out Output
592
+ # might not zero-out this when using v-prediction
593
+ # it could be good when using noise-prediction
594
+ # nn.init.constant_(self.final_block.linear.weight, 0)
595
+ # nn.init.constant_(self.final_block.linear.bias, 0)
596
+ # if self.use_conv:
597
+ # nn.init.constant_(self.final_block.final_layer.weight.data, 0)
598
+ # nn.init.constant_(self.final_block.final_layer.bias, 0)
599
+
600
+ # init out Conv
601
+ if self.use_conv:
602
+ nn.init.xavier_uniform_(self.final_block.final_layer.weight)
603
+ nn.init.constant_(self.final_block.final_layer.bias, 0)
604
+
605
+ def _concat_x_context(self, x, context, x_mask=None, context_mask=None):
606
+ assert context.shape[-2] == self.context_max_length
607
+ # Check if either x_mask or context_mask is provided
608
+ B = x.shape[0]
609
+ # Create default masks if they are not provided
610
+ if x_mask is None:
611
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
612
+ if context_mask is None:
613
+ context_mask = torch.ones(
614
+ B, context.shape[-2], device=context.device
615
+ ).bool()
616
+ # Concatenate the masks along the second dimension (dim=1)
617
+ x_mask = torch.cat([context_mask, x_mask], dim=1)
618
+ # Concatenate context and x along the second dimension (dim=1)
619
+ x = torch.cat((context, x), dim=1)
620
+ return x, x_mask
621
+
622
+ def forward(
623
+ self,
624
+ x,
625
+ timesteps,
626
+ context,
627
+ x_mask=None,
628
+ context_mask=None,
629
+ cls_token=None,
630
+ controlnet_skips=None,
631
+ ):
632
+ # make it compatible with int time step during inference
633
+ if timesteps.dim() == 0:
634
+ timesteps = timesteps.expand(x.shape[0]
635
+ ).to(x.device, dtype=torch.long)
636
+
637
+ x = self.patch_embed(x)
638
+ x = self.x_pe(x)
639
+
640
+ B, L, D = x.shape
641
+
642
+ if self.use_context:
643
+ context_token = self.context_embed(context)
644
+ context_token = self.context_pe(context_token)
645
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
646
+ x, x_mask = self._concat_x_context(
647
+ x=x,
648
+ context=context_token,
649
+ x_mask=x_mask,
650
+ context_mask=context_mask
651
+ )
652
+ context_token, context_mask = None, None
653
+ else:
654
+ context_token, context_mask = None, None
655
+
656
+ time_token = self.time_embed(timesteps)
657
+ if self.cls_embed:
658
+ cls_token = self.cls_embed(cls_token)
659
+ time_ada = None
660
+ time_ada_final = None
661
+ if self.use_adanorm:
662
+ if self.cls_embed:
663
+ time_token = time_token + cls_token
664
+ time_token = self.time_act(time_token)
665
+ time_ada_final = self.time_ada_final(time_token)
666
+ if self.time_ada is not None:
667
+ time_ada = self.time_ada(time_token)
668
+ else:
669
+ time_token = time_token.unsqueeze(dim=1)
670
+ if self.cls_embed:
671
+ cls_token = cls_token.unsqueeze(dim=1)
672
+ time_token = torch.cat([time_token, cls_token], dim=1)
673
+ time_token = self.time_pe(time_token)
674
+ x = torch.cat((time_token, x), dim=1)
675
+ if x_mask is not None:
676
+ x_mask = torch.cat([
677
+ torch.ones(B, time_token.shape[1],
678
+ device=x_mask.device).bool(), x_mask
679
+ ],
680
+ dim=1)
681
+ time_token = None
682
+
683
+ skips = []
684
+ for blk in self.in_blocks:
685
+ x = blk(
686
+ x=x,
687
+ time_token=time_token,
688
+ time_ada=time_ada,
689
+ skip=None,
690
+ context=context_token,
691
+ x_mask=x_mask,
692
+ context_mask=context_mask,
693
+ extras=self.extras
694
+ )
695
+ if self.use_skip:
696
+ skips.append(x)
697
+
698
+ x = self.mid_block(
699
+ x=x,
700
+ time_token=time_token,
701
+ time_ada=time_ada,
702
+ skip=None,
703
+ context=context_token,
704
+ x_mask=x_mask,
705
+ context_mask=context_mask,
706
+ extras=self.extras
707
+ )
708
+ for blk in self.out_blocks:
709
+ if self.use_skip:
710
+ skip = skips.pop()
711
+ if controlnet_skips:
712
+ # add to skip like u-net controlnet
713
+ skip = skip + controlnet_skips.pop()
714
+ else:
715
+ skip = None
716
+ if controlnet_skips:
717
+ # directly add to x
718
+ x = x + controlnet_skips.pop()
719
+
720
+ x = blk(
721
+ x=x,
722
+ time_token=time_token,
723
+ time_ada=time_ada,
724
+ skip=skip,
725
+ context=context_token,
726
+ x_mask=x_mask,
727
+ context_mask=context_mask,
728
+ extras=self.extras
729
+ )
730
+
731
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
732
+
733
+ return x
734
+
735
+
736
+ class MaskDiT(nn.Module):
737
+ def __init__(
738
+ self,
739
+ model: UDiT,
740
+ mae=False,
741
+ mae_prob=0.5,
742
+ mask_ratio=[0.25, 1.0],
743
+ mask_span=10,
744
+ ):
745
+ super().__init__()
746
+ self.model = model
747
+ self.mae = mae
748
+ if self.mae:
749
+ out_channel = model.out_chans
750
+ self.mask_embed = nn.Parameter(torch.zeros((out_channel)))
751
+ self.mae_prob = mae_prob
752
+ self.mask_ratio = mask_ratio
753
+ self.mask_span = mask_span
754
+
755
+ def random_masking(self, gt, mask_ratios, mae_mask_infer=None):
756
+ B, D, L = gt.shape
757
+ if mae_mask_infer is None:
758
+ # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1)
759
+ mask_ratios = mask_ratios.cpu().numpy()
760
+ mask = compute_mask_indices(
761
+ shape=[B, L],
762
+ padding_mask=None,
763
+ mask_prob=mask_ratios,
764
+ mask_length=self.mask_span,
765
+ mask_type="static",
766
+ mask_other=0.0,
767
+ min_masks=1,
768
+ no_overlap=False,
769
+ min_space=0,
770
+ )
771
+ mask = mask.unsqueeze(1).expand_as(gt)
772
+ else:
773
+ mask = mae_mask_infer
774
+ mask = mask.expand_as(gt)
775
+ gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask]
776
+ return gt, mask.type_as(gt)
777
+
778
+ def forward(
779
+ self,
780
+ x,
781
+ timesteps,
782
+ context,
783
+ x_mask=None,
784
+ context_mask=None,
785
+ cls_token=None,
786
+ gt=None,
787
+ mae_mask_infer=None,
788
+ forward_model=True
789
+ ):
790
+ # todo: handle controlnet inside
791
+ mae_mask = torch.ones_like(x)
792
+ if self.mae:
793
+ if gt is not None:
794
+ B, D, L = gt.shape
795
+ mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio
796
+ ).to(gt.device)
797
+ gt, mae_mask = self.random_masking(
798
+ gt, mask_ratios, mae_mask_infer
799
+ )
800
+ # apply mae only to the selected batches
801
+ if mae_mask_infer is None:
802
+ # determine mae batch
803
+ mae_batch = torch.rand(B) < self.mae_prob
804
+ gt[~mae_batch] = self.mask_embed.view(
805
+ 1, D, 1
806
+ ).expand_as(gt)[~mae_batch]
807
+ mae_mask[~mae_batch] = 1.0
808
+ else:
809
+ B, D, L = x.shape
810
+ gt = self.mask_embed.view(1, D, 1).expand_as(x)
811
+ x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1)
812
+
813
+ if forward_model:
814
+ x = self.model(
815
+ x=x,
816
+ timesteps=timesteps,
817
+ context=context,
818
+ x_mask=x_mask,
819
+ context_mask=context_mask,
820
+ cls_token=cls_token
821
+ )
822
+ # logger.info(mae_mask[:, 0, :].sum(dim=-1))
823
+ return x, mae_mask