Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Any, Dict | |
| from flask import Flask, jsonify, request, send_from_directory, abort | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| from dotenv import load_dotenv | |
| from model_loader import load_alexnet_model, preprocess_image | |
| from flask_cors import CORS | |
| from io import BytesIO | |
| import base64 | |
| import numpy as np | |
| load_dotenv(override=True) | |
| # HF sets PORT dynamically. Fall back to 7860 locally. | |
| PORT = int(os.getenv("PORT", os.getenv("FLASK_PORT", "7860"))) | |
| HOST = "0.0.0.0" | |
| MODEL_PATH = os.getenv("MODEL_PATH", "models/alexnext_vsf_bext.pth") | |
| # Preset image paths via ENV | |
| TP_PATH = os.getenv("TP_PATH", "images/TP.jpg") | |
| TN_PATH = os.getenv("TN_PATH", "images/TN.jpg") | |
| FN_PATH = os.getenv("FN_PATH", "images/FN.jpg") | |
| FP_PATH = os.getenv("FP_PATH", "images/FP.jpg") | |
| PRESET_MAP: Dict[str, str] = { | |
| "TP": TP_PATH, | |
| "TN": TN_PATH, | |
| "FN": FN_PATH, | |
| "FP": FP_PATH, | |
| } | |
| # Single worker is safest for GPU inference | |
| torch.set_num_threads(1) | |
| # Create app and static hosting | |
| app = Flask(__name__, static_folder="static", static_url_path="") | |
| CORS(app) | |
| # Device selection | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load model once at startup | |
| model, classes = load_alexnet_model(MODEL_PATH, device=DEVICE) | |
| model.to(DEVICE).eval() | |
| def root() -> Any: | |
| return send_from_directory(app.static_folder, "index.html") | |
| def health() -> Any: | |
| return jsonify({"status": "ok", "device": str(DEVICE)}) | |
| def load_image(file_stream_or_path): | |
| if isinstance(file_stream_or_path, str): | |
| return Image.open(file_stream_or_path).convert("RGB") | |
| return Image.open(file_stream_or_path).convert("RGB") | |
| def generate_gradcam(img_pil: Image.Image, target_idx: int) -> str: | |
| """ | |
| Returns a data URL (PNG) of the Grad-CAM overlay for the target class. | |
| """ | |
| model.eval() | |
| orig_w, orig_h = img_pil.size | |
| # Last conv of standard AlexNet | |
| target_layer = model.features[10] | |
| activations = [] | |
| gradients = [] | |
| def fwd_hook(_, __, out): | |
| # Save activations (detached) and attach a tensor hook to capture gradients | |
| activations.append(out.detach()) | |
| out.register_hook(lambda g: gradients.append(g.detach().clone())) | |
| handle = target_layer.register_forward_hook(fwd_hook) | |
| try: | |
| # Forward | |
| input_tensor = preprocess_image(img_pil).to(DEVICE) | |
| output = model(input_tensor) # [1, C] | |
| # Backward on the selected class | |
| if target_idx < 0 or target_idx >= output.shape[1]: | |
| raise ValueError(f"target_idx {target_idx} out of range for output dim {output.shape[1]}") | |
| score = output[0, target_idx] | |
| model.zero_grad(set_to_none=True) | |
| score.backward() | |
| # Ensure hooks fired | |
| if not activations or not gradients: | |
| raise RuntimeError("Grad-CAM hooks did not capture activations/gradients") | |
| A = activations[-1] # [1, C, H, W] | |
| dA = gradients[-1] # [1, C, H, W] | |
| # Weights: global-average-pool the gradients | |
| weights = dA.mean(dim=(2, 3), keepdim=True) # [1, C, 1, 1] | |
| cam = (weights * A).sum(dim=1, keepdim=False) # [1, H, W] | |
| cam = torch.relu(cam)[0] # [H, W] | |
| # Normalize to [0,1] | |
| cam -= cam.min() | |
| if cam.max() > 0: | |
| cam /= cam.max() | |
| # Resize CAM to original image size | |
| cam_np = cam.detach().cpu().numpy() | |
| cam_img = Image.fromarray((cam_np * 255).astype(np.uint8), mode="L") | |
| cam_img = cam_img.resize((orig_w, orig_h), resample=Image.BILINEAR) | |
| # Red alpha overlay | |
| heat_rgba = Image.new("RGBA", (orig_w, orig_h), (255, 0, 0, 0)) | |
| heat_rgba.putalpha(cam_img) | |
| base = img_pil.convert("RGBA") | |
| overlayed = Image.alpha_composite(base, heat_rgba) | |
| # Encode to data URL | |
| buff = BytesIO() | |
| overlayed.save(buff, format="PNG") | |
| b64 = base64.b64encode(buff.getvalue()).decode("utf-8") | |
| return f"data:image/png;base64,{b64}" | |
| finally: | |
| handle.remove() # <-- remove the actual handle you registered | |
| def run_inference_with_gradcam(img: Image.Image) -> Dict[str, Any]: | |
| """Run softmax inference and also compute Grad-CAM for the predicted class.""" | |
| # Regular inference (no grad) for probabilities | |
| input_tensor = preprocess_image(img).to(DEVICE) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| probabilities = F.softmax(output[0], dim=0).detach().cpu() | |
| pred_prob, pred_idx = torch.max(probabilities, dim=0) | |
| predicted_class = classes[int(pred_idx)] | |
| # Grad-CAM for predicted index | |
| gradcam_data_url = generate_gradcam(img, int(pred_idx)) | |
| return { | |
| "class": predicted_class, | |
| "confidence": float(pred_prob), | |
| "probabilities": {cls: float(prob) for cls, prob in zip(classes, probabilities.tolist())}, | |
| "gradcam": gradcam_data_url, | |
| } | |
| def run_inference(img: Image.Image) -> Dict[str, Any]: | |
| input_tensor = preprocess_image(img).to(DEVICE) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| probabilities = F.softmax(output[0], dim=0).detach().cpu() | |
| pred_prob, pred_idx = torch.max(probabilities, dim=0) | |
| predicted_class = classes[int(pred_idx)] | |
| return { | |
| "class": predicted_class, | |
| "confidence": float(pred_prob), | |
| "probabilities": {cls: float(prob) for cls, prob in zip(classes, probabilities.tolist())}, | |
| } | |
| # --- Existing upload classification --- | |
| def predict_alexnet() -> Any: | |
| if "image" not in request.files: | |
| return jsonify({"error": "Missing file field 'image'."}), 400 | |
| file = request.files["image"] | |
| if not file: | |
| return jsonify({"error": "Empty file."}), 400 | |
| try: | |
| img = load_image(file.stream) | |
| result = run_inference_with_gradcam(img) # << changed | |
| return jsonify(result) | |
| except Exception as e: | |
| return jsonify({"error": f"Failed to process image: {e}"}), 400 | |
| # --- NEW: classify a preset image --- | |
| def predict_preset() -> Any: | |
| try: | |
| payload = request.get_json(force=True, silent=False) | |
| except Exception: | |
| payload = None | |
| if not payload or "preset" not in payload: | |
| return jsonify({"error": "Missing JSON field 'preset' (TP|TN|FN|FP)."}), 400 | |
| key = str(payload["preset"]).upper() | |
| if key not in PRESET_MAP: | |
| return jsonify({"error": f"Invalid preset '{key}'. Use one of: TP, TN, FN, FP."}), 400 | |
| path = PRESET_MAP[key] | |
| if not os.path.exists(path): | |
| return jsonify({"error": f"Preset image not found on server: {path}"}), 404 | |
| try: | |
| img = load_image(path) | |
| result = run_inference_with_gradcam(img) # << changed | |
| result.update({"preset": key, "path": path}) | |
| return jsonify(result) | |
| except Exception as e: | |
| return jsonify({"error": f"Failed to process preset image: {e}"}), 400 | |
| # --- NEW: serve preset thumbnails safely --- | |
| def preset_image(label: str): | |
| key = str(label).upper() | |
| if key not in PRESET_MAP: | |
| abort(404) | |
| path = PRESET_MAP[key] | |
| if not os.path.exists(path): | |
| abort(404) | |
| directory, filename = os.path.split(os.path.abspath(path)) | |
| # Let Flask serve the actual file bytes | |
| return send_from_directory(directory, filename) | |
| if __name__ == "__main__": | |
| debug = bool(int(os.getenv("FLASK_DEBUG", "0"))) | |
| app.run(host=HOST, port=PORT, debug=debug) | |