from sklearn.metrics import classification_report, precision_recall_fscore_support
from transformers import (
DebertaV2TokenizerFast,
EarlyStoppingCallback,
Trainer,
TrainingArguments,
)
import logging
import numpy as np
import torch
from multi_head_model import MultiHeadModel, MultiHeadModelConfig
logger = logging.getLogger(__name__)
# ------------------------------------------------------------------------------
# Tokenize with max_length=512, stride=128, and subword alignment
# ------------------------------------------------------------------------------
class ExampleAligner:
def __init__(self, all_labels, label2id):
self.all_labels = all_labels
self.label2id = label2id
def tokenize_and_align_labels(self, examples):
"""
For each example, the tokenizer may produce multiple overlapping
chunks if the tokens exceed 512 subwords. Each chunk will be
length=512, with a stride=128 for the next chunk.
We'll align labels so that subwords beyond the first in a token get -100.
"""
# We rely on is_split_into_words=True because examples["tokens"] is a list of token strings.
tokenized_batch = tokenizer(
examples["tokens"],
is_split_into_words=True,
max_length=512,
stride=128,
truncation=True,
return_overflowing_tokens=True,
return_offsets_mapping=False, # not mandatory for basic alignment
padding="max_length"
)
# The tokenizer returns "overflow_to_sample_mapping", telling us
# which original example index each chunk corresponds to.
# If the tokenizer didn't need to create overflows, the key might be missing
if "overflow_to_sample_mapping" not in tokenized_batch:
# No overflow => each input corresponds 1:1 with the original example
sample_map = [i for i in range(len(tokenized_batch["input_ids"]))]
else:
sample_map = tokenized_batch["overflow_to_sample_mapping"]
# We'll build lists for final outputs.
# For each chunk i, we produce:
# "input_ids"[i], "attention_mask"[i], plus per-feature label IDs.
final_input_ids = []
final_attention_mask = []
final_labels_columns = {feat: [] for feat in self.all_labels} # store one label-sequence per chunk
for i in range(len(tokenized_batch["input_ids"])):
# chunk i
chunk_input_ids = tokenized_batch["input_ids"][i]
chunk_attn_mask = tokenized_batch["attention_mask"][i]
original_index = sample_map[i] # which example in the original batch
word_ids = tokenized_batch.word_ids(batch_index=i)
# We'll build label arrays for each feature
chunk_labels_dict = {}
for feat_name in self.all_labels:
# The UD token-level labels for the *original* example
token_labels = examples[feat_name][original_index] # e.g. length T
chunk_label_ids = []
previous_word_id = None
for w_id in word_ids:
if w_id is None:
# special token (CLS, SEP, padding)
chunk_label_ids.append(-100)
else:
# If it's the same word_id as before, it's a subword => label = -100
if w_id == previous_word_id:
chunk_label_ids.append(-100)
else:
# New token => use the actual label
label_str = token_labels[w_id]
label_id = self.label2id[feat_name][label_str]
chunk_label_ids.append(label_id)
previous_word_id = w_id
chunk_labels_dict[feat_name] = chunk_label_ids
final_input_ids.append(chunk_input_ids)
final_attention_mask.append(chunk_attn_mask)
for feat_name in self.all_labels:
final_labels_columns[feat_name].append(chunk_labels_dict[feat_name])
# Return the new "flattened" set of chunks
# So the "map" call will expand each example → multiple chunk examples.
result = {
"input_ids": final_input_ids,
"attention_mask": final_attention_mask,
}
# We'll store each feature's label IDs in separate columns (e.g. labels_xpos, labels_deprel, etc.)
for feat_name in self.all_labels:
result[f"labels_{feat_name}"] = final_labels_columns[feat_name]
return result
# ------------------------------------------------------------------------------
# Trainer Setup
# ------------------------------------------------------------------------------
class MultiHeadTrainer(Trainer):
def __init__(self, all_labels, **kwargs):
self.all_labels = all_labels
super().__init__(**kwargs)
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
# 1) Gather all your per-feature labels from inputs
_labels_dict = {}
for feat_name in self.all_labels:
key = f"labels_{feat_name}"
if key in inputs:
_labels_dict[feat_name] = inputs[key]
# 2) Remove them so they don't get passed incorrectly to the model
for key in list(inputs.keys()):
if key.startswith("labels_"):
del inputs[key]
# 3) Call model(...) with _labels_dict
outputs = model(**inputs, labels_dict=_labels_dict)
# 'outputs' is (loss, logits_dict) in training/eval mode
loss, logits_dict = outputs
# Optional: if your special param is used upstream for some logic,
# you can handle it here or pass it along. For example:
if num_items_in_batch is not None:
# ... do something if needed ...
pass
if return_outputs:
# Return (loss, logits_dict) so Trainer sees logits_dict as predictions
return loss, logits_dict
else:
return loss
def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None):
# 1) gather the "labels_xxx" columns
_labels_dict = {}
for feat_name in self.all_labels:
key = f"labels_{feat_name}"
if key in inputs:
_labels_dict[feat_name] = inputs[key]
del inputs[key]
# 2) forward pass without those keys
with torch.no_grad():
outputs = model(**inputs, labels_dict=_labels_dict)
loss, logits_dict = outputs # you are returning (loss, dict-of-arrays)
if prediction_loss_only:
return (loss, None, None)
# The trainer expects a triple: (loss, predictions, labels)
# - 'predictions' can be the dictionary
# - 'labels' can be the dictionary of label IDs
return loss, logits_dict, _labels_dict
def multi_head_classification_reports(logits_dict, labels_dict, id2label_dict):
"""
For each head, generate a classification report (precision, recall, f1, etc. per class).
Return them as a dict: {head_name: "string report"}.
:param logits_dict: dict of {head_name: np.array(batch_size, seq_len, num_classes)}
:param labels_dict: dict of {head_name: np.array(batch_size, seq_len)}
:param id2label_dict: dict of {head_name: {id: label_str}}
:return: A dict of classification-report strings, one per head.
"""
reports = {}
for head_name, logits in logits_dict.items():
if head_name not in labels_dict:
continue
predictions = np.argmax(logits, axis=-1)
valid_preds, valid_labels = [], []
for pred_seq, label_seq in zip(predictions, labels_dict[head_name]):
for p, lab in zip(pred_seq, label_seq):
if lab != -100:
valid_preds.append(p)
valid_labels.append(lab)
if len(valid_preds) == 0:
reports[head_name] = "No valid predictions."
continue
# Convert numeric IDs to string labels
valid_preds_str = [id2label_dict[head_name][p] for p in valid_preds]
valid_labels_str = [id2label_dict[head_name][l] for l in valid_labels]
# Generate the per-class classification report
report_str = classification_report(
valid_labels_str,
valid_preds_str,
zero_division=0
)
reports[head_name] = report_str
return reports
def multi_head_compute_metrics(logits_dict, labels_dict):
"""
For each head (e.g. xpos, deprel, Case, etc.), computes:
- Accuracy
- Precision (macro/micro)
- Recall (macro/micro)
- F1 (macro/micro)
:param logits_dict: dict of {head_name: np.array of shape (batch_size, seq_len, num_classes)}
:param labels_dict: dict of {head_name: np.array of shape (batch_size, seq_len)}
:return: A dict with aggregated metrics. Keys prefixed by head_name, e.g. "xpos_accuracy", "xpos_f1_macro", etc.
"""
# We'll accumulate metrics in one big dictionary, keyed by "
_"
results = {}
for head_name, logits in logits_dict.items():
if head_name not in labels_dict:
# In case there's a mismatch or a head we didn't provide labels for
continue
# (batch_size, seq_len, num_classes)
predictions = np.argmax(logits, axis=-1) # => (batch_size, seq_len)
# Flatten ignoring positions where label == -100
valid_preds, valid_labels = [], []
for pred_seq, label_seq in zip(predictions, labels_dict[head_name]):
for p, lab in zip(pred_seq, label_seq):
if lab != -100:
valid_preds.append(p)
valid_labels.append(lab)
valid_preds = np.array(valid_preds)
valid_labels = np.array(valid_labels)
if len(valid_preds) == 0:
# No valid data for this head—skip
continue
# Overall token-level accuracy
accuracy = (valid_preds == valid_labels).mean()
# Macro average => treat each class equally
precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
valid_labels, valid_preds, average="macro", zero_division=0
)
# Micro average => aggregate across all classes
precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
valid_labels, valid_preds, average="micro", zero_division=0
)
results[f"{head_name}_accuracy"] = accuracy
results[f"{head_name}_precision_macro"] = precision_macro
results[f"{head_name}_recall_macro"] = recall_macro
results[f"{head_name}_f1_macro"] = f1_macro
results[f"{head_name}_precision_micro"] = precision_micro
results[f"{head_name}_recall_micro"] = recall_micro
results[f"{head_name}_f1_micro"] = f1_micro
return results
def multi_head_compute_metrics_aggregate_f1(logits_dict, labels_dict):
results = multi_head_compute_metrics(logits_dict, labels_dict) # your existing function
# Grab all keys that end with "_f1_macro"
f1_keys = [k for k in results.keys() if k.endswith("_f1_macro")]
if not f1_keys:
# fallback in case no F1 keys exist
final_f1 = 0.0
else:
final_f1 = np.mean([results[k] for k in f1_keys])
final_dict = {"f1_macro": final_f1}
# Optionally keep all others for logging
final_dict.update(results)
return final_dict
def compute_metrics_for_trainer(eval_pred):
# This is the HF Trainer signature: eval_pred is usually (logits, labels) or (predictions, label_ids)
logits_dict, labels_dict = eval_pred.predictions, eval_pred.label_ids
return multi_head_compute_metrics_aggregate_f1(logits_dict, labels_dict)
if __name__ == "__main__":
from datasets import DatasetDict, load_from_disk
import argparse
import logging.config
from utils import default_logging_config, get_torch_device, get_uniq_training_labels, show_examples
arg_parser = argparse.ArgumentParser(description="Train multi-task model.")
arg_parser.add_argument("-A", "--accumulation-steps", help="Gradient accumulation steps.",
action="store", type=int, default=12)
arg_parser.add_argument("--data-only", help='Show training data info and exit.',
action="store_true", default=False)
arg_parser.add_argument("--data-path", help="Load training dataset from specified path.",
action="store", default="./training_data")
arg_parser.add_argument("-E", "--train-epochs", help="Number of epochs to train for.",
action="store", type=int, default=10)
arg_parser.add_argument("-V", "--eval-batch-size", help="Per device eval batch size.",
action="store", type=int, default=2)
arg_parser.add_argument("--from-base", help="Load a base model.",
action="store", default=None,
choices=[
"microsoft/deberta-v3-base", # Requires --deberta-v3
"microsoft/deberta-v3-large", # Requires --deberta-v3
# More?
])
arg_parser.add_argument("-L", "--learning-rate", help="Learning rate.",
action="store", type=float, default=2e-5)
arg_parser.add_argument("--mini", help='Train model using small subset of examples for pipeline testing.',
action="store_true", default=False)
arg_parser.add_argument("--save-path", help="Save final model to specified path.",
action="store", default="./final")
arg_parser.add_argument("--show", help="Show examples: /