| from dataclasses import dataclass, field |
| from functools import reduce |
| from typing import Callable, Dict, List, Optional, Tuple, Union, Any |
|
|
| import numpy as np |
| import torch |
| import wandb |
| import torch.nn as nn |
| from torch.utils.data import Dataset |
|
|
| from transformers import Trainer, Seq2SeqTrainingArguments |
| from transformers.data.data_collator import DataCollator |
| from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS |
| from transformers.trainer import ( |
| EvalPrediction, |
| PreTrainedModel, |
| PreTrainedTokenizerBase, |
| TrainerCallback, |
| ) |
| from transformers.trainer_pt_utils import get_parameter_names |
| from transformers.utils import is_sagemaker_mp_enabled, logging |
| from peft.tuners.lora.layer import Linear as LoraLinear |
|
|
| |
| include_keywords = ["encoder.block.2","encoder.block.3","encoder.block.4"] |
| |
| do_log = False |
|
|
|
|
| def get_forward_hook(name): |
| def hook(module, input, output): |
| wandb.log( |
| { |
| f"{name}/input_mean": input[0].mean().item(), |
| f"{name}/input_std": input[0].std().item(), |
| f"{name}/output_mean": output.mean().item(), |
| f"{name}/output_std": output.std().item(), |
| }, |
| commit=False, |
| ) |
| return hook |
|
|
| class LogTrainer(Trainer): |
| def __init__( |
| self, |
| model: Union[PreTrainedModel, nn.Module] = None, |
| args: Seq2SeqTrainingArguments = None, |
| data_collator: Optional[DataCollator] = None, |
| train_dataset: Optional[Dataset] = None, |
| eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, |
| tokenizer: Optional[PreTrainedTokenizerBase] = None, |
| model_init: Optional[Callable[[], PreTrainedModel]] = None, |
| compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, |
| callbacks: Optional[List[TrainerCallback]] = None, |
| optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( |
| None, |
| None, |
| ), |
| preprocess_logits_for_metrics: Optional[ |
| Callable[[torch.Tensor, torch.Tensor], torch.Tensor] |
| ] = None, |
| ): |
| super().__init__( |
| model, |
| args, |
| data_collator, |
| train_dataset, |
| eval_dataset, |
| tokenizer, |
| model_init, |
| compute_metrics, |
| callbacks, |
| optimizers, |
| preprocess_logits_for_metrics, |
| ) |
| self.is_peft = "PeftModel" in type(model).__name__ |
| if self.is_peft: |
| for name, module in model.named_modules(): |
| if isinstance(module, LoraLinear): |
| self.scaling = module.scaling["default"] |
| break |
| self.orig_A = None |
| self.orig_B = None |
| self.orig_W = None |
| self.gradient_accumulation_counter = 0 |
|
|
| def training_step( |
| self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] |
| ) -> torch.Tensor: |
| if not do_log: |
| return super().training_step(model, inputs) |
| if self.is_peft: |
| if self.orig_A is None: |
| self.orig_A = {} |
| self.orig_B = {} |
| for name, param in model.named_parameters(): |
| if param.requires_grad and any( |
| [kw in name for kw in include_keywords] |
| ): |
| if "lora_A" in name: |
| self.orig_A[name.split("lora_A.")[0]] = ( |
| param.detach().clone() |
| ) |
| elif "lora_B" in name: |
| self.orig_B[name.split("lora_B.")[0]] = ( |
| param.detach().clone() |
| ) |
| for name, module in model.named_modules(): |
| if any([kw in name for kw in include_keywords]) and isinstance(module, LoraLinear): |
| breakpoint() |
| hook = get_forward_hook(name) |
| module.register_forward_hook(hook) |
| else: |
| if self.orig_W is None: |
| self.orig_W = {} |
| for name, param in model.named_parameters(): |
| if param.requires_grad and any( |
| [kw in name for kw in include_keywords] |
| ): |
| self.orig_W[name] = param.detach().clone() |
|
|
| model.train() |
| inputs = self._prepare_inputs(inputs) |
|
|
| with self.compute_loss_context_manager(): |
| loss = self.compute_loss(model, inputs) |
|
|
| if self.args.n_gpu > 1: |
| loss = loss.mean() |
|
|
| self.accelerator.backward(loss) |
| with torch.no_grad(): |
| if ( |
| self.gradient_accumulation_counter |
| % self.args.gradient_accumulation_steps |
| == self.args.gradient_accumulation_steps - 1 |
| ): |
| if self.is_peft: |
| A_dict = {} |
| B_dict = {} |
| for name, param in model.named_parameters(): |
| if param.requires_grad and any( |
| [kw in name for kw in include_keywords] |
| ): |
| if "lora_A" in name: |
| A_dict[name.split("lora_A.")[0]] = param |
| elif "lora_B" in name: |
| B_dict[name.split("lora_B.")[0]] = param |
| assert ( |
| len(A_dict) |
| == len(self.orig_A) |
| == len(B_dict) |
| == len(self.orig_B) |
| ), ( |
| len(A_dict), |
| len(self.orig_A), |
| len(B_dict), |
| len(self.orig_B), |
| ) |
| for key in A_dict.keys(): |
| A = A_dict[key] |
| B = B_dict[key] |
| lora_r = A.shape[0] |
| A_grad = A_dict[key].grad |
| B_grad = B_dict[key].grad |
| A_0 = self.orig_A[key] |
| B_0 = self.orig_B[key] |
| A_diff = A - A_0 |
| B_diff = B - B_0 |
| BA = torch.matmul(B, A) |
| BA_0 = torch.matmul(B_0, A_0) |
| BA_diff = BA - BA_0 |
| BA_diff_norm = torch.norm(BA_diff).item() |
| A_diff_norm = torch.norm(A_diff).item() |
| B_diff_norm = torch.norm(B_diff).item() |
| A_norm = torch.norm(A).item() |
| B_norm = torch.norm(B).item() |
| A_grad_norm = torch.norm(A_grad).item() |
| B_grad_norm = torch.norm(B_grad).item() |
| |
| BA_singular_values = torch.svd_lowrank( |
| BA_diff.float(), q=2 * lora_r |
| )[1][:lora_r] |
| top_1_ratio = ( |
| BA_singular_values[0] / BA_singular_values.sum() |
| ).item() |
| top_4_ratio = ( |
| BA_singular_values[:4].sum() / BA_singular_values.sum() |
| ).item() |
| wandb.log( |
| { |
| f"A_norm/{key}": A_norm, |
| f"B_norm/{key}": B_norm, |
| f"A_grad_norm/{key}": A_grad_norm, |
| f"B_grad_norm/{key}": B_grad_norm, |
| f"A_diff_norm/{key}": A_diff_norm, |
| f"B_diff_norm/{key}": B_diff_norm, |
| f"BA_diff_norm/{key}": BA_diff_norm, |
| f"scaled_BA_diff_norm/{key}": self.scaling |
| * BA_diff_norm, |
| f"BA_top_1_ratio/{key}": top_1_ratio, |
| f"BA_top_4_ratio/{key}": top_4_ratio, |
| "train/global_step": self.state.global_step, |
| } |
| ) |
| else: |
| W_dict = {} |
| for name, param in model.named_parameters(): |
| if ( |
| param.requires_grad |
| and any([kw in name for kw in include_keywords]) |
| and len(param.shape) == 2 |
| ): |
| W_dict[name] = param |
| for key in W_dict.keys(): |
| W = W_dict[key] |
| W_grad = W.grad |
| W_0 = self.orig_W[key] |
| W_diff = W - W_0 |
| W_diff_norm = torch.norm(W_diff).item() |
| W_norm = torch.norm(W).item() |
| W_grad_norm = torch.norm(W_grad).item() |
| U, S, V = torch.svd(W_diff.float()) |
| top_1_ratio = S[0] / S.sum() |
| top_4_ratio = S[:4].sum() / S.sum() |
| wandb.log( |
| { |
| f"W_norm/{key}": W_norm, |
| f"W_grad_norm/{key}": W_grad_norm, |
| f"W_diff_norm/{key}": W_diff_norm, |
| "train/global_step": self.state.global_step, |
| f"W_top_1_ratio/{key}": top_1_ratio.item(), |
| f"W_top_4_ratio/{key}": top_4_ratio.item(), |
| } |
| ) |
| self.gradient_accumulation_counter += 1 |
|
|
| return loss.detach() / self.args.gradient_accumulation_steps |
|
|