Ptato commited on
Commit
30ec224
·
1 Parent(s): e6df219

more_stuff

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -59,7 +59,7 @@ if not st.session_state.filled:
59
  logits = predictions.logits
60
  sigmoid = torch.nn.Sigmoid()
61
  probs = sigmoid(logits.squeeze().cpu())
62
- predictions = np.zeros(probs.shape)
63
  predictions[np.where(probs >= 0.5)] = 1
64
  predicted_labels = [st.session_state.id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
65
  log = []
@@ -106,7 +106,7 @@ if not st.session_state.filled:
106
  else:
107
  log = [0] * 6
108
  log[1] = text
109
- if max(predictions) == 0:
110
  log[0] = 0
111
  log[2] = ("NO TOXICITY")
112
  log[3] = (f"{100 - round(probs[0].item() * 100, 1)}%")
@@ -116,7 +116,7 @@ if not st.session_state.filled:
116
  log[0] = 1
117
  _max = 0
118
  _max2 = 2
119
- for i in range(1, len(predictions)):
120
  if probs[i].item() > probs[_max].item():
121
  _max = i
122
  if i > 2 and probs[i].item() > probs[_max2].item():
 
59
  logits = predictions.logits
60
  sigmoid = torch.nn.Sigmoid()
61
  probs = sigmoid(logits.squeeze().cpu())
62
+ predicts = np.zeros(probs.shape)
63
  predictions[np.where(probs >= 0.5)] = 1
64
  predicted_labels = [st.session_state.id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
65
  log = []
 
106
  else:
107
  log = [0] * 6
108
  log[1] = text
109
+ if max(predicts) == 0:
110
  log[0] = 0
111
  log[2] = ("NO TOXICITY")
112
  log[3] = (f"{100 - round(probs[0].item() * 100, 1)}%")
 
116
  log[0] = 1
117
  _max = 0
118
  _max2 = 2
119
+ for i in range(1, len(predicts)):
120
  if probs[i].item() > probs[_max].item():
121
  _max = i
122
  if i > 2 and probs[i].item() > probs[_max2].item():