import os from typing import Any, Dict from flask import Flask, jsonify, request, send_from_directory, abort from PIL import Image import torch import torch.nn.functional as F from dotenv import load_dotenv from model_loader import load_alexnet_model, preprocess_image from flask_cors import CORS from io import BytesIO import base64 import numpy as np load_dotenv(override=True) # HF sets PORT dynamically. Fall back to 7860 locally. PORT = int(os.getenv("PORT", os.getenv("FLASK_PORT", "7860"))) HOST = "0.0.0.0" MODEL_PATH = os.getenv("MODEL_PATH", "models/alexnext_vsf_bext.pth") # Preset image paths via ENV TP_PATH = os.getenv("TP_PATH", "images/TP.jpg") TN_PATH = os.getenv("TN_PATH", "images/TN.jpg") FN_PATH = os.getenv("FN_PATH", "images/FN.jpg") FP_PATH = os.getenv("FP_PATH", "images/FP.jpg") PRESET_MAP: Dict[str, str] = { "TP": TP_PATH, "TN": TN_PATH, "FN": FN_PATH, "FP": FP_PATH, } # Single worker is safest for GPU inference torch.set_num_threads(1) # Create app and static hosting app = Flask(__name__, static_folder="static", static_url_path="") CORS(app) # Device selection DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model once at startup model, classes = load_alexnet_model(MODEL_PATH, device=DEVICE) model.to(DEVICE).eval() @app.get("/") def root() -> Any: return send_from_directory(app.static_folder, "index.html") @app.get("/health") def health() -> Any: return jsonify({"status": "ok", "device": str(DEVICE)}) def load_image(file_stream_or_path): if isinstance(file_stream_or_path, str): return Image.open(file_stream_or_path).convert("RGB") return Image.open(file_stream_or_path).convert("RGB") def generate_gradcam(img_pil: Image.Image, target_idx: int) -> str: """ Returns a data URL (PNG) of the Grad-CAM overlay for the target class. """ model.eval() orig_w, orig_h = img_pil.size # Last conv of standard AlexNet target_layer = model.features[10] activations = [] gradients = [] def fwd_hook(_, __, out): # Save activations (detached) and attach a tensor hook to capture gradients activations.append(out.detach()) out.register_hook(lambda g: gradients.append(g.detach().clone())) handle = target_layer.register_forward_hook(fwd_hook) try: # Forward input_tensor = preprocess_image(img_pil).to(DEVICE) output = model(input_tensor) # [1, C] # Backward on the selected class if target_idx < 0 or target_idx >= output.shape[1]: raise ValueError(f"target_idx {target_idx} out of range for output dim {output.shape[1]}") score = output[0, target_idx] model.zero_grad(set_to_none=True) score.backward() # Ensure hooks fired if not activations or not gradients: raise RuntimeError("Grad-CAM hooks did not capture activations/gradients") A = activations[-1] # [1, C, H, W] dA = gradients[-1] # [1, C, H, W] # Weights: global-average-pool the gradients weights = dA.mean(dim=(2, 3), keepdim=True) # [1, C, 1, 1] cam = (weights * A).sum(dim=1, keepdim=False) # [1, H, W] cam = torch.relu(cam)[0] # [H, W] # Normalize to [0,1] cam -= cam.min() if cam.max() > 0: cam /= cam.max() # Resize CAM to original image size cam_np = cam.detach().cpu().numpy() cam_img = Image.fromarray((cam_np * 255).astype(np.uint8), mode="L") cam_img = cam_img.resize((orig_w, orig_h), resample=Image.BILINEAR) # Red alpha overlay heat_rgba = Image.new("RGBA", (orig_w, orig_h), (255, 0, 0, 0)) heat_rgba.putalpha(cam_img) base = img_pil.convert("RGBA") overlayed = Image.alpha_composite(base, heat_rgba) # Encode to data URL buff = BytesIO() overlayed.save(buff, format="PNG") b64 = base64.b64encode(buff.getvalue()).decode("utf-8") return f"data:image/png;base64,{b64}" finally: handle.remove() # <-- remove the actual handle you registered def run_inference_with_gradcam(img: Image.Image) -> Dict[str, Any]: """Run softmax inference and also compute Grad-CAM for the predicted class.""" # Regular inference (no grad) for probabilities input_tensor = preprocess_image(img).to(DEVICE) with torch.no_grad(): output = model(input_tensor) probabilities = F.softmax(output[0], dim=0).detach().cpu() pred_prob, pred_idx = torch.max(probabilities, dim=0) predicted_class = classes[int(pred_idx)] # Grad-CAM for predicted index gradcam_data_url = generate_gradcam(img, int(pred_idx)) return { "class": predicted_class, "confidence": float(pred_prob), "probabilities": {cls: float(prob) for cls, prob in zip(classes, probabilities.tolist())}, "gradcam": gradcam_data_url, } def run_inference(img: Image.Image) -> Dict[str, Any]: input_tensor = preprocess_image(img).to(DEVICE) with torch.no_grad(): output = model(input_tensor) probabilities = F.softmax(output[0], dim=0).detach().cpu() pred_prob, pred_idx = torch.max(probabilities, dim=0) predicted_class = classes[int(pred_idx)] return { "class": predicted_class, "confidence": float(pred_prob), "probabilities": {cls: float(prob) for cls, prob in zip(classes, probabilities.tolist())}, } # --- Existing upload classification --- @app.post("/predict_AlexNet") def predict_alexnet() -> Any: if "image" not in request.files: return jsonify({"error": "Missing file field 'image'."}), 400 file = request.files["image"] if not file: return jsonify({"error": "Empty file."}), 400 try: img = load_image(file.stream) result = run_inference_with_gradcam(img) # << changed return jsonify(result) except Exception as e: return jsonify({"error": f"Failed to process image: {e}"}), 400 # --- NEW: classify a preset image --- @app.post("/predict_preset") def predict_preset() -> Any: try: payload = request.get_json(force=True, silent=False) except Exception: payload = None if not payload or "preset" not in payload: return jsonify({"error": "Missing JSON field 'preset' (TP|TN|FN|FP)."}), 400 key = str(payload["preset"]).upper() if key not in PRESET_MAP: return jsonify({"error": f"Invalid preset '{key}'. Use one of: TP, TN, FN, FP."}), 400 path = PRESET_MAP[key] if not os.path.exists(path): return jsonify({"error": f"Preset image not found on server: {path}"}), 404 try: img = load_image(path) result = run_inference_with_gradcam(img) # << changed result.update({"preset": key, "path": path}) return jsonify(result) except Exception as e: return jsonify({"error": f"Failed to process preset image: {e}"}), 400 # --- NEW: serve preset thumbnails safely --- @app.get("/preset_image/