""" Model evaluation and testing utilities for TTV-1B """ import torch import torch.nn as nn from video_ttv_1b import VideoTTV1B, create_model import time from typing import Dict, Tuple import numpy as np def count_parameters(model: nn.Module) -> Dict[str, int]: """Count parameters by component""" total = 0 breakdown = {} # Text encoder text_params = sum(p.numel() for p in model.text_encoder.parameters()) breakdown['text_encoder'] = text_params total += text_params # Patch embedding patch_params = sum(p.numel() for p in model.patch_embed.parameters()) breakdown['patch_embed'] = patch_params total += patch_params # DiT blocks dit_params = sum(p.numel() for p in model.blocks.parameters()) breakdown['dit_blocks'] = dit_params total += dit_params # Other other_params = sum(p.numel() for p in model.parameters()) - total breakdown['other'] = other_params total += other_params breakdown['total'] = total return breakdown def measure_inference_speed( model: nn.Module, batch_size: int = 1, num_iterations: int = 10, device: str = 'cuda', ) -> Dict[str, float]: """Measure inference speed""" model.eval() # Prepare dummy inputs videos = torch.randn(batch_size, 3, 16, 256, 256).to(device) timesteps = torch.randint(0, 1000, (batch_size,)).to(device) text_tokens = torch.randint(0, 50257, (batch_size, 256)).to(device) # Warmup with torch.no_grad(): for _ in range(3): _ = model(videos, timesteps, text_tokens) # Measure if device == 'cuda': torch.cuda.synchronize() start_time = time.time() with torch.no_grad(): for _ in range(num_iterations): _ = model(videos, timesteps, text_tokens) if device == 'cuda': torch.cuda.synchronize() end_time = time.time() total_time = end_time - start_time avg_time = total_time / num_iterations throughput = batch_size / avg_time return { 'total_time': total_time, 'avg_time_per_batch': avg_time, 'throughput': throughput, 'time_per_sample': avg_time / batch_size, } def measure_memory_usage( model: nn.Module, batch_size: int = 1, device: str = 'cuda', ) -> Dict[str, float]: """Measure memory usage""" if device != 'cuda': return {'error': 'Memory measurement only available on CUDA'} torch.cuda.reset_peak_memory_stats() torch.cuda.empty_cache() # Model memory model_memory = sum(p.numel() * p.element_size() for p in model.parameters()) model_memory_mb = model_memory / (1024 ** 2) # Forward pass memory videos = torch.randn(batch_size, 3, 16, 256, 256).to(device) timesteps = torch.randint(0, 1000, (batch_size,)).to(device) text_tokens = torch.randint(0, 50257, (batch_size, 256)).to(device) torch.cuda.reset_peak_memory_stats() with torch.no_grad(): _ = model(videos, timesteps, text_tokens) peak_memory = torch.cuda.max_memory_allocated() peak_memory_mb = peak_memory / (1024 ** 2) return { 'model_memory_mb': model_memory_mb, 'peak_memory_mb': peak_memory_mb, 'activation_memory_mb': peak_memory_mb - model_memory_mb, } def test_model_correctness(model: nn.Module, device: str = 'cuda') -> bool: """Test model correctness with various inputs""" model.eval() tests_passed = 0 total_tests = 0 # Test 1: Output shape total_tests += 1 x = torch.randn(2, 3, 16, 256, 256).to(device) t = torch.randint(0, 1000, (2,)).to(device) tokens = torch.randint(0, 50257, (2, 256)).to(device) with torch.no_grad(): output = model(x, t, tokens) if output.shape == x.shape: tests_passed += 1 print("✓ Test 1 passed: Output shape matches input") else: print(f"✗ Test 1 failed: Expected {x.shape}, got {output.shape}") # Test 2: No NaN values total_tests += 1 if not torch.isnan(output).any(): tests_passed += 1 print("✓ Test 2 passed: No NaN values in output") else: print("✗ Test 2 failed: NaN values detected in output") # Test 3: Different timesteps produce different outputs total_tests += 1 t1 = torch.full((2,), 0).to(device) t2 = torch.full((2,), 999).to(device) with torch.no_grad(): out1 = model(x, t1, tokens) out2 = model(x, t2, tokens) if not torch.allclose(out1, out2, rtol=1e-3): tests_passed += 1 print("✓ Test 3 passed: Different timesteps produce different outputs") else: print("✗ Test 3 failed: Outputs identical for different timesteps") # Test 4: Different text produces different outputs total_tests += 1 tokens1 = torch.randint(0, 50257, (2, 256)).to(device) tokens2 = torch.randint(0, 50257, (2, 256)).to(device) with torch.no_grad(): out1 = model(x, t, tokens1) out2 = model(x, t, tokens2) if not torch.allclose(out1, out2, rtol=1e-3): tests_passed += 1 print("✓ Test 4 passed: Different text produces different outputs") else: print("✗ Test 4 failed: Outputs identical for different text") # Test 5: Gradient flow (training mode) total_tests += 1 model.train() x.requires_grad = True output = model(x, t, tokens) loss = output.mean() loss.backward() if x.grad is not None and not torch.isnan(x.grad).any(): tests_passed += 1 print("✓ Test 5 passed: Gradients computed correctly") else: print("✗ Test 5 failed: Gradient computation error") model.eval() print(f"\nTests passed: {tests_passed}/{total_tests}") return tests_passed == total_tests def benchmark_full_pipeline(device: str = 'cuda'): """Comprehensive benchmark of the model""" print("="*60) print("TTV-1B Model Benchmark") print("="*60) # Create model print("\n1. Creating model...") model = create_model(device) print(f" Device: {device}") # Count parameters print("\n2. Parameter count:") param_counts = count_parameters(model) for name, count in param_counts.items(): print(f" {name:20s}: {count:>12,} ({count/1e6:>6.1f}M)") # Memory usage if device == 'cuda': print("\n3. Memory usage:") mem_stats = measure_memory_usage(model, batch_size=1, device=device) for name, value in mem_stats.items(): print(f" {name:25s}: {value:>8.1f} MB") # Inference speed print("\n4. Inference speed:") speed_stats = measure_inference_speed(model, batch_size=1, num_iterations=10, device=device) print(f" Average time per batch: {speed_stats['avg_time_per_batch']:.3f} seconds") print(f" Time per sample: {speed_stats['time_per_sample']:.3f} seconds") print(f" Throughput: {speed_stats['throughput']:.2f} samples/sec") # Correctness tests print("\n5. Correctness tests:") all_passed = test_model_correctness(model, device) print("\n" + "="*60) if all_passed: print("✓ All tests passed!") else: print("✗ Some tests failed") print("="*60) def estimate_training_time( num_samples: int = 1_000_000, batch_size: int = 16, num_epochs: int = 100, seconds_per_batch: float = 2.0, ) -> Dict[str, float]: """Estimate training time""" steps_per_epoch = num_samples // batch_size total_steps = steps_per_epoch * num_epochs total_seconds = total_steps * seconds_per_batch return { 'steps_per_epoch': steps_per_epoch, 'total_steps': total_steps, 'total_hours': total_seconds / 3600, 'total_days': total_seconds / (3600 * 24), } if __name__ == "__main__": # Run full benchmark device = 'cuda' if torch.cuda.is_available() else 'cpu' benchmark_full_pipeline(device) # Training time estimates print("\n" + "="*60) print("Training Time Estimates") print("="*60) configs = [ {'name': 'Single A100 (bs=2, grad_accum=8)', 'batch_size': 16, 'seconds_per_batch': 3.0}, {'name': '8x A100 (bs=16, grad_accum=8)', 'batch_size': 128, 'seconds_per_batch': 3.0}, ] for config in configs: print(f"\n{config['name']}:") estimates = estimate_training_time( num_samples=10_000_000, batch_size=config['batch_size'], num_epochs=10, seconds_per_batch=config['seconds_per_batch'], ) print(f" Steps per epoch: {estimates['steps_per_epoch']:,}") print(f" Total steps: {estimates['total_steps']:,}") print(f" Estimated time: {estimates['total_days']:.1f} days ({estimates['total_hours']:.1f} hours)")