Zenderos / utils.py
ASADSANAN's picture
Upload 11 files
3d8856d verified
"""
Utility functions for TTV-1B model
Data preprocessing, video I/O, and helper functions
"""
import torch
import numpy as np
from pathlib import Path
from typing import Optional, List, Tuple, Dict
import json
# ============================================================================
# Video Processing Utilities
# ============================================================================
def load_video_frames(
video_path: str,
num_frames: int = 16,
target_size: Tuple[int, int] = (256, 256),
) -> torch.Tensor:
"""
Load video and extract frames
Args:
video_path: Path to video file
num_frames: Number of frames to extract
target_size: Target resolution (H, W)
Returns:
Video tensor (C, T, H, W) normalized to [-1, 1]
"""
try:
# Try using torchvision
from torchvision.io import read_video
video, _, _ = read_video(video_path, pts_unit='sec')
video = video.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
# Sample frames uniformly
total_frames = video.shape[1]
indices = torch.linspace(0, total_frames - 1, num_frames).long()
video = video[:, indices]
# Resize
import torch.nn.functional as F
video = F.interpolate(
video.float(),
size=(num_frames, *target_size),
mode='trilinear',
align_corners=False
)
# Normalize to [-1, 1]
video = video / 127.5 - 1.0
return video
except ImportError:
# Fallback to opencv
import cv2
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Calculate frame indices to sample
indices = np.linspace(0, total_frames - 1, num_frames).astype(int)
frames = []
for idx in indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
# Resize and convert BGR to RGB
frame = cv2.resize(frame, target_size)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
cap.release()
# Convert to tensor
video = np.stack(frames, axis=0) # (T, H, W, C)
video = torch.from_numpy(video).permute(3, 0, 1, 2).float() # (C, T, H, W)
# Normalize to [-1, 1]
video = video / 127.5 - 1.0
return video
def save_video_frames(
frames: torch.Tensor,
output_path: str,
fps: int = 8,
codec: str = 'libx264',
):
"""
Save video tensor to file
Args:
frames: Video tensor (C, T, H, W) or (T, H, W, C) in range [-1, 1] or [0, 1]
output_path: Output file path
fps: Frames per second
codec: Video codec
"""
# Ensure frames are in [0, 1] range
if frames.min() < 0:
frames = (frames + 1) / 2 # [-1, 1] -> [0, 1]
frames = torch.clamp(frames, 0, 1)
# Convert to (T, H, W, C) format
if frames.shape[0] == 3: # (C, T, H, W)
frames = frames.permute(1, 2, 3, 0)
# Scale to [0, 255]
frames = (frames * 255).to(torch.uint8).cpu()
try:
from torchvision.io import write_video
write_video(output_path, frames, fps=fps, video_codec=codec)
print(f"Video saved to {output_path}")
except ImportError:
# Fallback to opencv
import cv2
height, width = frames.shape[1:3]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
for frame in frames:
frame_bgr = cv2.cvtColor(frame.numpy(), cv2.COLOR_RGB2BGR)
out.write(frame_bgr)
out.release()
print(f"Video saved to {output_path}")
def create_video_grid(
videos: List[torch.Tensor],
grid_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
"""
Create a grid of videos for comparison
Args:
videos: List of video tensors (C, T, H, W)
grid_size: (rows, cols). If None, automatically determined
Returns:
Grid video tensor (C, T, H_grid, W_grid)
"""
n_videos = len(videos)
if grid_size is None:
cols = int(np.ceil(np.sqrt(n_videos)))
rows = int(np.ceil(n_videos / cols))
else:
rows, cols = grid_size
C, T, H, W = videos[0].shape
# Pad with blank videos if needed
while len(videos) < rows * cols:
videos.append(torch.zeros_like(videos[0]))
# Arrange in grid
grid_rows = []
for i in range(rows):
row_videos = videos[i * cols:(i + 1) * cols]
row = torch.cat(row_videos, dim=-1) # Concatenate along width
grid_rows.append(row)
grid = torch.cat(grid_rows, dim=-2) # Concatenate along height
return grid
# ============================================================================
# Text Processing Utilities
# ============================================================================
class SimpleTokenizer:
"""Simple character-level tokenizer (replace with proper tokenizer in production)"""
def __init__(self, vocab_size: int = 50257):
self.vocab_size = vocab_size
def encode(self, text: str, max_length: int = 256) -> torch.Tensor:
"""Encode text to token IDs"""
# Simple character-level encoding
tokens = [ord(c) % self.vocab_size for c in text[:max_length]]
# Pad to max length
tokens = tokens + [0] * (max_length - len(tokens))
return torch.tensor(tokens, dtype=torch.long)
def decode(self, tokens: torch.Tensor) -> str:
"""Decode token IDs to text"""
chars = [chr(t.item()) for t in tokens if t.item() != 0]
return ''.join(chars)
def batch_encode(self, texts: List[str], max_length: int = 256) -> torch.Tensor:
"""Encode batch of texts"""
return torch.stack([self.encode(text, max_length) for text in texts])
# ============================================================================
# Dataset Utilities
# ============================================================================
def create_dataset_split(
annotation_file: str,
train_ratio: float = 0.9,
seed: int = 42,
) -> Tuple[Dict, Dict]:
"""
Split dataset into train and validation sets
Args:
annotation_file: Path to annotations JSON
train_ratio: Ratio of training data
seed: Random seed
Returns:
train_annotations, val_annotations
"""
with open(annotation_file, 'r') as f:
annotations = json.load(f)
# Shuffle keys
keys = list(annotations.keys())
np.random.seed(seed)
np.random.shuffle(keys)
# Split
split_idx = int(len(keys) * train_ratio)
train_keys = keys[:split_idx]
val_keys = keys[split_idx:]
train_annotations = {k: annotations[k] for k in train_keys}
val_annotations = {k: annotations[k] for k in val_keys}
return train_annotations, val_annotations
def validate_dataset(video_dir: str, annotation_file: str) -> Dict[str, any]:
"""
Validate dataset integrity
Returns:
Dictionary with validation results
"""
video_dir = Path(video_dir)
with open(annotation_file, 'r') as f:
annotations = json.load(f)
results = {
'total_videos': len(annotations),
'missing_videos': [],
'invalid_captions': [],
'warnings': [],
}
for video_id, data in annotations.items():
# Check video file exists
video_path = video_dir / f"{video_id}.mp4"
if not video_path.exists():
results['missing_videos'].append(video_id)
# Check caption
if 'caption' not in data or not data['caption'].strip():
results['invalid_captions'].append(video_id)
# Check caption length
if len(data.get('caption', '')) > 256:
results['warnings'].append(f"{video_id}: Caption too long")
results['valid'] = (
len(results['missing_videos']) == 0 and
len(results['invalid_captions']) == 0
)
return results
# ============================================================================
# Model Utilities
# ============================================================================
def count_model_parameters(model: torch.nn.Module) -> Dict[str, int]:
"""Count model parameters"""
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {
'total': total_params,
'trainable': trainable_params,
'non_trainable': total_params - trainable_params,
}
def load_checkpoint_safe(
model: torch.nn.Module,
checkpoint_path: str,
strict: bool = True,
) -> Dict[str, any]:
"""
Safely load checkpoint with error handling
Returns:
Dictionary with loading results
"""
try:
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# Load model state
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
else:
model.load_state_dict(checkpoint, strict=strict)
return {
'success': True,
'step': checkpoint.get('global_step', -1),
'epoch': checkpoint.get('epoch', -1),
}
except Exception as e:
return {
'success': False,
'error': str(e),
}
# ============================================================================
# Visualization Utilities
# ============================================================================
def create_comparison_video(
original: torch.Tensor,
generated: torch.Tensor,
prompt: str,
output_path: str,
):
"""
Create side-by-side comparison video
Args:
original: Original video (C, T, H, W)
generated: Generated video (C, T, H, W)
prompt: Text prompt
output_path: Where to save
"""
# Concatenate videos horizontally
combined = torch.cat([original, generated], dim=-1)
save_video_frames(combined, output_path)
print(f"Comparison video saved to {output_path}")
print(f"Prompt: {prompt}")
# ============================================================================
# Logging Utilities
# ============================================================================
class TrainingLogger:
"""Simple training logger"""
def __init__(self, log_dir: str):
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
self.log_file = self.log_dir / 'training.log'
self.metrics = {
'step': [],
'loss': [],
'lr': [],
}
def log(self, step: int, loss: float, lr: float):
"""Log training metrics"""
self.metrics['step'].append(step)
self.metrics['loss'].append(loss)
self.metrics['lr'].append(lr)
# Write to file
with open(self.log_file, 'a') as f:
f.write(f"{step},{loss},{lr}\n")
def save_metrics(self):
"""Save metrics to JSON"""
output_file = self.log_dir / 'metrics.json'
with open(output_file, 'w') as f:
json.dump(self.metrics, f, indent=2)
# ============================================================================
# Testing Utilities
# ============================================================================
def test_video_pipeline():
"""Test video loading and saving pipeline"""
print("Testing video pipeline...")
# Create dummy video
video = torch.randn(3, 16, 256, 256)
video = (video - video.min()) / (video.max() - video.min())
# Save
output_path = "test_video.mp4"
save_video_frames(video, output_path)
# Load
loaded = load_video_frames(output_path, num_frames=16)
print(f"Original shape: {video.shape}")
print(f"Loaded shape: {loaded.shape}")
print("✓ Video pipeline test passed")
def test_tokenizer():
"""Test tokenizer"""
print("Testing tokenizer...")
tokenizer = SimpleTokenizer()
text = "A beautiful sunset over the ocean"
tokens = tokenizer.encode(text, max_length=128)
decoded = tokenizer.decode(tokens)
print(f"Original: {text}")
print(f"Tokens shape: {tokens.shape}")
print(f"Decoded: {decoded[:len(text)]}")
print("✓ Tokenizer test passed")
if __name__ == "__main__":
print("Running utility tests...\n")
test_tokenizer()
print("\n" + "="*60 + "\n")
print("Note: Video pipeline test requires torchvision or opencv")
print("Run after installing dependencies")