Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| from torch.nn import functional as F | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import json | |
| import streamlit.components.v1 as components | |
| if __name__ == '__main__': | |
| st.markdown("### Arxiv paper classifier (No guarantees provided)") | |
| col1, col2 = st.columns([1, 1]) | |
| col1.image('imgs/akinator_ready.png', width=200) | |
| btn = col2.button('Classify!') | |
| model = AutoModelForSequenceClassification.from_pretrained('checkpoint-3000') | |
| tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
| with open('checkpoint-3000/config.json', 'r') as f: | |
| id2label = json.load(f)['id2label'] | |
| id2label = {int(key): value for key, value in id2label.items()} | |
| title = st.text_area(label='', placeholder='Input title...', height=3) | |
| abstract = st.text_area(label='', placeholder='Input abstract...', height=10) | |
| text = '\n'.join([title, abstract]) | |
| if btn and len(text) == 1: | |
| st.error('Title and abstract are empty!') | |
| if btn and len(text) > 1: | |
| tokenized = tokenizer(text) | |
| with torch.no_grad(): | |
| out = model(torch.tensor(tokenized['input_ids']).unsqueeze(dim=0)) | |
| _, ids = torch.sort(-out['logits']) | |
| probs = F.softmax(out['logits'][0, ids], dim=1) | |
| ids, probs = ids[0], probs[0] | |
| ptotal = 0 | |
| result = [] | |
| for i, prob in enumerate(probs): | |
| ptotal += prob | |
| result.append(f'{id2label[ids[i].item()]} (prob = {prob.item()})') | |
| output = '<br>'.join(result) | |
| components.html(f'<div>' | |
| f'<div style="height:120px;width:680px;' | |
| f'border:1px solid #ccc;border-color: red;' | |
| f'font:16px/26px Georgia, Garamond, Serif;' | |
| f'overflow:scroll;' | |
| f'color:black;">' | |
| f'{output}</div>') | |