| """Test script for explainability features.""" |
| import asyncio |
| import traceback |
| import numpy as np |
| from PIL import Image |
| import io |
|
|
| async def main(): |
| from app.services.model_registry import get_model_registry |
| from app.core.config import settings |
| |
| registry = get_model_registry() |
| |
| |
| print("Loading models...") |
| await registry.load_from_fusion_repo(settings.HF_FUSION_REPO_ID) |
| print("Models loaded!") |
| |
| |
| img = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) |
| buf = io.BytesIO() |
| img.save(buf, format='PNG') |
| img_bytes = buf.getvalue() |
| |
| |
| models = ['cnn-transfer', 'gradfield-cnn', 'vit-base', 'deit-distilled'] |
| |
| for model_name in models: |
| print(f"\nTesting {model_name}...") |
| try: |
| model = registry.get_submodel(model_name) |
| result = model.predict(image_bytes=img_bytes, explain=True) |
| has_heatmap = 'heatmap_base64' in result |
| print(f" Success! pred={result['pred']}, has_heatmap={has_heatmap}") |
| if has_heatmap: |
| |
| import base64 |
| decoded = base64.b64decode(result['heatmap_base64']) |
| print(f" Heatmap size: {len(decoded)} bytes") |
| except Exception as e: |
| print(f" ERROR: {e}") |
| traceback.print_exc() |
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |
|
|