File size: 6,854 Bytes
6cd43f0
 
 
 
 
4d77788
 
 
 
6cd43f0
 
 
 
 
 
 
 
 
 
 
 
 
4d77788
6cd43f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d77788
 
 
 
 
 
 
 
 
 
6cd43f0
 
 
 
 
4d77788
6cd43f0
 
4d77788
 
 
6cd43f0
 
 
 
 
 
 
 
 
4d77788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cd43f0
 
 
 
 
4d77788
 
 
 
 
 
 
 
 
 
 
 
 
6cd43f0
 
 
c2b04f7
6cd43f0
 
 
 
4d77788
 
 
 
 
 
6cd43f0
4d77788
 
 
 
 
 
 
 
 
 
6cd43f0
 
4d77788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cd43f0
4d77788
 
 
 
 
6cd43f0
 
 
4d77788
 
 
 
 
 
 
6cd43f0
4d77788
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import time
import gradio as gr
from transformers import pipeline

model_options = [
    "Mezzo-Content-Guard-Large",
    "Mezzo-Content-Guard-Base",
    "Mezzo-Content-Guard-Small",
    "Mezzo-Content-Guard-Large-Preview"
]
cached_models = {}


def run_model(model_name, input_text):
    if not input_text.strip():
        return {}, "0ms"

    model = cached_models.get(model_name)
    if not model:
        repo_id = f"RyanStudio/{model_name}"
        model = pipeline("text-classification", repo_id)
        cached_models[model_name] = model
        _ = model("warmup")

    start = time.time()
    raw_output = model(input_text, top_k=None)   # get all classes
    latency = f"{round((time.time() - start) * 1000, 2)} ms"

    if isinstance(raw_output, list) and isinstance(raw_output[0], list):
        raw_output = raw_output[0]
    elif isinstance(raw_output, dict):
        raw_output = [raw_output]

    scores = {item["label"]: item["score"] for item in raw_output}

    return scores, latency


def run_comparison(model_name_1, model_name_2, input_text):
    if not input_text.strip():
        return {}, {}, "0ms", "0ms"

    scores_1, latency_1 = run_model(model_name_1, input_text)
    scores_2, latency_2 = run_model(model_name_2, input_text)

    return scores_1, scores_2, latency_1, latency_2


custom_css = """
#container { max-width: 900px; margin: auto; padding-top: 20px; }
.output-stats { font-weight: bold; color: #555; }
"""

with gr.Blocks() as demo:
    with gr.Column(elem_id="container"):
        gr.Markdown("# 🛡️ Mezzo Guard – Content Moderation")
        gr.Markdown("Analyze text for 5 content categories (**Content Guard**).")

        comparison_toggle = gr.Checkbox(label="Compare 2 Models", value=False)

        with gr.Row():
            with gr.Column(scale=2):
                text_input = gr.Textbox(
                    label="Input Prompt",
                    placeholder="Enter the text you want to screen...",
                    lines=6,
                    max_lines=15
                )
                
                with gr.Group(visible=True) as single_model_group:
                    model_dropdown = gr.Dropdown(
                        label="Model",
                        choices=model_options,
                        value=model_options[0],
                        interactive=True
                    )
                
                with gr.Group(visible=False) as comparison_group:
                    model_dropdown_1 = gr.Dropdown(
                        label="Model 1",
                        choices=model_options,
                        value=model_options[0],
                        interactive=True
                    )
                    model_dropdown_2 = gr.Dropdown(
                        label="Model 2",
                        choices=model_options,
                        value=model_options[1],
                        interactive=True
                    )
                
                with gr.Row():
                    clear_btn = gr.Button("Clear", variant="secondary")
                    run_button = gr.Button("Analyze Prompt", variant="primary")

            with gr.Column(scale=1):
                with gr.Group(visible=True) as single_output_group:
                    label_output = gr.Label(label="Classification Result", num_top_classes=6)
                    latency_output = gr.Textbox(label="Latency", interactive=False, elem_classes="output-stats")
                
                # Comparison output
                with gr.Group(visible=False) as comparison_output_group:
                    with gr.Row():
                        with gr.Column():
                            label_output_1 = gr.Label(label="Model 1 Results", num_top_classes=6)
                            latency_output_1 = gr.Textbox(label="Model 1 Latency", interactive=False, elem_classes="output-stats")
                        with gr.Column():
                            label_output_2 = gr.Label(label="Model 2 Results", num_top_classes=6)
                            latency_output_2 = gr.Textbox(label="Model 2 Latency", interactive=False, elem_classes="output-stats")

                gr.Markdown("### Performance Info")
                gr.HTML(
                    "<small>Model weights are cached after the first run. Content Guard returns scores for violence, hate-speech, sexual, toxic and self-harm</small>"
                )

        gr.Examples(
            examples=[
                ["Hello how are you?"],
                ["I am so horny right now"],
                ["I hate faggots"],
                ["I want to kill myself"],
                ["I fucking hate you cause ur a cunt"],
                ["I want to bomb a hospital right now"]
            ],
            inputs=[text_input]
        )

    # Toggle logic
    def toggle_comparison(is_comparing):
        return (
            gr.Group(visible=not is_comparing),  # single_model_group
            gr.Group(visible=is_comparing),      # comparison_group
            gr.Group(visible=not is_comparing),  # single_output_group
            gr.Group(visible=is_comparing),      # comparison_output_group
        )

    comparison_toggle.change(
        fn=toggle_comparison,
        inputs=[comparison_toggle],
        outputs=[single_model_group, comparison_group, single_output_group, comparison_output_group]
    )

    # Single model run
    def run_single(model_name, input_text):
        scores, latency = run_model(model_name, input_text)
        return scores, latency

    # Comparison run
    def run_comparison_mode(model_1, model_2, input_text):
        scores_1, scores_2, latency_1, latency_2 = run_comparison(model_1, model_2, input_text)
        return scores_1, latency_1, scores_2, latency_2

    def get_run_fn(is_comparing):
        if is_comparing:
            return run_comparison_mode, [model_dropdown_1, model_dropdown_2, text_input], [label_output_1, latency_output_1, label_output_2, latency_output_2]
        else:
            return run_single, [model_dropdown, text_input], [label_output, latency_output]

    run_button.click(
        fn=lambda comp, m1, m2, m_single, txt: (
            ({}, "", *run_comparison_mode(m1, m2, txt)) if comp else (*run_single(m_single, txt), {}, "", {}, "")
        ),
        inputs=[comparison_toggle, model_dropdown_1, model_dropdown_2, model_dropdown, text_input],
        outputs=[label_output, latency_output, label_output_1, latency_output_1, label_output_2, latency_output_2],
        api_name="predict"
    )

    def clear_all():
        return "", {}, "", {}, "", {}, ""

    clear_btn.click(
        fn=clear_all,
        outputs=[text_input, label_output, latency_output, label_output_1, latency_output_1, label_output_2, latency_output_2]
    )

demo.launch(theme=gr.themes.Default(), css=custom_css)