#!/usr/bin/env python3 """ Test script for Sinkhorn-Normalized Quantization integration """ import sys import os sys.path.insert(0, 'services/ai-service/src') def test_imports(): """Test that all modules can be imported""" try: from ai_med_extract.utils.quantization_utils import quantize_model_weights, get_quantization_config from ai_med_extract.utils.model_config import QUANTIZATION_CONFIG from ai_med_extract.utils.model_manager import UnifiedModelManager print("✓ All imports successful") return True except Exception as e: print(f"✗ Import failed: {e}") return False def test_quantization_config(): """Test quantization configuration""" try: from ai_med_extract.utils.model_config import QUANTIZATION_CONFIG from ai_med_extract.utils.quantization_utils import get_quantization_config config = get_quantization_config() assert config['enabled'] == False, "Quantization should be disabled by default" assert 'num_centroids' in config assert 'num_iterations' in config print("✓ Quantization config test passed") return True except Exception as e: print(f"✗ Config test failed: {e}") return False def test_model_manager_initialization(): """Test model manager can be initialized without errors""" try: from ai_med_extract.utils.model_manager import UnifiedModelManager manager = UnifiedModelManager() print("✓ Model manager initialization successful") return True except Exception as e: print(f"✗ Model manager init failed: {e}") return False def test_quantization_function(): """Test quantization function with dummy data""" try: import torch from ai_med_extract.utils.quantization_utils import sinkhorn_normalized_quantization # Create dummy weights weights = torch.randn(10, 10) quantized = sinkhorn_normalized_quantization(weights, num_centroids=8, num_iterations=5) assert quantized.shape == weights.shape, "Shape should be preserved" print("✓ Quantization function test passed") return True except Exception as e: print(f"✗ Quantization function test failed: {e}") return False def main(): print("Running Sinkhorn-Normalized Quantization tests...\n") tests = [ test_imports, test_quantization_config, test_model_manager_initialization, test_quantization_function, ] passed = 0 total = len(tests) for test in tests: if test(): passed += 1 print() print(f"Results: {passed}/{total} tests passed") if passed == total: print("🎉 All tests passed! Sinkhorn-Normalized Quantization integration is working.") return 0 else: print("❌ Some tests failed. Please check the implementation.") return 1 if __name__ == "__main__": sys.exit(main())