File size: 3,021 Bytes
7d916a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python3
"""
Test script for Sinkhorn-Normalized Quantization integration
"""

import sys
import os
sys.path.insert(0, 'services/ai-service/src')

def test_imports():
    """Test that all modules can be imported"""
    try:
        from ai_med_extract.utils.quantization_utils import quantize_model_weights, get_quantization_config
        from ai_med_extract.utils.model_config import QUANTIZATION_CONFIG
        from ai_med_extract.utils.model_manager import UnifiedModelManager
        print("βœ“ All imports successful")
        return True
    except Exception as e:
        print(f"βœ— Import failed: {e}")
        return False

def test_quantization_config():
    """Test quantization configuration"""
    try:
        from ai_med_extract.utils.model_config import QUANTIZATION_CONFIG
        from ai_med_extract.utils.quantization_utils import get_quantization_config

        config = get_quantization_config()
        assert config['enabled'] == False, "Quantization should be disabled by default"
        assert 'num_centroids' in config
        assert 'num_iterations' in config
        print("βœ“ Quantization config test passed")
        return True
    except Exception as e:
        print(f"βœ— Config test failed: {e}")
        return False

def test_model_manager_initialization():
    """Test model manager can be initialized without errors"""
    try:
        from ai_med_extract.utils.model_manager import UnifiedModelManager
        manager = UnifiedModelManager()
        print("βœ“ Model manager initialization successful")
        return True
    except Exception as e:
        print(f"βœ— Model manager init failed: {e}")
        return False

def test_quantization_function():
    """Test quantization function with dummy data"""
    try:
        import torch
        from ai_med_extract.utils.quantization_utils import sinkhorn_normalized_quantization

        # Create dummy weights
        weights = torch.randn(10, 10)
        quantized = sinkhorn_normalized_quantization(weights, num_centroids=8, num_iterations=5)

        assert quantized.shape == weights.shape, "Shape should be preserved"
        print("βœ“ Quantization function test passed")
        return True
    except Exception as e:
        print(f"βœ— Quantization function test failed: {e}")
        return False

def main():
    print("Running Sinkhorn-Normalized Quantization tests...\n")

    tests = [
        test_imports,
        test_quantization_config,
        test_model_manager_initialization,
        test_quantization_function,
    ]

    passed = 0
    total = len(tests)

    for test in tests:
        if test():
            passed += 1
        print()

    print(f"Results: {passed}/{total} tests passed")

    if passed == total:
        print("πŸŽ‰ All tests passed! Sinkhorn-Normalized Quantization integration is working.")
        return 0
    else:
        print("❌ Some tests failed. Please check the implementation.")
        return 1

if __name__ == "__main__":
    sys.exit(main())