fdaudens's picture
Upload folder using huggingface_hub
7dc197f verified
import gradio as gr
from transformers import pipeline
models = {
"MoritzLaurer/deberta-v3-large-zeroshot-v2.0 (best, English)": "MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
"MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7 (multilingual incl. Dutch)": "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7",
"facebook/bart-large-mnli (classic)": "facebook/bart-large-mnli",
}
pipes = {}
def get_pipe(model_name):
if model_name not in pipes:
pipes[model_name] = pipeline(
"zero-shot-classification",
model=models[model_name],
)
return pipes[model_name]
PRESETS = {
"Custom (type your own)": "",
"News categories": "politics, economy, sports, culture, technology, health, crime, environment",
"Sentiment": "positive, negative, neutral",
"Urgency": "urgent, important, routine, not relevant",
"Story type": "breaking news, investigation, feature, opinion, analysis",
"Tips inbox triage": "actionable tip, complaint, spam, press release, personal story",
}
def classify(text, model_choice, labels_text, preset, multi_label):
if not text.strip():
return "Enter some text to classify."
if preset != "Custom (type your own)" and not labels_text.strip():
labels_text = PRESETS[preset]
if not labels_text.strip():
return "Enter at least two labels separated by commas."
labels = [l.strip() for l in labels_text.split(",") if l.strip()]
if len(labels) < 2:
return "Need at least two labels."
pipe = get_pipe(model_choice)
result = pipe(text, candidate_labels=labels, multi_label=multi_label)
output = ""
for label, score in zip(result["labels"], result["scores"]):
bar = "β–ˆ" * int(score * 30)
output += f"{label:.<30s} {score:.1%} {bar}\n"
return output
def update_labels(preset):
if preset == "Custom (type your own)":
return ""
return PRESETS.get(preset, "")
with gr.Blocks(title="Zero-Shot Classification β€” KRO-NCRV Workshop") as demo:
gr.Markdown("# Zero-Shot Classification")
gr.Markdown(
"Classify text into **any categories you define** β€” no training needed. "
"Works in Dutch and English."
)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Text to classify",
lines=6,
placeholder="Paste an article, tip, tweet, or paragraph...",
)
model_choice = gr.Dropdown(
choices=list(models.keys()),
value=list(models.keys())[1],
label="Model",
)
preset = gr.Dropdown(
choices=list(PRESETS.keys()),
value="Custom (type your own)",
label="Label preset",
)
labels_input = gr.Textbox(
label="Labels (comma-separated)",
placeholder="politics, economy, sports, culture",
)
multi_label = gr.Checkbox(
label="Multi-label (text can belong to multiple categories)",
value=False,
)
btn = gr.Button("Classify", variant="primary")
with gr.Column():
output = gr.Textbox(label="Results", lines=15, show_copy_button=True)
preset.change(fn=update_labels, inputs=[preset], outputs=[labels_input])
btn.click(
fn=classify,
inputs=[text_input, model_choice, labels_input, preset, multi_label],
outputs=[output],
)
demo.launch()