Spaces:
Sleeping
Sleeping
| import json | |
| import numpy as np | |
| import tensorflow as tf | |
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException | |
| from typing import Dict, Any | |
| # ========================= | |
| # Config | |
| # ========================= | |
| MODEL_PATH = "best_model.h5" # your uploaded model | |
| STATS_PATH = "means_std.json" # {"feature": {"mean": x, "std": y}, ...} | |
| CLASSES = ["Top", "Mid-Top", "Mid", "Mid-Low", "Low"] | |
| # ========================= | |
| # Load artifacts | |
| # ========================= | |
| print("Loading model and stats...") | |
| model = tf.keras.models.load_model(MODEL_PATH, compile=False) | |
| with open(STATS_PATH, "r") as f: | |
| stats = json.load(f) | |
| FEATURES = list(stats.keys()) | |
| print("Feature order:", FEATURES) | |
| # ========================= | |
| # Helpers | |
| # ========================= | |
| def _zscore(val: Any, mean: float, sd: float) -> float: | |
| try: | |
| v = float(val) | |
| except Exception: | |
| return 0.0 | |
| if sd is None or sd == 0: | |
| return 0.0 | |
| return (v - mean) / sd | |
| def _coral_probs_from_logits(logits_np: np.ndarray) -> np.ndarray: | |
| """ | |
| logits_np: (N, K-1) linear outputs. | |
| Returns probabilities (N, K) with p_k = σ(z_{k-1}) - σ(z_k), and boundaries 1/0. | |
| """ | |
| logits = tf.convert_to_tensor(logits_np, dtype=tf.float32) | |
| sig = tf.math.sigmoid(logits) # (N, K-1) | |
| left = tf.concat([tf.ones_like(sig[:, :1]), sig], axis=1) | |
| right = tf.concat([sig, tf.zeros_like(sig[:, :1])], axis=1) | |
| probs = tf.clip_by_value(left - right, 1e-12, 1.0) | |
| return probs.numpy() | |
| def _predict_core(ratios: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| ratios: dict mapping feature -> raw numeric value. | |
| Returns: dict with predicted_state, probabilities, z_scores, missing, input_ok. | |
| """ | |
| # Validate presence (we still accept missing and fill 0.0 after z-score) | |
| missing = [f for f in FEATURES if f not in ratios] | |
| # Build z-score vector in exact FEATURE order | |
| z_list, z_scores = [], {} | |
| for f in FEATURES: | |
| z = _zscore(ratios.get(f, 0.0), stats[f]["mean"], stats[f]["std"]) | |
| z_list.append(z) | |
| z_scores[f] = z | |
| X = np.array([z_list], dtype=np.float32) # (1, D) | |
| raw = model.predict(X, verbose=0) | |
| # Softmax (K) vs CORAL (K-1) | |
| if raw.ndim != 2: | |
| raise ValueError(f"Unexpected model output shape: {raw.shape}") | |
| if raw.shape[1] == len(CLASSES) - 1: | |
| probs = _coral_probs_from_logits(raw)[0] # (K,) | |
| elif raw.shape[1] == len(CLASSES): | |
| probs = raw[0] # (K,) | |
| else: | |
| raise ValueError(f"Model output width {raw.shape[1]} incompatible with classes {len(CLASSES)}") | |
| # Safety: normalize if not a perfect prob. vector | |
| probs = np.maximum(probs, 0.0) | |
| s = probs.sum() | |
| if s <= 0: | |
| # fallback uniform if something pathological happens | |
| probs = np.ones(len(CLASSES), dtype=np.float32) / float(len(CLASSES)) | |
| else: | |
| probs = probs / s | |
| pred_idx = int(np.argmax(probs)) | |
| return { | |
| "input_ok": len(missing) == 0, | |
| "missing": missing, | |
| "z_scores": z_scores, | |
| "probabilities": {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}, | |
| "predicted_state": CLASSES[pred_idx], | |
| } | |
| # ========================= | |
| # Gradio adapter (UI) | |
| # ========================= | |
| def _gradio_adapter(payload): | |
| """ | |
| Accepts either: | |
| - a dict {feature: value, ...} | |
| - a list with one dict [ {feature: value, ...} ] | |
| """ | |
| if isinstance(payload, list) and len(payload) == 1 and isinstance(payload[0], dict): | |
| payload = payload[0] | |
| if not isinstance(payload, dict): | |
| return {"error": "Expected JSON object mapping feature -> value."} | |
| return _predict_core(payload) | |
| demo = gr.Interface( | |
| fn=_gradio_adapter, | |
| inputs=gr.JSON(label="ratios JSON (dict of feature -> value)"), | |
| outputs="json", | |
| title="Static Fingerprint Model API", | |
| description="Programmatic use: POST a raw dict to /predict. UI here is for quick manual checks.", | |
| allow_flagging="never" | |
| ) | |
| # ========================= | |
| # FastAPI app (sync endpoint) | |
| # ========================= | |
| api = FastAPI() | |
| def health(): | |
| return {"status": "ok", "features": FEATURES, "classes": CLASSES} | |
| def predict_endpoint(payload: Any): | |
| # Allow list-of-one and dict | |
| if isinstance(payload, list) and len(payload) == 1 and isinstance(payload[0], dict): | |
| payload = payload[0] | |
| if not isinstance(payload, dict): | |
| raise HTTPException(status_code=400, detail="Expected JSON object mapping feature -> value.") | |
| try: | |
| return _predict_core(payload) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Mount Gradio UI at "/" and expose FastAPI routes alongside it | |
| app = gr.mount_gradio_app(api, demo, path="/") | |
| if __name__ == "__main__": | |
| # local dev run (HF Spaces will ignore this and use its own server) | |
| demo.launch() |