Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from pathlib import Path | |
| MODEL_PATH = str(Path(__file__).resolve().parent.parent / "model") | |
| TOKENIZER_NAME = "oracat/bert-paper-classifier-arxiv" | |
| def load_model(): | |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) | |
| model.eval() | |
| return tokenizer, model | |
| st.set_page_config(page_title="Arxiv Paper Classifier", layout="centered") | |
| st.title("Arxiv Paper Classifier") | |
| st.markdown("Classify an academic paper into an arxiv category by its title and/or abstract.") | |
| title = st.text_input("Paper title (optional)", placeholder="e.g. Attention Is All You Need") | |
| abstract = st.text_area("Abstract (optional)", placeholder="Paste the paper abstract here...", height=200) | |
| if st.button("Classify", type="primary"): | |
| if not title and not abstract: | |
| st.error("Please enter at least a title or an abstract.") | |
| else: | |
| tokenizer, model = load_model() | |
| if title and abstract: | |
| inputs = tokenizer(title, abstract, truncation=True, max_length=256, return_tensors="pt") | |
| elif title: | |
| inputs = tokenizer(title, truncation=True, max_length=256, return_tensors="pt") | |
| else: | |
| inputs = tokenizer(abstract, truncation=True, max_length=256, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = F.softmax(logits, dim=-1).squeeze() | |
| sorted_indices = torch.argsort(probs, descending=True) | |
| id2label = model.config.id2label | |
| cumulative = 0.0 | |
| results = [] | |
| for idx in sorted_indices: | |
| idx = idx.item() | |
| prob = probs[idx].item() | |
| cumulative += prob | |
| results.append((id2label[idx], prob)) | |
| if cumulative >= 0.95: | |
| break | |
| st.subheader("Predicted categories (top-95%)") | |
| for label, prob in results: | |
| col1, col2 = st.columns([3, 7]) | |
| with col1: | |
| st.markdown(f"**{label}**") | |
| with col2: | |
| st.progress(prob, text=f"{prob:.1%}") | |