File size: 801 Bytes
af64c3f 84eeea5 af64c3f 84eeea5 af64c3f 84eeea5 af64c3f 84eeea5 af64c3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from transformers import pipeline, DistilBertTokenizer, DistilBertForSequenceClassification
import torch
import gradio as gr
#generator = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
def classify_text(prompt):
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_class_id = logits.argmax().item()
return model.config.id2label[predicted_class_id]
# Create a Gradio interface
iface = gr.Interface(fn=classify_text, inputs="text", outputs="text")
iface.launch()
|