diss-space / app.py
heican's picture
Update app.py
afc854e verified
raw
history blame contribute delete
989 Bytes
import gradio as gr
from transformers import pipeline
MODEL_IDS = {
"RoBERTa-base (best, 89.1%)": "heican/sentiment-roberta-base",
"BERT-base (87.5%)": "heican/sentiment-bert-base",
}
DEFAULT_MODEL = "RoBERTa-base (best, 89.1%)"
_cache = {}
def get_model(name):
if name not in _cache:
_cache[name] = pipeline("sentiment-analysis", model=MODEL_IDS[name])
return _cache[name]
def predict(text, model_choice):
clf = get_model(model_choice or DEFAULT_MODEL)
result = clf(text)[0]
label = "Positive" if result["label"] == "LABEL_1" else "Negative"
return f"{label} ({result['score']:.2%})"
demo = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Tweet"),
gr.Radio(list(MODEL_IDS.keys()), label="Model", value=DEFAULT_MODEL),
],
outputs=gr.Textbox(label="Prediction"),
title="Sentiment Analysis Demo",
description="Compare two fine-tuned transformer models.",
)
if __name__ == "__main__":
demo.launch()