#!/usr/bin/env python3 """ WimBERT Synth v0 Gradio Space Dual-head multi-label classifier for Dutch signal messages """ import json import importlib.util import torch import gradio as gr from huggingface_hub import snapshot_download # Constants MODEL_REPO = "UWV/wimbert-synth-v0" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DTYPE = torch.float16 if DEVICE.type == "cuda" else torch.float32 MAX_LENGTH = 512 # Default to 512 for better CPU performance print(f"🔧 Loading model from {MODEL_REPO}...") print(f"🖥️ Device: {DEVICE} ({DTYPE})") # Download model files (uses HF cache) model_dir = snapshot_download(MODEL_REPO, cache_dir=None) # Dynamic import of model.py from downloaded dir spec = importlib.util.spec_from_file_location("model", f"{model_dir}/model.py") model_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(model_module) DualHeadModel = model_module.DualHeadModel # Load model + tokenizer + config model, tokenizer, config = DualHeadModel.from_pretrained(model_dir, device=DEVICE) # Cast to target dtype if DTYPE == torch.float16: model = model.half() # Warm-up inference with torch.no_grad(): dummy_input = tokenizer("Warm-up", return_tensors="pt", padding="max_length", max_length=MAX_LENGTH, truncation=True) _ = model.predict( dummy_input["input_ids"].to(DEVICE), dummy_input["attention_mask"].to(DEVICE) ) print(f"✅ Model loaded and warmed up") # Extract label names LABELS_ONDERWERP = config["labels"]["onderwerp"] LABELS_BELEVING = config["labels"]["beleving"] def prob_to_color(prob: float, threshold: float) -> str: """Generate CSS style for probability visualization""" lightness = 95 - int(prob * 65) border = "2px solid #1e3a8a" if prob >= threshold else "1px solid #e5e7eb" return f"background: hsl(210, 80%, {lightness}%); border: {border}; padding: 6px 12px; border-radius: 4px; margin: 2px 0;" def format_topk(labels: list, probs: list, threshold: float, topk: int) -> str: """Generate HTML for top-K labels""" sorted_indices = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True) html = "
" for idx in sorted_indices[:topk]: label = labels[idx] prob = probs[idx] style = prob_to_color(prob, threshold) predicted = " ✓" if prob >= threshold else "" html += f"
{label}: {prob:.3f}{predicted}
" html += "
" return html def format_all_labels(head_name: str, labels: list, probs: list, threshold: float) -> str: """Generate scrollable table for all labels""" sorted_indices = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True) html = f"

{head_name}

" html += "" html += "" html += "" html += "" for idx in sorted_indices: label = labels[idx] prob = probs[idx] style = prob_to_color(prob, threshold) predicted = "✓" if prob >= threshold else "" html += f"" html += "
LabelProbabilityPredicted
{label}{prob:.4f}{predicted}
" return html @torch.inference_mode() def predict(text: str, threshold: float, topk: int): """Run inference and return visualizations""" if not text or not text.strip(): empty_msg = "

Voer een bericht in om te classificeren...

" return empty_msg, empty_msg, {} # Tokenize inputs = tokenizer( text, return_tensors="pt", padding="max_length", max_length=MAX_LENGTH, truncation=True ) # Move to device input_ids = inputs["input_ids"].to(DEVICE) attention_mask = inputs["attention_mask"].to(DEVICE) # Predict onderwerp_probs, beleving_probs = model.predict(input_ids, attention_mask) # Convert to lists onderwerp_probs = onderwerp_probs[0].cpu().numpy().tolist() beleving_probs = beleving_probs[0].cpu().numpy().tolist() # Generate summary view (top-K for each head side by side) summary_html = "
" summary_html += f"

Onderwerp (Top-{topk})

{format_topk(LABELS_ONDERWERP, onderwerp_probs, threshold, topk)}
" summary_html += f"

Beleving (Top-{topk})

{format_topk(LABELS_BELEVING, beleving_probs, threshold, topk)}
" summary_html += "
" # Generate all labels view all_labels_html = "
" all_labels_html += f"
{format_all_labels('Onderwerp', LABELS_ONDERWERP, onderwerp_probs, threshold)}
" all_labels_html += f"
{format_all_labels('Beleving', LABELS_BELEVING, beleving_probs, threshold)}
" all_labels_html += "
" # Generate JSON output json_output = { "text": text, "threshold": threshold, "onderwerp": { "probabilities": {label: float(prob) for label, prob in zip(LABELS_ONDERWERP, onderwerp_probs)}, "predicted": [label for label, prob in zip(LABELS_ONDERWERP, onderwerp_probs) if prob >= threshold] }, "beleving": { "probabilities": {label: float(prob) for label, prob in zip(LABELS_BELEVING, beleving_probs)}, "predicted": [label for label, prob in zip(LABELS_BELEVING, beleving_probs) if prob >= threshold] } } return summary_html, all_labels_html, json_output def load_examples(): """Load example texts""" try: with open("examples.json") as f: return json.load(f) except: return [] # Build Gradio interface with gr.Blocks(title="WimBERT Synth v0", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🏛️ WimBERT Synth v0: Multi-label Signaal Classifier Classificeert Nederlandse signaalberichten op **Onderwerp** (64 categorieën) en **Beleving** (33 categorieën). """) with gr.Row(): with gr.Column(scale=2): input_text = gr.Textbox( label="Signaalbericht (Nederlands)", lines=8, placeholder="Bijv: Ik kan niet parkeren bij mijn huis en de website voor vergunningen werkt niet...", info="Voer een bericht in en klik op 'Voorspel'" ) with gr.Row(): predict_btn = gr.Button("🔮 Voorspel", variant="primary", scale=2) clear_btn = gr.ClearButton([input_text], value="🗑️ Wissen", scale=1) with gr.Column(scale=1): threshold_slider = gr.Slider( minimum=0, maximum=1, value=0.5, step=0.05, label="🎯 Drempel", info="Labels boven deze waarde worden als 'voorspeld' gemarkeerd" ) topk_slider = gr.Slider( minimum=1, maximum=15, value=5, step=1, label="📊 Top-K", info="Aantal top labels om te tonen in samenvatting" ) gr.Markdown(f""" **Hardware:** {DEVICE.type.upper()} **Dtype:** {DTYPE} **Max length:** {MAX_LENGTH} """) with gr.Tabs(): with gr.Tab("📋 Samenvatting"): summary_output = gr.HTML(label="Top voorspellingen per categorie") with gr.Tab("📊 Alle labels"): all_labels_output = gr.HTML(label="Volledige classificatie") with gr.Tab("💾 JSON"): json_output = gr.JSON(label="Ruwe output") gr.Examples( examples=load_examples(), inputs=input_text, label="📝 Voorbeelden" ) gr.Markdown(""" --- ### ℹ️ Over dit model - **Model:** `UWV/wimbert-synth-v0` (dual-head BERT) - **Licentie:** Apache-2.0 - **Privacy:** Input wordt alleen in-memory verwerkt, niet opgeslagen [Model Card](https://huggingface.co/UWV/wimbert-synth-v0) • Gebouwd met Gradio """) # Event handlers predict_btn.click( fn=predict, inputs=[input_text, threshold_slider, topk_slider], outputs=[summary_output, all_labels_output, json_output] ) # Update predictions when threshold/topk changes (if there's existing output) threshold_slider.change( fn=predict, inputs=[input_text, threshold_slider, topk_slider], outputs=[summary_output, all_labels_output, json_output] ) topk_slider.change( fn=predict, inputs=[input_text, threshold_slider, topk_slider], outputs=[summary_output, all_labels_output, json_output] ) if __name__ == "__main__": demo.launch()