Zenderos / train.py
ASADSANAN's picture
Upload 11 files
3d8856d verified
"""
Training script for TTV-1B Text-to-Video Model
Supports distributed training, mixed precision, and gradient checkpointing
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import os
import json
from pathlib import Path
from tqdm import tqdm
import numpy as np
from typing import Dict, List, Optional
import logging
from video_ttv_1b import VideoTTV1B, DDPMScheduler
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class VideoTextDataset(Dataset):
"""Dataset for video-text pairs"""
def __init__(self, video_dir: str, annotation_file: str,
num_frames: int = 16, img_size: tuple = (256, 256)):
self.video_dir = Path(video_dir)
self.num_frames = num_frames
self.img_size = img_size
# Load annotations
with open(annotation_file, 'r') as f:
self.annotations = json.load(f)
self.video_ids = list(self.annotations.keys())
logger.info(f"Loaded {len(self.video_ids)} video-text pairs")
def __len__(self):
return len(self.video_ids)
def tokenize(self, text: str, max_length: int = 256) -> torch.Tensor:
"""Simple character-level tokenization (replace with proper tokenizer)"""
tokens = [ord(c) % 50257 for c in text[:max_length]]
tokens = tokens + [0] * (max_length - len(tokens)) # Pad
return torch.tensor(tokens, dtype=torch.long)
def load_video(self, video_path: Path) -> torch.Tensor:
"""Load and preprocess video (placeholder - implement with actual video loading)"""
# In production, use libraries like torchvision.io or decord
# This is a placeholder that generates synthetic data
video = torch.randn(3, self.num_frames, *self.img_size)
# Normalize to [-1, 1]
video = (video - video.min()) / (video.max() - video.min()) * 2 - 1
return video
def __getitem__(self, idx: int):
video_id = self.video_ids[idx]
annotation = self.annotations[video_id]
# Load video
video_path = self.video_dir / f"{video_id}.mp4"
video = self.load_video(video_path)
# Tokenize text
text = annotation['caption']
text_tokens = self.tokenize(text)
return {
'video': video,
'text_tokens': text_tokens,
'text': text # Keep original text for logging
}
class Trainer:
"""Trainer class for TTV-1B model"""
def __init__(
self,
model: nn.Module,
train_dataset: Dataset,
val_dataset: Optional[Dataset] = None,
batch_size: int = 4,
num_workers: int = 4,
learning_rate: float = 1e-4,
weight_decay: float = 0.01,
num_epochs: int = 100,
gradient_accumulation_steps: int = 4,
mixed_precision: bool = True,
gradient_checkpointing: bool = True,
save_dir: str = './checkpoints',
log_every: int = 100,
save_every: int = 5000,
device: str = 'cuda',
):
self.model = model
self.device = device
self.batch_size = batch_size
self.num_epochs = num_epochs
self.gradient_accumulation_steps = gradient_accumulation_steps
self.mixed_precision = mixed_precision
self.log_every = log_every
self.save_every = save_every
self.save_dir = Path(save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
# Enable gradient checkpointing to save memory
if gradient_checkpointing:
logger.info("Enabling gradient checkpointing")
# Note: Requires implementing checkpointing in model blocks
# Create dataloaders
self.train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=True
)
self.val_loader = None
if val_dataset:
self.val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
# Optimizer
self.optimizer = AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.999)
)
# Learning rate scheduler
self.scheduler = CosineAnnealingLR(
self.optimizer,
T_max=num_epochs * len(self.train_loader),
eta_min=learning_rate * 0.1
)
# Mixed precision scaler
self.scaler = GradScaler() if mixed_precision else None
# Diffusion scheduler
self.noise_scheduler = DDPMScheduler(num_steps=1000)
# Training state
self.global_step = 0
self.epoch = 0
self.best_val_loss = float('inf')
def train_step(self, batch: Dict[str, torch.Tensor]) -> float:
"""Single training step"""
videos = batch['video'].to(self.device)
text_tokens = batch['text_tokens'].to(self.device)
# Sample random timesteps
timesteps = torch.randint(
0, self.noise_scheduler.num_steps,
(videos.shape[0],),
device=self.device
)
# Add noise to videos
noise = torch.randn_like(videos)
noisy_videos = self.noise_scheduler.add_noise(videos, timesteps, noise)
# Forward pass
if self.mixed_precision:
with autocast():
predicted_noise = self.model(noisy_videos, timesteps, text_tokens)
loss = F.mse_loss(predicted_noise, noise)
loss = loss / self.gradient_accumulation_steps
else:
predicted_noise = self.model(noisy_videos, timesteps, text_tokens)
loss = F.mse_loss(predicted_noise, noise)
loss = loss / self.gradient_accumulation_steps
# Backward pass
if self.mixed_precision:
self.scaler.scale(loss).backward()
else:
loss.backward()
return loss.item() * self.gradient_accumulation_steps
@torch.no_grad()
def validate(self) -> float:
"""Validation loop"""
if self.val_loader is None:
return 0.0
self.model.eval()
total_loss = 0.0
num_batches = 0
for batch in tqdm(self.val_loader, desc="Validating"):
videos = batch['video'].to(self.device)
text_tokens = batch['text_tokens'].to(self.device)
timesteps = torch.randint(
0, self.noise_scheduler.num_steps,
(videos.shape[0],),
device=self.device
)
noise = torch.randn_like(videos)
noisy_videos = self.noise_scheduler.add_noise(videos, timesteps, noise)
predicted_noise = self.model(noisy_videos, timesteps, text_tokens)
loss = F.mse_loss(predicted_noise, noise)
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches
self.model.train()
return avg_loss
def save_checkpoint(self, suffix: str = ""):
"""Save model checkpoint"""
checkpoint_path = self.save_dir / f"checkpoint_step_{self.global_step}{suffix}.pt"
checkpoint = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'global_step': self.global_step,
'epoch': self.epoch,
'best_val_loss': self.best_val_loss,
}
if self.scaler:
checkpoint['scaler_state_dict'] = self.scaler.state_dict()
torch.save(checkpoint, checkpoint_path)
logger.info(f"Saved checkpoint to {checkpoint_path}")
# Save model config
config_path = self.save_dir / "model_config.json"
config = {
'architecture': 'VideoTTV1B',
'parameters': self.model.count_parameters(),
'img_size': self.model.img_size,
'num_frames': self.model.num_frames,
'patch_size': self.model.patch_size,
'hidden_dim': self.model.hidden_dim,
}
with open(config_path, 'w') as f:
json.dump(config, f, indent=2)
def load_checkpoint(self, checkpoint_path: str):
"""Load model checkpoint"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.global_step = checkpoint['global_step']
self.epoch = checkpoint['epoch']
self.best_val_loss = checkpoint['best_val_loss']
if self.scaler and 'scaler_state_dict' in checkpoint:
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
logger.info(f"Loaded checkpoint from {checkpoint_path}")
def train(self):
"""Main training loop"""
logger.info("Starting training...")
logger.info(f"Total parameters: {self.model.count_parameters():,}")
logger.info(f"Batch size: {self.batch_size}")
logger.info(f"Gradient accumulation steps: {self.gradient_accumulation_steps}")
logger.info(f"Effective batch size: {self.batch_size * self.gradient_accumulation_steps}")
self.model.train()
for epoch in range(self.epoch, self.num_epochs):
self.epoch = epoch
epoch_loss = 0.0
num_batches = 0
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.num_epochs}")
for step, batch in enumerate(pbar):
loss = self.train_step(batch)
epoch_loss += loss
num_batches += 1
# Gradient accumulation
if (step + 1) % self.gradient_accumulation_steps == 0:
# Clip gradients
if self.mixed_precision:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
# Optimizer step
if self.mixed_precision:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
self.global_step += 1
# Logging
if self.global_step % self.log_every == 0:
avg_loss = epoch_loss / num_batches
lr = self.scheduler.get_last_lr()[0]
logger.info(
f"Step {self.global_step} | "
f"Loss: {avg_loss:.4f} | "
f"LR: {lr:.2e}"
)
# Save checkpoint
if self.global_step % self.save_every == 0:
self.save_checkpoint()
# Update progress bar
pbar.set_postfix({'loss': f'{loss:.4f}'})
# Validation
if self.val_loader:
val_loss = self.validate()
logger.info(f"Epoch {epoch+1} | Validation loss: {val_loss:.4f}")
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.save_checkpoint(suffix="_best")
# Save epoch checkpoint
self.save_checkpoint(suffix=f"_epoch_{epoch+1}")
logger.info("Training completed!")
def main():
"""Main training script"""
# Configuration
config = {
'data_dir': './data/videos',
'annotation_file': './data/annotations.json',
'batch_size': 2, # Small batch size for 1B model
'num_workers': 4,
'learning_rate': 1e-4,
'weight_decay': 0.01,
'num_epochs': 100,
'gradient_accumulation_steps': 8, # Effective batch size = 16
'mixed_precision': True,
'gradient_checkpointing': True,
'save_dir': './checkpoints',
'device': 'cuda' if torch.cuda.is_available() else 'cpu',
}
logger.info("Configuration:")
for key, value in config.items():
logger.info(f" {key}: {value}")
# Create synthetic dataset for demonstration
# In production, replace with actual video dataset
logger.warning("Using synthetic dataset - replace with real data for training")
class SyntheticDataset(Dataset):
def __init__(self, size=1000):
self.size = size
def __len__(self):
return self.size
def __getitem__(self, idx):
return {
'video': torch.randn(3, 16, 256, 256),
'text_tokens': torch.randint(0, 50257, (256,)),
'text': f"Sample video {idx}"
}
train_dataset = SyntheticDataset(size=10000)
val_dataset = SyntheticDataset(size=1000)
# Create model
from video_ttv_1b import create_model
model = create_model(config['device'])
# Create trainer
trainer = Trainer(
model=model,
train_dataset=train_dataset,
val_dataset=val_dataset,
**{k: v for k, v in config.items() if k not in ['data_dir', 'annotation_file', 'device']}
)
# Train
trainer.train()
if __name__ == "__main__":
main()