Zenderos / ARCHITECTURE.md
ASADSANAN's picture
Upload 11 files
3d8856d verified
# TTV-1B Model Architecture Specification
## Model Summary
**Name:** TTV-1B (Text-to-Video 1 Billion)
**Type:** Diffusion Transformer for Text-to-Video Generation
**Total Parameters:** 1,003,147,264 (~1.0 Billion)
## Architecture Components
### 1. Text Encoder (50M parameters)
```
Input: Text tokens (batch_size, 256)
Architecture:
- Token Embedding: 50,257 vocab β†’ 768 dim
- Position Embedding: 256 positions β†’ 768 dim
- 6 Transformer Layers:
* Multi-head Attention (12 heads)
* Feed-forward (768 β†’ 3072 β†’ 768)
* Layer Normalization
Output: Text features (batch_size, 256, 768)
```
### 2. Text Projection Layer
```
Linear: 768 β†’ 1536 dimensions
Purpose: Project text features to model hidden dimension
```
### 3. 3D Patch Embedding
```
Input: Video (batch_size, 3, 16, 256, 256)
Patch size: (2, 16, 16) - temporal Γ— height Γ— width
Conv3D: 3 channels β†’ 1536 channels
Output: (batch_size, 128, 1536) where 128 = (16/2) Γ— (256/16) Γ— (256/16)
= 8 Γ— 16 Γ— 16
```
### 4. Positional Embedding
```
Learnable position embeddings for 128 patches
Shape: (1, 128, 1536)
```
### 5. Timestep Embedding
```
Sinusoidal timestep encoding β†’ Linear(1536, 6144) β†’ SiLU β†’ Linear(6144, 1536)
Output: Conditioning vector (batch_size, 1536)
```
### 6. DiT Blocks (24 layers, 950M parameters)
Each block contains:
#### a) 3D Spatiotemporal Attention
```
- Query, Key, Value projections: Linear(1536, 4608)
- 24 attention heads (64 dimensions each)
- Rotary position embeddings on temporal dimension
- Scaled dot-product attention
- Output projection: Linear(1536, 1536)
```
#### b) Feed-Forward Network
```
- Linear: 1536 β†’ 6144 (4x expansion)
- GELU activation
- Linear: 6144 β†’ 1536
```
#### c) Adaptive Layer Normalization (AdaLN)
```
- Modulation network: SiLU β†’ Linear(1536, 9216)
- Generates 6 modulation parameters:
* scale_msa, shift_msa, gate_msa (for attention)
* scale_mlp, shift_mlp, gate_mlp (for FFN)
```
### 7. Final Layer
```
- Adaptive LayerNorm
- Linear: 1536 β†’ 768 (2Γ—16Γ—16Γ—3)
Purpose: Map back to patch space
```
### 8. Unpatchify
```
Reshape patches back to video
(batch_size, 128, 768) β†’ (batch_size, 3, 16, 256, 256)
```
## Parameter Breakdown
| Component | Parameters | Percentage |
|-----------|------------|------------|
| Text Encoder | 50,331,648 | 5.0% |
| Text Projection | 1,180,416 | 0.1% |
| Patch Embedding | 589,824 | 0.1% |
| Position Embedding | 196,608 | 0.02% |
| Timestep Embedding | 14,157,312 | 1.4% |
| DiT Blocks (24Γ—) | 927,711,744 | 92.5% |
| Final Layer | 8,979,712 | 0.9% |
| **Total** | **1,003,147,264** | **100%** |
## Per-Block Parameters (DiT)
Each of 24 DiT blocks contains ~38.7M parameters:
| Sub-component | Parameters |
|---------------|------------|
| Attention QKV | 7,077,888 |
| Attention Proj | 2,362,368 |
| Rotary Embedding | 48 |
| FFN Layer 1 | 9,443,328 |
| FFN Layer 2 | 9,443,328 |
| AdaLN Modulation | 14,155,776 |
| Layer Norms | 0 (no learnable params) |
| **Per Block Total** | **38,654,656** |
## Data Flow
```
1. Text Input (batch, 256 tokens)
↓
2. Text Encoder (6 transformer layers)
↓
3. Text Features (batch, 256, 768) β†’ Pool β†’ (batch, 768)
↓
4. Project to 1536 dim β†’ (batch, 1536)
↓
5. Add Timestep Embedding β†’ Conditioning (batch, 1536)
↓
6. Video Input (batch, 3, 16, 256, 256)
↓
7. 3D Patch Embed β†’ (batch, 128, 1536)
↓
8. Add Position Embedding
↓
9. 24Γ— DiT Blocks (with conditioning)
↓
10. Final Layer + AdaLN
↓
11. Unpatchify
↓
12. Output: Predicted Noise (batch, 3, 16, 256, 256)
```
## Memory Requirements
### Model Weights
- FP32: ~4.0 GB
- FP16: ~2.0 GB
- INT8: ~1.0 GB
### Activations (per sample, 256Γ—256Γ—16)
- Forward pass: ~8 GB (FP16)
- Backward pass: ~16 GB (FP16)
### Training (batch_size=2, FP16, gradient accumulation=8)
- Model: 2 GB
- Optimizer states (AdamW): 4 GB
- Gradients: 2 GB
- Activations: 16 GB
- **Total: ~24 GB per GPU**
### Inference (batch_size=1, FP16)
- Model: 2 GB
- Activations: 4 GB
- **Total: ~6 GB**
## Computational Complexity
### FLOPs per forward pass (approximate)
- Text Encoder: ~10 GFLOPs
- Patch Embedding: ~5 GFLOPs
- DiT Blocks (24Γ—): ~4,800 GFLOPs
- Unpatchify: ~1 GFLOPs
- **Total: ~4,816 GFLOPs per video**
### Training Speed Estimates
- Single A100 80GB: ~2-3 seconds per batch (batch_size=2)
- 8Γ— A100 80GB: ~2-3 seconds per batch (batch_size=16)
### Inference Speed Estimates
- A100 80GB (50 denoising steps): ~15-20 seconds per video
- RTX 4090 (50 denoising steps): ~25-35 seconds per video
## Diffusion Scheduler
### DDPM (Denoising Diffusion Probabilistic Model)
- Training steps: 1000
- Beta schedule: Linear (0.0001 β†’ 0.02)
- Loss: MSE between predicted and actual noise
- Sampling: Iterative denoising from T=999 to T=0
### Classifier-Free Guidance
- Unconditional dropout during training: 10%
- Guidance scale at inference: 7.5 (typical)
- Formula: `noise_pred = noise_uncond + guidance_scale Γ— (noise_cond - noise_uncond)`
## Key Features
1. **3D Spatiotemporal Attention**
- Full attention across time, height, and width
- Captures motion dynamics and spatial relationships
2. **Rotary Position Embeddings**
- Applied to temporal dimension
- Better sequence modeling than learned embeddings
3. **Adaptive Layer Normalization**
- Conditions on text and timestep
- Allows flexible control over generation
4. **Efficient Design**
- Patch-based processing reduces sequence length
- Mixed precision training support
- Gradient checkpointing compatible
## Comparison with Other Models
| Model | Parameters | Resolution | Frames | Architecture |
|-------|------------|------------|--------|--------------|
| TTV-1B (ours) | 1.0B | 256Γ—256 | 16 | DiT |
| Stable Diffusion Video | 1.7B | 512Γ—512 | 25 | U-Net |
| Make-A-Video | 9.7B | 256Γ—256 | 16 | U-Net |
| Imagen Video | 11B | 1280Γ—768 | 128 | U-Net Cascade |
## Optimization Techniques
1. **Mixed Precision (FP16)**
- Reduces memory by 50%
- Faster computation on modern GPUs
2. **Gradient Accumulation**
- Enables larger effective batch sizes
- Improves training stability
3. **Gradient Checkpointing**
- Trades computation for memory
- Enables larger batch sizes
4. **Flash Attention**
- O(N) memory instead of O(NΒ²)
- Faster attention computation
## Future Enhancements
1. **Higher Resolution**: 512Γ—512 or 1024Γ—1024
2. **Longer Videos**: 64 or 128 frames
3. **Better Text Encoding**: CLIP or T5
4. **Temporal Super-Resolution**: Increase frame rate
5. **Motion Control**: Add motion guidance
6. **Video Editing**: Inpainting, style transfer
7. **LoRA Fine-tuning**: Efficient adaptation
8. **Distillation**: Smaller, faster variants