test / test_quantization.py
sachinchandrankallar's picture
deploy
7d916a4
#!/usr/bin/env python3
"""
Test script for Sinkhorn-Normalized Quantization integration
"""
import sys
import os
sys.path.insert(0, 'services/ai-service/src')
def test_imports():
"""Test that all modules can be imported"""
try:
from ai_med_extract.utils.quantization_utils import quantize_model_weights, get_quantization_config
from ai_med_extract.utils.model_config import QUANTIZATION_CONFIG
from ai_med_extract.utils.model_manager import UnifiedModelManager
print("βœ“ All imports successful")
return True
except Exception as e:
print(f"βœ— Import failed: {e}")
return False
def test_quantization_config():
"""Test quantization configuration"""
try:
from ai_med_extract.utils.model_config import QUANTIZATION_CONFIG
from ai_med_extract.utils.quantization_utils import get_quantization_config
config = get_quantization_config()
assert config['enabled'] == False, "Quantization should be disabled by default"
assert 'num_centroids' in config
assert 'num_iterations' in config
print("βœ“ Quantization config test passed")
return True
except Exception as e:
print(f"βœ— Config test failed: {e}")
return False
def test_model_manager_initialization():
"""Test model manager can be initialized without errors"""
try:
from ai_med_extract.utils.model_manager import UnifiedModelManager
manager = UnifiedModelManager()
print("βœ“ Model manager initialization successful")
return True
except Exception as e:
print(f"βœ— Model manager init failed: {e}")
return False
def test_quantization_function():
"""Test quantization function with dummy data"""
try:
import torch
from ai_med_extract.utils.quantization_utils import sinkhorn_normalized_quantization
# Create dummy weights
weights = torch.randn(10, 10)
quantized = sinkhorn_normalized_quantization(weights, num_centroids=8, num_iterations=5)
assert quantized.shape == weights.shape, "Shape should be preserved"
print("βœ“ Quantization function test passed")
return True
except Exception as e:
print(f"βœ— Quantization function test failed: {e}")
return False
def main():
print("Running Sinkhorn-Normalized Quantization tests...\n")
tests = [
test_imports,
test_quantization_config,
test_model_manager_initialization,
test_quantization_function,
]
passed = 0
total = len(tests)
for test in tests:
if test():
passed += 1
print()
print(f"Results: {passed}/{total} tests passed")
if passed == total:
print("πŸŽ‰ All tests passed! Sinkhorn-Normalized Quantization integration is working.")
return 0
else:
print("❌ Some tests failed. Please check the implementation.")
return 1
if __name__ == "__main__":
sys.exit(main())