| | """ |
| | Utility functions for TTV-1B model |
| | Data preprocessing, video I/O, and helper functions |
| | """ |
| |
|
| | import torch |
| | import numpy as np |
| | from pathlib import Path |
| | from typing import Optional, List, Tuple, Dict |
| | import json |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_video_frames( |
| | video_path: str, |
| | num_frames: int = 16, |
| | target_size: Tuple[int, int] = (256, 256), |
| | ) -> torch.Tensor: |
| | """ |
| | Load video and extract frames |
| | |
| | Args: |
| | video_path: Path to video file |
| | num_frames: Number of frames to extract |
| | target_size: Target resolution (H, W) |
| | |
| | Returns: |
| | Video tensor (C, T, H, W) normalized to [-1, 1] |
| | """ |
| | try: |
| | |
| | from torchvision.io import read_video |
| | |
| | video, _, _ = read_video(video_path, pts_unit='sec') |
| | video = video.permute(3, 0, 1, 2) |
| | |
| | |
| | total_frames = video.shape[1] |
| | indices = torch.linspace(0, total_frames - 1, num_frames).long() |
| | video = video[:, indices] |
| | |
| | |
| | import torch.nn.functional as F |
| | video = F.interpolate( |
| | video.float(), |
| | size=(num_frames, *target_size), |
| | mode='trilinear', |
| | align_corners=False |
| | ) |
| | |
| | |
| | video = video / 127.5 - 1.0 |
| | |
| | return video |
| | |
| | except ImportError: |
| | |
| | import cv2 |
| | |
| | cap = cv2.VideoCapture(video_path) |
| | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | |
| | |
| | indices = np.linspace(0, total_frames - 1, num_frames).astype(int) |
| | |
| | frames = [] |
| | for idx in indices: |
| | cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
| | ret, frame = cap.read() |
| | if ret: |
| | |
| | frame = cv2.resize(frame, target_size) |
| | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | frames.append(frame) |
| | |
| | cap.release() |
| | |
| | |
| | video = np.stack(frames, axis=0) |
| | video = torch.from_numpy(video).permute(3, 0, 1, 2).float() |
| | |
| | |
| | video = video / 127.5 - 1.0 |
| | |
| | return video |
| |
|
| |
|
| | def save_video_frames( |
| | frames: torch.Tensor, |
| | output_path: str, |
| | fps: int = 8, |
| | codec: str = 'libx264', |
| | ): |
| | """ |
| | Save video tensor to file |
| | |
| | Args: |
| | frames: Video tensor (C, T, H, W) or (T, H, W, C) in range [-1, 1] or [0, 1] |
| | output_path: Output file path |
| | fps: Frames per second |
| | codec: Video codec |
| | """ |
| | |
| | if frames.min() < 0: |
| | frames = (frames + 1) / 2 |
| | |
| | frames = torch.clamp(frames, 0, 1) |
| | |
| | |
| | if frames.shape[0] == 3: |
| | frames = frames.permute(1, 2, 3, 0) |
| | |
| | |
| | frames = (frames * 255).to(torch.uint8).cpu() |
| | |
| | try: |
| | from torchvision.io import write_video |
| | write_video(output_path, frames, fps=fps, video_codec=codec) |
| | print(f"Video saved to {output_path}") |
| | |
| | except ImportError: |
| | |
| | import cv2 |
| | |
| | height, width = frames.shape[1:3] |
| | fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| | out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
| | |
| | for frame in frames: |
| | frame_bgr = cv2.cvtColor(frame.numpy(), cv2.COLOR_RGB2BGR) |
| | out.write(frame_bgr) |
| | |
| | out.release() |
| | print(f"Video saved to {output_path}") |
| |
|
| |
|
| | def create_video_grid( |
| | videos: List[torch.Tensor], |
| | grid_size: Optional[Tuple[int, int]] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Create a grid of videos for comparison |
| | |
| | Args: |
| | videos: List of video tensors (C, T, H, W) |
| | grid_size: (rows, cols). If None, automatically determined |
| | |
| | Returns: |
| | Grid video tensor (C, T, H_grid, W_grid) |
| | """ |
| | n_videos = len(videos) |
| | |
| | if grid_size is None: |
| | cols = int(np.ceil(np.sqrt(n_videos))) |
| | rows = int(np.ceil(n_videos / cols)) |
| | else: |
| | rows, cols = grid_size |
| | |
| | C, T, H, W = videos[0].shape |
| | |
| | |
| | while len(videos) < rows * cols: |
| | videos.append(torch.zeros_like(videos[0])) |
| | |
| | |
| | grid_rows = [] |
| | for i in range(rows): |
| | row_videos = videos[i * cols:(i + 1) * cols] |
| | row = torch.cat(row_videos, dim=-1) |
| | grid_rows.append(row) |
| | |
| | grid = torch.cat(grid_rows, dim=-2) |
| | |
| | return grid |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class SimpleTokenizer: |
| | """Simple character-level tokenizer (replace with proper tokenizer in production)""" |
| | |
| | def __init__(self, vocab_size: int = 50257): |
| | self.vocab_size = vocab_size |
| | |
| | def encode(self, text: str, max_length: int = 256) -> torch.Tensor: |
| | """Encode text to token IDs""" |
| | |
| | tokens = [ord(c) % self.vocab_size for c in text[:max_length]] |
| | |
| | |
| | tokens = tokens + [0] * (max_length - len(tokens)) |
| | |
| | return torch.tensor(tokens, dtype=torch.long) |
| | |
| | def decode(self, tokens: torch.Tensor) -> str: |
| | """Decode token IDs to text""" |
| | chars = [chr(t.item()) for t in tokens if t.item() != 0] |
| | return ''.join(chars) |
| | |
| | def batch_encode(self, texts: List[str], max_length: int = 256) -> torch.Tensor: |
| | """Encode batch of texts""" |
| | return torch.stack([self.encode(text, max_length) for text in texts]) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def create_dataset_split( |
| | annotation_file: str, |
| | train_ratio: float = 0.9, |
| | seed: int = 42, |
| | ) -> Tuple[Dict, Dict]: |
| | """ |
| | Split dataset into train and validation sets |
| | |
| | Args: |
| | annotation_file: Path to annotations JSON |
| | train_ratio: Ratio of training data |
| | seed: Random seed |
| | |
| | Returns: |
| | train_annotations, val_annotations |
| | """ |
| | with open(annotation_file, 'r') as f: |
| | annotations = json.load(f) |
| | |
| | |
| | keys = list(annotations.keys()) |
| | np.random.seed(seed) |
| | np.random.shuffle(keys) |
| | |
| | |
| | split_idx = int(len(keys) * train_ratio) |
| | train_keys = keys[:split_idx] |
| | val_keys = keys[split_idx:] |
| | |
| | train_annotations = {k: annotations[k] for k in train_keys} |
| | val_annotations = {k: annotations[k] for k in val_keys} |
| | |
| | return train_annotations, val_annotations |
| |
|
| |
|
| | def validate_dataset(video_dir: str, annotation_file: str) -> Dict[str, any]: |
| | """ |
| | Validate dataset integrity |
| | |
| | Returns: |
| | Dictionary with validation results |
| | """ |
| | video_dir = Path(video_dir) |
| | |
| | with open(annotation_file, 'r') as f: |
| | annotations = json.load(f) |
| | |
| | results = { |
| | 'total_videos': len(annotations), |
| | 'missing_videos': [], |
| | 'invalid_captions': [], |
| | 'warnings': [], |
| | } |
| | |
| | for video_id, data in annotations.items(): |
| | |
| | video_path = video_dir / f"{video_id}.mp4" |
| | if not video_path.exists(): |
| | results['missing_videos'].append(video_id) |
| | |
| | |
| | if 'caption' not in data or not data['caption'].strip(): |
| | results['invalid_captions'].append(video_id) |
| | |
| | |
| | if len(data.get('caption', '')) > 256: |
| | results['warnings'].append(f"{video_id}: Caption too long") |
| | |
| | results['valid'] = ( |
| | len(results['missing_videos']) == 0 and |
| | len(results['invalid_captions']) == 0 |
| | ) |
| | |
| | return results |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def count_model_parameters(model: torch.nn.Module) -> Dict[str, int]: |
| | """Count model parameters""" |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | |
| | return { |
| | 'total': total_params, |
| | 'trainable': trainable_params, |
| | 'non_trainable': total_params - trainable_params, |
| | } |
| |
|
| |
|
| | def load_checkpoint_safe( |
| | model: torch.nn.Module, |
| | checkpoint_path: str, |
| | strict: bool = True, |
| | ) -> Dict[str, any]: |
| | """ |
| | Safely load checkpoint with error handling |
| | |
| | Returns: |
| | Dictionary with loading results |
| | """ |
| | try: |
| | checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| | |
| | |
| | if 'model_state_dict' in checkpoint: |
| | model.load_state_dict(checkpoint['model_state_dict'], strict=strict) |
| | else: |
| | model.load_state_dict(checkpoint, strict=strict) |
| | |
| | return { |
| | 'success': True, |
| | 'step': checkpoint.get('global_step', -1), |
| | 'epoch': checkpoint.get('epoch', -1), |
| | } |
| | |
| | except Exception as e: |
| | return { |
| | 'success': False, |
| | 'error': str(e), |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def create_comparison_video( |
| | original: torch.Tensor, |
| | generated: torch.Tensor, |
| | prompt: str, |
| | output_path: str, |
| | ): |
| | """ |
| | Create side-by-side comparison video |
| | |
| | Args: |
| | original: Original video (C, T, H, W) |
| | generated: Generated video (C, T, H, W) |
| | prompt: Text prompt |
| | output_path: Where to save |
| | """ |
| | |
| | combined = torch.cat([original, generated], dim=-1) |
| | |
| | save_video_frames(combined, output_path) |
| | print(f"Comparison video saved to {output_path}") |
| | print(f"Prompt: {prompt}") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TrainingLogger: |
| | """Simple training logger""" |
| | |
| | def __init__(self, log_dir: str): |
| | self.log_dir = Path(log_dir) |
| | self.log_dir.mkdir(parents=True, exist_ok=True) |
| | self.log_file = self.log_dir / 'training.log' |
| | |
| | self.metrics = { |
| | 'step': [], |
| | 'loss': [], |
| | 'lr': [], |
| | } |
| | |
| | def log(self, step: int, loss: float, lr: float): |
| | """Log training metrics""" |
| | self.metrics['step'].append(step) |
| | self.metrics['loss'].append(loss) |
| | self.metrics['lr'].append(lr) |
| | |
| | |
| | with open(self.log_file, 'a') as f: |
| | f.write(f"{step},{loss},{lr}\n") |
| | |
| | def save_metrics(self): |
| | """Save metrics to JSON""" |
| | output_file = self.log_dir / 'metrics.json' |
| | with open(output_file, 'w') as f: |
| | json.dump(self.metrics, f, indent=2) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def test_video_pipeline(): |
| | """Test video loading and saving pipeline""" |
| | print("Testing video pipeline...") |
| | |
| | |
| | video = torch.randn(3, 16, 256, 256) |
| | video = (video - video.min()) / (video.max() - video.min()) |
| | |
| | |
| | output_path = "test_video.mp4" |
| | save_video_frames(video, output_path) |
| | |
| | |
| | loaded = load_video_frames(output_path, num_frames=16) |
| | |
| | print(f"Original shape: {video.shape}") |
| | print(f"Loaded shape: {loaded.shape}") |
| | print("✓ Video pipeline test passed") |
| |
|
| |
|
| | def test_tokenizer(): |
| | """Test tokenizer""" |
| | print("Testing tokenizer...") |
| | |
| | tokenizer = SimpleTokenizer() |
| | |
| | text = "A beautiful sunset over the ocean" |
| | tokens = tokenizer.encode(text, max_length=128) |
| | decoded = tokenizer.decode(tokens) |
| | |
| | print(f"Original: {text}") |
| | print(f"Tokens shape: {tokens.shape}") |
| | print(f"Decoded: {decoded[:len(text)]}") |
| | print("✓ Tokenizer test passed") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print("Running utility tests...\n") |
| | test_tokenizer() |
| | print("\n" + "="*60 + "\n") |
| | print("Note: Video pipeline test requires torchvision or opencv") |
| | print("Run after installing dependencies") |
| |
|