Zenderos / evaluate.py
ASADSANAN's picture
Upload 11 files
3d8856d verified
"""
Model evaluation and testing utilities for TTV-1B
"""
import torch
import torch.nn as nn
from video_ttv_1b import VideoTTV1B, create_model
import time
from typing import Dict, Tuple
import numpy as np
def count_parameters(model: nn.Module) -> Dict[str, int]:
"""Count parameters by component"""
total = 0
breakdown = {}
# Text encoder
text_params = sum(p.numel() for p in model.text_encoder.parameters())
breakdown['text_encoder'] = text_params
total += text_params
# Patch embedding
patch_params = sum(p.numel() for p in model.patch_embed.parameters())
breakdown['patch_embed'] = patch_params
total += patch_params
# DiT blocks
dit_params = sum(p.numel() for p in model.blocks.parameters())
breakdown['dit_blocks'] = dit_params
total += dit_params
# Other
other_params = sum(p.numel() for p in model.parameters()) - total
breakdown['other'] = other_params
total += other_params
breakdown['total'] = total
return breakdown
def measure_inference_speed(
model: nn.Module,
batch_size: int = 1,
num_iterations: int = 10,
device: str = 'cuda',
) -> Dict[str, float]:
"""Measure inference speed"""
model.eval()
# Prepare dummy inputs
videos = torch.randn(batch_size, 3, 16, 256, 256).to(device)
timesteps = torch.randint(0, 1000, (batch_size,)).to(device)
text_tokens = torch.randint(0, 50257, (batch_size, 256)).to(device)
# Warmup
with torch.no_grad():
for _ in range(3):
_ = model(videos, timesteps, text_tokens)
# Measure
if device == 'cuda':
torch.cuda.synchronize()
start_time = time.time()
with torch.no_grad():
for _ in range(num_iterations):
_ = model(videos, timesteps, text_tokens)
if device == 'cuda':
torch.cuda.synchronize()
end_time = time.time()
total_time = end_time - start_time
avg_time = total_time / num_iterations
throughput = batch_size / avg_time
return {
'total_time': total_time,
'avg_time_per_batch': avg_time,
'throughput': throughput,
'time_per_sample': avg_time / batch_size,
}
def measure_memory_usage(
model: nn.Module,
batch_size: int = 1,
device: str = 'cuda',
) -> Dict[str, float]:
"""Measure memory usage"""
if device != 'cuda':
return {'error': 'Memory measurement only available on CUDA'}
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
# Model memory
model_memory = sum(p.numel() * p.element_size() for p in model.parameters())
model_memory_mb = model_memory / (1024 ** 2)
# Forward pass memory
videos = torch.randn(batch_size, 3, 16, 256, 256).to(device)
timesteps = torch.randint(0, 1000, (batch_size,)).to(device)
text_tokens = torch.randint(0, 50257, (batch_size, 256)).to(device)
torch.cuda.reset_peak_memory_stats()
with torch.no_grad():
_ = model(videos, timesteps, text_tokens)
peak_memory = torch.cuda.max_memory_allocated()
peak_memory_mb = peak_memory / (1024 ** 2)
return {
'model_memory_mb': model_memory_mb,
'peak_memory_mb': peak_memory_mb,
'activation_memory_mb': peak_memory_mb - model_memory_mb,
}
def test_model_correctness(model: nn.Module, device: str = 'cuda') -> bool:
"""Test model correctness with various inputs"""
model.eval()
tests_passed = 0
total_tests = 0
# Test 1: Output shape
total_tests += 1
x = torch.randn(2, 3, 16, 256, 256).to(device)
t = torch.randint(0, 1000, (2,)).to(device)
tokens = torch.randint(0, 50257, (2, 256)).to(device)
with torch.no_grad():
output = model(x, t, tokens)
if output.shape == x.shape:
tests_passed += 1
print("βœ“ Test 1 passed: Output shape matches input")
else:
print(f"βœ— Test 1 failed: Expected {x.shape}, got {output.shape}")
# Test 2: No NaN values
total_tests += 1
if not torch.isnan(output).any():
tests_passed += 1
print("βœ“ Test 2 passed: No NaN values in output")
else:
print("βœ— Test 2 failed: NaN values detected in output")
# Test 3: Different timesteps produce different outputs
total_tests += 1
t1 = torch.full((2,), 0).to(device)
t2 = torch.full((2,), 999).to(device)
with torch.no_grad():
out1 = model(x, t1, tokens)
out2 = model(x, t2, tokens)
if not torch.allclose(out1, out2, rtol=1e-3):
tests_passed += 1
print("βœ“ Test 3 passed: Different timesteps produce different outputs")
else:
print("βœ— Test 3 failed: Outputs identical for different timesteps")
# Test 4: Different text produces different outputs
total_tests += 1
tokens1 = torch.randint(0, 50257, (2, 256)).to(device)
tokens2 = torch.randint(0, 50257, (2, 256)).to(device)
with torch.no_grad():
out1 = model(x, t, tokens1)
out2 = model(x, t, tokens2)
if not torch.allclose(out1, out2, rtol=1e-3):
tests_passed += 1
print("βœ“ Test 4 passed: Different text produces different outputs")
else:
print("βœ— Test 4 failed: Outputs identical for different text")
# Test 5: Gradient flow (training mode)
total_tests += 1
model.train()
x.requires_grad = True
output = model(x, t, tokens)
loss = output.mean()
loss.backward()
if x.grad is not None and not torch.isnan(x.grad).any():
tests_passed += 1
print("βœ“ Test 5 passed: Gradients computed correctly")
else:
print("βœ— Test 5 failed: Gradient computation error")
model.eval()
print(f"\nTests passed: {tests_passed}/{total_tests}")
return tests_passed == total_tests
def benchmark_full_pipeline(device: str = 'cuda'):
"""Comprehensive benchmark of the model"""
print("="*60)
print("TTV-1B Model Benchmark")
print("="*60)
# Create model
print("\n1. Creating model...")
model = create_model(device)
print(f" Device: {device}")
# Count parameters
print("\n2. Parameter count:")
param_counts = count_parameters(model)
for name, count in param_counts.items():
print(f" {name:20s}: {count:>12,} ({count/1e6:>6.1f}M)")
# Memory usage
if device == 'cuda':
print("\n3. Memory usage:")
mem_stats = measure_memory_usage(model, batch_size=1, device=device)
for name, value in mem_stats.items():
print(f" {name:25s}: {value:>8.1f} MB")
# Inference speed
print("\n4. Inference speed:")
speed_stats = measure_inference_speed(model, batch_size=1, num_iterations=10, device=device)
print(f" Average time per batch: {speed_stats['avg_time_per_batch']:.3f} seconds")
print(f" Time per sample: {speed_stats['time_per_sample']:.3f} seconds")
print(f" Throughput: {speed_stats['throughput']:.2f} samples/sec")
# Correctness tests
print("\n5. Correctness tests:")
all_passed = test_model_correctness(model, device)
print("\n" + "="*60)
if all_passed:
print("βœ“ All tests passed!")
else:
print("βœ— Some tests failed")
print("="*60)
def estimate_training_time(
num_samples: int = 1_000_000,
batch_size: int = 16,
num_epochs: int = 100,
seconds_per_batch: float = 2.0,
) -> Dict[str, float]:
"""Estimate training time"""
steps_per_epoch = num_samples // batch_size
total_steps = steps_per_epoch * num_epochs
total_seconds = total_steps * seconds_per_batch
return {
'steps_per_epoch': steps_per_epoch,
'total_steps': total_steps,
'total_hours': total_seconds / 3600,
'total_days': total_seconds / (3600 * 24),
}
if __name__ == "__main__":
# Run full benchmark
device = 'cuda' if torch.cuda.is_available() else 'cpu'
benchmark_full_pipeline(device)
# Training time estimates
print("\n" + "="*60)
print("Training Time Estimates")
print("="*60)
configs = [
{'name': 'Single A100 (bs=2, grad_accum=8)', 'batch_size': 16, 'seconds_per_batch': 3.0},
{'name': '8x A100 (bs=16, grad_accum=8)', 'batch_size': 128, 'seconds_per_batch': 3.0},
]
for config in configs:
print(f"\n{config['name']}:")
estimates = estimate_training_time(
num_samples=10_000_000,
batch_size=config['batch_size'],
num_epochs=10,
seconds_per_batch=config['seconds_per_batch'],
)
print(f" Steps per epoch: {estimates['steps_per_epoch']:,}")
print(f" Total steps: {estimates['total_steps']:,}")
print(f" Estimated time: {estimates['total_days']:.1f} days ({estimates['total_hours']:.1f} hours)")