Spaces:
Sleeping
Sleeping
| from flask import send_from_directory | |
| import os | |
| import re | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Tuple | |
| import numpy as np | |
| from flask import Flask, jsonify, render_template, request | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| # ---------------------------- | |
| # Model definitions | |
| # ---------------------------- | |
| class ModelSpec: | |
| id: str | |
| display_name: str # what the user sees (friendly + technical) | |
| filename: str # under ./models/ | |
| arch: str # "resnet" | "efficientnet" | |
| img_size: int # input resolution | |
| class_names: Tuple[str, ...] # output order used during training | |
| recommended_threshold: float # per-model uncertainty cutoff (from notebooks) | |
| # NOTE: | |
| # Your training notebooks use *different* class ordering between the ResNet notebooks | |
| # (sorted unique categories) and the EfficientNet notebook (explicit list). | |
| # We keep per-model class order to avoid mislabeling probabilities. | |
| RESNET_CLASS_ORDER = ("MildDemented", "ModerateDemented", "NonDemented", "VeryMildDemented") | |
| EFFICIENTNET_CLASS_ORDER = ("NonDemented", "VeryMildDemented", "MildDemented", "ModerateDemented") | |
| MODEL_SPECS: List[ModelSpec] = [ | |
| ModelSpec("atlas", "Atlas — ResNet-50", "resnet50.h5", "resnet", 224, RESNET_CLASS_ORDER, 0.95), | |
| ModelSpec("orion", "Orion — ResNet-101", "resnet101.h5", "resnet", 224, RESNET_CLASS_ORDER, 0.95), | |
| ModelSpec("pulse", "Pulse — EfficientNet-B2", "efficientnetb2.h5", "efficientnet", 260, EFFICIENTNET_CLASS_ORDER, 0.95), | |
| ] | |
| # ---------------------------- | |
| # Flask app | |
| # ---------------------------- | |
| app = Flask(__name__) | |
| # Lazy-loaded models (load on first use). Keep only what we need in CPU Spaces. | |
| _loaded_models: Dict[str, tf.keras.Model] = {} | |
| def _get_spec(model_id: str) -> ModelSpec: | |
| for s in MODEL_SPECS: | |
| if s.id == model_id: | |
| return s | |
| raise KeyError(f"Unknown model_id: {model_id}") | |
| def _get_preprocess_fn(arch: str): | |
| if arch == "resnet": | |
| from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess | |
| return resnet_preprocess | |
| if arch == "efficientnet": | |
| from tensorflow.keras.applications.efficientnet import preprocess_input as eff_preprocess | |
| return eff_preprocess | |
| raise ValueError(f"Unknown arch: {arch}") | |
| def _load_model(spec: ModelSpec) -> tf.keras.Model: | |
| if spec.id in _loaded_models: | |
| return _loaded_models[spec.id] | |
| model_path = os.path.join(os.path.dirname(__file__), "models", spec.filename) | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError( | |
| f"Model file not found: {model_path}. " | |
| f"Place it at models/{spec.filename} in your Space." | |
| ) | |
| # CPU-friendly TF settings (small wins on free Spaces) | |
| try: | |
| tf.config.threading.set_intra_op_parallelism_threads(0) | |
| tf.config.threading.set_inter_op_parallelism_threads(0) | |
| except Exception: | |
| pass | |
| model = load_model(model_path, compile=False) | |
| _loaded_models[spec.id] = model | |
| return model | |
| def _read_image(file_storage, img_size: int, preprocess_fn): | |
| # Decode image | |
| raw = file_storage.read() | |
| image = tf.io.decode_image(raw, channels=3, expand_animations=False) | |
| image = tf.image.resize(image, [img_size, img_size]) | |
| image = tf.cast(image, tf.float32) | |
| image = preprocess_fn(image) | |
| image = tf.expand_dims(image, axis=0) # [1, H, W, 3] | |
| return image | |
| def _predict(model: tf.keras.Model, image_tensor, class_names: Tuple[str, ...], threshold: float): | |
| probs = model.predict(image_tensor, verbose=0)[0].astype(float) | |
| probs = np.clip(probs, 0.0, 1.0) | |
| best_idx = int(np.argmax(probs)) | |
| best_prob = float(np.max(probs)) | |
| # Add "Uncertain" post-hoc (not a model output class) | |
| is_uncertain = best_prob < threshold | |
| # Build response payload | |
| by_class = [ | |
| {"id": name, "label": _pretty_label(name), "prob": float(probs[i])} | |
| for i, name in enumerate(class_names) | |
| ] | |
| by_class.sort(key=lambda x: x["prob"], reverse=True) | |
| return { | |
| "prediction": { | |
| "id": "Uncertain" if is_uncertain else class_names[best_idx], | |
| "label": "Uncertain" if is_uncertain else _pretty_label(class_names[best_idx]), | |
| "confidence": best_prob, | |
| "threshold": threshold, | |
| }, | |
| "probabilities": by_class, | |
| } | |
| def _pretty_label(name: str) -> str: | |
| # Internal training labels -> user-facing labels (final wording) | |
| mapping = { | |
| "NonDemented": "Healthy", | |
| "VeryMildDemented": "Very Mildly Demented", | |
| "MildDemented": "Mildly Demented", | |
| "ModerateDemented": "Moderately Demented", | |
| # Post-hoc | |
| "Uncertain": "Uncertain", | |
| } | |
| return mapping.get(name, name) | |
| from flask import send_from_directory | |
| import os | |
| def serve_frontend(path): | |
| # Keep API routes intact | |
| if path.startswith("api/"): | |
| return ("Not Found", 404) | |
| clarity_dir = os.path.join(app.root_path, "static", "clarity") | |
| requested = os.path.join(clarity_dir, path) | |
| # Serve real static files if they exist | |
| if path and os.path.isfile(requested): | |
| return send_from_directory(clarity_dir, path) | |
| # Otherwise serve React index.html | |
| return send_from_directory(clarity_dir, "index.html") | |
| def api_models(): | |
| return jsonify({ | |
| "models": [ | |
| { | |
| "id": s.id, | |
| "name": s.display_name, | |
| "img_size": s.img_size, | |
| "classes": [{"id": c, "label": _pretty_label(c)} for c in s.class_names], | |
| "recommended_threshold": s.recommended_threshold, | |
| } | |
| for s in MODEL_SPECS | |
| ], | |
| "default_model_id": MODEL_SPECS[0].id, | |
| }) | |
| def api_classify(): | |
| if "file" not in request.files: | |
| return jsonify({"error": "No file uploaded (field name must be 'file')."}), 400 | |
| model_id = request.form.get("model_id", MODEL_SPECS[0].id) | |
| spec = _get_spec(model_id) | |
| # Threshold is model-specific and not user-adjustable | |
| threshold = spec.recommended_threshold | |
| try: | |
| model = _load_model(spec) | |
| preprocess_fn = _get_preprocess_fn(spec.arch) | |
| image_tensor = _read_image(request.files["file"], spec.img_size, preprocess_fn) | |
| payload = _predict(model, image_tensor, spec.class_names, threshold) | |
| payload["model"] = {"id": spec.id, "name": spec.display_name} | |
| return jsonify(payload) | |
| except FileNotFoundError as e: | |
| return jsonify({"error": str(e)}), 500 | |
| except Exception as e: | |
| return jsonify({"error": f"Failed to classify image: {e}"}), 500 | |
| if __name__ == "__main__": | |
| # Local dev: python app.py | |
| # In Spaces (Dockerfile), gunicorn is used. | |
| app.run(host="0.0.0.0", port=int(os.getenv("PORT", "7860")), debug=False) |