| import os |
| import gradio as gr |
| from transformers import pipeline |
| import plotly.express as px |
| import pandas as pd |
|
|
| HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
| MODEL_LIST = [ |
| "thethinkmachine/immune-resin", |
| "thethinkmachine/bright-avocado", |
| "thethinkmachine/kickass-supercluster", |
| ] |
|
|
| pipeline_cache = {} |
|
|
| def get_pipeline(model_name): |
| if model_name not in pipeline_cache: |
| pipeline_cache[model_name] = pipeline( |
| "text-classification", |
| model=model_name, |
| token=HF_TOKEN, |
| return_all_scores=True |
| ) |
| return pipeline_cache[model_name] |
|
|
|
|
| def classify(text, model_name, chart_type, threshold): |
| if not text.strip(): |
| |
| return None, None |
|
|
| clf = get_pipeline(model_name) |
| results = clf(text)[0] |
| df = pd.DataFrame(results) |
|
|
| |
| df = df.sort_values(by="score", ascending=False).reset_index(drop=True) |
|
|
| |
| df["highlight"] = df["score"].apply(lambda x: "High" if x >= threshold else "Low") |
|
|
| |
| if chart_type == "Radar Chart": |
| fig = px.line_polar( |
| df, |
| r="score", |
| theta="label", |
| line_close=True, |
| color="highlight", |
| title=f"Label Probabilities - {model_name}" |
| ) |
| fig.update_traces(fill='toself') |
| fig.update_layout(polar=dict(radialaxis=dict(range=[0, 1]))) |
| else: |
| fig = px.bar( |
| df, |
| x="label", |
| y="score", |
| color="highlight", |
| title=f"Label Probabilities - {model_name}" |
| ) |
| fig.update_layout(yaxis=dict(range=[0, 1])) |
|
|
| return df, fig |
|
|
|
|
| with gr.Blocks(theme=gr.themes.Ocean()) as demo: |
| gr.Markdown("# Ekman Emotions Playground π€πΎπ¦") |
| gr.Markdown("### Why let humans play with your emotions when a robot can do it for you for free? ...Cheaper than therapy, 100% less effective!") |
|
|
| with gr.Row(): |
| model_dropdown = gr.Dropdown( |
| choices=MODEL_LIST, |
| label="Select Model", |
| value=MODEL_LIST[0] |
| ) |
| chart_dropdown = gr.Dropdown( |
| choices=["Radar Chart", "Bar Chart"], |
| label="Chart Type", |
| value="Bar Chart" |
| ) |
| threshold_slider = gr.Slider( |
| 0, 1, |
| value=0.5, |
| step=0.01, |
| label="Highlight Threshold" |
| ) |
|
|
| text_input = gr.Textbox(label="Input Text") |
|
|
| with gr.Row(): |
| output_table = gr.DataFrame(label="Scores Table") |
| output_plot = gr.Plot(label="Probability Chart") |
|
|
| classify_btn = gr.Button("Run Classification") |
|
|
| classify_btn.click( |
| classify, |
| inputs=[text_input, model_dropdown, chart_dropdown, threshold_slider], |
| outputs=[output_table, output_plot] |
| ) |
|
|
| gr.Examples( |
| examples=[ |
| "I only saw her once and I'm head over heels!", |
| "Pay you for what, just standing there?!", |
| "Dumbass Broncos fans circa December 2015.", |
| "It's great that you're a recovering addict, that's cool. Have you ever tried DMT?", |
| "I'm scared to even ask my mom ,I might get yelled at π", |
| "hurr durr I like using reddit, and anyone who doesn't agree with me is a retard", |
| "They're really pushing me into this... once I go, there's no coming back you know?", |
| "Considering I haven't eaten or drunk anything in about twenty hours, my head hurts.", |
| ], |
| inputs=[text_input] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |