Spaces:
Sleeping
Sleeping
File size: 2,409 Bytes
ae436d1 cc2e31c ae436d1 cc2e31c ae436d1 cc2e31c ae436d1 cc2e31c ae436d1 cc2e31c ae436d1 cc2e31c ae436d1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | import numpy as np
import pandas as pd
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import plotly.express as px
MODEL_REPO = "ChocoLord/paper-classifier-model"
MAX_LENGTH = 512
TOP_P = 0.95
st.set_page_config(page_title="Paper classifier", layout="wide")
st.title("Paper classifier")
@st.cache_resource
def load_artifacts():
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
model.eval()
return tokenizer, model
tokenizer, model = load_artifacts()
def predict(title: str, summary: str):
text = f"{title or ''}\n{summary or ''}".strip()
inputs = tokenizer(
text,
truncation=True,
padding="max_length",
max_length=MAX_LENGTH,
return_tensors="pt",
)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
labels = [model.config.id2label[i] for i in range(len(probs))]
df = pd.DataFrame({
"class_name": labels,
"predicted_proba": probs,
}).sort_values("predicted_proba", ascending=False).reset_index(drop=True)
df["cumsum"] = df["predicted_proba"].cumsum()
cutoff_idx = int(np.searchsorted(df["cumsum"].values, TOP_P, side="left"))
selected_df = df.iloc[:cutoff_idx + 1].copy()
return df, selected_df
title = st.text_input("Title")
summary = st.text_area("Summary", height=250)
if st.button("Classify", type="primary"):
if not title.strip() and not summary.strip():
st.warning("Enter title and/or summary.")
else:
df, selected_df = predict(title, summary)
st.subheader("Selected classes")
lines = [
f"{i+1}. {row.class_name} — {row.predicted_proba:.4f}"
for i, row in selected_df.iterrows()
]
st.text("\n".join(lines))
st.subheader("Probability bar chart")
fig = px.bar(
df,
x="class_name",
y="predicted_proba",
)
fig.update_layout(
xaxis_title="Class",
yaxis_title="Predicted probability",
xaxis_tickangle=-45,
)
st.plotly_chart(fig, use_container_width=True)
with st.expander("Full sorted predictions"):
st.dataframe(df, use_container_width=True) |