yhavinga commited on
Commit
13d4fa0
·
1 Parent(s): 27e4fba

Update train script to filter calamity messages

Browse files
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  language:
3
  - nl
4
- license: other
5
  library_name: transformers
6
  pipeline_tag: text-classification
7
  base_model: bert-base-multilingual-cased
@@ -190,3 +190,6 @@ print("Beleving:", topk(be_probs, labels["beleving"]))
190
  ## Acknowledgements
191
  - UWV WIM synthetic RD dataset
192
  - Hugging Face Transformers/Datasets
 
 
 
 
1
  ---
2
  language:
3
  - nl
4
+ license: apache-2.0
5
  library_name: transformers
6
  pipeline_tag: text-classification
7
  base_model: bert-base-multilingual-cased
 
190
  ## Acknowledgements
191
  - UWV WIM synthetic RD dataset
192
  - Hugging Face Transformers/Datasets
193
+
194
+ ## License
195
+ This model is licensed under the Apache License 2.0. See `LICENSE` for details.
train/rd_dataset_loader.py CHANGED
@@ -8,7 +8,7 @@ import numpy as np
8
  from datasets import load_dataset
9
 
10
 
11
- def load_rd_wim_dataset(max_samples=None, split='train'):
12
  """
13
  Load UWV/wim-synthetic-data-rd dataset and encode multi-labels.
14
 
@@ -17,21 +17,29 @@ def load_rd_wim_dataset(max_samples=None, split='train'):
17
  - beleving: How the citizen experienced the interaction (26 unique labels)
18
 
19
  Args:
20
- max_samples: Limit number of samples (None = all 9,351 samples)
21
  split: Dataset split to load (default: 'train')
 
22
 
23
  Returns:
24
  texts: List of conversation strings
25
- onderwerp_encoded: numpy array [n_samples, 96] - multi-hot encoded topics
26
- beleving_encoded: numpy array [n_samples, 26] - multi-hot encoded experiences
27
- onderwerp_labels: List of 96 onderwerp label names (sorted alphabetically)
28
- beleving_labels: List of 26 beleving label names (sorted alphabetically)
29
  """
30
 
31
  # Load dataset from HuggingFace
32
  print(f"Loading UWV/wim-synthetic-data-rd dataset (split={split})...")
33
  ds = load_dataset('UWV/wim-synthetic-data-rd', split=split)
34
 
 
 
 
 
 
 
 
35
  # Limit samples if requested
36
  if max_samples is not None:
37
  ds = ds.select(range(min(max_samples, len(ds))))
 
8
  from datasets import load_dataset
9
 
10
 
11
+ def load_rd_wim_dataset(max_samples=None, split='train', filter_calamity=True):
12
  """
13
  Load UWV/wim-synthetic-data-rd dataset and encode multi-labels.
14
 
 
17
  - beleving: How the citizen experienced the interaction (26 unique labels)
18
 
19
  Args:
20
+ max_samples: Limit number of samples (None = all samples)
21
  split: Dataset split to load (default: 'train')
22
+ filter_calamity: If True, exclude samples with is_calamity=True (default: True)
23
 
24
  Returns:
25
  texts: List of conversation strings
26
+ onderwerp_encoded: numpy array [n_samples, n_onderwerp] - multi-hot encoded topics
27
+ beleving_encoded: numpy array [n_samples, n_beleving] - multi-hot encoded experiences
28
+ onderwerp_labels: List of onderwerp label names (sorted alphabetically)
29
+ beleving_labels: List of beleving label names (sorted alphabetically)
30
  """
31
 
32
  # Load dataset from HuggingFace
33
  print(f"Loading UWV/wim-synthetic-data-rd dataset (split={split})...")
34
  ds = load_dataset('UWV/wim-synthetic-data-rd', split=split)
35
 
36
+ # Filter out calamity samples if requested
37
+ if filter_calamity:
38
+ original_len = len(ds)
39
+ ds = ds.filter(lambda x: not x['is_calamity'])
40
+ filtered_len = len(ds)
41
+ print(f"Filtered out {original_len - filtered_len} calamity samples ({filtered_len} remaining)")
42
+
43
  # Limit samples if requested
44
  if max_samples is not None:
45
  ds = ds.select(range(min(max_samples, len(ds))))
train/train_mmbert_dual_soft_f1_simplified.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/env python3
2
  """
3
  Dual-head multi-label PyTorch training script for mmBERT-base.
4
- Two classification heads: onderwerp (96 labels) and beleving (26 labels).
5
  Uses combined F1+BCE loss with weight α (configurable balance).
6
  Features: learnable thresholds, warmup + cosine LR, gradient clipping.
7
  mmBERT: Modern multilingual encoder (1800+ languages, 2x faster than XLM-R).
@@ -702,7 +702,7 @@ def main():
702
  set_seed(cfg.seed)
703
 
704
  # Load RD dataset
705
- print("\nLoading FULL RD dataset (9,351 samples)...")
706
  texts, onderwerp, beleving, onderwerp_names, beleving_names = load_rd_wim_dataset(
707
  max_samples=None # Using full dataset for better training
708
  )
 
1
  #!/usr/bin/env python3
2
  """
3
  Dual-head multi-label PyTorch training script for mmBERT-base.
4
+ Two classification heads: onderwerp (topic) and beleving (experience) with dynamic label counts.
5
  Uses combined F1+BCE loss with weight α (configurable balance).
6
  Features: learnable thresholds, warmup + cosine LR, gradient clipping.
7
  mmBERT: Modern multilingual encoder (1800+ languages, 2x faster than XLM-R).
 
702
  set_seed(cfg.seed)
703
 
704
  # Load RD dataset
705
+ print("\nLoading RD dataset...")
706
  texts, onderwerp, beleving, onderwerp_names, beleving_names = load_rd_wim_dataset(
707
  max_samples=None # Using full dataset for better training
708
  )