Mulah commited on
Commit
7f64ad2
·
1 Parent(s): 50c4964

Handle 1D r_probs shape (batch=1 case)

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -133,7 +133,8 @@ def predict(sequence: str):
133
  pred = pred_texts[0]
134
  r = float(r_pred.cpu().tolist()[0] if torch.is_tensor(r_pred) else r_pred[0])
135
  r_probs_list = r_probs.cpu().tolist() if torch.is_tensor(r_probs) else list(r_probs)
136
- p_pos = float(r_probs_list[0][1])
 
137
  return pred, format_reliability(r), f"{p_pos:.4f}"
138
 
139
 
 
133
  pred = pred_texts[0]
134
  r = float(r_pred.cpu().tolist()[0] if torch.is_tensor(r_pred) else r_pred[0])
135
  r_probs_list = r_probs.cpu().tolist() if torch.is_tensor(r_probs) else list(r_probs)
136
+ first = r_probs_list[0]
137
+ p_pos = float(first[1]) if isinstance(first, (list, tuple)) else float(r_probs_list[1])
138
  return pred, format_reliability(r), f"{p_pos:.4f}"
139
 
140