Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import plotly.express as px | |
| MODEL_REPO = "ChocoLord/paper-classifier-model" | |
| MAX_LENGTH = 512 | |
| TOP_P = 0.95 | |
| st.set_page_config(page_title="Paper classifier", layout="wide") | |
| st.title("Paper classifier") | |
| def load_artifacts(): | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO) | |
| model.eval() | |
| return tokenizer, model | |
| tokenizer, model = load_artifacts() | |
| def predict(title: str, summary: str): | |
| text = f"{title or ''}\n{summary or ''}".strip() | |
| inputs = tokenizer( | |
| text, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=MAX_LENGTH, | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = torch.softmax(logits, dim=-1).cpu().numpy()[0] | |
| labels = [model.config.id2label[i] for i in range(len(probs))] | |
| df = pd.DataFrame({ | |
| "class_name": labels, | |
| "predicted_proba": probs, | |
| }).sort_values("predicted_proba", ascending=False).reset_index(drop=True) | |
| df["cumsum"] = df["predicted_proba"].cumsum() | |
| cutoff_idx = int(np.searchsorted(df["cumsum"].values, TOP_P, side="left")) | |
| selected_df = df.iloc[:cutoff_idx + 1].copy() | |
| return df, selected_df | |
| title = st.text_input("Title") | |
| summary = st.text_area("Summary", height=250) | |
| if st.button("Classify", type="primary"): | |
| if not title.strip() and not summary.strip(): | |
| st.warning("Enter title and/or summary.") | |
| else: | |
| df, selected_df = predict(title, summary) | |
| st.subheader("Selected classes") | |
| lines = [ | |
| f"{i+1}. {row.class_name} — {row.predicted_proba:.4f}" | |
| for i, row in selected_df.iterrows() | |
| ] | |
| st.text("\n".join(lines)) | |
| st.subheader("Probability bar chart") | |
| fig = px.bar( | |
| df, | |
| x="class_name", | |
| y="predicted_proba", | |
| ) | |
| fig.update_layout( | |
| xaxis_title="Class", | |
| yaxis_title="Predicted probability", | |
| xaxis_tickangle=-45, | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| with st.expander("Full sorted predictions"): | |
| st.dataframe(df, use_container_width=True) |