# app_pytorch_inference.py """ Robust Flask inference server for multi-task EfficientNet-B3 model (classification + segmentation). - Robust checkpoint/state_dict loading - Tolerant Grad-CAM initialization across versions - Thread-safe Grad-CAM usage - Optional skipping of CAM/mask via query params - SQLite logging of predictions - CORS configured for dev origins - Extra: Fundus validation + low-confidence rejection for non-retinal images """ import io import os import json import base64 import traceback import threading from pathlib import Path from datetime import datetime from PIL import Image import numpy as np import cv2 from flask import Flask, request, jsonify from flask_cors import CORS import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models import torchvision.transforms as T # tolerant import of pytorch-grad-cam try: from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget except Exception: GradCAM = None show_cam_on_image = None preprocess_image = None ClassifierOutputTarget = None from sqlalchemy import create_engine, text import pandas as pd # ---------------- CONFIG - edit these ---------------- MODEL_PATH = Path(__file__).parent / "eye_model_lite.pth" LOG_DB_PATH = "sqlite:///predictions_flask.db" IMG_SIZE = 224 MAX_UPLOAD_MB = 12 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Reject if top class confidence is below this INVALID_THRESHOLD = 0.5 # 60% – you can tune to 0.7 or 0.8 if needed # CORS origins (not strictly used in current CORS(app) call, but kept for reference) CORS_ORIGINS = [ "http://localhost:3000", "http://127.0.0.1:3000", "https://ai-eye-disease-detection-chi.vercel.app", ] # ---------------------------------------------------- # Flask app app = Flask(__name__) CORS(app) app.config["MAX_CONTENT_LENGTH"] = MAX_UPLOAD_MB * 1024 * 1024 # DB engine engine = create_engine(LOG_DB_PATH, echo=False) # class map - MUST match training order CLASS_MAP_INV = { 0: "Normal", 1: "Cataract", 2: "Diabetic Retinopathy", 3: "Glaucoma" } NUM_CLASSES = len(CLASS_MAP_INV) # ---------------- Model definition (must match training) ---------------- class MultiTaskNet(nn.Module): def __init__(self, num_classes=4, dropout=0.5, img_size=IMG_SIZE): super().__init__() self.encoder = models.efficientnet_b3(weights=None) try: enc_features = self.encoder.classifier[1].in_features except: enc_features = 1536 self.encoder.classifier = nn.Identity() self.classifier = nn.Sequential(nn.Dropout(dropout), nn.Linear(enc_features, num_classes)) self.seg_head = nn.Sequential( nn.ConvTranspose2d(enc_features, 256, 4, 2, 1), nn.ReLU(), nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.ReLU(), nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(), nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(), nn.ConvTranspose2d(32, 1, 4, 2, 1) ) self.log_vars = nn.Parameter(torch.zeros(2)) self.img_size = img_size def forward(self, x): feats = self.encoder.features(x) pooled = F.adaptive_avg_pool2d(feats, 1).reshape(feats.shape[0], -1) cls_out = self.classifier(pooled) seg_out = self.seg_head(feats) if seg_out.shape[-2:] != (self.img_size, self.img_size): seg_out = F.interpolate(seg_out, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False) return cls_out, seg_out # ---------------- Globals ---------------- _model = None _gradcam = None _classification_wrapper = None _gradcam_lock = threading.Lock() # Preprocess transform MEAN = [0.485, 0.456, 0.406] STD = [0.229, 0.224, 0.225] def build_preprocess(): return T.Compose([ T.Resize((IMG_SIZE, IMG_SIZE)), T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) def pil_to_tensor_for_model(pil_img): tf = build_preprocess() return tf(pil_img).unsqueeze(0).to(DEVICE) def encode_base64_png_from_pil(pil_img): buff = io.BytesIO() pil_img.save(buff, format="PNG") buff.seek(0) return base64.b64encode(buff.read()).decode("utf-8") def overlay_heatmap_on_pil(pil_rgb, cam_mask, alpha=0.4): rgb = np.array(pil_rgb).astype(np.float32) / 255.0 cam_uint8 = (np.clip(cam_mask, 0, 1) * 255).astype("uint8") heatmap = cv2.applyColorMap(cam_uint8, cv2.COLORMAP_JET) heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB).astype(np.float32)/255.0 overlay = (1-alpha)*rgb + alpha*heatmap overlay = np.clip(overlay, 0, 1) overlay_img = Image.fromarray((overlay*255).astype("uint8")) return overlay_img # --- NEW FUNCTION FOR RED MASK OVERLAY --- def overlay_red_mask_on_pil(pil_rgb, binary_mask_uint8, alpha=0.5): """ Overlays a red color where the mask is 1 (255), transparent elsewhere. """ rgb = np.array(pil_rgb) red_layer = np.zeros_like(rgb) red_layer[:, :, 0] = 255 # Red channel full mask_bool = binary_mask_uint8 > 0 output = rgb.copy() output[mask_bool] = (rgb[mask_bool] * (1 - alpha) + red_layer[mask_bool] * alpha).astype(np.uint8) return Image.fromarray(output) # ---------- NEW: simple fundus check (circle detection) ---------- def is_probably_fundus(pil_img): """ Heuristic check to see if the image looks like a retinal fundus. Uses Hough circle detection on a grayscale version. Returns True if it likely is a fundus image, False otherwise. """ try: img = np.array(pil_img) if img.ndim == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) else: img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) h, w = img.shape[:2] max_dim = max(h, w) scale = 512.0 / max_dim if max_dim > 512 else 1.0 if scale != 1.0: img = cv2.resize(img, (int(w * scale), int(h * scale))) gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) gray = cv2.medianBlur(gray, 5) rows = gray.shape[0] circles = cv2.HoughCircles( gray, cv2.HOUGH_GRADIENT, dp=1.2, minDist=rows / 4, param1=50, param2=30, minRadius=int(rows * 0.2), maxRadius=int(rows * 0.5) ) if circles is None: # No round bright disc -> likely not fundus return False return True except Exception as e: print("Fundus check failed:", e) # If check itself fails, be conservative and say it's not fundus return False # ---------------- DB utilities ---------------- def init_db(): with engine.begin() as conn: conn.execute(text(""" CREATE TABLE IF NOT EXISTS predictions ( id INTEGER PRIMARY KEY AUTOINCREMENT, filename TEXT, predicted_disease TEXT, confidence REAL, probabilities TEXT, heatmap_base64 TEXT, mask_base64 TEXT, created_at DATETIME DEFAULT CURRENT_TIMESTAMP ) """)) # ---------------- Model loader (robust GradCAM init) ---------------- def _find_target_conv(module: nn.Module): try: feats = module.encoder.features last = None for m in feats.modules(): if isinstance(m, nn.Conv2d): last = m if last is not None: return last except Exception: pass last = None for m in module.modules(): if isinstance(m, nn.Conv2d): last = m return last def load_model(): global _model, _gradcam, _classification_wrapper if _model is not None: return if not MODEL_PATH.exists(): raise FileNotFoundError(f"Model checkpoint not found: {MODEL_PATH}") print("Loading model from:", MODEL_PATH) m = MultiTaskNet(num_classes=NUM_CLASSES).to(DEVICE) try: ckpt = torch.load(MODEL_PATH, map_location=DEVICE) except Exception as e: print(f"Failed to load checkpoint file: {e}") raise e state = ckpt if isinstance(ckpt, dict): if 'model' in ckpt: state = ckpt['model'] else: for key in ("model_state", "model_state_dict", "state_dict"): if key in ckpt: state = ckpt[key] break if isinstance(state, dict): new_state = {} for k, v in state.items(): nk = k.replace("module.", "") if isinstance(k, str) and k.startswith("module.") else k new_state[nk] = v state = new_state try: m.load_state_dict(state, strict=True) print("✅ Model loaded with strict=True") except Exception as e: print("Warning: strict load_state_dict failed:", e) try: m.load_state_dict(state, strict=False) print("⚠️ Loaded with strict=False.") except Exception as e2: print("Final load attempt failed:", e2) raise e2 m.eval() _model = m.to(DEVICE) print("Model loaded to", DEVICE) if GradCAM is None or show_cam_on_image is None or preprocess_image is None: print("pytorch-grad-cam not available; Grad-CAM disabled.") _gradcam = None _classification_wrapper = None return target_layer = _find_target_conv(_model) if target_layer is None: print("Could not find a conv layer for Grad-CAM; disabling CAM.") _gradcam = None _classification_wrapper = None return class ClassificationOnlyWrapper(nn.Module): def __init__(self, full_model): super().__init__() self.full = full_model def forward(self, x): cls, _ = self.full(x) return cls _classification_wrapper = ClassificationOnlyWrapper(_model).to(DEVICE) _gradcam = None try: _gradcam = GradCAM(model=_classification_wrapper, target_layers=[target_layer], use_cuda=(DEVICE=="cuda")) print("GradCAM initialized with use_cuda.") except TypeError: try: _gradcam = GradCAM(model=_classification_wrapper, target_layers=[target_layer], device=torch.device(DEVICE)) print("GradCAM initialized with device arg.") except TypeError: try: _gradcam = GradCAM(model=_classification_wrapper, target_layers=[target_layer]) print("GradCAM initialized without extra kwargs.") except Exception as e: print("GradCAM initialization failed; disabling CAM. Error:", e) _gradcam = None except Exception as e: print("Unexpected error initializing GradCAM; disabling CAM. Error:", e) _gradcam = None if _gradcam is not None: print("GradCAM ready.") else: print("GradCAM not available; continuing without CAM.") # ---------------- Routes ---------------- @app.route("/", methods=["GET", "HEAD"]) def index(): return "Backend is running!" @app.route("/health", methods=["GET"]) def health(): return jsonify({"status": "ok", "device": DEVICE}) @app.route("/history", methods=["GET"]) def history(): try: with engine.connect() as conn: rows = conn.execute(text( "SELECT id, filename, predicted_disease, confidence, probabilities, created_at " "FROM predictions ORDER BY created_at DESC LIMIT 200" )).fetchall() out = [] for r in rows: out.append({ "id": r[0], "filename": r[1], "predicted_disease": r[2], "confidence": float(r[3]) if r[3] is not None else None, "probabilities": json.loads(r[4]) if r[4] else None, "created_at": str(r[5]) }) return jsonify(out) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/predict", methods=["POST"]) def predict(): """ POST multipart/form-data with key "image" Optional query params: - no_cam=1 -> skip Grad-CAM generation - no_mask=1 -> skip mask generation """ try: load_model() init_db() if "image" not in request.files: return jsonify({"error": "no image file uploaded under key 'image'"}), 400 f = request.files["image"] if f.filename == "": return jsonify({"error": "empty filename"}), 400 no_cam = request.args.get("no_cam", "0").lower() in ("1", "true", "yes") no_mask = request.args.get("no_mask", "0").lower() in ("1", "true", "yes") # ---- Open image ---- pil = Image.open(f.stream).convert("RGB") # ---- FUNDUS SHAPE CHECK (reject non-eye images early) ---- if not is_probably_fundus(pil): print("⚠️ Rejected image: does not look like a retinal fundus") prob_map = {} with engine.begin() as conn: conn.execute( text("INSERT INTO predictions (filename, predicted_disease, confidence, probabilities) " "VALUES (:fn, :pd, :c, :p)"), { "fn": f.filename, "pd": "Invalid / Non-retinal Image", "c": 0.0, "p": json.dumps(prob_map) } ) return jsonify({ "predicted_disease": "Invalid / Non-retinal Image", "confidence": 0.0, "probabilities": prob_map, "message": "Please upload a valid retinal fundus image." }), 200 # If passed fundus check → resize for model pil_resized = pil.resize((IMG_SIZE, IMG_SIZE)) inp_tensor = pil_to_tensor_for_model(pil_resized) # ---- Forward pass (classification + segmentation) ---- with torch.inference_mode(): out = _model(inp_tensor) if isinstance(out, (list, tuple)): cls_logits = out[0] seg_logits = out[1] if len(out) > 1 else None else: cls_logits = out seg_logits = None if not isinstance(cls_logits, torch.Tensor): raise RuntimeError(f"Unexpected classification output type: {type(cls_logits)}") probs = torch.softmax(cls_logits, dim=1).cpu().numpy()[0] pred_idx = int(np.argmax(probs)) pred_label = CLASS_MAP_INV.get(pred_idx, str(pred_idx)) confidence = float(probs[pred_idx]) # Build probabilities map prob_map = {CLASS_MAP_INV[i]: float(round(float(probs[i]), 6)) for i in range(len(probs))} # ---- LOW CONFIDENCE CHECK (treat as invalid) ---- if confidence < INVALID_THRESHOLD: print(f"⚠️ Low confidence ({confidence:.3f}) – marking as Invalid / Non-retinal Image") pred_label = "Invalid / Non-retinal Image" with engine.begin() as conn: conn.execute( text("INSERT INTO predictions (filename, predicted_disease, confidence, probabilities) " "VALUES (:fn, :pd, :c, :p)"), { "fn": f.filename, "pd": pred_label, "c": confidence, "p": json.dumps(prob_map) } ) return jsonify({ "predicted_disease": pred_label, "confidence": confidence, "probabilities": prob_map, "message": "Model is not confident. Please upload a clear retinal fundus image." }), 200 # ---- SEGMENTATION MASK LOGIC ---- mask_b64 = None if (not no_mask) and (seg_logits is not None): try: seg_prob = torch.sigmoid(seg_logits).detach().cpu().numpy()[0, 0] if pred_label == "Normal": mask_uint8 = np.zeros_like(seg_prob, dtype="uint8") else: mask_uint8 = (seg_prob > 0.25).astype("uint8") * 255 if mask_uint8.max() > 0: mask_pil_overlay = overlay_red_mask_on_pil(pil_resized, mask_uint8) mask_b64 = encode_base64_png_from_pil(mask_pil_overlay) else: mask_b64 = encode_base64_png_from_pil(pil_resized) except Exception as e: print(f"Mask generation failed: {e}") traceback.print_exc() mask_b64 = None # ---- GRAD-CAM ---- overlay_b64 = None if (not no_cam) and (_gradcam is not None) and (preprocess_image is not None): try: rgb_for_cam = np.array(pil_resized).astype(np.float32) / 255.0 input_for_cam = preprocess_image(rgb_for_cam, mean=MEAN, std=STD).to(DEVICE) with _gradcam_lock: grayscale_cam = _gradcam(input_for_cam, targets=[ClassifierOutputTarget(pred_idx)]) cam_np = np.array(grayscale_cam) cam_np = np.squeeze(cam_np) if cam_np.ndim == 3: cam_np = cam_np[0] cam_np = cam_np.astype(np.float32) if cam_np.max() > 0: cam_np = (cam_np - cam_np.min()) / (cam_np.max() - cam_np.min() + 1e-8) else: cam_np = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.float32) if cam_np.shape != (IMG_SIZE, IMG_SIZE): cam_np = cv2.resize(cam_np, (IMG_SIZE, IMG_SIZE)) overlay_pil = overlay_heatmap_on_pil(pil_resized, cam_np) overlay_b64 = encode_base64_png_from_pil(overlay_pil) except Exception as e: print("Grad-CAM generation error:", e) traceback.print_exc() overlay_b64 = None # ---- Store in DB ---- probabilities_json = json.dumps(prob_map) with engine.begin() as conn: conn.execute(text( "INSERT INTO predictions (filename, predicted_disease, confidence, probabilities, heatmap_base64, mask_base64) " "VALUES (:fn,:pd,:c,:p,:h,:m)" ), { "fn": f.filename, "pd": pred_label, "c": confidence, "p": probabilities_json, "h": overlay_b64, "m": mask_b64 }) # ---- Build response ---- response = { "predicted_disease": pred_label, "confidence": confidence, "probabilities": prob_map } if overlay_b64 is not None: response["heatmap_png_base64"] = overlay_b64 if mask_b64 is not None: response["mask_png_base64"] = mask_b64 return jsonify(response) except Exception as e: traceback.print_exc() return jsonify({"error": str(e)}), 500 # ---------------- Main ---------------- if __name__ == "__main__": print("Starting Flask server on 0.0.0.0:8000") init_db() try: load_model() except Exception as e: print("Model failed to load on startup:", e) app.run(host="0.0.0.0", port=8000, debug=False)