File size: 5,331 Bytes
894fa47 ee5617c 894fa47 ee5617c 894fa47 ee5617c 894fa47 ee5617c 894fa47 ee5617c 894fa47 ee5617c 894fa47 ee5617c 894fa47 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | from fastapi import APIRouter, HTTPException, BackgroundTasks
from pydantic import BaseModel
import logging
import torch
import numpy as np
import base64
import cv2
import httpx
import traceback
from PIL import Image
import io
from app.models.disease_model import get_model
from app.utils.preprocessing import ImagePreprocessor, load_image_from_bytes
from app.api.diagnosis import _extract_storage_object_path, _standardize_transform
from app.database.supabase_client import get_supabase_client, IMAGES_BUCKET
from app.services.explainability_service import ExplainabilityService, build_overlay
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/explain", tags=["Explainability"])
class ExplainRequest(BaseModel):
image_url: str
diagnosis_data: dict
class ExplainResponse(BaseModel):
attention_heatmap_base64: str
attention_bbox_base64: str
highest_attention_crop_base64: str
zone_reference_base64: str
gpt_statement: str
def get_image_bytes(image_url: str) -> bytes:
try:
# Check if it's a supabase storage path
clean_path = _extract_storage_object_path(image_url)
if clean_path and not image_url.startswith("http"):
# Download from supabase storage
supabase = get_supabase_client()
res = supabase.storage.from_(IMAGES_BUCKET).download(clean_path)
return res
# Download from URL
with httpx.Client(timeout=30.0) as client:
response = client.get(image_url, follow_redirects=True)
response.raise_for_status()
return response.content
except Exception as e:
logger.error(f"Failed to fetch image: {e}")
raise ValueError(f"Failed to fetch image for explainability: {str(e)}")
def prepare_explainability_input(image_bytes: bytes, preprocessor: ImagePreprocessor) -> tuple[torch.Tensor, np.ndarray]:
image = load_image_from_bytes(image_bytes)
# Needs to match exactly the model input format
raw_tile = _standardize_transform(image)
tile_tensor = preprocessor.transform(raw_tile)
# Add batch dimension
preprocessed_image = tile_tensor.unsqueeze(0)
# Convert original to numpy for overlay
img_array = np.array(raw_tile.convert('RGB'))
return preprocessed_image, img_array
def tensor_to_base64(overlay: np.ndarray) -> str:
overlay_uint8 = (overlay * 255).astype(np.uint8)
# Convert RGB to BGR for OpenCV encoding
overlay_bgr = cv2.cvtColor(overlay_uint8, cv2.COLOR_RGB2BGR)
_, buffer = cv2.imencode('.jpg', overlay_bgr, [cv2.IMWRITE_JPEG_QUALITY, 85])
b64 = base64.b64encode(buffer).decode('utf-8')
return f"data:image/jpeg;base64,{b64}"
@router.post("", response_model=ExplainResponse)
async def explain_prediction(request: ExplainRequest):
logger.info(f"Received explainability request for {request.image_url}")
try:
# 1. Fetch image
image_bytes = get_image_bytes(request.image_url)
# 2. Get Model
wrapper = get_model()
if not wrapper or not wrapper.model:
raise HTTPException(status_code=503, detail="Model not loaded")
preprocessor = ImagePreprocessor(target_size=256)
preprocessed_image, original_array = prepare_explainability_input(image_bytes, preprocessor)
# 3. Generate XAI Maps
service = ExplainabilityService(wrapper)
xai_results = service.generate_heatmaps(
preprocessed_image,
original_array,
request.diagnosis_data
)
# 4. Extract comprehensive features for GPT
features = service.generate_comprehensive_features(xai_results['attention_heatmap'], original_array)
# 5. Generate Overlays & Base64 conversions
attention_overlay = build_overlay(original_array, xai_results['attention_heatmap'])
# Helper for direct image to base64 (for bbox, crop, zone which are already RGB/BGR)
def image_to_base64(img: np.ndarray) -> str:
# If it's BGR from cv2, keep it. If RGB, convert to BGR for imencode
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
_, buffer = cv2.imencode('.jpg', img_bgr, [cv2.IMWRITE_JPEG_QUALITY, 85])
b64 = base64.b64encode(buffer).decode('utf-8')
return f"data:image/jpeg;base64,{b64}"
# 6. Generate GPT Statement
gpt_statement = service.generate_gpt_explanation(features, request.diagnosis_data)
return ExplainResponse(
attention_heatmap_base64=tensor_to_base64(attention_overlay),
attention_bbox_base64=image_to_base64(xai_results['bbox_overlay']),
highest_attention_crop_base64=image_to_base64(xai_results['highest_attention_crop']),
zone_reference_base64=image_to_base64(xai_results['zone_reference']),
gpt_statement=gpt_statement
)
except ValueError as e:
logger.error(f"Value Error: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Internal error during explainability: {e}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail="Failed to generate explainability report")
|