import streamlit as st from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch # from model_debug import debug_model_path import requests import os import random import arxiv st.title("Article Category Detector") HF_TOKEN = os.environ.get("HF_TOKEN") MY_MODEL_ID = "IgorLarin/yasd_2026_articles" BASE_MODEL = "allenai/scibert_scivocab_uncased" # st.caption(HF_TOKEN) @st.cache_resource def load_model(): model = AutoModelForSequenceClassification.from_pretrained(MY_MODEL_ID, token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN) model.eval() return model, tokenizer def build_text(title: str, abstract: str = "") -> str: title = (title or "").strip() abstract = (abstract or "").strip() if abstract: return f"[TITLE] {title} [ABSTRACT] {abstract}" return f"[TITLE] {title}" def predict_top95(title: str, abstract: str = "", threshold: float = 0.95): text = build_text(title, abstract) inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=256 ) with torch.no_grad(): logits = model(**inputs).logits probs = torch.softmax(logits, dim=-1)[0] ranked = sorted( [(id2label[i], float(probs[i])) for i in range(len(probs))], key=lambda x: x[1], reverse=True ) result = [] total = 0.0 for label, prob in ranked: result.append((label, prob)) total += prob if total >= threshold: break return result def get_random_arxiv_paper(category=None, max_results=50): # Build query - use category if provided, otherwise search broadly if category: query = f"cat:{category}" else: query = "a" # can't do arbitrary query print("before seacrh") search = arxiv.Search( query=query, max_results=max_results, sort_by=arxiv.SortCriterion.SubmittedDate, sort_order=arxiv.SortOrder.Descending, ) client = arxiv.Client( # Try alternate base URL if needed (though not officially supported) delay_seconds=3.0, num_retries=5) papers = list(client.results(search)) if not papers: return None # Select random paper random_paper = random.choice(papers) return random_paper def has_content(): return (st.session_state.title is not None or st.session_state.abstract is not None or st.session_state.url is not None or st.session_state.primary_category is not None or st.session_state.result is not None) model, tokenizer = load_model() id2label = model.config.id2label if "title" not in st.session_state: st.session_state.title = "" if "url" not in st.session_state: st.session_state.url = "" if "primary_category" not in st.session_state: st.session_state.primary_category = "" if "result" not in st.session_state: st.session_state.result = None col1, col2 = st.columns([6, 1]) with col1: if st.button("Load random arXiv article"): try: paper = get_random_arxiv_paper() st.session_state.title = paper.title st.session_state.abstract = paper.summary st.session_state.url = paper.entry_id st.session_state.primary_category = paper.primary_category st.session_state.result = predict_top95(paper.title, paper.summary) except Exception as e: st.error(f"Failed to load article: {e}") with col2: if has_content(): if st.button("Clear"): st.session_state.title = None st.session_state.abstract = None st.session_state.url = None st.session_state.primary_category = None st.session_state.result = None if st.session_state.get("url"): url = st.session_state["url"] category = st.session_state["primary_category"] st.caption(f"{url} ({category})") title = st.text_area( "Title", key="title", height=30, placeholder="Enter title here...") abstract = st.text_area( "Abstract", key="abstract", height=150, placeholder="Enter abstract here...") if st.button("Detect"): if title.strip() == "": st.warning("Please enter a title of the article.") else: st.session_state.result = predict_top95(title, abstract) if st.session_state.result is not None: result = st.session_state.result col1, col2 = st.columns(2) with col1: st.subheader("Top prediction") st.write(result[0][0], f"({result[0][1]*100:.1f}%)") with col2: st.subheader("Top 95%") for label, score in result: st.write(f"{label}: {score*100:.1f}%")