multi-classifier / ud_multi_task_classifier.py
veryfansome's picture
feat: working training
ab4b5ab
raw
history blame
21.6 kB
from datasets import DatasetDict, load_from_disk
from sklearn.metrics import classification_report, precision_recall_fscore_support
from transformers import (
DebertaV2Config,
DebertaV2Model,
DebertaV2PreTrainedModel,
DebertaV2TokenizerFast,
Trainer,
TrainingArguments,
)
import argparse
import logging.config
import numpy as np
import torch
import torch.nn as nn
from utils import default_logging_config, get_uniq_training_labels, show_examples
logger = logging.getLogger(__name__)
arg_parser = argparse.ArgumentParser(description="Train multi-task model.")
arg_parser.add_argument("-A", "--accumulation-steps", help="Gradient accumulation steps.",
action="store", type=int, default=8)
arg_parser.add_argument("--data-only", help='Show training data info and exit.',
action="store_true", default=False)
arg_parser.add_argument("--data-path", help="Load training dataset from specified path.",
action="store", default="./training_data")
arg_parser.add_argument("-E", "--train-epochs", help="Number of epochs to train for.",
action="store", type=int, default=3)
arg_parser.add_argument("-V", "--eval-batch-size", help="Per device eval batch size.",
action="store", type=int, default=2)
arg_parser.add_argument("--from-base", help="Load a base model.",
action="store", default=None,
choices=[
"microsoft/deberta-v3-base", # Requires --deberta-v3
"microsoft/deberta-v3-large", # Requires --deberta-v3
# More?
])
arg_parser.add_argument("-L", "--learning-rate", help="Learning rate.",
action="store", type=float, default=5e-5)
arg_parser.add_argument("--mini", help='Train model using small subset of examples for pipeline testing.',
action="store_true", default=False)
arg_parser.add_argument("--save-path", help="Save final model to specified path.",
action="store", default="./final")
arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>",
action="store", default=None)
arg_parser.add_argument("--train", help='Train model using loaded examples.',
action="store_true", default=False)
arg_parser.add_argument("-T", "--train-batch-size", help="Per device train batch size.",
action="store", type=int, default=2)
args = arg_parser.parse_args()
logging.config.dictConfig(default_logging_config)
logger.info(f"Args {args}")
# ------------------------------------------------------------------------------
# Load dataset and show examples for manual inspection
# ------------------------------------------------------------------------------
loaded_dataset = load_from_disk(args.data_path)
show_examples(loaded_dataset, args.show)
# ------------------------------------------------------------------------------
# Convert label analysis data into label sets for each head
# ------------------------------------------------------------------------------
ALL_LABELS = {col: list(vals) for col, vals in get_uniq_training_labels(loaded_dataset).items()}
LABEL2ID = {
feat_name: {label: i for i, label in enumerate(ALL_LABELS[feat_name])}
for feat_name in ALL_LABELS
}
ID2LABEL = {
feat_name: {i: label for label, i in LABEL2ID[feat_name].items()}
for feat_name in LABEL2ID
}
# Each head's number of labels:
NUM_LABELS_DICT = {k: len(v) for k, v in ALL_LABELS.items()}
if args.data_only:
exit()
# ------------------------------------------------------------------------------
# Create a custom config that can store our multi-label info
# ------------------------------------------------------------------------------
class MultiHeadModelConfig(DebertaV2Config):
def __init__(self, label_maps=None, num_labels_dict=None, **kwargs):
super().__init__(**kwargs)
self.label_maps = label_maps or {}
self.num_labels_dict = num_labels_dict or {}
def to_dict(self):
output = super().to_dict()
output["label_maps"] = self.label_maps
output["num_labels_dict"] = self.num_labels_dict
return output
# ------------------------------------------------------------------------------
# Define a multi-head model
# ------------------------------------------------------------------------------
class MultiHeadModel(DebertaV2PreTrainedModel):
def __init__(self, config: MultiHeadModelConfig):
super().__init__(config)
self.deberta = DebertaV2Model(config)
self.classifiers = nn.ModuleDict()
hidden_size = config.hidden_size
for label_name, n_labels in config.num_labels_dict.items():
self.classifiers[label_name] = nn.Linear(hidden_size, n_labels)
# Initialize newly added weights
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
labels_dict=None,
**kwargs
):
"""
labels_dict: a dict of { label_name: (batch_size, seq_len) } with label ids.
If provided, we compute and return the sum of CE losses.
"""
outputs = self.deberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
**kwargs
)
sequence_output = outputs.last_hidden_state # (batch_size, seq_len, hidden_size)
logits_dict = {}
for label_name, classifier in self.classifiers.items():
logits_dict[label_name] = classifier(sequence_output)
total_loss = None
loss_dict = {}
if labels_dict is not None:
# We'll sum the losses from each head
loss_fct = nn.CrossEntropyLoss()
total_loss = 0.0
for label_name, logits in logits_dict.items():
if label_name not in labels_dict:
continue
label_ids = labels_dict[label_name]
# A typical approach for token classification:
# We ignore positions where label_ids == -100
active_loss = label_ids != -100 # shape (bs, seq_len)
# flatten everything
active_logits = logits.view(-1, logits.shape[-1])[active_loss.view(-1)]
active_labels = label_ids.view(-1)[active_loss.view(-1)]
loss = loss_fct(active_logits, active_labels)
loss_dict[label_name] = loss.item()
total_loss += loss
if labels_dict is not None:
# return (loss, predictions)
return total_loss, logits_dict
else:
# just return predictions
return logits_dict
# ------------------------------------------------------------------------------
# Tokenize with max_length=512, stride=128, and subword alignment
# ------------------------------------------------------------------------------
def tokenize_and_align_labels(examples):
"""
For each example, the tokenizer may produce multiple overlapping
chunks if the tokens exceed 512 subwords. Each chunk will be
length=512, with a stride=128 for the next chunk.
We'll align labels so that subwords beyond the first in a token get -100.
"""
# We rely on is_split_into_words=True because examples["tokens"] is a list of token strings.
tokenized_batch = tokenizer(
examples["tokens"],
is_split_into_words=True,
max_length=512,
stride=128,
truncation=True,
return_overflowing_tokens=True,
return_offsets_mapping=False, # not mandatory for basic alignment
padding="max_length"
)
# The tokenizer returns "overflow_to_sample_mapping", telling us
# which original example index each chunk corresponds to.
# If the tokenizer didn't need to create overflows, the key might be missing
if "overflow_to_sample_mapping" not in tokenized_batch:
# No overflow => each input corresponds 1:1 with the original example
sample_map = [i for i in range(len(tokenized_batch["input_ids"]))]
else:
sample_map = tokenized_batch["overflow_to_sample_mapping"]
# We'll build lists for final outputs.
# For each chunk i, we produce:
# "input_ids"[i], "attention_mask"[i], plus per-feature label IDs.
final_input_ids = []
final_attention_mask = []
final_labels_columns = {feat: [] for feat in ALL_LABELS} # store one label-sequence per chunk
for i in range(len(tokenized_batch["input_ids"])):
# chunk i
chunk_input_ids = tokenized_batch["input_ids"][i]
chunk_attn_mask = tokenized_batch["attention_mask"][i]
original_index = sample_map[i] # which example in the original batch
word_ids = tokenized_batch.word_ids(batch_index=i)
# We'll build label arrays for each feature
chunk_labels_dict = {}
for feat_name in ALL_LABELS:
# The UD token-level labels for the *original* example
token_labels = examples[feat_name][original_index] # e.g. length T
chunk_label_ids = []
previous_word_id = None
for w_id in word_ids:
if w_id is None:
# special token (CLS, SEP, padding)
chunk_label_ids.append(-100)
else:
# If it's the same word_id as before, it's a subword => label = -100
if w_id == previous_word_id:
chunk_label_ids.append(-100)
else:
# New token => use the actual label
label_str = token_labels[w_id]
label_id = LABEL2ID[feat_name][label_str]
chunk_label_ids.append(label_id)
previous_word_id = w_id
chunk_labels_dict[feat_name] = chunk_label_ids
final_input_ids.append(chunk_input_ids)
final_attention_mask.append(chunk_attn_mask)
for feat_name in ALL_LABELS:
final_labels_columns[feat_name].append(chunk_labels_dict[feat_name])
# Return the new "flattened" set of chunks
# So the "map" call will expand each example → multiple chunk examples.
result = {
"input_ids": final_input_ids,
"attention_mask": final_attention_mask,
}
# We'll store each feature's label IDs in separate columns (e.g. labels_xpos, labels_deprel, etc.)
for feat_name in ALL_LABELS:
result[f"labels_{feat_name}"] = final_labels_columns[feat_name]
return result
# ------------------------------------------------------------------------------
# Trainer Setup
# ------------------------------------------------------------------------------
class MultiHeadTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
# 1) Gather all your per-feature labels from inputs
_labels_dict = {}
for feat_name in ALL_LABELS:
key = f"labels_{feat_name}"
if key in inputs:
_labels_dict[feat_name] = inputs[key]
# 2) Remove them so they don't get passed incorrectly to the model
for key in list(inputs.keys()):
if key.startswith("labels_"):
del inputs[key]
# 3) Call model(...) with _labels_dict
outputs = model(**inputs, labels_dict=_labels_dict)
# 'outputs' is (loss, logits_dict) in training/eval mode
loss, logits_dict = outputs
# Optional: if your special param is used upstream for some logic,
# you can handle it here or pass it along. For example:
if num_items_in_batch is not None:
# ... do something if needed ...
pass
if return_outputs:
# Return (loss, logits_dict) so Trainer sees logits_dict as predictions
return (loss, logits_dict)
else:
return loss
def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None):
# 1) gather the "labels_xxx" columns
_labels_dict = {}
for feat_name in ALL_LABELS:
key = f"labels_{feat_name}"
if key in inputs:
_labels_dict[feat_name] = inputs[key]
del inputs[key]
# 2) forward pass without those keys
with torch.no_grad():
outputs = model(**inputs, labels_dict=_labels_dict)
loss, logits_dict = outputs # you are returning (loss, dict-of-arrays)
if prediction_loss_only:
return (loss, None, None)
# The trainer expects a triple: (loss, predictions, labels)
# - 'predictions' can be the dictionary
# - 'labels' can be the dictionary of label IDs
return (loss, logits_dict, _labels_dict)
def multi_head_classification_reports(logits_dict, labels_dict, id2label_dict):
"""
For each head, generate a classification report (precision, recall, f1, etc. per class).
Return them as a dict: {head_name: "string report"}.
:param logits_dict: dict of {head_name: np.array(batch_size, seq_len, num_classes)}
:param labels_dict: dict of {head_name: np.array(batch_size, seq_len)}
:param id2label_dict: dict of {head_name: {id: label_str}}
:return: A dict of classification-report strings, one per head.
"""
reports = {}
for head_name, logits in logits_dict.items():
if head_name not in labels_dict:
continue
predictions = np.argmax(logits, axis=-1)
valid_preds, valid_labels = [], []
for pred_seq, label_seq in zip(predictions, labels_dict[head_name]):
for p, lab in zip(pred_seq, label_seq):
if lab != -100:
valid_preds.append(p)
valid_labels.append(lab)
if len(valid_preds) == 0:
reports[head_name] = "No valid predictions."
continue
# Convert numeric IDs to string labels
valid_preds_str = [id2label_dict[head_name][p] for p in valid_preds]
valid_labels_str = [id2label_dict[head_name][l] for l in valid_labels]
# Generate the per-class classification report
report_str = classification_report(
valid_labels_str,
valid_preds_str,
zero_division=0
)
reports[head_name] = report_str
return reports
def multi_head_compute_metrics(logits_dict, labels_dict):
"""
For each head (e.g. xpos, deprel, Case, etc.), computes:
- Accuracy
- Precision (macro/micro)
- Recall (macro/micro)
- F1 (macro/micro)
:param logits_dict: dict of {head_name: np.array of shape (batch_size, seq_len, num_classes)}
:param labels_dict: dict of {head_name: np.array of shape (batch_size, seq_len)}
:return: A dict with aggregated metrics. Keys prefixed by head_name, e.g. "xpos_accuracy", "xpos_f1_macro", etc.
"""
# We'll accumulate metrics in one big dictionary, keyed by "<head>_<metric>"
results = {}
for head_name, logits in logits_dict.items():
if head_name not in labels_dict:
# In case there's a mismatch or a head we didn't provide labels for
continue
# (batch_size, seq_len, num_classes)
predictions = np.argmax(logits, axis=-1) # => (batch_size, seq_len)
# Flatten ignoring positions where label == -100
valid_preds, valid_labels = [], []
for pred_seq, label_seq in zip(predictions, labels_dict[head_name]):
for p, lab in zip(pred_seq, label_seq):
if lab != -100:
valid_preds.append(p)
valid_labels.append(lab)
valid_preds = np.array(valid_preds)
valid_labels = np.array(valid_labels)
if len(valid_preds) == 0:
# No valid data for this head—skip
continue
# Overall token-level accuracy
accuracy = (valid_preds == valid_labels).mean()
# Macro average => treat each class equally
precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
valid_labels, valid_preds, average="macro", zero_division=0
)
# Micro average => aggregate across all classes
precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
valid_labels, valid_preds, average="micro", zero_division=0
)
results[f"{head_name}_accuracy"] = accuracy
results[f"{head_name}_precision_macro"] = precision_macro
results[f"{head_name}_recall_macro"] = recall_macro
results[f"{head_name}_f1_macro"] = f1_macro
results[f"{head_name}_precision_micro"] = precision_micro
results[f"{head_name}_recall_micro"] = recall_micro
results[f"{head_name}_f1_micro"] = f1_micro
return results
# ------------------------------------------------------------------------------
# Instantiate model and tokenizer
# ------------------------------------------------------------------------------
if args.from_base:
model_name_or_path = args.from_base
multi_head_model = MultiHeadModel.from_pretrained(
model_name_or_path,
config=MultiHeadModelConfig.from_pretrained(
model_name_or_path,
num_labels_dict=NUM_LABELS_DICT,
label_maps=ALL_LABELS
)
)
else:
model_name_or_path = args.save_path
# For evaluation, always load the saved checkpoint without overriding the config.
multi_head_model = MultiHeadModel.from_pretrained(model_name_or_path)
# EXTREMELY IMPORTANT!
# Override the label mapping based on the stored config to ensure consistency with training time ordering.
ALL_LABELS = multi_head_model.config.label_maps
LABEL2ID = {feat: {label: i for i, label in enumerate(ALL_LABELS[feat])} for feat in ALL_LABELS}
ID2LABEL = {feat: {i: label for label, i in LABEL2ID[feat].items()} for feat in LABEL2ID}
logger.info(f"using {model_name_or_path}")
# Check if GPU is usable
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available(): # For Apple Silicon MPS
device = torch.device("mps")
else:
device = torch.device("cpu")
logger.info(f"using {device}")
multi_head_model.to(device)
tokenizer = DebertaV2TokenizerFast.from_pretrained(
model_name_or_path,
add_prefix_space=True,
)
# ------------------------------------------------------------------------------
# Shuffle, (optionally) sample, and tokenize final merged dataset
# ------------------------------------------------------------------------------
if args.mini:
loaded_dataset = DatasetDict({
"train": loaded_dataset["train"].shuffle(seed=42).select(range(1000)),
"validation": loaded_dataset["validation"].shuffle(seed=42).select(range(100)),
"test": loaded_dataset["test"].shuffle(seed=42).select(range(100)),
})
# remove_columns => remove old "text", "tokens", etc. so we keep only model inputs
tokenized_dataset = loaded_dataset.map(
tokenize_and_align_labels,
batched=True,
remove_columns=loaded_dataset["train"].column_names,
)
# ------------------------------------------------------------------------------
# Train the model!
# ------------------------------------------------------------------------------
"""
Current bests:
deberta-v3-base:
num_train_epochs=3,
learning_rate=5e-5,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
"""
training_args = TrainingArguments(
# Evaluate less frequently or keep the same
eval_strategy="epoch",
num_train_epochs=args.train_epochs,
learning_rate=args.learning_rate,
output_dir="training_output",
overwrite_output_dir=True,
remove_unused_columns=False, # important to keep the labels_xxx columns
logging_dir="training_logs",
logging_steps=100,
# Effective batch size = train_batch_size x gradient_accumulation_steps
per_device_train_batch_size=args.train_batch_size,
gradient_accumulation_steps=args.accumulation_steps,
per_device_eval_batch_size=args.eval_batch_size,
)
trainer = MultiHeadTrainer(
model=multi_head_model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
)
if args.train:
trainer.train()
trainer.evaluate()
trainer.save_model(args.save_path)
tokenizer.save_pretrained(args.save_path)
# ------------------------------------------------------------------------------
# Evaluate the model!
# ------------------------------------------------------------------------------
pred_output = trainer.predict(tokenized_dataset["test"])
pred_logits_dict = pred_output.predictions
pred_labels_dict = pred_output.label_ids
id2label_dict = ID2LABEL # from earlier definitions
# 1) Calculate metrics
metrics = multi_head_compute_metrics(pred_logits_dict, pred_labels_dict)
for k,v in metrics.items():
print(f"{k}: {v:.4f}")
# 2) Print classification reports
reports = multi_head_classification_reports(pred_logits_dict, pred_labels_dict, id2label_dict)
for head_name, rstr in reports.items():
print(f"----- {head_name} classification report -----")
print(rstr)