File size: 884 Bytes
af64c3f 84eeea5 afb9ff0 84eeea5 afb9ff0 84eeea5 af64c3f afb9ff0 af64c3f afb9ff0 af64c3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from transformers import pipeline, DistilBertTokenizer, DistilBertForSequenceClassification
import torch
import gradio as gr
myPipe = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
#tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
#model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
def classify_text(prompt):
return myPipe(prompt)[0]
# inputs = tokenizer(prompt, return_tensors="pt")
# with torch.no_grad():
# logits = model(**inputs).logits
# predicted_class_id = logits.argmax().item()
# return model.config.id2label[predicted_class_id]
# Create a Gradio interface
iface = gr.Interface(fn=classify_text, inputs=gr.Textbox(label="Your Text:"), outputs=gr.Textbox(label="Valence Score:"))
iface.launch()
|