Update train script to filter calamity messages
Browse files- README.md +4 -1
- train/rd_dataset_loader.py +14 -6
- train/train_mmbert_dual_soft_f1_simplified.py +2 -2
README.md
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
---
|
| 2 |
language:
|
| 3 |
- nl
|
| 4 |
-
license:
|
| 5 |
library_name: transformers
|
| 6 |
pipeline_tag: text-classification
|
| 7 |
base_model: bert-base-multilingual-cased
|
|
@@ -190,3 +190,6 @@ print("Beleving:", topk(be_probs, labels["beleving"]))
|
|
| 190 |
## Acknowledgements
|
| 191 |
- UWV WIM synthetic RD dataset
|
| 192 |
- Hugging Face Transformers/Datasets
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
language:
|
| 3 |
- nl
|
| 4 |
+
license: apache-2.0
|
| 5 |
library_name: transformers
|
| 6 |
pipeline_tag: text-classification
|
| 7 |
base_model: bert-base-multilingual-cased
|
|
|
|
| 190 |
## Acknowledgements
|
| 191 |
- UWV WIM synthetic RD dataset
|
| 192 |
- Hugging Face Transformers/Datasets
|
| 193 |
+
|
| 194 |
+
## License
|
| 195 |
+
This model is licensed under the Apache License 2.0. See `LICENSE` for details.
|
train/rd_dataset_loader.py
CHANGED
|
@@ -8,7 +8,7 @@ import numpy as np
|
|
| 8 |
from datasets import load_dataset
|
| 9 |
|
| 10 |
|
| 11 |
-
def load_rd_wim_dataset(max_samples=None, split='train'):
|
| 12 |
"""
|
| 13 |
Load UWV/wim-synthetic-data-rd dataset and encode multi-labels.
|
| 14 |
|
|
@@ -17,21 +17,29 @@ def load_rd_wim_dataset(max_samples=None, split='train'):
|
|
| 17 |
- beleving: How the citizen experienced the interaction (26 unique labels)
|
| 18 |
|
| 19 |
Args:
|
| 20 |
-
max_samples: Limit number of samples (None = all
|
| 21 |
split: Dataset split to load (default: 'train')
|
|
|
|
| 22 |
|
| 23 |
Returns:
|
| 24 |
texts: List of conversation strings
|
| 25 |
-
onderwerp_encoded: numpy array [n_samples,
|
| 26 |
-
beleving_encoded: numpy array [n_samples,
|
| 27 |
-
onderwerp_labels: List of
|
| 28 |
-
beleving_labels: List of
|
| 29 |
"""
|
| 30 |
|
| 31 |
# Load dataset from HuggingFace
|
| 32 |
print(f"Loading UWV/wim-synthetic-data-rd dataset (split={split})...")
|
| 33 |
ds = load_dataset('UWV/wim-synthetic-data-rd', split=split)
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
# Limit samples if requested
|
| 36 |
if max_samples is not None:
|
| 37 |
ds = ds.select(range(min(max_samples, len(ds))))
|
|
|
|
| 8 |
from datasets import load_dataset
|
| 9 |
|
| 10 |
|
| 11 |
+
def load_rd_wim_dataset(max_samples=None, split='train', filter_calamity=True):
|
| 12 |
"""
|
| 13 |
Load UWV/wim-synthetic-data-rd dataset and encode multi-labels.
|
| 14 |
|
|
|
|
| 17 |
- beleving: How the citizen experienced the interaction (26 unique labels)
|
| 18 |
|
| 19 |
Args:
|
| 20 |
+
max_samples: Limit number of samples (None = all samples)
|
| 21 |
split: Dataset split to load (default: 'train')
|
| 22 |
+
filter_calamity: If True, exclude samples with is_calamity=True (default: True)
|
| 23 |
|
| 24 |
Returns:
|
| 25 |
texts: List of conversation strings
|
| 26 |
+
onderwerp_encoded: numpy array [n_samples, n_onderwerp] - multi-hot encoded topics
|
| 27 |
+
beleving_encoded: numpy array [n_samples, n_beleving] - multi-hot encoded experiences
|
| 28 |
+
onderwerp_labels: List of onderwerp label names (sorted alphabetically)
|
| 29 |
+
beleving_labels: List of beleving label names (sorted alphabetically)
|
| 30 |
"""
|
| 31 |
|
| 32 |
# Load dataset from HuggingFace
|
| 33 |
print(f"Loading UWV/wim-synthetic-data-rd dataset (split={split})...")
|
| 34 |
ds = load_dataset('UWV/wim-synthetic-data-rd', split=split)
|
| 35 |
|
| 36 |
+
# Filter out calamity samples if requested
|
| 37 |
+
if filter_calamity:
|
| 38 |
+
original_len = len(ds)
|
| 39 |
+
ds = ds.filter(lambda x: not x['is_calamity'])
|
| 40 |
+
filtered_len = len(ds)
|
| 41 |
+
print(f"Filtered out {original_len - filtered_len} calamity samples ({filtered_len} remaining)")
|
| 42 |
+
|
| 43 |
# Limit samples if requested
|
| 44 |
if max_samples is not None:
|
| 45 |
ds = ds.select(range(min(max_samples, len(ds))))
|
train/train_mmbert_dual_soft_f1_simplified.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
Dual-head multi-label PyTorch training script for mmBERT-base.
|
| 4 |
-
Two classification heads: onderwerp (
|
| 5 |
Uses combined F1+BCE loss with weight α (configurable balance).
|
| 6 |
Features: learnable thresholds, warmup + cosine LR, gradient clipping.
|
| 7 |
mmBERT: Modern multilingual encoder (1800+ languages, 2x faster than XLM-R).
|
|
@@ -702,7 +702,7 @@ def main():
|
|
| 702 |
set_seed(cfg.seed)
|
| 703 |
|
| 704 |
# Load RD dataset
|
| 705 |
-
print("\nLoading
|
| 706 |
texts, onderwerp, beleving, onderwerp_names, beleving_names = load_rd_wim_dataset(
|
| 707 |
max_samples=None # Using full dataset for better training
|
| 708 |
)
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
Dual-head multi-label PyTorch training script for mmBERT-base.
|
| 4 |
+
Two classification heads: onderwerp (topic) and beleving (experience) with dynamic label counts.
|
| 5 |
Uses combined F1+BCE loss with weight α (configurable balance).
|
| 6 |
Features: learnable thresholds, warmup + cosine LR, gradient clipping.
|
| 7 |
mmBERT: Modern multilingual encoder (1800+ languages, 2x faster than XLM-R).
|
|
|
|
| 702 |
set_seed(cfg.seed)
|
| 703 |
|
| 704 |
# Load RD dataset
|
| 705 |
+
print("\nLoading RD dataset...")
|
| 706 |
texts, onderwerp, beleving, onderwerp_names, beleving_names = load_rd_wim_dataset(
|
| 707 |
max_samples=None # Using full dataset for better training
|
| 708 |
)
|