Prasanta4's picture
Update app.py
9ae5bc2 verified
import os
import io
import cv2
import base64
import logging
import numpy as np
import copy
from io import BytesIO
from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool
# -------- Optional XAI libraries ----------
try:
import shap
_HAS_SHAP = True
except Exception:
_HAS_SHAP = False
try:
from lime import lime_image
from skimage.segmentation import slic
_HAS_LIME = True
except Exception:
_HAS_LIME = False
# ---------------- Your model ----------------
from model import ModifiedMobileNetV2
# ====================== Logging setup ======================
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ====================== FastAPI setup ======================
app = FastAPI(
title="Gallbladder Classification API (Fast XAI)",
description="Gallbladder ultrasound classifier + Optimized Grad-CAM, SHAP, LIME explanations"
)
# ====================== CORS ======================
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ====================== Class Names ======================
class_names = [
'Gallstones', 'Intra-abdominal&Retroperitoneum', 'Cholecystitis', 'Gangrenous_Cholecystitis',
'Perforation', 'Polyps&Cholesterol_Crystal', 'Adenomyomatosis',
'Carcinoma', 'WallThickening', 'Normal'
]
# ====================== Device & Model ======================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
model = None
def _strip_module_prefix(state_dict):
"""Remove leading 'module.' from keys if present (DataParallel artifacts)."""
new_state = {}
for k, v in state_dict.items():
new_key = k
if k.startswith("module."):
new_key = k[len("module."):]
new_state[new_key] = v
return new_state
def load_model(model_path="GB_stu_mob.pth"):
global model
if not os.path.exists(model_path):
logger.error(f"Model file not found at {model_path}")
raise FileNotFoundError(f"Model file not found at {model_path}")
model = ModifiedMobileNetV2(num_classes=len(class_names)).to(device)
checkpoint = torch.load(model_path, map_location=device)
state_dict = _strip_module_prefix(checkpoint)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
logger.info("✅ Model loaded and set to eval()")
try:
load_model()
_MODEL_READY = True
except Exception as e:
logger.exception("Model load failed")
_MODEL_READY = False
# ====================== Preprocessing ======================
IMG_SIZE = 224
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
preprocess = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
normalize
])
def pil_to_tensor(pil_img):
return preprocess(pil_img).unsqueeze(0) # 1 x C x H x W
# ====================== Utils ======================
def _encode_png(np_img_bgr_uint8) -> str:
"""Encodes a BGR uint8 numpy array to a base64 PNG string."""
if np_img_bgr_uint8 is None or np_img_bgr_uint8.size == 0:
raise ValueError("Image provided to _encode_png is empty.")
# CRITICAL FIX: Ensure array is contiguous and uint8 to satisfy OpenCV/libpng
if not np_img_bgr_uint8.flags['C_CONTIGUOUS']:
np_img_bgr_uint8 = np.ascontiguousarray(np_img_bgr_uint8)
if np_img_bgr_uint8.dtype != np.uint8:
np_img_bgr_uint8 = np_img_bgr_uint8.astype(np.uint8)
ok, buf = cv2.imencode(".png", np_img_bgr_uint8)
if not ok:
raise RuntimeError("PNG encoding failed")
return base64.b64encode(buf.tobytes()).decode("utf-8")
def _overlay_heatmap(rgb_uint8, heatmap_float01):
"""
Applies Jet colormap, but forces low confidence to Black.
MODIFIED: Uses additive blending for much brighter results.
"""
# 1. Sanitize Inputs
if heatmap_float01.ndim > 2: heatmap_float01 = heatmap_float01.squeeze()
if heatmap_float01.ndim == 3: heatmap_float01 = heatmap_float01.mean(axis=2)
h, w = rgb_uint8.shape[:2]
if heatmap_float01.shape != (h, w):
heatmap_float01 = cv2.resize(heatmap_float01, (w, h))
# 2. Prepare for Opencv (float32 for accurate math)
bgr = cv2.cvtColor(rgb_uint8, cv2.COLOR_RGB2BGR).astype(np.float32)
heatmap = heatmap_float01.astype(np.float32).clip(0, 1)
# 3. Apply Jet Colormap
heatmap_uint8 = (heatmap * 255).astype(np.uint8)
jet_bgr = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET).astype(np.float32)
# --- BLENDING LOGIC CHANGED HERE ---
# 4. Prepare Heatmap Glow
mask = heatmap[..., None] # (H, W, 1)
# Adjust this 'glow_strength' value if needed.
# 0.7 is strong, 1.0 is very intense.
glow_strength = 0.7
jet_masked = jet_bgr * mask * glow_strength
# 5. Blend (Additive)
# We ADD the glow to the original BGR image.
# The background stays full brightness.
overlay = bgr + jet_masked
# 6. Finalize (Crucial: clip to 255 to prevent color weirdness)
overlay = overlay.clip(0, 255).astype(np.uint8)
return overlay
def _last_conv_name_for_mobilenet_v2(model: nn.Module) -> str:
last_name = None
for name, m in model.named_modules():
if isinstance(m, nn.Conv2d):
last_name = name
if last_name is None:
raise RuntimeError("No Conv2d layer found for Grad-CAM")
return last_name
# ====================== Inference ======================
def _predict_logits(tensor_b1chw: torch.Tensor) -> torch.Tensor:
model.eval()
return model(tensor_b1chw.to(device))
def _predict_proba_numpy(np_bhwc_uint8: np.ndarray) -> np.ndarray:
model.eval()
imgs = []
for img in np_bhwc_uint8:
pil = Image.fromarray(img.astype(np.uint8), mode="RGB")
imgs.append(preprocess(pil))
batch = torch.stack(imgs).to(device)
with torch.no_grad():
logits = model(batch)
probs = torch.softmax(logits, dim=1).cpu().numpy()
return probs
def predict_pil(pil_img: Image.Image):
if not _MODEL_READY:
raise HTTPException(status_code=500, detail="Model not loaded")
x = pil_to_tensor(pil_img).to(device)
model.eval()
with torch.no_grad():
logits = _predict_logits(x)
probs = torch.softmax(logits, dim=1)
idx = int(torch.argmax(probs, dim=1).item())
conf = float(probs[0, idx].item())
return class_names[idx], conf, idx
# ====================== Grad-CAM ======================
def gradcam_explain(pil_img: Image.Image, class_idx: int = None) -> str:
model.zero_grad()
gradients = []
activations = []
target_layer = _last_conv_name_for_mobilenet_v2(model)
layer_ref = dict(model.named_modules())[target_layer]
def fwd_hook(_, __, output): activations.append(output)
def bwd_hook(_, grad_in, grad_out): gradients.append(grad_out[0])
fwd_h = layer_ref.register_forward_hook(fwd_hook)
bwd_h = layer_ref.register_backward_hook(bwd_hook)
try:
x = pil_to_tensor(pil_img).to(device)
logits = model(x)
if class_idx is None:
class_idx = int(torch.argmax(logits, dim=1).item())
score = logits[0, class_idx]
model.zero_grad()
score.backward()
grad = gradients[0].detach().cpu().numpy()[0]
act = activations[0].detach().cpu().numpy()[0]
weights = grad.mean(axis=(1, 2))
cam = np.maximum(np.sum(weights[:, None, None] * act, axis=0), 0)
cam = cv2.resize(cam, (IMG_SIZE, IMG_SIZE))
cam -= cam.min()
if cam.max() > 1e-12:
cam /= cam.max()
else:
cam[:] = 0.0
rgb = np.array(pil_img.resize((IMG_SIZE, IMG_SIZE))).astype(np.uint8)
out_bgr = _overlay_heatmap(rgb, cam)
return _encode_png(out_bgr)
finally:
fwd_h.remove()
bwd_h.remove()
# ====================== SHAP ======================
def shap_explain(pil_img: Image.Image, top_label: int) -> str:
if not _HAS_SHAP: raise HTTPException(status_code=400, detail="SHAP missing")
# --- Model Copy Logic ---
temp_model = copy.deepcopy(model)
temp_model.eval()
def disable_inplace(m):
if hasattr(m, "inplace"): m.inplace = False
def replace_relu6(m):
for name, child in list(m.named_children()):
if isinstance(child, nn.ReLU6): setattr(m, name, nn.ReLU(inplace=False))
else: replace_relu6(child)
temp_model.apply(disable_inplace)
replace_relu6(temp_model)
def predict_numpy(np_imgs):
batch = []
for arr in np_imgs:
pil = Image.fromarray(arr.astype(np.uint8))
batch.append(preprocess(pil))
x = torch.stack(batch).to(device)
with torch.no_grad():
prob = torch.softmax(temp_model(x), dim=1).cpu().numpy()
return prob
# -----------------------
img_224 = pil_img.resize((IMG_SIZE, IMG_SIZE))
x_nhwc = np.asarray(img_224, dtype=np.uint8)[None, ...]
masker = shap.maskers.Image("blur(16,16)", x_nhwc[0].shape)
explainer = shap.Explainer(predict_numpy, masker, algorithm="partition")
# Run SHAP
shap_values = explainer(x_nhwc, max_evals=300, batch_size=50, outputs=[top_label])
# --- Robust Shape Handling ---
if hasattr(shap_values, 'values'):
sv = shap_values.values
else:
sv = shap_values
# Handle shape variations (Batch, H, W, C, Outputs) or (Batch, H, W, C)
# We want strictly (H, W, C) at the end
sv = np.array(sv) # Ensure numpy
# 1. Strip batch dim if present
if sv.ndim >= 4 and sv.shape[0] == 1:
sv = sv[0]
# 2. Strip output dim if present (e.g., shape is H, W, C, 1)
if sv.ndim == 4 and sv.shape[-1] == 1:
sv = sv[..., 0]
# 3. Final validation
if sv.ndim != 3:
logger.warning(f"SHAP unexpected shape: {sv.shape}. Trying to squeeze.")
sv = sv.squeeze()
if sv.ndim != 3:
# Last resort: if we can't find H,W,C, fail gracefully
logger.error("Could not resolve SHAP values to (H,W,C).")
return _encode_png(np.array(img_224))
# Handle NaNs
sv = np.nan_to_num(sv, nan=0.0)
# Convert to Grayscale Heatmap
sv_gray = np.mean(np.abs(sv), axis=2) # (H, W)
# Normalize
sv_gray -= sv_gray.min()
if sv_gray.max() > 1e-8:
sv_gray /= sv_gray.max()
img_rgb = np.array(img_224).astype(np.uint8)
# Create overlay
overlay = _overlay_heatmap(img_rgb, sv_gray)
del temp_model
torch.cuda.empty_cache()
return _encode_png(overlay)
# ====================== LIME ======================
def lime_explain(pil_img: Image.Image, top_label: int) -> str:
if not _HAS_LIME: raise HTTPException(status_code=400, detail="LIME missing")
rgb = np.array(pil_img.resize((IMG_SIZE, IMG_SIZE))).astype(np.uint8)
explainer = lime_image.LimeImageExplainer()
try:
explanation = explainer.explain_instance(
image=rgb, classifier_fn=_predict_proba_numpy, top_labels=3,
hide_color=0, num_samples=300,
segmentation_fn=lambda x: slic(x, n_segments=200, compactness=0.01, sigma=1)
)
except Exception as e:
logger.exception("LIME failed")
raise HTTPException(status_code=500, detail=str(e))
chosen_label = int(explanation.top_labels[0]) if explanation.top_labels else int(top_label)
temp_img, mask = explanation.get_image_and_mask(
label=chosen_label, positive_only=False, num_features=10, hide_rest=False
)
mask_float = np.abs(mask).astype(np.float32)
if mask_float.max() > 1e-12:
mask_float /= mask_float.max()
mask_float = cv2.GaussianBlur(mask_float, (3, 3), 0)
out_bgr = _overlay_heatmap(rgb, mask_float)
return _encode_png(out_bgr)
# ====================== API ROUTES ======================
@app.get("/")
async def root():
return {"message": "Gallbladder Classification API (XAI Optimized)"}
@app.get("/health")
async def health():
return {
"status": "healthy" if _MODEL_READY else "unhealthy",
"device": str(device),
"has_shap": _HAS_SHAP,
"has_lime": _HAS_LIME
}
@app.post("/predict")
async def predict_image(file: UploadFile = File(...)):
if not _MODEL_READY:
raise HTTPException(status_code=500, detail="Model not loaded")
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
data = await file.read()
pil = Image.open(BytesIO(data)).convert("RGB")
cls, conf, _ = predict_pil(pil)
return {
"filename": file.filename,
"predicted_class": cls,
"confidence_score": round(conf, 4)
}
@app.post("/explain")
async def explain_image(
file: UploadFile = File(...),
method: str = Query("all", regex="^(gradcam|shap|lime|all)$")
):
if not _MODEL_READY:
raise HTTPException(status_code=500, detail="Model not loaded")
data = await file.read()
pil = Image.open(BytesIO(data)).convert("RGB")
cls, conf, top_idx = predict_pil(pil)
out = {
"predicted_class": cls,
"confidence_score": round(conf, 4),
"xai": {}
}
if method in ("gradcam", "all"):
try:
out["xai"]["gradcam"] = await run_in_threadpool(gradcam_explain, pil, top_idx)
except Exception as e:
logger.exception("Grad-CAM failed")
out["xai"]["gradcam_error"] = str(e)
if method in ("shap", "all"):
if not _HAS_SHAP:
out["xai"]["shap_error"] = "SHAP not installed"
else:
try:
out["xai"]["shap"] = await run_in_threadpool(shap_explain, pil, top_idx)
except Exception as e:
logger.exception("SHAP failed")
out["xai"]["shap_error"] = str(e)
if method in ("lime", "all"):
if not _HAS_LIME:
out["xai"]["lime_error"] = "LIME not installed"
else:
try:
out["xai"]["lime"] = await run_in_threadpool(lime_explain, pil, top_idx)
except Exception as e:
logger.exception("LIME failed")
out["xai"]["lime_error"] = str(e)
return out
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)