| | """ |
| | Inference script for TTV-1B Text-to-Video Model |
| | Generate videos from text prompts |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from video_ttv_1b import VideoTTV1B, DDPMScheduler |
| | from pathlib import Path |
| | import numpy as np |
| | from typing import Optional, List |
| | from tqdm import tqdm |
| | import json |
| |
|
| |
|
| | class VideoGenerator: |
| | """Video generation from text prompts""" |
| | def __init__( |
| | self, |
| | model: nn.Module, |
| | noise_scheduler: DDPMScheduler, |
| | device: str = 'cuda', |
| | ): |
| | self.model = model.to(device) |
| | self.model.eval() |
| | self.noise_scheduler = noise_scheduler |
| | self.device = device |
| | |
| | def tokenize(self, text: str, max_length: int = 256) -> torch.Tensor: |
| | """Tokenize text (simple character-level tokenization)""" |
| | tokens = [ord(c) % 50257 for c in text[:max_length]] |
| | tokens = tokens + [0] * (max_length - len(tokens)) |
| | return torch.tensor([tokens], dtype=torch.long) |
| | |
| | @torch.no_grad() |
| | def generate( |
| | self, |
| | prompt: str, |
| | num_inference_steps: int = 50, |
| | guidance_scale: float = 7.5, |
| | seed: Optional[int] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Generate video from text prompt |
| | |
| | Args: |
| | prompt: Text description of the video |
| | num_inference_steps: Number of denoising steps |
| | guidance_scale: Classifier-free guidance scale |
| | seed: Random seed for reproducibility |
| | |
| | Returns: |
| | Generated video tensor (C, T, H, W) |
| | """ |
| | if seed is not None: |
| | torch.manual_seed(seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed(seed) |
| | |
| | |
| | text_tokens = self.tokenize(prompt).to(self.device) |
| | |
| | |
| | shape = (1, 3, self.model.num_frames, *self.model.img_size) |
| | x = torch.randn(shape, device=self.device) |
| | |
| | |
| | timesteps = torch.linspace( |
| | self.noise_scheduler.num_steps - 1, |
| | 0, |
| | num_inference_steps, |
| | dtype=torch.long, |
| | device=self.device |
| | ) |
| | |
| | |
| | for i, t in enumerate(tqdm(timesteps, desc="Generating video")): |
| | |
| | t_batch = t.unsqueeze(0) |
| | |
| | |
| | noise_pred = self.model(x, t_batch, text_tokens) |
| | |
| | |
| | if guidance_scale != 1.0: |
| | |
| | uncond_tokens = torch.zeros_like(text_tokens) |
| | noise_pred_uncond = self.model(x, t_batch, uncond_tokens) |
| | |
| | |
| | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) |
| | |
| | |
| | x = self.noise_scheduler.sample_step( |
| | lambda x_t, ts, txt: noise_pred, |
| | x, |
| | t.item(), |
| | text_tokens |
| | ) |
| | |
| | |
| | video = (x.squeeze(0) + 1) / 2 |
| | video = torch.clamp(video, 0, 1) |
| | |
| | return video |
| | |
| | def save_video(self, video: torch.Tensor, output_path: str, fps: int = 8): |
| | """ |
| | Save video tensor to file |
| | |
| | Args: |
| | video: Video tensor (C, T, H, W) in range [0, 1] |
| | output_path: Output file path |
| | fps: Frames per second |
| | """ |
| | try: |
| | import torchvision |
| | from torchvision.io import write_video |
| | |
| | |
| | video = video.permute(1, 2, 3, 0).cpu() |
| | video = (video * 255).to(torch.uint8) |
| | |
| | |
| | write_video(output_path, video, fps=fps) |
| | print(f"Video saved to {output_path}") |
| | |
| | except ImportError: |
| | print("torchvision not available, saving as numpy array") |
| | video_np = video.cpu().numpy() |
| | np.save(output_path.replace('.mp4', '.npy'), video_np) |
| | print(f"Video saved as numpy array to {output_path.replace('.mp4', '.npy')}") |
| |
|
| |
|
| | def load_model(checkpoint_path: str, device: str = 'cuda') -> VideoTTV1B: |
| | """Load model from checkpoint""" |
| | |
| | config_path = Path(checkpoint_path).parent / 'model_config.json' |
| | if config_path.exists(): |
| | with open(config_path, 'r') as f: |
| | config = json.load(f) |
| | print(f"Loaded model config: {config}") |
| | |
| | |
| | model = VideoTTV1B( |
| | img_size=(256, 256), |
| | num_frames=16, |
| | patch_size=(2, 16, 16), |
| | in_channels=3, |
| | hidden_dim=1536, |
| | depth=24, |
| | num_heads=24, |
| | mlp_ratio=4.0, |
| | ) |
| | |
| | |
| | checkpoint = torch.load(checkpoint_path, map_location=device) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | print(f"Loaded checkpoint from {checkpoint_path}") |
| | print(f"Training step: {checkpoint.get('global_step', 'unknown')}") |
| | |
| | return model |
| |
|
| |
|
| | def generate_video_from_prompt( |
| | prompt: str, |
| | checkpoint_path: str, |
| | output_path: str = "generated_video.mp4", |
| | num_steps: int = 50, |
| | guidance_scale: float = 7.5, |
| | seed: Optional[int] = None, |
| | device: str = 'cuda', |
| | ): |
| | """ |
| | High-level function to generate video from text prompt |
| | |
| | Args: |
| | prompt: Text description |
| | checkpoint_path: Path to model checkpoint |
| | output_path: Where to save the video |
| | num_steps: Number of denoising steps |
| | guidance_scale: Guidance strength |
| | seed: Random seed |
| | device: Device to run on |
| | """ |
| | print(f"Generating video for prompt: '{prompt}'") |
| | print(f"Using {num_steps} inference steps with guidance scale {guidance_scale}") |
| | |
| | |
| | model = load_model(checkpoint_path, device) |
| | |
| | |
| | noise_scheduler = DDPMScheduler(num_steps=1000) |
| | generator = VideoGenerator(model, noise_scheduler, device) |
| | |
| | |
| | video = generator.generate( |
| | prompt=prompt, |
| | num_inference_steps=num_steps, |
| | guidance_scale=guidance_scale, |
| | seed=seed, |
| | ) |
| | |
| | |
| | generator.save_video(video, output_path) |
| | |
| | return video |
| |
|
| |
|
| | def batch_generate( |
| | prompts: List[str], |
| | checkpoint_path: str, |
| | output_dir: str = "./generated_videos", |
| | **kwargs |
| | ): |
| | """Generate multiple videos from a list of prompts""" |
| | output_dir = Path(output_dir) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | for i, prompt in enumerate(prompts): |
| | print(f"\n[{i+1}/{len(prompts)}] Generating: {prompt}") |
| | output_path = output_dir / f"video_{i:04d}.mp4" |
| | |
| | try: |
| | generate_video_from_prompt( |
| | prompt=prompt, |
| | checkpoint_path=checkpoint_path, |
| | output_path=str(output_path), |
| | **kwargs |
| | ) |
| | except Exception as e: |
| | print(f"Error generating video {i}: {e}") |
| | continue |
| |
|
| |
|
| | def main(): |
| | """Example usage""" |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser(description="Generate videos from text prompts") |
| | parser.add_argument('--prompt', type=str, required=True, help='Text prompt') |
| | parser.add_argument('--checkpoint', type=str, required=True, help='Model checkpoint path') |
| | parser.add_argument('--output', type=str, default='generated_video.mp4', help='Output path') |
| | parser.add_argument('--steps', type=int, default=50, help='Number of inference steps') |
| | parser.add_argument('--guidance', type=float, default=7.5, help='Guidance scale') |
| | parser.add_argument('--seed', type=int, default=None, help='Random seed') |
| | parser.add_argument('--device', type=str, default='cuda', help='Device (cuda/cpu)') |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | generate_video_from_prompt( |
| | prompt=args.prompt, |
| | checkpoint_path=args.checkpoint, |
| | output_path=args.output, |
| | num_steps=args.steps, |
| | guidance_scale=args.guidance, |
| | seed=args.seed, |
| | device=args.device, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | example_prompts = [ |
| | "A serene sunset over the ocean with gentle waves", |
| | "A cat playing with a ball of yarn in slow motion", |
| | "Time-lapse of a flower blooming in spring", |
| | "Aerial view of a city at night with twinkling lights", |
| | "Underwater scene with colorful fish swimming", |
| | ] |
| | |
| | print("Example prompts for video generation:") |
| | for i, prompt in enumerate(example_prompts, 1): |
| | print(f"{i}. {prompt}") |
| | |
| | print("\nRun with: python inference.py --prompt 'your prompt' --checkpoint path/to/checkpoint.pt") |
| |
|