File size: 4,267 Bytes
3228ab0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import sys
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, 
                   format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def test_imports():
    """测试所有必需的模块都可以导入"""
    try:
        import torch
        logger.info(f"PyTorch版本: {torch.__version__}")
        
        import transformers
        logger.info(f"Transformers版本: {transformers.__version__}")
        
        import numpy as np
        logger.info(f"NumPy版本: {np.__version__}")
        
        import PIL
        logger.info(f"PIL版本: {PIL.__version__}")
        
        import scipy
        logger.info(f"SciPy版本: {scipy.__version__}")
        
        logger.info("所有导入成功")
        return True
    except ImportError as e:
        logger.error(f"导入错误: {str(e)}")
        return False

def test_model_loading():
    """测试模型可以加载"""
    try:
        from model import RadarDetectionModel
        
        # 检查是否设置了HF_TOKEN环境变量
        hf_token = os.environ.get("HF_TOKEN")
        if not hf_token:
            logger.warning("未设置HF_TOKEN环境变量,使用公共模型进行测试")
        
        # 尝试初始化模型,使用较小的公共模型
        logger.info("尝试初始化模型(使用较小的公共模型)")
        model = RadarDetectionModel(model_name="google/siglip-base-patch16-224")
        logger.info("模型初始化成功")
        return True
    except Exception as e:
        logger.error(f"模型加载错误: {str(e)}")
        return False

def test_feature_extraction():
    """测试特征提取功能"""
    try:
        import numpy as np
        from PIL import Image
        from feature_extraction import extract_features
        
        # 创建一个虚拟图像和检测结果
        logger.info("创建虚拟测试数据")
        dummy_image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))
        dummy_detection = {
            'boxes': [[50, 50, 100, 100]],
            'scores': [0.9],
            'labels': ['测试']
        }
        
        # 提取特征
        logger.info("提取特征")
        features = extract_features(dummy_image, dummy_detection)
        logger.info(f"提取的特征: {features}")
        return True
    except Exception as e:
        logger.error(f"特征提取错误: {str(e)}")
        return False

def test_app_initialization():
    """测试应用程序初始化但不加载模型"""
    try:
        logger.info("测试应用程序初始化")
        import app
        
        # 检查应用程序是否已初始化但没有加载模型
        logger.info("检查应用程序全局变量")
        assert app.model is None, "模型不应该在导入时加载"
        assert app.MODEL_INIT_ATTEMPTED is False, "模型初始化尝试标志应为False"
        
        logger.info("应用程序初始化测试通过")
        return True
    except Exception as e:
        logger.error(f"应用程序初始化错误: {str(e)}")
        return False

def run_tests():
    """运行所有测试"""
    tests = [
        ("导入测试", test_imports),
        ("应用程序初始化测试", test_app_initialization),
        ("模型加载测试", test_model_loading),
        ("特征提取测试", test_feature_extraction)
    ]
    
    results = []
    for name, test_func in tests:
        logger.info(f"运行{name}...")
        try:
            result = test_func()
            results.append((name, result))
            logger.info(f"{name}: {'通过' if result else '失败'}")
        except Exception as e:
            logger.error(f"{name}失败,错误: {str(e)}")
            results.append((name, False))
    
    # 打印摘要
    logger.info("\n--- 测试摘要 ---")
    passed = sum(1 for _, result in results if result)
    total = len(results)
    logger.info(f"通过: {passed}/{total} 测试")
    
    for name, result in results:
        status = "通过" if result else "失败"
        logger.info(f"{name}: {status}")
    
    return passed == total

if __name__ == "__main__":
    success = run_tests()
    sys.exit(0 if success else 1)