|
|
from sklearn.metrics import classification_report, precision_recall_fscore_support |
|
|
from transformers import ( |
|
|
DebertaV2TokenizerFast, |
|
|
EarlyStoppingCallback, |
|
|
Trainer, |
|
|
TrainingArguments, |
|
|
) |
|
|
import logging |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from multi_head_model import MultiHeadModel, MultiHeadModelConfig |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExampleAligner: |
|
|
def __init__(self, all_labels, label2id): |
|
|
self.all_labels = all_labels |
|
|
self.label2id = label2id |
|
|
|
|
|
def tokenize_and_align_labels(self, examples): |
|
|
""" |
|
|
For each example, the tokenizer may produce multiple overlapping |
|
|
chunks if the tokens exceed 512 subwords. Each chunk will be |
|
|
length=512, with a stride=128 for the next chunk. |
|
|
We'll align labels so that subwords beyond the first in a token get -100. |
|
|
""" |
|
|
|
|
|
tokenized_batch = tokenizer( |
|
|
examples["tokens"], |
|
|
is_split_into_words=True, |
|
|
max_length=512, |
|
|
stride=128, |
|
|
truncation=True, |
|
|
return_overflowing_tokens=True, |
|
|
return_offsets_mapping=False, |
|
|
padding="max_length" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "overflow_to_sample_mapping" not in tokenized_batch: |
|
|
|
|
|
sample_map = [i for i in range(len(tokenized_batch["input_ids"]))] |
|
|
else: |
|
|
sample_map = tokenized_batch["overflow_to_sample_mapping"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
final_input_ids = [] |
|
|
final_attention_mask = [] |
|
|
final_labels_columns = {feat: [] for feat in self.all_labels} |
|
|
|
|
|
for i in range(len(tokenized_batch["input_ids"])): |
|
|
|
|
|
chunk_input_ids = tokenized_batch["input_ids"][i] |
|
|
chunk_attn_mask = tokenized_batch["attention_mask"][i] |
|
|
|
|
|
original_index = sample_map[i] |
|
|
word_ids = tokenized_batch.word_ids(batch_index=i) |
|
|
|
|
|
|
|
|
chunk_labels_dict = {} |
|
|
|
|
|
for feat_name in self.all_labels: |
|
|
|
|
|
token_labels = examples[feat_name][original_index] |
|
|
chunk_label_ids = [] |
|
|
|
|
|
previous_word_id = None |
|
|
for w_id in word_ids: |
|
|
if w_id is None: |
|
|
|
|
|
chunk_label_ids.append(-100) |
|
|
else: |
|
|
|
|
|
if w_id == previous_word_id: |
|
|
chunk_label_ids.append(-100) |
|
|
else: |
|
|
|
|
|
label_str = token_labels[w_id] |
|
|
label_id = self.label2id[feat_name][label_str] |
|
|
chunk_label_ids.append(label_id) |
|
|
previous_word_id = w_id |
|
|
|
|
|
chunk_labels_dict[feat_name] = chunk_label_ids |
|
|
|
|
|
final_input_ids.append(chunk_input_ids) |
|
|
final_attention_mask.append(chunk_attn_mask) |
|
|
for feat_name in self.all_labels: |
|
|
final_labels_columns[feat_name].append(chunk_labels_dict[feat_name]) |
|
|
|
|
|
|
|
|
|
|
|
result = { |
|
|
"input_ids": final_input_ids, |
|
|
"attention_mask": final_attention_mask, |
|
|
} |
|
|
|
|
|
for feat_name in self.all_labels: |
|
|
result[f"labels_{feat_name}"] = final_labels_columns[feat_name] |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiHeadTrainer(Trainer): |
|
|
def __init__(self, all_labels, **kwargs): |
|
|
self.all_labels = all_labels |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
|
|
|
|
|
_labels_dict = {} |
|
|
for feat_name in self.all_labels: |
|
|
key = f"labels_{feat_name}" |
|
|
if key in inputs: |
|
|
_labels_dict[feat_name] = inputs[key] |
|
|
|
|
|
|
|
|
for key in list(inputs.keys()): |
|
|
if key.startswith("labels_"): |
|
|
del inputs[key] |
|
|
|
|
|
|
|
|
outputs = model(**inputs, labels_dict=_labels_dict) |
|
|
|
|
|
loss, logits_dict = outputs |
|
|
|
|
|
|
|
|
|
|
|
if num_items_in_batch is not None: |
|
|
|
|
|
pass |
|
|
|
|
|
if return_outputs: |
|
|
|
|
|
return loss, logits_dict |
|
|
else: |
|
|
return loss |
|
|
|
|
|
def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None): |
|
|
|
|
|
_labels_dict = {} |
|
|
for feat_name in self.all_labels: |
|
|
key = f"labels_{feat_name}" |
|
|
if key in inputs: |
|
|
_labels_dict[feat_name] = inputs[key] |
|
|
del inputs[key] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs, labels_dict=_labels_dict) |
|
|
|
|
|
loss, logits_dict = outputs |
|
|
|
|
|
if prediction_loss_only: |
|
|
return (loss, None, None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return loss, logits_dict, _labels_dict |
|
|
|
|
|
|
|
|
def multi_head_classification_reports(logits_dict, labels_dict, id2label_dict): |
|
|
""" |
|
|
For each head, generate a classification report (precision, recall, f1, etc. per class). |
|
|
Return them as a dict: {head_name: "string report"}. |
|
|
:param logits_dict: dict of {head_name: np.array(batch_size, seq_len, num_classes)} |
|
|
:param labels_dict: dict of {head_name: np.array(batch_size, seq_len)} |
|
|
:param id2label_dict: dict of {head_name: {id: label_str}} |
|
|
:return: A dict of classification-report strings, one per head. |
|
|
""" |
|
|
reports = {} |
|
|
|
|
|
for head_name, logits in logits_dict.items(): |
|
|
if head_name not in labels_dict: |
|
|
continue |
|
|
|
|
|
predictions = np.argmax(logits, axis=-1) |
|
|
valid_preds, valid_labels = [], [] |
|
|
for pred_seq, label_seq in zip(predictions, labels_dict[head_name]): |
|
|
for p, lab in zip(pred_seq, label_seq): |
|
|
if lab != -100: |
|
|
valid_preds.append(p) |
|
|
valid_labels.append(lab) |
|
|
|
|
|
if len(valid_preds) == 0: |
|
|
reports[head_name] = "No valid predictions." |
|
|
continue |
|
|
|
|
|
|
|
|
valid_preds_str = [id2label_dict[head_name][p] for p in valid_preds] |
|
|
valid_labels_str = [id2label_dict[head_name][l] for l in valid_labels] |
|
|
|
|
|
|
|
|
report_str = classification_report( |
|
|
valid_labels_str, |
|
|
valid_preds_str, |
|
|
zero_division=0 |
|
|
) |
|
|
reports[head_name] = report_str |
|
|
|
|
|
return reports |
|
|
|
|
|
|
|
|
def multi_head_compute_metrics(logits_dict, labels_dict): |
|
|
""" |
|
|
For each head (e.g. xpos, deprel, Case, etc.), computes: |
|
|
- Accuracy |
|
|
- Precision (macro/micro) |
|
|
- Recall (macro/micro) |
|
|
- F1 (macro/micro) |
|
|
|
|
|
:param logits_dict: dict of {head_name: np.array of shape (batch_size, seq_len, num_classes)} |
|
|
:param labels_dict: dict of {head_name: np.array of shape (batch_size, seq_len)} |
|
|
:return: A dict with aggregated metrics. Keys prefixed by head_name, e.g. "xpos_accuracy", "xpos_f1_macro", etc. |
|
|
""" |
|
|
|
|
|
results = {} |
|
|
|
|
|
for head_name, logits in logits_dict.items(): |
|
|
if head_name not in labels_dict: |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
predictions = np.argmax(logits, axis=-1) |
|
|
|
|
|
|
|
|
valid_preds, valid_labels = [], [] |
|
|
for pred_seq, label_seq in zip(predictions, labels_dict[head_name]): |
|
|
for p, lab in zip(pred_seq, label_seq): |
|
|
if lab != -100: |
|
|
valid_preds.append(p) |
|
|
valid_labels.append(lab) |
|
|
|
|
|
valid_preds = np.array(valid_preds) |
|
|
valid_labels = np.array(valid_labels) |
|
|
|
|
|
if len(valid_preds) == 0: |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
accuracy = (valid_preds == valid_labels).mean() |
|
|
|
|
|
|
|
|
precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support( |
|
|
valid_labels, valid_preds, average="macro", zero_division=0 |
|
|
) |
|
|
|
|
|
|
|
|
precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support( |
|
|
valid_labels, valid_preds, average="micro", zero_division=0 |
|
|
) |
|
|
|
|
|
results[f"{head_name}_accuracy"] = accuracy |
|
|
results[f"{head_name}_precision_macro"] = precision_macro |
|
|
results[f"{head_name}_recall_macro"] = recall_macro |
|
|
results[f"{head_name}_f1_macro"] = f1_macro |
|
|
results[f"{head_name}_precision_micro"] = precision_micro |
|
|
results[f"{head_name}_recall_micro"] = recall_micro |
|
|
results[f"{head_name}_f1_micro"] = f1_micro |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def multi_head_compute_metrics_aggregate_f1(logits_dict, labels_dict): |
|
|
results = multi_head_compute_metrics(logits_dict, labels_dict) |
|
|
|
|
|
|
|
|
f1_keys = [k for k in results.keys() if k.endswith("_f1_macro")] |
|
|
if not f1_keys: |
|
|
|
|
|
final_f1 = 0.0 |
|
|
else: |
|
|
final_f1 = np.mean([results[k] for k in f1_keys]) |
|
|
|
|
|
final_dict = {"f1_macro": final_f1} |
|
|
|
|
|
final_dict.update(results) |
|
|
return final_dict |
|
|
|
|
|
|
|
|
def compute_metrics_for_trainer(eval_pred): |
|
|
|
|
|
logits_dict, labels_dict = eval_pred.predictions, eval_pred.label_ids |
|
|
return multi_head_compute_metrics_aggregate_f1(logits_dict, labels_dict) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
from datasets import DatasetDict, load_from_disk |
|
|
import argparse |
|
|
import logging.config |
|
|
|
|
|
from utils import default_logging_config, get_torch_device, get_uniq_training_labels, show_examples |
|
|
|
|
|
arg_parser = argparse.ArgumentParser(description="Train multi-task model.") |
|
|
arg_parser.add_argument("-A", "--accumulation-steps", help="Gradient accumulation steps.", |
|
|
action="store", type=int, default=12) |
|
|
arg_parser.add_argument("--data-only", help='Show training data info and exit.', |
|
|
action="store_true", default=False) |
|
|
arg_parser.add_argument("--data-path", help="Load training dataset from specified path.", |
|
|
action="store", default="./training_data") |
|
|
arg_parser.add_argument("-E", "--train-epochs", help="Number of epochs to train for.", |
|
|
action="store", type=int, default=10) |
|
|
arg_parser.add_argument("-V", "--eval-batch-size", help="Per device eval batch size.", |
|
|
action="store", type=int, default=2) |
|
|
arg_parser.add_argument("--from-base", help="Load a base model.", |
|
|
action="store", default=None, |
|
|
choices=[ |
|
|
"microsoft/deberta-v3-base", |
|
|
"microsoft/deberta-v3-large", |
|
|
|
|
|
]) |
|
|
arg_parser.add_argument("-L", "--learning-rate", help="Learning rate.", |
|
|
action="store", type=float, default=2e-5) |
|
|
arg_parser.add_argument("--mini", help='Train model using small subset of examples for pipeline testing.', |
|
|
action="store_true", default=False) |
|
|
arg_parser.add_argument("--save-path", help="Save final model to specified path.", |
|
|
action="store", default="./final") |
|
|
arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>", |
|
|
action="store", default=None) |
|
|
arg_parser.add_argument("--train", help='Train model using loaded examples.', |
|
|
action="store_true", default=False) |
|
|
arg_parser.add_argument("-T", "--train-batch-size", help="Per device train batch size.", |
|
|
action="store", type=int, default=8) |
|
|
args = arg_parser.parse_args() |
|
|
logging.config.dictConfig(default_logging_config) |
|
|
logger.info(f"Args {args}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loaded_dataset = load_from_disk(args.data_path) |
|
|
show_examples(loaded_dataset, args.show) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.from_base: |
|
|
|
|
|
ALL_LABELS = {col: list(vals) for col, vals in get_uniq_training_labels(loaded_dataset).items()} |
|
|
LABEL2ID = { |
|
|
feat_name: {label: i for i, label in enumerate(ALL_LABELS[feat_name])} |
|
|
for feat_name in ALL_LABELS |
|
|
} |
|
|
ID2LABEL = { |
|
|
feat_name: {i: label for label, i in LABEL2ID[feat_name].items()} |
|
|
for feat_name in LABEL2ID |
|
|
} |
|
|
|
|
|
NUM_LABELS_DICT = {k: len(v) for k, v in ALL_LABELS.items()} |
|
|
model_name_or_path = args.from_base |
|
|
multi_head_model = MultiHeadModel.from_pretrained( |
|
|
model_name_or_path, |
|
|
config=MultiHeadModelConfig.from_pretrained( |
|
|
model_name_or_path, |
|
|
num_labels_dict=NUM_LABELS_DICT, |
|
|
label_maps=ALL_LABELS |
|
|
) |
|
|
) |
|
|
else: |
|
|
model_name_or_path = args.save_path |
|
|
|
|
|
multi_head_model = MultiHeadModel.from_pretrained(model_name_or_path) |
|
|
|
|
|
|
|
|
ALL_LABELS = multi_head_model.config.label_maps |
|
|
LABEL2ID = {feat: {label: i for i, label in enumerate(ALL_LABELS[feat])} for feat in ALL_LABELS} |
|
|
ID2LABEL = {feat: {i: label for label, i in LABEL2ID[feat].items()} for feat in LABEL2ID} |
|
|
logger.info(f"using {model_name_or_path}") |
|
|
|
|
|
|
|
|
device = get_torch_device() |
|
|
multi_head_model.to(device) |
|
|
|
|
|
tokenizer = DebertaV2TokenizerFast.from_pretrained( |
|
|
model_name_or_path, |
|
|
add_prefix_space=True, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.mini: |
|
|
loaded_dataset = DatasetDict({ |
|
|
"train": loaded_dataset["train"].shuffle(seed=42).select(range(1000)), |
|
|
"validation": loaded_dataset["validation"].shuffle(seed=42).select(range(100)), |
|
|
"test": loaded_dataset["test"].shuffle(seed=42).select(range(100)), |
|
|
}) |
|
|
|
|
|
|
|
|
example_aligner = ExampleAligner(ALL_LABELS, LABEL2ID) |
|
|
tokenized_dataset = loaded_dataset.map( |
|
|
example_aligner.tokenize_and_align_labels, |
|
|
batched=True, |
|
|
remove_columns=loaded_dataset["train"].column_names, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer = MultiHeadTrainer( |
|
|
ALL_LABELS, |
|
|
model=multi_head_model, |
|
|
args=TrainingArguments( |
|
|
|
|
|
eval_strategy="steps", |
|
|
save_strategy="steps", |
|
|
|
|
|
load_best_model_at_end=True, |
|
|
metric_for_best_model="f1_macro", |
|
|
greater_is_better=True, |
|
|
|
|
|
num_train_epochs=args.train_epochs, |
|
|
learning_rate=args.learning_rate, |
|
|
|
|
|
output_dir="training_output", |
|
|
overwrite_output_dir=True, |
|
|
remove_unused_columns=False, |
|
|
|
|
|
logging_dir="training_logs", |
|
|
logging_steps=100, |
|
|
|
|
|
|
|
|
per_device_eval_batch_size=args.eval_batch_size, |
|
|
per_device_train_batch_size=args.train_batch_size, |
|
|
gradient_accumulation_steps=args.accumulation_steps, |
|
|
|
|
|
warmup_ratio=0.1, |
|
|
|
|
|
|
|
|
weight_decay=0.01, |
|
|
), |
|
|
train_dataset=tokenized_dataset["train"], |
|
|
eval_dataset=tokenized_dataset["validation"], |
|
|
compute_metrics=compute_metrics_for_trainer, |
|
|
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] |
|
|
) |
|
|
|
|
|
if args.train: |
|
|
trainer.train() |
|
|
trainer.evaluate() |
|
|
trainer.save_model(args.save_path) |
|
|
tokenizer.save_pretrained(args.save_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_output = trainer.predict(tokenized_dataset["test"]) |
|
|
pred_logits_dict = pred_output.predictions |
|
|
pred_labels_dict = pred_output.label_ids |
|
|
id2label_dict = ID2LABEL |
|
|
|
|
|
reports = multi_head_classification_reports(pred_logits_dict, pred_labels_dict, id2label_dict) |
|
|
for head_name, rstr in reports.items(): |
|
|
print(f"----- {head_name} classification report -----") |
|
|
print(rstr) |