Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import torch.nn.functional as F | |
| import pandas as pd | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| from datasets import load_dataset | |
| device = 'cpu' | |
| def get_model_and_tokenizer(): | |
| model_name = "FacebookAI/roberta-base" | |
| num_labels = 157 | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels) | |
| chkp = torch.load("arxiv_roberta_final.pt", map_location=device) | |
| model.load_state_dict(chkp['model']) | |
| return model, tokenizer | |
| def get_categories(): | |
| categories = load_dataset("TimSchopf/arxiv_categories", "arxiv_category_descriptions") | |
| cat2id = dict((cat, id) for id, cat in enumerate(categories['arxiv_category_descriptions']['tag'])) | |
| id2cat = categories['arxiv_category_descriptions']['tag'] | |
| names = categories['arxiv_category_descriptions']['name'] | |
| return cat2id, id2cat, names | |
| model, tokenizer = get_model_and_tokenizer() | |
| cat2id, id2cat, cat_names = get_categories() | |
| def predict_and_decode(model, title='', abstract=''): | |
| model.eval() | |
| inputs = tokenizer(title, abstract, return_tensors='pt', truncation=True, max_length=512).to(device) | |
| logits = model(**inputs)['logits'][0].cpu() | |
| df = pd.DataFrame([ | |
| (id2cat[cat_id], cat_names[cat_id], prob.item()) | |
| for cat_id, prob in enumerate(F.sigmoid(logits)) | |
| ], columns=("tag", "name", "probability")) | |
| df.sort_values(by="probability", ascending=False, inplace=True) | |
| return df.reset_index(drop=True) | |
| st.header("Paper Category Classifier") | |
| st.text("Input a title and/or an abstract of a scientific paper, and get classification according to arxiv.org categories") | |
| input_container = st.container(border=True) | |
| with input_container: | |
| title_default = "Attention Is All You Need" | |
| abstract_default = ( | |
| "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks " | |
| "in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through " | |
| "an attention mechanism. We propose a new simple network architecture, the Transformer..." | |
| ) | |
| line_height = 34 | |
| n_lines = 10 | |
| title = st.text_input("Paper title", value=title_default, help="Type in paper's title") | |
| abstract = st.text_area("Paper abstract", value=abstract_default, height=line_height*n_lines, help="Type in paper's abstract") | |
| if title or abstract: | |
| result = predict_and_decode(model, title=title, abstract=abstract) | |
| main_cnt = st.container(border=True) | |
| with main_cnt: | |
| st.markdown("#### Top category") | |
| st.markdown(f"**{result.tag[0]}** -- {result.name[0]}") | |
| st.markdown(f"Probability: {result.probability[0]*100:.2f}%") | |
| rest_cnt = st.container(border=True) | |
| with rest_cnt: | |
| threshold = 0.55 | |
| st.text("Other top categories:") | |
| max_len = min(max(1, sum(result.iloc[1:].probability > threshold)), 5) | |
| def format_p(example): | |
| example.probability = f"{example.probability * 100 :.2f}%" | |
| return example | |
| st.table(result.iloc[1:1 + max_len].apply(format_p, axis=1)) | |
| else: | |
| st.warning("Type a title and/or an abstract to get started!") |