COCODEDE04 commited on
Commit
59369dd
·
verified ·
1 Parent(s): 1f02925

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -66
app.py CHANGED
@@ -1,120 +1,155 @@
 
1
  import json
 
 
2
  import numpy as np
3
  import tensorflow as tf
4
  import gradio as gr
5
  from fastapi import FastAPI, Request
6
  from fastapi.middleware.cors import CORSMiddleware
7
 
8
- # ---------- CONFIG ----------
9
- MODEL_PATH = "best_model.h5" # or best_model.keras if that’s what you uploaded
10
- STATS_PATH = "Means & Std for Excel.json" # exact filename
11
  CLASSES = ["Top", "Mid-Top", "Mid", "Mid-Low", "Low"]
12
- # ----------------------------
13
 
14
- print("Loading model and stats...")
15
- model = tf.keras.models.load_model(MODEL_PATH, compile=False)
16
- with open(STATS_PATH, "r") as f:
17
- stats = json.load(f)
18
 
19
- FEATURES = list(stats.keys())
20
- print("Feature order:", FEATURES)
21
 
22
- def _z(val): # safe z-score
23
  try:
24
- v = float(val)
25
  except Exception:
26
- v = 0.0
27
- return v
28
 
29
  def _zscore(val, mean, sd):
30
- v = _z(val)
31
  return 0.0 if (sd is None or sd == 0) else (v - mean) / sd
32
 
33
  def coral_probs_from_logits(logits_np):
34
- import tensorflow as tf
35
- logits = tf.convert_to_tensor(logits_np, dtype=tf.float32) # (1, K-1)
36
- sig = tf.math.sigmoid(logits) # (1, K-1)
37
  left = tf.concat([tf.ones_like(sig[:, :1]), sig], axis=1)
38
  right = tf.concat([sig, tf.zeros_like(sig[:, :1])], axis=1)
39
  probs = tf.clip_by_value(left - right, 1e-12, 1.0)
40
  return probs.numpy()
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def predict_core(ratios: dict):
43
- # build z vector in fixed order
44
- zscores = []
45
- z_map = {}
46
- for f in FEATURES:
47
- mean = stats[f]["mean"]
48
- sd = stats[f]["std"]
49
- val = ratios.get(f, 0.0)
50
- z = _zscore(val, mean, sd)
51
- zscores.append(z)
52
- z_map[f] = z
53
-
54
- X = np.array([zscores], dtype=np.float32)
55
- y = model.predict(X, verbose=0)
56
-
57
- # handle either softmax K or CORAL K-1
58
  if y.ndim == 2 and y.shape[1] == len(CLASSES):
59
  probs = y[0]
60
  elif y.ndim == 2 and y.shape[1] == len(CLASSES) - 1:
61
  probs = coral_probs_from_logits(y)[0]
62
  else:
63
- # fallback: normalize positive scores
64
- s = y[0].astype(np.float64)
65
- if s.ndim == 0:
66
- s = np.array([float(s)], dtype=np.float64)
67
- s = np.maximum(s, 0.0)
68
  probs = s / s.sum() if s.sum() > 0 else np.ones(len(CLASSES)) / len(CLASSES)
69
 
70
  pred_idx = int(np.argmax(probs))
71
- pred_state = CLASSES[pred_idx]
72
  return {
73
  "input_ok": True,
74
- "missing": [f for f in FEATURES if f not in ratios],
75
- "z_scores": z_map,
76
  "probabilities": {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))},
77
- "predicted_state": pred_state
78
  }
79
 
80
  def predict_from_json(payload):
81
- # accept raw dict OR list-of-one dict
82
  if isinstance(payload, list) and len(payload) == 1 and isinstance(payload[0], dict):
83
  payload = payload[0]
84
  if not isinstance(payload, dict):
85
- return {"error": "Invalid payload: send a JSON object mapping feature->value."}
86
- return predict_core(payload)
 
 
 
 
 
 
87
 
88
  # ------------------ FastAPI + Gradio ------------------
89
- # ------------------ FastAPI + Gradio ------------------
90
- from fastapi import FastAPI, Request
91
- from fastapi.middleware.cors import CORSMiddleware
92
- import gradio as gr
93
-
94
  app = FastAPI()
95
  app.add_middleware(
96
  CORSMiddleware,
97
  allow_origins=["*"], allow_methods=["*"], allow_headers=["*"],
98
  )
99
 
100
- # Plain handler we’ll reuse
 
 
 
 
 
 
 
 
 
 
 
 
101
  async def _handle_predict(req: Request):
102
- body = await req.json()
103
- # accept either raw dict or {"data":[{...}]}
 
 
 
104
  if isinstance(body, dict) and "data" in body and isinstance(body["data"], list) and body["data"]:
105
  body = body["data"][0]
106
- if not isinstance(body, dict):
107
- return {"error": "Invalid payload. Send a JSON object of feature->value or {'data':[that_object]}."}
108
- try:
109
- return predict_from_json(body)
110
- except Exception as e:
111
- return {"error": f"{type(e).__name__}: {e}"}
112
 
113
  @app.post("/predict")
114
  async def predict_main(req: Request):
115
  return await _handle_predict(req)
116
 
117
- # Be generous: also accept your older paths
118
  @app.post("/run/predict")
119
  async def predict_compat1(req: Request):
120
  return await _handle_predict(req)
@@ -123,11 +158,7 @@ async def predict_compat1(req: Request):
123
  async def predict_compat2(req: Request):
124
  return await _handle_predict(req)
125
 
126
- @app.get("/health")
127
- def health():
128
- return {"ok": True}
129
-
130
- # Mount the Gradio UI at root
131
  ui = gr.Interface(
132
  fn=predict_from_json,
133
  inputs=gr.JSON(label="ratios JSON (dict of feature -> value)"),
@@ -137,7 +168,7 @@ ui = gr.Interface(
137
  )
138
  app = gr.mount_gradio_app(app, ui, path="/")
139
 
140
- # DEBUG: print available routes so we can see them in the logs
141
  for r in app.router.routes:
142
  try:
143
  print("ROUTE:", r.path)
 
1
+ import os
2
  import json
3
+ from pathlib import Path
4
+
5
  import numpy as np
6
  import tensorflow as tf
7
  import gradio as gr
8
  from fastapi import FastAPI, Request
9
  from fastapi.middleware.cors import CORSMiddleware
10
 
11
+ # ---------- CONFIG (edit these names if yours differ) ----------
12
+ MODEL_PATH = os.environ.get("MODEL_PATH", "best_model.h5") # or "best_model.keras"
13
+ STATS_PATH = os.environ.get("STATS_PATH", "Means & Std for Excel.json")
14
  CLASSES = ["Top", "Mid-Top", "Mid", "Mid-Low", "Low"]
15
+ # --------------------------------------------------------------
16
 
17
+ # Global handles (lazy init)
18
+ _model = None
19
+ _stats = None
20
+ _features = None
21
 
22
+ def _exists(p): return Path(p).exists()
 
23
 
24
+ def _safe_float(x, default=0.0):
25
  try:
26
+ return float(x)
27
  except Exception:
28
+ return default
 
29
 
30
  def _zscore(val, mean, sd):
31
+ v = _safe_float(val, 0.0)
32
  return 0.0 if (sd is None or sd == 0) else (v - mean) / sd
33
 
34
  def coral_probs_from_logits(logits_np):
35
+ logits = tf.convert_to_tensor(logits_np, dtype=tf.float32) # (N, K-1)
36
+ sig = tf.math.sigmoid(logits)
 
37
  left = tf.concat([tf.ones_like(sig[:, :1]), sig], axis=1)
38
  right = tf.concat([sig, tf.zeros_like(sig[:, :1])], axis=1)
39
  probs = tf.clip_by_value(left - right, 1e-12, 1.0)
40
  return probs.numpy()
41
 
42
+ def lazy_init():
43
+ """Load model + stats on first use; never crash the process."""
44
+ global _model, _stats, _features
45
+ if _model is not None and _stats is not None and _features is not None:
46
+ return
47
+
48
+ problems = []
49
+ if not _exists(MODEL_PATH):
50
+ problems.append(f"Model file not found: {MODEL_PATH}")
51
+ if not _exists(STATS_PATH):
52
+ problems.append(f"Stats JSON not found: {STATS_PATH}")
53
+ if problems:
54
+ # Don’t raise—let callers see the reason in the response
55
+ raise RuntimeError("; ".join(problems))
56
+
57
+ try:
58
+ _model = tf.keras.models.load_model(MODEL_PATH, compile=False)
59
+ except Exception as e:
60
+ raise RuntimeError(f"Failed to load model: {type(e).__name__}: {e}")
61
+
62
+ try:
63
+ with open(STATS_PATH, "r") as f:
64
+ _stats = json.load(f)
65
+ except Exception as e:
66
+ raise RuntimeError(f"Failed to read stats JSON: {type(e).__name__}: {e}")
67
+
68
+ # Fixed feature order = keys order in JSON
69
+ _features = list(_stats.keys())
70
+ print("Feature order:", _features)
71
+
72
  def predict_core(ratios: dict):
73
+ lazy_init() # may raise RuntimeError with a clear message
74
+
75
+ zvec = []
76
+ zmap = {}
77
+ for f in _features:
78
+ mean = _stats[f]["mean"]
79
+ sd = _stats[f]["std"]
80
+ z = _zscore(ratios.get(f, 0.0), mean, sd)
81
+ zvec.append(z)
82
+ zmap[f] = z
83
+
84
+ X = np.array([zvec], dtype=np.float32)
85
+ y = _model.predict(X, verbose=0)
86
+
87
+ # Softmax (K) or CORAL (K-1)
88
  if y.ndim == 2 and y.shape[1] == len(CLASSES):
89
  probs = y[0]
90
  elif y.ndim == 2 and y.shape[1] == len(CLASSES) - 1:
91
  probs = coral_probs_from_logits(y)[0]
92
  else:
93
+ s = np.maximum(y[0].astype(np.float64).ravel(), 0.0)
 
 
 
 
94
  probs = s / s.sum() if s.sum() > 0 else np.ones(len(CLASSES)) / len(CLASSES)
95
 
96
  pred_idx = int(np.argmax(probs))
 
97
  return {
98
  "input_ok": True,
99
+ "missing": [f for f in _features if f not in ratios],
100
+ "z_scores": zmap,
101
  "probabilities": {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))},
102
+ "predicted_state": CLASSES[pred_idx]
103
  }
104
 
105
  def predict_from_json(payload):
106
+ # Accept raw dict or list-of-one
107
  if isinstance(payload, list) and len(payload) == 1 and isinstance(payload[0], dict):
108
  payload = payload[0]
109
  if not isinstance(payload, dict):
110
+ return {"error": "Invalid payload. Send a JSON object mapping feature->value."}
111
+ try:
112
+ return predict_core(payload)
113
+ except RuntimeError as e:
114
+ # File/boot issues come here (and we still return 200 JSON)
115
+ return {"error": str(e)}
116
+ except Exception as e:
117
+ return {"error": f"{type(e).__name__}: {e}"}
118
 
119
  # ------------------ FastAPI + Gradio ------------------
 
 
 
 
 
120
  app = FastAPI()
121
  app.add_middleware(
122
  CORSMiddleware,
123
  allow_origins=["*"], allow_methods=["*"], allow_headers=["*"],
124
  )
125
 
126
+ @app.get("/health")
127
+ def health():
128
+ return {
129
+ "ok": True,
130
+ "model_exists": _exists(MODEL_PATH),
131
+ "stats_exists": _exists(STATS_PATH),
132
+ "model_loaded": (_model is not None),
133
+ "stats_loaded": (_stats is not None)
134
+ }
135
+
136
+ # Plain REST endpoints for Excel (we expose several to be future-proof)
137
+ from fastapi import Request
138
+
139
  async def _handle_predict(req: Request):
140
+ try:
141
+ body = await req.json()
142
+ except Exception:
143
+ return {"error": "Invalid JSON"}
144
+ # raw dict or {"data":[{...}]}
145
  if isinstance(body, dict) and "data" in body and isinstance(body["data"], list) and body["data"]:
146
  body = body["data"][0]
147
+ return predict_from_json(body)
 
 
 
 
 
148
 
149
  @app.post("/predict")
150
  async def predict_main(req: Request):
151
  return await _handle_predict(req)
152
 
 
153
  @app.post("/run/predict")
154
  async def predict_compat1(req: Request):
155
  return await _handle_predict(req)
 
158
  async def predict_compat2(req: Request):
159
  return await _handle_predict(req)
160
 
161
+ # UI at root (keeps your browser demo)
 
 
 
 
162
  ui = gr.Interface(
163
  fn=predict_from_json,
164
  inputs=gr.JSON(label="ratios JSON (dict of feature -> value)"),
 
168
  )
169
  app = gr.mount_gradio_app(app, ui, path="/")
170
 
171
+ # Print routes in logs for visibility
172
  for r in app.router.routes:
173
  try:
174
  print("ROUTE:", r.path)