veryfansome commited on
Commit
c5081c8
·
1 Parent(s): ab4b5ab

feat: working end-to-end

Browse files
dataset_maker.py CHANGED
@@ -5,19 +5,14 @@ from openai import AsyncOpenAI
5
  from traceback import format_exc
6
  from typing import Union
7
  import asyncio
8
- import itertools
9
  import json
10
  import logging
11
- import sentencepiece as spm
12
 
13
- from utils import default_logging_config
14
 
15
  client = AsyncOpenAI()
16
  logger = logging.getLogger(__name__)
17
 
18
- sp = spm.SentencePieceProcessor()
19
- sp.LoadFromFile(f"sp.model")
20
-
21
  features = {
22
  "adj": {"JJ": "adjective",
23
  "JJR": "comparative adjective",
@@ -180,7 +175,7 @@ async def classify_with_retry(prompt, labels, tokens, model="gpt-4o", retry=10):
180
  await asyncio.sleep(i)
181
 
182
  async def generate_token_labels(case, model="gpt-4o"):
183
- tokens = list(itertools.chain.from_iterable([s.strip("▁").split("▁") for s in sp.EncodeAsPieces(case)]))
184
  sorted_cols = list(sorted(features.keys()))
185
  example = {}
186
  for idx, labels in enumerate(list(await asyncio.gather(
 
5
  from traceback import format_exc
6
  from typing import Union
7
  import asyncio
 
8
  import json
9
  import logging
 
10
 
11
+ from utils import default_logging_config, sp_tokenize
12
 
13
  client = AsyncOpenAI()
14
  logger = logging.getLogger(__name__)
15
 
 
 
 
16
  features = {
17
  "adj": {"JJ": "adjective",
18
  "JJR": "comparative adjective",
 
175
  await asyncio.sleep(i)
176
 
177
  async def generate_token_labels(case, model="gpt-4o"):
178
+ tokens = sp_tokenize(case)
179
  sorted_cols = list(sorted(features.keys()))
180
  example = {}
181
  for idx, labels in enumerate(list(await asyncio.gather(
multi_head_model.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DebertaV2Config, DebertaV2Model, DebertaV2PreTrainedModel
2
+ import torch.nn as nn
3
+
4
+
5
+ class MultiHeadModelConfig(DebertaV2Config):
6
+ def __init__(self, label_maps=None, num_labels_dict=None, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.label_maps = label_maps or {}
9
+ self.num_labels_dict = num_labels_dict or {}
10
+
11
+ def to_dict(self):
12
+ output = super().to_dict()
13
+ output["label_maps"] = self.label_maps
14
+ output["num_labels_dict"] = self.num_labels_dict
15
+ return output
16
+
17
+
18
+ class MultiHeadModel(DebertaV2PreTrainedModel):
19
+ def __init__(self, config: MultiHeadModelConfig):
20
+ super().__init__(config)
21
+
22
+ self.deberta = DebertaV2Model(config)
23
+ self.classifiers = nn.ModuleDict()
24
+
25
+ hidden_size = config.hidden_size
26
+ for label_name, n_labels in config.num_labels_dict.items():
27
+ self.classifiers[label_name] = nn.Linear(hidden_size, n_labels)
28
+
29
+ # Initialize newly added weights
30
+ self.post_init()
31
+
32
+ def forward(
33
+ self,
34
+ input_ids=None,
35
+ attention_mask=None,
36
+ token_type_ids=None,
37
+ labels_dict=None,
38
+ **kwargs
39
+ ):
40
+ """
41
+ labels_dict: a dict of { label_name: (batch_size, seq_len) } with label ids.
42
+ If provided, we compute and return the sum of CE losses.
43
+ """
44
+ outputs = self.deberta(
45
+ input_ids=input_ids,
46
+ attention_mask=attention_mask,
47
+ token_type_ids=token_type_ids,
48
+ **kwargs
49
+ )
50
+
51
+ sequence_output = outputs.last_hidden_state # (batch_size, seq_len, hidden_size)
52
+
53
+ logits_dict = {}
54
+ for label_name, classifier in self.classifiers.items():
55
+ logits_dict[label_name] = classifier(sequence_output)
56
+
57
+ total_loss = None
58
+ loss_dict = {}
59
+ if labels_dict is not None:
60
+ # We'll sum the losses from each head
61
+ loss_fct = nn.CrossEntropyLoss()
62
+ total_loss = 0.0
63
+
64
+ for label_name, logits in logits_dict.items():
65
+ if label_name not in labels_dict:
66
+ continue
67
+ label_ids = labels_dict[label_name]
68
+
69
+ # A typical approach for token classification:
70
+ # We ignore positions where label_ids == -100
71
+ active_loss = label_ids != -100 # shape (bs, seq_len)
72
+
73
+ # flatten everything
74
+ active_logits = logits.view(-1, logits.shape[-1])[active_loss.view(-1)]
75
+ active_labels = label_ids.view(-1)[active_loss.view(-1)]
76
+
77
+ loss = loss_fct(active_logits, active_labels)
78
+ loss_dict[label_name] = loss.item()
79
+ total_loss += loss
80
+
81
+ if labels_dict is not None:
82
+ # return (loss, predictions)
83
+ return total_loss, logits_dict
84
+ else:
85
+ # just return predictions
86
+ return logits_dict
multi_task_classifier.py → multi_head_trainer.py RENAMED
@@ -1,275 +1,126 @@
1
- from datasets import DatasetDict, load_from_disk
2
  from sklearn.metrics import classification_report, precision_recall_fscore_support
3
  from transformers import (
4
- DebertaV2Config,
5
- DebertaV2Model,
6
- DebertaV2PreTrainedModel,
7
  DebertaV2TokenizerFast,
8
  Trainer,
9
  TrainingArguments,
10
  )
11
- import argparse
12
- import logging.config
13
  import numpy as np
14
  import torch
15
- import torch.nn as nn
16
 
17
- from utils import default_logging_config, get_uniq_training_labels, show_examples
18
 
19
  logger = logging.getLogger(__name__)
20
 
21
- arg_parser = argparse.ArgumentParser(description="Train multi-task model.")
22
- arg_parser.add_argument("-A", "--accumulation-steps", help="Gradient accumulation steps.",
23
- action="store", type=int, default=8)
24
- arg_parser.add_argument("--data-only", help='Show training data info and exit.',
25
- action="store_true", default=False)
26
- arg_parser.add_argument("--data-path", help="Load training dataset from specified path.",
27
- action="store", default="./training_data")
28
- arg_parser.add_argument("-E", "--train-epochs", help="Number of epochs to train for.",
29
- action="store", type=int, default=3)
30
- arg_parser.add_argument("-V", "--eval-batch-size", help="Per device eval batch size.",
31
- action="store", type=int, default=2)
32
- arg_parser.add_argument("--from-base", help="Load a base model.",
33
- action="store", default=None,
34
- choices=[
35
- "microsoft/deberta-v3-base", # Requires --deberta-v3
36
- "microsoft/deberta-v3-large", # Requires --deberta-v3
37
- # More?
38
- ])
39
- arg_parser.add_argument("-L", "--learning-rate", help="Learning rate.",
40
- action="store", type=float, default=5e-5)
41
- arg_parser.add_argument("--mini", help='Train model using small subset of examples for pipeline testing.',
42
- action="store_true", default=False)
43
- arg_parser.add_argument("--save-path", help="Save final model to specified path.",
44
- action="store", default="./final")
45
- arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>",
46
- action="store", default=None)
47
- arg_parser.add_argument("--train", help='Train model using loaded examples.',
48
- action="store_true", default=False)
49
- arg_parser.add_argument("-T", "--train-batch-size", help="Per device train batch size.",
50
- action="store", type=int, default=2)
51
- args = arg_parser.parse_args()
52
- logging.config.dictConfig(default_logging_config)
53
- logger.info(f"Args {args}")
54
 
55
  # ------------------------------------------------------------------------------
56
- # Load dataset and show examples for manual inspection
57
- # ------------------------------------------------------------------------------
58
-
59
- loaded_dataset = load_from_disk(args.data_path)
60
- show_examples(loaded_dataset, args.show)
61
-
62
- # ------------------------------------------------------------------------------
63
- # Convert label analysis data into label sets for each head
64
- # ------------------------------------------------------------------------------
65
-
66
- ALL_LABELS = {col: list(vals) for col, vals in get_uniq_training_labels(loaded_dataset).items()}
67
-
68
- LABEL2ID = {
69
- feat_name: {label: i for i, label in enumerate(ALL_LABELS[feat_name])}
70
- for feat_name in ALL_LABELS
71
- }
72
- ID2LABEL = {
73
- feat_name: {i: label for label, i in LABEL2ID[feat_name].items()}
74
- for feat_name in LABEL2ID
75
- }
76
-
77
- # Each head's number of labels:
78
- NUM_LABELS_DICT = {k: len(v) for k, v in ALL_LABELS.items()}
79
-
80
- if args.data_only:
81
- exit()
82
-
83
- # ------------------------------------------------------------------------------
84
- # Create a custom config that can store our multi-label info
85
- # ------------------------------------------------------------------------------
86
-
87
- class MultiHeadModelConfig(DebertaV2Config):
88
- def __init__(self, label_maps=None, num_labels_dict=None, **kwargs):
89
- super().__init__(**kwargs)
90
- self.label_maps = label_maps or {}
91
- self.num_labels_dict = num_labels_dict or {}
92
-
93
- def to_dict(self):
94
- output = super().to_dict()
95
- output["label_maps"] = self.label_maps
96
- output["num_labels_dict"] = self.num_labels_dict
97
- return output
98
-
99
- # ------------------------------------------------------------------------------
100
- # Define a multi-head model
101
  # ------------------------------------------------------------------------------
102
 
103
- class MultiHeadModel(DebertaV2PreTrainedModel):
104
- def __init__(self, config: MultiHeadModelConfig):
105
- super().__init__(config)
106
-
107
- self.deberta = DebertaV2Model(config)
108
- self.classifiers = nn.ModuleDict()
109
-
110
- hidden_size = config.hidden_size
111
- for label_name, n_labels in config.num_labels_dict.items():
112
- self.classifiers[label_name] = nn.Linear(hidden_size, n_labels)
113
-
114
- # Initialize newly added weights
115
- self.post_init()
116
 
117
- def forward(
118
- self,
119
- input_ids=None,
120
- attention_mask=None,
121
- token_type_ids=None,
122
- labels_dict=None,
123
- **kwargs
124
- ):
125
  """
126
- labels_dict: a dict of { label_name: (batch_size, seq_len) } with label ids.
127
- If provided, we compute and return the sum of CE losses.
 
 
128
  """
129
- outputs = self.deberta(
130
- input_ids=input_ids,
131
- attention_mask=attention_mask,
132
- token_type_ids=token_type_ids,
133
- **kwargs
 
 
 
 
 
134
  )
135
 
136
- sequence_output = outputs.last_hidden_state # (batch_size, seq_len, hidden_size)
137
-
138
- logits_dict = {}
139
- for label_name, classifier in self.classifiers.items():
140
- logits_dict[label_name] = classifier(sequence_output)
141
-
142
- total_loss = None
143
- loss_dict = {}
144
- if labels_dict is not None:
145
- # We'll sum the losses from each head
146
- loss_fct = nn.CrossEntropyLoss()
147
- total_loss = 0.0
148
-
149
- for label_name, logits in logits_dict.items():
150
- if label_name not in labels_dict:
151
- continue
152
- label_ids = labels_dict[label_name]
153
-
154
- # A typical approach for token classification:
155
- # We ignore positions where label_ids == -100
156
- active_loss = label_ids != -100 # shape (bs, seq_len)
157
-
158
- # flatten everything
159
- active_logits = logits.view(-1, logits.shape[-1])[active_loss.view(-1)]
160
- active_labels = label_ids.view(-1)[active_loss.view(-1)]
161
-
162
- loss = loss_fct(active_logits, active_labels)
163
- loss_dict[label_name] = loss.item()
164
- total_loss += loss
165
-
166
- if labels_dict is not None:
167
- # return (loss, predictions)
168
- return total_loss, logits_dict
169
  else:
170
- # just return predictions
171
- return logits_dict
172
-
173
- # ------------------------------------------------------------------------------
174
- # Tokenize with max_length=512, stride=128, and subword alignment
175
- # ------------------------------------------------------------------------------
176
-
177
- def tokenize_and_align_labels(examples):
178
- """
179
- For each example, the tokenizer may produce multiple overlapping
180
- chunks if the tokens exceed 512 subwords. Each chunk will be
181
- length=512, with a stride=128 for the next chunk.
182
- We'll align labels so that subwords beyond the first in a token get -100.
183
- """
184
- # We rely on is_split_into_words=True because examples["tokens"] is a list of token strings.
185
- tokenized_batch = tokenizer(
186
- examples["tokens"],
187
- is_split_into_words=True,
188
- max_length=512,
189
- stride=128,
190
- truncation=True,
191
- return_overflowing_tokens=True,
192
- return_offsets_mapping=False, # not mandatory for basic alignment
193
- padding="max_length"
194
- )
195
-
196
- # The tokenizer returns "overflow_to_sample_mapping", telling us
197
- # which original example index each chunk corresponds to.
198
- # If the tokenizer didn't need to create overflows, the key might be missing
199
- if "overflow_to_sample_mapping" not in tokenized_batch:
200
- # No overflow => each input corresponds 1:1 with the original example
201
- sample_map = [i for i in range(len(tokenized_batch["input_ids"]))]
202
- else:
203
- sample_map = tokenized_batch["overflow_to_sample_mapping"]
204
-
205
- # We'll build lists for final outputs.
206
- # For each chunk i, we produce:
207
- # "input_ids"[i], "attention_mask"[i], plus per-feature label IDs.
208
- final_input_ids = []
209
- final_attention_mask = []
210
- final_labels_columns = {feat: [] for feat in ALL_LABELS} # store one label-sequence per chunk
211
-
212
- for i in range(len(tokenized_batch["input_ids"])):
213
- # chunk i
214
- chunk_input_ids = tokenized_batch["input_ids"][i]
215
- chunk_attn_mask = tokenized_batch["attention_mask"][i]
216
-
217
- original_index = sample_map[i] # which example in the original batch
218
- word_ids = tokenized_batch.word_ids(batch_index=i)
219
-
220
- # We'll build label arrays for each feature
221
- chunk_labels_dict = {}
222
-
223
- for feat_name in ALL_LABELS:
224
- # The UD token-level labels for the *original* example
225
- token_labels = examples[feat_name][original_index] # e.g. length T
226
- chunk_label_ids = []
227
-
228
- previous_word_id = None
229
- for w_id in word_ids:
230
- if w_id is None:
231
- # special token (CLS, SEP, padding)
232
- chunk_label_ids.append(-100)
233
- else:
234
- # If it's the same word_id as before, it's a subword => label = -100
235
- if w_id == previous_word_id:
236
  chunk_label_ids.append(-100)
237
  else:
238
- # New token => use the actual label
239
- label_str = token_labels[w_id]
240
- label_id = LABEL2ID[feat_name][label_str]
241
- chunk_label_ids.append(label_id)
242
- previous_word_id = w_id
243
-
244
- chunk_labels_dict[feat_name] = chunk_label_ids
245
-
246
- final_input_ids.append(chunk_input_ids)
247
- final_attention_mask.append(chunk_attn_mask)
248
- for feat_name in ALL_LABELS:
249
- final_labels_columns[feat_name].append(chunk_labels_dict[feat_name])
250
-
251
- # Return the new "flattened" set of chunks
252
- # So the "map" call will expand each example → multiple chunk examples.
253
- result = {
254
- "input_ids": final_input_ids,
255
- "attention_mask": final_attention_mask,
256
- }
257
- # We'll store each feature's label IDs in separate columns (e.g. labels_xpos, labels_deprel, etc.)
258
- for feat_name in ALL_LABELS:
259
- result[f"labels_{feat_name}"] = final_labels_columns[feat_name]
260
-
261
- return result
 
 
 
 
262
 
263
  # ------------------------------------------------------------------------------
264
  # Trainer Setup
265
  # ------------------------------------------------------------------------------
266
 
267
  class MultiHeadTrainer(Trainer):
 
 
 
268
 
269
  def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
270
  # 1) Gather all your per-feature labels from inputs
271
  _labels_dict = {}
272
- for feat_name in ALL_LABELS:
273
  key = f"labels_{feat_name}"
274
  if key in inputs:
275
  _labels_dict[feat_name] = inputs[key]
@@ -299,7 +150,7 @@ class MultiHeadTrainer(Trainer):
299
  def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None):
300
  # 1) gather the "labels_xxx" columns
301
  _labels_dict = {}
302
- for feat_name in ALL_LABELS:
303
  key = f"labels_{feat_name}"
304
  if key in inputs:
305
  _labels_dict[feat_name] = inputs[key]
@@ -317,7 +168,7 @@ class MultiHeadTrainer(Trainer):
317
  # The trainer expects a triple: (loss, predictions, labels)
318
  # - 'predictions' can be the dictionary
319
  # - 'labels' can be the dictionary of label IDs
320
- return (loss, logits_dict, _labels_dict)
321
 
322
 
323
  def multi_head_classification_reports(logits_dict, labels_dict, id2label_dict):
@@ -423,127 +274,182 @@ def multi_head_compute_metrics(logits_dict, labels_dict):
423
 
424
  return results
425
 
426
- # ------------------------------------------------------------------------------
427
- # Instantiate model and tokenizer
428
- # ------------------------------------------------------------------------------
429
 
430
- if args.from_base:
431
- model_name_or_path = args.from_base
432
- multi_head_model = MultiHeadModel.from_pretrained(
433
- model_name_or_path,
434
- config=MultiHeadModelConfig.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  model_name_or_path,
436
- num_labels_dict=NUM_LABELS_DICT,
437
- label_maps=ALL_LABELS
 
 
 
438
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  )
440
- else:
441
- model_name_or_path = args.save_path
442
- # For evaluation, always load the saved checkpoint without overriding the config.
443
- multi_head_model = MultiHeadModel.from_pretrained(model_name_or_path)
444
- # EXTREMELY IMPORTANT!
445
- # Override the label mapping based on the stored config to ensure consistency with training time ordering.
446
- ALL_LABELS = multi_head_model.config.label_maps
447
- LABEL2ID = {feat: {label: i for i, label in enumerate(ALL_LABELS[feat])} for feat in ALL_LABELS}
448
- ID2LABEL = {feat: {i: label for label, i in LABEL2ID[feat].items()} for feat in LABEL2ID}
449
- logger.info(f"using {model_name_or_path}")
450
-
451
- # Check if GPU is usable
452
- if torch.cuda.is_available():
453
- device = torch.device("cuda")
454
- elif torch.backends.mps.is_available(): # For Apple Silicon MPS
455
- device = torch.device("mps")
456
- else:
457
- device = torch.device("cpu")
458
- logger.info(f"using {device}")
459
- multi_head_model.to(device)
460
-
461
- tokenizer = DebertaV2TokenizerFast.from_pretrained(
462
- model_name_or_path,
463
- add_prefix_space=True,
464
- )
465
-
466
- # ------------------------------------------------------------------------------
467
- # Shuffle, (optionally) sample, and tokenize final merged dataset
468
- # ------------------------------------------------------------------------------
469
-
470
- if args.mini:
471
- loaded_dataset = DatasetDict({
472
- "train": loaded_dataset["train"].shuffle(seed=42).select(range(1000)),
473
- "validation": loaded_dataset["validation"].shuffle(seed=42).select(range(100)),
474
- "test": loaded_dataset["test"].shuffle(seed=42).select(range(100)),
475
- })
476
-
477
- # remove_columns => remove old "text", "tokens", etc. so we keep only model inputs
478
- tokenized_dataset = loaded_dataset.map(
479
- tokenize_and_align_labels,
480
- batched=True,
481
- remove_columns=loaded_dataset["train"].column_names,
482
- )
483
 
484
- # ------------------------------------------------------------------------------
485
- # Train the model!
486
- # ------------------------------------------------------------------------------
487
-
488
- """
489
- Current bests:
490
-
491
- deberta-v3-base:
492
- num_train_epochs=3,
493
- learning_rate=5e-5,
494
- per_device_train_batch_size=2,
495
- gradient_accumulation_steps=8,
496
- """
497
-
498
- training_args = TrainingArguments(
499
- # Evaluate less frequently or keep the same
500
- eval_strategy="epoch",
501
- num_train_epochs=args.train_epochs,
502
- learning_rate=args.learning_rate,
503
-
504
- output_dir="training_output",
505
- overwrite_output_dir=True,
506
- remove_unused_columns=False, # important to keep the labels_xxx columns
507
-
508
- logging_dir="training_logs",
509
- logging_steps=100,
510
-
511
- # Effective batch size = train_batch_size x gradient_accumulation_steps
512
- per_device_train_batch_size=args.train_batch_size,
513
- gradient_accumulation_steps=args.accumulation_steps,
514
 
515
- per_device_eval_batch_size=args.eval_batch_size,
516
- )
 
517
 
518
- trainer = MultiHeadTrainer(
519
- model=multi_head_model,
520
- args=training_args,
521
- train_dataset=tokenized_dataset["train"],
522
- eval_dataset=tokenized_dataset["validation"],
523
- )
524
 
525
- if args.train:
526
- trainer.train()
527
- trainer.evaluate()
528
- trainer.save_model(args.save_path)
529
- tokenizer.save_pretrained(args.save_path)
 
530
 
531
- # ------------------------------------------------------------------------------
532
- # Evaluate the model!
533
- # ------------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
 
535
- pred_output = trainer.predict(tokenized_dataset["test"])
536
- pred_logits_dict = pred_output.predictions
537
- pred_labels_dict = pred_output.label_ids
538
- id2label_dict = ID2LABEL # from earlier definitions
539
-
540
- # 1) Calculate metrics
541
- metrics = multi_head_compute_metrics(pred_logits_dict, pred_labels_dict)
542
- for k,v in metrics.items():
543
- print(f"{k}: {v:.4f}")
544
-
545
- # 2) Print classification reports
546
- reports = multi_head_classification_reports(pred_logits_dict, pred_labels_dict, id2label_dict)
547
- for head_name, rstr in reports.items():
548
- print(f"----- {head_name} classification report -----")
549
- print(rstr)
 
 
 
 
 
 
 
 
 
 
 
 
1
  from sklearn.metrics import classification_report, precision_recall_fscore_support
2
  from transformers import (
 
 
 
3
  DebertaV2TokenizerFast,
4
  Trainer,
5
  TrainingArguments,
6
  )
7
+ import logging
 
8
  import numpy as np
9
  import torch
 
10
 
11
+ from multi_head_model import MultiHeadModel, MultiHeadModelConfig
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # ------------------------------------------------------------------------------
17
+ # Tokenize with max_length=512, stride=128, and subword alignment
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # ------------------------------------------------------------------------------
19
 
20
+ class ExampleAligner:
21
+ def __init__(self, all_labels, label2id):
22
+ self.all_labels = all_labels
23
+ self.label2id = label2id
 
 
 
 
 
 
 
 
 
24
 
25
+ def tokenize_and_align_labels(self, examples):
 
 
 
 
 
 
 
26
  """
27
+ For each example, the tokenizer may produce multiple overlapping
28
+ chunks if the tokens exceed 512 subwords. Each chunk will be
29
+ length=512, with a stride=128 for the next chunk.
30
+ We'll align labels so that subwords beyond the first in a token get -100.
31
  """
32
+ # We rely on is_split_into_words=True because examples["tokens"] is a list of token strings.
33
+ tokenized_batch = tokenizer(
34
+ examples["tokens"],
35
+ is_split_into_words=True,
36
+ max_length=512,
37
+ stride=128,
38
+ truncation=True,
39
+ return_overflowing_tokens=True,
40
+ return_offsets_mapping=False, # not mandatory for basic alignment
41
+ padding="max_length"
42
  )
43
 
44
+ # The tokenizer returns "overflow_to_sample_mapping", telling us
45
+ # which original example index each chunk corresponds to.
46
+ # If the tokenizer didn't need to create overflows, the key might be missing
47
+ if "overflow_to_sample_mapping" not in tokenized_batch:
48
+ # No overflow => each input corresponds 1:1 with the original example
49
+ sample_map = [i for i in range(len(tokenized_batch["input_ids"]))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  else:
51
+ sample_map = tokenized_batch["overflow_to_sample_mapping"]
52
+
53
+ # We'll build lists for final outputs.
54
+ # For each chunk i, we produce:
55
+ # "input_ids"[i], "attention_mask"[i], plus per-feature label IDs.
56
+ final_input_ids = []
57
+ final_attention_mask = []
58
+ final_labels_columns = {feat: [] for feat in self.all_labels} # store one label-sequence per chunk
59
+
60
+ for i in range(len(tokenized_batch["input_ids"])):
61
+ # chunk i
62
+ chunk_input_ids = tokenized_batch["input_ids"][i]
63
+ chunk_attn_mask = tokenized_batch["attention_mask"][i]
64
+
65
+ original_index = sample_map[i] # which example in the original batch
66
+ word_ids = tokenized_batch.word_ids(batch_index=i)
67
+
68
+ # We'll build label arrays for each feature
69
+ chunk_labels_dict = {}
70
+
71
+ for feat_name in self.all_labels:
72
+ # The UD token-level labels for the *original* example
73
+ token_labels = examples[feat_name][original_index] # e.g. length T
74
+ chunk_label_ids = []
75
+
76
+ previous_word_id = None
77
+ for w_id in word_ids:
78
+ if w_id is None:
79
+ # special token (CLS, SEP, padding)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  chunk_label_ids.append(-100)
81
  else:
82
+ # If it's the same word_id as before, it's a subword => label = -100
83
+ if w_id == previous_word_id:
84
+ chunk_label_ids.append(-100)
85
+ else:
86
+ # New token => use the actual label
87
+ label_str = token_labels[w_id]
88
+ label_id = self.label2id[feat_name][label_str]
89
+ chunk_label_ids.append(label_id)
90
+ previous_word_id = w_id
91
+
92
+ chunk_labels_dict[feat_name] = chunk_label_ids
93
+
94
+ final_input_ids.append(chunk_input_ids)
95
+ final_attention_mask.append(chunk_attn_mask)
96
+ for feat_name in self.all_labels:
97
+ final_labels_columns[feat_name].append(chunk_labels_dict[feat_name])
98
+
99
+ # Return the new "flattened" set of chunks
100
+ # So the "map" call will expand each example → multiple chunk examples.
101
+ result = {
102
+ "input_ids": final_input_ids,
103
+ "attention_mask": final_attention_mask,
104
+ }
105
+ # We'll store each feature's label IDs in separate columns (e.g. labels_xpos, labels_deprel, etc.)
106
+ for feat_name in self.all_labels:
107
+ result[f"labels_{feat_name}"] = final_labels_columns[feat_name]
108
+
109
+ return result
110
 
111
  # ------------------------------------------------------------------------------
112
  # Trainer Setup
113
  # ------------------------------------------------------------------------------
114
 
115
  class MultiHeadTrainer(Trainer):
116
+ def __init__(self, all_labels, **kwargs):
117
+ self.all_labels = all_labels
118
+ super().__init__(**kwargs)
119
 
120
  def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
121
  # 1) Gather all your per-feature labels from inputs
122
  _labels_dict = {}
123
+ for feat_name in self.all_labels:
124
  key = f"labels_{feat_name}"
125
  if key in inputs:
126
  _labels_dict[feat_name] = inputs[key]
 
150
  def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None):
151
  # 1) gather the "labels_xxx" columns
152
  _labels_dict = {}
153
+ for feat_name in self.all_labels:
154
  key = f"labels_{feat_name}"
155
  if key in inputs:
156
  _labels_dict[feat_name] = inputs[key]
 
168
  # The trainer expects a triple: (loss, predictions, labels)
169
  # - 'predictions' can be the dictionary
170
  # - 'labels' can be the dictionary of label IDs
171
+ return loss, logits_dict, _labels_dict
172
 
173
 
174
  def multi_head_classification_reports(logits_dict, labels_dict, id2label_dict):
 
274
 
275
  return results
276
 
 
 
 
277
 
278
+ if __name__ == "__main__":
279
+ from datasets import DatasetDict, load_from_disk
280
+ import argparse
281
+ import logging.config
282
+
283
+ from utils import default_logging_config, get_torch_device, get_uniq_training_labels, show_examples
284
+
285
+ arg_parser = argparse.ArgumentParser(description="Train multi-task model.")
286
+ arg_parser.add_argument("-A", "--accumulation-steps", help="Gradient accumulation steps.",
287
+ action="store", type=int, default=8)
288
+ arg_parser.add_argument("--data-only", help='Show training data info and exit.',
289
+ action="store_true", default=False)
290
+ arg_parser.add_argument("--data-path", help="Load training dataset from specified path.",
291
+ action="store", default="./training_data")
292
+ arg_parser.add_argument("-E", "--train-epochs", help="Number of epochs to train for.",
293
+ action="store", type=int, default=3)
294
+ arg_parser.add_argument("-V", "--eval-batch-size", help="Per device eval batch size.",
295
+ action="store", type=int, default=2)
296
+ arg_parser.add_argument("--from-base", help="Load a base model.",
297
+ action="store", default=None,
298
+ choices=[
299
+ "microsoft/deberta-v3-base", # Requires --deberta-v3
300
+ "microsoft/deberta-v3-large", # Requires --deberta-v3
301
+ # More?
302
+ ])
303
+ arg_parser.add_argument("-L", "--learning-rate", help="Learning rate.",
304
+ action="store", type=float, default=5e-5)
305
+ arg_parser.add_argument("--mini", help='Train model using small subset of examples for pipeline testing.',
306
+ action="store_true", default=False)
307
+ arg_parser.add_argument("--save-path", help="Save final model to specified path.",
308
+ action="store", default="./final")
309
+ arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>",
310
+ action="store", default=None)
311
+ arg_parser.add_argument("--train", help='Train model using loaded examples.',
312
+ action="store_true", default=False)
313
+ arg_parser.add_argument("-T", "--train-batch-size", help="Per device train batch size.",
314
+ action="store", type=int, default=2)
315
+ args = arg_parser.parse_args()
316
+ logging.config.dictConfig(default_logging_config)
317
+ logger.info(f"Args {args}")
318
+
319
+ # ------------------------------------------------------------------------------
320
+ # Load dataset and show examples for manual inspection
321
+ # ------------------------------------------------------------------------------
322
+
323
+ loaded_dataset = load_from_disk(args.data_path)
324
+ show_examples(loaded_dataset, args.show)
325
+
326
+ ## ------------------------------------------------------------------------------
327
+ ## Instantiate model and tokenizer
328
+ ## ------------------------------------------------------------------------------
329
+
330
+ if args.from_base:
331
+ # Convert label analysis data into label sets for each head
332
+ ALL_LABELS = {col: list(vals) for col, vals in get_uniq_training_labels(loaded_dataset).items()}
333
+ LABEL2ID = {
334
+ feat_name: {label: i for i, label in enumerate(ALL_LABELS[feat_name])}
335
+ for feat_name in ALL_LABELS
336
+ }
337
+ ID2LABEL = {
338
+ feat_name: {i: label for label, i in LABEL2ID[feat_name].items()}
339
+ for feat_name in LABEL2ID
340
+ }
341
+ # Each head's number of labels:
342
+ NUM_LABELS_DICT = {k: len(v) for k, v in ALL_LABELS.items()}
343
+ model_name_or_path = args.from_base
344
+ multi_head_model = MultiHeadModel.from_pretrained(
345
  model_name_or_path,
346
+ config=MultiHeadModelConfig.from_pretrained(
347
+ model_name_or_path,
348
+ num_labels_dict=NUM_LABELS_DICT,
349
+ label_maps=ALL_LABELS
350
+ )
351
  )
352
+ else:
353
+ model_name_or_path = args.save_path
354
+ # For evaluation, always load the saved checkpoint without overriding the config.
355
+ multi_head_model = MultiHeadModel.from_pretrained(model_name_or_path)
356
+ # EXTREMELY IMPORTANT!
357
+ # Override the label mapping based on the stored config to ensure consistency with training time ordering.
358
+ ALL_LABELS = multi_head_model.config.label_maps
359
+ LABEL2ID = {feat: {label: i for i, label in enumerate(ALL_LABELS[feat])} for feat in ALL_LABELS}
360
+ ID2LABEL = {feat: {i: label for label, i in LABEL2ID[feat].items()} for feat in LABEL2ID}
361
+ logger.info(f"using {model_name_or_path}")
362
+
363
+ # Check if GPU is usable
364
+ device = get_torch_device()
365
+ multi_head_model.to(device)
366
+
367
+ tokenizer = DebertaV2TokenizerFast.from_pretrained(
368
+ model_name_or_path,
369
+ add_prefix_space=True,
370
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
+ # ------------------------------------------------------------------------------
373
+ # Shuffle, (optionally) sample, and tokenize final merged dataset
374
+ # ------------------------------------------------------------------------------
375
+
376
+ if args.mini:
377
+ loaded_dataset = DatasetDict({
378
+ "train": loaded_dataset["train"].shuffle(seed=42).select(range(1000)),
379
+ "validation": loaded_dataset["validation"].shuffle(seed=42).select(range(100)),
380
+ "test": loaded_dataset["test"].shuffle(seed=42).select(range(100)),
381
+ })
382
+
383
+ # remove_columns => remove old "text", "tokens", etc. so we keep only model inputs
384
+ example_aligner = ExampleAligner(ALL_LABELS, LABEL2ID)
385
+ tokenized_dataset = loaded_dataset.map(
386
+ example_aligner.tokenize_and_align_labels,
387
+ batched=True,
388
+ remove_columns=loaded_dataset["train"].column_names,
389
+ )
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
+ # ------------------------------------------------------------------------------
392
+ # Train the model!
393
+ # ------------------------------------------------------------------------------
394
 
395
+ """
396
+ Current bests:
 
 
 
 
397
 
398
+ deberta-v3-base:
399
+ num_train_epochs=3,
400
+ learning_rate=5e-5,
401
+ per_device_train_batch_size=2,
402
+ gradient_accumulation_steps=8,
403
+ """
404
 
405
+ trainer = MultiHeadTrainer(
406
+ ALL_LABELS,
407
+ model=multi_head_model,
408
+ args=TrainingArguments(
409
+ # Evaluate less frequently or keep the same
410
+ eval_strategy="epoch",
411
+ num_train_epochs=args.train_epochs,
412
+ learning_rate=args.learning_rate,
413
+
414
+ output_dir="training_output",
415
+ overwrite_output_dir=True,
416
+ remove_unused_columns=False, # important to keep the labels_xxx columns
417
+
418
+ logging_dir="training_logs",
419
+ logging_steps=100,
420
+
421
+ # Effective batch size = train_batch_size x gradient_accumulation_steps
422
+ per_device_train_batch_size=args.train_batch_size,
423
+ gradient_accumulation_steps=args.accumulation_steps,
424
+
425
+ per_device_eval_batch_size=args.eval_batch_size,
426
+ ),
427
+ train_dataset=tokenized_dataset["train"],
428
+ eval_dataset=tokenized_dataset["validation"],
429
+ )
430
 
431
+ if args.train:
432
+ trainer.train()
433
+ trainer.evaluate()
434
+ trainer.save_model(args.save_path)
435
+ tokenizer.save_pretrained(args.save_path)
436
+
437
+ # ------------------------------------------------------------------------------
438
+ # Evaluate the model!
439
+ # ------------------------------------------------------------------------------
440
+
441
+ pred_output = trainer.predict(tokenized_dataset["test"])
442
+ pred_logits_dict = pred_output.predictions
443
+ pred_labels_dict = pred_output.label_ids
444
+ id2label_dict = ID2LABEL # from earlier definitions
445
+
446
+ # 1) Calculate metrics
447
+ metrics = multi_head_compute_metrics(pred_logits_dict, pred_labels_dict)
448
+ for k,v in metrics.items():
449
+ print(f"{k}: {v:.4f}")
450
+
451
+ # 2) Print classification reports
452
+ reports = multi_head_classification_reports(pred_logits_dict, pred_labels_dict, id2label_dict)
453
+ for head_name, rstr in reports.items():
454
+ print(f"----- {head_name} classification report -----")
455
+ print(rstr)
multi_predict.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DebertaV2TokenizerFast
2
+ import torch
3
+
4
+ from multi_head_model import MultiHeadModel
5
+ from utils import get_torch_device, sp_tokenize
6
+
7
+
8
+ class MultiHeadPredictor:
9
+ def __init__(self, model_name_or_path: str):
10
+ self.tokenizer = DebertaV2TokenizerFast.from_pretrained(model_name_or_path, add_prefix_space=True)
11
+ self.model = MultiHeadModel.from_pretrained(model_name_or_path)
12
+ self.id2label = self.model.config.label_maps
13
+
14
+ self.device = get_torch_device()
15
+ self.model.to(self.device)
16
+ self.model.eval()
17
+
18
+
19
+ def predict(self, text: str):
20
+ """
21
+ Perform multi-headed token classification on a single piece of text.
22
+
23
+ :param text: The raw text string.
24
+
25
+ :return: A dict with {head_name: [predicted_label_for_each_token]} for the tokens in `text`.
26
+ """
27
+ raw_tokens = sp_tokenize(text)
28
+
29
+ # We'll do a single-example batch to replicate training chunk logic.
30
+ # is_split_into_words=True => we pass a list of tokens, not a single string.
31
+ # This returns possibly multiple overflows if the sequence is long:
32
+ encoded = self.tokenizer(
33
+ raw_tokens,
34
+ is_split_into_words=True,
35
+ max_length=512,
36
+ stride=128,
37
+ truncation=True,
38
+ return_overflowing_tokens=True,
39
+ return_offsets_mapping=False,
40
+ padding="max_length"
41
+ )
42
+
43
+ # 'overflow_to_sample_mapping' indicates which chunk maps back to this example's index
44
+ # For a single example, they should all map to 0, but let's handle it anyway:
45
+ sample_map = encoded.get("overflow_to_sample_mapping", [0] * len(encoded["input_ids"]))
46
+
47
+ # We'll store predictions for each chunk, then reconcile them.
48
+ chunk_preds = []
49
+ chunk_word_ids = []
50
+
51
+ # Model forward:
52
+ # We iterate over each chunk, move them to device, and compute logits_dict.
53
+ for i in range(len(encoded["input_ids"])):
54
+ # Build a batch of size 1 for chunk i
55
+ input_ids_tensor = torch.tensor([encoded["input_ids"][i]], dtype=torch.long).to(self.device)
56
+ attention_mask_tensor = torch.tensor([encoded["attention_mask"][i]], dtype=torch.long).to(self.device)
57
+
58
+ # The model forward returns logits_dict since we don't provide labels_dict
59
+ with torch.no_grad():
60
+ logits_dict = self.model(
61
+ input_ids=input_ids_tensor,
62
+ attention_mask=attention_mask_tensor
63
+ ) # shape for each head: (1, seq_len, num_labels)
64
+
65
+ # Convert each head's logits to predicted IDs
66
+ # logits_dict is {head_name: Tensor of shape [1, seq_len, num_labels]}
67
+ pred_ids_dict = {}
68
+ for head_name, logits in logits_dict.items():
69
+ # shape (1, seq_len, num_labels)
70
+ preds = torch.argmax(logits, dim=-1) # => shape (1, seq_len)
71
+ # Move to CPU numpy
72
+ pred_ids_dict[head_name] = preds[0].cpu().numpy().tolist()
73
+
74
+ # Keep track of predicted IDs + the corresponding word_ids for alignment
75
+ chunk_preds.append(pred_ids_dict)
76
+
77
+ # Also store the chunk's word_ids (so we can map subwords -> actual token index)
78
+ # Note: you MUST call `tokenizer.word_ids(batch_index=i)` with is_split_into_words=True
79
+ # which is only available on a batched encoding. So we re-call it carefully:
80
+ word_ids_chunk = encoded.word_ids(batch_index=i)
81
+ chunk_word_ids.append(word_ids_chunk)
82
+
83
+ # Now we combine chunk predictions into a single sequence of token-level labels.
84
+ # Because we used a sliding window, tokens appear in multiple chunks. We can
85
+ # keep the first occurrence, or we might want to carefully handle overlaps.
86
+ # Below is a simplistic approach: We will read each chunk in order, skipping
87
+ # positions with word_id=None or repeated word_id (subword).
88
+
89
+ # We'll build final predictions for each head at the *token* level (not subword).
90
+ # For each original token index from 0..len(raw_tokens)-1, we pick the first chunk
91
+ # that includes it, and the subword=first-subword label.
92
+
93
+ # We define an array of "final predictions" for each head, size = len(raw_tokens).
94
+ final_pred_labels = {**{
95
+ "text": text,
96
+ "tokens": raw_tokens,
97
+ }, **{
98
+ head: ["O"] * len(raw_tokens) # or "O" or "" placeholder
99
+ for head in self.id2label.keys()
100
+ }}
101
+
102
+ # We'll keep track of which tokens we've already assigned. Each chunk is
103
+ # processed left-to-right, so effectively the earliest chunk covers it.
104
+ assigned_tokens = set()
105
+
106
+ for i, pred_dict in enumerate(chunk_preds):
107
+ w_ids = chunk_word_ids[i]
108
+ for pos, w_id in enumerate(w_ids):
109
+ if w_id is None:
110
+ # This is a special token (CLS, SEP, or padding)
111
+ continue
112
+ if w_id in assigned_tokens:
113
+ # Already assigned from a previous chunk
114
+ continue
115
+
116
+ # If it's the first subword of that token, record the predicted label for each head.
117
+ # pred_dict[head_name] is a list of length seq_len
118
+ for head_name, pred_ids in pred_dict.items():
119
+ label_id = pred_ids[pos]
120
+ label_str = self.id2label[head_name][label_id]
121
+ final_pred_labels[head_name][w_id] = label_str
122
+
123
+ assigned_tokens.add(w_id)
124
+
125
+ return final_pred_labels
126
+
127
+
128
+ if __name__ == "__main__":
129
+ predictor = MultiHeadPredictor("./o3-mini_20250218_final")
130
+
131
+ test_cases = [
132
+ "How to convince my parents to let me get a Ball python?",
133
+ ]
134
+ for case in test_cases:
135
+ predictions = predictor.predict(case)
136
+ for head_name, labels in predictions.items():
137
+ print(f"{head_name}: {labels}")
ud_multi_task_classifier.py DELETED
@@ -1,551 +0,0 @@
1
- from datasets import DatasetDict, load_from_disk
2
- from sklearn.metrics import classification_report, precision_recall_fscore_support
3
- from transformers import (
4
- DebertaV2Config,
5
- DebertaV2Model,
6
- DebertaV2PreTrainedModel,
7
- DebertaV2TokenizerFast,
8
- Trainer,
9
- TrainingArguments,
10
- )
11
- import argparse
12
- import logging.config
13
- import numpy as np
14
- import torch
15
- import torch.nn as nn
16
-
17
- from utils import default_logging_config, get_uniq_training_labels, show_examples
18
-
19
- logger = logging.getLogger(__name__)
20
-
21
- arg_parser = argparse.ArgumentParser(description="Train multi-task model.")
22
- arg_parser.add_argument("-A", "--accumulation-steps", help="Gradient accumulation steps.",
23
- action="store", type=int, default=8)
24
- arg_parser.add_argument("--data-only", help='Show training data info and exit.',
25
- action="store_true", default=False)
26
- arg_parser.add_argument("--data-path", help="Load training dataset from specified path.",
27
- action="store", default="./training_data")
28
- arg_parser.add_argument("-E", "--train-epochs", help="Number of epochs to train for.",
29
- action="store", type=int, default=3)
30
- arg_parser.add_argument("-V", "--eval-batch-size", help="Per device eval batch size.",
31
- action="store", type=int, default=2)
32
- arg_parser.add_argument("--from-base", help="Load a base model.",
33
- action="store", default=None,
34
- choices=[
35
- "microsoft/deberta-v3-base", # Requires --deberta-v3
36
- "microsoft/deberta-v3-large", # Requires --deberta-v3
37
- # More?
38
- ])
39
- arg_parser.add_argument("-L", "--learning-rate", help="Learning rate.",
40
- action="store", type=float, default=5e-5)
41
- arg_parser.add_argument("--mini", help='Train model using small subset of examples for pipeline testing.',
42
- action="store_true", default=False)
43
- arg_parser.add_argument("--save-path", help="Save final model to specified path.",
44
- action="store", default="./final")
45
- arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>",
46
- action="store", default=None)
47
- arg_parser.add_argument("--train", help='Train model using loaded examples.',
48
- action="store_true", default=False)
49
- arg_parser.add_argument("-T", "--train-batch-size", help="Per device train batch size.",
50
- action="store", type=int, default=2)
51
- args = arg_parser.parse_args()
52
- logging.config.dictConfig(default_logging_config)
53
- logger.info(f"Args {args}")
54
-
55
-
56
-
57
- # ------------------------------------------------------------------------------
58
- # Load dataset and show examples for manual inspection
59
- # ------------------------------------------------------------------------------
60
-
61
- loaded_dataset = load_from_disk(args.data_path)
62
- show_examples(loaded_dataset, args.show)
63
-
64
- # ------------------------------------------------------------------------------
65
- # Convert label analysis data into label sets for each head
66
- # ------------------------------------------------------------------------------
67
-
68
- ALL_LABELS = {col: list(vals) for col, vals in get_uniq_training_labels(loaded_dataset).items()}
69
-
70
- LABEL2ID = {
71
- feat_name: {label: i for i, label in enumerate(ALL_LABELS[feat_name])}
72
- for feat_name in ALL_LABELS
73
- }
74
- ID2LABEL = {
75
- feat_name: {i: label for label, i in LABEL2ID[feat_name].items()}
76
- for feat_name in LABEL2ID
77
- }
78
-
79
- # Each head's number of labels:
80
- NUM_LABELS_DICT = {k: len(v) for k, v in ALL_LABELS.items()}
81
-
82
- if args.data_only:
83
- exit()
84
-
85
- # ------------------------------------------------------------------------------
86
- # Create a custom config that can store our multi-label info
87
- # ------------------------------------------------------------------------------
88
-
89
- class MultiHeadModelConfig(DebertaV2Config):
90
- def __init__(self, label_maps=None, num_labels_dict=None, **kwargs):
91
- super().__init__(**kwargs)
92
- self.label_maps = label_maps or {}
93
- self.num_labels_dict = num_labels_dict or {}
94
-
95
- def to_dict(self):
96
- output = super().to_dict()
97
- output["label_maps"] = self.label_maps
98
- output["num_labels_dict"] = self.num_labels_dict
99
- return output
100
-
101
- # ------------------------------------------------------------------------------
102
- # Define a multi-head model
103
- # ------------------------------------------------------------------------------
104
-
105
- class MultiHeadModel(DebertaV2PreTrainedModel):
106
- def __init__(self, config: MultiHeadModelConfig):
107
- super().__init__(config)
108
-
109
- self.deberta = DebertaV2Model(config)
110
- self.classifiers = nn.ModuleDict()
111
-
112
- hidden_size = config.hidden_size
113
- for label_name, n_labels in config.num_labels_dict.items():
114
- self.classifiers[label_name] = nn.Linear(hidden_size, n_labels)
115
-
116
- # Initialize newly added weights
117
- self.post_init()
118
-
119
- def forward(
120
- self,
121
- input_ids=None,
122
- attention_mask=None,
123
- token_type_ids=None,
124
- labels_dict=None,
125
- **kwargs
126
- ):
127
- """
128
- labels_dict: a dict of { label_name: (batch_size, seq_len) } with label ids.
129
- If provided, we compute and return the sum of CE losses.
130
- """
131
- outputs = self.deberta(
132
- input_ids=input_ids,
133
- attention_mask=attention_mask,
134
- token_type_ids=token_type_ids,
135
- **kwargs
136
- )
137
-
138
- sequence_output = outputs.last_hidden_state # (batch_size, seq_len, hidden_size)
139
-
140
- logits_dict = {}
141
- for label_name, classifier in self.classifiers.items():
142
- logits_dict[label_name] = classifier(sequence_output)
143
-
144
- total_loss = None
145
- loss_dict = {}
146
- if labels_dict is not None:
147
- # We'll sum the losses from each head
148
- loss_fct = nn.CrossEntropyLoss()
149
- total_loss = 0.0
150
-
151
- for label_name, logits in logits_dict.items():
152
- if label_name not in labels_dict:
153
- continue
154
- label_ids = labels_dict[label_name]
155
-
156
- # A typical approach for token classification:
157
- # We ignore positions where label_ids == -100
158
- active_loss = label_ids != -100 # shape (bs, seq_len)
159
-
160
- # flatten everything
161
- active_logits = logits.view(-1, logits.shape[-1])[active_loss.view(-1)]
162
- active_labels = label_ids.view(-1)[active_loss.view(-1)]
163
-
164
- loss = loss_fct(active_logits, active_labels)
165
- loss_dict[label_name] = loss.item()
166
- total_loss += loss
167
-
168
- if labels_dict is not None:
169
- # return (loss, predictions)
170
- return total_loss, logits_dict
171
- else:
172
- # just return predictions
173
- return logits_dict
174
-
175
- # ------------------------------------------------------------------------------
176
- # Tokenize with max_length=512, stride=128, and subword alignment
177
- # ------------------------------------------------------------------------------
178
-
179
- def tokenize_and_align_labels(examples):
180
- """
181
- For each example, the tokenizer may produce multiple overlapping
182
- chunks if the tokens exceed 512 subwords. Each chunk will be
183
- length=512, with a stride=128 for the next chunk.
184
- We'll align labels so that subwords beyond the first in a token get -100.
185
- """
186
- # We rely on is_split_into_words=True because examples["tokens"] is a list of token strings.
187
- tokenized_batch = tokenizer(
188
- examples["tokens"],
189
- is_split_into_words=True,
190
- max_length=512,
191
- stride=128,
192
- truncation=True,
193
- return_overflowing_tokens=True,
194
- return_offsets_mapping=False, # not mandatory for basic alignment
195
- padding="max_length"
196
- )
197
-
198
- # The tokenizer returns "overflow_to_sample_mapping", telling us
199
- # which original example index each chunk corresponds to.
200
- # If the tokenizer didn't need to create overflows, the key might be missing
201
- if "overflow_to_sample_mapping" not in tokenized_batch:
202
- # No overflow => each input corresponds 1:1 with the original example
203
- sample_map = [i for i in range(len(tokenized_batch["input_ids"]))]
204
- else:
205
- sample_map = tokenized_batch["overflow_to_sample_mapping"]
206
-
207
- # We'll build lists for final outputs.
208
- # For each chunk i, we produce:
209
- # "input_ids"[i], "attention_mask"[i], plus per-feature label IDs.
210
- final_input_ids = []
211
- final_attention_mask = []
212
- final_labels_columns = {feat: [] for feat in ALL_LABELS} # store one label-sequence per chunk
213
-
214
- for i in range(len(tokenized_batch["input_ids"])):
215
- # chunk i
216
- chunk_input_ids = tokenized_batch["input_ids"][i]
217
- chunk_attn_mask = tokenized_batch["attention_mask"][i]
218
-
219
- original_index = sample_map[i] # which example in the original batch
220
- word_ids = tokenized_batch.word_ids(batch_index=i)
221
-
222
- # We'll build label arrays for each feature
223
- chunk_labels_dict = {}
224
-
225
- for feat_name in ALL_LABELS:
226
- # The UD token-level labels for the *original* example
227
- token_labels = examples[feat_name][original_index] # e.g. length T
228
- chunk_label_ids = []
229
-
230
- previous_word_id = None
231
- for w_id in word_ids:
232
- if w_id is None:
233
- # special token (CLS, SEP, padding)
234
- chunk_label_ids.append(-100)
235
- else:
236
- # If it's the same word_id as before, it's a subword => label = -100
237
- if w_id == previous_word_id:
238
- chunk_label_ids.append(-100)
239
- else:
240
- # New token => use the actual label
241
- label_str = token_labels[w_id]
242
- label_id = LABEL2ID[feat_name][label_str]
243
- chunk_label_ids.append(label_id)
244
- previous_word_id = w_id
245
-
246
- chunk_labels_dict[feat_name] = chunk_label_ids
247
-
248
- final_input_ids.append(chunk_input_ids)
249
- final_attention_mask.append(chunk_attn_mask)
250
- for feat_name in ALL_LABELS:
251
- final_labels_columns[feat_name].append(chunk_labels_dict[feat_name])
252
-
253
- # Return the new "flattened" set of chunks
254
- # So the "map" call will expand each example → multiple chunk examples.
255
- result = {
256
- "input_ids": final_input_ids,
257
- "attention_mask": final_attention_mask,
258
- }
259
- # We'll store each feature's label IDs in separate columns (e.g. labels_xpos, labels_deprel, etc.)
260
- for feat_name in ALL_LABELS:
261
- result[f"labels_{feat_name}"] = final_labels_columns[feat_name]
262
-
263
- return result
264
-
265
- # ------------------------------------------------------------------------------
266
- # Trainer Setup
267
- # ------------------------------------------------------------------------------
268
-
269
- class MultiHeadTrainer(Trainer):
270
-
271
- def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
272
- # 1) Gather all your per-feature labels from inputs
273
- _labels_dict = {}
274
- for feat_name in ALL_LABELS:
275
- key = f"labels_{feat_name}"
276
- if key in inputs:
277
- _labels_dict[feat_name] = inputs[key]
278
-
279
- # 2) Remove them so they don't get passed incorrectly to the model
280
- for key in list(inputs.keys()):
281
- if key.startswith("labels_"):
282
- del inputs[key]
283
-
284
- # 3) Call model(...) with _labels_dict
285
- outputs = model(**inputs, labels_dict=_labels_dict)
286
- # 'outputs' is (loss, logits_dict) in training/eval mode
287
- loss, logits_dict = outputs
288
-
289
- # Optional: if your special param is used upstream for some logic,
290
- # you can handle it here or pass it along. For example:
291
- if num_items_in_batch is not None:
292
- # ... do something if needed ...
293
- pass
294
-
295
- if return_outputs:
296
- # Return (loss, logits_dict) so Trainer sees logits_dict as predictions
297
- return (loss, logits_dict)
298
- else:
299
- return loss
300
-
301
- def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None):
302
- # 1) gather the "labels_xxx" columns
303
- _labels_dict = {}
304
- for feat_name in ALL_LABELS:
305
- key = f"labels_{feat_name}"
306
- if key in inputs:
307
- _labels_dict[feat_name] = inputs[key]
308
- del inputs[key]
309
-
310
- # 2) forward pass without those keys
311
- with torch.no_grad():
312
- outputs = model(**inputs, labels_dict=_labels_dict)
313
-
314
- loss, logits_dict = outputs # you are returning (loss, dict-of-arrays)
315
-
316
- if prediction_loss_only:
317
- return (loss, None, None)
318
-
319
- # The trainer expects a triple: (loss, predictions, labels)
320
- # - 'predictions' can be the dictionary
321
- # - 'labels' can be the dictionary of label IDs
322
- return (loss, logits_dict, _labels_dict)
323
-
324
-
325
- def multi_head_classification_reports(logits_dict, labels_dict, id2label_dict):
326
- """
327
- For each head, generate a classification report (precision, recall, f1, etc. per class).
328
- Return them as a dict: {head_name: "string report"}.
329
- :param logits_dict: dict of {head_name: np.array(batch_size, seq_len, num_classes)}
330
- :param labels_dict: dict of {head_name: np.array(batch_size, seq_len)}
331
- :param id2label_dict: dict of {head_name: {id: label_str}}
332
- :return: A dict of classification-report strings, one per head.
333
- """
334
- reports = {}
335
-
336
- for head_name, logits in logits_dict.items():
337
- if head_name not in labels_dict:
338
- continue
339
-
340
- predictions = np.argmax(logits, axis=-1)
341
- valid_preds, valid_labels = [], []
342
- for pred_seq, label_seq in zip(predictions, labels_dict[head_name]):
343
- for p, lab in zip(pred_seq, label_seq):
344
- if lab != -100:
345
- valid_preds.append(p)
346
- valid_labels.append(lab)
347
-
348
- if len(valid_preds) == 0:
349
- reports[head_name] = "No valid predictions."
350
- continue
351
-
352
- # Convert numeric IDs to string labels
353
- valid_preds_str = [id2label_dict[head_name][p] for p in valid_preds]
354
- valid_labels_str = [id2label_dict[head_name][l] for l in valid_labels]
355
-
356
- # Generate the per-class classification report
357
- report_str = classification_report(
358
- valid_labels_str,
359
- valid_preds_str,
360
- zero_division=0
361
- )
362
- reports[head_name] = report_str
363
-
364
- return reports
365
-
366
-
367
- def multi_head_compute_metrics(logits_dict, labels_dict):
368
- """
369
- For each head (e.g. xpos, deprel, Case, etc.), computes:
370
- - Accuracy
371
- - Precision (macro/micro)
372
- - Recall (macro/micro)
373
- - F1 (macro/micro)
374
-
375
- :param logits_dict: dict of {head_name: np.array of shape (batch_size, seq_len, num_classes)}
376
- :param labels_dict: dict of {head_name: np.array of shape (batch_size, seq_len)}
377
- :return: A dict with aggregated metrics. Keys prefixed by head_name, e.g. "xpos_accuracy", "xpos_f1_macro", etc.
378
- """
379
- # We'll accumulate metrics in one big dictionary, keyed by "<head>_<metric>"
380
- results = {}
381
-
382
- for head_name, logits in logits_dict.items():
383
- if head_name not in labels_dict:
384
- # In case there's a mismatch or a head we didn't provide labels for
385
- continue
386
-
387
- # (batch_size, seq_len, num_classes)
388
- predictions = np.argmax(logits, axis=-1) # => (batch_size, seq_len)
389
-
390
- # Flatten ignoring positions where label == -100
391
- valid_preds, valid_labels = [], []
392
- for pred_seq, label_seq in zip(predictions, labels_dict[head_name]):
393
- for p, lab in zip(pred_seq, label_seq):
394
- if lab != -100:
395
- valid_preds.append(p)
396
- valid_labels.append(lab)
397
-
398
- valid_preds = np.array(valid_preds)
399
- valid_labels = np.array(valid_labels)
400
-
401
- if len(valid_preds) == 0:
402
- # No valid data for this head—skip
403
- continue
404
-
405
- # Overall token-level accuracy
406
- accuracy = (valid_preds == valid_labels).mean()
407
-
408
- # Macro average => treat each class equally
409
- precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
410
- valid_labels, valid_preds, average="macro", zero_division=0
411
- )
412
-
413
- # Micro average => aggregate across all classes
414
- precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
415
- valid_labels, valid_preds, average="micro", zero_division=0
416
- )
417
-
418
- results[f"{head_name}_accuracy"] = accuracy
419
- results[f"{head_name}_precision_macro"] = precision_macro
420
- results[f"{head_name}_recall_macro"] = recall_macro
421
- results[f"{head_name}_f1_macro"] = f1_macro
422
- results[f"{head_name}_precision_micro"] = precision_micro
423
- results[f"{head_name}_recall_micro"] = recall_micro
424
- results[f"{head_name}_f1_micro"] = f1_micro
425
-
426
- return results
427
-
428
- # ------------------------------------------------------------------------------
429
- # Instantiate model and tokenizer
430
- # ------------------------------------------------------------------------------
431
-
432
- if args.from_base:
433
- model_name_or_path = args.from_base
434
- multi_head_model = MultiHeadModel.from_pretrained(
435
- model_name_or_path,
436
- config=MultiHeadModelConfig.from_pretrained(
437
- model_name_or_path,
438
- num_labels_dict=NUM_LABELS_DICT,
439
- label_maps=ALL_LABELS
440
- )
441
- )
442
- else:
443
- model_name_or_path = args.save_path
444
- # For evaluation, always load the saved checkpoint without overriding the config.
445
- multi_head_model = MultiHeadModel.from_pretrained(model_name_or_path)
446
- # EXTREMELY IMPORTANT!
447
- # Override the label mapping based on the stored config to ensure consistency with training time ordering.
448
- ALL_LABELS = multi_head_model.config.label_maps
449
- LABEL2ID = {feat: {label: i for i, label in enumerate(ALL_LABELS[feat])} for feat in ALL_LABELS}
450
- ID2LABEL = {feat: {i: label for label, i in LABEL2ID[feat].items()} for feat in LABEL2ID}
451
- logger.info(f"using {model_name_or_path}")
452
-
453
- # Check if GPU is usable
454
- if torch.cuda.is_available():
455
- device = torch.device("cuda")
456
- elif torch.backends.mps.is_available(): # For Apple Silicon MPS
457
- device = torch.device("mps")
458
- else:
459
- device = torch.device("cpu")
460
- logger.info(f"using {device}")
461
- multi_head_model.to(device)
462
-
463
- tokenizer = DebertaV2TokenizerFast.from_pretrained(
464
- model_name_or_path,
465
- add_prefix_space=True,
466
- )
467
-
468
- # ------------------------------------------------------------------------------
469
- # Shuffle, (optionally) sample, and tokenize final merged dataset
470
- # ------------------------------------------------------------------------------
471
-
472
- if args.mini:
473
- loaded_dataset = DatasetDict({
474
- "train": loaded_dataset["train"].shuffle(seed=42).select(range(1000)),
475
- "validation": loaded_dataset["validation"].shuffle(seed=42).select(range(100)),
476
- "test": loaded_dataset["test"].shuffle(seed=42).select(range(100)),
477
- })
478
-
479
- # remove_columns => remove old "text", "tokens", etc. so we keep only model inputs
480
- tokenized_dataset = loaded_dataset.map(
481
- tokenize_and_align_labels,
482
- batched=True,
483
- remove_columns=loaded_dataset["train"].column_names,
484
- )
485
-
486
- # ------------------------------------------------------------------------------
487
- # Train the model!
488
- # ------------------------------------------------------------------------------
489
-
490
- """
491
- Current bests:
492
-
493
- deberta-v3-base:
494
- num_train_epochs=3,
495
- learning_rate=5e-5,
496
- per_device_train_batch_size=2,
497
- gradient_accumulation_steps=8,
498
- """
499
-
500
- training_args = TrainingArguments(
501
- # Evaluate less frequently or keep the same
502
- eval_strategy="epoch",
503
- num_train_epochs=args.train_epochs,
504
- learning_rate=args.learning_rate,
505
-
506
- output_dir="training_output",
507
- overwrite_output_dir=True,
508
- remove_unused_columns=False, # important to keep the labels_xxx columns
509
-
510
- logging_dir="training_logs",
511
- logging_steps=100,
512
-
513
- # Effective batch size = train_batch_size x gradient_accumulation_steps
514
- per_device_train_batch_size=args.train_batch_size,
515
- gradient_accumulation_steps=args.accumulation_steps,
516
-
517
- per_device_eval_batch_size=args.eval_batch_size,
518
- )
519
-
520
- trainer = MultiHeadTrainer(
521
- model=multi_head_model,
522
- args=training_args,
523
- train_dataset=tokenized_dataset["train"],
524
- eval_dataset=tokenized_dataset["validation"],
525
- )
526
-
527
- if args.train:
528
- trainer.train()
529
- trainer.evaluate()
530
- trainer.save_model(args.save_path)
531
- tokenizer.save_pretrained(args.save_path)
532
-
533
- # ------------------------------------------------------------------------------
534
- # Evaluate the model!
535
- # ------------------------------------------------------------------------------
536
-
537
- pred_output = trainer.predict(tokenized_dataset["test"])
538
- pred_logits_dict = pred_output.predictions
539
- pred_labels_dict = pred_output.label_ids
540
- id2label_dict = ID2LABEL # from earlier definitions
541
-
542
- # 1) Calculate metrics
543
- metrics = multi_head_compute_metrics(pred_logits_dict, pred_labels_dict)
544
- for k,v in metrics.items():
545
- print(f"{k}: {v:.4f}")
546
-
547
- # 2) Print classification reports
548
- reports = multi_head_classification_reports(pred_logits_dict, pred_labels_dict, id2label_dict)
549
- for head_name, rstr in reports.items():
550
- print(f"----- {head_name} classification report -----")
551
- print(rstr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/__init__.py CHANGED
@@ -1,9 +1,15 @@
1
  from datasets import DatasetDict
2
  from typing import Optional
 
3
  import logging
 
 
4
 
5
  logger = logging.getLogger(__name__)
6
 
 
 
 
7
  default_logging_config = {
8
  "version": 1,
9
  "disable_existing_loggers": False,
@@ -27,6 +33,17 @@ default_logging_config = {
27
  }
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
30
  def get_uniq_training_labels(ds: DatasetDict, columns_to_exclude: set[str] = None):
31
  columns_to_train_on = [k for k in ds["train"].features.keys() if k not in (
32
  {"text", "tokens"} if columns_to_exclude is None else columns_to_exclude)]
@@ -72,3 +89,7 @@ def show_examples(ds: DatasetDict, show_expr: Optional[str]):
72
  logger.info(f"Example {i}:")
73
  for feature in examples_to_show.keys():
74
  logger.info(f" {feature}: {examples_to_show[feature][i]}")
 
 
 
 
 
1
  from datasets import DatasetDict
2
  from typing import Optional
3
+ import itertools
4
  import logging
5
+ import sentencepiece as spm
6
+ import torch
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
+ sp = spm.SentencePieceProcessor()
11
+ sp.LoadFromFile(f"sp.model")
12
+
13
  default_logging_config = {
14
  "version": 1,
15
  "disable_existing_loggers": False,
 
33
  }
34
 
35
 
36
+ def get_torch_device():
37
+ if torch.cuda.is_available():
38
+ device = torch.device("cuda")
39
+ elif torch.backends.mps.is_available(): # For Apple Silicon MPS
40
+ device = torch.device("mps")
41
+ else:
42
+ device = torch.device("cpu")
43
+ logger.info(f"using {device}")
44
+ return device
45
+
46
+
47
  def get_uniq_training_labels(ds: DatasetDict, columns_to_exclude: set[str] = None):
48
  columns_to_train_on = [k for k in ds["train"].features.keys() if k not in (
49
  {"text", "tokens"} if columns_to_exclude is None else columns_to_exclude)]
 
89
  logger.info(f"Example {i}:")
90
  for feature in examples_to_show.keys():
91
  logger.info(f" {feature}: {examples_to_show[feature][i]}")
92
+
93
+
94
+ def sp_tokenize(text: str):
95
+ return list(itertools.chain.from_iterable([s.strip("▁").split("▁") for s in sp.EncodeAsPieces(text)]))