from sklearn.metrics import classification_report, precision_recall_fscore_support from transformers import ( DebertaV2TokenizerFast, EarlyStoppingCallback, Trainer, TrainingArguments, ) import logging import numpy as np import torch from multi_head_model import MultiHeadModel, MultiHeadModelConfig logger = logging.getLogger(__name__) # ------------------------------------------------------------------------------ # Tokenize with max_length=512, stride=128, and subword alignment # ------------------------------------------------------------------------------ class ExampleAligner: def __init__(self, all_labels, label2id): self.all_labels = all_labels self.label2id = label2id def tokenize_and_align_labels(self, examples): """ For each example, the tokenizer may produce multiple overlapping chunks if the tokens exceed 512 subwords. Each chunk will be length=512, with a stride=128 for the next chunk. We'll align labels so that subwords beyond the first in a token get -100. """ # We rely on is_split_into_words=True because examples["tokens"] is a list of token strings. tokenized_batch = tokenizer( examples["tokens"], is_split_into_words=True, max_length=512, stride=128, truncation=True, return_overflowing_tokens=True, return_offsets_mapping=False, # not mandatory for basic alignment padding="max_length" ) # The tokenizer returns "overflow_to_sample_mapping", telling us # which original example index each chunk corresponds to. # If the tokenizer didn't need to create overflows, the key might be missing if "overflow_to_sample_mapping" not in tokenized_batch: # No overflow => each input corresponds 1:1 with the original example sample_map = [i for i in range(len(tokenized_batch["input_ids"]))] else: sample_map = tokenized_batch["overflow_to_sample_mapping"] # We'll build lists for final outputs. # For each chunk i, we produce: # "input_ids"[i], "attention_mask"[i], plus per-feature label IDs. final_input_ids = [] final_attention_mask = [] final_labels_columns = {feat: [] for feat in self.all_labels} # store one label-sequence per chunk for i in range(len(tokenized_batch["input_ids"])): # chunk i chunk_input_ids = tokenized_batch["input_ids"][i] chunk_attn_mask = tokenized_batch["attention_mask"][i] original_index = sample_map[i] # which example in the original batch word_ids = tokenized_batch.word_ids(batch_index=i) # We'll build label arrays for each feature chunk_labels_dict = {} for feat_name in self.all_labels: # The UD token-level labels for the *original* example token_labels = examples[feat_name][original_index] # e.g. length T chunk_label_ids = [] previous_word_id = None for w_id in word_ids: if w_id is None: # special token (CLS, SEP, padding) chunk_label_ids.append(-100) else: # If it's the same word_id as before, it's a subword => label = -100 if w_id == previous_word_id: chunk_label_ids.append(-100) else: # New token => use the actual label label_str = token_labels[w_id] label_id = self.label2id[feat_name][label_str] chunk_label_ids.append(label_id) previous_word_id = w_id chunk_labels_dict[feat_name] = chunk_label_ids final_input_ids.append(chunk_input_ids) final_attention_mask.append(chunk_attn_mask) for feat_name in self.all_labels: final_labels_columns[feat_name].append(chunk_labels_dict[feat_name]) # Return the new "flattened" set of chunks # So the "map" call will expand each example → multiple chunk examples. result = { "input_ids": final_input_ids, "attention_mask": final_attention_mask, } # We'll store each feature's label IDs in separate columns (e.g. labels_xpos, labels_deprel, etc.) for feat_name in self.all_labels: result[f"labels_{feat_name}"] = final_labels_columns[feat_name] return result # ------------------------------------------------------------------------------ # Trainer Setup # ------------------------------------------------------------------------------ class MultiHeadTrainer(Trainer): def __init__(self, all_labels, **kwargs): self.all_labels = all_labels super().__init__(**kwargs) def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # 1) Gather all your per-feature labels from inputs _labels_dict = {} for feat_name in self.all_labels: key = f"labels_{feat_name}" if key in inputs: _labels_dict[feat_name] = inputs[key] # 2) Remove them so they don't get passed incorrectly to the model for key in list(inputs.keys()): if key.startswith("labels_"): del inputs[key] # 3) Call model(...) with _labels_dict outputs = model(**inputs, labels_dict=_labels_dict) # 'outputs' is (loss, logits_dict) in training/eval mode loss, logits_dict = outputs # Optional: if your special param is used upstream for some logic, # you can handle it here or pass it along. For example: if num_items_in_batch is not None: # ... do something if needed ... pass if return_outputs: # Return (loss, logits_dict) so Trainer sees logits_dict as predictions return loss, logits_dict else: return loss def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None): # 1) gather the "labels_xxx" columns _labels_dict = {} for feat_name in self.all_labels: key = f"labels_{feat_name}" if key in inputs: _labels_dict[feat_name] = inputs[key] del inputs[key] # 2) forward pass without those keys with torch.no_grad(): outputs = model(**inputs, labels_dict=_labels_dict) loss, logits_dict = outputs # you are returning (loss, dict-of-arrays) if prediction_loss_only: return (loss, None, None) # The trainer expects a triple: (loss, predictions, labels) # - 'predictions' can be the dictionary # - 'labels' can be the dictionary of label IDs return loss, logits_dict, _labels_dict def multi_head_classification_reports(logits_dict, labels_dict, id2label_dict): """ For each head, generate a classification report (precision, recall, f1, etc. per class). Return them as a dict: {head_name: "string report"}. :param logits_dict: dict of {head_name: np.array(batch_size, seq_len, num_classes)} :param labels_dict: dict of {head_name: np.array(batch_size, seq_len)} :param id2label_dict: dict of {head_name: {id: label_str}} :return: A dict of classification-report strings, one per head. """ reports = {} for head_name, logits in logits_dict.items(): if head_name not in labels_dict: continue predictions = np.argmax(logits, axis=-1) valid_preds, valid_labels = [], [] for pred_seq, label_seq in zip(predictions, labels_dict[head_name]): for p, lab in zip(pred_seq, label_seq): if lab != -100: valid_preds.append(p) valid_labels.append(lab) if len(valid_preds) == 0: reports[head_name] = "No valid predictions." continue # Convert numeric IDs to string labels valid_preds_str = [id2label_dict[head_name][p] for p in valid_preds] valid_labels_str = [id2label_dict[head_name][l] for l in valid_labels] # Generate the per-class classification report report_str = classification_report( valid_labels_str, valid_preds_str, zero_division=0 ) reports[head_name] = report_str return reports def multi_head_compute_metrics(logits_dict, labels_dict): """ For each head (e.g. xpos, deprel, Case, etc.), computes: - Accuracy - Precision (macro/micro) - Recall (macro/micro) - F1 (macro/micro) :param logits_dict: dict of {head_name: np.array of shape (batch_size, seq_len, num_classes)} :param labels_dict: dict of {head_name: np.array of shape (batch_size, seq_len)} :return: A dict with aggregated metrics. Keys prefixed by head_name, e.g. "xpos_accuracy", "xpos_f1_macro", etc. """ # We'll accumulate metrics in one big dictionary, keyed by "_" results = {} for head_name, logits in logits_dict.items(): if head_name not in labels_dict: # In case there's a mismatch or a head we didn't provide labels for continue # (batch_size, seq_len, num_classes) predictions = np.argmax(logits, axis=-1) # => (batch_size, seq_len) # Flatten ignoring positions where label == -100 valid_preds, valid_labels = [], [] for pred_seq, label_seq in zip(predictions, labels_dict[head_name]): for p, lab in zip(pred_seq, label_seq): if lab != -100: valid_preds.append(p) valid_labels.append(lab) valid_preds = np.array(valid_preds) valid_labels = np.array(valid_labels) if len(valid_preds) == 0: # No valid data for this head—skip continue # Overall token-level accuracy accuracy = (valid_preds == valid_labels).mean() # Macro average => treat each class equally precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support( valid_labels, valid_preds, average="macro", zero_division=0 ) # Micro average => aggregate across all classes precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support( valid_labels, valid_preds, average="micro", zero_division=0 ) results[f"{head_name}_accuracy"] = accuracy results[f"{head_name}_precision_macro"] = precision_macro results[f"{head_name}_recall_macro"] = recall_macro results[f"{head_name}_f1_macro"] = f1_macro results[f"{head_name}_precision_micro"] = precision_micro results[f"{head_name}_recall_micro"] = recall_micro results[f"{head_name}_f1_micro"] = f1_micro return results def multi_head_compute_metrics_aggregate_f1(logits_dict, labels_dict): results = multi_head_compute_metrics(logits_dict, labels_dict) # your existing function # Grab all keys that end with "_f1_macro" f1_keys = [k for k in results.keys() if k.endswith("_f1_macro")] if not f1_keys: # fallback in case no F1 keys exist final_f1 = 0.0 else: final_f1 = np.mean([results[k] for k in f1_keys]) final_dict = {"f1_macro": final_f1} # Optionally keep all others for logging final_dict.update(results) return final_dict def compute_metrics_for_trainer(eval_pred): # This is the HF Trainer signature: eval_pred is usually (logits, labels) or (predictions, label_ids) logits_dict, labels_dict = eval_pred.predictions, eval_pred.label_ids return multi_head_compute_metrics_aggregate_f1(logits_dict, labels_dict) if __name__ == "__main__": from datasets import DatasetDict, load_from_disk import argparse import logging.config from utils import default_logging_config, get_torch_device, get_uniq_training_labels, show_examples arg_parser = argparse.ArgumentParser(description="Train multi-task model.") arg_parser.add_argument("-A", "--accumulation-steps", help="Gradient accumulation steps.", action="store", type=int, default=12) arg_parser.add_argument("--data-only", help='Show training data info and exit.', action="store_true", default=False) arg_parser.add_argument("--data-path", help="Load training dataset from specified path.", action="store", default="./training_data") arg_parser.add_argument("-E", "--train-epochs", help="Number of epochs to train for.", action="store", type=int, default=10) arg_parser.add_argument("-V", "--eval-batch-size", help="Per device eval batch size.", action="store", type=int, default=2) arg_parser.add_argument("--from-base", help="Load a base model.", action="store", default=None, choices=[ "microsoft/deberta-v3-base", # Requires --deberta-v3 "microsoft/deberta-v3-large", # Requires --deberta-v3 # More? ]) arg_parser.add_argument("-L", "--learning-rate", help="Learning rate.", action="store", type=float, default=2e-5) arg_parser.add_argument("--mini", help='Train model using small subset of examples for pipeline testing.', action="store_true", default=False) arg_parser.add_argument("--save-path", help="Save final model to specified path.", action="store", default="./final") arg_parser.add_argument("--show", help="Show examples: //