File size: 3,579 Bytes
7dc197f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
from transformers import pipeline

models = {
    "MoritzLaurer/deberta-v3-large-zeroshot-v2.0 (best, English)": "MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
    "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7 (multilingual incl. Dutch)": "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7",
    "facebook/bart-large-mnli (classic)": "facebook/bart-large-mnli",
}

pipes = {}


def get_pipe(model_name):
    if model_name not in pipes:
        pipes[model_name] = pipeline(
            "zero-shot-classification",
            model=models[model_name],
        )
    return pipes[model_name]


PRESETS = {
    "Custom (type your own)": "",
    "News categories": "politics, economy, sports, culture, technology, health, crime, environment",
    "Sentiment": "positive, negative, neutral",
    "Urgency": "urgent, important, routine, not relevant",
    "Story type": "breaking news, investigation, feature, opinion, analysis",
    "Tips inbox triage": "actionable tip, complaint, spam, press release, personal story",
}


def classify(text, model_choice, labels_text, preset, multi_label):
    if not text.strip():
        return "Enter some text to classify."

    if preset != "Custom (type your own)" and not labels_text.strip():
        labels_text = PRESETS[preset]

    if not labels_text.strip():
        return "Enter at least two labels separated by commas."

    labels = [l.strip() for l in labels_text.split(",") if l.strip()]
    if len(labels) < 2:
        return "Need at least two labels."

    pipe = get_pipe(model_choice)
    result = pipe(text, candidate_labels=labels, multi_label=multi_label)

    output = ""
    for label, score in zip(result["labels"], result["scores"]):
        bar = "█" * int(score * 30)
        output += f"{label:.<30s} {score:.1%} {bar}\n"

    return output


def update_labels(preset):
    if preset == "Custom (type your own)":
        return ""
    return PRESETS.get(preset, "")


with gr.Blocks(title="Zero-Shot Classification — KRO-NCRV Workshop") as demo:
    gr.Markdown("# Zero-Shot Classification")
    gr.Markdown(
        "Classify text into **any categories you define** — no training needed. "
        "Works in Dutch and English."
    )

    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(
                label="Text to classify",
                lines=6,
                placeholder="Paste an article, tip, tweet, or paragraph...",
            )
            model_choice = gr.Dropdown(
                choices=list(models.keys()),
                value=list(models.keys())[1],
                label="Model",
            )
            preset = gr.Dropdown(
                choices=list(PRESETS.keys()),
                value="Custom (type your own)",
                label="Label preset",
            )
            labels_input = gr.Textbox(
                label="Labels (comma-separated)",
                placeholder="politics, economy, sports, culture",
            )
            multi_label = gr.Checkbox(
                label="Multi-label (text can belong to multiple categories)",
                value=False,
            )
            btn = gr.Button("Classify", variant="primary")

        with gr.Column():
            output = gr.Textbox(label="Results", lines=15, show_copy_button=True)

    preset.change(fn=update_labels, inputs=[preset], outputs=[labels_input])
    btn.click(
        fn=classify,
        inputs=[text_input, model_choice, labels_input, preset, multi_label],
        outputs=[output],
    )

demo.launch()