app / app.py
itsLu's picture
Update app.py
5102e58 verified
from flask import send_from_directory
import os
import re
from dataclasses import dataclass
from typing import Dict, List, Tuple
import numpy as np
from flask import Flask, jsonify, render_template, request
import tensorflow as tf
from tensorflow.keras.models import load_model
# ----------------------------
# Model definitions
# ----------------------------
@dataclass(frozen=True)
class ModelSpec:
id: str
display_name: str # what the user sees (friendly + technical)
filename: str # under ./models/
arch: str # "resnet" | "efficientnet"
img_size: int # input resolution
class_names: Tuple[str, ...] # output order used during training
recommended_threshold: float # per-model uncertainty cutoff (from notebooks)
# NOTE:
# Your training notebooks use *different* class ordering between the ResNet notebooks
# (sorted unique categories) and the EfficientNet notebook (explicit list).
# We keep per-model class order to avoid mislabeling probabilities.
RESNET_CLASS_ORDER = ("MildDemented", "ModerateDemented", "NonDemented", "VeryMildDemented")
EFFICIENTNET_CLASS_ORDER = ("NonDemented", "VeryMildDemented", "MildDemented", "ModerateDemented")
MODEL_SPECS: List[ModelSpec] = [
ModelSpec("atlas", "Atlas — ResNet-50", "resnet50.h5", "resnet", 224, RESNET_CLASS_ORDER, 0.95),
ModelSpec("orion", "Orion — ResNet-101", "resnet101.h5", "resnet", 224, RESNET_CLASS_ORDER, 0.95),
ModelSpec("pulse", "Pulse — EfficientNet-B2", "efficientnetb2.h5", "efficientnet", 260, EFFICIENTNET_CLASS_ORDER, 0.95),
]
# ----------------------------
# Flask app
# ----------------------------
app = Flask(__name__)
# Lazy-loaded models (load on first use). Keep only what we need in CPU Spaces.
_loaded_models: Dict[str, tf.keras.Model] = {}
def _get_spec(model_id: str) -> ModelSpec:
for s in MODEL_SPECS:
if s.id == model_id:
return s
raise KeyError(f"Unknown model_id: {model_id}")
def _get_preprocess_fn(arch: str):
if arch == "resnet":
from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess
return resnet_preprocess
if arch == "efficientnet":
from tensorflow.keras.applications.efficientnet import preprocess_input as eff_preprocess
return eff_preprocess
raise ValueError(f"Unknown arch: {arch}")
def _load_model(spec: ModelSpec) -> tf.keras.Model:
if spec.id in _loaded_models:
return _loaded_models[spec.id]
model_path = os.path.join(os.path.dirname(__file__), "models", spec.filename)
if not os.path.exists(model_path):
raise FileNotFoundError(
f"Model file not found: {model_path}. "
f"Place it at models/{spec.filename} in your Space."
)
# CPU-friendly TF settings (small wins on free Spaces)
try:
tf.config.threading.set_intra_op_parallelism_threads(0)
tf.config.threading.set_inter_op_parallelism_threads(0)
except Exception:
pass
model = load_model(model_path, compile=False)
_loaded_models[spec.id] = model
return model
def _read_image(file_storage, img_size: int, preprocess_fn):
# Decode image
raw = file_storage.read()
image = tf.io.decode_image(raw, channels=3, expand_animations=False)
image = tf.image.resize(image, [img_size, img_size])
image = tf.cast(image, tf.float32)
image = preprocess_fn(image)
image = tf.expand_dims(image, axis=0) # [1, H, W, 3]
return image
def _predict(model: tf.keras.Model, image_tensor, class_names: Tuple[str, ...], threshold: float):
probs = model.predict(image_tensor, verbose=0)[0].astype(float)
probs = np.clip(probs, 0.0, 1.0)
best_idx = int(np.argmax(probs))
best_prob = float(np.max(probs))
# Add "Uncertain" post-hoc (not a model output class)
is_uncertain = best_prob < threshold
# Build response payload
by_class = [
{"id": name, "label": _pretty_label(name), "prob": float(probs[i])}
for i, name in enumerate(class_names)
]
by_class.sort(key=lambda x: x["prob"], reverse=True)
return {
"prediction": {
"id": "Uncertain" if is_uncertain else class_names[best_idx],
"label": "Uncertain" if is_uncertain else _pretty_label(class_names[best_idx]),
"confidence": best_prob,
"threshold": threshold,
},
"probabilities": by_class,
}
def _pretty_label(name: str) -> str:
# Internal training labels -> user-facing labels (final wording)
mapping = {
"NonDemented": "Healthy",
"VeryMildDemented": "Very Mildly Demented",
"MildDemented": "Mildly Demented",
"ModerateDemented": "Moderately Demented",
# Post-hoc
"Uncertain": "Uncertain",
}
return mapping.get(name, name)
from flask import send_from_directory
import os
@app.route("/", defaults={"path": ""})
@app.route("/<path:path>")
def serve_frontend(path):
# Keep API routes intact
if path.startswith("api/"):
return ("Not Found", 404)
clarity_dir = os.path.join(app.root_path, "static", "clarity")
requested = os.path.join(clarity_dir, path)
# Serve real static files if they exist
if path and os.path.isfile(requested):
return send_from_directory(clarity_dir, path)
# Otherwise serve React index.html
return send_from_directory(clarity_dir, "index.html")
@app.get("/api/models")
def api_models():
return jsonify({
"models": [
{
"id": s.id,
"name": s.display_name,
"img_size": s.img_size,
"classes": [{"id": c, "label": _pretty_label(c)} for c in s.class_names],
"recommended_threshold": s.recommended_threshold,
}
for s in MODEL_SPECS
],
"default_model_id": MODEL_SPECS[0].id,
})
@app.post("/api/classify")
def api_classify():
if "file" not in request.files:
return jsonify({"error": "No file uploaded (field name must be 'file')."}), 400
model_id = request.form.get("model_id", MODEL_SPECS[0].id)
spec = _get_spec(model_id)
# Threshold is model-specific and not user-adjustable
threshold = spec.recommended_threshold
try:
model = _load_model(spec)
preprocess_fn = _get_preprocess_fn(spec.arch)
image_tensor = _read_image(request.files["file"], spec.img_size, preprocess_fn)
payload = _predict(model, image_tensor, spec.class_names, threshold)
payload["model"] = {"id": spec.id, "name": spec.display_name}
return jsonify(payload)
except FileNotFoundError as e:
return jsonify({"error": str(e)}), 500
except Exception as e:
return jsonify({"error": f"Failed to classify image: {e}"}), 500
if __name__ == "__main__":
# Local dev: python app.py
# In Spaces (Dockerfile), gunicorn is used.
app.run(host="0.0.0.0", port=int(os.getenv("PORT", "7860")), debug=False)