BERT / src /train.py
Empfloo's picture
Upload 12 files
e829681 verified
import json
from collections import Counter, defaultdict
from os import PathLike
from typing import Any, Dict, List
from torch import nn
from transformers import Trainer, AutoModelForSequenceClassification, TrainingArguments
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_from_disk
from preprocess import dataset_preprocess
from pathlib import Path
import numpy as np
import torch
class BatchCollator:
def __init__(
self,
num_classes: int,
p_only_title: float = 0.4,
):
self.num_classes = num_classes
self.p_only_title = p_only_title
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
batch_size = len(features)
use_only_title = np.random.rand(batch_size) < self.p_only_title
input_ids = []
attention_mask = []
for i, feat in enumerate(features):
required_keys = [
'title_input_ids',
'title_attention_mask',
'full_input_ids',
'full_attention_mask',
'labels_ids',
]
if not all(k in feat for k in required_keys):
raise KeyError(f'{required_keys} must be in dataset elements!')
if use_only_title[i]:
input_ids.append(feat["title_input_ids"])
attention_mask.append(feat["title_attention_mask"])
else:
input_ids.append(feat["full_input_ids"])
attention_mask.append(feat["full_attention_mask"])
input_ids = torch.tensor(input_ids, dtype=torch.long)
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
labels = torch.zeros((batch_size, self.num_classes), dtype=torch.float32)
for i, el in enumerate(features):
for label_id in el["labels_ids"]:
labels[i, label_id] = 1.0
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
def get_dataset(
dataset_json_path: str | PathLike[str],
model_name: str,
categories_column: str
) -> Dict[str, Any]:
save_path = Path('./data')
if (save_path / 'tokenized_dataset').exists():
dataset = load_from_disk(str(save_path / 'tokenized_dataset'))
with open(save_path / 'cat2ids.json', 'r') as f:
cat2ids = json.load(f)
cat2ids = {k: int(v) for k, v in cat2ids.items()}
with open(save_path / 'ids2cat.json', 'r') as f:
ids2cat = json.load(f)
ids2cat = {int(k): v for k, v in ids2cat.items()}
return {
"dataset": dataset,
"cat2ids": cat2ids,
"ids2cat": ids2cat,
}
result = dataset_preprocess(
dataset_json_path,
model_name,
categories_column
)
save_data(result, './data')
return result
def save_data(result: Dict[str, Any], save_path: str | PathLike[str]):
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
if not (save_path / 'tokenized_dataset').exists():
dataset = result["dataset"]
dataset.save_to_disk(save_path / 'tokenized_dataset')
if not (save_path / 'cat2ids.json').exists():
with open(save_path / 'cat2ids.json', "w", encoding="utf-8") as f:
json.dump(result['cat2ids'], f, ensure_ascii=False, indent=4)
if not (save_path / 'ids2cat.json').exists():
with open(save_path / 'ids2cat.json', "w", encoding="utf-8") as f:
json.dump(result['ids2cat'], f, ensure_ascii=False, indent=4)
def compute_cats_weights(dataset, num_classes: int) -> torch.Tensor:
obj_per_cat = defaultdict(int)
for labels in dataset["labels_ids"]:
for label in labels:
obj_per_cat[label] += 1
num_obj = len(dataset)
pos_weights = []
for class_id in range(num_classes):
pos_count = obj_per_cat[class_id]
if pos_count == 0:
pos_weights.append(1.0)
else:
pos_weights.append((num_obj - pos_count) / pos_count)
return torch.tensor(pos_weights, dtype=torch.float32)
def compute_loss_func(outputs, labels, num_items_in_batch=None):
logits = outputs.logits
loss = nn.functional.binary_cross_entropy_with_logits(
input=logits,
target=labels,
pos_weight=pos_weights.to(logits.device)
)
return loss
if __name__ == '__main__':
dataset_path = Path('dataset.json')
model_name = 'oracat/bert-paper-classifier-arxiv'
output_model_path = Path(f'./checkpoints/{model_name.split("/")[-1]}/checkpoints')
categories_column = 'categories'
result = get_dataset(dataset_path, model_name, categories_column)
num_cats = len(result['cat2ids'])
print(f'num_cats: {num_cats}')
dataset = result['dataset']
dataset = dataset.train_test_split(test_size=0.1, seed=123)
train_dataset, test_dataset = dataset['train'], dataset['test']
#####
data = []
for x in train_dataset['labels_ids']:
data.extend(x)
cnt = Counter(data)
print(f'num classes in train: {len(cnt.keys())}')
data = []
for x in test_dataset['labels_ids']:
data.extend(x)
cnt = Counter(data)
print(f'num classes in test: {len(cnt.keys())}')
#####
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=num_cats,
problem_type='multi_label_classification',
ignore_mismatched_sizes=True
)
pos_weights = compute_cats_weights(train_dataset, num_cats)
assert pos_weights.shape[0] == num_cats
lora_config = LoraConfig(
modules_to_save=['classifier'],
use_rslora=True,
r=16,
lora_alpha=32,
lora_dropout=0.05,
task_type=TaskType.SEQ_CLS
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
training_args = TrainingArguments(
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
output_dir=output_model_path,
logging_dir=output_model_path / "runs",
num_train_epochs=20,
learning_rate=5e-4,
lr_scheduler_type='cosine',
warmup_steps=10,
optim='adamw_torch_fused',
weight_decay=0.001,
gradient_accumulation_steps=4,
bf16=True,
logging_strategy='epoch',
eval_strategy='epoch',
load_best_model_at_end=True,
save_only_model=False,
save_total_limit=2,
save_strategy='epoch',
disable_tqdm=False,
remove_unused_columns=False,
seed=42,
dataloader_num_workers=4
)
collator = BatchCollator(
num_classes=num_cats,
p_only_title=0.4
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
data_collator=collator,
compute_loss_func=compute_loss_func,
)
trainer.train()
model.save_pretrained(output_model_path / 'final_model')