File size: 5,331 Bytes
894fa47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee5617c
 
 
894fa47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee5617c
894fa47
ee5617c
894fa47
 
 
 
ee5617c
894fa47
ee5617c
 
894fa47
ee5617c
 
 
 
 
 
 
 
894fa47
 
 
 
 
ee5617c
 
 
894fa47
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from fastapi import APIRouter, HTTPException, BackgroundTasks
from pydantic import BaseModel
import logging
import torch
import numpy as np
import base64
import cv2
import httpx
import traceback
from PIL import Image
import io

from app.models.disease_model import get_model
from app.utils.preprocessing import ImagePreprocessor, load_image_from_bytes
from app.api.diagnosis import _extract_storage_object_path, _standardize_transform
from app.database.supabase_client import get_supabase_client, IMAGES_BUCKET
from app.services.explainability_service import ExplainabilityService, build_overlay

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/explain", tags=["Explainability"])

class ExplainRequest(BaseModel):
    image_url: str
    diagnosis_data: dict

class ExplainResponse(BaseModel):
    attention_heatmap_base64: str
    attention_bbox_base64: str
    highest_attention_crop_base64: str
    zone_reference_base64: str
    gpt_statement: str

def get_image_bytes(image_url: str) -> bytes:
    try:
        # Check if it's a supabase storage path
        clean_path = _extract_storage_object_path(image_url)
        if clean_path and not image_url.startswith("http"):
            # Download from supabase storage
            supabase = get_supabase_client()
            res = supabase.storage.from_(IMAGES_BUCKET).download(clean_path)
            return res
        
        # Download from URL
        with httpx.Client(timeout=30.0) as client:
            response = client.get(image_url, follow_redirects=True)
            response.raise_for_status()
            return response.content
    except Exception as e:
        logger.error(f"Failed to fetch image: {e}")
        raise ValueError(f"Failed to fetch image for explainability: {str(e)}")

def prepare_explainability_input(image_bytes: bytes, preprocessor: ImagePreprocessor) -> tuple[torch.Tensor, np.ndarray]:
    image = load_image_from_bytes(image_bytes)
    
    # Needs to match exactly the model input format
    raw_tile = _standardize_transform(image)
    tile_tensor = preprocessor.transform(raw_tile)
    
    # Add batch dimension
    preprocessed_image = tile_tensor.unsqueeze(0)
    
    # Convert original to numpy for overlay
    img_array = np.array(raw_tile.convert('RGB'))
    
    return preprocessed_image, img_array

def tensor_to_base64(overlay: np.ndarray) -> str:
    overlay_uint8 = (overlay * 255).astype(np.uint8)
    # Convert RGB to BGR for OpenCV encoding
    overlay_bgr = cv2.cvtColor(overlay_uint8, cv2.COLOR_RGB2BGR)
    _, buffer = cv2.imencode('.jpg', overlay_bgr, [cv2.IMWRITE_JPEG_QUALITY, 85])
    b64 = base64.b64encode(buffer).decode('utf-8')
    return f"data:image/jpeg;base64,{b64}"

@router.post("", response_model=ExplainResponse)
async def explain_prediction(request: ExplainRequest):
    logger.info(f"Received explainability request for {request.image_url}")
    
    try:
        # 1. Fetch image
        image_bytes = get_image_bytes(request.image_url)
        
        # 2. Get Model
        wrapper = get_model()
        if not wrapper or not wrapper.model:
            raise HTTPException(status_code=503, detail="Model not loaded")
            
        preprocessor = ImagePreprocessor(target_size=256)
        preprocessed_image, original_array = prepare_explainability_input(image_bytes, preprocessor)
        
        # 3. Generate XAI Maps
        service = ExplainabilityService(wrapper)
        xai_results = service.generate_heatmaps(
            preprocessed_image, 
            original_array,
            request.diagnosis_data
        )
        
        # 4. Extract comprehensive features for GPT
        features = service.generate_comprehensive_features(xai_results['attention_heatmap'], original_array)
        
        # 5. Generate Overlays & Base64 conversions
        attention_overlay = build_overlay(original_array, xai_results['attention_heatmap'])
        
        # Helper for direct image to base64 (for bbox, crop, zone which are already RGB/BGR)
        def image_to_base64(img: np.ndarray) -> str:
            # If it's BGR from cv2, keep it. If RGB, convert to BGR for imencode
            img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            _, buffer = cv2.imencode('.jpg', img_bgr, [cv2.IMWRITE_JPEG_QUALITY, 85])
            b64 = base64.b64encode(buffer).decode('utf-8')
            return f"data:image/jpeg;base64,{b64}"

        # 6. Generate GPT Statement
        gpt_statement = service.generate_gpt_explanation(features, request.diagnosis_data)
        
        return ExplainResponse(
            attention_heatmap_base64=tensor_to_base64(attention_overlay),
            attention_bbox_base64=image_to_base64(xai_results['bbox_overlay']),
            highest_attention_crop_base64=image_to_base64(xai_results['highest_attention_crop']),
            zone_reference_base64=image_to_base64(xai_results['zone_reference']),
            gpt_statement=gpt_statement
        )
        
    except ValueError as e:
        logger.error(f"Value Error: {e}")
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        logger.error(f"Internal error during explainability: {e}")
        logger.error(traceback.format_exc())
        raise HTTPException(status_code=500, detail="Failed to generate explainability report")