| from sklearn.metrics import classification_report, precision_recall_fscore_support |
| from transformers import ( |
| DebertaV2TokenizerFast, |
| Trainer, |
| TrainingArguments, |
| ) |
| import logging |
| import numpy as np |
| import torch |
|
|
| from multi_head_model import MultiHeadModel, MultiHeadModelConfig |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| class ExampleAligner: |
| def __init__(self, all_labels, label2id): |
| self.all_labels = all_labels |
| self.label2id = label2id |
|
|
| def tokenize_and_align_labels(self, examples): |
| """ |
| For each example, the tokenizer may produce multiple overlapping |
| chunks if the tokens exceed 512 subwords. Each chunk will be |
| length=512, with a stride=128 for the next chunk. |
| We'll align labels so that subwords beyond the first in a token get -100. |
| """ |
| |
| tokenized_batch = tokenizer( |
| examples["tokens"], |
| is_split_into_words=True, |
| max_length=512, |
| stride=128, |
| truncation=True, |
| return_overflowing_tokens=True, |
| return_offsets_mapping=False, |
| padding="max_length" |
| ) |
|
|
| |
| |
| |
| if "overflow_to_sample_mapping" not in tokenized_batch: |
| |
| sample_map = [i for i in range(len(tokenized_batch["input_ids"]))] |
| else: |
| sample_map = tokenized_batch["overflow_to_sample_mapping"] |
|
|
| |
| |
| |
| final_input_ids = [] |
| final_attention_mask = [] |
| final_labels_columns = {feat: [] for feat in self.all_labels} |
|
|
| for i in range(len(tokenized_batch["input_ids"])): |
| |
| chunk_input_ids = tokenized_batch["input_ids"][i] |
| chunk_attn_mask = tokenized_batch["attention_mask"][i] |
|
|
| original_index = sample_map[i] |
| word_ids = tokenized_batch.word_ids(batch_index=i) |
|
|
| |
| chunk_labels_dict = {} |
|
|
| for feat_name in self.all_labels: |
| |
| token_labels = examples[feat_name][original_index] |
| chunk_label_ids = [] |
|
|
| previous_word_id = None |
| for w_id in word_ids: |
| if w_id is None: |
| |
| chunk_label_ids.append(-100) |
| else: |
| |
| if w_id == previous_word_id: |
| chunk_label_ids.append(-100) |
| else: |
| |
| label_str = token_labels[w_id] |
| label_id = self.label2id[feat_name][label_str] |
| chunk_label_ids.append(label_id) |
| previous_word_id = w_id |
|
|
| chunk_labels_dict[feat_name] = chunk_label_ids |
|
|
| final_input_ids.append(chunk_input_ids) |
| final_attention_mask.append(chunk_attn_mask) |
| for feat_name in self.all_labels: |
| final_labels_columns[feat_name].append(chunk_labels_dict[feat_name]) |
|
|
| |
| |
| result = { |
| "input_ids": final_input_ids, |
| "attention_mask": final_attention_mask, |
| } |
| |
| for feat_name in self.all_labels: |
| result[f"labels_{feat_name}"] = final_labels_columns[feat_name] |
|
|
| return result |
|
|
| |
| |
| |
|
|
| class MultiHeadTrainer(Trainer): |
| def __init__(self, all_labels, **kwargs): |
| self.all_labels = all_labels |
| super().__init__(**kwargs) |
|
|
| def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
| |
| _labels_dict = {} |
| for feat_name in self.all_labels: |
| key = f"labels_{feat_name}" |
| if key in inputs: |
| _labels_dict[feat_name] = inputs[key] |
|
|
| |
| for key in list(inputs.keys()): |
| if key.startswith("labels_"): |
| del inputs[key] |
|
|
| |
| outputs = model(**inputs, labels_dict=_labels_dict) |
| |
| loss, logits_dict = outputs |
|
|
| |
| |
| if num_items_in_batch is not None: |
| |
| pass |
|
|
| if return_outputs: |
| |
| return (loss, logits_dict) |
| else: |
| return loss |
|
|
| def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None): |
| |
| _labels_dict = {} |
| for feat_name in self.all_labels: |
| key = f"labels_{feat_name}" |
| if key in inputs: |
| _labels_dict[feat_name] = inputs[key] |
| del inputs[key] |
|
|
| |
| with torch.no_grad(): |
| outputs = model(**inputs, labels_dict=_labels_dict) |
|
|
| loss, logits_dict = outputs |
|
|
| if prediction_loss_only: |
| return (loss, None, None) |
|
|
| |
| |
| |
| return loss, logits_dict, _labels_dict |
|
|
|
|
| def multi_head_classification_reports(logits_dict, labels_dict, id2label_dict): |
| """ |
| For each head, generate a classification report (precision, recall, f1, etc. per class). |
| Return them as a dict: {head_name: "string report"}. |
| :param logits_dict: dict of {head_name: np.array(batch_size, seq_len, num_classes)} |
| :param labels_dict: dict of {head_name: np.array(batch_size, seq_len)} |
| :param id2label_dict: dict of {head_name: {id: label_str}} |
| :return: A dict of classification-report strings, one per head. |
| """ |
| reports = {} |
|
|
| for head_name, logits in logits_dict.items(): |
| if head_name not in labels_dict: |
| continue |
|
|
| predictions = np.argmax(logits, axis=-1) |
| valid_preds, valid_labels = [], [] |
| for pred_seq, label_seq in zip(predictions, labels_dict[head_name]): |
| for p, lab in zip(pred_seq, label_seq): |
| if lab != -100: |
| valid_preds.append(p) |
| valid_labels.append(lab) |
|
|
| if len(valid_preds) == 0: |
| reports[head_name] = "No valid predictions." |
| continue |
|
|
| |
| valid_preds_str = [id2label_dict[head_name][p] for p in valid_preds] |
| valid_labels_str = [id2label_dict[head_name][l] for l in valid_labels] |
|
|
| |
| report_str = classification_report( |
| valid_labels_str, |
| valid_preds_str, |
| zero_division=0 |
| ) |
| reports[head_name] = report_str |
|
|
| return reports |
|
|
|
|
| def multi_head_compute_metrics(logits_dict, labels_dict): |
| """ |
| For each head (e.g. xpos, deprel, Case, etc.), computes: |
| - Accuracy |
| - Precision (macro/micro) |
| - Recall (macro/micro) |
| - F1 (macro/micro) |
| |
| :param logits_dict: dict of {head_name: np.array of shape (batch_size, seq_len, num_classes)} |
| :param labels_dict: dict of {head_name: np.array of shape (batch_size, seq_len)} |
| :return: A dict with aggregated metrics. Keys prefixed by head_name, e.g. "xpos_accuracy", "xpos_f1_macro", etc. |
| """ |
| |
| results = {} |
|
|
| for head_name, logits in logits_dict.items(): |
| if head_name not in labels_dict: |
| |
| continue |
|
|
| |
| predictions = np.argmax(logits, axis=-1) |
|
|
| |
| valid_preds, valid_labels = [], [] |
| for pred_seq, label_seq in zip(predictions, labels_dict[head_name]): |
| for p, lab in zip(pred_seq, label_seq): |
| if lab != -100: |
| valid_preds.append(p) |
| valid_labels.append(lab) |
|
|
| valid_preds = np.array(valid_preds) |
| valid_labels = np.array(valid_labels) |
|
|
| if len(valid_preds) == 0: |
| |
| continue |
|
|
| |
| accuracy = (valid_preds == valid_labels).mean() |
|
|
| |
| precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support( |
| valid_labels, valid_preds, average="macro", zero_division=0 |
| ) |
|
|
| |
| precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support( |
| valid_labels, valid_preds, average="micro", zero_division=0 |
| ) |
|
|
| results[f"{head_name}_accuracy"] = accuracy |
| results[f"{head_name}_precision_macro"] = precision_macro |
| results[f"{head_name}_recall_macro"] = recall_macro |
| results[f"{head_name}_f1_macro"] = f1_macro |
| results[f"{head_name}_precision_micro"] = precision_micro |
| results[f"{head_name}_recall_micro"] = recall_micro |
| results[f"{head_name}_f1_micro"] = f1_micro |
|
|
| return results |
|
|
|
|
| if __name__ == "__main__": |
| from datasets import DatasetDict, load_from_disk |
| import argparse |
| import logging.config |
|
|
| from utils import default_logging_config, get_torch_device, get_uniq_training_labels, show_examples |
|
|
| arg_parser = argparse.ArgumentParser(description="Train multi-task model.") |
| arg_parser.add_argument("-A", "--accumulation-steps", help="Gradient accumulation steps.", |
| action="store", type=int, default=8) |
| arg_parser.add_argument("--data-only", help='Show training data info and exit.', |
| action="store_true", default=False) |
| arg_parser.add_argument("--data-path", help="Load training dataset from specified path.", |
| action="store", default="./training_data") |
| arg_parser.add_argument("-E", "--train-epochs", help="Number of epochs to train for.", |
| action="store", type=int, default=3) |
| arg_parser.add_argument("-V", "--eval-batch-size", help="Per device eval batch size.", |
| action="store", type=int, default=2) |
| arg_parser.add_argument("--from-base", help="Load a base model.", |
| action="store", default=None, |
| choices=[ |
| "microsoft/deberta-v3-base", |
| "microsoft/deberta-v3-large", |
| |
| ]) |
| arg_parser.add_argument("-L", "--learning-rate", help="Learning rate.", |
| action="store", type=float, default=5e-5) |
| arg_parser.add_argument("--mini", help='Train model using small subset of examples for pipeline testing.', |
| action="store_true", default=False) |
| arg_parser.add_argument("--save-path", help="Save final model to specified path.", |
| action="store", default="./final") |
| arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>", |
| action="store", default=None) |
| arg_parser.add_argument("--train", help='Train model using loaded examples.', |
| action="store_true", default=False) |
| arg_parser.add_argument("-T", "--train-batch-size", help="Per device train batch size.", |
| action="store", type=int, default=2) |
| args = arg_parser.parse_args() |
| logging.config.dictConfig(default_logging_config) |
| logger.info(f"Args {args}") |
|
|
| |
| |
| |
|
|
| loaded_dataset = load_from_disk(args.data_path) |
| show_examples(loaded_dataset, args.show) |
|
|
| |
| |
| |
|
|
| if args.from_base: |
| |
| ALL_LABELS = {col: list(vals) for col, vals in get_uniq_training_labels(loaded_dataset).items()} |
| LABEL2ID = { |
| feat_name: {label: i for i, label in enumerate(ALL_LABELS[feat_name])} |
| for feat_name in ALL_LABELS |
| } |
| ID2LABEL = { |
| feat_name: {i: label for label, i in LABEL2ID[feat_name].items()} |
| for feat_name in LABEL2ID |
| } |
| |
| NUM_LABELS_DICT = {k: len(v) for k, v in ALL_LABELS.items()} |
| model_name_or_path = args.from_base |
| multi_head_model = MultiHeadModel.from_pretrained( |
| model_name_or_path, |
| config=MultiHeadModelConfig.from_pretrained( |
| model_name_or_path, |
| num_labels_dict=NUM_LABELS_DICT, |
| label_maps=ALL_LABELS |
| ) |
| ) |
| else: |
| model_name_or_path = args.save_path |
| |
| multi_head_model = MultiHeadModel.from_pretrained(model_name_or_path) |
| |
| |
| ALL_LABELS = multi_head_model.config.label_maps |
| LABEL2ID = {feat: {label: i for i, label in enumerate(ALL_LABELS[feat])} for feat in ALL_LABELS} |
| ID2LABEL = {feat: {i: label for label, i in LABEL2ID[feat].items()} for feat in LABEL2ID} |
| logger.info(f"using {model_name_or_path}") |
|
|
| |
| device = get_torch_device() |
| multi_head_model.to(device) |
|
|
| tokenizer = DebertaV2TokenizerFast.from_pretrained( |
| model_name_or_path, |
| add_prefix_space=True, |
| ) |
|
|
| |
| |
| |
|
|
| if args.mini: |
| loaded_dataset = DatasetDict({ |
| "train": loaded_dataset["train"].shuffle(seed=42).select(range(1000)), |
| "validation": loaded_dataset["validation"].shuffle(seed=42).select(range(100)), |
| "test": loaded_dataset["test"].shuffle(seed=42).select(range(100)), |
| }) |
|
|
| |
| example_aligner = ExampleAligner(ALL_LABELS, LABEL2ID) |
| tokenized_dataset = loaded_dataset.map( |
| example_aligner.tokenize_and_align_labels, |
| batched=True, |
| remove_columns=loaded_dataset["train"].column_names, |
| ) |
|
|
| |
| |
| |
|
|
| trainer = MultiHeadTrainer( |
| ALL_LABELS, |
| model=multi_head_model, |
| args=TrainingArguments( |
| |
| eval_strategy="steps", |
| save_strategy="steps", |
| load_best_model_at_end=True, |
| num_train_epochs=args.train_epochs, |
| learning_rate=args.learning_rate, |
|
|
| output_dir="training_output", |
| overwrite_output_dir=True, |
| remove_unused_columns=False, |
|
|
| logging_dir="training_logs", |
| logging_steps=100, |
|
|
| |
| per_device_eval_batch_size=args.eval_batch_size, |
| per_device_train_batch_size=args.train_batch_size, |
| gradient_accumulation_steps=args.accumulation_steps, |
|
|
| warmup_ratio=0.1, |
| weight_decay=0.01, |
| ), |
| train_dataset=tokenized_dataset["train"], |
| eval_dataset=tokenized_dataset["validation"], |
| ) |
|
|
| if args.train: |
| trainer.train() |
| trainer.evaluate() |
| trainer.save_model(args.save_path) |
| tokenizer.save_pretrained(args.save_path) |
|
|
| |
| |
| |
|
|
| pred_output = trainer.predict(tokenized_dataset["test"]) |
| pred_logits_dict = pred_output.predictions |
| pred_labels_dict = pred_output.label_ids |
| id2label_dict = ID2LABEL |
|
|
| |
| metrics = multi_head_compute_metrics(pred_logits_dict, pred_labels_dict) |
| for k,v in metrics.items(): |
| print(f"{k}: {v:.4f}") |
|
|
| |
| reports = multi_head_classification_reports(pred_logits_dict, pred_labels_dict, id2label_dict) |
| for head_name, rstr in reports.items(): |
| print(f"----- {head_name} classification report -----") |
| print(rstr) |