Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import os | |
| import torch.nn as nn | |
| from safetensors import safe_open | |
| from transformers import BertPreTrainedModel, BertModel, BertTokenizer, BertConfig | |
| st.set_page_config(page_title="Paper Classifier", layout="wide") | |
| class BERTClass(BertPreTrainedModel): | |
| def __init__(self, config, p=0.3): | |
| super().__init__(config) | |
| self.bert = BertModel(config) | |
| self.dropout = nn.Dropout(p) | |
| self.linear = nn.Linear(config.hidden_size, config.num_labels) | |
| self.init_weights() | |
| def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None): | |
| outputs = self.bert( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| return_dict=True | |
| ) | |
| pooled_output = outputs.pooler_output | |
| pooled_output = self.dropout(pooled_output) | |
| logits = self.linear(pooled_output) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = nn.BCEWithLogitsLoss() | |
| loss = loss_fct(logits, labels) | |
| return {"loss": loss, "logits": logits} | |
| MODEL_PATH = "." | |
| LABELS = ['astro-ph', 'cond-mat', 'cs', 'eess', 'gr-qc', | |
| 'hep-ex', 'hep-lat', 'hep-ph', 'hep-th', 'math', 'math-ph', 'nlin', | |
| 'nucl-ex', 'nucl-th', 'physics', 'q-bio', 'quant-ph', 'stat'] | |
| MAX_LEN = 512 | |
| def load_model(): | |
| try: | |
| config = BertConfig.from_pretrained("bert-base-cased") | |
| config.num_labels = len(LABELS) | |
| model = BERTClass(config) | |
| with safe_open(f"{MODEL_PATH}/model.safetensors", framework="pt") as f: | |
| state_dict = {key: f.get_tensor(key) for key in f.keys()} | |
| model.load_state_dict(state_dict) | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-cased") | |
| return model.eval(), tokenizer | |
| except Exception as e: | |
| st.error(f"Model loading failed: {str(e)}") | |
| st.stop() | |
| def predict(title, abstract): | |
| if not title.strip() and not abstract.strip(): | |
| raise ValueError("Bro, do you want me to guess?) Give me at least the title!") | |
| text = f"{title.strip()}. {abstract.strip()}".strip() | |
| if len(text) < 10: | |
| raise ValueError("Too short text to say anything sensible") | |
| device = next(model.parameters()).device | |
| inputs = tokenizer.encode_plus( | |
| text, | |
| max_length=MAX_LEN, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs['logits'] | |
| probs = torch.sigmoid(logits).cpu().numpy()[0] | |
| return {label: float(probs[i]) for i, label in enumerate(LABELS)} | |
| model, tokenizer = load_model() | |
| with st.sidebar: | |
| st.header("Display Settings") | |
| display_mode = st.radio( | |
| "Result filtering mode", | |
| ["Top-k categories", "Top-% confidence"], | |
| index=0 | |
| ) | |
| if display_mode == "Top-k categories": | |
| top_k = st.slider( | |
| "Number of categories to show", | |
| min_value=1, | |
| max_value=10, | |
| value=3, | |
| help="Select how many top categories to display" | |
| ) | |
| else: | |
| selected_percent = st.selectbox( | |
| "Confidence threshold", | |
| ["50%", "75%", "95%"], | |
| index=2, | |
| help="Display categories until reaching this cumulative confidence" | |
| ) | |
| st.title("π Academic Paper Classifier") | |
| with st.form("input_form"): | |
| title = st.text_input("Paper Title", placeholder="Enter paper title...") | |
| abstract = st.text_area("Abstract", placeholder="Paste paper abstract here...", height=200) | |
| submitted = st.form_submit_button("Classify") | |
| if submitted: | |
| with st.spinner("Analyzing paper..."): | |
| try: | |
| full_predictions = predict(title, abstract) | |
| sorted_preds = sorted(full_predictions.items(), | |
| key=lambda x: x[1], | |
| reverse=True) | |
| if display_mode == "Top-k categories": | |
| filtered = dict(sorted_preds[:top_k]) | |
| else: | |
| threshold = {"50%": 0.5, "75%": 0.75, "95%": 0.95}[selected_percent] | |
| total = sum(score for _, score in sorted_preds) | |
| cumulative = 0 | |
| filtered = {} | |
| for label, score in sorted_preds: | |
| cumulative += score | |
| filtered[label] = score | |
| if cumulative >= threshold: | |
| break | |
| if len(filtered) >= 10: | |
| break | |
| if not filtered: | |
| st.warning("No categories meet the selected criteria") | |
| else: | |
| top_class = max(filtered, key=filtered.get) | |
| st.success(f"Most likely category: **{top_class}**") | |
| st.subheader("Category Confidence Scores:") | |
| total_shown = sum(filtered.values()) | |
| for label, score in filtered.items(): | |
| relative_score = score / total_shown | |
| st.progress( | |
| relative_score, | |
| text=f"{label}: {score:.1%}" | |
| ) | |
| st.caption(f"Coverage: {sum(filtered.values()):.1%} of total confidence") | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| with st.sidebar: | |
| st.header("About") | |
| st.markdown(f""" | |
| This tool predicts the arxiv tag of research papers by their title and abstarct via fine-tuned BERT. | |
| - Enter title and abstract | |
| - Enjoy the magnificent classification results | |
| """) |