yasd_2026_articles / src /streamlit_app.py
IgorLarin's picture
Update src/streamlit_app.py
b07c106 verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# from model_debug import debug_model_path
import requests
import os
import random
import arxiv
st.title("Article Category Detector")
HF_TOKEN = os.environ.get("HF_TOKEN")
MY_MODEL_ID = "IgorLarin/yasd_2026_articles"
BASE_MODEL = "allenai/scibert_scivocab_uncased"
# st.caption(HF_TOKEN)
@st.cache_resource
def load_model():
model = AutoModelForSequenceClassification.from_pretrained(MY_MODEL_ID, token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN)
model.eval()
return model, tokenizer
def build_text(title: str, abstract: str = "") -> str:
title = (title or "").strip()
abstract = (abstract or "").strip()
if abstract:
return f"[TITLE] {title} [ABSTRACT] {abstract}"
return f"[TITLE] {title}"
def predict_top95(title: str, abstract: str = "", threshold: float = 0.95):
text = build_text(title, abstract)
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=256
)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)[0]
ranked = sorted(
[(id2label[i], float(probs[i])) for i in range(len(probs))],
key=lambda x: x[1],
reverse=True
)
result = []
total = 0.0
for label, prob in ranked:
result.append((label, prob))
total += prob
if total >= threshold:
break
return result
def get_random_arxiv_paper(category=None, max_results=50):
# Build query - use category if provided, otherwise search broadly
if category:
query = f"cat:{category}"
else:
query = "a" # can't do arbitrary query
print("before seacrh")
search = arxiv.Search(
query=query,
max_results=max_results,
sort_by=arxiv.SortCriterion.SubmittedDate,
sort_order=arxiv.SortOrder.Descending,
)
client = arxiv.Client(
# Try alternate base URL if needed (though not officially supported)
delay_seconds=3.0,
num_retries=5)
papers = list(client.results(search))
if not papers:
return None
# Select random paper
random_paper = random.choice(papers)
return random_paper
def has_content():
return (st.session_state.title is not None or
st.session_state.abstract is not None or
st.session_state.url is not None or
st.session_state.primary_category is not None or
st.session_state.result is not None)
model, tokenizer = load_model()
id2label = model.config.id2label
if "title" not in st.session_state:
st.session_state.title = ""
if "url" not in st.session_state:
st.session_state.url = ""
if "primary_category" not in st.session_state:
st.session_state.primary_category = ""
if "result" not in st.session_state:
st.session_state.result = None
col1, col2 = st.columns([6, 1])
with col1:
if st.button("Load random arXiv article"):
try:
paper = get_random_arxiv_paper()
st.session_state.title = paper.title
st.session_state.abstract = paper.summary
st.session_state.url = paper.entry_id
st.session_state.primary_category = paper.primary_category
st.session_state.result = predict_top95(paper.title, paper.summary)
except Exception as e:
st.error(f"Failed to load article: {e}")
with col2:
if has_content():
if st.button("Clear"):
st.session_state.title = None
st.session_state.abstract = None
st.session_state.url = None
st.session_state.primary_category = None
st.session_state.result = None
if st.session_state.get("url"):
url = st.session_state["url"]
category = st.session_state["primary_category"]
st.caption(f"{url} ({category})")
title = st.text_area(
"Title",
key="title",
height=30,
placeholder="Enter title here...")
abstract = st.text_area(
"Abstract",
key="abstract",
height=150,
placeholder="Enter abstract here...")
if st.button("Detect"):
if title.strip() == "":
st.warning("Please enter a title of the article.")
else:
st.session_state.result = predict_top95(title, abstract)
if st.session_state.result is not None:
result = st.session_state.result
col1, col2 = st.columns(2)
with col1:
st.subheader("Top prediction")
st.write(result[0][0], f"({result[0][1]*100:.1f}%)")
with col2:
st.subheader("Top 95%")
for label, score in result:
st.write(f"{label}: {score*100:.1f}%")