Update app.py
Browse files
app.py
CHANGED
|
@@ -21,7 +21,7 @@ def predict_author(text: str):
|
|
| 21 |
with torch.no_grad():
|
| 22 |
outputs = model(**inputs)
|
| 23 |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 24 |
-
|
| 25 |
confidence = probs[0][pred].item()
|
| 26 |
predicted_author = label_map[pred]
|
| 27 |
return predicted_author, round(confidence * 100, 2)
|
|
@@ -45,4 +45,3 @@ async def whatsapp_reply(Body: str = Form(...)):
|
|
| 45 |
async def predict(text: str):
|
| 46 |
author, confidence = predict_author(text)
|
| 47 |
return {"prediction": author, "confidence": confidence}
|
| 48 |
-
|
|
|
|
| 21 |
with torch.no_grad():
|
| 22 |
outputs = model(**inputs)
|
| 23 |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 24 |
+
pred = torch.argmax(probs, dim=1).item()
|
| 25 |
confidence = probs[0][pred].item()
|
| 26 |
predicted_author = label_map[pred]
|
| 27 |
return predicted_author, round(confidence * 100, 2)
|
|
|
|
| 45 |
async def predict(text: str):
|
| 46 |
author, confidence = predict_author(text)
|
| 47 |
return {"prediction": author, "confidence": confidence}
|
|
|