File size: 3,393 Bytes
ada4299
e0b0f3b
ada4299
e0b0f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
from model_utils import PaperClassifier

st.set_page_config(page_title="Paper Classifier", layout="centered")

st.markdown("""
<style>
    .result-box {
        background: #4a5568; padding: 1rem; border-radius: 8px; color: white; margin-bottom: 0.5rem;
    }
    .prob-bar {
        background: rgba(255,255,255,0.2); border-radius: 6px; height: 22px; margin-top: 4px; overflow: hidden;
    }
    .prob-fill {
        background: #68d391; height: 100%; border-radius: 6px;
        padding-left: 8px; font-size: 0.85rem; font-weight: 600;
        color: #1a202c; display: flex; align-items: center;
    }
</style>
""", unsafe_allow_html=True)


@st.cache_resource(show_spinner="Loading model...")
def load_model():
    return PaperClassifier()


EXAMPLES = [
    {"title": "Attention Is All You Need",
     "abstract": "We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely."},
    {"title": "A Survey on 3D Gaussian Splatting",
     "abstract": "3D Gaussian splatting (GS) has emerged as a transformative technique in radiance fields. Unlike mainstream implicit neural models, 3D GS uses millions of learnable 3D Gaussians for an explicit scene representation."},
    {"title": "Interior Point Differential Dynamic Programming",
     "abstract": ""},
]

if "input_title" not in st.session_state:
    st.session_state.input_title = ""
if "input_abstract" not in st.session_state:
    st.session_state.input_abstract = ""


def set_example(idx):
    st.session_state.input_title = EXAMPLES[idx]["title"]
    st.session_state.input_abstract = EXAMPLES[idx]["abstract"]


def show_results(results):
    st.markdown(f"### Predicted {len(results)} categories")
    for r in results:
        pct = r["probability"] * 100
        st.markdown(f"""
        <div class="result-box">
            <b>{r['tag']}</b> - {r['name']}
            <div class="prob-bar">
                <div class="prob-fill" style="width:{max(pct,3)}%">{pct:.1f}%</div>
            </div>
        </div>""", unsafe_allow_html=True)


def main():
    st.title("Paper Classifier")
    st.write("Classify papers using fine-tuned SciBERT in one click!")

    try:
        clf = load_model()
    except Exception as err:
        st.error(f"Could not load model: {err}")
        return

    title = st.text_input("**Title:**", key="input_title", placeholder="Paste paper title here")
    abstract = st.text_area("**Abstract**", key="input_abstract", placeholder="You can leave it empty", height=150)

    st.write("**Use our examples:**")
    cols = st.columns(len(EXAMPLES))
    for i, (col, ex) in enumerate(zip(cols, EXAMPLES)):
        with col:
            label = ex["title"][:20] + "..." if len(ex["title"]) > 20 else ex["title"]
            st.button(label, key=f"ex_{i}", on_click=set_example, args=(i,), use_container_width=True)

    if st.button("Classify", use_container_width=True):
        if not title or not title.strip():
            st.warning("Enter a title first.")
            return

        with st.spinner("Classifying..."):
            try:
                results = clf.predict(title=title, abstract=abstract)
            except Exception as err:
                st.error(f"Error: {err}")
                return

        show_results(results)


if __name__ == "__main__":
    main()