File size: 1,111 Bytes
40e0339 ca5d70b 40e0339 ca5d70b 40e0339 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import gradio as gr
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
# ✅ Load the model from Hugging Face Hub
MODEL_NAME = "Kaiyeee/fine_tuned_distilbert_imdb" # Update this with your actual repo name
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
def predict_sentiment(text):
inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_id = torch.argmax(logits, dim=-1).item()
sentiment = "positive" if predicted_class_id == 1 else "negative"
return sentiment
# ✅ Create a Gradio interface
demo = gr.Interface(
fn=predict_sentiment,
inputs=gr.Textbox(lines=5, placeholder="Enter text for sentiment analysis..."),
outputs="text",
title="Sentiment Analysis with DistilBERT",
description="Enter text to predict sentiment (positive or negative)."
)
if __name__ == "__main__":
demo.launch() |