Spaces:
Runtime error
Runtime error
Commit
·
8aad599
1
Parent(s):
9bcd196
prediction probability feature
Browse files
app.py
CHANGED
|
@@ -72,9 +72,16 @@ id2label = {0: "REJECTED", 1: "ACCEPTED"}
|
|
| 72 |
# when submit button clicked, run the model and get result
|
| 73 |
if st.button("Submit"):
|
| 74 |
with torch.no_grad():
|
| 75 |
-
|
|
|
|
| 76 |
|
| 77 |
-
predicted_class_id =
|
| 78 |
pred_label = id2label[predicted_class_id]
|
| 79 |
st.title("Predicted Patentability")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
st.write(pred_label)
|
|
|
|
| 72 |
# when submit button clicked, run the model and get result
|
| 73 |
if st.button("Submit"):
|
| 74 |
with torch.no_grad():
|
| 75 |
+
outputs = model(**inputs)
|
| 76 |
+
probability = torch.nn.functional.softmax(outputs.logits, dim=1)
|
| 77 |
|
| 78 |
+
predicted_class_id = probability.argmax().item()
|
| 79 |
pred_label = id2label[predicted_class_id]
|
| 80 |
st.title("Predicted Patentability")
|
| 81 |
+
if probability[0][0] > probability[0][1]:
|
| 82 |
+
st.write("Rejection Score:")
|
| 83 |
+
st.write(probability[0][0])
|
| 84 |
+
else:
|
| 85 |
+
st.write("Acceptance Score:")
|
| 86 |
+
st.write(probability[0][1])
|
| 87 |
st.write(pred_label)
|