Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -72,8 +72,12 @@ def get_prediction(inputs):
|
|
| 72 |
outputs = model(**inputs)
|
| 73 |
logits = outputs.last_hidden_state[:, 0, :] # 取CLS标记的输出进行分类
|
| 74 |
pred_prob = torch.softmax(logits, dim=1)
|
| 75 |
-
pred = torch.argmax(pred_prob, dim=1)
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
# vectorizer= nltk_u.vectorizer()
|
| 79 |
# vectorizer.fit(train_data.text)
|
|
|
|
| 72 |
outputs = model(**inputs)
|
| 73 |
logits = outputs.last_hidden_state[:, 0, :] # 取CLS标记的输出进行分类
|
| 74 |
pred_prob = torch.softmax(logits, dim=1)
|
| 75 |
+
pred = torch.argmax(pred_prob, dim=1).item()
|
| 76 |
+
if pred in class_names:
|
| 77 |
+
return class_names[pred]
|
| 78 |
+
else:
|
| 79 |
+
print(f"Warning: Prediction index {pred} not found in class_names.")
|
| 80 |
+
return "Unknown" # 或者其他默认的响应
|
| 81 |
|
| 82 |
# vectorizer= nltk_u.vectorizer()
|
| 83 |
# vectorizer.fit(train_data.text)
|