| | """ |
| | Model evaluation and testing utilities for TTV-1B |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from video_ttv_1b import VideoTTV1B, create_model |
| | import time |
| | from typing import Dict, Tuple |
| | import numpy as np |
| |
|
| |
|
| | def count_parameters(model: nn.Module) -> Dict[str, int]: |
| | """Count parameters by component""" |
| | total = 0 |
| | breakdown = {} |
| | |
| | |
| | text_params = sum(p.numel() for p in model.text_encoder.parameters()) |
| | breakdown['text_encoder'] = text_params |
| | total += text_params |
| | |
| | |
| | patch_params = sum(p.numel() for p in model.patch_embed.parameters()) |
| | breakdown['patch_embed'] = patch_params |
| | total += patch_params |
| | |
| | |
| | dit_params = sum(p.numel() for p in model.blocks.parameters()) |
| | breakdown['dit_blocks'] = dit_params |
| | total += dit_params |
| | |
| | |
| | other_params = sum(p.numel() for p in model.parameters()) - total |
| | breakdown['other'] = other_params |
| | total += other_params |
| | |
| | breakdown['total'] = total |
| | |
| | return breakdown |
| |
|
| |
|
| | def measure_inference_speed( |
| | model: nn.Module, |
| | batch_size: int = 1, |
| | num_iterations: int = 10, |
| | device: str = 'cuda', |
| | ) -> Dict[str, float]: |
| | """Measure inference speed""" |
| | model.eval() |
| | |
| | |
| | videos = torch.randn(batch_size, 3, 16, 256, 256).to(device) |
| | timesteps = torch.randint(0, 1000, (batch_size,)).to(device) |
| | text_tokens = torch.randint(0, 50257, (batch_size, 256)).to(device) |
| | |
| | |
| | with torch.no_grad(): |
| | for _ in range(3): |
| | _ = model(videos, timesteps, text_tokens) |
| | |
| | |
| | if device == 'cuda': |
| | torch.cuda.synchronize() |
| | |
| | start_time = time.time() |
| | |
| | with torch.no_grad(): |
| | for _ in range(num_iterations): |
| | _ = model(videos, timesteps, text_tokens) |
| | if device == 'cuda': |
| | torch.cuda.synchronize() |
| | |
| | end_time = time.time() |
| | |
| | total_time = end_time - start_time |
| | avg_time = total_time / num_iterations |
| | throughput = batch_size / avg_time |
| | |
| | return { |
| | 'total_time': total_time, |
| | 'avg_time_per_batch': avg_time, |
| | 'throughput': throughput, |
| | 'time_per_sample': avg_time / batch_size, |
| | } |
| |
|
| |
|
| | def measure_memory_usage( |
| | model: nn.Module, |
| | batch_size: int = 1, |
| | device: str = 'cuda', |
| | ) -> Dict[str, float]: |
| | """Measure memory usage""" |
| | if device != 'cuda': |
| | return {'error': 'Memory measurement only available on CUDA'} |
| | |
| | torch.cuda.reset_peak_memory_stats() |
| | torch.cuda.empty_cache() |
| | |
| | |
| | model_memory = sum(p.numel() * p.element_size() for p in model.parameters()) |
| | model_memory_mb = model_memory / (1024 ** 2) |
| | |
| | |
| | videos = torch.randn(batch_size, 3, 16, 256, 256).to(device) |
| | timesteps = torch.randint(0, 1000, (batch_size,)).to(device) |
| | text_tokens = torch.randint(0, 50257, (batch_size, 256)).to(device) |
| | |
| | torch.cuda.reset_peak_memory_stats() |
| | |
| | with torch.no_grad(): |
| | _ = model(videos, timesteps, text_tokens) |
| | |
| | peak_memory = torch.cuda.max_memory_allocated() |
| | peak_memory_mb = peak_memory / (1024 ** 2) |
| | |
| | return { |
| | 'model_memory_mb': model_memory_mb, |
| | 'peak_memory_mb': peak_memory_mb, |
| | 'activation_memory_mb': peak_memory_mb - model_memory_mb, |
| | } |
| |
|
| |
|
| | def test_model_correctness(model: nn.Module, device: str = 'cuda') -> bool: |
| | """Test model correctness with various inputs""" |
| | model.eval() |
| | |
| | tests_passed = 0 |
| | total_tests = 0 |
| | |
| | |
| | total_tests += 1 |
| | x = torch.randn(2, 3, 16, 256, 256).to(device) |
| | t = torch.randint(0, 1000, (2,)).to(device) |
| | tokens = torch.randint(0, 50257, (2, 256)).to(device) |
| | |
| | with torch.no_grad(): |
| | output = model(x, t, tokens) |
| | |
| | if output.shape == x.shape: |
| | tests_passed += 1 |
| | print("β Test 1 passed: Output shape matches input") |
| | else: |
| | print(f"β Test 1 failed: Expected {x.shape}, got {output.shape}") |
| | |
| | |
| | total_tests += 1 |
| | if not torch.isnan(output).any(): |
| | tests_passed += 1 |
| | print("β Test 2 passed: No NaN values in output") |
| | else: |
| | print("β Test 2 failed: NaN values detected in output") |
| | |
| | |
| | total_tests += 1 |
| | t1 = torch.full((2,), 0).to(device) |
| | t2 = torch.full((2,), 999).to(device) |
| | |
| | with torch.no_grad(): |
| | out1 = model(x, t1, tokens) |
| | out2 = model(x, t2, tokens) |
| | |
| | if not torch.allclose(out1, out2, rtol=1e-3): |
| | tests_passed += 1 |
| | print("β Test 3 passed: Different timesteps produce different outputs") |
| | else: |
| | print("β Test 3 failed: Outputs identical for different timesteps") |
| | |
| | |
| | total_tests += 1 |
| | tokens1 = torch.randint(0, 50257, (2, 256)).to(device) |
| | tokens2 = torch.randint(0, 50257, (2, 256)).to(device) |
| | |
| | with torch.no_grad(): |
| | out1 = model(x, t, tokens1) |
| | out2 = model(x, t, tokens2) |
| | |
| | if not torch.allclose(out1, out2, rtol=1e-3): |
| | tests_passed += 1 |
| | print("β Test 4 passed: Different text produces different outputs") |
| | else: |
| | print("β Test 4 failed: Outputs identical for different text") |
| | |
| | |
| | total_tests += 1 |
| | model.train() |
| | x.requires_grad = True |
| | output = model(x, t, tokens) |
| | loss = output.mean() |
| | loss.backward() |
| | |
| | if x.grad is not None and not torch.isnan(x.grad).any(): |
| | tests_passed += 1 |
| | print("β Test 5 passed: Gradients computed correctly") |
| | else: |
| | print("β Test 5 failed: Gradient computation error") |
| | |
| | model.eval() |
| | |
| | print(f"\nTests passed: {tests_passed}/{total_tests}") |
| | return tests_passed == total_tests |
| |
|
| |
|
| | def benchmark_full_pipeline(device: str = 'cuda'): |
| | """Comprehensive benchmark of the model""" |
| | print("="*60) |
| | print("TTV-1B Model Benchmark") |
| | print("="*60) |
| | |
| | |
| | print("\n1. Creating model...") |
| | model = create_model(device) |
| | print(f" Device: {device}") |
| | |
| | |
| | print("\n2. Parameter count:") |
| | param_counts = count_parameters(model) |
| | for name, count in param_counts.items(): |
| | print(f" {name:20s}: {count:>12,} ({count/1e6:>6.1f}M)") |
| | |
| | |
| | if device == 'cuda': |
| | print("\n3. Memory usage:") |
| | mem_stats = measure_memory_usage(model, batch_size=1, device=device) |
| | for name, value in mem_stats.items(): |
| | print(f" {name:25s}: {value:>8.1f} MB") |
| | |
| | |
| | print("\n4. Inference speed:") |
| | speed_stats = measure_inference_speed(model, batch_size=1, num_iterations=10, device=device) |
| | print(f" Average time per batch: {speed_stats['avg_time_per_batch']:.3f} seconds") |
| | print(f" Time per sample: {speed_stats['time_per_sample']:.3f} seconds") |
| | print(f" Throughput: {speed_stats['throughput']:.2f} samples/sec") |
| | |
| | |
| | print("\n5. Correctness tests:") |
| | all_passed = test_model_correctness(model, device) |
| | |
| | print("\n" + "="*60) |
| | if all_passed: |
| | print("β All tests passed!") |
| | else: |
| | print("β Some tests failed") |
| | print("="*60) |
| |
|
| |
|
| | def estimate_training_time( |
| | num_samples: int = 1_000_000, |
| | batch_size: int = 16, |
| | num_epochs: int = 100, |
| | seconds_per_batch: float = 2.0, |
| | ) -> Dict[str, float]: |
| | """Estimate training time""" |
| | steps_per_epoch = num_samples // batch_size |
| | total_steps = steps_per_epoch * num_epochs |
| | total_seconds = total_steps * seconds_per_batch |
| | |
| | return { |
| | 'steps_per_epoch': steps_per_epoch, |
| | 'total_steps': total_steps, |
| | 'total_hours': total_seconds / 3600, |
| | 'total_days': total_seconds / (3600 * 24), |
| | } |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | benchmark_full_pipeline(device) |
| | |
| | |
| | print("\n" + "="*60) |
| | print("Training Time Estimates") |
| | print("="*60) |
| | |
| | configs = [ |
| | {'name': 'Single A100 (bs=2, grad_accum=8)', 'batch_size': 16, 'seconds_per_batch': 3.0}, |
| | {'name': '8x A100 (bs=16, grad_accum=8)', 'batch_size': 128, 'seconds_per_batch': 3.0}, |
| | ] |
| | |
| | for config in configs: |
| | print(f"\n{config['name']}:") |
| | estimates = estimate_training_time( |
| | num_samples=10_000_000, |
| | batch_size=config['batch_size'], |
| | num_epochs=10, |
| | seconds_per_batch=config['seconds_per_batch'], |
| | ) |
| | print(f" Steps per epoch: {estimates['steps_per_epoch']:,}") |
| | print(f" Total steps: {estimates['total_steps']:,}") |
| | print(f" Estimated time: {estimates['total_days']:.1f} days ({estimates['total_hours']:.1f} hours)") |
| |
|