Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| WimBERT Synth v0 Gradio Space | |
| Dual-head multi-label classifier for Dutch signal messages | |
| """ | |
| import json | |
| import importlib.util | |
| import torch | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| # Constants | |
| MODEL_REPO = "UWV/wimbert-synth-v0" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| DTYPE = torch.float16 if DEVICE.type == "cuda" else torch.float32 | |
| MAX_LENGTH = 512 # Default to 512 for better CPU performance | |
| print(f"🔧 Loading model from {MODEL_REPO}...") | |
| print(f"🖥️ Device: {DEVICE} ({DTYPE})") | |
| # Download model files (uses HF cache) | |
| model_dir = snapshot_download(MODEL_REPO, cache_dir=None) | |
| # Dynamic import of model.py from downloaded dir | |
| spec = importlib.util.spec_from_file_location("model", f"{model_dir}/model.py") | |
| model_module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(model_module) | |
| DualHeadModel = model_module.DualHeadModel | |
| # Load model + tokenizer + config | |
| model, tokenizer, config = DualHeadModel.from_pretrained(model_dir, device=DEVICE) | |
| # Cast to target dtype | |
| if DTYPE == torch.float16: | |
| model = model.half() | |
| # Warm-up inference | |
| with torch.no_grad(): | |
| dummy_input = tokenizer("Warm-up", return_tensors="pt", padding="max_length", | |
| max_length=MAX_LENGTH, truncation=True) | |
| _ = model.predict( | |
| dummy_input["input_ids"].to(DEVICE), | |
| dummy_input["attention_mask"].to(DEVICE) | |
| ) | |
| print(f"✅ Model loaded and warmed up") | |
| # Extract label names | |
| LABELS_ONDERWERP = config["labels"]["onderwerp"] | |
| LABELS_BELEVING = config["labels"]["beleving"] | |
| def prob_to_color(prob: float, threshold: float) -> str: | |
| """Generate CSS style for probability visualization""" | |
| lightness = 95 - int(prob * 65) | |
| border = "2px solid #1e3a8a" if prob >= threshold else "1px solid #e5e7eb" | |
| return f"background: hsl(210, 80%, {lightness}%); border: {border}; padding: 6px 12px; border-radius: 4px; margin: 2px 0;" | |
| def format_topk(labels: list, probs: list, threshold: float, topk: int) -> str: | |
| """Generate HTML for top-K labels""" | |
| sorted_indices = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True) | |
| html = "<div style='display: flex; flex-direction: column; gap: 6px;'>" | |
| for idx in sorted_indices[:topk]: | |
| label = labels[idx] | |
| prob = probs[idx] | |
| style = prob_to_color(prob, threshold) | |
| predicted = " ✓" if prob >= threshold else "" | |
| html += f"<div style='{style}'><b>{label}</b>: {prob:.3f}{predicted}</div>" | |
| html += "</div>" | |
| return html | |
| def format_all_labels(head_name: str, labels: list, probs: list, threshold: float) -> str: | |
| """Generate scrollable table for all labels""" | |
| sorted_indices = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True) | |
| html = f"<h3>{head_name}</h3><div style='max-height: 500px; overflow-y: auto; border: 1px solid #e5e7eb; border-radius: 4px;'>" | |
| html += "<table style='width: 100%; border-collapse: collapse;'>" | |
| html += "<thead style='position: sticky; top: 0; background: white; border-bottom: 2px solid #e5e7eb;'>" | |
| html += "<tr><th style='text-align: left; padding: 8px;'>Label</th><th style='text-align: right; padding: 8px;'>Probability</th><th style='padding: 8px;'>Predicted</th></tr>" | |
| html += "</thead><tbody>" | |
| for idx in sorted_indices: | |
| label = labels[idx] | |
| prob = probs[idx] | |
| style = prob_to_color(prob, threshold) | |
| predicted = "✓" if prob >= threshold else "" | |
| html += f"<tr><td style='{style}'><b>{label}</b></td><td style='text-align: right; padding: 8px;'>{prob:.4f}</td><td style='text-align: center; padding: 8px;'>{predicted}</td></tr>" | |
| html += "</tbody></table></div>" | |
| return html | |
| def predict(text: str, threshold: float, topk: int): | |
| """Run inference and return visualizations""" | |
| if not text or not text.strip(): | |
| empty_msg = "<p style='color: #666; font-style: italic;'>Voer een bericht in om te classificeren...</p>" | |
| return empty_msg, empty_msg, {} | |
| # Tokenize | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding="max_length", | |
| max_length=MAX_LENGTH, | |
| truncation=True | |
| ) | |
| # Move to device | |
| input_ids = inputs["input_ids"].to(DEVICE) | |
| attention_mask = inputs["attention_mask"].to(DEVICE) | |
| # Predict | |
| onderwerp_probs, beleving_probs = model.predict(input_ids, attention_mask) | |
| # Convert to lists | |
| onderwerp_probs = onderwerp_probs[0].cpu().numpy().tolist() | |
| beleving_probs = beleving_probs[0].cpu().numpy().tolist() | |
| # Generate summary view (top-K for each head side by side) | |
| summary_html = "<div style='display: grid; grid-template-columns: 1fr 1fr; gap: 20px;'>" | |
| summary_html += f"<div><h3>Onderwerp (Top-{topk})</h3>{format_topk(LABELS_ONDERWERP, onderwerp_probs, threshold, topk)}</div>" | |
| summary_html += f"<div><h3>Beleving (Top-{topk})</h3>{format_topk(LABELS_BELEVING, beleving_probs, threshold, topk)}</div>" | |
| summary_html += "</div>" | |
| # Generate all labels view | |
| all_labels_html = "<div style='display: grid; grid-template-columns: 1fr 1fr; gap: 20px;'>" | |
| all_labels_html += f"<div>{format_all_labels('Onderwerp', LABELS_ONDERWERP, onderwerp_probs, threshold)}</div>" | |
| all_labels_html += f"<div>{format_all_labels('Beleving', LABELS_BELEVING, beleving_probs, threshold)}</div>" | |
| all_labels_html += "</div>" | |
| # Generate JSON output | |
| json_output = { | |
| "text": text, | |
| "threshold": threshold, | |
| "onderwerp": { | |
| "probabilities": {label: float(prob) for label, prob in zip(LABELS_ONDERWERP, onderwerp_probs)}, | |
| "predicted": [label for label, prob in zip(LABELS_ONDERWERP, onderwerp_probs) if prob >= threshold] | |
| }, | |
| "beleving": { | |
| "probabilities": {label: float(prob) for label, prob in zip(LABELS_BELEVING, beleving_probs)}, | |
| "predicted": [label for label, prob in zip(LABELS_BELEVING, beleving_probs) if prob >= threshold] | |
| } | |
| } | |
| return summary_html, all_labels_html, json_output | |
| def load_examples(): | |
| """Load example texts""" | |
| try: | |
| with open("examples.json") as f: | |
| return json.load(f) | |
| except: | |
| return [] | |
| # Build Gradio interface | |
| with gr.Blocks(title="WimBERT Synth v0", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🏛️ WimBERT Synth v0: Multi-label Signaal Classifier | |
| Classificeert Nederlandse signaalberichten op **Onderwerp** (64 categorieën) en **Beleving** (33 categorieën). | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| input_text = gr.Textbox( | |
| label="Signaalbericht (Nederlands)", | |
| lines=8, | |
| placeholder="Bijv: Ik kan niet parkeren bij mijn huis en de website voor vergunningen werkt niet...", | |
| info="Voer een bericht in en klik op 'Voorspel'" | |
| ) | |
| with gr.Row(): | |
| predict_btn = gr.Button("🔮 Voorspel", variant="primary", scale=2) | |
| clear_btn = gr.ClearButton([input_text], value="🗑️ Wissen", scale=1) | |
| with gr.Column(scale=1): | |
| threshold_slider = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.5, | |
| step=0.05, | |
| label="🎯 Drempel", | |
| info="Labels boven deze waarde worden als 'voorspeld' gemarkeerd" | |
| ) | |
| topk_slider = gr.Slider( | |
| minimum=1, | |
| maximum=15, | |
| value=5, | |
| step=1, | |
| label="📊 Top-K", | |
| info="Aantal top labels om te tonen in samenvatting" | |
| ) | |
| gr.Markdown(f""" | |
| **Hardware:** {DEVICE.type.upper()} | |
| **Dtype:** {DTYPE} | |
| **Max length:** {MAX_LENGTH} | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("📋 Samenvatting"): | |
| summary_output = gr.HTML(label="Top voorspellingen per categorie") | |
| with gr.Tab("📊 Alle labels"): | |
| all_labels_output = gr.HTML(label="Volledige classificatie") | |
| with gr.Tab("💾 JSON"): | |
| json_output = gr.JSON(label="Ruwe output") | |
| gr.Examples( | |
| examples=load_examples(), | |
| inputs=input_text, | |
| label="📝 Voorbeelden" | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### ℹ️ Over dit model | |
| - **Model:** `UWV/wimbert-synth-v0` (dual-head BERT) | |
| - **Licentie:** Apache-2.0 | |
| - **Privacy:** Input wordt alleen in-memory verwerkt, niet opgeslagen | |
| [Model Card](https://huggingface.co/UWV/wimbert-synth-v0) • Gebouwd met Gradio | |
| """) | |
| # Event handlers | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=[input_text, threshold_slider, topk_slider], | |
| outputs=[summary_output, all_labels_output, json_output] | |
| ) | |
| # Update predictions when threshold/topk changes (if there's existing output) | |
| threshold_slider.change( | |
| fn=predict, | |
| inputs=[input_text, threshold_slider, topk_slider], | |
| outputs=[summary_output, all_labels_output, json_output] | |
| ) | |
| topk_slider.change( | |
| fn=predict, | |
| inputs=[input_text, threshold_slider, topk_slider], | |
| outputs=[summary_output, all_labels_output, json_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |