Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| def test_imports(): | |
| """测试所有必需的模块都可以导入""" | |
| try: | |
| import torch | |
| logger.info(f"PyTorch版本: {torch.__version__}") | |
| import transformers | |
| logger.info(f"Transformers版本: {transformers.__version__}") | |
| import numpy as np | |
| logger.info(f"NumPy版本: {np.__version__}") | |
| import PIL | |
| logger.info(f"PIL版本: {PIL.__version__}") | |
| import scipy | |
| logger.info(f"SciPy版本: {scipy.__version__}") | |
| logger.info("所有导入成功") | |
| return True | |
| except ImportError as e: | |
| logger.error(f"导入错误: {str(e)}") | |
| return False | |
| def test_model_loading(): | |
| """测试模型可以加载""" | |
| try: | |
| from model import RadarDetectionModel | |
| # 检查是否设置了HF_TOKEN环境变量 | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if not hf_token: | |
| logger.warning("未设置HF_TOKEN环境变量,使用公共模型进行测试") | |
| # 尝试初始化模型,使用较小的公共模型 | |
| logger.info("尝试初始化模型(使用较小的公共模型)") | |
| model = RadarDetectionModel(model_name="google/siglip-base-patch16-224") | |
| logger.info("模型初始化成功") | |
| return True | |
| except Exception as e: | |
| logger.error(f"模型加载错误: {str(e)}") | |
| return False | |
| def test_feature_extraction(): | |
| """测试特征提取功能""" | |
| try: | |
| import numpy as np | |
| from PIL import Image | |
| from feature_extraction import extract_features | |
| # 创建一个虚拟图像和检测结果 | |
| logger.info("创建虚拟测试数据") | |
| dummy_image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) | |
| dummy_detection = { | |
| 'boxes': [[50, 50, 100, 100]], | |
| 'scores': [0.9], | |
| 'labels': ['测试'] | |
| } | |
| # 提取特征 | |
| logger.info("提取特征") | |
| features = extract_features(dummy_image, dummy_detection) | |
| logger.info(f"提取的特征: {features}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"特征提取错误: {str(e)}") | |
| return False | |
| def test_app_initialization(): | |
| """测试应用程序初始化但不加载模型""" | |
| try: | |
| logger.info("测试应用程序初始化") | |
| import app | |
| # 检查应用程序是否已初始化但没有加载模型 | |
| logger.info("检查应用程序全局变量") | |
| assert app.model is None, "模型不应该在导入时加载" | |
| assert app.MODEL_INIT_ATTEMPTED is False, "模型初始化尝试标志应为False" | |
| logger.info("应用程序初始化测试通过") | |
| return True | |
| except Exception as e: | |
| logger.error(f"应用程序初始化错误: {str(e)}") | |
| return False | |
| def run_tests(): | |
| """运行所有测试""" | |
| tests = [ | |
| ("导入测试", test_imports), | |
| ("应用程序初始化测试", test_app_initialization), | |
| ("模型加载测试", test_model_loading), | |
| ("特征提取测试", test_feature_extraction) | |
| ] | |
| results = [] | |
| for name, test_func in tests: | |
| logger.info(f"运行{name}...") | |
| try: | |
| result = test_func() | |
| results.append((name, result)) | |
| logger.info(f"{name}: {'通过' if result else '失败'}") | |
| except Exception as e: | |
| logger.error(f"{name}失败,错误: {str(e)}") | |
| results.append((name, False)) | |
| # 打印摘要 | |
| logger.info("\n--- 测试摘要 ---") | |
| passed = sum(1 for _, result in results if result) | |
| total = len(results) | |
| logger.info(f"通过: {passed}/{total} 测试") | |
| for name, result in results: | |
| status = "通过" if result else "失败" | |
| logger.info(f"{name}: {status}") | |
| return passed == total | |
| if __name__ == "__main__": | |
| success = run_tests() | |
| sys.exit(0 if success else 1) |