import json from collections import Counter, defaultdict from os import PathLike from typing import Any, Dict, List from torch import nn from transformers import Trainer, AutoModelForSequenceClassification, TrainingArguments from peft import get_peft_model, LoraConfig, TaskType from datasets import load_from_disk from preprocess import dataset_preprocess from pathlib import Path import numpy as np import torch class BatchCollator: def __init__( self, num_classes: int, p_only_title: float = 0.4, ): self.num_classes = num_classes self.p_only_title = p_only_title def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: batch_size = len(features) use_only_title = np.random.rand(batch_size) < self.p_only_title input_ids = [] attention_mask = [] for i, feat in enumerate(features): required_keys = [ 'title_input_ids', 'title_attention_mask', 'full_input_ids', 'full_attention_mask', 'labels_ids', ] if not all(k in feat for k in required_keys): raise KeyError(f'{required_keys} must be in dataset elements!') if use_only_title[i]: input_ids.append(feat["title_input_ids"]) attention_mask.append(feat["title_attention_mask"]) else: input_ids.append(feat["full_input_ids"]) attention_mask.append(feat["full_attention_mask"]) input_ids = torch.tensor(input_ids, dtype=torch.long) attention_mask = torch.tensor(attention_mask, dtype=torch.long) labels = torch.zeros((batch_size, self.num_classes), dtype=torch.float32) for i, el in enumerate(features): for label_id in el["labels_ids"]: labels[i, label_id] = 1.0 return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } def get_dataset( dataset_json_path: str | PathLike[str], model_name: str, categories_column: str ) -> Dict[str, Any]: save_path = Path('./data') if (save_path / 'tokenized_dataset').exists(): dataset = load_from_disk(str(save_path / 'tokenized_dataset')) with open(save_path / 'cat2ids.json', 'r') as f: cat2ids = json.load(f) cat2ids = {k: int(v) for k, v in cat2ids.items()} with open(save_path / 'ids2cat.json', 'r') as f: ids2cat = json.load(f) ids2cat = {int(k): v for k, v in ids2cat.items()} return { "dataset": dataset, "cat2ids": cat2ids, "ids2cat": ids2cat, } result = dataset_preprocess( dataset_json_path, model_name, categories_column ) save_data(result, './data') return result def save_data(result: Dict[str, Any], save_path: str | PathLike[str]): save_path = Path(save_path) save_path.mkdir(parents=True, exist_ok=True) if not (save_path / 'tokenized_dataset').exists(): dataset = result["dataset"] dataset.save_to_disk(save_path / 'tokenized_dataset') if not (save_path / 'cat2ids.json').exists(): with open(save_path / 'cat2ids.json', "w", encoding="utf-8") as f: json.dump(result['cat2ids'], f, ensure_ascii=False, indent=4) if not (save_path / 'ids2cat.json').exists(): with open(save_path / 'ids2cat.json', "w", encoding="utf-8") as f: json.dump(result['ids2cat'], f, ensure_ascii=False, indent=4) def compute_cats_weights(dataset, num_classes: int) -> torch.Tensor: obj_per_cat = defaultdict(int) for labels in dataset["labels_ids"]: for label in labels: obj_per_cat[label] += 1 num_obj = len(dataset) pos_weights = [] for class_id in range(num_classes): pos_count = obj_per_cat[class_id] if pos_count == 0: pos_weights.append(1.0) else: pos_weights.append((num_obj - pos_count) / pos_count) return torch.tensor(pos_weights, dtype=torch.float32) def compute_loss_func(outputs, labels, num_items_in_batch=None): logits = outputs.logits loss = nn.functional.binary_cross_entropy_with_logits( input=logits, target=labels, pos_weight=pos_weights.to(logits.device) ) return loss if __name__ == '__main__': dataset_path = Path('dataset.json') model_name = 'oracat/bert-paper-classifier-arxiv' output_model_path = Path(f'./checkpoints/{model_name.split("/")[-1]}/checkpoints') categories_column = 'categories' result = get_dataset(dataset_path, model_name, categories_column) num_cats = len(result['cat2ids']) print(f'num_cats: {num_cats}') dataset = result['dataset'] dataset = dataset.train_test_split(test_size=0.1, seed=123) train_dataset, test_dataset = dataset['train'], dataset['test'] ##### data = [] for x in train_dataset['labels_ids']: data.extend(x) cnt = Counter(data) print(f'num classes in train: {len(cnt.keys())}') data = [] for x in test_dataset['labels_ids']: data.extend(x) cnt = Counter(data) print(f'num classes in test: {len(cnt.keys())}') ##### model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=num_cats, problem_type='multi_label_classification', ignore_mismatched_sizes=True ) pos_weights = compute_cats_weights(train_dataset, num_cats) assert pos_weights.shape[0] == num_cats lora_config = LoraConfig( modules_to_save=['classifier'], use_rslora=True, r=16, lora_alpha=32, lora_dropout=0.05, task_type=TaskType.SEQ_CLS ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() training_args = TrainingArguments( per_device_train_batch_size=64, per_device_eval_batch_size=64, output_dir=output_model_path, logging_dir=output_model_path / "runs", num_train_epochs=20, learning_rate=5e-4, lr_scheduler_type='cosine', warmup_steps=10, optim='adamw_torch_fused', weight_decay=0.001, gradient_accumulation_steps=4, bf16=True, logging_strategy='epoch', eval_strategy='epoch', load_best_model_at_end=True, save_only_model=False, save_total_limit=2, save_strategy='epoch', disable_tqdm=False, remove_unused_columns=False, seed=42, dataloader_num_workers=4 ) collator = BatchCollator( num_classes=num_cats, p_only_title=0.4 ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=test_dataset, data_collator=collator, compute_loss_func=compute_loss_func, ) trainer.train() model.save_pretrained(output_model_path / 'final_model')