Zenderos / inference.py
ASADSANAN's picture
Upload 11 files
3d8856d verified
"""
Inference script for TTV-1B Text-to-Video Model
Generate videos from text prompts
"""
import torch
import torch.nn as nn
from video_ttv_1b import VideoTTV1B, DDPMScheduler
from pathlib import Path
import numpy as np
from typing import Optional, List
from tqdm import tqdm
import json
class VideoGenerator:
"""Video generation from text prompts"""
def __init__(
self,
model: nn.Module,
noise_scheduler: DDPMScheduler,
device: str = 'cuda',
):
self.model = model.to(device)
self.model.eval()
self.noise_scheduler = noise_scheduler
self.device = device
def tokenize(self, text: str, max_length: int = 256) -> torch.Tensor:
"""Tokenize text (simple character-level tokenization)"""
tokens = [ord(c) % 50257 for c in text[:max_length]]
tokens = tokens + [0] * (max_length - len(tokens))
return torch.tensor([tokens], dtype=torch.long)
@torch.no_grad()
def generate(
self,
prompt: str,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
seed: Optional[int] = None,
) -> torch.Tensor:
"""
Generate video from text prompt
Args:
prompt: Text description of the video
num_inference_steps: Number of denoising steps
guidance_scale: Classifier-free guidance scale
seed: Random seed for reproducibility
Returns:
Generated video tensor (C, T, H, W)
"""
if seed is not None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# Tokenize prompt
text_tokens = self.tokenize(prompt).to(self.device)
# Start from random noise
shape = (1, 3, self.model.num_frames, *self.model.img_size)
x = torch.randn(shape, device=self.device)
# Prepare timesteps for inference
timesteps = torch.linspace(
self.noise_scheduler.num_steps - 1,
0,
num_inference_steps,
dtype=torch.long,
device=self.device
)
# Denoising loop
for i, t in enumerate(tqdm(timesteps, desc="Generating video")):
# Expand timestep to batch dimension
t_batch = t.unsqueeze(0)
# Predict noise
noise_pred = self.model(x, t_batch, text_tokens)
# Classifier-free guidance (requires training with unconditional dropout)
if guidance_scale != 1.0:
# Generate unconditional prediction
uncond_tokens = torch.zeros_like(text_tokens)
noise_pred_uncond = self.model(x, t_batch, uncond_tokens)
# Apply guidance
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
# Denoise step
x = self.noise_scheduler.sample_step(
lambda x_t, ts, txt: noise_pred,
x,
t.item(),
text_tokens
)
# Denormalize from [-1, 1] to [0, 1]
video = (x.squeeze(0) + 1) / 2
video = torch.clamp(video, 0, 1)
return video
def save_video(self, video: torch.Tensor, output_path: str, fps: int = 8):
"""
Save video tensor to file
Args:
video: Video tensor (C, T, H, W) in range [0, 1]
output_path: Output file path
fps: Frames per second
"""
try:
import torchvision
from torchvision.io import write_video
# Convert to (T, H, W, C) and scale to [0, 255]
video = video.permute(1, 2, 3, 0).cpu()
video = (video * 255).to(torch.uint8)
# Save video
write_video(output_path, video, fps=fps)
print(f"Video saved to {output_path}")
except ImportError:
print("torchvision not available, saving as numpy array")
video_np = video.cpu().numpy()
np.save(output_path.replace('.mp4', '.npy'), video_np)
print(f"Video saved as numpy array to {output_path.replace('.mp4', '.npy')}")
def load_model(checkpoint_path: str, device: str = 'cuda') -> VideoTTV1B:
"""Load model from checkpoint"""
# Load config
config_path = Path(checkpoint_path).parent / 'model_config.json'
if config_path.exists():
with open(config_path, 'r') as f:
config = json.load(f)
print(f"Loaded model config: {config}")
# Create model
model = VideoTTV1B(
img_size=(256, 256),
num_frames=16,
patch_size=(2, 16, 16),
in_channels=3,
hidden_dim=1536,
depth=24,
num_heads=24,
mlp_ratio=4.0,
)
# Load weights
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded checkpoint from {checkpoint_path}")
print(f"Training step: {checkpoint.get('global_step', 'unknown')}")
return model
def generate_video_from_prompt(
prompt: str,
checkpoint_path: str,
output_path: str = "generated_video.mp4",
num_steps: int = 50,
guidance_scale: float = 7.5,
seed: Optional[int] = None,
device: str = 'cuda',
):
"""
High-level function to generate video from text prompt
Args:
prompt: Text description
checkpoint_path: Path to model checkpoint
output_path: Where to save the video
num_steps: Number of denoising steps
guidance_scale: Guidance strength
seed: Random seed
device: Device to run on
"""
print(f"Generating video for prompt: '{prompt}'")
print(f"Using {num_steps} inference steps with guidance scale {guidance_scale}")
# Load model
model = load_model(checkpoint_path, device)
# Create generator
noise_scheduler = DDPMScheduler(num_steps=1000)
generator = VideoGenerator(model, noise_scheduler, device)
# Generate video
video = generator.generate(
prompt=prompt,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
# Save video
generator.save_video(video, output_path)
return video
def batch_generate(
prompts: List[str],
checkpoint_path: str,
output_dir: str = "./generated_videos",
**kwargs
):
"""Generate multiple videos from a list of prompts"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for i, prompt in enumerate(prompts):
print(f"\n[{i+1}/{len(prompts)}] Generating: {prompt}")
output_path = output_dir / f"video_{i:04d}.mp4"
try:
generate_video_from_prompt(
prompt=prompt,
checkpoint_path=checkpoint_path,
output_path=str(output_path),
**kwargs
)
except Exception as e:
print(f"Error generating video {i}: {e}")
continue
def main():
"""Example usage"""
import argparse
parser = argparse.ArgumentParser(description="Generate videos from text prompts")
parser.add_argument('--prompt', type=str, required=True, help='Text prompt')
parser.add_argument('--checkpoint', type=str, required=True, help='Model checkpoint path')
parser.add_argument('--output', type=str, default='generated_video.mp4', help='Output path')
parser.add_argument('--steps', type=int, default=50, help='Number of inference steps')
parser.add_argument('--guidance', type=float, default=7.5, help='Guidance scale')
parser.add_argument('--seed', type=int, default=None, help='Random seed')
parser.add_argument('--device', type=str, default='cuda', help='Device (cuda/cpu)')
args = parser.parse_args()
# Generate video
generate_video_from_prompt(
prompt=args.prompt,
checkpoint_path=args.checkpoint,
output_path=args.output,
num_steps=args.steps,
guidance_scale=args.guidance,
seed=args.seed,
device=args.device,
)
if __name__ == "__main__":
# Example prompts for testing
example_prompts = [
"A serene sunset over the ocean with gentle waves",
"A cat playing with a ball of yarn in slow motion",
"Time-lapse of a flower blooming in spring",
"Aerial view of a city at night with twinkling lights",
"Underwater scene with colorful fish swimming",
]
print("Example prompts for video generation:")
for i, prompt in enumerate(example_prompts, 1):
print(f"{i}. {prompt}")
print("\nRun with: python inference.py --prompt 'your prompt' --checkpoint path/to/checkpoint.pt")