Spaces:
Sleeping
Sleeping
| import json | |
| import torch | |
| from transformers import AutoTokenizer | |
| from transformers import AutoModelForSequenceClassification | |
| import gradio as gr | |
| with open('tag_to_name.json', 'r') as fin: | |
| tag_to_name = json.load(fin) | |
| id_to_name = dict(zip(range(len(tag_to_name)), tag_to_name.values())) | |
| tokenizer = AutoTokenizer.from_pretrained('checkpoints/checkpoint-5000/') | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| 'checkpoints/checkpoint-5000/') | |
| model.eval() | |
| TITLE_DEFAULT = "Attention Is All You Need" | |
| def predict(title, abstract, top_p, top_k): | |
| if abstract: | |
| text = title + "[SEP]" + abstract | |
| else: | |
| text = title | |
| tokenized = tokenizer(text, truncation=True, | |
| return_tensors='pt', max_length=512) | |
| probs = model(**tokenized).logits[0].softmax(0) | |
| top_probs, top_inds = probs.sort(descending=True) | |
| mask = top_probs.cumsum(0) <= top_p | |
| if not mask.all(): | |
| mask[mask.sum()] = True | |
| mask[top_k:] = False | |
| mask[0] = True | |
| predicted_ids = top_inds[mask].tolist() | |
| predicted_probs = top_probs[mask].tolist() | |
| predicted_names = [id_to_name[id_] for id_ in predicted_ids] | |
| return {name: prob for name, prob in zip(predicted_names, predicted_probs)} | |
| def inference( | |
| title, | |
| abstract, | |
| top_p, | |
| top_k, | |
| ): | |
| if not title: | |
| title = TITLE_DEFAULT | |
| return predict(title, abstract, top_p, top_k) | |
| g = gr.Interface( | |
| fn=inference, | |
| inputs=[ | |
| gr.components.Textbox( | |
| lines=2, label="Title", placeholder=TITLE_DEFAULT | |
| ), | |
| gr.components.Textbox(lines=4, label="Abstract", placeholder=""), | |
| gr.components.Slider(minimum=0, maximum=1, value=0.95, label="Top p"), | |
| gr.components.Slider(minimum=1, maximum=len(tag_to_name), | |
| step=1, value=10, label="Top n"), | |
| ], | |
| outputs=gr.outputs.Label(label="Predicted categories"), | |
| title="🪄 arXiv classifier 🪄", | |
| ) | |
| g.launch() | |