File size: 884 Bytes
af64c3f
 
 
84eeea5
afb9ff0
84eeea5
afb9ff0
 
84eeea5
af64c3f
afb9ff0
 
 
 
 
 
af64c3f
 
afb9ff0
af64c3f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from transformers import pipeline, DistilBertTokenizer, DistilBertForSequenceClassification
import torch
import gradio as gr

myPipe = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")

#tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
#model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")

def classify_text(prompt):
    return myPipe(prompt)[0]
#    inputs = tokenizer(prompt, return_tensors="pt")
#    with torch.no_grad():
#        logits = model(**inputs).logits
#    predicted_class_id = logits.argmax().item()
#    return model.config.id2label[predicted_class_id]

# Create a Gradio interface
iface = gr.Interface(fn=classify_text, inputs=gr.Textbox(label="Your Text:"), outputs=gr.Textbox(label="Valence Score:"))
iface.launch()