from typing import Tuple, List import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)] / dim)) t = torch.arange(end) freqs = torch.outer(t, freqs) freqs_cis = torch.complex(torch.cos(freqs), torch.sin(freqs)) return freqs_cis def apply_rotary_emb( x: torch.Tensor, freqs_cis: torch.Tensor, ) -> torch.Tensor: x_ = torch.view_as_complex(x.float().reshape(*x.shape[:3], -1, 2)) x_ = x_ * freqs_cis[..., None, :] x_ = torch.view_as_real(x_).reshape(x.shape) return x_.type_as(x) def get_timestep_embedding( timestep: torch.Tensor, embed_size: int, ) -> torch.Tensor: assert embed_size % 2 == 0 half = embed_size // 2 freqs = 1000 * torch.exp( -torch.log(torch.tensor(10000.0)) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(timestep.device) args = timestep[..., None] * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) return embedding.to(timestep.dtype) class LowRankAdaLN(nn.Module): def __init__( self, model_size: int, rank: int, eps: float ): super().__init__() self.eps = eps self.shift_down = nn.Linear(model_size, rank, bias=False) self.scale_down = nn.Linear(model_size, rank, bias=False) self.gate_down = nn.Linear(model_size, rank, bias=False) self.shift_up = nn.Linear(rank, model_size, bias=True) self.scale_up = nn.Linear(rank, model_size, bias=True) self.gate_up = nn.Linear(rank, model_size, bias=True) def forward( self, x: torch.Tensor, cond_embed: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: shift, scale, gate = cond_embed.chunk(3, dim=-1) shift = self.shift_up(self.shift_down(F.silu(shift))) + shift scale = self.scale_up(self.scale_down(F.silu(scale))) + scale gate = self.gate_up(self.gate_down(F.silu(gate))) + gate x_dtype = x.dtype x = x.float() x = x * torch.rsqrt(torch.pow(x.float(), 2).mean(dim=-1, keepdim=True) + self.eps) x = x * (scale + 1) + shift gate = torch.tanh(gate) return x.to(x_dtype), gate class RMSNorm(nn.Module): # could also just use torch rmsnorm def __init__( self, model_size: int | Tuple[int, int], eps: float ): super().__init__() self.eps = eps if isinstance(model_size, int): model_size = (model_size, ) self.weight = nn.Parameter(torch.ones(model_size)) def forward(self, x: torch.Tensor) -> torch.Tensor: x_dtype = x.dtype x = x.float() x = x * torch.rsqrt(torch.pow(x.float(), 2).mean(dim=-1, keepdim=True) + self.eps) x = x * self.weight return x.to(x_dtype) class SelfAttention(nn.Module): def __init__( self, model_size: int, num_heads: int, is_causal: bool, norm_eps: float ): super().__init__() self.num_heads = num_heads self.is_causal = is_causal self.wq = nn.Linear(model_size, model_size, bias=False) self.wk = nn.Linear(model_size, model_size, bias=False) self.wv = nn.Linear(model_size, model_size, bias=False) self.wo = nn.Linear(model_size, model_size, bias=False) self.gate = nn.Linear(model_size, model_size, bias=False) assert model_size % num_heads == 0 self.q_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps) self.k_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps) def forward(self, x: torch.Tensor, mask: torch.Tensor | None, freqs_cis: torch.Tensor) -> torch.Tensor: batch_size, seq_len = x.shape[:2] xq = self.wq(x).reshape(batch_size, seq_len, self.num_heads, -1) xk = self.wk(x).reshape(batch_size, seq_len, self.num_heads, -1) xv = self.wv(x).reshape(batch_size, seq_len, self.num_heads, -1) gate = self.gate(x) xq = self.q_norm(xq) xk = self.k_norm(xk) xq = apply_rotary_emb(xq, freqs_cis[:seq_len]) xk = apply_rotary_emb(xk, freqs_cis[:seq_len]) if mask is not None: assert mask.ndim == 2 # (b, s) mask = mask[:, None, None] output = F.scaled_dot_product_attention( query=xq.transpose(1, 2), key=xk.transpose(1, 2), value=xv.transpose(1, 2), attn_mask=mask, is_causal=self.is_causal ).transpose(1, 2) output = output.reshape(batch_size, seq_len, -1) output = output * torch.sigmoid(gate) output = self.wo(output) return output class JointAttention(nn.Module): def __init__( self, model_size: int, num_heads: int, text_model_size: int, speaker_model_size: int, speaker_patch_size: int, norm_eps: float ): super().__init__() self.speaker_patch_size = speaker_patch_size self.num_heads = num_heads self.wq = nn.Linear(model_size, model_size, bias=False) self.wk = nn.Linear(model_size, model_size, bias=False) self.wv = nn.Linear(model_size, model_size, bias=False) self.wk_text = nn.Linear(text_model_size, model_size, bias=False) self.wv_text = nn.Linear(text_model_size, model_size, bias=False) self.wk_speaker = nn.Linear(speaker_model_size, model_size, bias=False) self.wv_speaker = nn.Linear(speaker_model_size, model_size, bias=False) self.wk_latent = nn.Linear(speaker_model_size, model_size, bias=False) self.wv_latent = nn.Linear(speaker_model_size, model_size, bias=False) assert model_size % num_heads == 0 self.head_dim = model_size // num_heads self.q_norm = RMSNorm((num_heads, self.head_dim), eps=norm_eps) self.k_norm = RMSNorm((num_heads, self.head_dim), eps=norm_eps) self.gate = nn.Linear(model_size, model_size, bias=False) self.wo = nn.Linear(model_size, model_size, bias=False) def _apply_rotary_half(self, y: torch.Tensor, fc: torch.Tensor) -> torch.Tensor: y1, y2 = y.chunk(2, dim=-2) y1 = apply_rotary_emb(y1, fc) return torch.cat([y1, y2], dim=-2) def forward( self, x: torch.Tensor, text_mask: torch.Tensor, speaker_mask: torch.Tensor, freqs_cis: torch.Tensor, kv_cache_text: Tuple[torch.Tensor, torch.Tensor], kv_cache_speaker: Tuple[torch.Tensor, torch.Tensor], start_pos: int | None, kv_cache_latent: Tuple[torch.Tensor, torch.Tensor] | None ) -> torch.Tensor: batch_size, seq_len = x.shape[:2] xq = self.wq(x).reshape(batch_size, seq_len, self.num_heads, -1) xk_self = self.wk(x).reshape(batch_size, seq_len, self.num_heads, -1) xv_self = self.wv(x).reshape(batch_size, seq_len, self.num_heads, -1) xq = self.q_norm(xq) xk_self = self.k_norm(xk_self) gate = self.gate(x) if start_pos is None: start_pos = 0 freqs_q = freqs_cis[start_pos : start_pos + seq_len] xq = self._apply_rotary_half(xq, freqs_q) xk_self = self._apply_rotary_half(xk_self, freqs_q) xk_text, xv_text = kv_cache_text xk_speaker, xv_speaker = kv_cache_speaker if kv_cache_latent is None or kv_cache_latent[0].shape [1] == 0: xk_latent = torch.zeros((batch_size, 0, self.num_heads, xq.shape[-1]), device=x.device, dtype=x.dtype) xv_latent = torch.zeros((batch_size, 0, self.num_heads, xq.shape[-1]), device=x.device, dtype=x.dtype) latent_mask = torch.zeros((batch_size, 0), dtype=torch.bool, device=x.device) else: xk_latent, xv_latent = kv_cache_latent latent_positions = torch.arange(xk_latent.shape[1], device=x.device, dtype=torch.long) * self.speaker_patch_size latent_mask = (latent_positions[None, :] < start_pos).expand(batch_size, xk_latent.shape[1]) xk = torch.cat([xk_self, xk_latent, xk_text, xk_speaker], dim=1) xv = torch.cat([xv_self, xv_latent, xv_text, xv_speaker], dim=1) self_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=x.device) mask = torch.cat([self_mask, latent_mask, text_mask, speaker_mask], dim=1) mask = mask[:, None, None] output = F.scaled_dot_product_attention( query=xq.transpose(1, 2), key=xk.transpose(1, 2), value=xv.transpose(1, 2), attn_mask=mask, is_causal=False ).transpose(1, 2) output = output.reshape(batch_size, seq_len, -1) output = output * torch.sigmoid(gate) output = self.wo(output) return output def get_kv_cache_text(self, text_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size = text_state.shape[0] xk = self.wk_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1) xv = self.wv_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1) xk = self.k_norm(xk) return xk, xv def get_kv_cache_speaker(self, speaker_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size = speaker_state.shape[0] xk = self.wk_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1) xv = self.wv_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1) xk = self.k_norm(xk) return xk, xv def get_kv_cache_latent(self, latent_state: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size = latent_state.shape[0] seq_len = latent_state.shape[1] xk = self.wk_latent(latent_state).reshape(batch_size, seq_len, self.num_heads, -1) xv = self.wv_latent(latent_state).reshape(batch_size, seq_len, self.num_heads, -1) xk = self.k_norm(xk) xk = self._apply_rotary_half(xk, freqs_cis) return xk, xv class MLP(nn.Module): def __init__( self, model_size: int, intermediate_size: int ): super().__init__() self.w1 = nn.Linear(model_size, intermediate_size, bias=False) self.w3 = nn.Linear(model_size, intermediate_size, bias=False) self.w2 = nn.Linear(intermediate_size, model_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) class EncoderTransformerBlock(nn.Module): def __init__( self, model_size: int, num_heads: int, intermediate_size: int, is_causal: bool, norm_eps: float ): super().__init__() self.attention = SelfAttention( model_size=model_size, num_heads=num_heads, is_causal=is_causal, norm_eps=norm_eps ) self.mlp = MLP( model_size=model_size, intermediate_size=intermediate_size ) self.attention_norm = RMSNorm(model_size, norm_eps) self.mlp_norm = RMSNorm(model_size, norm_eps) def forward(self, x: torch.Tensor, mask: torch.Tensor | None, freqs_cis: torch.Tensor) -> torch.Tensor: x = x + self.attention(self.attention_norm(x), mask, freqs_cis) x = x + self.mlp(self.mlp_norm(x)) return x class TransformerBlock(nn.Module): def __init__( self, model_size: int, num_heads: int, intermediate_size: int, norm_eps: float, text_model_size: int, speaker_model_size: int, speaker_patch_size: int, adaln_rank: int, ): super().__init__() self.attention = JointAttention( model_size=model_size, num_heads=num_heads, text_model_size=text_model_size, speaker_model_size=speaker_model_size, speaker_patch_size=speaker_patch_size, norm_eps=norm_eps ) self.mlp = MLP( model_size=model_size, intermediate_size=intermediate_size ) self.attention_adaln = LowRankAdaLN(model_size=model_size, rank=adaln_rank, eps=norm_eps) self.mlp_adaln = LowRankAdaLN(model_size=model_size, rank=adaln_rank, eps=norm_eps) def forward( self, x: torch.Tensor, cond_embed: torch.Tensor, text_mask: torch.Tensor, speaker_mask: torch.Tensor, freqs_cis: torch.Tensor, kv_cache_text: Tuple[torch.Tensor, torch.Tensor], kv_cache_speaker: Tuple[torch.Tensor, torch.Tensor], start_pos: int | None, kv_cache_latent: Tuple[torch.Tensor, torch.Tensor] | None, ) -> torch.Tensor: x_norm, attention_gate = self.attention_adaln(x, cond_embed) x = x + attention_gate * self.attention(x_norm, text_mask, speaker_mask, freqs_cis, kv_cache_text, kv_cache_speaker, start_pos, kv_cache_latent) x_norm, mlp_gate = self.mlp_adaln(x, cond_embed) x = x + mlp_gate * self.mlp(x_norm) return x class TextEncoder(nn.Module): def __init__( self, vocab_size: int, model_size: int, num_layers: int, num_heads: int, intermediate_size: int, norm_eps: float, ): super().__init__() self.text_embedding = nn.Embedding(vocab_size, model_size) self.blocks = nn.ModuleList() for i in range(num_layers): block = EncoderTransformerBlock( model_size=model_size, num_heads=num_heads, intermediate_size=intermediate_size, is_causal=False, norm_eps=norm_eps ) self.blocks.append(block) self.head_dim = model_size // num_heads def forward(self, input_ids: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: x = self.text_embedding(input_ids) freqs_cis = precompute_freqs_cis(self.head_dim, input_ids.shape[1]).to(x.device) # could cache for block in self.blocks: x = block(x, mask, freqs_cis) return x class SpeakerEncoder(nn.Module): def __init__( self, latent_size: int, patch_size: int, model_size: int, num_layers: int, num_heads: int, intermediate_size: int, norm_eps: float, ): super().__init__() self.patch_size = patch_size self.in_proj = nn.Linear(latent_size * patch_size, model_size, bias=True) self.blocks = nn.ModuleList() for i in range(num_layers): block = EncoderTransformerBlock( model_size=model_size, num_heads=num_heads, intermediate_size=intermediate_size, is_causal=True, norm_eps=norm_eps ) self.blocks.append(block) self.head_dim = model_size // num_heads def forward(self, latent: torch.Tensor) -> torch.Tensor: x = latent.reshape(*latent.shape[:-2], latent.shape[-2] // self.patch_size, latent.shape[-1] * self.patch_size) x = self.in_proj(x) x = x / 6. # this helped with initial activation dynamics in early ablations, could also bake into in_proj freqs_cis = precompute_freqs_cis(self.head_dim, x.shape[1]).to(x.device) # could cache for block in self.blocks: x = block(x, None, freqs_cis) return x class EchoDiT(nn.Module): def __init__( self, latent_size: int, # model_size: int, num_layers: int, num_heads: int, intermediate_size: int, norm_eps: float, # text_vocab_size: int, text_model_size: int, text_num_layers: int, text_num_heads: int, text_intermediate_size: int, # speaker_patch_size: int, speaker_model_size: int, speaker_num_layers: int, speaker_num_heads: int, speaker_intermediate_size: int, # timestep_embed_size: int, adaln_rank: int, ): super().__init__() self.speaker_patch_size = speaker_patch_size self.timestep_embed_size = timestep_embed_size self.text_encoder = TextEncoder( vocab_size=text_vocab_size, model_size=text_model_size, num_layers=text_num_layers, num_heads=text_num_heads, intermediate_size=text_intermediate_size, norm_eps=norm_eps, ) self.speaker_encoder = SpeakerEncoder( latent_size=latent_size, patch_size=speaker_patch_size, model_size=speaker_model_size, num_layers=speaker_num_layers, num_heads=speaker_num_heads, intermediate_size=speaker_intermediate_size, norm_eps=norm_eps, ) self.latent_encoder = SpeakerEncoder( latent_size=latent_size, patch_size=speaker_patch_size, model_size=speaker_model_size, num_layers=speaker_num_layers, num_heads=speaker_num_heads, intermediate_size=speaker_intermediate_size, norm_eps=norm_eps, ) self.text_norm = RMSNorm(text_model_size, norm_eps) self.speaker_norm = RMSNorm(speaker_model_size, norm_eps) self.latent_norm = RMSNorm(speaker_model_size, norm_eps) self.cond_module = nn.Sequential( nn.Linear(timestep_embed_size, model_size, bias=False), nn.SiLU(), nn.Linear(model_size, model_size, bias=False), nn.SiLU(), nn.Linear(model_size, model_size * 3, bias=False), ) self.in_proj = nn.Linear(latent_size, model_size, bias=True) self.blocks = nn.ModuleList() for i in range(num_layers): block = TransformerBlock( model_size=model_size, num_heads=num_heads, intermediate_size=intermediate_size, norm_eps=norm_eps, text_model_size=text_model_size, speaker_model_size=speaker_model_size, speaker_patch_size=speaker_patch_size, adaln_rank=adaln_rank, ) self.blocks.append(block) self.out_norm = RMSNorm(model_size, norm_eps) self.out_proj = nn.Linear(model_size, latent_size, bias=True) self.head_dim = model_size // num_heads def forward( self, x: torch.Tensor, t: torch.Tensor, text_mask: torch.Tensor, speaker_mask: torch.Tensor, kv_cache_text: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache_speaker: List[Tuple[torch.Tensor, torch.Tensor]], start_pos: int | None = None, kv_cache_latent: List[Tuple[torch.Tensor, torch.Tensor]] | None = None, ) -> torch.Tensor: if start_pos is None: start_pos = 0 max_pos = start_pos + x.shape[1] freqs_cis = precompute_freqs_cis(self.head_dim, max_pos).to(x.device) # could cache speaker_mask = speaker_mask[..., ::self.speaker_patch_size] cond_embed = self.cond_module(get_timestep_embedding(t, self.timestep_embed_size)) cond_embed = cond_embed[:, None] x = self.in_proj(x) for i, block in enumerate(self.blocks): x = block( x=x, cond_embed=cond_embed, text_mask=text_mask, speaker_mask=speaker_mask, freqs_cis=freqs_cis, kv_cache_text=kv_cache_text[i], kv_cache_speaker=kv_cache_speaker[i], start_pos=start_pos, kv_cache_latent=kv_cache_latent[i] if kv_cache_latent is not None else None, ) x = self.out_norm(x) x = self.out_proj(x) return x.float() def get_kv_cache_text( self, text_input_ids: torch.Tensor, text_mask: torch.Tensor | None, ) -> List[Tuple[torch.Tensor, torch.Tensor]]: text_state = self.text_encoder(text_input_ids, text_mask) text_state = self.text_norm(text_state) return [block.attention.get_kv_cache_text(text_state) for block in self.blocks] def get_kv_cache_speaker( self, speaker_latent: torch.Tensor, ) -> List[Tuple[torch.Tensor, torch.Tensor]]: speaker_state = self.speaker_encoder(speaker_latent) speaker_state = self.speaker_norm(speaker_state) return [block.attention.get_kv_cache_speaker(speaker_state) for block in self.blocks] def get_kv_cache_latent( self, prefix_latent: torch.Tensor, ) -> List[Tuple[torch.Tensor, torch.Tensor]]: latent_state = self.latent_encoder(prefix_latent) latent_state = self.latent_norm(latent_state) seq_len = latent_state.shape[1] max_pos = seq_len * self.speaker_patch_size freqs_cis = precompute_freqs_cis(self.head_dim, max_pos).to(latent_state.device) # could cache positions = torch.arange(seq_len, device=latent_state.device) * self.speaker_patch_size freqs_latent = freqs_cis[positions] return [block.attention.get_kv_cache_latent(latent_state, freqs_latent) for block in self.blocks] @property def device(self) -> torch.device: return next(self.parameters()).device @property def dtype(self) -> torch.dtype: return next(self.parameters()).dtype