|
|
import gradio as gr |
|
|
from transformers import pipeline |
|
|
|
|
|
|
|
|
|
|
|
MODEL_IDS = [ |
|
|
"shri171981/genai_model_dberta_base", |
|
|
"shri171981/genai_model_dberta_large", |
|
|
"shri171981/genai_model_roberta_base" |
|
|
] |
|
|
|
|
|
|
|
|
LABELS = ["Anger", "Fear", "Joy", "Sadness", "Surprise"] |
|
|
|
|
|
|
|
|
pipelines = [] |
|
|
for model_id in MODEL_IDS: |
|
|
try: |
|
|
print(f"Loading {model_id}...") |
|
|
|
|
|
p = pipeline("text-classification", model=model_id, top_k=None) |
|
|
pipelines.append(p) |
|
|
except Exception as e: |
|
|
print(f"Failed to load {model_id}: {e}") |
|
|
|
|
|
def predict(text): |
|
|
|
|
|
final_scores = {label: 0.0 for label in LABELS} |
|
|
|
|
|
|
|
|
for pipe in pipelines: |
|
|
|
|
|
results = pipe(text)[0] |
|
|
|
|
|
|
|
|
for result in results: |
|
|
label_id = int(result['label'].split('_')[-1]) |
|
|
label_name = LABELS[label_id] |
|
|
score = result['score'] |
|
|
|
|
|
|
|
|
final_scores[label_name] += score |
|
|
|
|
|
|
|
|
|
|
|
num_models = len(pipelines) |
|
|
averaged_scores = {k: v / num_models for k, v in final_scores.items()} |
|
|
|
|
|
return averaged_scores |
|
|
|
|
|
|
|
|
theme = gr.themes.Soft( |
|
|
primary_hue="teal", |
|
|
secondary_hue="slate", |
|
|
) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# DL GenAI Emotion Classifier |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_text = gr.Textbox( |
|
|
label="Input Text", |
|
|
placeholder="Type something emotional here...", |
|
|
lines=3 |
|
|
) |
|
|
submit_btn = gr.Button("Analyze with Ensemble", variant="primary") |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["I can't believe you would betray me like this!"], |
|
|
["I heard a strange noise outside and I'm scared to look."], |
|
|
["I finally got the promotion! This is the best day ever!"], |
|
|
["I feel so lonely and empty inside."], |
|
|
["Wow! I never expected a surprise party!"] |
|
|
], |
|
|
inputs=input_text |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
|
|
|
output_chart = gr.Label(label="Sentiment Analysis", num_top_classes=5) |
|
|
|
|
|
|
|
|
submit_btn.click(fn=predict, inputs=input_text, outputs=output_chart) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |