import os import io import cv2 import base64 import logging import numpy as np import copy from io import BytesIO from PIL import Image import torch import torch.nn as nn from torchvision import transforms from fastapi import FastAPI, File, UploadFile, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.concurrency import run_in_threadpool # -------- Optional XAI libraries ---------- try: import shap _HAS_SHAP = True except Exception: _HAS_SHAP = False try: from lime import lime_image from skimage.segmentation import slic _HAS_LIME = True except Exception: _HAS_LIME = False # ---------------- Your model ---------------- from model import ModifiedMobileNetV2 # ====================== Logging setup ====================== logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ====================== FastAPI setup ====================== app = FastAPI( title="Gallbladder Classification API (Fast XAI)", description="Gallbladder ultrasound classifier + Optimized Grad-CAM, SHAP, LIME explanations" ) # ====================== CORS ====================== app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ====================== Class Names ====================== class_names = [ 'Gallstones', 'Intra-abdominal&Retroperitoneum', 'Cholecystitis', 'Gangrenous_Cholecystitis', 'Perforation', 'Polyps&Cholesterol_Crystal', 'Adenomyomatosis', 'Carcinoma', 'WallThickening', 'Normal' ] # ====================== Device & Model ====================== device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"Using device: {device}") model = None def _strip_module_prefix(state_dict): """Remove leading 'module.' from keys if present (DataParallel artifacts).""" new_state = {} for k, v in state_dict.items(): new_key = k if k.startswith("module."): new_key = k[len("module."):] new_state[new_key] = v return new_state def load_model(model_path="GB_stu_mob.pth"): global model if not os.path.exists(model_path): logger.error(f"Model file not found at {model_path}") raise FileNotFoundError(f"Model file not found at {model_path}") model = ModifiedMobileNetV2(num_classes=len(class_names)).to(device) checkpoint = torch.load(model_path, map_location=device) state_dict = _strip_module_prefix(checkpoint) model.load_state_dict(state_dict) model.to(device) model.eval() logger.info("✅ Model loaded and set to eval()") try: load_model() _MODEL_READY = True except Exception as e: logger.exception("Model load failed") _MODEL_READY = False # ====================== Preprocessing ====================== IMG_SIZE = 224 normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) preprocess = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), normalize ]) def pil_to_tensor(pil_img): return preprocess(pil_img).unsqueeze(0) # 1 x C x H x W # ====================== Utils ====================== def _encode_png(np_img_bgr_uint8) -> str: """Encodes a BGR uint8 numpy array to a base64 PNG string.""" if np_img_bgr_uint8 is None or np_img_bgr_uint8.size == 0: raise ValueError("Image provided to _encode_png is empty.") # CRITICAL FIX: Ensure array is contiguous and uint8 to satisfy OpenCV/libpng if not np_img_bgr_uint8.flags['C_CONTIGUOUS']: np_img_bgr_uint8 = np.ascontiguousarray(np_img_bgr_uint8) if np_img_bgr_uint8.dtype != np.uint8: np_img_bgr_uint8 = np_img_bgr_uint8.astype(np.uint8) ok, buf = cv2.imencode(".png", np_img_bgr_uint8) if not ok: raise RuntimeError("PNG encoding failed") return base64.b64encode(buf.tobytes()).decode("utf-8") def _overlay_heatmap(rgb_uint8, heatmap_float01): """ Applies Jet colormap, but forces low confidence to Black. MODIFIED: Uses additive blending for much brighter results. """ # 1. Sanitize Inputs if heatmap_float01.ndim > 2: heatmap_float01 = heatmap_float01.squeeze() if heatmap_float01.ndim == 3: heatmap_float01 = heatmap_float01.mean(axis=2) h, w = rgb_uint8.shape[:2] if heatmap_float01.shape != (h, w): heatmap_float01 = cv2.resize(heatmap_float01, (w, h)) # 2. Prepare for Opencv (float32 for accurate math) bgr = cv2.cvtColor(rgb_uint8, cv2.COLOR_RGB2BGR).astype(np.float32) heatmap = heatmap_float01.astype(np.float32).clip(0, 1) # 3. Apply Jet Colormap heatmap_uint8 = (heatmap * 255).astype(np.uint8) jet_bgr = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET).astype(np.float32) # --- BLENDING LOGIC CHANGED HERE --- # 4. Prepare Heatmap Glow mask = heatmap[..., None] # (H, W, 1) # Adjust this 'glow_strength' value if needed. # 0.7 is strong, 1.0 is very intense. glow_strength = 0.7 jet_masked = jet_bgr * mask * glow_strength # 5. Blend (Additive) # We ADD the glow to the original BGR image. # The background stays full brightness. overlay = bgr + jet_masked # 6. Finalize (Crucial: clip to 255 to prevent color weirdness) overlay = overlay.clip(0, 255).astype(np.uint8) return overlay def _last_conv_name_for_mobilenet_v2(model: nn.Module) -> str: last_name = None for name, m in model.named_modules(): if isinstance(m, nn.Conv2d): last_name = name if last_name is None: raise RuntimeError("No Conv2d layer found for Grad-CAM") return last_name # ====================== Inference ====================== def _predict_logits(tensor_b1chw: torch.Tensor) -> torch.Tensor: model.eval() return model(tensor_b1chw.to(device)) def _predict_proba_numpy(np_bhwc_uint8: np.ndarray) -> np.ndarray: model.eval() imgs = [] for img in np_bhwc_uint8: pil = Image.fromarray(img.astype(np.uint8), mode="RGB") imgs.append(preprocess(pil)) batch = torch.stack(imgs).to(device) with torch.no_grad(): logits = model(batch) probs = torch.softmax(logits, dim=1).cpu().numpy() return probs def predict_pil(pil_img: Image.Image): if not _MODEL_READY: raise HTTPException(status_code=500, detail="Model not loaded") x = pil_to_tensor(pil_img).to(device) model.eval() with torch.no_grad(): logits = _predict_logits(x) probs = torch.softmax(logits, dim=1) idx = int(torch.argmax(probs, dim=1).item()) conf = float(probs[0, idx].item()) return class_names[idx], conf, idx # ====================== Grad-CAM ====================== def gradcam_explain(pil_img: Image.Image, class_idx: int = None) -> str: model.zero_grad() gradients = [] activations = [] target_layer = _last_conv_name_for_mobilenet_v2(model) layer_ref = dict(model.named_modules())[target_layer] def fwd_hook(_, __, output): activations.append(output) def bwd_hook(_, grad_in, grad_out): gradients.append(grad_out[0]) fwd_h = layer_ref.register_forward_hook(fwd_hook) bwd_h = layer_ref.register_backward_hook(bwd_hook) try: x = pil_to_tensor(pil_img).to(device) logits = model(x) if class_idx is None: class_idx = int(torch.argmax(logits, dim=1).item()) score = logits[0, class_idx] model.zero_grad() score.backward() grad = gradients[0].detach().cpu().numpy()[0] act = activations[0].detach().cpu().numpy()[0] weights = grad.mean(axis=(1, 2)) cam = np.maximum(np.sum(weights[:, None, None] * act, axis=0), 0) cam = cv2.resize(cam, (IMG_SIZE, IMG_SIZE)) cam -= cam.min() if cam.max() > 1e-12: cam /= cam.max() else: cam[:] = 0.0 rgb = np.array(pil_img.resize((IMG_SIZE, IMG_SIZE))).astype(np.uint8) out_bgr = _overlay_heatmap(rgb, cam) return _encode_png(out_bgr) finally: fwd_h.remove() bwd_h.remove() # ====================== SHAP ====================== def shap_explain(pil_img: Image.Image, top_label: int) -> str: if not _HAS_SHAP: raise HTTPException(status_code=400, detail="SHAP missing") # --- Model Copy Logic --- temp_model = copy.deepcopy(model) temp_model.eval() def disable_inplace(m): if hasattr(m, "inplace"): m.inplace = False def replace_relu6(m): for name, child in list(m.named_children()): if isinstance(child, nn.ReLU6): setattr(m, name, nn.ReLU(inplace=False)) else: replace_relu6(child) temp_model.apply(disable_inplace) replace_relu6(temp_model) def predict_numpy(np_imgs): batch = [] for arr in np_imgs: pil = Image.fromarray(arr.astype(np.uint8)) batch.append(preprocess(pil)) x = torch.stack(batch).to(device) with torch.no_grad(): prob = torch.softmax(temp_model(x), dim=1).cpu().numpy() return prob # ----------------------- img_224 = pil_img.resize((IMG_SIZE, IMG_SIZE)) x_nhwc = np.asarray(img_224, dtype=np.uint8)[None, ...] masker = shap.maskers.Image("blur(16,16)", x_nhwc[0].shape) explainer = shap.Explainer(predict_numpy, masker, algorithm="partition") # Run SHAP shap_values = explainer(x_nhwc, max_evals=300, batch_size=50, outputs=[top_label]) # --- Robust Shape Handling --- if hasattr(shap_values, 'values'): sv = shap_values.values else: sv = shap_values # Handle shape variations (Batch, H, W, C, Outputs) or (Batch, H, W, C) # We want strictly (H, W, C) at the end sv = np.array(sv) # Ensure numpy # 1. Strip batch dim if present if sv.ndim >= 4 and sv.shape[0] == 1: sv = sv[0] # 2. Strip output dim if present (e.g., shape is H, W, C, 1) if sv.ndim == 4 and sv.shape[-1] == 1: sv = sv[..., 0] # 3. Final validation if sv.ndim != 3: logger.warning(f"SHAP unexpected shape: {sv.shape}. Trying to squeeze.") sv = sv.squeeze() if sv.ndim != 3: # Last resort: if we can't find H,W,C, fail gracefully logger.error("Could not resolve SHAP values to (H,W,C).") return _encode_png(np.array(img_224)) # Handle NaNs sv = np.nan_to_num(sv, nan=0.0) # Convert to Grayscale Heatmap sv_gray = np.mean(np.abs(sv), axis=2) # (H, W) # Normalize sv_gray -= sv_gray.min() if sv_gray.max() > 1e-8: sv_gray /= sv_gray.max() img_rgb = np.array(img_224).astype(np.uint8) # Create overlay overlay = _overlay_heatmap(img_rgb, sv_gray) del temp_model torch.cuda.empty_cache() return _encode_png(overlay) # ====================== LIME ====================== def lime_explain(pil_img: Image.Image, top_label: int) -> str: if not _HAS_LIME: raise HTTPException(status_code=400, detail="LIME missing") rgb = np.array(pil_img.resize((IMG_SIZE, IMG_SIZE))).astype(np.uint8) explainer = lime_image.LimeImageExplainer() try: explanation = explainer.explain_instance( image=rgb, classifier_fn=_predict_proba_numpy, top_labels=3, hide_color=0, num_samples=300, segmentation_fn=lambda x: slic(x, n_segments=200, compactness=0.01, sigma=1) ) except Exception as e: logger.exception("LIME failed") raise HTTPException(status_code=500, detail=str(e)) chosen_label = int(explanation.top_labels[0]) if explanation.top_labels else int(top_label) temp_img, mask = explanation.get_image_and_mask( label=chosen_label, positive_only=False, num_features=10, hide_rest=False ) mask_float = np.abs(mask).astype(np.float32) if mask_float.max() > 1e-12: mask_float /= mask_float.max() mask_float = cv2.GaussianBlur(mask_float, (3, 3), 0) out_bgr = _overlay_heatmap(rgb, mask_float) return _encode_png(out_bgr) # ====================== API ROUTES ====================== @app.get("/") async def root(): return {"message": "Gallbladder Classification API (XAI Optimized)"} @app.get("/health") async def health(): return { "status": "healthy" if _MODEL_READY else "unhealthy", "device": str(device), "has_shap": _HAS_SHAP, "has_lime": _HAS_LIME } @app.post("/predict") async def predict_image(file: UploadFile = File(...)): if not _MODEL_READY: raise HTTPException(status_code=500, detail="Model not loaded") if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="File must be an image") data = await file.read() pil = Image.open(BytesIO(data)).convert("RGB") cls, conf, _ = predict_pil(pil) return { "filename": file.filename, "predicted_class": cls, "confidence_score": round(conf, 4) } @app.post("/explain") async def explain_image( file: UploadFile = File(...), method: str = Query("all", regex="^(gradcam|shap|lime|all)$") ): if not _MODEL_READY: raise HTTPException(status_code=500, detail="Model not loaded") data = await file.read() pil = Image.open(BytesIO(data)).convert("RGB") cls, conf, top_idx = predict_pil(pil) out = { "predicted_class": cls, "confidence_score": round(conf, 4), "xai": {} } if method in ("gradcam", "all"): try: out["xai"]["gradcam"] = await run_in_threadpool(gradcam_explain, pil, top_idx) except Exception as e: logger.exception("Grad-CAM failed") out["xai"]["gradcam_error"] = str(e) if method in ("shap", "all"): if not _HAS_SHAP: out["xai"]["shap_error"] = "SHAP not installed" else: try: out["xai"]["shap"] = await run_in_threadpool(shap_explain, pil, top_idx) except Exception as e: logger.exception("SHAP failed") out["xai"]["shap_error"] = str(e) if method in ("lime", "all"): if not _HAS_LIME: out["xai"]["lime_error"] = "LIME not installed" else: try: out["xai"]["lime"] = await run_in_threadpool(lime_explain, pil, top_idx) except Exception as e: logger.exception("LIME failed") out["xai"]["lime_error"] = str(e) return out if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)