File size: 4,056 Bytes
2afb296
26e52f6
2afb296
 
f58df3f
32903ec
26e52f6
 
 
 
 
f58df3f
 
 
 
 
26e52f6
 
 
 
 
 
 
32903ec
 
b9bfaab
32903ec
 
 
26e52f6
32903ec
26e52f6
2afb296
 
 
 
 
 
 
 
 
 
 
26e52f6
 
 
f58df3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26e52f6
32903ec
 
 
 
 
f58df3f
 
 
26e52f6
 
 
 
 
 
32903ec
 
 
 
 
 
 
 
2afb296
 
26e52f6
2afb296
26e52f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2afb296
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import json
import os
import time

import arxiv
import joblib
import streamlit as st
from transformers import pipeline

st.set_page_config(page_title="ArXiv Paper Classifier", page_icon="📄")

if "auto_title" not in st.session_state:
    st.session_state["auto_title"] = ""
if "auto_abstract" not in st.session_state:
    st.session_state["auto_abstract"] = ""


@st.cache_resource
def load_pipeline():
    model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model")
    return pipeline("text-classification", model=model_path, top_k=None)


@st.cache_resource
def load_gatekeeper():
    base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    return joblib.load(os.path.join(base_dir, "ood_detector.pkl"))


classifier = load_pipeline()
gatekeeper = load_gatekeeper()

with st.sidebar:
    st.subheader("About the Model")
    st.markdown(
        """
        - **Base model:** `distilbert-base-uncased`
        - **Fine-tuning:** Balanced ArXiv dataset  (ccdv/arxiv-classification)
        - **Task:** Classification
        """
    )
    st.info("The model is cached after the first load for fast inference on subsequent requests.")

st.title("ArXiv Paper Classifier")
st.write("Enter a paper's title and abstract to predict its subject category.")

with st.expander("Load from link", expanded=True):
    arxiv_url = st.text_input("ArXiv URL", placeholder="https://arxiv.org/abs/1706.03762")
    if st.button("Fetch paper data"):
        if "arxiv.org/abs/" not in arxiv_url:
            st.warning("Please enter a valid ArXiv URL containing 'arxiv.org/abs/'.")
        else:
            paper_id = arxiv_url.rstrip("/").split("/")[-1].split("v")[0]
            with st.spinner("Fetching from ArXiv..."):
                try:
                    search = arxiv.Search(id_list=[paper_id])
                    paper = next(search.results())
                    st.session_state["auto_title"] = paper.title
                    st.session_state["auto_abstract"] = paper.summary
                    st.success(f"Loaded: {paper.title}")
                except Exception as e:
                    st.error(f"Failed to fetch paper: {e}")

st.text_input("Title", key="auto_title")
st.text_area("Abstract", height=200, key="auto_abstract")

col_btn, col_bypass = st.columns([3, 1])
classify_clicked = col_btn.button("Classify", use_container_width=True)
bypass_gatekeeper = col_bypass.toggle("⚡ Bypass Gatekeeper")

if classify_clicked:
    title = st.session_state["auto_title"]
    abstract = st.session_state["auto_abstract"]

    if not title.strip() and not abstract.strip():
        st.error("Please provide at least a title or an abstract.")
        st.stop()

    text = f"{title.strip()}. {abstract.strip()}" if title.strip() else abstract.strip()

    if not bypass_gatekeeper:
        is_science = gatekeeper.predict([text])[0]
        if is_science == 0:
            st.warning(
                "This text is NOT a scientific paper. Please enter a valid scientific abstract."
            )
            st.stop()

    with st.spinner("Classifying paper"):
        start_time = time.time()
        predictions = classifier(text)[0]
        end_time = time.time()

    predictions.sort(key=lambda x: x["score"], reverse=True)

    top_predictions = []
    cumulative = 0.0
    for pred in predictions:
        top_predictions.append(pred)
        cumulative += pred["score"]
        if cumulative >= 0.95:
            break

    st.subheader("Results")
    for pred in top_predictions:
        label = pred["label"]
        score = pred["score"]
        st.write(f"**{label}** — {score * 100:.1f}%")
        st.progress(score)

    st.caption(f"Inference time: {end_time - start_time:.3f} seconds")

    results_json = json.dumps(
        [{"label": p["label"], "score": round(p["score"], 6)} for p in top_predictions],
        indent=2,
    )
    st.download_button(
        label="Download Results JSON",
        data=results_json,
        file_name="predictions.json",
        mime="application/json",
    )