| import json
|
| from collections import Counter, defaultdict
|
| from os import PathLike
|
| from typing import Any, Dict, List
|
|
|
| from torch import nn
|
| from transformers import Trainer, AutoModelForSequenceClassification, TrainingArguments
|
| from peft import get_peft_model, LoraConfig, TaskType
|
| from datasets import load_from_disk
|
| from preprocess import dataset_preprocess
|
| from pathlib import Path
|
|
|
| import numpy as np
|
| import torch
|
|
|
|
|
| class BatchCollator:
|
| def __init__(
|
| self,
|
| num_classes: int,
|
| p_only_title: float = 0.4,
|
| ):
|
| self.num_classes = num_classes
|
| self.p_only_title = p_only_title
|
|
|
| def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| batch_size = len(features)
|
|
|
| use_only_title = np.random.rand(batch_size) < self.p_only_title
|
|
|
| input_ids = []
|
| attention_mask = []
|
|
|
| for i, feat in enumerate(features):
|
|
|
| required_keys = [
|
| 'title_input_ids',
|
| 'title_attention_mask',
|
| 'full_input_ids',
|
| 'full_attention_mask',
|
| 'labels_ids',
|
| ]
|
|
|
| if not all(k in feat for k in required_keys):
|
| raise KeyError(f'{required_keys} must be in dataset elements!')
|
|
|
| if use_only_title[i]:
|
| input_ids.append(feat["title_input_ids"])
|
| attention_mask.append(feat["title_attention_mask"])
|
| else:
|
| input_ids.append(feat["full_input_ids"])
|
| attention_mask.append(feat["full_attention_mask"])
|
|
|
| input_ids = torch.tensor(input_ids, dtype=torch.long)
|
| attention_mask = torch.tensor(attention_mask, dtype=torch.long)
|
|
|
| labels = torch.zeros((batch_size, self.num_classes), dtype=torch.float32)
|
|
|
| for i, el in enumerate(features):
|
| for label_id in el["labels_ids"]:
|
| labels[i, label_id] = 1.0
|
|
|
| return {
|
| "input_ids": input_ids,
|
| "attention_mask": attention_mask,
|
| "labels": labels,
|
| }
|
|
|
| def get_dataset(
|
| dataset_json_path: str | PathLike[str],
|
| model_name: str,
|
| categories_column: str
|
| ) -> Dict[str, Any]:
|
|
|
| save_path = Path('./data')
|
|
|
| if (save_path / 'tokenized_dataset').exists():
|
| dataset = load_from_disk(str(save_path / 'tokenized_dataset'))
|
|
|
| with open(save_path / 'cat2ids.json', 'r') as f:
|
| cat2ids = json.load(f)
|
| cat2ids = {k: int(v) for k, v in cat2ids.items()}
|
|
|
| with open(save_path / 'ids2cat.json', 'r') as f:
|
| ids2cat = json.load(f)
|
| ids2cat = {int(k): v for k, v in ids2cat.items()}
|
|
|
| return {
|
| "dataset": dataset,
|
| "cat2ids": cat2ids,
|
| "ids2cat": ids2cat,
|
| }
|
|
|
| result = dataset_preprocess(
|
| dataset_json_path,
|
| model_name,
|
| categories_column
|
| )
|
| save_data(result, './data')
|
|
|
| return result
|
|
|
| def save_data(result: Dict[str, Any], save_path: str | PathLike[str]):
|
| save_path = Path(save_path)
|
| save_path.mkdir(parents=True, exist_ok=True)
|
|
|
| if not (save_path / 'tokenized_dataset').exists():
|
| dataset = result["dataset"]
|
| dataset.save_to_disk(save_path / 'tokenized_dataset')
|
|
|
| if not (save_path / 'cat2ids.json').exists():
|
| with open(save_path / 'cat2ids.json', "w", encoding="utf-8") as f:
|
| json.dump(result['cat2ids'], f, ensure_ascii=False, indent=4)
|
|
|
| if not (save_path / 'ids2cat.json').exists():
|
| with open(save_path / 'ids2cat.json', "w", encoding="utf-8") as f:
|
| json.dump(result['ids2cat'], f, ensure_ascii=False, indent=4)
|
|
|
| def compute_cats_weights(dataset, num_classes: int) -> torch.Tensor:
|
| obj_per_cat = defaultdict(int)
|
|
|
| for labels in dataset["labels_ids"]:
|
| for label in labels:
|
| obj_per_cat[label] += 1
|
|
|
| num_obj = len(dataset)
|
| pos_weights = []
|
|
|
| for class_id in range(num_classes):
|
| pos_count = obj_per_cat[class_id]
|
|
|
| if pos_count == 0:
|
| pos_weights.append(1.0)
|
| else:
|
| pos_weights.append((num_obj - pos_count) / pos_count)
|
|
|
| return torch.tensor(pos_weights, dtype=torch.float32)
|
|
|
|
|
| def compute_loss_func(outputs, labels, num_items_in_batch=None):
|
| logits = outputs.logits
|
| loss = nn.functional.binary_cross_entropy_with_logits(
|
| input=logits,
|
| target=labels,
|
| pos_weight=pos_weights.to(logits.device)
|
| )
|
| return loss
|
|
|
|
|
| if __name__ == '__main__':
|
| dataset_path = Path('dataset.json')
|
| model_name = 'oracat/bert-paper-classifier-arxiv'
|
| output_model_path = Path(f'./checkpoints/{model_name.split("/")[-1]}/checkpoints')
|
|
|
| categories_column = 'categories'
|
|
|
| result = get_dataset(dataset_path, model_name, categories_column)
|
|
|
| num_cats = len(result['cat2ids'])
|
| print(f'num_cats: {num_cats}')
|
|
|
| dataset = result['dataset']
|
| dataset = dataset.train_test_split(test_size=0.1, seed=123)
|
| train_dataset, test_dataset = dataset['train'], dataset['test']
|
|
|
|
|
|
|
| data = []
|
| for x in train_dataset['labels_ids']:
|
| data.extend(x)
|
|
|
| cnt = Counter(data)
|
| print(f'num classes in train: {len(cnt.keys())}')
|
|
|
| data = []
|
| for x in test_dataset['labels_ids']:
|
| data.extend(x)
|
|
|
| cnt = Counter(data)
|
| print(f'num classes in test: {len(cnt.keys())}')
|
|
|
|
|
| model = AutoModelForSequenceClassification.from_pretrained(
|
| model_name,
|
| num_labels=num_cats,
|
| problem_type='multi_label_classification',
|
| ignore_mismatched_sizes=True
|
| )
|
|
|
| pos_weights = compute_cats_weights(train_dataset, num_cats)
|
|
|
| assert pos_weights.shape[0] == num_cats
|
|
|
| lora_config = LoraConfig(
|
| modules_to_save=['classifier'],
|
| use_rslora=True,
|
| r=16,
|
| lora_alpha=32,
|
| lora_dropout=0.05,
|
| task_type=TaskType.SEQ_CLS
|
| )
|
|
|
| model = get_peft_model(model, lora_config)
|
| model.print_trainable_parameters()
|
|
|
| training_args = TrainingArguments(
|
| per_device_train_batch_size=64,
|
| per_device_eval_batch_size=64,
|
|
|
| output_dir=output_model_path,
|
| logging_dir=output_model_path / "runs",
|
|
|
| num_train_epochs=20,
|
| learning_rate=5e-4,
|
| lr_scheduler_type='cosine',
|
| warmup_steps=10,
|
| optim='adamw_torch_fused',
|
| weight_decay=0.001,
|
|
|
| gradient_accumulation_steps=4,
|
| bf16=True,
|
|
|
| logging_strategy='epoch',
|
| eval_strategy='epoch',
|
|
|
| load_best_model_at_end=True,
|
| save_only_model=False,
|
| save_total_limit=2,
|
| save_strategy='epoch',
|
| disable_tqdm=False,
|
| remove_unused_columns=False,
|
| seed=42,
|
| dataloader_num_workers=4
|
| )
|
|
|
| collator = BatchCollator(
|
| num_classes=num_cats,
|
| p_only_title=0.4
|
| )
|
|
|
| trainer = Trainer(
|
| model=model,
|
| args=training_args,
|
| train_dataset=train_dataset,
|
| eval_dataset=test_dataset,
|
| data_collator=collator,
|
| compute_loss_func=compute_loss_func,
|
| )
|
|
|
| trainer.train()
|
|
|
| model.save_pretrained(output_model_path / 'final_model')
|
|
|