from flask import send_from_directory import os import re from dataclasses import dataclass from typing import Dict, List, Tuple import numpy as np from flask import Flask, jsonify, render_template, request import tensorflow as tf from tensorflow.keras.models import load_model # ---------------------------- # Model definitions # ---------------------------- @dataclass(frozen=True) class ModelSpec: id: str display_name: str # what the user sees (friendly + technical) filename: str # under ./models/ arch: str # "resnet" | "efficientnet" img_size: int # input resolution class_names: Tuple[str, ...] # output order used during training recommended_threshold: float # per-model uncertainty cutoff (from notebooks) # NOTE: # Your training notebooks use *different* class ordering between the ResNet notebooks # (sorted unique categories) and the EfficientNet notebook (explicit list). # We keep per-model class order to avoid mislabeling probabilities. RESNET_CLASS_ORDER = ("MildDemented", "ModerateDemented", "NonDemented", "VeryMildDemented") EFFICIENTNET_CLASS_ORDER = ("NonDemented", "VeryMildDemented", "MildDemented", "ModerateDemented") MODEL_SPECS: List[ModelSpec] = [ ModelSpec("atlas", "Atlas — ResNet-50", "resnet50.h5", "resnet", 224, RESNET_CLASS_ORDER, 0.95), ModelSpec("orion", "Orion — ResNet-101", "resnet101.h5", "resnet", 224, RESNET_CLASS_ORDER, 0.95), ModelSpec("pulse", "Pulse — EfficientNet-B2", "efficientnetb2.h5", "efficientnet", 260, EFFICIENTNET_CLASS_ORDER, 0.95), ] # ---------------------------- # Flask app # ---------------------------- app = Flask(__name__) # Lazy-loaded models (load on first use). Keep only what we need in CPU Spaces. _loaded_models: Dict[str, tf.keras.Model] = {} def _get_spec(model_id: str) -> ModelSpec: for s in MODEL_SPECS: if s.id == model_id: return s raise KeyError(f"Unknown model_id: {model_id}") def _get_preprocess_fn(arch: str): if arch == "resnet": from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess return resnet_preprocess if arch == "efficientnet": from tensorflow.keras.applications.efficientnet import preprocess_input as eff_preprocess return eff_preprocess raise ValueError(f"Unknown arch: {arch}") def _load_model(spec: ModelSpec) -> tf.keras.Model: if spec.id in _loaded_models: return _loaded_models[spec.id] model_path = os.path.join(os.path.dirname(__file__), "models", spec.filename) if not os.path.exists(model_path): raise FileNotFoundError( f"Model file not found: {model_path}. " f"Place it at models/{spec.filename} in your Space." ) # CPU-friendly TF settings (small wins on free Spaces) try: tf.config.threading.set_intra_op_parallelism_threads(0) tf.config.threading.set_inter_op_parallelism_threads(0) except Exception: pass model = load_model(model_path, compile=False) _loaded_models[spec.id] = model return model def _read_image(file_storage, img_size: int, preprocess_fn): # Decode image raw = file_storage.read() image = tf.io.decode_image(raw, channels=3, expand_animations=False) image = tf.image.resize(image, [img_size, img_size]) image = tf.cast(image, tf.float32) image = preprocess_fn(image) image = tf.expand_dims(image, axis=0) # [1, H, W, 3] return image def _predict(model: tf.keras.Model, image_tensor, class_names: Tuple[str, ...], threshold: float): probs = model.predict(image_tensor, verbose=0)[0].astype(float) probs = np.clip(probs, 0.0, 1.0) best_idx = int(np.argmax(probs)) best_prob = float(np.max(probs)) # Add "Uncertain" post-hoc (not a model output class) is_uncertain = best_prob < threshold # Build response payload by_class = [ {"id": name, "label": _pretty_label(name), "prob": float(probs[i])} for i, name in enumerate(class_names) ] by_class.sort(key=lambda x: x["prob"], reverse=True) return { "prediction": { "id": "Uncertain" if is_uncertain else class_names[best_idx], "label": "Uncertain" if is_uncertain else _pretty_label(class_names[best_idx]), "confidence": best_prob, "threshold": threshold, }, "probabilities": by_class, } def _pretty_label(name: str) -> str: # Internal training labels -> user-facing labels (final wording) mapping = { "NonDemented": "Healthy", "VeryMildDemented": "Very Mildly Demented", "MildDemented": "Mildly Demented", "ModerateDemented": "Moderately Demented", # Post-hoc "Uncertain": "Uncertain", } return mapping.get(name, name) from flask import send_from_directory import os @app.route("/", defaults={"path": ""}) @app.route("/") def serve_frontend(path): # Keep API routes intact if path.startswith("api/"): return ("Not Found", 404) clarity_dir = os.path.join(app.root_path, "static", "clarity") requested = os.path.join(clarity_dir, path) # Serve real static files if they exist if path and os.path.isfile(requested): return send_from_directory(clarity_dir, path) # Otherwise serve React index.html return send_from_directory(clarity_dir, "index.html") @app.get("/api/models") def api_models(): return jsonify({ "models": [ { "id": s.id, "name": s.display_name, "img_size": s.img_size, "classes": [{"id": c, "label": _pretty_label(c)} for c in s.class_names], "recommended_threshold": s.recommended_threshold, } for s in MODEL_SPECS ], "default_model_id": MODEL_SPECS[0].id, }) @app.post("/api/classify") def api_classify(): if "file" not in request.files: return jsonify({"error": "No file uploaded (field name must be 'file')."}), 400 model_id = request.form.get("model_id", MODEL_SPECS[0].id) spec = _get_spec(model_id) # Threshold is model-specific and not user-adjustable threshold = spec.recommended_threshold try: model = _load_model(spec) preprocess_fn = _get_preprocess_fn(spec.arch) image_tensor = _read_image(request.files["file"], spec.img_size, preprocess_fn) payload = _predict(model, image_tensor, spec.class_names, threshold) payload["model"] = {"id": spec.id, "name": spec.display_name} return jsonify(payload) except FileNotFoundError as e: return jsonify({"error": str(e)}), 500 except Exception as e: return jsonify({"error": f"Failed to classify image: {e}"}), 500 if __name__ == "__main__": # Local dev: python app.py # In Spaces (Dockerfile), gunicorn is used. app.run(host="0.0.0.0", port=int(os.getenv("PORT", "7860")), debug=False)