File size: 13,139 Bytes
3d8856d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
"""
Utility functions for TTV-1B model
Data preprocessing, video I/O, and helper functions
"""

import torch
import numpy as np
from pathlib import Path
from typing import Optional, List, Tuple, Dict
import json


# ============================================================================
# Video Processing Utilities
# ============================================================================

def load_video_frames(
    video_path: str,
    num_frames: int = 16,
    target_size: Tuple[int, int] = (256, 256),
) -> torch.Tensor:
    """
    Load video and extract frames
    
    Args:
        video_path: Path to video file
        num_frames: Number of frames to extract
        target_size: Target resolution (H, W)
        
    Returns:
        Video tensor (C, T, H, W) normalized to [-1, 1]
    """
    try:
        # Try using torchvision
        from torchvision.io import read_video
        
        video, _, _ = read_video(video_path, pts_unit='sec')
        video = video.permute(3, 0, 1, 2)  # (T, H, W, C) -> (C, T, H, W)
        
        # Sample frames uniformly
        total_frames = video.shape[1]
        indices = torch.linspace(0, total_frames - 1, num_frames).long()
        video = video[:, indices]
        
        # Resize
        import torch.nn.functional as F
        video = F.interpolate(
            video.float(),
            size=(num_frames, *target_size),
            mode='trilinear',
            align_corners=False
        )
        
        # Normalize to [-1, 1]
        video = video / 127.5 - 1.0
        
        return video
        
    except ImportError:
        # Fallback to opencv
        import cv2
        
        cap = cv2.VideoCapture(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        # Calculate frame indices to sample
        indices = np.linspace(0, total_frames - 1, num_frames).astype(int)
        
        frames = []
        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                # Resize and convert BGR to RGB
                frame = cv2.resize(frame, target_size)
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(frame)
        
        cap.release()
        
        # Convert to tensor
        video = np.stack(frames, axis=0)  # (T, H, W, C)
        video = torch.from_numpy(video).permute(3, 0, 1, 2).float()  # (C, T, H, W)
        
        # Normalize to [-1, 1]
        video = video / 127.5 - 1.0
        
        return video


def save_video_frames(
    frames: torch.Tensor,
    output_path: str,
    fps: int = 8,
    codec: str = 'libx264',
):
    """
    Save video tensor to file
    
    Args:
        frames: Video tensor (C, T, H, W) or (T, H, W, C) in range [-1, 1] or [0, 1]
        output_path: Output file path
        fps: Frames per second
        codec: Video codec
    """
    # Ensure frames are in [0, 1] range
    if frames.min() < 0:
        frames = (frames + 1) / 2  # [-1, 1] -> [0, 1]
    
    frames = torch.clamp(frames, 0, 1)
    
    # Convert to (T, H, W, C) format
    if frames.shape[0] == 3:  # (C, T, H, W)
        frames = frames.permute(1, 2, 3, 0)
    
    # Scale to [0, 255]
    frames = (frames * 255).to(torch.uint8).cpu()
    
    try:
        from torchvision.io import write_video
        write_video(output_path, frames, fps=fps, video_codec=codec)
        print(f"Video saved to {output_path}")
        
    except ImportError:
        # Fallback to opencv
        import cv2
        
        height, width = frames.shape[1:3]
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
        
        for frame in frames:
            frame_bgr = cv2.cvtColor(frame.numpy(), cv2.COLOR_RGB2BGR)
            out.write(frame_bgr)
        
        out.release()
        print(f"Video saved to {output_path}")


def create_video_grid(
    videos: List[torch.Tensor],
    grid_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
    """
    Create a grid of videos for comparison
    
    Args:
        videos: List of video tensors (C, T, H, W)
        grid_size: (rows, cols). If None, automatically determined
        
    Returns:
        Grid video tensor (C, T, H_grid, W_grid)
    """
    n_videos = len(videos)
    
    if grid_size is None:
        cols = int(np.ceil(np.sqrt(n_videos)))
        rows = int(np.ceil(n_videos / cols))
    else:
        rows, cols = grid_size
    
    C, T, H, W = videos[0].shape
    
    # Pad with blank videos if needed
    while len(videos) < rows * cols:
        videos.append(torch.zeros_like(videos[0]))
    
    # Arrange in grid
    grid_rows = []
    for i in range(rows):
        row_videos = videos[i * cols:(i + 1) * cols]
        row = torch.cat(row_videos, dim=-1)  # Concatenate along width
        grid_rows.append(row)
    
    grid = torch.cat(grid_rows, dim=-2)  # Concatenate along height
    
    return grid


# ============================================================================
# Text Processing Utilities
# ============================================================================

class SimpleTokenizer:
    """Simple character-level tokenizer (replace with proper tokenizer in production)"""
    
    def __init__(self, vocab_size: int = 50257):
        self.vocab_size = vocab_size
    
    def encode(self, text: str, max_length: int = 256) -> torch.Tensor:
        """Encode text to token IDs"""
        # Simple character-level encoding
        tokens = [ord(c) % self.vocab_size for c in text[:max_length]]
        
        # Pad to max length
        tokens = tokens + [0] * (max_length - len(tokens))
        
        return torch.tensor(tokens, dtype=torch.long)
    
    def decode(self, tokens: torch.Tensor) -> str:
        """Decode token IDs to text"""
        chars = [chr(t.item()) for t in tokens if t.item() != 0]
        return ''.join(chars)
    
    def batch_encode(self, texts: List[str], max_length: int = 256) -> torch.Tensor:
        """Encode batch of texts"""
        return torch.stack([self.encode(text, max_length) for text in texts])


# ============================================================================
# Dataset Utilities
# ============================================================================

def create_dataset_split(
    annotation_file: str,
    train_ratio: float = 0.9,
    seed: int = 42,
) -> Tuple[Dict, Dict]:
    """
    Split dataset into train and validation sets
    
    Args:
        annotation_file: Path to annotations JSON
        train_ratio: Ratio of training data
        seed: Random seed
        
    Returns:
        train_annotations, val_annotations
    """
    with open(annotation_file, 'r') as f:
        annotations = json.load(f)
    
    # Shuffle keys
    keys = list(annotations.keys())
    np.random.seed(seed)
    np.random.shuffle(keys)
    
    # Split
    split_idx = int(len(keys) * train_ratio)
    train_keys = keys[:split_idx]
    val_keys = keys[split_idx:]
    
    train_annotations = {k: annotations[k] for k in train_keys}
    val_annotations = {k: annotations[k] for k in val_keys}
    
    return train_annotations, val_annotations


def validate_dataset(video_dir: str, annotation_file: str) -> Dict[str, any]:
    """
    Validate dataset integrity
    
    Returns:
        Dictionary with validation results
    """
    video_dir = Path(video_dir)
    
    with open(annotation_file, 'r') as f:
        annotations = json.load(f)
    
    results = {
        'total_videos': len(annotations),
        'missing_videos': [],
        'invalid_captions': [],
        'warnings': [],
    }
    
    for video_id, data in annotations.items():
        # Check video file exists
        video_path = video_dir / f"{video_id}.mp4"
        if not video_path.exists():
            results['missing_videos'].append(video_id)
        
        # Check caption
        if 'caption' not in data or not data['caption'].strip():
            results['invalid_captions'].append(video_id)
        
        # Check caption length
        if len(data.get('caption', '')) > 256:
            results['warnings'].append(f"{video_id}: Caption too long")
    
    results['valid'] = (
        len(results['missing_videos']) == 0 and
        len(results['invalid_captions']) == 0
    )
    
    return results


# ============================================================================
# Model Utilities
# ============================================================================

def count_model_parameters(model: torch.nn.Module) -> Dict[str, int]:
    """Count model parameters"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    return {
        'total': total_params,
        'trainable': trainable_params,
        'non_trainable': total_params - trainable_params,
    }


def load_checkpoint_safe(
    model: torch.nn.Module,
    checkpoint_path: str,
    strict: bool = True,
) -> Dict[str, any]:
    """
    Safely load checkpoint with error handling
    
    Returns:
        Dictionary with loading results
    """
    try:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        
        # Load model state
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
        else:
            model.load_state_dict(checkpoint, strict=strict)
        
        return {
            'success': True,
            'step': checkpoint.get('global_step', -1),
            'epoch': checkpoint.get('epoch', -1),
        }
        
    except Exception as e:
        return {
            'success': False,
            'error': str(e),
        }


# ============================================================================
# Visualization Utilities
# ============================================================================

def create_comparison_video(
    original: torch.Tensor,
    generated: torch.Tensor,
    prompt: str,
    output_path: str,
):
    """
    Create side-by-side comparison video
    
    Args:
        original: Original video (C, T, H, W)
        generated: Generated video (C, T, H, W)
        prompt: Text prompt
        output_path: Where to save
    """
    # Concatenate videos horizontally
    combined = torch.cat([original, generated], dim=-1)
    
    save_video_frames(combined, output_path)
    print(f"Comparison video saved to {output_path}")
    print(f"Prompt: {prompt}")


# ============================================================================
# Logging Utilities
# ============================================================================

class TrainingLogger:
    """Simple training logger"""
    
    def __init__(self, log_dir: str):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(parents=True, exist_ok=True)
        self.log_file = self.log_dir / 'training.log'
        
        self.metrics = {
            'step': [],
            'loss': [],
            'lr': [],
        }
    
    def log(self, step: int, loss: float, lr: float):
        """Log training metrics"""
        self.metrics['step'].append(step)
        self.metrics['loss'].append(loss)
        self.metrics['lr'].append(lr)
        
        # Write to file
        with open(self.log_file, 'a') as f:
            f.write(f"{step},{loss},{lr}\n")
    
    def save_metrics(self):
        """Save metrics to JSON"""
        output_file = self.log_dir / 'metrics.json'
        with open(output_file, 'w') as f:
            json.dump(self.metrics, f, indent=2)


# ============================================================================
# Testing Utilities
# ============================================================================

def test_video_pipeline():
    """Test video loading and saving pipeline"""
    print("Testing video pipeline...")
    
    # Create dummy video
    video = torch.randn(3, 16, 256, 256)
    video = (video - video.min()) / (video.max() - video.min())
    
    # Save
    output_path = "test_video.mp4"
    save_video_frames(video, output_path)
    
    # Load
    loaded = load_video_frames(output_path, num_frames=16)
    
    print(f"Original shape: {video.shape}")
    print(f"Loaded shape: {loaded.shape}")
    print("✓ Video pipeline test passed")


def test_tokenizer():
    """Test tokenizer"""
    print("Testing tokenizer...")
    
    tokenizer = SimpleTokenizer()
    
    text = "A beautiful sunset over the ocean"
    tokens = tokenizer.encode(text, max_length=128)
    decoded = tokenizer.decode(tokens)
    
    print(f"Original: {text}")
    print(f"Tokens shape: {tokens.shape}")
    print(f"Decoded: {decoded[:len(text)]}")
    print("✓ Tokenizer test passed")


if __name__ == "__main__":
    print("Running utility tests...\n")
    test_tokenizer()
    print("\n" + "="*60 + "\n")
    print("Note: Video pipeline test requires torchvision or opencv")
    print("Run after installing dependencies")