File size: 3,504 Bytes
3d8856d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
"""
Quick Start Script for TTV-1B
Run this to verify installation and test the model
"""

import sys

def check_imports():
    """Check if required packages are installed"""
    print("Checking dependencies...")
    
    required = {
        'torch': 'PyTorch',
        'numpy': 'NumPy',
        'tqdm': 'tqdm',
    }
    
    missing = []
    for module, name in required.items():
        try:
            __import__(module)
            print(f"  ✓ {name}")
        except ImportError:
            print(f"  ✗ {name} - MISSING")
            missing.append(name)
    
    if missing:
        print(f"\nMissing packages: {', '.join(missing)}")
        print("Install with: pip install -r requirements.txt")
        return False
    
    return True


def test_model():
    """Test model creation"""
    print("\nTesting model...")
    
    try:
        import torch
        from video_ttv_1b import create_model
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"  Using device: {device}")
        
        # Create model (this will work even without CUDA)
        print("  Creating model...")
        model = create_model(device)
        
        print(f"  ✓ Model created successfully")
        print(f"  Total parameters: {model.count_parameters():,}")
        
        # Test forward pass with small inputs
        print("  Testing forward pass...")
        batch_size = 1
        x = torch.randn(batch_size, 3, 16, 256, 256).to(device)
        t = torch.randint(0, 1000, (batch_size,)).to(device)
        tokens = torch.randint(0, 50257, (batch_size, 256)).to(device)
        
        with torch.no_grad():
            output = model(x, t, tokens)
        
        print(f"  ✓ Forward pass successful")
        print(f"  Input shape:  {x.shape}")
        print(f"  Output shape: {output.shape}")
        
        return True
        
    except Exception as e:
        print(f"  ✗ Error: {e}")
        return False


def show_next_steps():
    """Show next steps"""
    print("\n" + "="*60)
    print("Next Steps:")
    print("="*60)
    print("\n1. Prepare your dataset:")
    print("   - Create data/videos/ directory")
    print("   - Add video files (MP4, 256x256, 16 frames)")
    print("   - Create data/annotations.json")
    
    print("\n2. Start training:")
    print("   python train.py")
    
    print("\n3. Generate videos (after training):")
    print("   python inference.py \\")
    print("       --prompt 'Your prompt here' \\")
    print("       --checkpoint checkpoints/best.pt \\")
    print("       --output video.mp4")
    
    print("\n4. Read documentation:")
    print("   - README.md - Overview and usage")
    print("   - ARCHITECTURE.md - Model details")
    print("   - SETUP.md - Installation guide")
    
    print("\n" + "="*60)


def main():
    """Main function"""
    print("="*60)
    print("TTV-1B Quick Start")
    print("1 Billion Parameter Text-to-Video Model")
    print("="*60)
    print()
    
    # Check dependencies
    if not check_imports():
        print("\nPlease install missing dependencies first.")
        sys.exit(1)
    
    # Test model
    if not test_model():
        print("\nModel test failed. Check the error messages above.")
        sys.exit(1)
    
    # Show next steps
    show_next_steps()
    
    print("\n✓ Quick start completed successfully!")
    print("\nYou're ready to train and generate videos with TTV-1B!")


if __name__ == "__main__":
    main()