Spaces:
Sleeping
Sleeping
File size: 4,267 Bytes
3228ab0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import os
import sys
import logging
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_imports():
"""测试所有必需的模块都可以导入"""
try:
import torch
logger.info(f"PyTorch版本: {torch.__version__}")
import transformers
logger.info(f"Transformers版本: {transformers.__version__}")
import numpy as np
logger.info(f"NumPy版本: {np.__version__}")
import PIL
logger.info(f"PIL版本: {PIL.__version__}")
import scipy
logger.info(f"SciPy版本: {scipy.__version__}")
logger.info("所有导入成功")
return True
except ImportError as e:
logger.error(f"导入错误: {str(e)}")
return False
def test_model_loading():
"""测试模型可以加载"""
try:
from model import RadarDetectionModel
# 检查是否设置了HF_TOKEN环境变量
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
logger.warning("未设置HF_TOKEN环境变量,使用公共模型进行测试")
# 尝试初始化模型,使用较小的公共模型
logger.info("尝试初始化模型(使用较小的公共模型)")
model = RadarDetectionModel(model_name="google/siglip-base-patch16-224")
logger.info("模型初始化成功")
return True
except Exception as e:
logger.error(f"模型加载错误: {str(e)}")
return False
def test_feature_extraction():
"""测试特征提取功能"""
try:
import numpy as np
from PIL import Image
from feature_extraction import extract_features
# 创建一个虚拟图像和检测结果
logger.info("创建虚拟测试数据")
dummy_image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))
dummy_detection = {
'boxes': [[50, 50, 100, 100]],
'scores': [0.9],
'labels': ['测试']
}
# 提取特征
logger.info("提取特征")
features = extract_features(dummy_image, dummy_detection)
logger.info(f"提取的特征: {features}")
return True
except Exception as e:
logger.error(f"特征提取错误: {str(e)}")
return False
def test_app_initialization():
"""测试应用程序初始化但不加载模型"""
try:
logger.info("测试应用程序初始化")
import app
# 检查应用程序是否已初始化但没有加载模型
logger.info("检查应用程序全局变量")
assert app.model is None, "模型不应该在导入时加载"
assert app.MODEL_INIT_ATTEMPTED is False, "模型初始化尝试标志应为False"
logger.info("应用程序初始化测试通过")
return True
except Exception as e:
logger.error(f"应用程序初始化错误: {str(e)}")
return False
def run_tests():
"""运行所有测试"""
tests = [
("导入测试", test_imports),
("应用程序初始化测试", test_app_initialization),
("模型加载测试", test_model_loading),
("特征提取测试", test_feature_extraction)
]
results = []
for name, test_func in tests:
logger.info(f"运行{name}...")
try:
result = test_func()
results.append((name, result))
logger.info(f"{name}: {'通过' if result else '失败'}")
except Exception as e:
logger.error(f"{name}失败,错误: {str(e)}")
results.append((name, False))
# 打印摘要
logger.info("\n--- 测试摘要 ---")
passed = sum(1 for _, result in results if result)
total = len(results)
logger.info(f"通过: {passed}/{total} 测试")
for name, result in results:
status = "通过" if result else "失败"
logger.info(f"{name}: {status}")
return passed == total
if __name__ == "__main__":
success = run_tests()
sys.exit(0 if success else 1) |