File size: 3,054 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
#!/usr/bin/env python3
"""
Test script for GenVidBench app
"""
import os
import sys
def test_imports():
"""Test if all required modules can be imported"""
print("π§ͺ Testing imports...")
try:
import torch
print("β
PyTorch imported successfully")
except ImportError as e:
print(f"β PyTorch import failed: {e}")
return False
try:
import gradio as gr
print("β
Gradio imported successfully")
except ImportError as e:
print(f"β Gradio import failed: {e}")
return False
try:
from mmaction.apis import init_recognizer, inference_recognizer
print("β
MMAction2 APIs imported successfully")
except ImportError as e:
print(f"β MMAction2 import failed: {e}")
return False
return True
def test_files():
"""Test if required files exist"""
print("π§ͺ Testing file structure...")
required_files = [
'app.py',
'requirements.txt',
'README.md',
'demo/demo_configs/tsn_r50_1x1x8_video_infer.py'
]
for file_path in required_files:
if os.path.exists(file_path):
print(f"β
{file_path} exists")
else:
print(f"β {file_path} missing")
return False
return True
def test_model_checkpoint():
"""Test if model checkpoint exists or can be downloaded"""
print("π§ͺ Testing model checkpoint...")
checkpoint_path = 'checkpoints/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb_20220818-2692d16c.pth'
if os.path.exists(checkpoint_path):
print(f"β
Model checkpoint found: {checkpoint_path}")
return True
else:
print(f"β οΈ Model checkpoint not found: {checkpoint_path}")
print(" Run 'python download_model.py' to download it")
return False
def main():
"""Run all tests"""
print("π Testing GenVidBench setup...")
print("=" * 50)
tests = [
("Import Test", test_imports),
("File Structure Test", test_files),
("Model Checkpoint Test", test_model_checkpoint)
]
results = []
for test_name, test_func in tests:
print(f"\nπ {test_name}")
print("-" * 30)
result = test_func()
results.append((test_name, result))
print("\n" + "=" * 50)
print("π Test Results Summary:")
print("=" * 50)
all_passed = True
for test_name, result in results:
status = "β
PASS" if result else "β FAIL"
print(f"{status} {test_name}")
if not result:
all_passed = False
if all_passed:
print("\nπ All tests passed! Your app is ready for deployment.")
else:
print("\nβ οΈ Some tests failed. Please fix the issues before deploying.")
return all_passed
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)
|