gajavegs's picture
Added Grad-CAM
259a0f2
import os
from typing import Any, Dict
from flask import Flask, jsonify, request, send_from_directory, abort
from PIL import Image
import torch
import torch.nn.functional as F
from dotenv import load_dotenv
from model_loader import load_alexnet_model, preprocess_image
from flask_cors import CORS
from io import BytesIO
import base64
import numpy as np
load_dotenv(override=True)
# HF sets PORT dynamically. Fall back to 7860 locally.
PORT = int(os.getenv("PORT", os.getenv("FLASK_PORT", "7860")))
HOST = "0.0.0.0"
MODEL_PATH = os.getenv("MODEL_PATH", "models/alexnext_vsf_bext.pth")
# Preset image paths via ENV
TP_PATH = os.getenv("TP_PATH", "images/TP.jpg")
TN_PATH = os.getenv("TN_PATH", "images/TN.jpg")
FN_PATH = os.getenv("FN_PATH", "images/FN.jpg")
FP_PATH = os.getenv("FP_PATH", "images/FP.jpg")
PRESET_MAP: Dict[str, str] = {
"TP": TP_PATH,
"TN": TN_PATH,
"FN": FN_PATH,
"FP": FP_PATH,
}
# Single worker is safest for GPU inference
torch.set_num_threads(1)
# Create app and static hosting
app = Flask(__name__, static_folder="static", static_url_path="")
CORS(app)
# Device selection
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model once at startup
model, classes = load_alexnet_model(MODEL_PATH, device=DEVICE)
model.to(DEVICE).eval()
@app.get("/")
def root() -> Any:
return send_from_directory(app.static_folder, "index.html")
@app.get("/health")
def health() -> Any:
return jsonify({"status": "ok", "device": str(DEVICE)})
def load_image(file_stream_or_path):
if isinstance(file_stream_or_path, str):
return Image.open(file_stream_or_path).convert("RGB")
return Image.open(file_stream_or_path).convert("RGB")
def generate_gradcam(img_pil: Image.Image, target_idx: int) -> str:
"""
Returns a data URL (PNG) of the Grad-CAM overlay for the target class.
"""
model.eval()
orig_w, orig_h = img_pil.size
# Last conv of standard AlexNet
target_layer = model.features[10]
activations = []
gradients = []
def fwd_hook(_, __, out):
# Save activations (detached) and attach a tensor hook to capture gradients
activations.append(out.detach())
out.register_hook(lambda g: gradients.append(g.detach().clone()))
handle = target_layer.register_forward_hook(fwd_hook)
try:
# Forward
input_tensor = preprocess_image(img_pil).to(DEVICE)
output = model(input_tensor) # [1, C]
# Backward on the selected class
if target_idx < 0 or target_idx >= output.shape[1]:
raise ValueError(f"target_idx {target_idx} out of range for output dim {output.shape[1]}")
score = output[0, target_idx]
model.zero_grad(set_to_none=True)
score.backward()
# Ensure hooks fired
if not activations or not gradients:
raise RuntimeError("Grad-CAM hooks did not capture activations/gradients")
A = activations[-1] # [1, C, H, W]
dA = gradients[-1] # [1, C, H, W]
# Weights: global-average-pool the gradients
weights = dA.mean(dim=(2, 3), keepdim=True) # [1, C, 1, 1]
cam = (weights * A).sum(dim=1, keepdim=False) # [1, H, W]
cam = torch.relu(cam)[0] # [H, W]
# Normalize to [0,1]
cam -= cam.min()
if cam.max() > 0:
cam /= cam.max()
# Resize CAM to original image size
cam_np = cam.detach().cpu().numpy()
cam_img = Image.fromarray((cam_np * 255).astype(np.uint8), mode="L")
cam_img = cam_img.resize((orig_w, orig_h), resample=Image.BILINEAR)
# Red alpha overlay
heat_rgba = Image.new("RGBA", (orig_w, orig_h), (255, 0, 0, 0))
heat_rgba.putalpha(cam_img)
base = img_pil.convert("RGBA")
overlayed = Image.alpha_composite(base, heat_rgba)
# Encode to data URL
buff = BytesIO()
overlayed.save(buff, format="PNG")
b64 = base64.b64encode(buff.getvalue()).decode("utf-8")
return f"data:image/png;base64,{b64}"
finally:
handle.remove() # <-- remove the actual handle you registered
def run_inference_with_gradcam(img: Image.Image) -> Dict[str, Any]:
"""Run softmax inference and also compute Grad-CAM for the predicted class."""
# Regular inference (no grad) for probabilities
input_tensor = preprocess_image(img).to(DEVICE)
with torch.no_grad():
output = model(input_tensor)
probabilities = F.softmax(output[0], dim=0).detach().cpu()
pred_prob, pred_idx = torch.max(probabilities, dim=0)
predicted_class = classes[int(pred_idx)]
# Grad-CAM for predicted index
gradcam_data_url = generate_gradcam(img, int(pred_idx))
return {
"class": predicted_class,
"confidence": float(pred_prob),
"probabilities": {cls: float(prob) for cls, prob in zip(classes, probabilities.tolist())},
"gradcam": gradcam_data_url,
}
def run_inference(img: Image.Image) -> Dict[str, Any]:
input_tensor = preprocess_image(img).to(DEVICE)
with torch.no_grad():
output = model(input_tensor)
probabilities = F.softmax(output[0], dim=0).detach().cpu()
pred_prob, pred_idx = torch.max(probabilities, dim=0)
predicted_class = classes[int(pred_idx)]
return {
"class": predicted_class,
"confidence": float(pred_prob),
"probabilities": {cls: float(prob) for cls, prob in zip(classes, probabilities.tolist())},
}
# --- Existing upload classification ---
@app.post("/predict_AlexNet")
def predict_alexnet() -> Any:
if "image" not in request.files:
return jsonify({"error": "Missing file field 'image'."}), 400
file = request.files["image"]
if not file:
return jsonify({"error": "Empty file."}), 400
try:
img = load_image(file.stream)
result = run_inference_with_gradcam(img) # << changed
return jsonify(result)
except Exception as e:
return jsonify({"error": f"Failed to process image: {e}"}), 400
# --- NEW: classify a preset image ---
@app.post("/predict_preset")
def predict_preset() -> Any:
try:
payload = request.get_json(force=True, silent=False)
except Exception:
payload = None
if not payload or "preset" not in payload:
return jsonify({"error": "Missing JSON field 'preset' (TP|TN|FN|FP)."}), 400
key = str(payload["preset"]).upper()
if key not in PRESET_MAP:
return jsonify({"error": f"Invalid preset '{key}'. Use one of: TP, TN, FN, FP."}), 400
path = PRESET_MAP[key]
if not os.path.exists(path):
return jsonify({"error": f"Preset image not found on server: {path}"}), 404
try:
img = load_image(path)
result = run_inference_with_gradcam(img) # << changed
result.update({"preset": key, "path": path})
return jsonify(result)
except Exception as e:
return jsonify({"error": f"Failed to process preset image: {e}"}), 400
# --- NEW: serve preset thumbnails safely ---
@app.get("/preset_image/<label>")
def preset_image(label: str):
key = str(label).upper()
if key not in PRESET_MAP:
abort(404)
path = PRESET_MAP[key]
if not os.path.exists(path):
abort(404)
directory, filename = os.path.split(os.path.abspath(path))
# Let Flask serve the actual file bytes
return send_from_directory(directory, filename)
if __name__ == "__main__":
debug = bool(int(os.getenv("FLASK_DEBUG", "0")))
app.run(host=HOST, port=PORT, debug=debug)