SF_FastAPI / app.py
COCODEDE04's picture
Update app.py
f95e7a2 verified
raw
history blame
5.98 kB
import json
import os
from typing import Any, Dict
import numpy as np
import tensorflow as tf
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
# ----------------- CONFIG -----------------
MODEL_PATH = os.getenv("MODEL_PATH", "best_model.h5")
STATS_PATH = os.getenv("STATS_PATH", "means_std.json")
CLASSES = ["Top", "Mid-Top", "Mid", "Mid-Low", "Low"]
# ------------------------------------------
# Debug & decoding control
FORCE_CORAL = os.getenv("FORCE_CORAL", "0") in ("1", "true", "True", "YES", "yes")
RETURN_DEBUG = os.getenv("RETURN_DEBUG", "1") in ("1", "true", "True", "YES", "yes")
print("Loading model and stats...")
model = tf.keras.models.load_model(MODEL_PATH, compile=False)
with open(STATS_PATH, "r") as f:
stats: Dict[str, Dict[str, float]] = json.load(f)
# IMPORTANT: FEATURES order must match training!
FEATURES = list(stats.keys())
print("Feature order:", FEATURES)
# ---------- robust numeric coercion ----------
def coerce_float(val: Any) -> float:
"""
Accepts numeric, or strings like:
"49.709,14" -> 49709.14
"49,709.14" -> 49709.14
"0,005" -> 0.005
" 1 234 " -> 1234
Returns float, or raises ValueError if impossible.
"""
if isinstance(val, (int, float)):
return float(val)
s = str(val).strip()
if s == "":
raise ValueError("empty")
# remove spaces
s = s.replace(" ", "")
has_dot = "." in s
has_comma = "," in s
if has_dot and has_comma:
# Decide which is decimal separator by last occurrence
last_dot = s.rfind(".")
last_comma = s.rfind(",")
if last_comma > last_dot:
# decimal is comma, thousands is dot
s = s.replace(".", "")
s = s.replace(",", ".")
else:
# decimal is dot, thousands is comma
s = s.replace(",", "")
elif has_comma and not has_dot:
# likely decimal is comma
s = s.replace(",", ".")
# dots only or pure digits -> leave as is
return float(s)
def _z(val: Any, mean: float, sd: float) -> float:
try:
v = coerce_float(val)
except Exception:
return 0.0
if not sd:
return 0.0
return (v - mean) / sd
def coral_probs_from_logits(logits_np: np.ndarray) -> np.ndarray:
"""(N, K-1) logits -> (N, K) probabilities for CORAL ordinal output."""
logits = tf.convert_to_tensor(logits_np, dtype=tf.float32)
sig = tf.math.sigmoid(logits) # (N, K-1)
left = tf.concat([tf.ones_like(sig[:, :1]), sig], axis=1)
right = tf.concat([sig, tf.zeros_like(sig[:, :1])], axis=1)
probs = tf.clip_by_value(left - right, 1e-12, 1.0)
return probs.numpy()
# ------------- FastAPI app ----------------
app = FastAPI(title="Static Fingerprint API", version="1.0.0")
# Allow Excel / local tools to call the API
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
def root():
return {
"message": "Static Fingerprint API is running.",
"try": ["GET /health", "POST /predict"],
}
@app.get("/health")
def health():
return {
"status": "ok",
"features": FEATURES,
"classes": CLASSES,
"model_file": MODEL_PATH,
"stats_file": STATS_PATH,
}
@app.post("/echo")
async def echo(req: Request):
payload = await req.json()
return {"received": payload}
@app.post("/predict")
async def predict(req: Request):
"""
Body: a single JSON dict mapping feature -> numeric value.
Example:
{
"autosuf_oper": 1.0,
"cov_improductiva": 0.9,
...
}
"""
payload = await req.json()
if not isinstance(payload, dict):
return {"error": "Expected a JSON object mapping feature -> value."}
# Build z-scores in strict model order
z = []
z_detail = {}
missing = []
for f in FEATURES:
mean = stats[f]["mean"]
sd = stats[f]["std"]
if f in payload:
zf = _z(payload[f], mean, sd)
else:
missing.append(f)
zf = _z(0.0, mean, sd) # treat missing as 0 input
z.append(zf)
z_detail[f] = zf
X = np.array([z], dtype=np.float32)
raw = model.predict(X, verbose=0)
# ---------------- DEBUG INFO ----------------
raw_shape = tuple(raw.shape)
# --------------------------------------------
# Decode: CORAL vs Softmax
probs = None
decode_mode = "auto"
try:
if FORCE_CORAL:
decode_mode = "forced_coral"
probs = coral_probs_from_logits(raw)[0]
else:
if raw.ndim == 2 and raw.shape[1] == (len(CLASSES) - 1):
decode_mode = "auto_coral"
probs = coral_probs_from_logits(raw)[0]
else:
decode_mode = "auto_softmax_or_logits"
probs = raw[0]
s = float(np.sum(probs))
if s > 0: # defensive normalize
probs = probs / s
except Exception as e:
decode_mode = "fallback_raw_norm"
probs = raw[0]
s = float(np.sum(probs))
if s > 0:
probs = probs / s
pred_idx = int(np.argmax(probs))
resp = {
"input_ok": (len(missing) == 0),
"missing": missing,
"z_scores": z_detail,
"probabilities": {
CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))
},
"predicted_state": CLASSES[pred_idx],
}
# Include debug fields so we can see shape & decode path
if RETURN_DEBUG:
resp["debug"] = {
"raw_shape": raw_shape,
"decode_mode": decode_mode,
"raw_first_row": [
float(x)
for x in (
raw[0].tolist() if raw.ndim >= 2 else [float(raw)]
)
],
}
return resp