Update app.py
Browse files
app.py
CHANGED
|
@@ -25,14 +25,14 @@ def get_top95(labels, probs):
|
|
| 25 |
sorted_labels = [labels[i.item()] for i in sorted_indices]
|
| 26 |
|
| 27 |
cumulative = torch.cumsum(sorted_probs, dim=0)
|
| 28 |
-
cutoff = torch.where(cumulative >= 0.
|
| 29 |
last_idx = cutoff[0].item() + 1 if len(cutoff) > 0 else len(sorted_probs)
|
| 30 |
|
| 31 |
return list(zip(sorted_labels[:last_idx], sorted_probs[:last_idx].tolist()))
|
| 32 |
|
| 33 |
# UI
|
| 34 |
st.set_page_config(page_title="Article Topic Classifier")
|
| 35 |
-
st.title("
|
| 36 |
st.markdown("Enter the **title** and optionally **abstract** of the article.")
|
| 37 |
|
| 38 |
title = st.text_input("Title", placeholder="e.g. Neural Networks for Quantum Physics")
|
|
@@ -52,5 +52,6 @@ if st.button("Classify"):
|
|
| 52 |
|
| 53 |
top_labels = get_top95(id2label, probs)
|
| 54 |
|
|
|
|
| 55 |
for label, prob in top_labels:
|
| 56 |
st.markdown(f"- **{categories[label]} ({label})**: {prob * 100:.1f}%")
|
|
|
|
| 25 |
sorted_labels = [labels[i.item()] for i in sorted_indices]
|
| 26 |
|
| 27 |
cumulative = torch.cumsum(sorted_probs, dim=0)
|
| 28 |
+
cutoff = torch.where(cumulative >= 0.95)[0]
|
| 29 |
last_idx = cutoff[0].item() + 1 if len(cutoff) > 0 else len(sorted_probs)
|
| 30 |
|
| 31 |
return list(zip(sorted_labels[:last_idx], sorted_probs[:last_idx].tolist()))
|
| 32 |
|
| 33 |
# UI
|
| 34 |
st.set_page_config(page_title="Article Topic Classifier")
|
| 35 |
+
st.title("Article Topic Classifier")
|
| 36 |
st.markdown("Enter the **title** and optionally **abstract** of the article.")
|
| 37 |
|
| 38 |
title = st.text_input("Title", placeholder="e.g. Neural Networks for Quantum Physics")
|
|
|
|
| 52 |
|
| 53 |
top_labels = get_top95(id2label, probs)
|
| 54 |
|
| 55 |
+
st.markdown("Top 95% topics:")
|
| 56 |
for label, prob in top_labels:
|
| 57 |
st.markdown(f"- **{categories[label]} ({label})**: {prob * 100:.1f}%")
|