Zenderos / quickstart.py
ASADSANAN's picture
Upload 11 files
3d8856d verified
#!/usr/bin/env python3
"""
Quick Start Script for TTV-1B
Run this to verify installation and test the model
"""
import sys
def check_imports():
"""Check if required packages are installed"""
print("Checking dependencies...")
required = {
'torch': 'PyTorch',
'numpy': 'NumPy',
'tqdm': 'tqdm',
}
missing = []
for module, name in required.items():
try:
__import__(module)
print(f" βœ“ {name}")
except ImportError:
print(f" βœ— {name} - MISSING")
missing.append(name)
if missing:
print(f"\nMissing packages: {', '.join(missing)}")
print("Install with: pip install -r requirements.txt")
return False
return True
def test_model():
"""Test model creation"""
print("\nTesting model...")
try:
import torch
from video_ttv_1b import create_model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f" Using device: {device}")
# Create model (this will work even without CUDA)
print(" Creating model...")
model = create_model(device)
print(f" βœ“ Model created successfully")
print(f" Total parameters: {model.count_parameters():,}")
# Test forward pass with small inputs
print(" Testing forward pass...")
batch_size = 1
x = torch.randn(batch_size, 3, 16, 256, 256).to(device)
t = torch.randint(0, 1000, (batch_size,)).to(device)
tokens = torch.randint(0, 50257, (batch_size, 256)).to(device)
with torch.no_grad():
output = model(x, t, tokens)
print(f" βœ“ Forward pass successful")
print(f" Input shape: {x.shape}")
print(f" Output shape: {output.shape}")
return True
except Exception as e:
print(f" βœ— Error: {e}")
return False
def show_next_steps():
"""Show next steps"""
print("\n" + "="*60)
print("Next Steps:")
print("="*60)
print("\n1. Prepare your dataset:")
print(" - Create data/videos/ directory")
print(" - Add video files (MP4, 256x256, 16 frames)")
print(" - Create data/annotations.json")
print("\n2. Start training:")
print(" python train.py")
print("\n3. Generate videos (after training):")
print(" python inference.py \\")
print(" --prompt 'Your prompt here' \\")
print(" --checkpoint checkpoints/best.pt \\")
print(" --output video.mp4")
print("\n4. Read documentation:")
print(" - README.md - Overview and usage")
print(" - ARCHITECTURE.md - Model details")
print(" - SETUP.md - Installation guide")
print("\n" + "="*60)
def main():
"""Main function"""
print("="*60)
print("TTV-1B Quick Start")
print("1 Billion Parameter Text-to-Video Model")
print("="*60)
print()
# Check dependencies
if not check_imports():
print("\nPlease install missing dependencies first.")
sys.exit(1)
# Test model
if not test_model():
print("\nModel test failed. Check the error messages above.")
sys.exit(1)
# Show next steps
show_next_steps()
print("\nβœ“ Quick start completed successfully!")
print("\nYou're ready to train and generate videos with TTV-1B!")
if __name__ == "__main__":
main()