sirunchained's picture
bug fix
2f7a316
import torch
import transformers
import gradio as gr
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
tokenizer = transformers.AutoTokenizer.from_pretrained(
"distilbert/distilbert-base-uncased",
use_fast=True
)
model = transformers.AutoModelForSequenceClassification.from_pretrained(
"sirunchained/imdb-text-classifier"
)
for name, param in model.named_parameters():
print(name, param.mean().item())
break
print(model.config)
model.to(device)
model.eval()
def predict_sentiment(text):
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=512
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)[0]
negative = float(probs[0])
positive = float(probs[1])
if positive > negative:
sentiment = "POSITIVE πŸ˜„"
confidence = positive
color = "orange"
else:
sentiment = "NEGATIVE 😑"
confidence = negative
color = "red"
return (
{
"Negative": negative,
"Positive": positive,
},
sentiment,
confidence,
gr.update(visible=True, value=confidence),
gr.update(elem_classes=color)
)
with gr.Blocks(css="""
.orange .bar {background-color: orange !important;}
.red .bar {background-color: red !important;}
""") as demo:
gr.Markdown("## 🎬 IMDb Sentiment Analyzer")
text_input = gr.Textbox(
label="Enter movie review (press Shift+Enter to predict)",
lines=4,
placeholder="This movie was absolutely amazing..."
)
with gr.Row():
label_output = gr.Label(num_top_classes=2)
sentiment_text = gr.Markdown()
confidence_bar = gr.Slider(
minimum=0,
maximum=1,
step=0.01,
label="Model Confidence",
interactive=False,
visible=False
)
text_input.submit(
predict_sentiment,
inputs=text_input,
outputs=[
label_output,
sentiment_text,
confidence_bar,
confidence_bar,
label_output
]
)
demo.launch()