|
|
|
|
|
"""
|
|
|
Test script for GenVidBench app
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import sys
|
|
|
|
|
|
def test_imports():
|
|
|
"""Test if all required modules can be imported"""
|
|
|
print("π§ͺ Testing imports...")
|
|
|
|
|
|
try:
|
|
|
import torch
|
|
|
print("β
PyTorch imported successfully")
|
|
|
except ImportError as e:
|
|
|
print(f"β PyTorch import failed: {e}")
|
|
|
return False
|
|
|
|
|
|
try:
|
|
|
import gradio as gr
|
|
|
print("β
Gradio imported successfully")
|
|
|
except ImportError as e:
|
|
|
print(f"β Gradio import failed: {e}")
|
|
|
return False
|
|
|
|
|
|
try:
|
|
|
from mmaction.apis import init_recognizer, inference_recognizer
|
|
|
print("β
MMAction2 APIs imported successfully")
|
|
|
except ImportError as e:
|
|
|
print(f"β MMAction2 import failed: {e}")
|
|
|
return False
|
|
|
|
|
|
return True
|
|
|
|
|
|
def test_files():
|
|
|
"""Test if required files exist"""
|
|
|
print("π§ͺ Testing file structure...")
|
|
|
|
|
|
required_files = [
|
|
|
'app.py',
|
|
|
'requirements.txt',
|
|
|
'README.md',
|
|
|
'demo/demo_configs/tsn_r50_1x1x8_video_infer.py'
|
|
|
]
|
|
|
|
|
|
for file_path in required_files:
|
|
|
if os.path.exists(file_path):
|
|
|
print(f"β
{file_path} exists")
|
|
|
else:
|
|
|
print(f"β {file_path} missing")
|
|
|
return False
|
|
|
|
|
|
return True
|
|
|
|
|
|
def test_model_checkpoint():
|
|
|
"""Test if model checkpoint exists or can be downloaded"""
|
|
|
print("π§ͺ Testing model checkpoint...")
|
|
|
|
|
|
checkpoint_path = 'checkpoints/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb_20220818-2692d16c.pth'
|
|
|
|
|
|
if os.path.exists(checkpoint_path):
|
|
|
print(f"β
Model checkpoint found: {checkpoint_path}")
|
|
|
return True
|
|
|
else:
|
|
|
print(f"β οΈ Model checkpoint not found: {checkpoint_path}")
|
|
|
print(" Run 'python download_model.py' to download it")
|
|
|
return False
|
|
|
|
|
|
def main():
|
|
|
"""Run all tests"""
|
|
|
print("π Testing GenVidBench setup...")
|
|
|
print("=" * 50)
|
|
|
|
|
|
tests = [
|
|
|
("Import Test", test_imports),
|
|
|
("File Structure Test", test_files),
|
|
|
("Model Checkpoint Test", test_model_checkpoint)
|
|
|
]
|
|
|
|
|
|
results = []
|
|
|
for test_name, test_func in tests:
|
|
|
print(f"\nπ {test_name}")
|
|
|
print("-" * 30)
|
|
|
result = test_func()
|
|
|
results.append((test_name, result))
|
|
|
|
|
|
print("\n" + "=" * 50)
|
|
|
print("π Test Results Summary:")
|
|
|
print("=" * 50)
|
|
|
|
|
|
all_passed = True
|
|
|
for test_name, result in results:
|
|
|
status = "β
PASS" if result else "β FAIL"
|
|
|
print(f"{status} {test_name}")
|
|
|
if not result:
|
|
|
all_passed = False
|
|
|
|
|
|
if all_passed:
|
|
|
print("\nπ All tests passed! Your app is ready for deployment.")
|
|
|
else:
|
|
|
print("\nβ οΈ Some tests failed. Please fix the issues before deploying.")
|
|
|
|
|
|
return all_passed
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
success = main()
|
|
|
sys.exit(0 if success else 1)
|
|
|
|