PythonProject1 / .venv /PythonProjectFile1.py
DrDavis's picture
Update .venv/PythonProjectFile1.py
afb9ff0 verified
from transformers import pipeline, DistilBertTokenizer, DistilBertForSequenceClassification
import torch
import gradio as gr
myPipe = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
#tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
#model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
def classify_text(prompt):
return myPipe(prompt)[0]
# inputs = tokenizer(prompt, return_tensors="pt")
# with torch.no_grad():
# logits = model(**inputs).logits
# predicted_class_id = logits.argmax().item()
# return model.config.id2label[predicted_class_id]
# Create a Gradio interface
iface = gr.Interface(fn=classify_text, inputs=gr.Textbox(label="Your Text:"), outputs=gr.Textbox(label="Valence Score:"))
iface.launch()