Spaces:
Running
Running
| import streamlit as st | |
| from model_utils import PaperClassifier | |
| st.set_page_config(page_title="Paper Classifier", layout="centered") | |
| st.markdown(""" | |
| <style> | |
| .result-box { | |
| background: #4a5568; padding: 1rem; border-radius: 8px; color: white; margin-bottom: 0.5rem; | |
| } | |
| .prob-bar { | |
| background: rgba(255,255,255,0.2); border-radius: 6px; height: 22px; margin-top: 4px; overflow: hidden; | |
| } | |
| .prob-fill { | |
| background: #68d391; height: 100%; border-radius: 6px; | |
| padding-left: 8px; font-size: 0.85rem; font-weight: 600; | |
| color: #1a202c; display: flex; align-items: center; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| def load_model(): | |
| return PaperClassifier() | |
| EXAMPLES = [ | |
| {"title": "Attention Is All You Need", | |
| "abstract": "We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely."}, | |
| {"title": "A Survey on 3D Gaussian Splatting", | |
| "abstract": "3D Gaussian splatting (GS) has emerged as a transformative technique in radiance fields. Unlike mainstream implicit neural models, 3D GS uses millions of learnable 3D Gaussians for an explicit scene representation."}, | |
| {"title": "Interior Point Differential Dynamic Programming", | |
| "abstract": ""}, | |
| ] | |
| if "input_title" not in st.session_state: | |
| st.session_state.input_title = "" | |
| if "input_abstract" not in st.session_state: | |
| st.session_state.input_abstract = "" | |
| def set_example(idx): | |
| st.session_state.input_title = EXAMPLES[idx]["title"] | |
| st.session_state.input_abstract = EXAMPLES[idx]["abstract"] | |
| def show_results(results): | |
| st.markdown(f"### Predicted {len(results)} categories") | |
| for r in results: | |
| pct = r["probability"] * 100 | |
| st.markdown(f""" | |
| <div class="result-box"> | |
| <b>{r['tag']}</b> - {r['name']} | |
| <div class="prob-bar"> | |
| <div class="prob-fill" style="width:{max(pct,3)}%">{pct:.1f}%</div> | |
| </div> | |
| </div>""", unsafe_allow_html=True) | |
| def main(): | |
| st.title("Paper Classifier") | |
| st.write("Classify papers using fine-tuned SciBERT in one click!") | |
| try: | |
| clf = load_model() | |
| except Exception as err: | |
| st.error(f"Could not load model: {err}") | |
| return | |
| title = st.text_input("**Title:**", key="input_title", placeholder="Paste paper title here") | |
| abstract = st.text_area("**Abstract**", key="input_abstract", placeholder="You can leave it empty", height=150) | |
| st.write("**Use our examples:**") | |
| cols = st.columns(len(EXAMPLES)) | |
| for i, (col, ex) in enumerate(zip(cols, EXAMPLES)): | |
| with col: | |
| label = ex["title"][:20] + "..." if len(ex["title"]) > 20 else ex["title"] | |
| st.button(label, key=f"ex_{i}", on_click=set_example, args=(i,), use_container_width=True) | |
| if st.button("Classify", use_container_width=True): | |
| if not title or not title.strip(): | |
| st.warning("Enter a title first.") | |
| return | |
| with st.spinner("Classifying..."): | |
| try: | |
| results = clf.predict(title=title, abstract=abstract) | |
| except Exception as err: | |
| st.error(f"Error: {err}") | |
| return | |
| show_results(results) | |
| if __name__ == "__main__": | |
| main() | |