Spaces:
Running
Running
File size: 3,393 Bytes
ada4299 e0b0f3b ada4299 e0b0f3b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | import streamlit as st
from model_utils import PaperClassifier
st.set_page_config(page_title="Paper Classifier", layout="centered")
st.markdown("""
<style>
.result-box {
background: #4a5568; padding: 1rem; border-radius: 8px; color: white; margin-bottom: 0.5rem;
}
.prob-bar {
background: rgba(255,255,255,0.2); border-radius: 6px; height: 22px; margin-top: 4px; overflow: hidden;
}
.prob-fill {
background: #68d391; height: 100%; border-radius: 6px;
padding-left: 8px; font-size: 0.85rem; font-weight: 600;
color: #1a202c; display: flex; align-items: center;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource(show_spinner="Loading model...")
def load_model():
return PaperClassifier()
EXAMPLES = [
{"title": "Attention Is All You Need",
"abstract": "We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely."},
{"title": "A Survey on 3D Gaussian Splatting",
"abstract": "3D Gaussian splatting (GS) has emerged as a transformative technique in radiance fields. Unlike mainstream implicit neural models, 3D GS uses millions of learnable 3D Gaussians for an explicit scene representation."},
{"title": "Interior Point Differential Dynamic Programming",
"abstract": ""},
]
if "input_title" not in st.session_state:
st.session_state.input_title = ""
if "input_abstract" not in st.session_state:
st.session_state.input_abstract = ""
def set_example(idx):
st.session_state.input_title = EXAMPLES[idx]["title"]
st.session_state.input_abstract = EXAMPLES[idx]["abstract"]
def show_results(results):
st.markdown(f"### Predicted {len(results)} categories")
for r in results:
pct = r["probability"] * 100
st.markdown(f"""
<div class="result-box">
<b>{r['tag']}</b> - {r['name']}
<div class="prob-bar">
<div class="prob-fill" style="width:{max(pct,3)}%">{pct:.1f}%</div>
</div>
</div>""", unsafe_allow_html=True)
def main():
st.title("Paper Classifier")
st.write("Classify papers using fine-tuned SciBERT in one click!")
try:
clf = load_model()
except Exception as err:
st.error(f"Could not load model: {err}")
return
title = st.text_input("**Title:**", key="input_title", placeholder="Paste paper title here")
abstract = st.text_area("**Abstract**", key="input_abstract", placeholder="You can leave it empty", height=150)
st.write("**Use our examples:**")
cols = st.columns(len(EXAMPLES))
for i, (col, ex) in enumerate(zip(cols, EXAMPLES)):
with col:
label = ex["title"][:20] + "..." if len(ex["title"]) > 20 else ex["title"]
st.button(label, key=f"ex_{i}", on_click=set_example, args=(i,), use_container_width=True)
if st.button("Classify", use_container_width=True):
if not title or not title.strip():
st.warning("Enter a title first.")
return
with st.spinner("Classifying..."):
try:
results = clf.predict(title=title, abstract=abstract)
except Exception as err:
st.error(f"Error: {err}")
return
show_results(results)
if __name__ == "__main__":
main()
|