diff --git a/Dockerfile b/Dockerfile index 719098724057d06b794a610cb3ec73e109ef35e8..0a8887e4b0bb32b8b4df1064d04b22c3d33a5101 100644 --- a/Dockerfile +++ b/Dockerfile @@ -25,6 +25,11 @@ COPY . /app # Hugging Face sets $PORT; Gunicorn will bind to it. ENV PORT=7860 ENV MODEL_PATH=models/alexnext_vsf_bext.pth +ENV CONFUSion_PATH=images/TP.jpg +ENV TP_PATH=images/TP.jpg +ENV TN_PATH=images/TN.jpg +ENV FN_PATH=images/FN.jpg +ENV FP_PATH=images/FP.jpg ENV FLASK_DEBUG=0 # Single worker (GPU inference), thread worker for simplicity diff --git a/__pycache__/model_loader.cpython-39.pyc b/__pycache__/model_loader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ded52a3233645c5830b64e9e0252a0caea3c71ac Binary files /dev/null and b/__pycache__/model_loader.cpython-39.pyc differ diff --git a/app.py b/app.py index 241eafe0dbafc300f0c6b844077a620e05655448..0f7af844a79490b80860a3b009209f7dce5cd013 100644 --- a/app.py +++ b/app.py @@ -1,11 +1,12 @@ import os -from typing import Any -from flask import Flask, jsonify, request, send_from_directory +from typing import Any, Dict +from flask import Flask, jsonify, request, send_from_directory, abort from PIL import Image import torch import torch.nn.functional as F from dotenv import load_dotenv from model_loader import load_alexnet_model, preprocess_image +from flask_cors import CORS load_dotenv(override=True) @@ -14,11 +15,25 @@ PORT = int(os.getenv("PORT", os.getenv("FLASK_PORT", "7860"))) HOST = "0.0.0.0" MODEL_PATH = os.getenv("MODEL_PATH", "models/alexnext_vsf_bext.pth") +# Preset image paths via ENV +TP_PATH = os.getenv("TP_PATH", "images/TP.jpg") +TN_PATH = os.getenv("TN_PATH", "images/TN.jpg") +FN_PATH = os.getenv("FN_PATH", "images/FN.jpg") +FP_PATH = os.getenv("FP_PATH", "images/FP.jpg") + +PRESET_MAP: Dict[str, str] = { + "TP": TP_PATH, + "TN": TN_PATH, + "FN": FN_PATH, + "FP": FP_PATH, +} + # Single worker is safest for GPU inference torch.set_num_threads(1) # Create app and static hosting app = Flask(__name__, static_folder="static", static_url_path="") +CORS(app) # Device selection DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -29,48 +44,84 @@ model.to(DEVICE).eval() @app.get("/") def root() -> Any: - # serve your frontend return send_from_directory(app.static_folder, "index.html") @app.get("/health") def health() -> Any: return jsonify({"status": "ok", "device": str(DEVICE)}) -def load_image(file_stream): - return Image.open(file_stream).convert("RGB") - +def load_image(file_stream_or_path): + if isinstance(file_stream_or_path, str): + return Image.open(file_stream_or_path).convert("RGB") + return Image.open(file_stream_or_path).convert("RGB") + +def run_inference(img: Image.Image) -> Dict[str, Any]: + input_tensor = preprocess_image(img).to(DEVICE) + with torch.no_grad(): + output = model(input_tensor) + probabilities = F.softmax(output[0], dim=0).detach().cpu() + pred_prob, pred_idx = torch.max(probabilities, dim=0) + predicted_class = classes[int(pred_idx)] + return { + "class": predicted_class, + "confidence": float(pred_prob), + "probabilities": {cls: float(prob) for cls, prob in zip(classes, probabilities.tolist())}, + } + +# --- Existing upload classification --- @app.post("/predict_AlexNet") def predict_alexnet() -> Any: if "image" not in request.files: return jsonify({"error": "Missing file field 'image'."}), 400 - file = request.files["image"] if not file: return jsonify({"error": "Empty file."}), 400 - try: img = load_image(file.stream) - input_tensor = preprocess_image(img).to(DEVICE) - - with torch.no_grad(): - output = model(input_tensor) - probabilities = F.softmax(output[0], dim=0).detach().cpu() - - pred_prob, pred_idx = torch.max(probabilities, dim=0) - predicted_class = classes[int(pred_idx)] - - result = { - "class": predicted_class, - "confidence": float(pred_prob), - "probabilities": { - cls: float(prob) for cls, prob in zip(classes, probabilities.tolist()) - }, - } + result = run_inference(img) return jsonify(result) - except Exception as e: return jsonify({"error": f"Failed to process image: {e}"}), 400 +# --- NEW: classify a preset image --- +@app.post("/predict_preset") +def predict_preset() -> Any: + try: + payload = request.get_json(force=True, silent=False) + except Exception: + payload = None + if not payload or "preset" not in payload: + return jsonify({"error": "Missing JSON field 'preset' (TP|TN|FN|FP)."}), 400 + + key = str(payload["preset"]).upper() + if key not in PRESET_MAP: + return jsonify({"error": f"Invalid preset '{key}'. Use one of: TP, TN, FN, FP."}), 400 + + path = PRESET_MAP[key] + if not os.path.exists(path): + return jsonify({"error": f"Preset image not found on server: {path}"}), 404 + + try: + img = load_image(path) + result = run_inference(img) + result.update({"preset": key, "path": path}) + return jsonify(result) + except Exception as e: + return jsonify({"error": f"Failed to process preset image: {e}"}), 400 + +# --- NEW: serve preset thumbnails safely --- +@app.get("/preset_image/