""" Utility functions for TTV-1B model Data preprocessing, video I/O, and helper functions """ import torch import numpy as np from pathlib import Path from typing import Optional, List, Tuple, Dict import json # ============================================================================ # Video Processing Utilities # ============================================================================ def load_video_frames( video_path: str, num_frames: int = 16, target_size: Tuple[int, int] = (256, 256), ) -> torch.Tensor: """ Load video and extract frames Args: video_path: Path to video file num_frames: Number of frames to extract target_size: Target resolution (H, W) Returns: Video tensor (C, T, H, W) normalized to [-1, 1] """ try: # Try using torchvision from torchvision.io import read_video video, _, _ = read_video(video_path, pts_unit='sec') video = video.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W) # Sample frames uniformly total_frames = video.shape[1] indices = torch.linspace(0, total_frames - 1, num_frames).long() video = video[:, indices] # Resize import torch.nn.functional as F video = F.interpolate( video.float(), size=(num_frames, *target_size), mode='trilinear', align_corners=False ) # Normalize to [-1, 1] video = video / 127.5 - 1.0 return video except ImportError: # Fallback to opencv import cv2 cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Calculate frame indices to sample indices = np.linspace(0, total_frames - 1, num_frames).astype(int) frames = [] for idx in indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if ret: # Resize and convert BGR to RGB frame = cv2.resize(frame, target_size) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame) cap.release() # Convert to tensor video = np.stack(frames, axis=0) # (T, H, W, C) video = torch.from_numpy(video).permute(3, 0, 1, 2).float() # (C, T, H, W) # Normalize to [-1, 1] video = video / 127.5 - 1.0 return video def save_video_frames( frames: torch.Tensor, output_path: str, fps: int = 8, codec: str = 'libx264', ): """ Save video tensor to file Args: frames: Video tensor (C, T, H, W) or (T, H, W, C) in range [-1, 1] or [0, 1] output_path: Output file path fps: Frames per second codec: Video codec """ # Ensure frames are in [0, 1] range if frames.min() < 0: frames = (frames + 1) / 2 # [-1, 1] -> [0, 1] frames = torch.clamp(frames, 0, 1) # Convert to (T, H, W, C) format if frames.shape[0] == 3: # (C, T, H, W) frames = frames.permute(1, 2, 3, 0) # Scale to [0, 255] frames = (frames * 255).to(torch.uint8).cpu() try: from torchvision.io import write_video write_video(output_path, frames, fps=fps, video_codec=codec) print(f"Video saved to {output_path}") except ImportError: # Fallback to opencv import cv2 height, width = frames.shape[1:3] fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) for frame in frames: frame_bgr = cv2.cvtColor(frame.numpy(), cv2.COLOR_RGB2BGR) out.write(frame_bgr) out.release() print(f"Video saved to {output_path}") def create_video_grid( videos: List[torch.Tensor], grid_size: Optional[Tuple[int, int]] = None, ) -> torch.Tensor: """ Create a grid of videos for comparison Args: videos: List of video tensors (C, T, H, W) grid_size: (rows, cols). If None, automatically determined Returns: Grid video tensor (C, T, H_grid, W_grid) """ n_videos = len(videos) if grid_size is None: cols = int(np.ceil(np.sqrt(n_videos))) rows = int(np.ceil(n_videos / cols)) else: rows, cols = grid_size C, T, H, W = videos[0].shape # Pad with blank videos if needed while len(videos) < rows * cols: videos.append(torch.zeros_like(videos[0])) # Arrange in grid grid_rows = [] for i in range(rows): row_videos = videos[i * cols:(i + 1) * cols] row = torch.cat(row_videos, dim=-1) # Concatenate along width grid_rows.append(row) grid = torch.cat(grid_rows, dim=-2) # Concatenate along height return grid # ============================================================================ # Text Processing Utilities # ============================================================================ class SimpleTokenizer: """Simple character-level tokenizer (replace with proper tokenizer in production)""" def __init__(self, vocab_size: int = 50257): self.vocab_size = vocab_size def encode(self, text: str, max_length: int = 256) -> torch.Tensor: """Encode text to token IDs""" # Simple character-level encoding tokens = [ord(c) % self.vocab_size for c in text[:max_length]] # Pad to max length tokens = tokens + [0] * (max_length - len(tokens)) return torch.tensor(tokens, dtype=torch.long) def decode(self, tokens: torch.Tensor) -> str: """Decode token IDs to text""" chars = [chr(t.item()) for t in tokens if t.item() != 0] return ''.join(chars) def batch_encode(self, texts: List[str], max_length: int = 256) -> torch.Tensor: """Encode batch of texts""" return torch.stack([self.encode(text, max_length) for text in texts]) # ============================================================================ # Dataset Utilities # ============================================================================ def create_dataset_split( annotation_file: str, train_ratio: float = 0.9, seed: int = 42, ) -> Tuple[Dict, Dict]: """ Split dataset into train and validation sets Args: annotation_file: Path to annotations JSON train_ratio: Ratio of training data seed: Random seed Returns: train_annotations, val_annotations """ with open(annotation_file, 'r') as f: annotations = json.load(f) # Shuffle keys keys = list(annotations.keys()) np.random.seed(seed) np.random.shuffle(keys) # Split split_idx = int(len(keys) * train_ratio) train_keys = keys[:split_idx] val_keys = keys[split_idx:] train_annotations = {k: annotations[k] for k in train_keys} val_annotations = {k: annotations[k] for k in val_keys} return train_annotations, val_annotations def validate_dataset(video_dir: str, annotation_file: str) -> Dict[str, any]: """ Validate dataset integrity Returns: Dictionary with validation results """ video_dir = Path(video_dir) with open(annotation_file, 'r') as f: annotations = json.load(f) results = { 'total_videos': len(annotations), 'missing_videos': [], 'invalid_captions': [], 'warnings': [], } for video_id, data in annotations.items(): # Check video file exists video_path = video_dir / f"{video_id}.mp4" if not video_path.exists(): results['missing_videos'].append(video_id) # Check caption if 'caption' not in data or not data['caption'].strip(): results['invalid_captions'].append(video_id) # Check caption length if len(data.get('caption', '')) > 256: results['warnings'].append(f"{video_id}: Caption too long") results['valid'] = ( len(results['missing_videos']) == 0 and len(results['invalid_captions']) == 0 ) return results # ============================================================================ # Model Utilities # ============================================================================ def count_model_parameters(model: torch.nn.Module) -> Dict[str, int]: """Count model parameters""" total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) return { 'total': total_params, 'trainable': trainable_params, 'non_trainable': total_params - trainable_params, } def load_checkpoint_safe( model: torch.nn.Module, checkpoint_path: str, strict: bool = True, ) -> Dict[str, any]: """ Safely load checkpoint with error handling Returns: Dictionary with loading results """ try: checkpoint = torch.load(checkpoint_path, map_location='cpu') # Load model state if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict'], strict=strict) else: model.load_state_dict(checkpoint, strict=strict) return { 'success': True, 'step': checkpoint.get('global_step', -1), 'epoch': checkpoint.get('epoch', -1), } except Exception as e: return { 'success': False, 'error': str(e), } # ============================================================================ # Visualization Utilities # ============================================================================ def create_comparison_video( original: torch.Tensor, generated: torch.Tensor, prompt: str, output_path: str, ): """ Create side-by-side comparison video Args: original: Original video (C, T, H, W) generated: Generated video (C, T, H, W) prompt: Text prompt output_path: Where to save """ # Concatenate videos horizontally combined = torch.cat([original, generated], dim=-1) save_video_frames(combined, output_path) print(f"Comparison video saved to {output_path}") print(f"Prompt: {prompt}") # ============================================================================ # Logging Utilities # ============================================================================ class TrainingLogger: """Simple training logger""" def __init__(self, log_dir: str): self.log_dir = Path(log_dir) self.log_dir.mkdir(parents=True, exist_ok=True) self.log_file = self.log_dir / 'training.log' self.metrics = { 'step': [], 'loss': [], 'lr': [], } def log(self, step: int, loss: float, lr: float): """Log training metrics""" self.metrics['step'].append(step) self.metrics['loss'].append(loss) self.metrics['lr'].append(lr) # Write to file with open(self.log_file, 'a') as f: f.write(f"{step},{loss},{lr}\n") def save_metrics(self): """Save metrics to JSON""" output_file = self.log_dir / 'metrics.json' with open(output_file, 'w') as f: json.dump(self.metrics, f, indent=2) # ============================================================================ # Testing Utilities # ============================================================================ def test_video_pipeline(): """Test video loading and saving pipeline""" print("Testing video pipeline...") # Create dummy video video = torch.randn(3, 16, 256, 256) video = (video - video.min()) / (video.max() - video.min()) # Save output_path = "test_video.mp4" save_video_frames(video, output_path) # Load loaded = load_video_frames(output_path, num_frames=16) print(f"Original shape: {video.shape}") print(f"Loaded shape: {loaded.shape}") print("✓ Video pipeline test passed") def test_tokenizer(): """Test tokenizer""" print("Testing tokenizer...") tokenizer = SimpleTokenizer() text = "A beautiful sunset over the ocean" tokens = tokenizer.encode(text, max_length=128) decoded = tokenizer.decode(tokens) print(f"Original: {text}") print(f"Tokens shape: {tokens.shape}") print(f"Decoded: {decoded[:len(text)]}") print("✓ Tokenizer test passed") if __name__ == "__main__": print("Running utility tests...\n") test_tokenizer() print("\n" + "="*60 + "\n") print("Note: Video pipeline test requires torchvision or opencv") print("Run after installing dependencies")