Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import pipeline | |
| models = { | |
| "MoritzLaurer/deberta-v3-large-zeroshot-v2.0 (best, English)": "MoritzLaurer/deberta-v3-large-zeroshot-v2.0", | |
| "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7 (multilingual incl. Dutch)": "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7", | |
| "facebook/bart-large-mnli (classic)": "facebook/bart-large-mnli", | |
| } | |
| pipes = {} | |
| def get_pipe(model_name): | |
| if model_name not in pipes: | |
| pipes[model_name] = pipeline( | |
| "zero-shot-classification", | |
| model=models[model_name], | |
| ) | |
| return pipes[model_name] | |
| PRESETS = { | |
| "Custom (type your own)": "", | |
| "News categories": "politics, economy, sports, culture, technology, health, crime, environment", | |
| "Sentiment": "positive, negative, neutral", | |
| "Urgency": "urgent, important, routine, not relevant", | |
| "Story type": "breaking news, investigation, feature, opinion, analysis", | |
| "Tips inbox triage": "actionable tip, complaint, spam, press release, personal story", | |
| } | |
| def classify(text, model_choice, labels_text, preset, multi_label): | |
| if not text.strip(): | |
| return "Enter some text to classify." | |
| if preset != "Custom (type your own)" and not labels_text.strip(): | |
| labels_text = PRESETS[preset] | |
| if not labels_text.strip(): | |
| return "Enter at least two labels separated by commas." | |
| labels = [l.strip() for l in labels_text.split(",") if l.strip()] | |
| if len(labels) < 2: | |
| return "Need at least two labels." | |
| pipe = get_pipe(model_choice) | |
| result = pipe(text, candidate_labels=labels, multi_label=multi_label) | |
| output = "" | |
| for label, score in zip(result["labels"], result["scores"]): | |
| bar = "β" * int(score * 30) | |
| output += f"{label:.<30s} {score:.1%} {bar}\n" | |
| return output | |
| def update_labels(preset): | |
| if preset == "Custom (type your own)": | |
| return "" | |
| return PRESETS.get(preset, "") | |
| with gr.Blocks(title="Zero-Shot Classification β KRO-NCRV Workshop") as demo: | |
| gr.Markdown("# Zero-Shot Classification") | |
| gr.Markdown( | |
| "Classify text into **any categories you define** β no training needed. " | |
| "Works in Dutch and English." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Text to classify", | |
| lines=6, | |
| placeholder="Paste an article, tip, tweet, or paragraph...", | |
| ) | |
| model_choice = gr.Dropdown( | |
| choices=list(models.keys()), | |
| value=list(models.keys())[1], | |
| label="Model", | |
| ) | |
| preset = gr.Dropdown( | |
| choices=list(PRESETS.keys()), | |
| value="Custom (type your own)", | |
| label="Label preset", | |
| ) | |
| labels_input = gr.Textbox( | |
| label="Labels (comma-separated)", | |
| placeholder="politics, economy, sports, culture", | |
| ) | |
| multi_label = gr.Checkbox( | |
| label="Multi-label (text can belong to multiple categories)", | |
| value=False, | |
| ) | |
| btn = gr.Button("Classify", variant="primary") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Results", lines=15, show_copy_button=True) | |
| preset.change(fn=update_labels, inputs=[preset], outputs=[labels_input]) | |
| btn.click( | |
| fn=classify, | |
| inputs=[text_input, model_choice, labels_input, preset, multi_label], | |
| outputs=[output], | |
| ) | |
| demo.launch() | |