| from fastapi import APIRouter, HTTPException, BackgroundTasks |
| from pydantic import BaseModel |
| import logging |
| import torch |
| import numpy as np |
| import base64 |
| import cv2 |
| import httpx |
| import traceback |
| from PIL import Image |
| import io |
|
|
| from app.models.disease_model import get_model |
| from app.utils.preprocessing import ImagePreprocessor, load_image_from_bytes |
| from app.api.diagnosis import _extract_storage_object_path, _standardize_transform |
| from app.database.supabase_client import get_supabase_client, IMAGES_BUCKET |
| from app.services.explainability_service import ExplainabilityService, build_overlay |
|
|
| logger = logging.getLogger(__name__) |
|
|
| router = APIRouter(prefix="/explain", tags=["Explainability"]) |
|
|
| class ExplainRequest(BaseModel): |
| image_url: str |
| diagnosis_data: dict |
|
|
| class ExplainResponse(BaseModel): |
| attention_heatmap_base64: str |
| attention_bbox_base64: str |
| highest_attention_crop_base64: str |
| zone_reference_base64: str |
| gpt_statement: str |
|
|
| def get_image_bytes(image_url: str) -> bytes: |
| try: |
| |
| clean_path = _extract_storage_object_path(image_url) |
| if clean_path and not image_url.startswith("http"): |
| |
| supabase = get_supabase_client() |
| res = supabase.storage.from_(IMAGES_BUCKET).download(clean_path) |
| return res |
| |
| |
| with httpx.Client(timeout=30.0) as client: |
| response = client.get(image_url, follow_redirects=True) |
| response.raise_for_status() |
| return response.content |
| except Exception as e: |
| logger.error(f"Failed to fetch image: {e}") |
| raise ValueError(f"Failed to fetch image for explainability: {str(e)}") |
|
|
| def prepare_explainability_input(image_bytes: bytes, preprocessor: ImagePreprocessor) -> tuple[torch.Tensor, np.ndarray]: |
| image = load_image_from_bytes(image_bytes) |
| |
| |
| raw_tile = _standardize_transform(image) |
| tile_tensor = preprocessor.transform(raw_tile) |
| |
| |
| preprocessed_image = tile_tensor.unsqueeze(0) |
| |
| |
| img_array = np.array(raw_tile.convert('RGB')) |
| |
| return preprocessed_image, img_array |
|
|
| def tensor_to_base64(overlay: np.ndarray) -> str: |
| overlay_uint8 = (overlay * 255).astype(np.uint8) |
| |
| overlay_bgr = cv2.cvtColor(overlay_uint8, cv2.COLOR_RGB2BGR) |
| _, buffer = cv2.imencode('.jpg', overlay_bgr, [cv2.IMWRITE_JPEG_QUALITY, 85]) |
| b64 = base64.b64encode(buffer).decode('utf-8') |
| return f"data:image/jpeg;base64,{b64}" |
|
|
| @router.post("", response_model=ExplainResponse) |
| async def explain_prediction(request: ExplainRequest): |
| logger.info(f"Received explainability request for {request.image_url}") |
| |
| try: |
| |
| image_bytes = get_image_bytes(request.image_url) |
| |
| |
| wrapper = get_model() |
| if not wrapper or not wrapper.model: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
| |
| preprocessor = ImagePreprocessor(target_size=256) |
| preprocessed_image, original_array = prepare_explainability_input(image_bytes, preprocessor) |
| |
| |
| service = ExplainabilityService(wrapper) |
| xai_results = service.generate_heatmaps( |
| preprocessed_image, |
| original_array, |
| request.diagnosis_data |
| ) |
| |
| |
| features = service.generate_comprehensive_features(xai_results['attention_heatmap'], original_array) |
| |
| |
| attention_overlay = build_overlay(original_array, xai_results['attention_heatmap']) |
| |
| |
| def image_to_base64(img: np.ndarray) -> str: |
| |
| img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
| _, buffer = cv2.imencode('.jpg', img_bgr, [cv2.IMWRITE_JPEG_QUALITY, 85]) |
| b64 = base64.b64encode(buffer).decode('utf-8') |
| return f"data:image/jpeg;base64,{b64}" |
|
|
| |
| gpt_statement = service.generate_gpt_explanation(features, request.diagnosis_data) |
| |
| return ExplainResponse( |
| attention_heatmap_base64=tensor_to_base64(attention_overlay), |
| attention_bbox_base64=image_to_base64(xai_results['bbox_overlay']), |
| highest_attention_crop_base64=image_to_base64(xai_results['highest_attention_crop']), |
| zone_reference_base64=image_to_base64(xai_results['zone_reference']), |
| gpt_statement=gpt_statement |
| ) |
| |
| except ValueError as e: |
| logger.error(f"Value Error: {e}") |
| raise HTTPException(status_code=400, detail=str(e)) |
| except Exception as e: |
| logger.error(f"Internal error during explainability: {e}") |
| logger.error(traceback.format_exc()) |
| raise HTTPException(status_code=500, detail="Failed to generate explainability report") |
|
|