| """ |
| Finetuning functions to do post-distillation |
| """ |
| from os.path import join |
| from omegaconf import OmegaConf |
|
|
| import torch |
| from torch.nn import Module |
|
|
| from src.utils.setup import update_config_from_args |
| from src.dataloaders import load_data |
| from src.trainer import get_trainer, get_optimizer, get_scheduler |
|
|
|
|
| def prepare_finetune_configs(args, model_config: dict, |
| finetune_config_name: str = None, |
| finetune_checkpoint_name: str = None, |
| config_dir='./configs/experiment'): |
| """ |
| Prepare finetuning configs |
| """ |
| |
| finetune_config = (finetune_config_name if finetune_config_name is not None else |
| finetune_checkpoint_name.split('-f=')[-1].split('-')[0]) |
| finetune_config_path = join(config_dir, f'{finetune_config}.yaml') |
| finetune_config = OmegaConf.load(finetune_config_path) |
| finetune_config = update_config_from_args(finetune_config, args, |
| ignore_args=['lr', 'weight_decay']) |
| |
| if getattr(finetune_config.dataset, 'pretrained_model_config', None) is not None: |
| for k in ['pretrained_model_name_or_path', 'cache_dir']: |
| finetune_config.dataset.pretrained_model_config[k] = model_config['model'][k] |
| |
| for arg, argv in finetune_config.trainer.items(): |
| if arg != 'name': |
| setattr(args, arg, argv) |
| for _config in ['dataloader', 'optimizer', 'lr_scheduler']: |
| setattr(args, _config, OmegaConf.to_container(getattr(finetune_config, _config))) |
| return finetune_config, args |
|
|
|
|
| def get_finetuner(model: Module, finetune_config: dict, device: torch.device, |
| args: any, wandb: any, initial_eval: bool = False): |
| """ |
| Initialize finetuning trainer |
| """ |
| model.to(device) |
| model.train() |
|
|
| |
| optimizer = get_optimizer(model=model, **finetune_config.optimizer) |
| scheduler = get_scheduler(optimizer=optimizer, **finetune_config.lr_scheduler) |
|
|
| dataloaders = load_data(finetune_config.dataset, finetune_config.dataloader) |
| train_loader = dataloaders[finetune_config.trainer.train_split] |
| eval_loader = dataloaders[finetune_config.trainer.val_split] |
|
|
| OurTrainer = get_trainer(finetune_config.trainer.name) |
| trainer = OurTrainer(model=model, |
| args=args, |
| train_loader=train_loader, |
| eval_loader=eval_loader, |
| optimizer_and_scheduler=(optimizer, scheduler), |
| device=device, |
| wandb=wandb, |
| checkpoint_suffix='_ft', |
| **finetune_config.trainer) |
| return trainer |
|
|