RottenClass / app.py
MiVaCod's picture
Update app.py
d45f2b8 verified
raw
history blame
1.28 kB
import gradio as gr
import tensorflow as tf
from transformers import BertTokenizer, BertModel
name = "MiVaCod/mbart-neutralization"
tokenizer = BertTokenizer.from_pretrained(name)
model = BertModel.from_pretrained(name)
# Define a function to make predictions
def predict(texts):
# Tokenize and preprocess the new text
new_encodings = tokenizer(texts, truncation=True, padding=True, max_length=70, return_tensors='tf')
new_predictions = model(new_encodings)
# Make predictions
new_predictions = model(new_encodings)
new_labels_pred = tf.argmax(new_predictions.logits, axis=1)
new_labels_pred = new_labels_pred.numpy()[0]
labels_list = ["Negative 😠", "Positive 😍"]
emotion = labels_list[new_labels_pred]
return emotion
# Create a Gradio interface
iface = gr.Interface(
fn=predict,
inputs="text",
outputs=gr.outputs.Label(num_top_classes = 6), # Corrected output type
examples=[["the rock is destined to be the 21st century's new conan and that he's going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal."],
],
title="Rotten tomatoes classification",
description="Predict the class associated with a text."
)
# Launch the interfac
iface.launch()