File size: 801 Bytes
af64c3f
 
 
84eeea5
af64c3f
84eeea5
af64c3f
 
84eeea5
af64c3f
 
 
 
84eeea5
af64c3f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from transformers import pipeline, DistilBertTokenizer, DistilBertForSequenceClassification
import torch
import gradio as gr

#generator = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")

def classify_text(prompt):
    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        logits = model(**inputs).logits

    predicted_class_id = logits.argmax().item()
    return model.config.id2label[predicted_class_id]

# Create a Gradio interface
iface = gr.Interface(fn=classify_text, inputs="text", outputs="text")
iface.launch()