eye_backend / app.py
rocky200416's picture
Update app.py
547a6b4 verified
# app_pytorch_inference.py
"""
Robust Flask inference server for multi-task EfficientNet-B3 model (classification + segmentation).
- Robust checkpoint/state_dict loading
- Tolerant Grad-CAM initialization across versions
- Thread-safe Grad-CAM usage
- Optional skipping of CAM/mask via query params
- SQLite logging of predictions
- CORS configured for dev origins
- Extra: Fundus validation + low-confidence rejection for non-retinal images
"""
import io
import os
import json
import base64
import traceback
import threading
from pathlib import Path
from datetime import datetime
from PIL import Image
import numpy as np
import cv2
from flask import Flask, request, jsonify
from flask_cors import CORS
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torchvision.transforms as T
# tolerant import of pytorch-grad-cam
try:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
except Exception:
GradCAM = None
show_cam_on_image = None
preprocess_image = None
ClassifierOutputTarget = None
from sqlalchemy import create_engine, text
import pandas as pd
# ---------------- CONFIG - edit these ----------------
MODEL_PATH = Path(__file__).parent / "eye_model_lite.pth"
LOG_DB_PATH = "sqlite:///predictions_flask.db"
IMG_SIZE = 224
MAX_UPLOAD_MB = 12
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Reject if top class confidence is below this
INVALID_THRESHOLD = 0.5 # 60% – you can tune to 0.7 or 0.8 if needed
# CORS origins (not strictly used in current CORS(app) call, but kept for reference)
CORS_ORIGINS = [
"http://localhost:3000",
"http://127.0.0.1:3000",
"https://ai-eye-disease-detection-chi.vercel.app",
]
# ----------------------------------------------------
# Flask app
app = Flask(__name__)
CORS(app)
app.config["MAX_CONTENT_LENGTH"] = MAX_UPLOAD_MB * 1024 * 1024
# DB engine
engine = create_engine(LOG_DB_PATH, echo=False)
# class map - MUST match training order
CLASS_MAP_INV = {
0: "Normal",
1: "Cataract",
2: "Diabetic Retinopathy",
3: "Glaucoma"
}
NUM_CLASSES = len(CLASS_MAP_INV)
# ---------------- Model definition (must match training) ----------------
class MultiTaskNet(nn.Module):
def __init__(self, num_classes=4, dropout=0.5, img_size=IMG_SIZE):
super().__init__()
self.encoder = models.efficientnet_b3(weights=None)
try:
enc_features = self.encoder.classifier[1].in_features
except:
enc_features = 1536
self.encoder.classifier = nn.Identity()
self.classifier = nn.Sequential(nn.Dropout(dropout), nn.Linear(enc_features, num_classes))
self.seg_head = nn.Sequential(
nn.ConvTranspose2d(enc_features, 256, 4, 2, 1), nn.ReLU(),
nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(),
nn.ConvTranspose2d(32, 1, 4, 2, 1)
)
self.log_vars = nn.Parameter(torch.zeros(2))
self.img_size = img_size
def forward(self, x):
feats = self.encoder.features(x)
pooled = F.adaptive_avg_pool2d(feats, 1).reshape(feats.shape[0], -1)
cls_out = self.classifier(pooled)
seg_out = self.seg_head(feats)
if seg_out.shape[-2:] != (self.img_size, self.img_size):
seg_out = F.interpolate(seg_out, size=(self.img_size, self.img_size),
mode='bilinear', align_corners=False)
return cls_out, seg_out
# ---------------- Globals ----------------
_model = None
_gradcam = None
_classification_wrapper = None
_gradcam_lock = threading.Lock()
# Preprocess transform
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
def build_preprocess():
return T.Compose([
T.Resize((IMG_SIZE, IMG_SIZE)),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
def pil_to_tensor_for_model(pil_img):
tf = build_preprocess()
return tf(pil_img).unsqueeze(0).to(DEVICE)
def encode_base64_png_from_pil(pil_img):
buff = io.BytesIO()
pil_img.save(buff, format="PNG")
buff.seek(0)
return base64.b64encode(buff.read()).decode("utf-8")
def overlay_heatmap_on_pil(pil_rgb, cam_mask, alpha=0.4):
rgb = np.array(pil_rgb).astype(np.float32) / 255.0
cam_uint8 = (np.clip(cam_mask, 0, 1) * 255).astype("uint8")
heatmap = cv2.applyColorMap(cam_uint8, cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB).astype(np.float32)/255.0
overlay = (1-alpha)*rgb + alpha*heatmap
overlay = np.clip(overlay, 0, 1)
overlay_img = Image.fromarray((overlay*255).astype("uint8"))
return overlay_img
# --- NEW FUNCTION FOR RED MASK OVERLAY ---
def overlay_red_mask_on_pil(pil_rgb, binary_mask_uint8, alpha=0.5):
"""
Overlays a red color where the mask is 1 (255), transparent elsewhere.
"""
rgb = np.array(pil_rgb)
red_layer = np.zeros_like(rgb)
red_layer[:, :, 0] = 255 # Red channel full
mask_bool = binary_mask_uint8 > 0
output = rgb.copy()
output[mask_bool] = (rgb[mask_bool] * (1 - alpha) + red_layer[mask_bool] * alpha).astype(np.uint8)
return Image.fromarray(output)
# ---------- NEW: simple fundus check (circle detection) ----------
def is_probably_fundus(pil_img):
"""
Heuristic check to see if the image looks like a retinal fundus.
Uses Hough circle detection on a grayscale version.
Returns True if it likely is a fundus image, False otherwise.
"""
try:
img = np.array(pil_img)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
else:
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
h, w = img.shape[:2]
max_dim = max(h, w)
scale = 512.0 / max_dim if max_dim > 512 else 1.0
if scale != 1.0:
img = cv2.resize(img, (int(w * scale), int(h * scale)))
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray = cv2.medianBlur(gray, 5)
rows = gray.shape[0]
circles = cv2.HoughCircles(
gray,
cv2.HOUGH_GRADIENT,
dp=1.2,
minDist=rows / 4,
param1=50,
param2=30,
minRadius=int(rows * 0.2),
maxRadius=int(rows * 0.5)
)
if circles is None:
# No round bright disc -> likely not fundus
return False
return True
except Exception as e:
print("Fundus check failed:", e)
# If check itself fails, be conservative and say it's not fundus
return False
# ---------------- DB utilities ----------------
def init_db():
with engine.begin() as conn:
conn.execute(text("""
CREATE TABLE IF NOT EXISTS predictions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
filename TEXT,
predicted_disease TEXT,
confidence REAL,
probabilities TEXT,
heatmap_base64 TEXT,
mask_base64 TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"""))
# ---------------- Model loader (robust GradCAM init) ----------------
def _find_target_conv(module: nn.Module):
try:
feats = module.encoder.features
last = None
for m in feats.modules():
if isinstance(m, nn.Conv2d):
last = m
if last is not None:
return last
except Exception:
pass
last = None
for m in module.modules():
if isinstance(m, nn.Conv2d):
last = m
return last
def load_model():
global _model, _gradcam, _classification_wrapper
if _model is not None:
return
if not MODEL_PATH.exists():
raise FileNotFoundError(f"Model checkpoint not found: {MODEL_PATH}")
print("Loading model from:", MODEL_PATH)
m = MultiTaskNet(num_classes=NUM_CLASSES).to(DEVICE)
try:
ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
except Exception as e:
print(f"Failed to load checkpoint file: {e}")
raise e
state = ckpt
if isinstance(ckpt, dict):
if 'model' in ckpt:
state = ckpt['model']
else:
for key in ("model_state", "model_state_dict", "state_dict"):
if key in ckpt:
state = ckpt[key]
break
if isinstance(state, dict):
new_state = {}
for k, v in state.items():
nk = k.replace("module.", "") if isinstance(k, str) and k.startswith("module.") else k
new_state[nk] = v
state = new_state
try:
m.load_state_dict(state, strict=True)
print("✅ Model loaded with strict=True")
except Exception as e:
print("Warning: strict load_state_dict failed:", e)
try:
m.load_state_dict(state, strict=False)
print("⚠️ Loaded with strict=False.")
except Exception as e2:
print("Final load attempt failed:", e2)
raise e2
m.eval()
_model = m.to(DEVICE)
print("Model loaded to", DEVICE)
if GradCAM is None or show_cam_on_image is None or preprocess_image is None:
print("pytorch-grad-cam not available; Grad-CAM disabled.")
_gradcam = None
_classification_wrapper = None
return
target_layer = _find_target_conv(_model)
if target_layer is None:
print("Could not find a conv layer for Grad-CAM; disabling CAM.")
_gradcam = None
_classification_wrapper = None
return
class ClassificationOnlyWrapper(nn.Module):
def __init__(self, full_model):
super().__init__()
self.full = full_model
def forward(self, x):
cls, _ = self.full(x)
return cls
_classification_wrapper = ClassificationOnlyWrapper(_model).to(DEVICE)
_gradcam = None
try:
_gradcam = GradCAM(model=_classification_wrapper, target_layers=[target_layer], use_cuda=(DEVICE=="cuda"))
print("GradCAM initialized with use_cuda.")
except TypeError:
try:
_gradcam = GradCAM(model=_classification_wrapper, target_layers=[target_layer],
device=torch.device(DEVICE))
print("GradCAM initialized with device arg.")
except TypeError:
try:
_gradcam = GradCAM(model=_classification_wrapper, target_layers=[target_layer])
print("GradCAM initialized without extra kwargs.")
except Exception as e:
print("GradCAM initialization failed; disabling CAM. Error:", e)
_gradcam = None
except Exception as e:
print("Unexpected error initializing GradCAM; disabling CAM. Error:", e)
_gradcam = None
if _gradcam is not None:
print("GradCAM ready.")
else:
print("GradCAM not available; continuing without CAM.")
# ---------------- Routes ----------------
@app.route("/", methods=["GET", "HEAD"])
def index():
return "Backend is running!"
@app.route("/health", methods=["GET"])
def health():
return jsonify({"status": "ok", "device": DEVICE})
@app.route("/history", methods=["GET"])
def history():
try:
with engine.connect() as conn:
rows = conn.execute(text(
"SELECT id, filename, predicted_disease, confidence, probabilities, created_at "
"FROM predictions ORDER BY created_at DESC LIMIT 200"
)).fetchall()
out = []
for r in rows:
out.append({
"id": r[0],
"filename": r[1],
"predicted_disease": r[2],
"confidence": float(r[3]) if r[3] is not None else None,
"probabilities": json.loads(r[4]) if r[4] else None,
"created_at": str(r[5])
})
return jsonify(out)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/predict", methods=["POST"])
def predict():
"""
POST multipart/form-data with key "image"
Optional query params:
- no_cam=1 -> skip Grad-CAM generation
- no_mask=1 -> skip mask generation
"""
try:
load_model()
init_db()
if "image" not in request.files:
return jsonify({"error": "no image file uploaded under key 'image'"}), 400
f = request.files["image"]
if f.filename == "":
return jsonify({"error": "empty filename"}), 400
no_cam = request.args.get("no_cam", "0").lower() in ("1", "true", "yes")
no_mask = request.args.get("no_mask", "0").lower() in ("1", "true", "yes")
# ---- Open image ----
pil = Image.open(f.stream).convert("RGB")
# ---- FUNDUS SHAPE CHECK (reject non-eye images early) ----
if not is_probably_fundus(pil):
print("⚠️ Rejected image: does not look like a retinal fundus")
prob_map = {}
with engine.begin() as conn:
conn.execute(
text("INSERT INTO predictions (filename, predicted_disease, confidence, probabilities) "
"VALUES (:fn, :pd, :c, :p)"),
{
"fn": f.filename,
"pd": "Invalid / Non-retinal Image",
"c": 0.0,
"p": json.dumps(prob_map)
}
)
return jsonify({
"predicted_disease": "Invalid / Non-retinal Image",
"confidence": 0.0,
"probabilities": prob_map,
"message": "Please upload a valid retinal fundus image."
}), 200
# If passed fundus check → resize for model
pil_resized = pil.resize((IMG_SIZE, IMG_SIZE))
inp_tensor = pil_to_tensor_for_model(pil_resized)
# ---- Forward pass (classification + segmentation) ----
with torch.inference_mode():
out = _model(inp_tensor)
if isinstance(out, (list, tuple)):
cls_logits = out[0]
seg_logits = out[1] if len(out) > 1 else None
else:
cls_logits = out
seg_logits = None
if not isinstance(cls_logits, torch.Tensor):
raise RuntimeError(f"Unexpected classification output type: {type(cls_logits)}")
probs = torch.softmax(cls_logits, dim=1).cpu().numpy()[0]
pred_idx = int(np.argmax(probs))
pred_label = CLASS_MAP_INV.get(pred_idx, str(pred_idx))
confidence = float(probs[pred_idx])
# Build probabilities map
prob_map = {CLASS_MAP_INV[i]: float(round(float(probs[i]), 6)) for i in range(len(probs))}
# ---- LOW CONFIDENCE CHECK (treat as invalid) ----
if confidence < INVALID_THRESHOLD:
print(f"⚠️ Low confidence ({confidence:.3f}) – marking as Invalid / Non-retinal Image")
pred_label = "Invalid / Non-retinal Image"
with engine.begin() as conn:
conn.execute(
text("INSERT INTO predictions (filename, predicted_disease, confidence, probabilities) "
"VALUES (:fn, :pd, :c, :p)"),
{
"fn": f.filename,
"pd": pred_label,
"c": confidence,
"p": json.dumps(prob_map)
}
)
return jsonify({
"predicted_disease": pred_label,
"confidence": confidence,
"probabilities": prob_map,
"message": "Model is not confident. Please upload a clear retinal fundus image."
}), 200
# ---- SEGMENTATION MASK LOGIC ----
mask_b64 = None
if (not no_mask) and (seg_logits is not None):
try:
seg_prob = torch.sigmoid(seg_logits).detach().cpu().numpy()[0, 0]
if pred_label == "Normal":
mask_uint8 = np.zeros_like(seg_prob, dtype="uint8")
else:
mask_uint8 = (seg_prob > 0.25).astype("uint8") * 255
if mask_uint8.max() > 0:
mask_pil_overlay = overlay_red_mask_on_pil(pil_resized, mask_uint8)
mask_b64 = encode_base64_png_from_pil(mask_pil_overlay)
else:
mask_b64 = encode_base64_png_from_pil(pil_resized)
except Exception as e:
print(f"Mask generation failed: {e}")
traceback.print_exc()
mask_b64 = None
# ---- GRAD-CAM ----
overlay_b64 = None
if (not no_cam) and (_gradcam is not None) and (preprocess_image is not None):
try:
rgb_for_cam = np.array(pil_resized).astype(np.float32) / 255.0
input_for_cam = preprocess_image(rgb_for_cam, mean=MEAN, std=STD).to(DEVICE)
with _gradcam_lock:
grayscale_cam = _gradcam(input_for_cam, targets=[ClassifierOutputTarget(pred_idx)])
cam_np = np.array(grayscale_cam)
cam_np = np.squeeze(cam_np)
if cam_np.ndim == 3:
cam_np = cam_np[0]
cam_np = cam_np.astype(np.float32)
if cam_np.max() > 0:
cam_np = (cam_np - cam_np.min()) / (cam_np.max() - cam_np.min() + 1e-8)
else:
cam_np = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.float32)
if cam_np.shape != (IMG_SIZE, IMG_SIZE):
cam_np = cv2.resize(cam_np, (IMG_SIZE, IMG_SIZE))
overlay_pil = overlay_heatmap_on_pil(pil_resized, cam_np)
overlay_b64 = encode_base64_png_from_pil(overlay_pil)
except Exception as e:
print("Grad-CAM generation error:", e)
traceback.print_exc()
overlay_b64 = None
# ---- Store in DB ----
probabilities_json = json.dumps(prob_map)
with engine.begin() as conn:
conn.execute(text(
"INSERT INTO predictions (filename, predicted_disease, confidence, probabilities, heatmap_base64, mask_base64) "
"VALUES (:fn,:pd,:c,:p,:h,:m)"
), {
"fn": f.filename,
"pd": pred_label,
"c": confidence,
"p": probabilities_json,
"h": overlay_b64,
"m": mask_b64
})
# ---- Build response ----
response = {
"predicted_disease": pred_label,
"confidence": confidence,
"probabilities": prob_map
}
if overlay_b64 is not None:
response["heatmap_png_base64"] = overlay_b64
if mask_b64 is not None:
response["mask_png_base64"] = mask_b64
return jsonify(response)
except Exception as e:
traceback.print_exc()
return jsonify({"error": str(e)}), 500
# ---------------- Main ----------------
if __name__ == "__main__":
print("Starting Flask server on 0.0.0.0:8000")
init_db()
try:
load_model()
except Exception as e:
print("Model failed to load on startup:", e)
app.run(host="0.0.0.0", port=8000, debug=False)