Spaces:
Running
Running
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| # from model_debug import debug_model_path | |
| import requests | |
| import os | |
| import random | |
| import arxiv | |
| st.title("Article Category Detector") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| MY_MODEL_ID = "IgorLarin/yasd_2026_articles" | |
| BASE_MODEL = "allenai/scibert_scivocab_uncased" | |
| # st.caption(HF_TOKEN) | |
| def load_model(): | |
| model = AutoModelForSequenceClassification.from_pretrained(MY_MODEL_ID, token=HF_TOKEN) | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN) | |
| model.eval() | |
| return model, tokenizer | |
| def build_text(title: str, abstract: str = "") -> str: | |
| title = (title or "").strip() | |
| abstract = (abstract or "").strip() | |
| if abstract: | |
| return f"[TITLE] {title} [ABSTRACT] {abstract}" | |
| return f"[TITLE] {title}" | |
| def predict_top95(title: str, abstract: str = "", threshold: float = 0.95): | |
| text = build_text(title, abstract) | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=256 | |
| ) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = torch.softmax(logits, dim=-1)[0] | |
| ranked = sorted( | |
| [(id2label[i], float(probs[i])) for i in range(len(probs))], | |
| key=lambda x: x[1], | |
| reverse=True | |
| ) | |
| result = [] | |
| total = 0.0 | |
| for label, prob in ranked: | |
| result.append((label, prob)) | |
| total += prob | |
| if total >= threshold: | |
| break | |
| return result | |
| def get_random_arxiv_paper(category=None, max_results=50): | |
| # Build query - use category if provided, otherwise search broadly | |
| if category: | |
| query = f"cat:{category}" | |
| else: | |
| query = "a" # can't do arbitrary query | |
| print("before seacrh") | |
| search = arxiv.Search( | |
| query=query, | |
| max_results=max_results, | |
| sort_by=arxiv.SortCriterion.SubmittedDate, | |
| sort_order=arxiv.SortOrder.Descending, | |
| ) | |
| client = arxiv.Client( | |
| # Try alternate base URL if needed (though not officially supported) | |
| delay_seconds=3.0, | |
| num_retries=5) | |
| papers = list(client.results(search)) | |
| if not papers: | |
| return None | |
| # Select random paper | |
| random_paper = random.choice(papers) | |
| return random_paper | |
| def has_content(): | |
| return (st.session_state.title is not None or | |
| st.session_state.abstract is not None or | |
| st.session_state.url is not None or | |
| st.session_state.primary_category is not None or | |
| st.session_state.result is not None) | |
| model, tokenizer = load_model() | |
| id2label = model.config.id2label | |
| if "title" not in st.session_state: | |
| st.session_state.title = "" | |
| if "url" not in st.session_state: | |
| st.session_state.url = "" | |
| if "primary_category" not in st.session_state: | |
| st.session_state.primary_category = "" | |
| if "result" not in st.session_state: | |
| st.session_state.result = None | |
| col1, col2 = st.columns([6, 1]) | |
| with col1: | |
| if st.button("Load random arXiv article"): | |
| try: | |
| paper = get_random_arxiv_paper() | |
| st.session_state.title = paper.title | |
| st.session_state.abstract = paper.summary | |
| st.session_state.url = paper.entry_id | |
| st.session_state.primary_category = paper.primary_category | |
| st.session_state.result = predict_top95(paper.title, paper.summary) | |
| except Exception as e: | |
| st.error(f"Failed to load article: {e}") | |
| with col2: | |
| if has_content(): | |
| if st.button("Clear"): | |
| st.session_state.title = None | |
| st.session_state.abstract = None | |
| st.session_state.url = None | |
| st.session_state.primary_category = None | |
| st.session_state.result = None | |
| if st.session_state.get("url"): | |
| url = st.session_state["url"] | |
| category = st.session_state["primary_category"] | |
| st.caption(f"{url} ({category})") | |
| title = st.text_area( | |
| "Title", | |
| key="title", | |
| height=30, | |
| placeholder="Enter title here...") | |
| abstract = st.text_area( | |
| "Abstract", | |
| key="abstract", | |
| height=150, | |
| placeholder="Enter abstract here...") | |
| if st.button("Detect"): | |
| if title.strip() == "": | |
| st.warning("Please enter a title of the article.") | |
| else: | |
| st.session_state.result = predict_top95(title, abstract) | |
| if st.session_state.result is not None: | |
| result = st.session_state.result | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Top prediction") | |
| st.write(result[0][0], f"({result[0][1]*100:.1f}%)") | |
| with col2: | |
| st.subheader("Top 95%") | |
| for label, score in result: | |
| st.write(f"{label}: {score*100:.1f}%") |