| """ |
| 1B Parameter Text-to-Video Model (TTV-1B) |
| A production-ready diffusion-based text-to-video generation model |
| Architecture: DiT (Diffusion Transformer) with 3D spatiotemporal attention |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Tuple, List |
| import math |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| """Rotary Position Embedding for temporal and spatial dimensions""" |
| def __init__(self, dim: int, max_seq_len: int = 10000): |
| super().__init__() |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer('inv_freq', inv_freq) |
| self.max_seq_len = max_seq_len |
| |
| def forward(self, seq_len: int, device: torch.device): |
| t = torch.arange(seq_len, device=device).type_as(self.inv_freq) |
| freqs = torch.einsum('i,j->ij', t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| return emb.cos(), emb.sin() |
|
|
|
|
| def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): |
| """Apply rotary embeddings to input tensor""" |
| x1, x2 = x[..., ::2], x[..., 1::2] |
| rotated = torch.cat([-x2, x1], dim=-1) |
| return (x * cos) + (rotated * sin) |
|
|
|
|
| class SpatioTemporalAttention(nn.Module): |
| """3D Attention mechanism for video data (Time x Height x Width)""" |
| def __init__(self, dim: int, num_heads: int = 16, qkv_bias: bool = True): |
| super().__init__() |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.scale = self.head_dim ** -0.5 |
| |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.proj = nn.Linear(dim, dim) |
| self.rotary_emb = RotaryEmbedding(self.head_dim) |
| |
| def forward(self, x: torch.Tensor, temporal_len: int): |
| B, N, C = x.shape |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| |
| |
| cos, sin = self.rotary_emb(temporal_len, x.device) |
| if N >= temporal_len: |
| cos = cos.unsqueeze(0).unsqueeze(0).repeat(B, self.num_heads, N // temporal_len, 1) |
| sin = sin.unsqueeze(0).unsqueeze(0).repeat(B, self.num_heads, N // temporal_len, 1) |
| q = apply_rotary_emb(q, cos, sin) |
| k = apply_rotary_emb(k, cos, sin) |
| |
| |
| attn = (q @ k.transpose(-2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| |
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| return x |
|
|
|
|
| class FeedForward(nn.Module): |
| """Feed-forward network with GELU activation""" |
| def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(dim, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, dim), |
| nn.Dropout(dropout) |
| ) |
| |
| def forward(self, x: torch.Tensor): |
| return self.net(x) |
|
|
|
|
| class DiTBlock(nn.Module): |
| """Diffusion Transformer Block with adaptive layer norm""" |
| def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) |
| self.attn = SpatioTemporalAttention(dim, num_heads) |
| self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) |
| mlp_hidden_dim = int(dim * mlp_ratio) |
| self.mlp = FeedForward(dim, mlp_hidden_dim) |
| |
| |
| self.adaLN_modulation = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(dim, 6 * dim, bias=True) |
| ) |
| |
| def forward(self, x: torch.Tensor, c: torch.Tensor, temporal_len: int): |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \ |
| self.adaLN_modulation(c).chunk(6, dim=-1) |
| |
| |
| x = x + gate_msa.unsqueeze(1) * self.attn( |
| self.modulate(self.norm1(x), shift_msa, scale_msa), temporal_len |
| ) |
| |
| |
| x = x + gate_mlp.unsqueeze(1) * self.mlp( |
| self.modulate(self.norm2(x), shift_mlp, scale_mlp) |
| ) |
| return x |
| |
| @staticmethod |
| def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): |
| return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
|
|
|
| class TextEncoder(nn.Module): |
| """Simple text encoder using transformer architecture""" |
| def __init__(self, vocab_size: int = 50257, dim: int = 768, max_len: int = 256): |
| super().__init__() |
| self.token_embedding = nn.Embedding(vocab_size, dim) |
| self.position_embedding = nn.Embedding(max_len, dim) |
| self.layers = nn.ModuleList([ |
| nn.TransformerEncoderLayer(d_model=dim, nhead=12, dim_feedforward=dim*4, |
| batch_first=True, norm_first=True) |
| for _ in range(6) |
| ]) |
| self.norm = nn.LayerNorm(dim) |
| |
| def forward(self, tokens: torch.Tensor): |
| B, L = tokens.shape |
| positions = torch.arange(L, device=tokens.device).unsqueeze(0).expand(B, -1) |
| x = self.token_embedding(tokens) + self.position_embedding(positions) |
| |
| for layer in self.layers: |
| x = layer(x) |
| |
| return self.norm(x) |
|
|
|
|
| class PatchEmbed3D(nn.Module): |
| """3D Patch Embedding for video (T, H, W, C) -> (N, D)""" |
| def __init__(self, patch_size: Tuple[int, int, int] = (2, 16, 16), |
| in_channels: int = 3, embed_dim: int = 768): |
| super().__init__() |
| self.patch_size = patch_size |
| t_patch, h_patch, w_patch = patch_size |
| |
| self.proj = nn.Conv3d( |
| in_channels, embed_dim, |
| kernel_size=patch_size, |
| stride=patch_size |
| ) |
| |
| def forward(self, x: torch.Tensor): |
| |
| x = self.proj(x) |
| B, D, T, H, W = x.shape |
| x = x.flatten(2).transpose(1, 2) |
| return x, (T, H, W) |
|
|
|
|
| class VideoTTV1B(nn.Module): |
| """ |
| 1B Parameter Text-to-Video Model |
| |
| Architecture: |
| - Text Encoder: 6-layer transformer (50M params) |
| - DiT Backbone: 24 blocks, 1536 hidden dim, 24 heads (950M params) |
| - 3D Patch Embedding & Unpatchify |
| |
| Total: ~1.0B parameters |
| """ |
| def __init__( |
| self, |
| img_size: Tuple[int, int] = (256, 256), |
| num_frames: int = 16, |
| patch_size: Tuple[int, int, int] = (2, 16, 16), |
| in_channels: int = 3, |
| hidden_dim: int = 1536, |
| depth: int = 24, |
| num_heads: int = 24, |
| mlp_ratio: float = 4.0, |
| text_dim: int = 768, |
| vocab_size: int = 50257, |
| max_text_len: int = 256, |
| ): |
| super().__init__() |
| self.img_size = img_size |
| self.num_frames = num_frames |
| self.patch_size = patch_size |
| self.in_channels = in_channels |
| self.hidden_dim = hidden_dim |
| |
| |
| self.t_patches = num_frames // patch_size[0] |
| self.h_patches = img_size[0] // patch_size[1] |
| self.w_patches = img_size[1] // patch_size[2] |
| self.num_patches = self.t_patches * self.h_patches * self.w_patches |
| |
| |
| self.text_encoder = TextEncoder(vocab_size, text_dim, max_text_len) |
| |
| |
| self.text_proj = nn.Linear(text_dim, hidden_dim) |
| |
| |
| self.patch_embed = PatchEmbed3D(patch_size, in_channels, hidden_dim) |
| |
| |
| self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, hidden_dim)) |
| |
| |
| self.time_embed = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim * 4), |
| nn.SiLU(), |
| nn.Linear(hidden_dim * 4, hidden_dim), |
| ) |
| |
| |
| self.blocks = nn.ModuleList([ |
| DiTBlock(hidden_dim, num_heads, mlp_ratio) |
| for _ in range(depth) |
| ]) |
| |
| |
| self.final_layer = nn.Sequential( |
| nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6), |
| nn.Linear(hidden_dim, patch_size[0] * patch_size[1] * patch_size[2] * in_channels), |
| ) |
| |
| |
| self.final_adaLN = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(hidden_dim, 2 * hidden_dim, bias=True) |
| ) |
| |
| self.initialize_weights() |
| |
| def initialize_weights(self): |
| """Initialize weights""" |
| |
| w = self.patch_embed.proj.weight.data |
| nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| nn.init.constant_(self.patch_embed.proj.bias, 0) |
| |
| |
| nn.init.normal_(self.pos_embed, std=0.02) |
| |
| |
| def _basic_init(module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.constant_(module.bias, 0) |
| self.apply(_basic_init) |
| |
| def get_timestep_embedding(self, timesteps: torch.Tensor, dim: int): |
| """Sinusoidal timestep embeddings""" |
| half_dim = dim // 2 |
| emb = math.log(10000) / (half_dim - 1) |
| emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb) |
| emb = timesteps[:, None] * emb[None, :] |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) |
| return emb |
| |
| def unpatchify(self, x: torch.Tensor): |
| """Convert patches back to video (B, N, patch_dim) -> (B, C, T, H, W)""" |
| B = x.shape[0] |
| t, h, w = self.patch_size |
| |
| x = x.reshape(B, self.t_patches, self.h_patches, self.w_patches, |
| t, h, w, self.in_channels) |
| x = x.permute(0, 7, 1, 4, 2, 5, 3, 6) |
| x = x.reshape(B, self.in_channels, self.num_frames, self.img_size[0], self.img_size[1]) |
| return x |
| |
| def forward(self, x: torch.Tensor, timesteps: torch.Tensor, text_tokens: torch.Tensor): |
| """ |
| Forward pass |
| |
| Args: |
| x: Noisy video tensor (B, C, T, H, W) |
| timesteps: Diffusion timesteps (B,) |
| text_tokens: Text token IDs (B, L) |
| |
| Returns: |
| Predicted noise (B, C, T, H, W) |
| """ |
| B = x.shape[0] |
| |
| |
| text_emb = self.text_encoder(text_tokens) |
| text_emb = self.text_proj(text_emb.mean(dim=1)) |
| |
| |
| t_emb = self.get_timestep_embedding(timesteps, self.hidden_dim) |
| t_emb = self.time_embed(t_emb) |
| |
| |
| c = text_emb + t_emb |
| |
| |
| x, (T, H, W) = self.patch_embed(x) |
| x = x + self.pos_embed |
| |
| |
| for block in self.blocks: |
| x = block(x, c, self.t_patches) |
| |
| |
| shift, scale = self.final_adaLN(c).chunk(2, dim=-1) |
| x = self.final_layer.modulate(self.final_layer[0](x), shift, scale) |
| x = self.final_layer[1](x) |
| |
| |
| x = self.unpatchify(x) |
| |
| return x |
| |
| def count_parameters(self): |
| """Count total parameters""" |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
| class DDPMScheduler: |
| """DDPM noise scheduler for training and sampling""" |
| def __init__(self, num_steps: int = 1000, beta_start: float = 0.0001, |
| beta_end: float = 0.02): |
| self.num_steps = num_steps |
| |
| |
| self.betas = torch.linspace(beta_start, beta_end, num_steps) |
| self.alphas = 1.0 - self.betas |
| self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
| self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0) |
| |
| |
| self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) |
| self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) |
| |
| |
| self.posterior_variance = ( |
| self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) |
| ) |
| |
| def add_noise(self, x_0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor): |
| """Add noise to clean data""" |
| sqrt_alpha_prod = self.sqrt_alphas_cumprod[t].reshape(-1, 1, 1, 1, 1) |
| sqrt_one_minus_alpha_prod = self.sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1, 1) |
| |
| return sqrt_alpha_prod.to(x_0.device) * x_0 + sqrt_one_minus_alpha_prod.to(x_0.device) * noise |
| |
| @torch.no_grad() |
| def sample_step(self, model: nn.Module, x_t: torch.Tensor, t: int, |
| text_tokens: torch.Tensor): |
| """Single denoising step""" |
| betas_t = self.betas[t] |
| sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t] |
| sqrt_recip_alphas_t = torch.sqrt(1.0 / self.alphas[t]) |
| |
| |
| timesteps = torch.full((x_t.shape[0],), t, device=x_t.device, dtype=torch.long) |
| predicted_noise = model(x_t, timesteps, text_tokens) |
| |
| |
| model_mean = sqrt_recip_alphas_t * ( |
| x_t - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t |
| ) |
| |
| if t == 0: |
| return model_mean |
| else: |
| posterior_variance_t = self.posterior_variance[t] |
| noise = torch.randn_like(x_t) |
| return model_mean + torch.sqrt(posterior_variance_t) * noise |
|
|
|
|
| def create_model(device: str = 'cuda'): |
| """Factory function to create the model""" |
| model = VideoTTV1B( |
| img_size=(256, 256), |
| num_frames=16, |
| patch_size=(2, 16, 16), |
| in_channels=3, |
| hidden_dim=1536, |
| depth=24, |
| num_heads=24, |
| mlp_ratio=4.0, |
| ) |
| |
| total_params = model.count_parameters() |
| print(f"Total parameters: {total_params:,} ({total_params/1e9:.2f}B)") |
| |
| return model.to(device) |
|
|
|
|
| if __name__ == "__main__": |
| |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| print(f"Using device: {device}") |
| |
| |
| model = create_model(device) |
| |
| |
| batch_size = 2 |
| x = torch.randn(batch_size, 3, 16, 256, 256).to(device) |
| timesteps = torch.randint(0, 1000, (batch_size,)).to(device) |
| text_tokens = torch.randint(0, 50257, (batch_size, 128)).to(device) |
| |
| print(f"\nInput shape: {x.shape}") |
| print(f"Timesteps shape: {timesteps.shape}") |
| print(f"Text tokens shape: {text_tokens.shape}") |
| |
| with torch.no_grad(): |
| output = model(x, timesteps, text_tokens) |
| |
| print(f"Output shape: {output.shape}") |
| print("\n✓ Model test passed!") |
|
|