| | """ |
| | Training script for TTV-1B Text-to-Video Model |
| | Supports distributed training, mixed precision, and gradient checkpointing |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset, DataLoader |
| | from torch.cuda.amp import autocast, GradScaler |
| | from torch.optim import AdamW |
| | from torch.optim.lr_scheduler import CosineAnnealingLR |
| | import os |
| | import json |
| | from pathlib import Path |
| | from tqdm import tqdm |
| | import numpy as np |
| | from typing import Dict, List, Optional |
| | import logging |
| |
|
| | from video_ttv_1b import VideoTTV1B, DDPMScheduler |
| |
|
| |
|
| | |
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class VideoTextDataset(Dataset): |
| | """Dataset for video-text pairs""" |
| | def __init__(self, video_dir: str, annotation_file: str, |
| | num_frames: int = 16, img_size: tuple = (256, 256)): |
| | self.video_dir = Path(video_dir) |
| | self.num_frames = num_frames |
| | self.img_size = img_size |
| | |
| | |
| | with open(annotation_file, 'r') as f: |
| | self.annotations = json.load(f) |
| | |
| | self.video_ids = list(self.annotations.keys()) |
| | logger.info(f"Loaded {len(self.video_ids)} video-text pairs") |
| | |
| | def __len__(self): |
| | return len(self.video_ids) |
| | |
| | def tokenize(self, text: str, max_length: int = 256) -> torch.Tensor: |
| | """Simple character-level tokenization (replace with proper tokenizer)""" |
| | tokens = [ord(c) % 50257 for c in text[:max_length]] |
| | tokens = tokens + [0] * (max_length - len(tokens)) |
| | return torch.tensor(tokens, dtype=torch.long) |
| | |
| | def load_video(self, video_path: Path) -> torch.Tensor: |
| | """Load and preprocess video (placeholder - implement with actual video loading)""" |
| | |
| | |
| | video = torch.randn(3, self.num_frames, *self.img_size) |
| | |
| | video = (video - video.min()) / (video.max() - video.min()) * 2 - 1 |
| | return video |
| | |
| | def __getitem__(self, idx: int): |
| | video_id = self.video_ids[idx] |
| | annotation = self.annotations[video_id] |
| | |
| | |
| | video_path = self.video_dir / f"{video_id}.mp4" |
| | video = self.load_video(video_path) |
| | |
| | |
| | text = annotation['caption'] |
| | text_tokens = self.tokenize(text) |
| | |
| | return { |
| | 'video': video, |
| | 'text_tokens': text_tokens, |
| | 'text': text |
| | } |
| |
|
| |
|
| | class Trainer: |
| | """Trainer class for TTV-1B model""" |
| | def __init__( |
| | self, |
| | model: nn.Module, |
| | train_dataset: Dataset, |
| | val_dataset: Optional[Dataset] = None, |
| | batch_size: int = 4, |
| | num_workers: int = 4, |
| | learning_rate: float = 1e-4, |
| | weight_decay: float = 0.01, |
| | num_epochs: int = 100, |
| | gradient_accumulation_steps: int = 4, |
| | mixed_precision: bool = True, |
| | gradient_checkpointing: bool = True, |
| | save_dir: str = './checkpoints', |
| | log_every: int = 100, |
| | save_every: int = 5000, |
| | device: str = 'cuda', |
| | ): |
| | self.model = model |
| | self.device = device |
| | self.batch_size = batch_size |
| | self.num_epochs = num_epochs |
| | self.gradient_accumulation_steps = gradient_accumulation_steps |
| | self.mixed_precision = mixed_precision |
| | self.log_every = log_every |
| | self.save_every = save_every |
| | self.save_dir = Path(save_dir) |
| | self.save_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | |
| | if gradient_checkpointing: |
| | logger.info("Enabling gradient checkpointing") |
| | |
| | |
| | |
| | self.train_loader = DataLoader( |
| | train_dataset, |
| | batch_size=batch_size, |
| | shuffle=True, |
| | num_workers=num_workers, |
| | pin_memory=True, |
| | drop_last=True |
| | ) |
| | |
| | self.val_loader = None |
| | if val_dataset: |
| | self.val_loader = DataLoader( |
| | val_dataset, |
| | batch_size=batch_size, |
| | shuffle=False, |
| | num_workers=num_workers, |
| | pin_memory=True |
| | ) |
| | |
| | |
| | self.optimizer = AdamW( |
| | model.parameters(), |
| | lr=learning_rate, |
| | weight_decay=weight_decay, |
| | betas=(0.9, 0.999) |
| | ) |
| | |
| | |
| | self.scheduler = CosineAnnealingLR( |
| | self.optimizer, |
| | T_max=num_epochs * len(self.train_loader), |
| | eta_min=learning_rate * 0.1 |
| | ) |
| | |
| | |
| | self.scaler = GradScaler() if mixed_precision else None |
| | |
| | |
| | self.noise_scheduler = DDPMScheduler(num_steps=1000) |
| | |
| | |
| | self.global_step = 0 |
| | self.epoch = 0 |
| | self.best_val_loss = float('inf') |
| | |
| | def train_step(self, batch: Dict[str, torch.Tensor]) -> float: |
| | """Single training step""" |
| | videos = batch['video'].to(self.device) |
| | text_tokens = batch['text_tokens'].to(self.device) |
| | |
| | |
| | timesteps = torch.randint( |
| | 0, self.noise_scheduler.num_steps, |
| | (videos.shape[0],), |
| | device=self.device |
| | ) |
| | |
| | |
| | noise = torch.randn_like(videos) |
| | noisy_videos = self.noise_scheduler.add_noise(videos, timesteps, noise) |
| | |
| | |
| | if self.mixed_precision: |
| | with autocast(): |
| | predicted_noise = self.model(noisy_videos, timesteps, text_tokens) |
| | loss = F.mse_loss(predicted_noise, noise) |
| | loss = loss / self.gradient_accumulation_steps |
| | else: |
| | predicted_noise = self.model(noisy_videos, timesteps, text_tokens) |
| | loss = F.mse_loss(predicted_noise, noise) |
| | loss = loss / self.gradient_accumulation_steps |
| | |
| | |
| | if self.mixed_precision: |
| | self.scaler.scale(loss).backward() |
| | else: |
| | loss.backward() |
| | |
| | return loss.item() * self.gradient_accumulation_steps |
| | |
| | @torch.no_grad() |
| | def validate(self) -> float: |
| | """Validation loop""" |
| | if self.val_loader is None: |
| | return 0.0 |
| | |
| | self.model.eval() |
| | total_loss = 0.0 |
| | num_batches = 0 |
| | |
| | for batch in tqdm(self.val_loader, desc="Validating"): |
| | videos = batch['video'].to(self.device) |
| | text_tokens = batch['text_tokens'].to(self.device) |
| | |
| | timesteps = torch.randint( |
| | 0, self.noise_scheduler.num_steps, |
| | (videos.shape[0],), |
| | device=self.device |
| | ) |
| | |
| | noise = torch.randn_like(videos) |
| | noisy_videos = self.noise_scheduler.add_noise(videos, timesteps, noise) |
| | |
| | predicted_noise = self.model(noisy_videos, timesteps, text_tokens) |
| | loss = F.mse_loss(predicted_noise, noise) |
| | |
| | total_loss += loss.item() |
| | num_batches += 1 |
| | |
| | avg_loss = total_loss / num_batches |
| | self.model.train() |
| | return avg_loss |
| | |
| | def save_checkpoint(self, suffix: str = ""): |
| | """Save model checkpoint""" |
| | checkpoint_path = self.save_dir / f"checkpoint_step_{self.global_step}{suffix}.pt" |
| | |
| | checkpoint = { |
| | 'model_state_dict': self.model.state_dict(), |
| | 'optimizer_state_dict': self.optimizer.state_dict(), |
| | 'scheduler_state_dict': self.scheduler.state_dict(), |
| | 'global_step': self.global_step, |
| | 'epoch': self.epoch, |
| | 'best_val_loss': self.best_val_loss, |
| | } |
| | |
| | if self.scaler: |
| | checkpoint['scaler_state_dict'] = self.scaler.state_dict() |
| | |
| | torch.save(checkpoint, checkpoint_path) |
| | logger.info(f"Saved checkpoint to {checkpoint_path}") |
| | |
| | |
| | config_path = self.save_dir / "model_config.json" |
| | config = { |
| | 'architecture': 'VideoTTV1B', |
| | 'parameters': self.model.count_parameters(), |
| | 'img_size': self.model.img_size, |
| | 'num_frames': self.model.num_frames, |
| | 'patch_size': self.model.patch_size, |
| | 'hidden_dim': self.model.hidden_dim, |
| | } |
| | with open(config_path, 'w') as f: |
| | json.dump(config, f, indent=2) |
| | |
| | def load_checkpoint(self, checkpoint_path: str): |
| | """Load model checkpoint""" |
| | checkpoint = torch.load(checkpoint_path, map_location=self.device) |
| | |
| | self.model.load_state_dict(checkpoint['model_state_dict']) |
| | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
| | self.global_step = checkpoint['global_step'] |
| | self.epoch = checkpoint['epoch'] |
| | self.best_val_loss = checkpoint['best_val_loss'] |
| | |
| | if self.scaler and 'scaler_state_dict' in checkpoint: |
| | self.scaler.load_state_dict(checkpoint['scaler_state_dict']) |
| | |
| | logger.info(f"Loaded checkpoint from {checkpoint_path}") |
| | |
| | def train(self): |
| | """Main training loop""" |
| | logger.info("Starting training...") |
| | logger.info(f"Total parameters: {self.model.count_parameters():,}") |
| | logger.info(f"Batch size: {self.batch_size}") |
| | logger.info(f"Gradient accumulation steps: {self.gradient_accumulation_steps}") |
| | logger.info(f"Effective batch size: {self.batch_size * self.gradient_accumulation_steps}") |
| | |
| | self.model.train() |
| | |
| | for epoch in range(self.epoch, self.num_epochs): |
| | self.epoch = epoch |
| | epoch_loss = 0.0 |
| | num_batches = 0 |
| | |
| | pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.num_epochs}") |
| | |
| | for step, batch in enumerate(pbar): |
| | loss = self.train_step(batch) |
| | epoch_loss += loss |
| | num_batches += 1 |
| | |
| | |
| | if (step + 1) % self.gradient_accumulation_steps == 0: |
| | |
| | if self.mixed_precision: |
| | self.scaler.unscale_(self.optimizer) |
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
| | |
| | |
| | if self.mixed_precision: |
| | self.scaler.step(self.optimizer) |
| | self.scaler.update() |
| | else: |
| | self.optimizer.step() |
| | |
| | self.scheduler.step() |
| | self.optimizer.zero_grad() |
| | self.global_step += 1 |
| | |
| | |
| | if self.global_step % self.log_every == 0: |
| | avg_loss = epoch_loss / num_batches |
| | lr = self.scheduler.get_last_lr()[0] |
| | logger.info( |
| | f"Step {self.global_step} | " |
| | f"Loss: {avg_loss:.4f} | " |
| | f"LR: {lr:.2e}" |
| | ) |
| | |
| | |
| | if self.global_step % self.save_every == 0: |
| | self.save_checkpoint() |
| | |
| | |
| | pbar.set_postfix({'loss': f'{loss:.4f}'}) |
| | |
| | |
| | if self.val_loader: |
| | val_loss = self.validate() |
| | logger.info(f"Epoch {epoch+1} | Validation loss: {val_loss:.4f}") |
| | |
| | if val_loss < self.best_val_loss: |
| | self.best_val_loss = val_loss |
| | self.save_checkpoint(suffix="_best") |
| | |
| | |
| | self.save_checkpoint(suffix=f"_epoch_{epoch+1}") |
| | |
| | logger.info("Training completed!") |
| |
|
| |
|
| | def main(): |
| | """Main training script""" |
| | |
| | config = { |
| | 'data_dir': './data/videos', |
| | 'annotation_file': './data/annotations.json', |
| | 'batch_size': 2, |
| | 'num_workers': 4, |
| | 'learning_rate': 1e-4, |
| | 'weight_decay': 0.01, |
| | 'num_epochs': 100, |
| | 'gradient_accumulation_steps': 8, |
| | 'mixed_precision': True, |
| | 'gradient_checkpointing': True, |
| | 'save_dir': './checkpoints', |
| | 'device': 'cuda' if torch.cuda.is_available() else 'cpu', |
| | } |
| | |
| | logger.info("Configuration:") |
| | for key, value in config.items(): |
| | logger.info(f" {key}: {value}") |
| | |
| | |
| | |
| | logger.warning("Using synthetic dataset - replace with real data for training") |
| | |
| | class SyntheticDataset(Dataset): |
| | def __init__(self, size=1000): |
| | self.size = size |
| | |
| | def __len__(self): |
| | return self.size |
| | |
| | def __getitem__(self, idx): |
| | return { |
| | 'video': torch.randn(3, 16, 256, 256), |
| | 'text_tokens': torch.randint(0, 50257, (256,)), |
| | 'text': f"Sample video {idx}" |
| | } |
| | |
| | train_dataset = SyntheticDataset(size=10000) |
| | val_dataset = SyntheticDataset(size=1000) |
| | |
| | |
| | from video_ttv_1b import create_model |
| | model = create_model(config['device']) |
| | |
| | |
| | trainer = Trainer( |
| | model=model, |
| | train_dataset=train_dataset, |
| | val_dataset=val_dataset, |
| | **{k: v for k, v in config.items() if k not in ['data_dir', 'annotation_file', 'device']} |
| | ) |
| | |
| | |
| | trainer.train() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|