import streamlit as st from model_utils import PaperClassifier st.set_page_config(page_title="Paper Classifier", layout="centered") st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource(show_spinner="Loading model...") def load_model(): return PaperClassifier() EXAMPLES = [ {"title": "Attention Is All You Need", "abstract": "We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely."}, {"title": "A Survey on 3D Gaussian Splatting", "abstract": "3D Gaussian splatting (GS) has emerged as a transformative technique in radiance fields. Unlike mainstream implicit neural models, 3D GS uses millions of learnable 3D Gaussians for an explicit scene representation."}, {"title": "Interior Point Differential Dynamic Programming", "abstract": ""}, ] if "input_title" not in st.session_state: st.session_state.input_title = "" if "input_abstract" not in st.session_state: st.session_state.input_abstract = "" def set_example(idx): st.session_state.input_title = EXAMPLES[idx]["title"] st.session_state.input_abstract = EXAMPLES[idx]["abstract"] def show_results(results): st.markdown(f"### Predicted {len(results)} categories") for r in results: pct = r["probability"] * 100 st.markdown(f"""
{r['tag']} - {r['name']}
{pct:.1f}%
""", unsafe_allow_html=True) def main(): st.title("Paper Classifier") st.write("Classify papers using fine-tuned SciBERT in one click!") try: clf = load_model() except Exception as err: st.error(f"Could not load model: {err}") return title = st.text_input("**Title:**", key="input_title", placeholder="Paste paper title here") abstract = st.text_area("**Abstract**", key="input_abstract", placeholder="You can leave it empty", height=150) st.write("**Use our examples:**") cols = st.columns(len(EXAMPLES)) for i, (col, ex) in enumerate(zip(cols, EXAMPLES)): with col: label = ex["title"][:20] + "..." if len(ex["title"]) > 20 else ex["title"] st.button(label, key=f"ex_{i}", on_click=set_example, args=(i,), use_container_width=True) if st.button("Classify", use_container_width=True): if not title or not title.strip(): st.warning("Enter a title first.") return with st.spinner("Classifying..."): try: results = clf.predict(title=title, abstract=abstract) except Exception as err: st.error(f"Error: {err}") return show_results(results) if __name__ == "__main__": main()