File size: 2,409 Bytes
ae436d1
 
 
 
 
 
 
cc2e31c
 
 
ae436d1
 
 
 
 
 
 
 
 
cc2e31c
ae436d1
cc2e31c
ae436d1
 
cc2e31c
ae436d1
 
 
 
 
 
 
 
 
 
 
 
 
cc2e31c
 
ae436d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2e31c
ae436d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
import pandas as pd
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import plotly.express as px

MODEL_REPO = "ChocoLord/paper-classifier-model"
MAX_LENGTH = 512
TOP_P = 0.95

st.set_page_config(page_title="Paper classifier", layout="wide")
st.title("Paper classifier")

@st.cache_resource
def load_artifacts():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
    model.eval()
    return tokenizer, model

tokenizer, model = load_artifacts()

def predict(title: str, summary: str):
    text = f"{title or ''}\n{summary or ''}".strip()

    inputs = tokenizer(
        text,
        truncation=True,
        padding="max_length",
        max_length=MAX_LENGTH,
        return_tensors="pt",
    )

    with torch.no_grad():
        logits = model(**inputs).logits
        probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]

    labels = [model.config.id2label[i] for i in range(len(probs))]

    df = pd.DataFrame({
        "class_name": labels,
        "predicted_proba": probs,
    }).sort_values("predicted_proba", ascending=False).reset_index(drop=True)

    df["cumsum"] = df["predicted_proba"].cumsum()
    cutoff_idx = int(np.searchsorted(df["cumsum"].values, TOP_P, side="left"))
    selected_df = df.iloc[:cutoff_idx + 1].copy()

    return df, selected_df

title = st.text_input("Title")
summary = st.text_area("Summary", height=250)

if st.button("Classify", type="primary"):
    if not title.strip() and not summary.strip():
        st.warning("Enter title and/or summary.")
    else:
        df, selected_df = predict(title, summary)

        st.subheader("Selected classes")
        lines = [
            f"{i+1}. {row.class_name}{row.predicted_proba:.4f}"
            for i, row in selected_df.iterrows()
        ]
        st.text("\n".join(lines))

        st.subheader("Probability bar chart")
        fig = px.bar(
            df,
            x="class_name",
            y="predicted_proba",
        )
        fig.update_layout(
            xaxis_title="Class",
            yaxis_title="Predicted probability",
            xaxis_tickangle=-45,
        )
        st.plotly_chart(fig, use_container_width=True)

        with st.expander("Full sorted predictions"):
            st.dataframe(df, use_container_width=True)