Mulah commited on
Commit
23f8086
·
1 Parent(s): 7f64ad2

Robust extraction of positive-class prob from r_probs (any shape)

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -132,9 +132,12 @@ def predict(sequence: str):
132
 
133
  pred = pred_texts[0]
134
  r = float(r_pred.cpu().tolist()[0] if torch.is_tensor(r_pred) else r_pred[0])
135
- r_probs_list = r_probs.cpu().tolist() if torch.is_tensor(r_probs) else list(r_probs)
136
- first = r_probs_list[0]
137
- p_pos = float(first[1]) if isinstance(first, (list, tuple)) else float(r_probs_list[1])
 
 
 
138
  return pred, format_reliability(r), f"{p_pos:.4f}"
139
 
140
 
 
132
 
133
  pred = pred_texts[0]
134
  r = float(r_pred.cpu().tolist()[0] if torch.is_tensor(r_pred) else r_pred[0])
135
+ if torch.is_tensor(r_probs):
136
+ flat = r_probs.flatten().cpu().tolist()
137
+ else:
138
+ flat = [float(x) for sub in r_probs for x in (sub if isinstance(sub, (list, tuple)) else [sub])]
139
+ print(f"[debug] r_probs raw flat = {flat}") # remove after verifying
140
+ p_pos = float(flat[-1])
141
  return pred, format_reliability(r), f"{p_pos:.4f}"
142
 
143