| |
|
| | import torch
|
| | import torch.nn as nn
|
| | from ovi.modules.model import WanLayerNorm, WanModel, WanRMSNorm, gradient_checkpointing, rope_apply
|
| | from ovi.modules.attention import flash_attention
|
| | from ovi.distributed_comms.communications import all_gather, all_to_all_4D
|
| | from ovi.distributed_comms.parallel_states import nccl_info, get_sequence_parallel_state
|
| |
|
| | class FusionModel(nn.Module):
|
| | def __init__(self, video_config=None, audio_config=None):
|
| | super().__init__()
|
| | has_video = True
|
| | has_audio = True
|
| | if video_config is not None:
|
| | self.video_model = WanModel(**video_config)
|
| | else:
|
| | has_video = False
|
| | self.video_model = None
|
| | print("Warning: No video model is provided!")
|
| |
|
| | if audio_config is not None:
|
| | self.audio_model = WanModel(**audio_config)
|
| | else:
|
| | has_audio = False
|
| | self.audio_model = None
|
| | print("Warning: No audio model is provided!")
|
| |
|
| | if has_video and has_audio:
|
| | assert len(self.video_model.blocks) == len(self.audio_model.blocks)
|
| | self.num_blocks = len(self.video_model.blocks)
|
| |
|
| | self.use_sp = get_sequence_parallel_state()
|
| | if self.use_sp:
|
| | self.sp_size = nccl_info.sp_size
|
| | self.sp_rank = nccl_info.rank_within_group
|
| | self.inject_cross_attention_kv_projections()
|
| |
|
| | self.init_weights()
|
| |
|
| | def inject_cross_attention_kv_projections(self):
|
| | for vid_block in self.video_model.blocks:
|
| | vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim)
|
| | vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim)
|
| | vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True)
|
| | vid_block.cross_attn.norm_k_fusion = WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity()
|
| |
|
| |
|
| | for audio_block in self.audio_model.blocks:
|
| | audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim)
|
| | audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim)
|
| | audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True)
|
| | audio_block.cross_attn.norm_k_fusion = WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity()
|
| |
|
| |
|
| | def merge_kwargs(self, vid_kwargs, audio_kwargs):
|
| | """
|
| | keys in each kwarg:
|
| | e
|
| | seq_lens
|
| | grid_sizes
|
| | freqs
|
| | context
|
| | context_lens
|
| | """
|
| | merged_kwargs = {}
|
| | for key in vid_kwargs:
|
| | merged_kwargs[f"vid_{key}"] = vid_kwargs[key]
|
| | for key in audio_kwargs:
|
| | merged_kwargs[f"audio_{key}"] = audio_kwargs[key]
|
| | return merged_kwargs
|
| |
|
| | def single_fusion_cross_attention_forward(self,
|
| | cross_attn_block,
|
| | src_seq,
|
| | src_grid_sizes,
|
| | src_freqs,
|
| | target_seq,
|
| | target_seq_lens,
|
| | target_grid_sizes,
|
| | target_freqs,
|
| | context,
|
| | context_lens
|
| | ):
|
| | b, n, d = src_seq.size(0), cross_attn_block.num_heads, cross_attn_block.head_dim
|
| | if hasattr(cross_attn_block, "k_img"):
|
| |
|
| | q, k, v, k_img, v_img = cross_attn_block.qkv_fn(src_seq, context)
|
| | else:
|
| |
|
| | q, k, v = cross_attn_block.qkv_fn(src_seq, context)
|
| | k_img = v_img = None
|
| |
|
| |
|
| | if self.use_sp:
|
| | q = all_to_all_4D(q, scatter_dim=2, gather_dim=1)
|
| | k = torch.chunk(k, self.sp_size, dim=2)[self.sp_rank]
|
| | v = torch.chunk(v, self.sp_size, dim=2)[self.sp_rank]
|
| | if k_img is not None:
|
| | k_img = torch.chunk(k_img, self.sp_size, dim=2)[self.sp_rank]
|
| | if v_img is not None:
|
| | v_img = torch.chunk(v_img, self.sp_size, dim=2)[self.sp_rank]
|
| |
|
| | x = flash_attention(q, k, v, k_lens=context_lens)
|
| |
|
| | if k_img is not None:
|
| | img_x = flash_attention(q, k_img, v_img, k_lens=None)
|
| | x = x + img_x
|
| |
|
| | is_vid = src_grid_sizes.shape[1] > 1
|
| |
|
| | target_seq = cross_attn_block.pre_attn_norm_fusion(target_seq)
|
| | k_target = cross_attn_block.norm_k_fusion(cross_attn_block.k_fusion(target_seq)).view(b, -1, n, d)
|
| | v_target = cross_attn_block.v_fusion(target_seq).view(b, -1, n, d)
|
| | if self.use_sp:
|
| | k_target = all_to_all_4D(k_target, scatter_dim=2, gather_dim=1)
|
| | v_target = all_to_all_4D(v_target, scatter_dim=2, gather_dim=1)
|
| |
|
| | q = rope_apply(q, src_grid_sizes, src_freqs)
|
| | k_target = rope_apply(k_target, target_grid_sizes, target_freqs)
|
| |
|
| | target_x = flash_attention(q, k_target, v_target, k_lens=target_seq_lens)
|
| |
|
| | x = x + target_x
|
| | if self.use_sp:
|
| | x = all_to_all_4D(x, scatter_dim=1, gather_dim=2)
|
| |
|
| | x = x.flatten(2)
|
| |
|
| | x = cross_attn_block.o(x)
|
| | return x
|
| |
|
| | def single_fusion_cross_attention_ffn_forward(self,
|
| | attn_block,
|
| | src_seq,
|
| | src_grid_sizes,
|
| | src_freqs,
|
| | target_seq,
|
| | target_seq_lens,
|
| | target_grid_sizes,
|
| | target_freqs,
|
| | context,
|
| | context_lens,
|
| | src_e):
|
| |
|
| | src_seq = src_seq + self.single_fusion_cross_attention_forward(attn_block.cross_attn,
|
| | attn_block.norm3(src_seq),
|
| | src_grid_sizes=src_grid_sizes,
|
| | src_freqs=src_freqs,
|
| | target_seq=target_seq,
|
| | target_seq_lens=target_seq_lens,
|
| | target_grid_sizes=target_grid_sizes,
|
| | target_freqs=target_freqs,
|
| | context=context,
|
| | context_lens=context_lens
|
| | )
|
| | y = attn_block.ffn(attn_block.norm2(src_seq).bfloat16() * (1 + src_e[4].squeeze(2)) + src_e[3].squeeze(2))
|
| | with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
| | src_seq = src_seq + y * src_e[5].squeeze(2)
|
| | return src_seq
|
| |
|
| | def single_fusion_block_forward(self,
|
| | vid_block,
|
| | audio_block,
|
| | vid,
|
| | audio,
|
| | vid_e,
|
| | vid_seq_lens,
|
| | vid_grid_sizes,
|
| | vid_freqs,
|
| | vid_context,
|
| | vid_context_lens,
|
| | audio_e,
|
| | audio_seq_lens,
|
| | audio_grid_sizes,
|
| | audio_freqs,
|
| | audio_context,
|
| | audio_context_lens
|
| | ):
|
| |
|
| | assert audio_e.dtype == torch.bfloat16
|
| | assert len(audio_e.shape) == 4 and audio_e.size(2) == 6 and audio_e.shape[1] == audio.shape[1], f"{audio_e.shape}, {audio.shape}"
|
| | with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
| | audio_e = audio_block.modulation(audio_e).chunk(6, dim=2)
|
| | assert audio_e[0].dtype == torch.bfloat16
|
| |
|
| |
|
| | audio_y = audio_block.self_attn(
|
| | audio_block.norm1(audio).bfloat16() * (1 + audio_e[1].squeeze(2)) + audio_e[0].squeeze(2), audio_seq_lens, audio_grid_sizes,
|
| | audio_freqs)
|
| | with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
| | audio = audio + audio_y * audio_e[2].squeeze(2)
|
| |
|
| |
|
| | assert len(vid_e.shape) == 4 and vid_e.size(2) == 6 and vid_e.shape[1] == vid.shape[1], f"{vid_e.shape}, {vid.shape}"
|
| | with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
| | vid_e = vid_block.modulation(vid_e).chunk(6, dim=2)
|
| |
|
| |
|
| | vid_y = vid_block.self_attn(
|
| | vid_block.norm1(vid).bfloat16() * (1 + vid_e[1].squeeze(2)) + vid_e[0].squeeze(2), vid_seq_lens, vid_grid_sizes,
|
| | vid_freqs)
|
| |
|
| | with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
| | vid = vid + vid_y * vid_e[2].squeeze(2)
|
| |
|
| | og_audio = audio
|
| |
|
| |
|
| | audio = self.single_fusion_cross_attention_ffn_forward(
|
| | audio_block,
|
| | audio,
|
| | audio_grid_sizes,
|
| | audio_freqs,
|
| | vid,
|
| | vid_seq_lens,
|
| | vid_grid_sizes,
|
| | vid_freqs,
|
| | audio_context,
|
| | audio_context_lens,
|
| | audio_e
|
| | )
|
| |
|
| | assert not torch.equal(og_audio, audio), "Audio should be changed after cross-attention!"
|
| |
|
| |
|
| | vid = self.single_fusion_cross_attention_ffn_forward(
|
| | vid_block,
|
| | vid,
|
| | vid_grid_sizes,
|
| | vid_freqs,
|
| | og_audio,
|
| | audio_seq_lens,
|
| | audio_grid_sizes,
|
| | audio_freqs,
|
| | vid_context,
|
| | vid_context_lens,
|
| | vid_e
|
| | )
|
| |
|
| | return vid, audio
|
| |
|
| | def forward(
|
| | self,
|
| | vid,
|
| | audio,
|
| | t,
|
| | vid_context,
|
| | audio_context,
|
| | vid_seq_len,
|
| | audio_seq_len,
|
| | clip_fea=None,
|
| | clip_fea_audio=None,
|
| | y=None,
|
| | first_frame_is_clean=False,
|
| | slg_layer=False
|
| | ):
|
| |
|
| | assert clip_fea is None
|
| | assert y is None
|
| |
|
| | if vid is None or all([x is None for x in vid]):
|
| | assert vid_context is None
|
| | assert vid_seq_len is None
|
| | assert self.audio_model is not None
|
| |
|
| | return None, self.audio_model(x=audio, t=t, context=audio_context, seq_len=audio_seq_len, clip_fea=clip_fea_audio, y=None)
|
| |
|
| | if audio is None or all([x is None for x in audio]):
|
| | assert clip_fea_audio is None
|
| | assert audio_context is None
|
| | assert audio_seq_len is None
|
| | assert self.video_model is not None
|
| |
|
| | return self.video_model(x=vid, t=t, context=vid_context, seq_len=vid_seq_len, clip_fea=clip_fea, y=y, first_frame_is_clean=first_frame_is_clean), None
|
| |
|
| | vid, vid_e, vid_kwargs = self.video_model.prepare_transformer_block_kwargs(
|
| | x=vid, t=t, context=vid_context, seq_len=vid_seq_len, clip_fea=clip_fea, y=y, first_frame_is_clean=first_frame_is_clean
|
| | )
|
| |
|
| | audio, audio_e, audio_kwargs = self.audio_model.prepare_transformer_block_kwargs(
|
| | x=audio, t=t, context=audio_context, seq_len=audio_seq_len, clip_fea=clip_fea_audio, y=None, first_frame_is_clean=False
|
| | )
|
| |
|
| | kwargs = self.merge_kwargs(vid_kwargs, audio_kwargs)
|
| |
|
| | for i in range(self.num_blocks):
|
| | """
|
| | 1 fusion block refers to 1 audio block with 1 video block.
|
| | """
|
| | if slg_layer > 0 and i == slg_layer:
|
| | continue
|
| | vid_block = self.video_model.blocks[i]
|
| | audio_block = self.audio_model.blocks[i]
|
| | vid, audio = gradient_checkpointing(
|
| | enabled=(self.training and self.gradient_checkpointing),
|
| | module=self.single_fusion_block_forward,
|
| | vid_block=vid_block,
|
| | audio_block=audio_block,
|
| | vid=vid,
|
| | audio=audio,
|
| | **kwargs
|
| | )
|
| |
|
| | vid = self.video_model.post_transformer_block_out(vid, vid_kwargs['grid_sizes'], vid_e)
|
| | audio = self.audio_model.post_transformer_block_out(audio, audio_kwargs['grid_sizes'], audio_e)
|
| |
|
| | return vid, audio
|
| |
|
| | def init_weights(self):
|
| | if self.audio_model is not None:
|
| | self.audio_model.init_weights()
|
| |
|
| | if self.video_model is not None:
|
| | self.video_model.init_weights()
|
| |
|
| | for name, mod in self.video_model.named_modules():
|
| | if "fusion" in name and isinstance(mod, nn.Linear):
|
| | with torch.no_grad():
|
| | mod.weight.div_(10.0)
|
| |
|
| |
|
| | def set_rope_params(self):
|
| | self.video_model.set_rope_params()
|
| | self.audio_model.set_rope_params() |