Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| # In[2]: | |
| import gradio as gr | |
| import torch | |
| # In[3]: | |
| model_ckpt = "langfab/distilbert-base-uncased-finetuned-movie-genre" | |
| from transformers import (AutoTokenizer, AutoConfig, | |
| AutoModelForSequenceClassification) | |
| tokenizer = AutoTokenizer.from_pretrained(model_ckpt) | |
| config = AutoConfig.from_pretrained(model_ckpt) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,config=config) | |
| # In[4]: | |
| id2label = model.config.id2label | |
| def predict(plot): | |
| encoding = tokenizer(plot, padding=True, truncation=True, return_tensors="pt") | |
| encoding = {k: v.to(model.device) for k,v in encoding.items()} | |
| outputs = model(**encoding) | |
| logits = outputs.logits | |
| logits.shape | |
| predictions = torch.nn.functional.softmax(logits.squeeze().cpu(), dim=-1) | |
| predictions | |
| return id2label[int(predictions.argmax())] | |
| iface = gr.Interface(title = "Movie Plot Genre Predictor", fn=predict, inputs="text", outputs="text") | |
| iface.launch(share=True) | |