PythonProject1 / .venv /PythonProjectFile1.py
DrDavis's picture
Update .venv/PythonProjectFile1.py
af64c3f verified
raw
history blame
801 Bytes
from transformers import pipeline, DistilBertTokenizer, DistilBertForSequenceClassification
import torch
import gradio as gr
#generator = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
def classify_text(prompt):
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_class_id = logits.argmax().item()
return model.config.id2label[predicted_class_id]
# Create a Gradio interface
iface = gr.Interface(fn=classify_text, inputs="text", outputs="text")
iface.launch()