jpmendes commited on
Commit
1faf52d
·
verified ·
1 Parent(s): 2ff098a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -9
app.py CHANGED
@@ -32,16 +32,23 @@ def respond(message, history, system_message=None, max_tokens=None, temperature=
32
 
33
  with torch.no_grad():
34
  logits = model(**inputs).logits
35
- probs = F.softmax(logits, dim=-1).squeeze().tolist() # probabilidades
36
- pred_index = torch.argmax(logits, dim=-1).item()
37
- pred_label = labels[pred_index]
38
 
39
- # Monta mensagem de saída detalhada
40
- output = f"Predicted label: {pred_label}\n\n"
41
- output += "Probabilities:\n"
42
- for i, p in enumerate(probs):
43
- output += f"{labels[i]}: {p:.4f}\n"
44
- output += f"\nLogits: {logits.squeeze().tolist()}"
 
 
 
 
 
 
 
 
 
 
45
  return output
46
 
47
  # -------------------------------
 
32
 
33
  with torch.no_grad():
34
  logits = model(**inputs).logits
 
 
 
35
 
36
+ # Número de classes
37
+ num_labels = model.config.num_labels
38
+
39
+ if num_labels == 1:
40
+ prob = torch.sigmoid(logits).item()
41
+ output = f"Predicted logit: {logits.item():.4f}\nProbability (sigmoid): {prob:.4f}"
42
+ else:
43
+ probs = F.softmax(logits, dim=-1).squeeze().tolist()
44
+ pred_index = torch.argmax(logits, dim=-1).item()
45
+ pred_label = labels[pred_index]
46
+
47
+ output = f"Predicted label: {pred_label}\nProbabilities:\n"
48
+ for i, p in enumerate(probs):
49
+ output += f"{labels[i]}: {p:.4f}\n"
50
+ output += f"\nLogits: {logits.squeeze().tolist()}"
51
+
52
  return output
53
 
54
  # -------------------------------