| from transformers import pipeline, DistilBertTokenizer, DistilBertForSequenceClassification | |
| import torch | |
| import gradio as gr | |
| #generator = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english") | |
| tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
| model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
| def classify_text(prompt): | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| predicted_class_id = logits.argmax().item() | |
| return model.config.id2label[predicted_class_id] | |
| # Create a Gradio interface | |
| iface = gr.Interface(fn=classify_text, inputs="text", outputs="text") | |
| iface.launch() | |