Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import cv2 | |
| import base64 | |
| import logging | |
| import numpy as np | |
| import copy | |
| from io import BytesIO | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.concurrency import run_in_threadpool | |
| # -------- Optional XAI libraries ---------- | |
| try: | |
| import shap | |
| _HAS_SHAP = True | |
| except Exception: | |
| _HAS_SHAP = False | |
| try: | |
| from lime import lime_image | |
| from skimage.segmentation import slic | |
| _HAS_LIME = True | |
| except Exception: | |
| _HAS_LIME = False | |
| # ---------------- Your model ---------------- | |
| from model import ModifiedMobileNetV2 | |
| # ====================== Logging setup ====================== | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ====================== FastAPI setup ====================== | |
| app = FastAPI( | |
| title="Gallbladder Classification API (Fast XAI)", | |
| description="Gallbladder ultrasound classifier + Optimized Grad-CAM, SHAP, LIME explanations" | |
| ) | |
| # ====================== CORS ====================== | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ====================== Class Names ====================== | |
| class_names = [ | |
| 'Gallstones', 'Intra-abdominal&Retroperitoneum', 'Cholecystitis', 'Gangrenous_Cholecystitis', | |
| 'Perforation', 'Polyps&Cholesterol_Crystal', 'Adenomyomatosis', | |
| 'Carcinoma', 'WallThickening', 'Normal' | |
| ] | |
| # ====================== Device & Model ====================== | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| logger.info(f"Using device: {device}") | |
| model = None | |
| def _strip_module_prefix(state_dict): | |
| """Remove leading 'module.' from keys if present (DataParallel artifacts).""" | |
| new_state = {} | |
| for k, v in state_dict.items(): | |
| new_key = k | |
| if k.startswith("module."): | |
| new_key = k[len("module."):] | |
| new_state[new_key] = v | |
| return new_state | |
| def load_model(model_path="GB_stu_mob.pth"): | |
| global model | |
| if not os.path.exists(model_path): | |
| logger.error(f"Model file not found at {model_path}") | |
| raise FileNotFoundError(f"Model file not found at {model_path}") | |
| model = ModifiedMobileNetV2(num_classes=len(class_names)).to(device) | |
| checkpoint = torch.load(model_path, map_location=device) | |
| state_dict = _strip_module_prefix(checkpoint) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| logger.info("✅ Model loaded and set to eval()") | |
| try: | |
| load_model() | |
| _MODEL_READY = True | |
| except Exception as e: | |
| logger.exception("Model load failed") | |
| _MODEL_READY = False | |
| # ====================== Preprocessing ====================== | |
| IMG_SIZE = 224 | |
| normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| normalize | |
| ]) | |
| def pil_to_tensor(pil_img): | |
| return preprocess(pil_img).unsqueeze(0) # 1 x C x H x W | |
| # ====================== Utils ====================== | |
| def _encode_png(np_img_bgr_uint8) -> str: | |
| """Encodes a BGR uint8 numpy array to a base64 PNG string.""" | |
| if np_img_bgr_uint8 is None or np_img_bgr_uint8.size == 0: | |
| raise ValueError("Image provided to _encode_png is empty.") | |
| # CRITICAL FIX: Ensure array is contiguous and uint8 to satisfy OpenCV/libpng | |
| if not np_img_bgr_uint8.flags['C_CONTIGUOUS']: | |
| np_img_bgr_uint8 = np.ascontiguousarray(np_img_bgr_uint8) | |
| if np_img_bgr_uint8.dtype != np.uint8: | |
| np_img_bgr_uint8 = np_img_bgr_uint8.astype(np.uint8) | |
| ok, buf = cv2.imencode(".png", np_img_bgr_uint8) | |
| if not ok: | |
| raise RuntimeError("PNG encoding failed") | |
| return base64.b64encode(buf.tobytes()).decode("utf-8") | |
| def _overlay_heatmap(rgb_uint8, heatmap_float01): | |
| """ | |
| Applies Jet colormap, but forces low confidence to Black. | |
| MODIFIED: Uses additive blending for much brighter results. | |
| """ | |
| # 1. Sanitize Inputs | |
| if heatmap_float01.ndim > 2: heatmap_float01 = heatmap_float01.squeeze() | |
| if heatmap_float01.ndim == 3: heatmap_float01 = heatmap_float01.mean(axis=2) | |
| h, w = rgb_uint8.shape[:2] | |
| if heatmap_float01.shape != (h, w): | |
| heatmap_float01 = cv2.resize(heatmap_float01, (w, h)) | |
| # 2. Prepare for Opencv (float32 for accurate math) | |
| bgr = cv2.cvtColor(rgb_uint8, cv2.COLOR_RGB2BGR).astype(np.float32) | |
| heatmap = heatmap_float01.astype(np.float32).clip(0, 1) | |
| # 3. Apply Jet Colormap | |
| heatmap_uint8 = (heatmap * 255).astype(np.uint8) | |
| jet_bgr = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET).astype(np.float32) | |
| # --- BLENDING LOGIC CHANGED HERE --- | |
| # 4. Prepare Heatmap Glow | |
| mask = heatmap[..., None] # (H, W, 1) | |
| # Adjust this 'glow_strength' value if needed. | |
| # 0.7 is strong, 1.0 is very intense. | |
| glow_strength = 0.7 | |
| jet_masked = jet_bgr * mask * glow_strength | |
| # 5. Blend (Additive) | |
| # We ADD the glow to the original BGR image. | |
| # The background stays full brightness. | |
| overlay = bgr + jet_masked | |
| # 6. Finalize (Crucial: clip to 255 to prevent color weirdness) | |
| overlay = overlay.clip(0, 255).astype(np.uint8) | |
| return overlay | |
| def _last_conv_name_for_mobilenet_v2(model: nn.Module) -> str: | |
| last_name = None | |
| for name, m in model.named_modules(): | |
| if isinstance(m, nn.Conv2d): | |
| last_name = name | |
| if last_name is None: | |
| raise RuntimeError("No Conv2d layer found for Grad-CAM") | |
| return last_name | |
| # ====================== Inference ====================== | |
| def _predict_logits(tensor_b1chw: torch.Tensor) -> torch.Tensor: | |
| model.eval() | |
| return model(tensor_b1chw.to(device)) | |
| def _predict_proba_numpy(np_bhwc_uint8: np.ndarray) -> np.ndarray: | |
| model.eval() | |
| imgs = [] | |
| for img in np_bhwc_uint8: | |
| pil = Image.fromarray(img.astype(np.uint8), mode="RGB") | |
| imgs.append(preprocess(pil)) | |
| batch = torch.stack(imgs).to(device) | |
| with torch.no_grad(): | |
| logits = model(batch) | |
| probs = torch.softmax(logits, dim=1).cpu().numpy() | |
| return probs | |
| def predict_pil(pil_img: Image.Image): | |
| if not _MODEL_READY: | |
| raise HTTPException(status_code=500, detail="Model not loaded") | |
| x = pil_to_tensor(pil_img).to(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| logits = _predict_logits(x) | |
| probs = torch.softmax(logits, dim=1) | |
| idx = int(torch.argmax(probs, dim=1).item()) | |
| conf = float(probs[0, idx].item()) | |
| return class_names[idx], conf, idx | |
| # ====================== Grad-CAM ====================== | |
| def gradcam_explain(pil_img: Image.Image, class_idx: int = None) -> str: | |
| model.zero_grad() | |
| gradients = [] | |
| activations = [] | |
| target_layer = _last_conv_name_for_mobilenet_v2(model) | |
| layer_ref = dict(model.named_modules())[target_layer] | |
| def fwd_hook(_, __, output): activations.append(output) | |
| def bwd_hook(_, grad_in, grad_out): gradients.append(grad_out[0]) | |
| fwd_h = layer_ref.register_forward_hook(fwd_hook) | |
| bwd_h = layer_ref.register_backward_hook(bwd_hook) | |
| try: | |
| x = pil_to_tensor(pil_img).to(device) | |
| logits = model(x) | |
| if class_idx is None: | |
| class_idx = int(torch.argmax(logits, dim=1).item()) | |
| score = logits[0, class_idx] | |
| model.zero_grad() | |
| score.backward() | |
| grad = gradients[0].detach().cpu().numpy()[0] | |
| act = activations[0].detach().cpu().numpy()[0] | |
| weights = grad.mean(axis=(1, 2)) | |
| cam = np.maximum(np.sum(weights[:, None, None] * act, axis=0), 0) | |
| cam = cv2.resize(cam, (IMG_SIZE, IMG_SIZE)) | |
| cam -= cam.min() | |
| if cam.max() > 1e-12: | |
| cam /= cam.max() | |
| else: | |
| cam[:] = 0.0 | |
| rgb = np.array(pil_img.resize((IMG_SIZE, IMG_SIZE))).astype(np.uint8) | |
| out_bgr = _overlay_heatmap(rgb, cam) | |
| return _encode_png(out_bgr) | |
| finally: | |
| fwd_h.remove() | |
| bwd_h.remove() | |
| # ====================== SHAP ====================== | |
| def shap_explain(pil_img: Image.Image, top_label: int) -> str: | |
| if not _HAS_SHAP: raise HTTPException(status_code=400, detail="SHAP missing") | |
| # --- Model Copy Logic --- | |
| temp_model = copy.deepcopy(model) | |
| temp_model.eval() | |
| def disable_inplace(m): | |
| if hasattr(m, "inplace"): m.inplace = False | |
| def replace_relu6(m): | |
| for name, child in list(m.named_children()): | |
| if isinstance(child, nn.ReLU6): setattr(m, name, nn.ReLU(inplace=False)) | |
| else: replace_relu6(child) | |
| temp_model.apply(disable_inplace) | |
| replace_relu6(temp_model) | |
| def predict_numpy(np_imgs): | |
| batch = [] | |
| for arr in np_imgs: | |
| pil = Image.fromarray(arr.astype(np.uint8)) | |
| batch.append(preprocess(pil)) | |
| x = torch.stack(batch).to(device) | |
| with torch.no_grad(): | |
| prob = torch.softmax(temp_model(x), dim=1).cpu().numpy() | |
| return prob | |
| # ----------------------- | |
| img_224 = pil_img.resize((IMG_SIZE, IMG_SIZE)) | |
| x_nhwc = np.asarray(img_224, dtype=np.uint8)[None, ...] | |
| masker = shap.maskers.Image("blur(16,16)", x_nhwc[0].shape) | |
| explainer = shap.Explainer(predict_numpy, masker, algorithm="partition") | |
| # Run SHAP | |
| shap_values = explainer(x_nhwc, max_evals=300, batch_size=50, outputs=[top_label]) | |
| # --- Robust Shape Handling --- | |
| if hasattr(shap_values, 'values'): | |
| sv = shap_values.values | |
| else: | |
| sv = shap_values | |
| # Handle shape variations (Batch, H, W, C, Outputs) or (Batch, H, W, C) | |
| # We want strictly (H, W, C) at the end | |
| sv = np.array(sv) # Ensure numpy | |
| # 1. Strip batch dim if present | |
| if sv.ndim >= 4 and sv.shape[0] == 1: | |
| sv = sv[0] | |
| # 2. Strip output dim if present (e.g., shape is H, W, C, 1) | |
| if sv.ndim == 4 and sv.shape[-1] == 1: | |
| sv = sv[..., 0] | |
| # 3. Final validation | |
| if sv.ndim != 3: | |
| logger.warning(f"SHAP unexpected shape: {sv.shape}. Trying to squeeze.") | |
| sv = sv.squeeze() | |
| if sv.ndim != 3: | |
| # Last resort: if we can't find H,W,C, fail gracefully | |
| logger.error("Could not resolve SHAP values to (H,W,C).") | |
| return _encode_png(np.array(img_224)) | |
| # Handle NaNs | |
| sv = np.nan_to_num(sv, nan=0.0) | |
| # Convert to Grayscale Heatmap | |
| sv_gray = np.mean(np.abs(sv), axis=2) # (H, W) | |
| # Normalize | |
| sv_gray -= sv_gray.min() | |
| if sv_gray.max() > 1e-8: | |
| sv_gray /= sv_gray.max() | |
| img_rgb = np.array(img_224).astype(np.uint8) | |
| # Create overlay | |
| overlay = _overlay_heatmap(img_rgb, sv_gray) | |
| del temp_model | |
| torch.cuda.empty_cache() | |
| return _encode_png(overlay) | |
| # ====================== LIME ====================== | |
| def lime_explain(pil_img: Image.Image, top_label: int) -> str: | |
| if not _HAS_LIME: raise HTTPException(status_code=400, detail="LIME missing") | |
| rgb = np.array(pil_img.resize((IMG_SIZE, IMG_SIZE))).astype(np.uint8) | |
| explainer = lime_image.LimeImageExplainer() | |
| try: | |
| explanation = explainer.explain_instance( | |
| image=rgb, classifier_fn=_predict_proba_numpy, top_labels=3, | |
| hide_color=0, num_samples=300, | |
| segmentation_fn=lambda x: slic(x, n_segments=200, compactness=0.01, sigma=1) | |
| ) | |
| except Exception as e: | |
| logger.exception("LIME failed") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| chosen_label = int(explanation.top_labels[0]) if explanation.top_labels else int(top_label) | |
| temp_img, mask = explanation.get_image_and_mask( | |
| label=chosen_label, positive_only=False, num_features=10, hide_rest=False | |
| ) | |
| mask_float = np.abs(mask).astype(np.float32) | |
| if mask_float.max() > 1e-12: | |
| mask_float /= mask_float.max() | |
| mask_float = cv2.GaussianBlur(mask_float, (3, 3), 0) | |
| out_bgr = _overlay_heatmap(rgb, mask_float) | |
| return _encode_png(out_bgr) | |
| # ====================== API ROUTES ====================== | |
| async def root(): | |
| return {"message": "Gallbladder Classification API (XAI Optimized)"} | |
| async def health(): | |
| return { | |
| "status": "healthy" if _MODEL_READY else "unhealthy", | |
| "device": str(device), | |
| "has_shap": _HAS_SHAP, | |
| "has_lime": _HAS_LIME | |
| } | |
| async def predict_image(file: UploadFile = File(...)): | |
| if not _MODEL_READY: | |
| raise HTTPException(status_code=500, detail="Model not loaded") | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| data = await file.read() | |
| pil = Image.open(BytesIO(data)).convert("RGB") | |
| cls, conf, _ = predict_pil(pil) | |
| return { | |
| "filename": file.filename, | |
| "predicted_class": cls, | |
| "confidence_score": round(conf, 4) | |
| } | |
| async def explain_image( | |
| file: UploadFile = File(...), | |
| method: str = Query("all", regex="^(gradcam|shap|lime|all)$") | |
| ): | |
| if not _MODEL_READY: | |
| raise HTTPException(status_code=500, detail="Model not loaded") | |
| data = await file.read() | |
| pil = Image.open(BytesIO(data)).convert("RGB") | |
| cls, conf, top_idx = predict_pil(pil) | |
| out = { | |
| "predicted_class": cls, | |
| "confidence_score": round(conf, 4), | |
| "xai": {} | |
| } | |
| if method in ("gradcam", "all"): | |
| try: | |
| out["xai"]["gradcam"] = await run_in_threadpool(gradcam_explain, pil, top_idx) | |
| except Exception as e: | |
| logger.exception("Grad-CAM failed") | |
| out["xai"]["gradcam_error"] = str(e) | |
| if method in ("shap", "all"): | |
| if not _HAS_SHAP: | |
| out["xai"]["shap_error"] = "SHAP not installed" | |
| else: | |
| try: | |
| out["xai"]["shap"] = await run_in_threadpool(shap_explain, pil, top_idx) | |
| except Exception as e: | |
| logger.exception("SHAP failed") | |
| out["xai"]["shap_error"] = str(e) | |
| if method in ("lime", "all"): | |
| if not _HAS_LIME: | |
| out["xai"]["lime_error"] = "LIME not installed" | |
| else: | |
| try: | |
| out["xai"]["lime"] = await run_in_threadpool(lime_explain, pil, top_idx) | |
| except Exception as e: | |
| logger.exception("LIME failed") | |
| out["xai"]["lime_error"] = str(e) | |
| return out | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |