Zenderos / video_ttv_1b.py
ASADSANAN's picture
Upload 11 files
3d8856d verified
"""
1B Parameter Text-to-Video Model (TTV-1B)
A production-ready diffusion-based text-to-video generation model
Architecture: DiT (Diffusion Transformer) with 3D spatiotemporal attention
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List
import math
class RotaryEmbedding(nn.Module):
"""Rotary Position Embedding for temporal and spatial dimensions"""
def __init__(self, dim: int, max_seq_len: int = 10000):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self.max_seq_len = max_seq_len
def forward(self, seq_len: int, device: torch.device):
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos(), emb.sin()
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
"""Apply rotary embeddings to input tensor"""
x1, x2 = x[..., ::2], x[..., 1::2]
rotated = torch.cat([-x2, x1], dim=-1)
return (x * cos) + (rotated * sin)
class SpatioTemporalAttention(nn.Module):
"""3D Attention mechanism for video data (Time x Height x Width)"""
def __init__(self, dim: int, num_heads: int = 16, qkv_bias: bool = True):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.rotary_emb = RotaryEmbedding(self.head_dim)
def forward(self, x: torch.Tensor, temporal_len: int):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# Apply rotary embeddings to temporal dimension
cos, sin = self.rotary_emb(temporal_len, x.device)
if N >= temporal_len:
cos = cos.unsqueeze(0).unsqueeze(0).repeat(B, self.num_heads, N // temporal_len, 1)
sin = sin.unsqueeze(0).unsqueeze(0).repeat(B, self.num_heads, N // temporal_len, 1)
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
# Scaled dot-product attention
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class FeedForward(nn.Module):
"""Feed-forward network with GELU activation"""
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x: torch.Tensor):
return self.net(x)
class DiTBlock(nn.Module):
"""Diffusion Transformer Block with adaptive layer norm"""
def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.attn = SpatioTemporalAttention(dim, num_heads)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = FeedForward(dim, mlp_hidden_dim)
# AdaLN modulation
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim, bias=True)
)
def forward(self, x: torch.Tensor, c: torch.Tensor, temporal_len: int):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(c).chunk(6, dim=-1)
# Attention block with modulation
x = x + gate_msa.unsqueeze(1) * self.attn(
self.modulate(self.norm1(x), shift_msa, scale_msa), temporal_len
)
# MLP block with modulation
x = x + gate_mlp.unsqueeze(1) * self.mlp(
self.modulate(self.norm2(x), shift_mlp, scale_mlp)
)
return x
@staticmethod
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class TextEncoder(nn.Module):
"""Simple text encoder using transformer architecture"""
def __init__(self, vocab_size: int = 50257, dim: int = 768, max_len: int = 256):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, dim)
self.position_embedding = nn.Embedding(max_len, dim)
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(d_model=dim, nhead=12, dim_feedforward=dim*4,
batch_first=True, norm_first=True)
for _ in range(6)
])
self.norm = nn.LayerNorm(dim)
def forward(self, tokens: torch.Tensor):
B, L = tokens.shape
positions = torch.arange(L, device=tokens.device).unsqueeze(0).expand(B, -1)
x = self.token_embedding(tokens) + self.position_embedding(positions)
for layer in self.layers:
x = layer(x)
return self.norm(x)
class PatchEmbed3D(nn.Module):
"""3D Patch Embedding for video (T, H, W, C) -> (N, D)"""
def __init__(self, patch_size: Tuple[int, int, int] = (2, 16, 16),
in_channels: int = 3, embed_dim: int = 768):
super().__init__()
self.patch_size = patch_size
t_patch, h_patch, w_patch = patch_size
self.proj = nn.Conv3d(
in_channels, embed_dim,
kernel_size=patch_size,
stride=patch_size
)
def forward(self, x: torch.Tensor):
# x: (B, C, T, H, W)
x = self.proj(x) # (B, D, T', H', W')
B, D, T, H, W = x.shape
x = x.flatten(2).transpose(1, 2) # (B, T'*H'*W', D)
return x, (T, H, W)
class VideoTTV1B(nn.Module):
"""
1B Parameter Text-to-Video Model
Architecture:
- Text Encoder: 6-layer transformer (50M params)
- DiT Backbone: 24 blocks, 1536 hidden dim, 24 heads (950M params)
- 3D Patch Embedding & Unpatchify
Total: ~1.0B parameters
"""
def __init__(
self,
img_size: Tuple[int, int] = (256, 256),
num_frames: int = 16,
patch_size: Tuple[int, int, int] = (2, 16, 16),
in_channels: int = 3,
hidden_dim: int = 1536,
depth: int = 24,
num_heads: int = 24,
mlp_ratio: float = 4.0,
text_dim: int = 768,
vocab_size: int = 50257,
max_text_len: int = 256,
):
super().__init__()
self.img_size = img_size
self.num_frames = num_frames
self.patch_size = patch_size
self.in_channels = in_channels
self.hidden_dim = hidden_dim
# Calculate patch dimensions
self.t_patches = num_frames // patch_size[0]
self.h_patches = img_size[0] // patch_size[1]
self.w_patches = img_size[1] // patch_size[2]
self.num_patches = self.t_patches * self.h_patches * self.w_patches
# Text encoder
self.text_encoder = TextEncoder(vocab_size, text_dim, max_text_len)
# Project text features to hidden dim
self.text_proj = nn.Linear(text_dim, hidden_dim)
# Patch embedding
self.patch_embed = PatchEmbed3D(patch_size, in_channels, hidden_dim)
# Positional embedding
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, hidden_dim))
# Timestep embedding for diffusion
self.time_embed = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.SiLU(),
nn.Linear(hidden_dim * 4, hidden_dim),
)
# DiT blocks
self.blocks = nn.ModuleList([
DiTBlock(hidden_dim, num_heads, mlp_ratio)
for _ in range(depth)
])
# Final layer
self.final_layer = nn.Sequential(
nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6),
nn.Linear(hidden_dim, patch_size[0] * patch_size[1] * patch_size[2] * in_channels),
)
# AdaLN for final layer
self.final_adaLN = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_dim, 2 * hidden_dim, bias=True)
)
self.initialize_weights()
def initialize_weights(self):
"""Initialize weights"""
# Initialize patch embedding like nn.Linear
w = self.patch_embed.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.patch_embed.proj.bias, 0)
# Initialize positional embedding
nn.init.normal_(self.pos_embed, std=0.02)
# Initialize transformer blocks
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def get_timestep_embedding(self, timesteps: torch.Tensor, dim: int):
"""Sinusoidal timestep embeddings"""
half_dim = dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb)
emb = timesteps[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
return emb
def unpatchify(self, x: torch.Tensor):
"""Convert patches back to video (B, N, patch_dim) -> (B, C, T, H, W)"""
B = x.shape[0]
t, h, w = self.patch_size
x = x.reshape(B, self.t_patches, self.h_patches, self.w_patches,
t, h, w, self.in_channels)
x = x.permute(0, 7, 1, 4, 2, 5, 3, 6) # (B, C, T', t, H', h, W', w)
x = x.reshape(B, self.in_channels, self.num_frames, self.img_size[0], self.img_size[1])
return x
def forward(self, x: torch.Tensor, timesteps: torch.Tensor, text_tokens: torch.Tensor):
"""
Forward pass
Args:
x: Noisy video tensor (B, C, T, H, W)
timesteps: Diffusion timesteps (B,)
text_tokens: Text token IDs (B, L)
Returns:
Predicted noise (B, C, T, H, W)
"""
B = x.shape[0]
# Encode text
text_emb = self.text_encoder(text_tokens) # (B, L, text_dim)
text_emb = self.text_proj(text_emb.mean(dim=1)) # (B, hidden_dim) - pool text features
# Timestep embedding
t_emb = self.get_timestep_embedding(timesteps, self.hidden_dim)
t_emb = self.time_embed(t_emb) # (B, hidden_dim)
# Combine text and timestep conditioning
c = text_emb + t_emb # (B, hidden_dim)
# Patch embedding
x, (T, H, W) = self.patch_embed(x) # (B, N, hidden_dim)
x = x + self.pos_embed
# Apply DiT blocks
for block in self.blocks:
x = block(x, c, self.t_patches)
# Final layer with adaptive layer norm
shift, scale = self.final_adaLN(c).chunk(2, dim=-1)
x = self.final_layer.modulate(self.final_layer[0](x), shift, scale)
x = self.final_layer[1](x)
# Unpatchify to video
x = self.unpatchify(x)
return x
def count_parameters(self):
"""Count total parameters"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
class DDPMScheduler:
"""DDPM noise scheduler for training and sampling"""
def __init__(self, num_steps: int = 1000, beta_start: float = 0.0001,
beta_end: float = 0.02):
self.num_steps = num_steps
# Linear beta schedule
self.betas = torch.linspace(beta_start, beta_end, num_steps)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
# Calculations for diffusion q(x_t | x_{t-1})
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
# Calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
def add_noise(self, x_0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor):
"""Add noise to clean data"""
sqrt_alpha_prod = self.sqrt_alphas_cumprod[t].reshape(-1, 1, 1, 1, 1)
sqrt_one_minus_alpha_prod = self.sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1, 1)
return sqrt_alpha_prod.to(x_0.device) * x_0 + sqrt_one_minus_alpha_prod.to(x_0.device) * noise
@torch.no_grad()
def sample_step(self, model: nn.Module, x_t: torch.Tensor, t: int,
text_tokens: torch.Tensor):
"""Single denoising step"""
betas_t = self.betas[t]
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t]
sqrt_recip_alphas_t = torch.sqrt(1.0 / self.alphas[t])
# Predict noise
timesteps = torch.full((x_t.shape[0],), t, device=x_t.device, dtype=torch.long)
predicted_noise = model(x_t, timesteps, text_tokens)
# Compute mean
model_mean = sqrt_recip_alphas_t * (
x_t - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t
)
if t == 0:
return model_mean
else:
posterior_variance_t = self.posterior_variance[t]
noise = torch.randn_like(x_t)
return model_mean + torch.sqrt(posterior_variance_t) * noise
def create_model(device: str = 'cuda'):
"""Factory function to create the model"""
model = VideoTTV1B(
img_size=(256, 256),
num_frames=16,
patch_size=(2, 16, 16),
in_channels=3,
hidden_dim=1536,
depth=24,
num_heads=24,
mlp_ratio=4.0,
)
total_params = model.count_parameters()
print(f"Total parameters: {total_params:,} ({total_params/1e9:.2f}B)")
return model.to(device)
if __name__ == "__main__":
# Test the model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Create model
model = create_model(device)
# Test forward pass
batch_size = 2
x = torch.randn(batch_size, 3, 16, 256, 256).to(device)
timesteps = torch.randint(0, 1000, (batch_size,)).to(device)
text_tokens = torch.randint(0, 50257, (batch_size, 128)).to(device)
print(f"\nInput shape: {x.shape}")
print(f"Timesteps shape: {timesteps.shape}")
print(f"Text tokens shape: {text_tokens.shape}")
with torch.no_grad():
output = model(x, timesteps, text_tokens)
print(f"Output shape: {output.shape}")
print("\n✓ Model test passed!")