import json import os import time import arxiv import joblib import streamlit as st from transformers import pipeline st.set_page_config(page_title="ArXiv Paper Classifier", page_icon="📄") if "auto_title" not in st.session_state: st.session_state["auto_title"] = "" if "auto_abstract" not in st.session_state: st.session_state["auto_abstract"] = "" @st.cache_resource def load_pipeline(): model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model") return pipeline("text-classification", model=model_path, top_k=None) @st.cache_resource def load_gatekeeper(): base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) return joblib.load(os.path.join(base_dir, "ood_detector.pkl")) classifier = load_pipeline() gatekeeper = load_gatekeeper() with st.sidebar: st.subheader("About the Model") st.markdown( """ - **Base model:** `distilbert-base-uncased` - **Fine-tuning:** Balanced ArXiv dataset (ccdv/arxiv-classification) - **Task:** Classification """ ) st.info("The model is cached after the first load for fast inference on subsequent requests.") st.title("ArXiv Paper Classifier") st.write("Enter a paper's title and abstract to predict its subject category.") with st.expander("Load from link", expanded=True): arxiv_url = st.text_input("ArXiv URL", placeholder="https://arxiv.org/abs/1706.03762") if st.button("Fetch paper data"): if "arxiv.org/abs/" not in arxiv_url: st.warning("Please enter a valid ArXiv URL containing 'arxiv.org/abs/'.") else: paper_id = arxiv_url.rstrip("/").split("/")[-1].split("v")[0] with st.spinner("Fetching from ArXiv..."): try: search = arxiv.Search(id_list=[paper_id]) paper = next(search.results()) st.session_state["auto_title"] = paper.title st.session_state["auto_abstract"] = paper.summary st.success(f"Loaded: {paper.title}") except Exception as e: st.error(f"Failed to fetch paper: {e}") st.text_input("Title", key="auto_title") st.text_area("Abstract", height=200, key="auto_abstract") col_btn, col_bypass = st.columns([3, 1]) classify_clicked = col_btn.button("Classify", use_container_width=True) bypass_gatekeeper = col_bypass.toggle("⚡ Bypass Gatekeeper") if classify_clicked: title = st.session_state["auto_title"] abstract = st.session_state["auto_abstract"] if not title.strip() and not abstract.strip(): st.error("Please provide at least a title or an abstract.") st.stop() text = f"{title.strip()}. {abstract.strip()}" if title.strip() else abstract.strip() if not bypass_gatekeeper: is_science = gatekeeper.predict([text])[0] if is_science == 0: st.warning( "This text is NOT a scientific paper. Please enter a valid scientific abstract." ) st.stop() with st.spinner("Classifying paper"): start_time = time.time() predictions = classifier(text)[0] end_time = time.time() predictions.sort(key=lambda x: x["score"], reverse=True) top_predictions = [] cumulative = 0.0 for pred in predictions: top_predictions.append(pred) cumulative += pred["score"] if cumulative >= 0.95: break st.subheader("Results") for pred in top_predictions: label = pred["label"] score = pred["score"] st.write(f"**{label}** — {score * 100:.1f}%") st.progress(score) st.caption(f"Inference time: {end_time - start_time:.3f} seconds") results_json = json.dumps( [{"label": p["label"], "score": round(p["score"], 6)} for p in top_predictions], indent=2, ) st.download_button( label="Download Results JSON", data=results_json, file_name="predictions.json", mime="application/json", )