| | from sklearn.metrics import classification_report, precision_recall_fscore_support |
| | from transformers import ( |
| | DebertaV2TokenizerFast, |
| | Trainer, |
| | TrainingArguments, |
| | ) |
| | import logging |
| | import numpy as np |
| | import torch |
| |
|
| | from multi_head_model import MultiHeadModel, MultiHeadModelConfig |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class ExampleAligner: |
| | def __init__(self, all_labels, label2id): |
| | self.all_labels = all_labels |
| | self.label2id = label2id |
| |
|
| | def tokenize_and_align_labels(self, examples): |
| | """ |
| | For each example, the tokenizer may produce multiple overlapping |
| | chunks if the tokens exceed 512 subwords. Each chunk will be |
| | length=512, with a stride=128 for the next chunk. |
| | We'll align labels so that subwords beyond the first in a token get -100. |
| | """ |
| | |
| | tokenized_batch = tokenizer( |
| | examples["tokens"], |
| | is_split_into_words=True, |
| | max_length=512, |
| | stride=128, |
| | truncation=True, |
| | return_overflowing_tokens=True, |
| | return_offsets_mapping=False, |
| | padding="max_length" |
| | ) |
| |
|
| | |
| | |
| | |
| | if "overflow_to_sample_mapping" not in tokenized_batch: |
| | |
| | sample_map = [i for i in range(len(tokenized_batch["input_ids"]))] |
| | else: |
| | sample_map = tokenized_batch["overflow_to_sample_mapping"] |
| |
|
| | |
| | |
| | |
| | final_input_ids = [] |
| | final_attention_mask = [] |
| | final_labels_columns = {feat: [] for feat in self.all_labels} |
| |
|
| | for i in range(len(tokenized_batch["input_ids"])): |
| | |
| | chunk_input_ids = tokenized_batch["input_ids"][i] |
| | chunk_attn_mask = tokenized_batch["attention_mask"][i] |
| |
|
| | original_index = sample_map[i] |
| | word_ids = tokenized_batch.word_ids(batch_index=i) |
| |
|
| | |
| | chunk_labels_dict = {} |
| |
|
| | for feat_name in self.all_labels: |
| | |
| | token_labels = examples[feat_name][original_index] |
| | chunk_label_ids = [] |
| |
|
| | previous_word_id = None |
| | for w_id in word_ids: |
| | if w_id is None: |
| | |
| | chunk_label_ids.append(-100) |
| | else: |
| | |
| | if w_id == previous_word_id: |
| | chunk_label_ids.append(-100) |
| | else: |
| | |
| | label_str = token_labels[w_id] |
| | label_id = self.label2id[feat_name][label_str] |
| | chunk_label_ids.append(label_id) |
| | previous_word_id = w_id |
| |
|
| | chunk_labels_dict[feat_name] = chunk_label_ids |
| |
|
| | final_input_ids.append(chunk_input_ids) |
| | final_attention_mask.append(chunk_attn_mask) |
| | for feat_name in self.all_labels: |
| | final_labels_columns[feat_name].append(chunk_labels_dict[feat_name]) |
| |
|
| | |
| | |
| | result = { |
| | "input_ids": final_input_ids, |
| | "attention_mask": final_attention_mask, |
| | } |
| | |
| | for feat_name in self.all_labels: |
| | result[f"labels_{feat_name}"] = final_labels_columns[feat_name] |
| |
|
| | return result |
| |
|
| | |
| | |
| | |
| |
|
| | class MultiHeadTrainer(Trainer): |
| | def __init__(self, all_labels, **kwargs): |
| | self.all_labels = all_labels |
| | super().__init__(**kwargs) |
| |
|
| | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
| | |
| | _labels_dict = {} |
| | for feat_name in self.all_labels: |
| | key = f"labels_{feat_name}" |
| | if key in inputs: |
| | _labels_dict[feat_name] = inputs[key] |
| |
|
| | |
| | for key in list(inputs.keys()): |
| | if key.startswith("labels_"): |
| | del inputs[key] |
| |
|
| | |
| | outputs = model(**inputs, labels_dict=_labels_dict) |
| | |
| | loss, logits_dict = outputs |
| |
|
| | |
| | |
| | if num_items_in_batch is not None: |
| | |
| | pass |
| |
|
| | if return_outputs: |
| | |
| | return (loss, logits_dict) |
| | else: |
| | return loss |
| |
|
| | def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None): |
| | |
| | _labels_dict = {} |
| | for feat_name in self.all_labels: |
| | key = f"labels_{feat_name}" |
| | if key in inputs: |
| | _labels_dict[feat_name] = inputs[key] |
| | del inputs[key] |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = model(**inputs, labels_dict=_labels_dict) |
| |
|
| | loss, logits_dict = outputs |
| |
|
| | if prediction_loss_only: |
| | return (loss, None, None) |
| |
|
| | |
| | |
| | |
| | return loss, logits_dict, _labels_dict |
| |
|
| |
|
| | def multi_head_classification_reports(logits_dict, labels_dict, id2label_dict): |
| | """ |
| | For each head, generate a classification report (precision, recall, f1, etc. per class). |
| | Return them as a dict: {head_name: "string report"}. |
| | :param logits_dict: dict of {head_name: np.array(batch_size, seq_len, num_classes)} |
| | :param labels_dict: dict of {head_name: np.array(batch_size, seq_len)} |
| | :param id2label_dict: dict of {head_name: {id: label_str}} |
| | :return: A dict of classification-report strings, one per head. |
| | """ |
| | reports = {} |
| |
|
| | for head_name, logits in logits_dict.items(): |
| | if head_name not in labels_dict: |
| | continue |
| |
|
| | predictions = np.argmax(logits, axis=-1) |
| | valid_preds, valid_labels = [], [] |
| | for pred_seq, label_seq in zip(predictions, labels_dict[head_name]): |
| | for p, lab in zip(pred_seq, label_seq): |
| | if lab != -100: |
| | valid_preds.append(p) |
| | valid_labels.append(lab) |
| |
|
| | if len(valid_preds) == 0: |
| | reports[head_name] = "No valid predictions." |
| | continue |
| |
|
| | |
| | valid_preds_str = [id2label_dict[head_name][p] for p in valid_preds] |
| | valid_labels_str = [id2label_dict[head_name][l] for l in valid_labels] |
| |
|
| | |
| | report_str = classification_report( |
| | valid_labels_str, |
| | valid_preds_str, |
| | zero_division=0 |
| | ) |
| | reports[head_name] = report_str |
| |
|
| | return reports |
| |
|
| |
|
| | def multi_head_compute_metrics(logits_dict, labels_dict): |
| | """ |
| | For each head (e.g. xpos, deprel, Case, etc.), computes: |
| | - Accuracy |
| | - Precision (macro/micro) |
| | - Recall (macro/micro) |
| | - F1 (macro/micro) |
| | |
| | :param logits_dict: dict of {head_name: np.array of shape (batch_size, seq_len, num_classes)} |
| | :param labels_dict: dict of {head_name: np.array of shape (batch_size, seq_len)} |
| | :return: A dict with aggregated metrics. Keys prefixed by head_name, e.g. "xpos_accuracy", "xpos_f1_macro", etc. |
| | """ |
| | |
| | results = {} |
| |
|
| | for head_name, logits in logits_dict.items(): |
| | if head_name not in labels_dict: |
| | |
| | continue |
| |
|
| | |
| | predictions = np.argmax(logits, axis=-1) |
| |
|
| | |
| | valid_preds, valid_labels = [], [] |
| | for pred_seq, label_seq in zip(predictions, labels_dict[head_name]): |
| | for p, lab in zip(pred_seq, label_seq): |
| | if lab != -100: |
| | valid_preds.append(p) |
| | valid_labels.append(lab) |
| |
|
| | valid_preds = np.array(valid_preds) |
| | valid_labels = np.array(valid_labels) |
| |
|
| | if len(valid_preds) == 0: |
| | |
| | continue |
| |
|
| | |
| | accuracy = (valid_preds == valid_labels).mean() |
| |
|
| | |
| | precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support( |
| | valid_labels, valid_preds, average="macro", zero_division=0 |
| | ) |
| |
|
| | |
| | precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support( |
| | valid_labels, valid_preds, average="micro", zero_division=0 |
| | ) |
| |
|
| | results[f"{head_name}_accuracy"] = accuracy |
| | results[f"{head_name}_precision_macro"] = precision_macro |
| | results[f"{head_name}_recall_macro"] = recall_macro |
| | results[f"{head_name}_f1_macro"] = f1_macro |
| | results[f"{head_name}_precision_micro"] = precision_micro |
| | results[f"{head_name}_recall_micro"] = recall_micro |
| | results[f"{head_name}_f1_micro"] = f1_micro |
| |
|
| | return results |
| |
|
| |
|
| | if __name__ == "__main__": |
| | from datasets import DatasetDict, load_from_disk |
| | import argparse |
| | import logging.config |
| |
|
| | from utils import default_logging_config, get_torch_device, get_uniq_training_labels, show_examples |
| |
|
| | arg_parser = argparse.ArgumentParser(description="Train multi-task model.") |
| | arg_parser.add_argument("-A", "--accumulation-steps", help="Gradient accumulation steps.", |
| | action="store", type=int, default=8) |
| | arg_parser.add_argument("--data-only", help='Show training data info and exit.', |
| | action="store_true", default=False) |
| | arg_parser.add_argument("--data-path", help="Load training dataset from specified path.", |
| | action="store", default="./training_data") |
| | arg_parser.add_argument("-E", "--train-epochs", help="Number of epochs to train for.", |
| | action="store", type=int, default=3) |
| | arg_parser.add_argument("-V", "--eval-batch-size", help="Per device eval batch size.", |
| | action="store", type=int, default=2) |
| | arg_parser.add_argument("--from-base", help="Load a base model.", |
| | action="store", default=None, |
| | choices=[ |
| | "microsoft/deberta-v3-base", |
| | "microsoft/deberta-v3-large", |
| | |
| | ]) |
| | arg_parser.add_argument("-L", "--learning-rate", help="Learning rate.", |
| | action="store", type=float, default=5e-5) |
| | arg_parser.add_argument("--mini", help='Train model using small subset of examples for pipeline testing.', |
| | action="store_true", default=False) |
| | arg_parser.add_argument("--save-path", help="Save final model to specified path.", |
| | action="store", default="./ud_final") |
| | arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>", |
| | action="store", default=None) |
| | arg_parser.add_argument("--train", help='Train model using loaded examples.', |
| | action="store_true", default=False) |
| | arg_parser.add_argument("-T", "--train-batch-size", help="Per device train batch size.", |
| | action="store", type=int, default=2) |
| | args = arg_parser.parse_args() |
| | logging.config.dictConfig(default_logging_config) |
| | logger.info(f"Args {args}") |
| |
|
| | |
| | |
| | |
| |
|
| | loaded_dataset = load_from_disk(args.data_path) |
| | show_examples(loaded_dataset, args.show) |
| |
|
| | |
| | |
| | |
| |
|
| | if args.from_base: |
| | |
| | ALL_LABELS = {col: list(vals) for col, vals in get_uniq_training_labels(loaded_dataset).items()} |
| | LABEL2ID = { |
| | feat_name: {label: i for i, label in enumerate(ALL_LABELS[feat_name])} |
| | for feat_name in ALL_LABELS |
| | } |
| | ID2LABEL = { |
| | feat_name: {i: label for label, i in LABEL2ID[feat_name].items()} |
| | for feat_name in LABEL2ID |
| | } |
| | |
| | NUM_LABELS_DICT = {k: len(v) for k, v in ALL_LABELS.items()} |
| | model_name_or_path = args.from_base |
| | multi_head_model = MultiHeadModel.from_pretrained( |
| | model_name_or_path, |
| | config=MultiHeadModelConfig.from_pretrained( |
| | model_name_or_path, |
| | num_labels_dict=NUM_LABELS_DICT, |
| | label_maps=ALL_LABELS |
| | ) |
| | ) |
| | else: |
| | model_name_or_path = args.save_path |
| | |
| | multi_head_model = MultiHeadModel.from_pretrained(model_name_or_path) |
| | |
| | |
| | ALL_LABELS = multi_head_model.config.label_maps |
| | LABEL2ID = {feat: {label: i for i, label in enumerate(ALL_LABELS[feat])} for feat in ALL_LABELS} |
| | ID2LABEL = {feat: {i: label for label, i in LABEL2ID[feat].items()} for feat in LABEL2ID} |
| | logger.info(f"using {model_name_or_path}") |
| |
|
| | |
| | device = get_torch_device() |
| | multi_head_model.to(device) |
| |
|
| | tokenizer = DebertaV2TokenizerFast.from_pretrained( |
| | model_name_or_path, |
| | add_prefix_space=True, |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | if args.mini: |
| | loaded_dataset = DatasetDict({ |
| | "train": loaded_dataset["train"].shuffle(seed=42).select(range(1000)), |
| | "validation": loaded_dataset["validation"].shuffle(seed=42).select(range(100)), |
| | "test": loaded_dataset["test"].shuffle(seed=42).select(range(100)), |
| | }) |
| |
|
| | |
| | example_aligner = ExampleAligner(ALL_LABELS, LABEL2ID) |
| | tokenized_dataset = loaded_dataset.map( |
| | example_aligner.tokenize_and_align_labels, |
| | batched=True, |
| | remove_columns=loaded_dataset["train"].column_names, |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | trainer = MultiHeadTrainer( |
| | ALL_LABELS, |
| | model=multi_head_model, |
| | args=TrainingArguments( |
| | |
| | eval_strategy="steps", |
| | save_strategy="steps", |
| | load_best_model_at_end=True, |
| | num_train_epochs=args.train_epochs, |
| | learning_rate=args.learning_rate, |
| |
|
| | output_dir="training_output", |
| | overwrite_output_dir=True, |
| | remove_unused_columns=False, |
| |
|
| | logging_dir="training_logs", |
| | logging_steps=100, |
| |
|
| | |
| | per_device_eval_batch_size=args.eval_batch_size, |
| | per_device_train_batch_size=args.train_batch_size, |
| | gradient_accumulation_steps=args.accumulation_steps, |
| |
|
| | warmup_ratio=0.1, |
| | weight_decay=0.01, |
| | ), |
| | train_dataset=tokenized_dataset["train"], |
| | eval_dataset=tokenized_dataset["validation"], |
| | ) |
| |
|
| | if args.train: |
| | trainer.train() |
| | trainer.evaluate() |
| | trainer.save_model(args.save_path) |
| | tokenizer.save_pretrained(args.save_path) |
| |
|
| | |
| | |
| | |
| |
|
| | pred_output = trainer.predict(tokenized_dataset["test"]) |
| | pred_logits_dict = pred_output.predictions |
| | pred_labels_dict = pred_output.label_ids |
| | id2label_dict = ID2LABEL |
| |
|
| | |
| | metrics = multi_head_compute_metrics(pred_logits_dict, pred_labels_dict) |
| | for k,v in metrics.items(): |
| | print(f"{k}: {v:.4f}") |
| |
|
| | |
| | reports = multi_head_classification_reports(pred_logits_dict, pred_labels_dict, id2label_dict) |
| | for head_name, rstr in reports.items(): |
| | print(f"----- {head_name} classification report -----") |
| | print(rstr) |