wimbert-space / app.py
yhavinga's picture
Initial Gradio Space implementation for WimBERT Synth v0
85efe28
raw
history blame
9.47 kB
#!/usr/bin/env python3
"""
WimBERT Synth v0 Gradio Space
Dual-head multi-label classifier for Dutch signal messages
"""
import json
import importlib.util
import torch
import gradio as gr
from huggingface_hub import snapshot_download
# Constants
MODEL_REPO = "UWV/wimbert-synth-v0"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float16 if DEVICE.type == "cuda" else torch.float32
MAX_LENGTH = 512 # Default to 512 for better CPU performance
print(f"🔧 Loading model from {MODEL_REPO}...")
print(f"🖥️ Device: {DEVICE} ({DTYPE})")
# Download model files (uses HF cache)
model_dir = snapshot_download(MODEL_REPO, cache_dir=None)
# Dynamic import of model.py from downloaded dir
spec = importlib.util.spec_from_file_location("model", f"{model_dir}/model.py")
model_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_module)
DualHeadModel = model_module.DualHeadModel
# Load model + tokenizer + config
model, tokenizer, config = DualHeadModel.from_pretrained(model_dir, device=DEVICE)
# Cast to target dtype
if DTYPE == torch.float16:
model = model.half()
# Warm-up inference
with torch.no_grad():
dummy_input = tokenizer("Warm-up", return_tensors="pt", padding="max_length",
max_length=MAX_LENGTH, truncation=True)
_ = model.predict(
dummy_input["input_ids"].to(DEVICE),
dummy_input["attention_mask"].to(DEVICE)
)
print(f"✅ Model loaded and warmed up")
# Extract label names
LABELS_ONDERWERP = config["labels"]["onderwerp"]
LABELS_BELEVING = config["labels"]["beleving"]
def prob_to_color(prob: float, threshold: float) -> str:
"""Generate CSS style for probability visualization"""
lightness = 95 - int(prob * 65)
border = "2px solid #1e3a8a" if prob >= threshold else "1px solid #e5e7eb"
return f"background: hsl(210, 80%, {lightness}%); border: {border}; padding: 6px 12px; border-radius: 4px; margin: 2px 0;"
def format_topk(labels: list, probs: list, threshold: float, topk: int) -> str:
"""Generate HTML for top-K labels"""
sorted_indices = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True)
html = "<div style='display: flex; flex-direction: column; gap: 6px;'>"
for idx in sorted_indices[:topk]:
label = labels[idx]
prob = probs[idx]
style = prob_to_color(prob, threshold)
predicted = " ✓" if prob >= threshold else ""
html += f"<div style='{style}'><b>{label}</b>: {prob:.3f}{predicted}</div>"
html += "</div>"
return html
def format_all_labels(head_name: str, labels: list, probs: list, threshold: float) -> str:
"""Generate scrollable table for all labels"""
sorted_indices = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True)
html = f"<h3>{head_name}</h3><div style='max-height: 500px; overflow-y: auto; border: 1px solid #e5e7eb; border-radius: 4px;'>"
html += "<table style='width: 100%; border-collapse: collapse;'>"
html += "<thead style='position: sticky; top: 0; background: white; border-bottom: 2px solid #e5e7eb;'>"
html += "<tr><th style='text-align: left; padding: 8px;'>Label</th><th style='text-align: right; padding: 8px;'>Probability</th><th style='padding: 8px;'>Predicted</th></tr>"
html += "</thead><tbody>"
for idx in sorted_indices:
label = labels[idx]
prob = probs[idx]
style = prob_to_color(prob, threshold)
predicted = "✓" if prob >= threshold else ""
html += f"<tr><td style='{style}'><b>{label}</b></td><td style='text-align: right; padding: 8px;'>{prob:.4f}</td><td style='text-align: center; padding: 8px;'>{predicted}</td></tr>"
html += "</tbody></table></div>"
return html
@torch.inference_mode()
def predict(text: str, threshold: float, topk: int):
"""Run inference and return visualizations"""
if not text or not text.strip():
empty_msg = "<p style='color: #666; font-style: italic;'>Voer een bericht in om te classificeren...</p>"
return empty_msg, empty_msg, {}
# Tokenize
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
max_length=MAX_LENGTH,
truncation=True
)
# Move to device
input_ids = inputs["input_ids"].to(DEVICE)
attention_mask = inputs["attention_mask"].to(DEVICE)
# Predict
onderwerp_probs, beleving_probs = model.predict(input_ids, attention_mask)
# Convert to lists
onderwerp_probs = onderwerp_probs[0].cpu().numpy().tolist()
beleving_probs = beleving_probs[0].cpu().numpy().tolist()
# Generate summary view (top-K for each head side by side)
summary_html = "<div style='display: grid; grid-template-columns: 1fr 1fr; gap: 20px;'>"
summary_html += f"<div><h3>Onderwerp (Top-{topk})</h3>{format_topk(LABELS_ONDERWERP, onderwerp_probs, threshold, topk)}</div>"
summary_html += f"<div><h3>Beleving (Top-{topk})</h3>{format_topk(LABELS_BELEVING, beleving_probs, threshold, topk)}</div>"
summary_html += "</div>"
# Generate all labels view
all_labels_html = "<div style='display: grid; grid-template-columns: 1fr 1fr; gap: 20px;'>"
all_labels_html += f"<div>{format_all_labels('Onderwerp', LABELS_ONDERWERP, onderwerp_probs, threshold)}</div>"
all_labels_html += f"<div>{format_all_labels('Beleving', LABELS_BELEVING, beleving_probs, threshold)}</div>"
all_labels_html += "</div>"
# Generate JSON output
json_output = {
"text": text,
"threshold": threshold,
"onderwerp": {
"probabilities": {label: float(prob) for label, prob in zip(LABELS_ONDERWERP, onderwerp_probs)},
"predicted": [label for label, prob in zip(LABELS_ONDERWERP, onderwerp_probs) if prob >= threshold]
},
"beleving": {
"probabilities": {label: float(prob) for label, prob in zip(LABELS_BELEVING, beleving_probs)},
"predicted": [label for label, prob in zip(LABELS_BELEVING, beleving_probs) if prob >= threshold]
}
}
return summary_html, all_labels_html, json_output
def load_examples():
"""Load example texts"""
try:
with open("examples.json") as f:
return json.load(f)
except:
return []
# Build Gradio interface
with gr.Blocks(title="WimBERT Synth v0", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🏛️ WimBERT Synth v0: Multi-label Signaal Classifier
Classificeert Nederlandse signaalberichten op **Onderwerp** (64 categorieën) en **Beleving** (33 categorieën).
""")
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(
label="Signaalbericht (Nederlands)",
lines=8,
placeholder="Bijv: Ik kan niet parkeren bij mijn huis en de website voor vergunningen werkt niet...",
info="Voer een bericht in en klik op 'Voorspel'"
)
with gr.Row():
predict_btn = gr.Button("🔮 Voorspel", variant="primary", scale=2)
clear_btn = gr.ClearButton([input_text], value="🗑️ Wissen", scale=1)
with gr.Column(scale=1):
threshold_slider = gr.Slider(
minimum=0,
maximum=1,
value=0.5,
step=0.05,
label="🎯 Drempel",
info="Labels boven deze waarde worden als 'voorspeld' gemarkeerd"
)
topk_slider = gr.Slider(
minimum=1,
maximum=15,
value=5,
step=1,
label="📊 Top-K",
info="Aantal top labels om te tonen in samenvatting"
)
gr.Markdown(f"""
**Hardware:** {DEVICE.type.upper()}
**Dtype:** {DTYPE}
**Max length:** {MAX_LENGTH}
""")
with gr.Tabs():
with gr.Tab("📋 Samenvatting"):
summary_output = gr.HTML(label="Top voorspellingen per categorie")
with gr.Tab("📊 Alle labels"):
all_labels_output = gr.HTML(label="Volledige classificatie")
with gr.Tab("💾 JSON"):
json_output = gr.JSON(label="Ruwe output")
gr.Examples(
examples=load_examples(),
inputs=input_text,
label="📝 Voorbeelden"
)
gr.Markdown("""
---
### ℹ️ Over dit model
- **Model:** `UWV/wimbert-synth-v0` (dual-head BERT)
- **Licentie:** Apache-2.0
- **Privacy:** Input wordt alleen in-memory verwerkt, niet opgeslagen
[Model Card](https://huggingface.co/UWV/wimbert-synth-v0) • Gebouwd met Gradio
""")
# Event handlers
predict_btn.click(
fn=predict,
inputs=[input_text, threshold_slider, topk_slider],
outputs=[summary_output, all_labels_output, json_output]
)
# Update predictions when threshold/topk changes (if there's existing output)
threshold_slider.change(
fn=predict,
inputs=[input_text, threshold_slider, topk_slider],
outputs=[summary_output, all_labels_output, json_output]
)
topk_slider.change(
fn=predict,
inputs=[input_text, threshold_slider, topk_slider],
outputs=[summary_output, all_labels_output, json_output]
)
if __name__ == "__main__":
demo.launch()