COCODEDE04 commited on
Commit
acbe7ed
·
verified ·
1 Parent(s): 6218a6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -41
app.py CHANGED
@@ -45,28 +45,21 @@ def coerce_float(val: Any) -> float:
45
  if s == "":
46
  raise ValueError("empty")
47
 
48
- # remove spaces
49
  s = s.replace(" ", "")
50
-
51
  has_dot = "." in s
52
  has_comma = "," in s
53
 
54
  if has_dot and has_comma:
55
- # Decide which is decimal separator by last occurrence
56
  last_dot = s.rfind(".")
57
  last_comma = s.rfind(",")
58
  if last_comma > last_dot:
59
- # decimal is comma, thousands is dot
60
  s = s.replace(".", "")
61
  s = s.replace(",", ".")
62
  else:
63
- # decimal is dot, thousands is comma
64
  s = s.replace(",", "")
65
  elif has_comma and not has_dot:
66
- # likely decimal is comma
67
  s = s.replace(",", ".")
68
- # dots only or pure digits -> leave as is
69
-
70
  return float(s)
71
 
72
 
@@ -80,14 +73,52 @@ def _z(val: Any, mean: float, sd: float) -> float:
80
  return (v - mean) / sd
81
 
82
 
83
- def coral_probs_from_logits(logits_np: np.ndarray) -> np.ndarray:
84
- """(N, K-1) logits -> (N, K) probabilities for CORAL ordinal output."""
85
- logits = tf.convert_to_tensor(logits_np, dtype=tf.float32)
86
- sig = tf.math.sigmoid(logits) # (N, K-1)
87
- left = tf.concat([tf.ones_like(sig[:, :1]), sig], axis=1)
88
- right = tf.concat([sig, tf.zeros_like(sig[:, :1])], axis=1)
89
- probs = tf.clip_by_value(left - right, 1e-12, 1.0)
90
- return probs.numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
 
93
  # ------------- FastAPI app ----------------
@@ -132,18 +163,12 @@ async def echo(req: Request):
132
  async def predict(req: Request):
133
  """
134
  Body: a single JSON dict mapping feature -> numeric value.
135
- Example:
136
- {
137
- "autosuf_oper": 1.0,
138
- "cov_improductiva": 0.9,
139
- ...
140
- }
141
  """
142
  payload = await req.json()
143
  if not isinstance(payload, dict):
144
  return {"error": "Expected a JSON object mapping feature -> value."}
145
 
146
- # Build z-scores in strict model order
147
  z = []
148
  z_detail = {}
149
  missing = []
@@ -154,35 +179,32 @@ async def predict(req: Request):
154
  zf = _z(payload[f], mean, sd)
155
  else:
156
  missing.append(f)
157
- zf = _z(0.0, mean, sd) # treat missing as 0 input
158
  z.append(zf)
159
  z_detail[f] = zf
160
 
161
  X = np.array([z], dtype=np.float32)
162
  raw = model.predict(X, verbose=0)
163
-
164
- # ---------------- DEBUG INFO ----------------
165
  raw_shape = tuple(raw.shape)
166
- # --------------------------------------------
167
 
168
- # Decode: CORAL vs Softmax
169
  probs = None
170
  decode_mode = "auto"
171
  try:
172
  if FORCE_CORAL:
173
- decode_mode = "forced_coral"
174
- probs = coral_probs_from_logits(raw)[0]
175
  else:
176
  if raw.ndim == 2 and raw.shape[1] == (len(CLASSES) - 1):
177
- decode_mode = "auto_coral"
178
- probs = coral_probs_from_logits(raw)[0]
179
  else:
180
  decode_mode = "auto_softmax_or_logits"
181
  probs = raw[0]
182
  s = float(np.sum(probs))
183
- if s > 0: # defensive normalize
184
  probs = probs / s
185
- except Exception as e:
186
  decode_mode = "fallback_raw_norm"
187
  probs = raw[0]
188
  s = float(np.sum(probs))
@@ -191,26 +213,23 @@ async def predict(req: Request):
191
 
192
  pred_idx = int(np.argmax(probs))
193
 
 
194
  resp = {
195
  "input_ok": (len(missing) == 0),
196
  "missing": missing,
197
  "z_scores": z_detail,
198
- "probabilities": {
199
- CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))
200
- },
201
  "predicted_state": CLASSES[pred_idx],
202
  }
203
 
204
- # Include debug fields so we can see shape & decode path
205
  if RETURN_DEBUG:
206
  resp["debug"] = {
207
  "raw_shape": raw_shape,
208
  "decode_mode": decode_mode,
209
  "raw_first_row": [
210
  float(x)
211
- for x in (
212
- raw[0].tolist() if raw.ndim >= 2 else [float(raw)]
213
- )
214
  ],
215
  }
216
 
 
45
  if s == "":
46
  raise ValueError("empty")
47
 
 
48
  s = s.replace(" ", "")
 
49
  has_dot = "." in s
50
  has_comma = "," in s
51
 
52
  if has_dot and has_comma:
 
53
  last_dot = s.rfind(".")
54
  last_comma = s.rfind(",")
55
  if last_comma > last_dot:
 
56
  s = s.replace(".", "")
57
  s = s.replace(",", ".")
58
  else:
 
59
  s = s.replace(",", "")
60
  elif has_comma and not has_dot:
 
61
  s = s.replace(",", ".")
62
+ # dots only or digits -> leave
 
63
  return float(s)
64
 
65
 
 
73
  return (v - mean) / sd
74
 
75
 
76
+ # ---------- CORAL utilities ----------
77
+ def enforce_nonincreasing(sig_vec: np.ndarray) -> np.ndarray:
78
+ """
79
+ Given a 1D array of cumulative probs s (should be non-increasing for CORAL),
80
+ enforce s[0] >= s[1] >= ... >= s[K-1] using a simple PAV algorithm.
81
+ """
82
+ s = sig_vec.astype(float).copy()
83
+ n = len(s)
84
+ blocks = [[i] for i in range(n)]
85
+ vals = s.tolist()
86
+
87
+ i = 0
88
+ while i < len(vals) - 1:
89
+ if vals[i] < vals[i + 1]: # violation: should be non-increasing
90
+ merged_idx = blocks[i] + blocks[i + 1]
91
+ avg = (
92
+ (vals[i] * len(blocks[i]) + vals[i + 1] * len(blocks[i + 1]))
93
+ / (len(blocks[i]) + len(blocks[i + 1]))
94
+ )
95
+ blocks[i] = merged_idx
96
+ vals[i] = avg
97
+ del blocks[i + 1]
98
+ del vals[i + 1]
99
+ if i > 0:
100
+ i -= 1
101
+ else:
102
+ i += 1
103
+
104
+ out = np.zeros(n, dtype=float)
105
+ for v, idxs in zip(vals, blocks):
106
+ for j in idxs:
107
+ out[j] = v
108
+ return np.clip(out, 1e-12, 1 - 1e-12)
109
+
110
+
111
+ def coral_probs_from_logits_monotone(logits_np: np.ndarray) -> np.ndarray:
112
+ """
113
+ CORAL decoding with monotonicity enforcement so class probs are valid (sum=1, nonnegative).
114
+ """
115
+ sig = 1.0 / (1.0 + np.exp(-logits_np)) # sigmoid
116
+ sig_m = enforce_nonincreasing(sig[0]) # enforce order
117
+ left = np.concatenate([np.array([1.0], dtype=float), sig_m])
118
+ right = np.concatenate([sig_m, np.array([0.0], dtype=float)])
119
+ probs = np.clip(left - right, 1e-12, 1.0)
120
+ probs = probs / probs.sum() # normalize
121
+ return probs
122
 
123
 
124
  # ------------- FastAPI app ----------------
 
163
  async def predict(req: Request):
164
  """
165
  Body: a single JSON dict mapping feature -> numeric value.
 
 
 
 
 
 
166
  """
167
  payload = await req.json()
168
  if not isinstance(payload, dict):
169
  return {"error": "Expected a JSON object mapping feature -> value."}
170
 
171
+ # --- Build z-scores in strict model order ---
172
  z = []
173
  z_detail = {}
174
  missing = []
 
179
  zf = _z(payload[f], mean, sd)
180
  else:
181
  missing.append(f)
182
+ zf = _z(0.0, mean, sd)
183
  z.append(zf)
184
  z_detail[f] = zf
185
 
186
  X = np.array([z], dtype=np.float32)
187
  raw = model.predict(X, verbose=0)
 
 
188
  raw_shape = tuple(raw.shape)
 
189
 
190
+ # --- Decode ---
191
  probs = None
192
  decode_mode = "auto"
193
  try:
194
  if FORCE_CORAL:
195
+ decode_mode = "forced_coral_monotone"
196
+ probs = coral_probs_from_logits_monotone(raw)
197
  else:
198
  if raw.ndim == 2 and raw.shape[1] == (len(CLASSES) - 1):
199
+ decode_mode = "auto_coral_monotone"
200
+ probs = coral_probs_from_logits_monotone(raw)
201
  else:
202
  decode_mode = "auto_softmax_or_logits"
203
  probs = raw[0]
204
  s = float(np.sum(probs))
205
+ if s > 0:
206
  probs = probs / s
207
+ except Exception:
208
  decode_mode = "fallback_raw_norm"
209
  probs = raw[0]
210
  s = float(np.sum(probs))
 
213
 
214
  pred_idx = int(np.argmax(probs))
215
 
216
+ # --- Response ---
217
  resp = {
218
  "input_ok": (len(missing) == 0),
219
  "missing": missing,
220
  "z_scores": z_detail,
221
+ "probabilities": {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))},
 
 
222
  "predicted_state": CLASSES[pred_idx],
223
  }
224
 
225
+ # --- Debug block ---
226
  if RETURN_DEBUG:
227
  resp["debug"] = {
228
  "raw_shape": raw_shape,
229
  "decode_mode": decode_mode,
230
  "raw_first_row": [
231
  float(x)
232
+ for x in (raw[0].tolist() if raw.ndim >= 2 else [float(raw)])
 
 
233
  ],
234
  }
235