Spaces:
Sleeping
Sleeping
File size: 7,053 Bytes
5102e58 6a47082 86821de 2649a67 86821de 2649a67 6a47082 86821de 2649a67 6a47082 86821de 6a47082 86821de 00a6594 86821de 6a47082 86821de 5102e58 86821de 6a47082 86821de 6a47082 86821de 6a47082 86821de | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 | from flask import send_from_directory
import os
import re
from dataclasses import dataclass
from typing import Dict, List, Tuple
import numpy as np
from flask import Flask, jsonify, render_template, request
import tensorflow as tf
from tensorflow.keras.models import load_model
# ----------------------------
# Model definitions
# ----------------------------
@dataclass(frozen=True)
class ModelSpec:
id: str
display_name: str # what the user sees (friendly + technical)
filename: str # under ./models/
arch: str # "resnet" | "efficientnet"
img_size: int # input resolution
class_names: Tuple[str, ...] # output order used during training
recommended_threshold: float # per-model uncertainty cutoff (from notebooks)
# NOTE:
# Your training notebooks use *different* class ordering between the ResNet notebooks
# (sorted unique categories) and the EfficientNet notebook (explicit list).
# We keep per-model class order to avoid mislabeling probabilities.
RESNET_CLASS_ORDER = ("MildDemented", "ModerateDemented", "NonDemented", "VeryMildDemented")
EFFICIENTNET_CLASS_ORDER = ("NonDemented", "VeryMildDemented", "MildDemented", "ModerateDemented")
MODEL_SPECS: List[ModelSpec] = [
ModelSpec("atlas", "Atlas — ResNet-50", "resnet50.h5", "resnet", 224, RESNET_CLASS_ORDER, 0.95),
ModelSpec("orion", "Orion — ResNet-101", "resnet101.h5", "resnet", 224, RESNET_CLASS_ORDER, 0.95),
ModelSpec("pulse", "Pulse — EfficientNet-B2", "efficientnetb2.h5", "efficientnet", 260, EFFICIENTNET_CLASS_ORDER, 0.95),
]
# ----------------------------
# Flask app
# ----------------------------
app = Flask(__name__)
# Lazy-loaded models (load on first use). Keep only what we need in CPU Spaces.
_loaded_models: Dict[str, tf.keras.Model] = {}
def _get_spec(model_id: str) -> ModelSpec:
for s in MODEL_SPECS:
if s.id == model_id:
return s
raise KeyError(f"Unknown model_id: {model_id}")
def _get_preprocess_fn(arch: str):
if arch == "resnet":
from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess
return resnet_preprocess
if arch == "efficientnet":
from tensorflow.keras.applications.efficientnet import preprocess_input as eff_preprocess
return eff_preprocess
raise ValueError(f"Unknown arch: {arch}")
def _load_model(spec: ModelSpec) -> tf.keras.Model:
if spec.id in _loaded_models:
return _loaded_models[spec.id]
model_path = os.path.join(os.path.dirname(__file__), "models", spec.filename)
if not os.path.exists(model_path):
raise FileNotFoundError(
f"Model file not found: {model_path}. "
f"Place it at models/{spec.filename} in your Space."
)
# CPU-friendly TF settings (small wins on free Spaces)
try:
tf.config.threading.set_intra_op_parallelism_threads(0)
tf.config.threading.set_inter_op_parallelism_threads(0)
except Exception:
pass
model = load_model(model_path, compile=False)
_loaded_models[spec.id] = model
return model
def _read_image(file_storage, img_size: int, preprocess_fn):
# Decode image
raw = file_storage.read()
image = tf.io.decode_image(raw, channels=3, expand_animations=False)
image = tf.image.resize(image, [img_size, img_size])
image = tf.cast(image, tf.float32)
image = preprocess_fn(image)
image = tf.expand_dims(image, axis=0) # [1, H, W, 3]
return image
def _predict(model: tf.keras.Model, image_tensor, class_names: Tuple[str, ...], threshold: float):
probs = model.predict(image_tensor, verbose=0)[0].astype(float)
probs = np.clip(probs, 0.0, 1.0)
best_idx = int(np.argmax(probs))
best_prob = float(np.max(probs))
# Add "Uncertain" post-hoc (not a model output class)
is_uncertain = best_prob < threshold
# Build response payload
by_class = [
{"id": name, "label": _pretty_label(name), "prob": float(probs[i])}
for i, name in enumerate(class_names)
]
by_class.sort(key=lambda x: x["prob"], reverse=True)
return {
"prediction": {
"id": "Uncertain" if is_uncertain else class_names[best_idx],
"label": "Uncertain" if is_uncertain else _pretty_label(class_names[best_idx]),
"confidence": best_prob,
"threshold": threshold,
},
"probabilities": by_class,
}
def _pretty_label(name: str) -> str:
# Internal training labels -> user-facing labels (final wording)
mapping = {
"NonDemented": "Healthy",
"VeryMildDemented": "Very Mildly Demented",
"MildDemented": "Mildly Demented",
"ModerateDemented": "Moderately Demented",
# Post-hoc
"Uncertain": "Uncertain",
}
return mapping.get(name, name)
from flask import send_from_directory
import os
@app.route("/", defaults={"path": ""})
@app.route("/<path:path>")
def serve_frontend(path):
# Keep API routes intact
if path.startswith("api/"):
return ("Not Found", 404)
clarity_dir = os.path.join(app.root_path, "static", "clarity")
requested = os.path.join(clarity_dir, path)
# Serve real static files if they exist
if path and os.path.isfile(requested):
return send_from_directory(clarity_dir, path)
# Otherwise serve React index.html
return send_from_directory(clarity_dir, "index.html")
@app.get("/api/models")
def api_models():
return jsonify({
"models": [
{
"id": s.id,
"name": s.display_name,
"img_size": s.img_size,
"classes": [{"id": c, "label": _pretty_label(c)} for c in s.class_names],
"recommended_threshold": s.recommended_threshold,
}
for s in MODEL_SPECS
],
"default_model_id": MODEL_SPECS[0].id,
})
@app.post("/api/classify")
def api_classify():
if "file" not in request.files:
return jsonify({"error": "No file uploaded (field name must be 'file')."}), 400
model_id = request.form.get("model_id", MODEL_SPECS[0].id)
spec = _get_spec(model_id)
# Threshold is model-specific and not user-adjustable
threshold = spec.recommended_threshold
try:
model = _load_model(spec)
preprocess_fn = _get_preprocess_fn(spec.arch)
image_tensor = _read_image(request.files["file"], spec.img_size, preprocess_fn)
payload = _predict(model, image_tensor, spec.class_names, threshold)
payload["model"] = {"id": spec.id, "name": spec.display_name}
return jsonify(payload)
except FileNotFoundError as e:
return jsonify({"error": str(e)}), 500
except Exception as e:
return jsonify({"error": f"Failed to classify image: {e}"}), 500
if __name__ == "__main__":
# Local dev: python app.py
# In Spaces (Dockerfile), gunicorn is used.
app.run(host="0.0.0.0", port=int(os.getenv("PORT", "7860")), debug=False) |