|
|
| import importlib
|
| import math
|
| import numpy as np
|
| import os
|
| import re
|
| import shutil
|
| import sys
|
| import time
|
| import warnings
|
| from dataclasses import asdict, dataclass, field
|
| from enum import Enum
|
| from packaging import version
|
| from typing import List, Literal, Optional, Union, Dict, Tuple
|
|
|
| import peft
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from accelerate import dispatch_model, infer_auto_device_map
|
| from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules
|
| from accelerate.utils import get_balanced_memory
|
| from huggingface_hub import hf_hub_download
|
| from peft import LoftQConfig, LoraRuntimeConfig, PeftConfig, PeftModelForCausalLM, PeftType, set_peft_model_state_dict
|
| from peft.utils import transpose
|
| from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
|
| from transformers.pytorch_utils import Conv1D
|
| from transformers.trainer import Trainer, TrainerState, TrainOutput, has_length, is_sagemaker_mp_enabled, get_model_param_count, speed_metrics, deepspeed_init, TRAINER_STATE_NAME
|
| from transformers.trainer_callback import ExportableState
|
| from transformers.trainer_pt_utils import IterableDatasetShard
|
| from transformers.utils import logging, is_torch_xla_available, is_apex_available
|
|
|
|
|
| if is_apex_available():
|
| from apex import amp
|
|
|
| def is_bnb_available():
|
| return importlib.util.find_spec("bitsandbytes") is not None
|
|
|
| if is_bnb_available():
|
| import bitsandbytes as bnb
|
| from bitsandbytes.nn.modules import Int8Params
|
|
|
| parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
|
| is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11")
|
| logger = logging.get_logger(__name__)
|
|
|
| WEIGHTS_NAME = "adapter_model.bin"
|
| CONFIG_NAME = "adapter_config.json"
|
| NUM_ATTENTION_HEADS = 32
|
| pruning_groups = {'self_attn': ['q_proj', 'k_proj', 'v_proj', 'o_proj'],
|
| 'mlp': ['up_proj', 'gate_proj'],
|
| 'block': ['o_proj', 'down_proj']}
|
|
|
|
|
| @dataclass
|
| class LoraConfig(PeftConfig):
|
| """Configuration class to store the configuration of a LoRA model."""
|
| r: int = field(default=8, metadata={"help": "Lora attention dimension"})
|
| target_modules: Optional[Union[List[str], str]] = field(
|
| default=None,
|
| metadata={"help": "List of module names or regex expression of the module names to replace with Lora."}
|
| )
|
| lora_alpha: int = field(default=None, metadata={"help": "Lora alpha"})
|
| lora_dropout: float = field(default=None, metadata={"help": "Lora dropout"})
|
| merge_weights: bool = field(default=False, metadata={"help": "Merge weights of the original model and the Lora model"})
|
| fan_in_fan_out: bool = field(default=False, metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"})
|
| enable_lora: Optional[List[bool]] = field(default=None, metadata={"help": "Used with `lora.MergedLinear`."})
|
| bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"})
|
| lora_bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"})
|
| modules_to_save: Optional[List[str]] = field(default=None, metadata={"help": "List of modules apart from LoRA layers to be set as trainable and saved."})
|
| use_rslora: bool = field(default=False, metadata={"help": "Use Rank-Stabilized LoRA."})
|
| init_lora_weights: Union[bool, Literal["gaussian", "olora", "pissa", "pissa_niter_[number of iters]", "loftq"]] = field(
|
| default=True, metadata={"help": "How to initialize the weights of the LoRA layers."}
|
| )
|
| layers_to_transform: Optional[Union[List[int], int]] = field(default=None, metadata={"help": "The layer indexes to transform."})
|
| layer_replication: Optional[List[Tuple[int, int]]] = field(default=None, metadata={"help": "Layer replication configuration."})
|
| rank_pattern: Optional[Dict] = field(default_factory=dict, metadata={"help": "Mapping from layer names to ranks."})
|
| alpha_pattern: Optional[Dict] = field(default_factory=dict, metadata={"help": "Mapping from layer names to alphas."})
|
| megatron_config: Optional[Dict] = field(default=None, metadata={"help": "TransformerConfig from Megatron."})
|
| megatron_core: Optional[str] = field(default="megatron.core", metadata={"help": "Core module from Megatron."})
|
| loftq_config: Union[LoftQConfig, Dict] = field(default_factory=dict, metadata={"help": "Configuration of LoftQ."})
|
| use_dora: bool = field(default=False, metadata={"help": "Enable Weight-Decomposed Low-Rank Adaptation (DoRA)."})
|
| runtime_config: LoraRuntimeConfig = field(default_factory=LoraRuntimeConfig, metadata={"help": "Runtime configurations"})
|
| _custom_modules: Optional[Dict] = field(default=None, metadata={"help": "Custom module mapping for LoRA."})
|
|
|
| def __post_init__(self):
|
| self.peft_type = PeftType.LORA
|
| self.target_modules = set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
|
| if isinstance(self.target_modules, str) and (self.layers_to_transform is not None or hasattr(self, 'layers_pattern') and self.layers_pattern is not None):
|
| raise ValueError("`layers_to_transform` or `layers_pattern` cannot be used when `target_modules` is a str.")
|
| if self.use_dora and self.megatron_config:
|
| raise ValueError("DoRA does not support megatron_core, set `use_dora=False`.")
|
| if self.init_lora_weights == "loftq":
|
| if not importlib.util.find_spec("scipy"):
|
| raise ImportError("The required package 'scipy' is not installed.")
|
| if self.loftq_config is None:
|
| raise ValueError("`loftq_config` must be specified when `init_lora_weights` is 'loftq'.")
|
| if self.use_rslora and (self.rank_pattern or self.alpha_pattern) and (isinstance(self.init_lora_weights, str) and (self.init_lora_weights.startswith("pissa") or self.init_lora_weights == "olora")):
|
| warnings.warn("Using Rank-Stabilized LoRA with rank_pattern/alpha_pattern and post-training conversion may conflict.")
|
| if self.loftq_config and not isinstance(self.loftq_config, dict):
|
| self.loftq_config = vars(self.loftq_config)
|
|
|
| def to_dict(self):
|
| """Convert the config to a dictionary for compatibility with peft."""
|
| rv = asdict(self)
|
| rv.pop("runtime_config", None)
|
| rv["peft_type"] = self.peft_type.value
|
| return rv
|
|
|
| def update(self, updates):
|
| """Add an update method to allow merging with a dictionary."""
|
| for key, value in updates.items():
|
| if hasattr(self, key):
|
| setattr(self, key, value)
|
| else:
|
| raise KeyError(f"Cannot update {key}: not a valid attribute of LoraConfig")
|
|
|
|
|
| class LoraModel(torch.nn.Module):
|
| def __init__(self, config, model):
|
| super().__init__()
|
| self.peft_config = config
|
| self.model = model
|
| self._find_and_replace()
|
| mark_only_lora_as_trainable(self.model, self.peft_config.lora_bias)
|
| self.forward = self.model.forward
|
|
|
| def _find_and_replace(self):
|
| loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
|
| if loaded_in_8bit and not is_bnb_available():
|
| raise ImportError("To use Lora with 8-bit quantization, install `bitsandbytes`.")
|
| is_target_modules_in_base_model = False
|
| kwargs = {"r": self.peft_config.r, "lora_alpha": self.peft_config.lora_alpha, "lora_dropout": self.peft_config.lora_dropout,
|
| "fan_in_fan_out": self.peft_config.fan_in_fan_out, "merge_weights": self.peft_config.merge_weights or self.peft_config.inference_mode}
|
| key_list = [key for key, _ in self.model.named_modules()]
|
| for key in key_list:
|
| target_module_found = re.fullmatch(self.peft_config.target_modules, key) if isinstance(self.peft_config.target_modules, str) else \
|
| any(key.endswith(target_key) for target_key in self.peft_config.target_modules)
|
| if target_module_found:
|
| if not is_target_modules_in_base_model:
|
| is_target_modules_in_base_model = True
|
| parent, target, target_name = self._get_submodules(key)
|
| bias = target.bias is not None
|
| if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
|
| new_kwargs = kwargs.copy()
|
| new_kwargs.update({"has_fp16_weights": target.state.has_fp16_weights, "memory_efficient_backward": target.state.memory_efficient_backward,
|
| "threshold": target.state.threshold, "index": target.index})
|
| new_module = Linear8bitLt(target.in_features, target.out_features, bias=bias, **new_kwargs) if self.peft_config.enable_lora is None else \
|
| MergedLinear8bitLt(target.in_features, target.out_features, bias=bias, **new_kwargs.update({"enable_lora": self.peft_config.enable_lora}))
|
| elif isinstance(target, peft.tuners.lora.layer.Linear) and self.peft_config.enable_lora is None:
|
| new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs)
|
| elif isinstance(target, nn.Linear):
|
| new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs)
|
| elif self.peft_config.enable_lora is not None:
|
| new_kwargs = kwargs.copy()
|
| new_kwargs.update({"enable_lora": self.peft_config.enable_lora})
|
| in_features, out_features = (target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape) if isinstance(target, Conv1D) else \
|
| (target.in_features, target.out_features)
|
| if new_kwargs["fan_in_fan_out"] and not isinstance(target, Conv1D):
|
| warnings.warn("fan_in_fan_out is set to True but the target module is not a Conv1D. Setting to False.")
|
| new_kwargs["fan_in_fan_out"] = False
|
| new_module = Linear(in_features, out_features, bias=bias, **new_kwargs)
|
| else:
|
| raise ValueError("No valid condition met for creating a new module.")
|
| self._replace_module(parent, target_name, new_module, target)
|
| if not is_target_modules_in_base_model:
|
| raise ValueError(f"Target modules {self.peft_config.target_modules} not found in the base model.")
|
|
|
| def _get_submodules(self, key):
|
| parent = self.model.get_submodule(".".join(key.split(".")[:-1]))
|
| target_name = key.split(".")[-1]
|
| target = self.model.get_submodule(key)
|
| return parent, target, target_name
|
|
|
| def _replace_module(self, parent_module, child_name, new_module, old_module):
|
| setattr(parent_module, child_name, new_module)
|
| new_module.weight = old_module.weight
|
| if old_module.bias is not None:
|
| new_module.bias = old_module.bias
|
| if getattr(old_module, "state", None) is not None:
|
| new_module.state = old_module.state
|
| new_module.to(old_module.weight.device)
|
|
|
| def __getattr__(self, name: str):
|
| try:
|
| return super().__getattr__(name)
|
| except AttributeError:
|
| return getattr(self.model, name)
|
|
|
| @property
|
| def modules_to_save(self):
|
| return None
|
|
|
| def get_peft_config_as_dict(self, inference: bool = False):
|
| config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(self.peft_config).items()}
|
| if inference:
|
| config["inference_mode"] = True
|
| return config
|
|
|
| def _set_adapter_layers(self, enabled=True):
|
| for module in self.model.modules():
|
| if isinstance(module, LoraLayer):
|
| module.disable_adapters = False if enabled else True
|
|
|
| def enable_adapter_layers(self):
|
| self._set_adapter_layers(enabled=True)
|
|
|
| def disable_adapter_layers(self):
|
| self._set_adapter_layers(enabled=False)
|
|
|
| class LoraLayer:
|
| def __init__(self, r: int, lora_alpha: int, lora_dropout: float, merge_weights: bool):
|
| self.r = r
|
| self.lora_alpha = lora_alpha
|
| self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0.0 else lambda x: x
|
| self.merged = False
|
| self.merge_weights = merge_weights
|
| self.disable_adapters = False
|
|
|
| class Linear(nn.Linear, LoraLayer):
|
| def __init__(self, in_features: int, out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0,
|
| fan_in_fan_out: bool = False, merge_weights: bool = True, **kwargs):
|
| nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
| LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
|
| self.fan_in_fan_out = fan_in_fan_out
|
| self.lora_mask = nn.Parameter(torch.ones(out_features), requires_grad=False)
|
| if r > 0:
|
| self.lora_A = nn.Linear(in_features, r, bias=False)
|
| self.lora_B = nn.Linear(r, out_features, bias=False)
|
| self.scaling = self.lora_alpha / self.r
|
| self.weight.requires_grad = False
|
| self.is_prune = True
|
| self.reset_parameters()
|
| if fan_in_fan_out:
|
| self.weight.data = self.weight.data.T
|
|
|
| def reset_parameters(self):
|
| nn.Linear.reset_parameters(self)
|
| if hasattr(self, "lora_A"):
|
| nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
|
| nn.init.zeros_(self.lora_B.weight)
|
|
|
| def train(self, mode: bool = True):
|
| nn.Linear.train(self, mode)
|
| self.lora_A.train(mode)
|
| self.lora_B.train(mode)
|
| if not mode and self.merge_weights and not self.merged:
|
| if self.r > 0:
|
| self.weight.data += transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
|
| self.merged = True
|
| elif self.merge_weights and self.merged:
|
| if self.r > 0:
|
| self.weight.data -= transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
|
| self.merged = False
|
|
|
| def eval(self):
|
| nn.Linear.eval(self)
|
| self.lora_A.eval()
|
| self.lora_B.eval()
|
|
|
| def forward(self, x: torch.Tensor):
|
| if self.disable_adapters:
|
| if self.r > 0 and self.merged:
|
| self.weight.data -= transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
|
| self.merged = False
|
| result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
|
| elif self.r > 0 and not self.merged:
|
| result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
|
| if self.r > 0:
|
| lora_output = self.lora_B(self.lora_A(self.lora_dropout(x).to(self.lora_A.weight.dtype))) * self.scaling
|
| result += lora_output.to(result.dtype)
|
| else:
|
| result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
|
| if hasattr(self, 'lora_mask'):
|
| result *= self.lora_mask.reshape(1, 1, -1)
|
| return result
|
|
|
| if is_bnb_available():
|
| class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer):
|
| def __init__(self, in_features, out_features, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, **kwargs):
|
| bnb.nn.Linear8bitLt.__init__(self, in_features, out_features, bias=kwargs.get("bias", True),
|
| has_fp16_weights=kwargs.get("has_fp16_weights", True),
|
| memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
|
| threshold=kwargs.get("threshold", 0.0), index=kwargs.get("index", None))
|
| LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
|
| if r > 0:
|
| self.lora_A = nn.Linear(in_features, r, bias=False)
|
| self.lora_B = nn.Linear(r, out_features, bias=False)
|
| self.scaling = self.lora_alpha / self.r
|
| self.weight.requires_grad = False
|
| self.lora_mask = nn.Parameter(torch.ones(out_features), requires_grad=False)
|
| self.is_prune = True
|
| self.reset_parameters()
|
|
|
| def reset_parameters(self):
|
| if hasattr(self, "lora_A"):
|
| nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
|
| nn.init.zeros_(self.lora_B.weight)
|
|
|
| def forward(self, x: torch.Tensor):
|
| result = super().forward(x)
|
| if self.disable_adapters:
|
| return result
|
| elif self.r > 0:
|
| if not torch.is_autocast_enabled():
|
| expected_dtype = result.dtype
|
| if x.dtype != torch.float32:
|
| x = x.float()
|
| output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling
|
| result += output
|
| else:
|
| output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
|
| result += output
|
| if hasattr(self, 'lora_mask'):
|
| result *= self.lora_mask.reshape(1, 1, -1)
|
| return result
|
|
|
| def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
|
| for n, p in model.named_parameters():
|
| if "lora_" not in n:
|
| p.requires_grad = False
|
| if bias == "none":
|
| return
|
| elif bias == "all":
|
| for n, p in model.named_parameters():
|
| if "bias" in n:
|
| p.requires_grad = True
|
| elif bias == "lora_only":
|
| for m in model.modules():
|
| if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None:
|
| m.bias.requires_grad = True
|
| else:
|
| raise NotImplementedError
|
|
|
|
|
| class LoraPeftModelForCausalLM(PeftModelForCausalLM):
|
| def __init__(self, model, peft_config):
|
| super().__init__(model, peft_config)
|
| self.base_model = LoraModel(peft_config, model)
|
|
|
| @classmethod
|
| def from_pretrained(cls, model, model_id, **kwargs):
|
| from peft.mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING
|
| config = PEFT_TYPE_TO_CONFIG_MAPPING[PeftConfig.from_pretrained(model_id).peft_type].from_pretrained(model_id)
|
| model = LoraPeftModelForCausalLM(model, config)
|
| if os.path.exists(os.path.join(model_id, WEIGHTS_NAME)):
|
| print("loading .................")
|
| filename = os.path.join(model_id, WEIGHTS_NAME)
|
| else:
|
| try:
|
| filename = hf_hub_download(model_id, WEIGHTS_NAME)
|
| except:
|
| raise ValueError(f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub.")
|
| adapters_weights = torch.load(filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
| model = set_peft_model_state_dict(model, adapters_weights)
|
| if getattr(model, "hf_device_map", None) is not None:
|
| device_map = kwargs.get("device_map", "auto")
|
| max_memory = kwargs.get("max_memory", None)
|
| no_split_module_classes = model._no_split_modules
|
| if device_map != "sequential":
|
| max_memory = get_balanced_memory(model, max_memory=max_memory, no_split_module_classes=no_split_module_classes,
|
| low_zero=(device_map == "balanced_low_0"))
|
| if isinstance(device_map, str):
|
| device_map = infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=no_split_module_classes)
|
| model = dispatch_model(model, device_map=device_map)
|
| hook = AlignDevicesHook(io_same_device=True)
|
| if model.peft_config.peft_type == PeftType.LORA:
|
| add_hook_to_module(model.base_model.model, hook)
|
| else:
|
| remove_hook_from_submodules(model.prompt_encoder)
|
| add_hook_to_module(model.base_model, hook)
|
| return model
|
|
|
| @property
|
| def active_peft_config(self):
|
| return self.peft_config
|
|
|
| def get_peft_model(model, peft_config):
|
| """Returns a PEFT model object from a model and a config."""
|
| model_config = model.config.to_dict()
|
| peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)
|
| return LoraPeftModelForCausalLM(model, peft_config)
|
|
|
|
|
| def _is_target_larer(module):
|
| return (isinstance(module, Linear) or isinstance(module, Linear8bitLt)) and module.is_prune
|
|
|
| def unfreeze(model):
|
| for name, module in model.named_modules():
|
| if _is_target_larer(module):
|
| module.weight.requires_grad = True
|
|
|
| def freeze(model):
|
| layers = len(model.model.model.layers)
|
| freeze_layer = int(layers * 0.1)
|
| for name, module in model.named_modules():
|
| if _is_target_larer(module):
|
| layer = int(name.split('.')[4])
|
| if layer < freeze_layer or layer == layers-1:
|
| module.is_prune = False
|
|
|
| def init_sensitivity_dict(model):
|
| sensitivity_record = {}
|
| for name, module in model.named_modules():
|
| if _is_target_larer(module):
|
| module_name = name.split('.')[-1]
|
| if module_name in pruning_groups['self_attn']:
|
| head_dim = module.out_features // NUM_ATTENTION_HEADS
|
| groups = module.out_features // head_dim
|
| else:
|
| groups = module.out_features
|
| name = ".".join(name.split('.')[:-1])
|
| if name in sensitivity_record:
|
| continue
|
| sensitivity_record[name] = module.lora_A.weight.data.new_zeros(groups)
|
| return sensitivity_record
|
|
|
| def update_sensitivity_dict(model, s_dict, pruning_type):
|
| s_all = init_sensitivity_dict(model)
|
| for name, module in model.named_modules():
|
| if _is_target_larer(module):
|
| is_attn = name.split('.')[-1] in pruning_groups['self_attn']
|
| fan_in = name.split('.')[-1] in pruning_groups['block']
|
| s = compute_sensitivity(module, is_attn, pruning_type, fan_in)
|
| name = ".".join(name.split('.')[:-1])
|
| s_all[name] += s
|
| for name, imp in s_all.items():
|
| if torch.isnan(imp.sum()):
|
| return s_dict
|
| for name, imp in s_dict.items():
|
| s_dict[name] = imp * 0.9 + s_all[name] * 0.1
|
| return s_dict
|
|
|
| def compute_sensitivity(layer, is_attn, prune_metric='lora', transpose=False, norm=True):
|
| a = layer.lora_A.weight.data
|
| b = layer.lora_B.weight.data
|
| if prune_metric == 'lora':
|
| grad_a = layer.lora_A.weight.grad
|
| grad_b = layer.lora_B.weight.grad
|
| grad = (grad_b @ a + b @ grad_a - grad_b @ grad_a)
|
| elif prune_metric == 'magnitude':
|
| grad = 1
|
| elif prune_metric == 'grad':
|
| grad = layer.weight.grad
|
| else:
|
| raise NotImplementedError
|
| if hasattr(layer, 'state'):
|
| weight = (layer.weight.data * layer.state.SCB.reshape(-1, 1)) / 127
|
| else:
|
| weight = layer.weight.data
|
| s = (grad * (b @ a * layer.scaling + weight)).abs()
|
| if transpose:
|
| s = s.t()
|
| if is_attn:
|
| head_dim = layer.out_features // NUM_ATTENTION_HEADS
|
| s = s.reshape(s.shape[0] // head_dim, -1)
|
| s = s.sum(1)
|
| if norm:
|
| s = s / (torch.linalg.norm(s) + 1e-8)
|
| return s
|
|
|
| def prune_fp16_module(module, mask, transpose):
|
| mask = mask.bool()
|
| module.train()
|
| if not transpose:
|
| module.weight.data = module.weight.data[mask]
|
| module.out_features = int(mask.sum())
|
| if module.bias:
|
| module.bias.data = module.bias.data[mask]
|
| module.lora_B.weight.data = module.lora_B.weight.data[mask]
|
| module.lora_B.out_features = int(mask.sum())
|
| else:
|
| module.weight.data = module.weight.data[:, mask]
|
| module.in_features = int(mask.sum())
|
| module.lora_A.weight.data = module.lora_A.weight.data[:, mask]
|
| module.lora_A.in_features = int(mask.sum())
|
| module.merge_weights = True
|
| module.train(False)
|
|
|
| def prune_one_layer(layer):
|
| prune_fp16_module(layer.self_attn.q_proj, layer.self_attn.q_proj.lora_mask, False)
|
| prune_fp16_module(layer.self_attn.k_proj, layer.self_attn.k_proj.lora_mask, False)
|
| prune_fp16_module(layer.self_attn.v_proj, layer.self_attn.v_proj.lora_mask, False)
|
| prune_fp16_module(layer.self_attn.o_proj, layer.self_attn.q_proj.lora_mask, True)
|
| layer.self_attn.num_heads = int(layer.self_attn.q_proj.lora_mask.sum()) // 128
|
| layer.self_attn.hidden_size = int(layer.self_attn.q_proj.lora_mask.sum())
|
| prune_fp16_module(layer.mlp.gate_proj, layer.mlp.gate_proj.lora_mask, False)
|
| prune_fp16_module(layer.mlp.up_proj, layer.mlp.up_proj.lora_mask, False)
|
| prune_fp16_module(layer.mlp.down_proj, layer.mlp.gate_proj.lora_mask, True)
|
| del(layer.self_attn.q_proj.lora_mask)
|
| del(layer.self_attn.k_proj.lora_mask)
|
| del(layer.self_attn.v_proj.lora_mask)
|
| del(layer.mlp.gate_proj.lora_mask)
|
| del(layer.mlp.up_proj.lora_mask)
|
|
|
| def prune(model):
|
| for layer_id, layer in enumerate(model.model.model.layers):
|
| print(f"pruning layer {layer_id}")
|
| prune_one_layer(layer)
|
|
|
| def local_prune(model, s_dict, ratio, target_ratio):
|
| original_param_num = pruned_param_num = 0
|
| for name, module in model.named_modules():
|
| if _is_target_larer(module):
|
| original_param_num += np.prod(module.weight.shape)
|
| pruned_param_num += np.prod(module.weight.shape) * ratio
|
| is_attn = name.split('.')[-1] in pruning_groups['self_attn']
|
| if name.split('.')[-1] in pruning_groups['block']:
|
| continue
|
| name = ".".join(name.split('.')[:-1])
|
| if not hasattr(module, 'lora_mask') or (1-module.lora_mask.mean()).item() >= target_ratio:
|
| continue
|
| total_num = module.lora_mask.numel()
|
| c_mask = module.lora_mask.data
|
| mask = torch.ones_like(c_mask)
|
| if is_attn:
|
| head_dim = module.out_features // NUM_ATTENTION_HEADS
|
| mask = mask.reshape(-1, head_dim)[:, 0]
|
| c_mask = c_mask.reshape(-1, head_dim)[:, 0]
|
| total_num /= head_dim
|
| need_prune_num = int(total_num * ratio)
|
| importance = s_dict[name] * c_mask
|
| can_prune = torch.argsort(importance)[:need_prune_num]
|
| mask[can_prune] = 0
|
| if is_attn:
|
| mask = (mask.new_ones(module.lora_mask.shape).reshape(-1, head_dim) * mask.unsqueeze(1)).reshape(-1)
|
| module.lora_mask.data = mask
|
| else:
|
| if hasattr(module, 'weight'):
|
| original_param_num += np.prod(module.weight.shape)
|
| print(f"pruned/original parameters number:{pruned_param_num*1e-9:.3f}/{original_param_num*1e-9:.3f} ratio:{pruned_param_num/original_param_num:.3f}")
|
|
|
| def schedule_sparsity_ratio(step, total_step, initial_warmup, final_warmup, initial_sparsity, final_sparsity):
|
| if step <= initial_warmup * total_step:
|
| sparsity = initial_sparsity
|
| elif step > (total_step - final_warmup * total_step):
|
| sparsity = final_sparsity
|
| else:
|
| spars_warmup_steps = initial_warmup * total_step
|
| spars_schedu_steps = (final_warmup + initial_warmup) * total_step
|
| mul_coeff = 1 - (step - spars_warmup_steps) / (total_step - spars_schedu_steps)
|
| sparsity = final_sparsity + (initial_sparsity - final_sparsity) * (mul_coeff ** 3)
|
| return sparsity
|
|
|
| def prune_from_checkpoint(model):
|
| prune(model)
|
|
|
| def print_trainable_parameters(model):
|
| total_params = trainable_params = 0
|
| for n, p in model.named_parameters():
|
| if p.requires_grad:
|
| trainable_params += p.numel()
|
| total_params += p.numel()
|
| print(f"total params:{total_params * 1e-6:.2f}M trainable params:{trainable_params * 1e-6:.2f}M ratio:{trainable_params / total_params:.3f}")
|
|
|
| class LoRAPruneTrainer(Trainer):
|
| def __init__(self, model, train_dataset, eval_dataset, args, data_collator, ratio, init_ratio, warmup_iters,
|
| cooldown_iters, prune_freq, prune_metric, **kwargs):
|
| super().__init__(
|
| model=model,
|
| train_dataset=train_dataset,
|
| eval_dataset=eval_dataset,
|
| args=args,
|
| data_collator=data_collator,
|
| **kwargs
|
| )
|
| self.ratio = ratio
|
| self.init_ratio = init_ratio
|
| self.warmup_iters = warmup_iters
|
| self.cooldown_iters = cooldown_iters
|
| self.prune_freq = prune_freq
|
| self.prune_metric = prune_metric
|
|
|
| def compute_loss(self, model, inputs, return_outputs=False):
|
|
|
| labels = inputs.get("labels")
|
| outputs = model(**inputs)
|
| logits = outputs.logits
|
|
|
|
|
| loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
|
| shift_logits = logits[..., :-1, :].contiguous()
|
| shift_labels = labels[..., 1:].contiguous()
|
| loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
|
| return (loss, outputs) if return_outputs else loss
|
|
|
| def _inner_training_loop(self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None):
|
| self._train_batch_size = batch_size
|
| train_dataloader = self.get_train_dataloader()
|
| total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
|
|
|
| len_dataloader = len(train_dataloader) if has_length(train_dataloader) else None
|
| if len_dataloader:
|
| num_update_steps_per_epoch = max(len_dataloader // args.gradient_accumulation_steps, 1)
|
| num_examples = self.num_examples(train_dataloader)
|
| if args.max_steps > 0:
|
| max_steps = args.max_steps
|
| num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(args.max_steps % num_update_steps_per_epoch > 0)
|
| num_train_samples = args.max_steps * total_train_batch_size
|
| else:
|
| max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
|
| num_train_epochs = math.ceil(args.num_train_epochs)
|
| num_train_samples = num_examples * args.num_train_epochs
|
| elif args.max_steps > 0:
|
| max_steps = args.max_steps
|
| num_train_epochs = sys.maxsize
|
| num_update_steps_per_epoch = max_steps
|
| num_examples = num_train_samples = args.max_steps * total_train_batch_size
|
| else:
|
| raise ValueError(f"args.max_steps must be set to a positive value if dataloader does not have a length, was {args.max_steps}")
|
|
|
| if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug and self.args.n_gpu > 1:
|
| raise ValueError("Currently --debug underflow_overflow is not supported under DP. Please use DDP.")
|
| else:
|
| debug_overflow = DebugUnderflowOverflow(self.model) if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug else None
|
|
|
| delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
|
| if args.deepspeed:
|
| deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint)
|
| self.model, self.model_wrapped, self.deepspeed, self.optimizer, self.lr_scheduler = deepspeed_engine.module, deepspeed_engine, deepspeed_engine, optimizer, lr_scheduler
|
| elif not delay_optimizer_creation:
|
| self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
|
|
| self.state = TrainerState(stateful_callbacks=[cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)])
|
| self.state.is_hyper_param_search = trial is not None
|
| if args.gradient_checkpointing:
|
| self.model.gradient_checkpointing_enable()
|
| model = self._wrap_model(self.model_wrapped)
|
| if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
|
| self._load_from_checkpoint(resume_from_checkpoint, model)
|
| if model is not self.model:
|
| self.model_wrapped = model
|
| if delay_optimizer_creation:
|
| self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
| self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
|
|
| total_params = kept_params = sum([p.numel() if not p.requires_grad else 0 for p in model.parameters()])
|
| logger.info("***** Running training *****")
|
| logger.info(f" Num examples = {num_examples:,}")
|
| logger.info(f" Num Epochs = {num_train_epochs:,}")
|
| logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size:,}")
|
| logger.info(f" Total train batch size = {total_train_batch_size:,}")
|
| logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| logger.info(f" Total optimization steps = {max_steps:,}")
|
| logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
|
|
|
| self.state.epoch, start_time, epochs_trained, steps_trained_in_current_epoch = 0, time.time(), 0, 0
|
| if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)):
|
| self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
|
| epochs_trained = self.state.global_step // num_update_steps_per_epoch
|
| steps_trained_in_current_epoch = (self.state.global_step % num_update_steps_per_epoch) * args.gradient_accumulation_steps if not args.ignore_data_skip else 0
|
| logger.info(f" Continuing training from epoch {epochs_trained}, global step {self.state.global_step}")
|
|
|
| self.callback_handler.model, self.callback_handler.optimizer, self.callback_handler.lr_scheduler, self.callback_handler.train_dataloader = self.model, self.optimizer, self.lr_scheduler, train_dataloader
|
| if self.hp_name is not None and self._trial is not None:
|
| self.state.trial_name = self.hp_name(self._trial)
|
| self.state.trial_params, self.state.max_steps, self.state.num_train_epochs = None, max_steps, num_train_epochs
|
| self.state.is_local_process_zero, self.state.is_world_process_zero = self.is_local_process_zero(), self.is_world_process_zero()
|
|
|
| tr_loss = torch.tensor(0.0).to(args.device)
|
| self._total_loss_scalar, self._globalstep_last_logged = 0.0, self.state.global_step
|
| model.zero_grad()
|
| self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
|
|
| if not args.ignore_data_skip:
|
| for epoch in range(epochs_trained):
|
| is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(train_dataloader.sampler, RandomSampler)
|
| if is_torch_less_than_1_11 or not is_random_sampler:
|
| for _ in train_dataloader:
|
| break
|
| else:
|
| _ = list(train_dataloader.sampler)
|
|
|
| total_batched_samples = 0
|
| if self.prune_metric == 'grad':
|
| unfreeze(model)
|
| sensitivity_dict = init_sensitivity_dict(model)
|
|
|
| for epoch in range(epochs_trained, num_train_epochs):
|
| if isinstance(train_dataloader, torch.utils.data.DataLoader) and isinstance(train_dataloader.sampler, torch.utils.data.distributed.DistributedSampler):
|
| train_dataloader.sampler.set_epoch(epoch)
|
| elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
|
| train_dataloader.dataset.set_epoch(epoch)
|
| epoch_iterator = train_dataloader
|
| if args.past_index >= 0:
|
| self._past = None
|
| steps_in_epoch = len(epoch_iterator) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps
|
| self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
|
|
|
| if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
|
| self._load_rng_state(resume_from_checkpoint)
|
|
|
| rng_to_sync, steps_skipped = False, 0
|
| for step, inputs in enumerate(epoch_iterator):
|
| total_batched_samples += 1
|
| if rng_to_sync:
|
| self._load_rng_state(resume_from_checkpoint)
|
| rng_to_sync = False
|
| if steps_trained_in_current_epoch > 0:
|
| steps_trained_in_current_epoch -= 1
|
| continue
|
| if step % args.gradient_accumulation_steps == 0:
|
| self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
|
|
|
| tr_loss_step = self.training_step(model, inputs)
|
| if args.logging_nan_inf_filter and not is_torch_xla_available() and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)):
|
| tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
|
| else:
|
| tr_loss += tr_loss_step
|
| self.current_flos += float(self.floating_point_ops(inputs))
|
|
|
| if self.deepspeed:
|
| self.deepspeed.step()
|
| if total_batched_samples % args.gradient_accumulation_steps == 0 or (steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch):
|
| if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
|
| if is_sagemaker_mp_enabled() and args.fp16:
|
| grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
|
| elif hasattr(self.optimizer, "clip_grad_norm"):
|
| grad_norm = self.optimizer.clip_grad_norm(args.max_grad_norm)
|
| elif hasattr(model, "clip_grad_norm_"):
|
| grad_norm = model.clip_grad_norm_(args.max_grad_norm)
|
| else:
|
| grad_norm = nn.utils.clip_grad_norm_(amp.master_params(self.optimizer) if self.use_apex else model.parameters(), args.max_grad_norm)
|
| if not self.deepspeed:
|
| sensitivity_dict = update_sensitivity_dict(model, sensitivity_dict, self.prune_metric)
|
| ratio = schedule_sparsity_ratio(self.state.global_step, self.state.max_steps, self.warmup_iters, self.cooldown_iters, self.init_ratio, self.ratio)
|
| if (self.state.global_step) % self.prune_freq == 0 and ratio > self.init_ratio and ratio < self.ratio:
|
| local_prune(model, sensitivity_dict, ratio, self.ratio)
|
| optimizer_was_run = True
|
| if not self.deepspeed:
|
| self.optimizer.step()
|
| if optimizer_was_run:
|
| self.lr_scheduler.step()
|
| model.zero_grad()
|
| self.state.global_step += 1
|
| self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
|
| self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
| self._maybe_log_save_evaluate(tr_loss, grad_norm if grad_norm is not None else None, model, trial, epoch, ignore_keys_for_eval)
|
| else:
|
| self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
|
| if self.control.should_epoch_stop or self.control.should_training_stop:
|
| break
|
| if step < 0:
|
| logger.warning(f"No samples in epoch_iterator, stopping at step {self.state.global_step}.")
|
| self.control.should_training_stop = True
|
| self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
|
| self._maybe_log_save_evaluate(tr_loss, None, model, trial, epoch, ignore_keys_for_eval)
|
| if self.control.should_training_stop:
|
| break
|
|
|
| if args.past_index and hasattr(self, "_past"):
|
| delattr(self, "_past")
|
| logger.info("\n\nTraining completed. Share your model on huggingface.co/models =)\n\n")
|
| if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
|
| self._load_best_model()
|
|
|
| self._total_loss_scalar += tr_loss.item()
|
| train_loss = self._total_loss_scalar / self.state.global_step
|
| metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
|
| self.store_flos()
|
| metrics["total_flos"], metrics["train_loss"] = self.state.total_flos, train_loss
|
| self.is_in_train = False
|
| self._memory_tracker.stop_and_update_metrics(metrics)
|
| self.log(metrics)
|
|
|
| run_dir = self._get_output_dir(trial)
|
| checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
|
| if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
|
| for checkpoint in checkpoints_sorted:
|
| if checkpoint != self.state.best_model_checkpoint:
|
| logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
|
| shutil.rmtree(checkpoint)
|
|
|
| self.control = self.callback_handler.on_train_end(args, self.state, self.control)
|
| return TrainOutput(self.state.global_step, train_loss, metrics) |