File size: 3,054 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/env python3
"""

Test script for GenVidBench app

"""

import os
import sys

def test_imports():
    """Test if all required modules can be imported"""
    print("πŸ§ͺ Testing imports...")
    
    try:
        import torch
        print("βœ… PyTorch imported successfully")
    except ImportError as e:
        print(f"❌ PyTorch import failed: {e}")
        return False
    
    try:
        import gradio as gr
        print("βœ… Gradio imported successfully")
    except ImportError as e:
        print(f"❌ Gradio import failed: {e}")
        return False
    
    try:
        from mmaction.apis import init_recognizer, inference_recognizer
        print("βœ… MMAction2 APIs imported successfully")
    except ImportError as e:
        print(f"❌ MMAction2 import failed: {e}")
        return False
    
    return True

def test_files():
    """Test if required files exist"""
    print("πŸ§ͺ Testing file structure...")
    
    required_files = [
        'app.py',
        'requirements.txt',
        'README.md',
        'demo/demo_configs/tsn_r50_1x1x8_video_infer.py'
    ]
    
    for file_path in required_files:
        if os.path.exists(file_path):
            print(f"βœ… {file_path} exists")
        else:
            print(f"❌ {file_path} missing")
            return False
    
    return True

def test_model_checkpoint():
    """Test if model checkpoint exists or can be downloaded"""
    print("πŸ§ͺ Testing model checkpoint...")
    
    checkpoint_path = 'checkpoints/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb_20220818-2692d16c.pth'
    
    if os.path.exists(checkpoint_path):
        print(f"βœ… Model checkpoint found: {checkpoint_path}")
        return True
    else:
        print(f"⚠️  Model checkpoint not found: {checkpoint_path}")
        print("   Run 'python download_model.py' to download it")
        return False

def main():
    """Run all tests"""
    print("πŸš€ Testing GenVidBench setup...")
    print("=" * 50)
    
    tests = [
        ("Import Test", test_imports),
        ("File Structure Test", test_files),
        ("Model Checkpoint Test", test_model_checkpoint)
    ]
    
    results = []
    for test_name, test_func in tests:
        print(f"\nπŸ“‹ {test_name}")
        print("-" * 30)
        result = test_func()
        results.append((test_name, result))
    
    print("\n" + "=" * 50)
    print("πŸ“Š Test Results Summary:")
    print("=" * 50)
    
    all_passed = True
    for test_name, result in results:
        status = "βœ… PASS" if result else "❌ FAIL"
        print(f"{status} {test_name}")
        if not result:
            all_passed = False
    
    if all_passed:
        print("\nπŸŽ‰ All tests passed! Your app is ready for deployment.")
    else:
        print("\n⚠️  Some tests failed. Please fix the issues before deploying.")
    
    return all_passed

if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1)