loraprune / loraprune.py
mosroormofizarman's picture
update
2a69376 verified
# Imports and Dependencies (unchanged)
import importlib
import math
import numpy as np
import os
import re
import shutil
import sys
import time
import warnings
from dataclasses import asdict, dataclass, field
from enum import Enum
from packaging import version
from typing import List, Literal, Optional, Union, Dict, Tuple # Added Tuple
import peft
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules
from accelerate.utils import get_balanced_memory
from huggingface_hub import hf_hub_download
from peft import LoftQConfig, LoraRuntimeConfig, PeftConfig, PeftModelForCausalLM, PeftType, set_peft_model_state_dict
from peft.utils import transpose
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.pytorch_utils import Conv1D
from transformers.trainer import Trainer, TrainerState, TrainOutput, has_length, is_sagemaker_mp_enabled, get_model_param_count, speed_metrics, deepspeed_init, TRAINER_STATE_NAME
from transformers.trainer_callback import ExportableState
from transformers.trainer_pt_utils import IterableDatasetShard
from transformers.utils import logging, is_torch_xla_available, is_apex_available
if is_apex_available():
from apex import amp
def is_bnb_available():
return importlib.util.find_spec("bitsandbytes") is not None
if is_bnb_available():
import bitsandbytes as bnb
from bitsandbytes.nn.modules import Int8Params
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11")
logger = logging.get_logger(__name__)
WEIGHTS_NAME = "adapter_model.bin"
CONFIG_NAME = "adapter_config.json"
NUM_ATTENTION_HEADS = 32
pruning_groups = {'self_attn': ['q_proj', 'k_proj', 'v_proj', 'o_proj'],
'mlp': ['up_proj', 'gate_proj'],
'block': ['o_proj', 'down_proj']}
# LoRA Configuration
@dataclass
class LoraConfig(PeftConfig):
"""Configuration class to store the configuration of a LoRA model."""
r: int = field(default=8, metadata={"help": "Lora attention dimension"})
target_modules: Optional[Union[List[str], str]] = field(
default=None,
metadata={"help": "List of module names or regex expression of the module names to replace with Lora."}
)
lora_alpha: int = field(default=None, metadata={"help": "Lora alpha"})
lora_dropout: float = field(default=None, metadata={"help": "Lora dropout"})
merge_weights: bool = field(default=False, metadata={"help": "Merge weights of the original model and the Lora model"})
fan_in_fan_out: bool = field(default=False, metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"})
enable_lora: Optional[List[bool]] = field(default=None, metadata={"help": "Used with `lora.MergedLinear`."})
bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"})
lora_bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"})
modules_to_save: Optional[List[str]] = field(default=None, metadata={"help": "List of modules apart from LoRA layers to be set as trainable and saved."})
use_rslora: bool = field(default=False, metadata={"help": "Use Rank-Stabilized LoRA."})
init_lora_weights: Union[bool, Literal["gaussian", "olora", "pissa", "pissa_niter_[number of iters]", "loftq"]] = field(
default=True, metadata={"help": "How to initialize the weights of the LoRA layers."}
)
layers_to_transform: Optional[Union[List[int], int]] = field(default=None, metadata={"help": "The layer indexes to transform."})
layer_replication: Optional[List[Tuple[int, int]]] = field(default=None, metadata={"help": "Layer replication configuration."})
rank_pattern: Optional[Dict] = field(default_factory=dict, metadata={"help": "Mapping from layer names to ranks."})
alpha_pattern: Optional[Dict] = field(default_factory=dict, metadata={"help": "Mapping from layer names to alphas."})
megatron_config: Optional[Dict] = field(default=None, metadata={"help": "TransformerConfig from Megatron."})
megatron_core: Optional[str] = field(default="megatron.core", metadata={"help": "Core module from Megatron."})
loftq_config: Union[LoftQConfig, Dict] = field(default_factory=dict, metadata={"help": "Configuration of LoftQ."})
use_dora: bool = field(default=False, metadata={"help": "Enable Weight-Decomposed Low-Rank Adaptation (DoRA)."})
runtime_config: LoraRuntimeConfig = field(default_factory=LoraRuntimeConfig, metadata={"help": "Runtime configurations"})
_custom_modules: Optional[Dict] = field(default=None, metadata={"help": "Custom module mapping for LoRA."})
def __post_init__(self):
self.peft_type = PeftType.LORA
self.target_modules = set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
if isinstance(self.target_modules, str) and (self.layers_to_transform is not None or hasattr(self, 'layers_pattern') and self.layers_pattern is not None):
raise ValueError("`layers_to_transform` or `layers_pattern` cannot be used when `target_modules` is a str.")
if self.use_dora and self.megatron_config:
raise ValueError("DoRA does not support megatron_core, set `use_dora=False`.")
if self.init_lora_weights == "loftq":
if not importlib.util.find_spec("scipy"):
raise ImportError("The required package 'scipy' is not installed.")
if self.loftq_config is None:
raise ValueError("`loftq_config` must be specified when `init_lora_weights` is 'loftq'.")
if self.use_rslora and (self.rank_pattern or self.alpha_pattern) and (isinstance(self.init_lora_weights, str) and (self.init_lora_weights.startswith("pissa") or self.init_lora_weights == "olora")):
warnings.warn("Using Rank-Stabilized LoRA with rank_pattern/alpha_pattern and post-training conversion may conflict.")
if self.loftq_config and not isinstance(self.loftq_config, dict):
self.loftq_config = vars(self.loftq_config)
def to_dict(self):
"""Convert the config to a dictionary for compatibility with peft."""
rv = asdict(self)
rv.pop("runtime_config", None) # Remove runtime_config as in original
rv["peft_type"] = self.peft_type.value # Ensure peft_type is a string
return rv
def update(self, updates):
"""Add an update method to allow merging with a dictionary."""
for key, value in updates.items():
if hasattr(self, key):
setattr(self, key, value)
else:
raise KeyError(f"Cannot update {key}: not a valid attribute of LoraConfig")
# LoRA Model and Layers
class LoraModel(torch.nn.Module):
def __init__(self, config, model):
super().__init__()
self.peft_config = config
self.model = model
self._find_and_replace()
mark_only_lora_as_trainable(self.model, self.peft_config.lora_bias) # Use lora_bias instead of bias
self.forward = self.model.forward
def _find_and_replace(self):
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
if loaded_in_8bit and not is_bnb_available():
raise ImportError("To use Lora with 8-bit quantization, install `bitsandbytes`.")
is_target_modules_in_base_model = False
kwargs = {"r": self.peft_config.r, "lora_alpha": self.peft_config.lora_alpha, "lora_dropout": self.peft_config.lora_dropout,
"fan_in_fan_out": self.peft_config.fan_in_fan_out, "merge_weights": self.peft_config.merge_weights or self.peft_config.inference_mode}
key_list = [key for key, _ in self.model.named_modules()]
for key in key_list:
target_module_found = re.fullmatch(self.peft_config.target_modules, key) if isinstance(self.peft_config.target_modules, str) else \
any(key.endswith(target_key) for target_key in self.peft_config.target_modules)
if target_module_found:
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
parent, target, target_name = self._get_submodules(key)
bias = target.bias is not None
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
new_kwargs = kwargs.copy() # Fix: Create a copy of kwargs
new_kwargs.update({"has_fp16_weights": target.state.has_fp16_weights, "memory_efficient_backward": target.state.memory_efficient_backward,
"threshold": target.state.threshold, "index": target.index})
new_module = Linear8bitLt(target.in_features, target.out_features, bias=bias, **new_kwargs) if self.peft_config.enable_lora is None else \
MergedLinear8bitLt(target.in_features, target.out_features, bias=bias, **new_kwargs.update({"enable_lora": self.peft_config.enable_lora}))
elif isinstance(target, peft.tuners.lora.layer.Linear) and self.peft_config.enable_lora is None:
new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs)
elif isinstance(target, nn.Linear): # Add this condition
new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs)
elif self.peft_config.enable_lora is not None:
new_kwargs = kwargs.copy() # Fix: Create a copy of kwargs
new_kwargs.update({"enable_lora": self.peft_config.enable_lora})
in_features, out_features = (target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape) if isinstance(target, Conv1D) else \
(target.in_features, target.out_features)
if new_kwargs["fan_in_fan_out"] and not isinstance(target, Conv1D):
warnings.warn("fan_in_fan_out is set to True but the target module is not a Conv1D. Setting to False.")
new_kwargs["fan_in_fan_out"] = False
new_module = Linear(in_features, out_features, bias=bias, **new_kwargs)
else:
raise ValueError("No valid condition met for creating a new module.")
self._replace_module(parent, target_name, new_module, target)
if not is_target_modules_in_base_model:
raise ValueError(f"Target modules {self.peft_config.target_modules} not found in the base model.")
def _get_submodules(self, key):
parent = self.model.get_submodule(".".join(key.split(".")[:-1]))
target_name = key.split(".")[-1]
target = self.model.get_submodule(key)
return parent, target, target_name
def _replace_module(self, parent_module, child_name, new_module, old_module):
setattr(parent_module, child_name, new_module)
new_module.weight = old_module.weight
if old_module.bias is not None:
new_module.bias = old_module.bias
if getattr(old_module, "state", None) is not None:
new_module.state = old_module.state
new_module.to(old_module.weight.device)
def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.model, name)
@property
def modules_to_save(self):
return None
def get_peft_config_as_dict(self, inference: bool = False):
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(self.peft_config).items()}
if inference:
config["inference_mode"] = True
return config
def _set_adapter_layers(self, enabled=True):
for module in self.model.modules():
if isinstance(module, LoraLayer):
module.disable_adapters = False if enabled else True
def enable_adapter_layers(self):
self._set_adapter_layers(enabled=True)
def disable_adapter_layers(self):
self._set_adapter_layers(enabled=False)
class LoraLayer:
def __init__(self, r: int, lora_alpha: int, lora_dropout: float, merge_weights: bool):
self.r = r
self.lora_alpha = lora_alpha
self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0.0 else lambda x: x
self.merged = False
self.merge_weights = merge_weights
self.disable_adapters = False
class Linear(nn.Linear, LoraLayer):
def __init__(self, in_features: int, out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, merge_weights: bool = True, **kwargs):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
self.fan_in_fan_out = fan_in_fan_out
self.lora_mask = nn.Parameter(torch.ones(out_features), requires_grad=False)
if r > 0:
self.lora_A = nn.Linear(in_features, r, bias=False)
self.lora_B = nn.Linear(r, out_features, bias=False)
self.scaling = self.lora_alpha / self.r
self.weight.requires_grad = False
self.is_prune = True
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, "lora_A"):
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def train(self, mode: bool = True):
nn.Linear.train(self, mode)
self.lora_A.train(mode)
self.lora_B.train(mode)
if not mode and self.merge_weights and not self.merged:
if self.r > 0:
self.weight.data += transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
self.merged = True
elif self.merge_weights and self.merged:
if self.r > 0:
self.weight.data -= transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
self.merged = False
def eval(self):
nn.Linear.eval(self)
self.lora_A.eval()
self.lora_B.eval()
def forward(self, x: torch.Tensor):
if self.disable_adapters:
if self.r > 0 and self.merged:
self.weight.data -= transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
self.merged = False
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
elif self.r > 0 and not self.merged:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
if self.r > 0:
lora_output = self.lora_B(self.lora_A(self.lora_dropout(x).to(self.lora_A.weight.dtype))) * self.scaling
result += lora_output.to(result.dtype)
else:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
if hasattr(self, 'lora_mask'):
result *= self.lora_mask.reshape(1, 1, -1)
return result
if is_bnb_available():
class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer):
def __init__(self, in_features, out_features, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, **kwargs):
bnb.nn.Linear8bitLt.__init__(self, in_features, out_features, bias=kwargs.get("bias", True),
has_fp16_weights=kwargs.get("has_fp16_weights", True),
memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
threshold=kwargs.get("threshold", 0.0), index=kwargs.get("index", None))
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
if r > 0:
self.lora_A = nn.Linear(in_features, r, bias=False)
self.lora_B = nn.Linear(r, out_features, bias=False)
self.scaling = self.lora_alpha / self.r
self.weight.requires_grad = False
self.lora_mask = nn.Parameter(torch.ones(out_features), requires_grad=False)
self.is_prune = True
self.reset_parameters()
def reset_parameters(self):
if hasattr(self, "lora_A"):
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def forward(self, x: torch.Tensor):
result = super().forward(x)
if self.disable_adapters:
return result
elif self.r > 0:
if not torch.is_autocast_enabled():
expected_dtype = result.dtype
if x.dtype != torch.float32:
x = x.float()
output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling
result += output
else:
output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
result += output
if hasattr(self, 'lora_mask'):
result *= self.lora_mask.reshape(1, 1, -1)
return result
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
for n, p in model.named_parameters():
if "lora_" not in n:
p.requires_grad = False
if bias == "none":
return
elif bias == "all":
for n, p in model.named_parameters():
if "bias" in n:
p.requires_grad = True
elif bias == "lora_only":
for m in model.modules():
if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError
# LoRA PEFT Model
class LoraPeftModelForCausalLM(PeftModelForCausalLM):
def __init__(self, model, peft_config):
super().__init__(model, peft_config)
self.base_model = LoraModel(peft_config, model)
@classmethod
def from_pretrained(cls, model, model_id, **kwargs):
from peft.mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING
config = PEFT_TYPE_TO_CONFIG_MAPPING[PeftConfig.from_pretrained(model_id).peft_type].from_pretrained(model_id)
model = LoraPeftModelForCausalLM(model, config)
if os.path.exists(os.path.join(model_id, WEIGHTS_NAME)):
print("loading .................")
filename = os.path.join(model_id, WEIGHTS_NAME)
else:
try:
filename = hf_hub_download(model_id, WEIGHTS_NAME)
except:
raise ValueError(f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub.")
adapters_weights = torch.load(filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
model = set_peft_model_state_dict(model, adapters_weights)
if getattr(model, "hf_device_map", None) is not None:
device_map = kwargs.get("device_map", "auto")
max_memory = kwargs.get("max_memory", None)
no_split_module_classes = model._no_split_modules
if device_map != "sequential":
max_memory = get_balanced_memory(model, max_memory=max_memory, no_split_module_classes=no_split_module_classes,
low_zero=(device_map == "balanced_low_0"))
if isinstance(device_map, str):
device_map = infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=no_split_module_classes)
model = dispatch_model(model, device_map=device_map)
hook = AlignDevicesHook(io_same_device=True)
if model.peft_config.peft_type == PeftType.LORA:
add_hook_to_module(model.base_model.model, hook)
else:
remove_hook_from_submodules(model.prompt_encoder)
add_hook_to_module(model.base_model, hook)
return model
@property
def active_peft_config(self):
return self.peft_config
def get_peft_model(model, peft_config):
"""Returns a PEFT model object from a model and a config."""
model_config = model.config.to_dict()
peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)
return LoraPeftModelForCausalLM(model, peft_config)
# Pruning Utilities
def _is_target_larer(module):
return (isinstance(module, Linear) or isinstance(module, Linear8bitLt)) and module.is_prune
def unfreeze(model):
for name, module in model.named_modules():
if _is_target_larer(module):
module.weight.requires_grad = True
def freeze(model):
layers = len(model.model.model.layers)
freeze_layer = int(layers * 0.1)
for name, module in model.named_modules():
if _is_target_larer(module):
layer = int(name.split('.')[4])
if layer < freeze_layer or layer == layers-1:
module.is_prune = False
def init_sensitivity_dict(model):
sensitivity_record = {}
for name, module in model.named_modules():
if _is_target_larer(module):
module_name = name.split('.')[-1]
if module_name in pruning_groups['self_attn']:
head_dim = module.out_features // NUM_ATTENTION_HEADS
groups = module.out_features // head_dim
else:
groups = module.out_features
name = ".".join(name.split('.')[:-1])
if name in sensitivity_record:
continue
sensitivity_record[name] = module.lora_A.weight.data.new_zeros(groups)
return sensitivity_record
def update_sensitivity_dict(model, s_dict, pruning_type):
s_all = init_sensitivity_dict(model)
for name, module in model.named_modules():
if _is_target_larer(module):
is_attn = name.split('.')[-1] in pruning_groups['self_attn']
fan_in = name.split('.')[-1] in pruning_groups['block']
s = compute_sensitivity(module, is_attn, pruning_type, fan_in)
name = ".".join(name.split('.')[:-1])
s_all[name] += s
for name, imp in s_all.items():
if torch.isnan(imp.sum()):
return s_dict
for name, imp in s_dict.items():
s_dict[name] = imp * 0.9 + s_all[name] * 0.1
return s_dict
def compute_sensitivity(layer, is_attn, prune_metric='lora', transpose=False, norm=True):
a = layer.lora_A.weight.data
b = layer.lora_B.weight.data
if prune_metric == 'lora':
grad_a = layer.lora_A.weight.grad
grad_b = layer.lora_B.weight.grad
grad = (grad_b @ a + b @ grad_a - grad_b @ grad_a)
elif prune_metric == 'magnitude':
grad = 1
elif prune_metric == 'grad':
grad = layer.weight.grad
else:
raise NotImplementedError
if hasattr(layer, 'state'):
weight = (layer.weight.data * layer.state.SCB.reshape(-1, 1)) / 127
else:
weight = layer.weight.data
s = (grad * (b @ a * layer.scaling + weight)).abs()
if transpose:
s = s.t()
if is_attn:
head_dim = layer.out_features // NUM_ATTENTION_HEADS
s = s.reshape(s.shape[0] // head_dim, -1)
s = s.sum(1)
if norm:
s = s / (torch.linalg.norm(s) + 1e-8)
return s
def prune_fp16_module(module, mask, transpose):
mask = mask.bool()
module.train()
if not transpose:
module.weight.data = module.weight.data[mask]
module.out_features = int(mask.sum())
if module.bias:
module.bias.data = module.bias.data[mask]
module.lora_B.weight.data = module.lora_B.weight.data[mask]
module.lora_B.out_features = int(mask.sum())
else:
module.weight.data = module.weight.data[:, mask]
module.in_features = int(mask.sum())
module.lora_A.weight.data = module.lora_A.weight.data[:, mask]
module.lora_A.in_features = int(mask.sum())
module.merge_weights = True
module.train(False)
def prune_one_layer(layer):
prune_fp16_module(layer.self_attn.q_proj, layer.self_attn.q_proj.lora_mask, False)
prune_fp16_module(layer.self_attn.k_proj, layer.self_attn.k_proj.lora_mask, False)
prune_fp16_module(layer.self_attn.v_proj, layer.self_attn.v_proj.lora_mask, False)
prune_fp16_module(layer.self_attn.o_proj, layer.self_attn.q_proj.lora_mask, True)
layer.self_attn.num_heads = int(layer.self_attn.q_proj.lora_mask.sum()) // 128
layer.self_attn.hidden_size = int(layer.self_attn.q_proj.lora_mask.sum())
prune_fp16_module(layer.mlp.gate_proj, layer.mlp.gate_proj.lora_mask, False)
prune_fp16_module(layer.mlp.up_proj, layer.mlp.up_proj.lora_mask, False)
prune_fp16_module(layer.mlp.down_proj, layer.mlp.gate_proj.lora_mask, True)
del(layer.self_attn.q_proj.lora_mask)
del(layer.self_attn.k_proj.lora_mask)
del(layer.self_attn.v_proj.lora_mask)
del(layer.mlp.gate_proj.lora_mask)
del(layer.mlp.up_proj.lora_mask)
def prune(model):
for layer_id, layer in enumerate(model.model.model.layers):
print(f"pruning layer {layer_id}")
prune_one_layer(layer)
def local_prune(model, s_dict, ratio, target_ratio):
original_param_num = pruned_param_num = 0
for name, module in model.named_modules():
if _is_target_larer(module):
original_param_num += np.prod(module.weight.shape)
pruned_param_num += np.prod(module.weight.shape) * ratio
is_attn = name.split('.')[-1] in pruning_groups['self_attn']
if name.split('.')[-1] in pruning_groups['block']:
continue
name = ".".join(name.split('.')[:-1])
if not hasattr(module, 'lora_mask') or (1-module.lora_mask.mean()).item() >= target_ratio:
continue
total_num = module.lora_mask.numel()
c_mask = module.lora_mask.data
mask = torch.ones_like(c_mask)
if is_attn:
head_dim = module.out_features // NUM_ATTENTION_HEADS
mask = mask.reshape(-1, head_dim)[:, 0]
c_mask = c_mask.reshape(-1, head_dim)[:, 0]
total_num /= head_dim
need_prune_num = int(total_num * ratio)
importance = s_dict[name] * c_mask
can_prune = torch.argsort(importance)[:need_prune_num]
mask[can_prune] = 0
if is_attn:
mask = (mask.new_ones(module.lora_mask.shape).reshape(-1, head_dim) * mask.unsqueeze(1)).reshape(-1)
module.lora_mask.data = mask
else:
if hasattr(module, 'weight'):
original_param_num += np.prod(module.weight.shape)
print(f"pruned/original parameters number:{pruned_param_num*1e-9:.3f}/{original_param_num*1e-9:.3f} ratio:{pruned_param_num/original_param_num:.3f}")
def schedule_sparsity_ratio(step, total_step, initial_warmup, final_warmup, initial_sparsity, final_sparsity):
if step <= initial_warmup * total_step:
sparsity = initial_sparsity
elif step > (total_step - final_warmup * total_step):
sparsity = final_sparsity
else:
spars_warmup_steps = initial_warmup * total_step
spars_schedu_steps = (final_warmup + initial_warmup) * total_step
mul_coeff = 1 - (step - spars_warmup_steps) / (total_step - spars_schedu_steps)
sparsity = final_sparsity + (initial_sparsity - final_sparsity) * (mul_coeff ** 3)
return sparsity
def prune_from_checkpoint(model):
prune(model)
def print_trainable_parameters(model):
total_params = trainable_params = 0
for n, p in model.named_parameters():
if p.requires_grad:
trainable_params += p.numel()
total_params += p.numel()
print(f"total params:{total_params * 1e-6:.2f}M trainable params:{trainable_params * 1e-6:.2f}M ratio:{trainable_params / total_params:.3f}")
class LoRAPruneTrainer(Trainer):
def __init__(self, model, train_dataset, eval_dataset, args, data_collator, ratio, init_ratio, warmup_iters,
cooldown_iters, prune_freq, prune_metric, **kwargs):
super().__init__(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=args,
data_collator=data_collator,
**kwargs # Forward additional arguments like label_names
)
self.ratio = ratio
self.init_ratio = init_ratio
self.warmup_iters = warmup_iters
self.cooldown_iters = cooldown_iters
self.prune_freq = prune_freq
self.prune_metric = prune_metric
def compute_loss(self, model, inputs, return_outputs=False):
# Extract inputs and labels
labels = inputs.get("labels")
outputs = model(**inputs)
logits = outputs.logits
# Compute loss (causal LM typically uses cross-entropy)
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100) # -100 is typically the padding token
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return (loss, outputs) if return_outputs else loss
def _inner_training_loop(self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None):
self._train_batch_size = batch_size
train_dataloader = self.get_train_dataloader()
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
len_dataloader = len(train_dataloader) if has_length(train_dataloader) else None
if len_dataloader:
num_update_steps_per_epoch = max(len_dataloader // args.gradient_accumulation_steps, 1)
num_examples = self.num_examples(train_dataloader)
if args.max_steps > 0:
max_steps = args.max_steps
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(args.max_steps % num_update_steps_per_epoch > 0)
num_train_samples = args.max_steps * total_train_batch_size
else:
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(args.num_train_epochs)
num_train_samples = num_examples * args.num_train_epochs
elif args.max_steps > 0:
max_steps = args.max_steps
num_train_epochs = sys.maxsize
num_update_steps_per_epoch = max_steps
num_examples = num_train_samples = args.max_steps * total_train_batch_size
else:
raise ValueError(f"args.max_steps must be set to a positive value if dataloader does not have a length, was {args.max_steps}")
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug and self.args.n_gpu > 1:
raise ValueError("Currently --debug underflow_overflow is not supported under DP. Please use DDP.")
else:
debug_overflow = DebugUnderflowOverflow(self.model) if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug else None
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
if args.deepspeed:
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint)
self.model, self.model_wrapped, self.deepspeed, self.optimizer, self.lr_scheduler = deepspeed_engine.module, deepspeed_engine, deepspeed_engine, optimizer, lr_scheduler
elif not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState(stateful_callbacks=[cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)])
self.state.is_hyper_param_search = trial is not None
if args.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
model = self._wrap_model(self.model_wrapped)
if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
self._load_from_checkpoint(resume_from_checkpoint, model)
if model is not self.model:
self.model_wrapped = model
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self._load_optimizer_and_scheduler(resume_from_checkpoint)
total_params = kept_params = sum([p.numel() if not p.requires_grad else 0 for p in model.parameters()])
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples:,}")
logger.info(f" Num Epochs = {num_train_epochs:,}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size:,}")
logger.info(f" Total train batch size = {total_train_batch_size:,}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps:,}")
logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
self.state.epoch, start_time, epochs_trained, steps_trained_in_current_epoch = 0, time.time(), 0, 0
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
epochs_trained = self.state.global_step // num_update_steps_per_epoch
steps_trained_in_current_epoch = (self.state.global_step % num_update_steps_per_epoch) * args.gradient_accumulation_steps if not args.ignore_data_skip else 0
logger.info(f" Continuing training from epoch {epochs_trained}, global step {self.state.global_step}")
self.callback_handler.model, self.callback_handler.optimizer, self.callback_handler.lr_scheduler, self.callback_handler.train_dataloader = self.model, self.optimizer, self.lr_scheduler, train_dataloader
if self.hp_name is not None and self._trial is not None:
self.state.trial_name = self.hp_name(self._trial)
self.state.trial_params, self.state.max_steps, self.state.num_train_epochs = None, max_steps, num_train_epochs
self.state.is_local_process_zero, self.state.is_world_process_zero = self.is_local_process_zero(), self.is_world_process_zero()
tr_loss = torch.tensor(0.0).to(args.device)
self._total_loss_scalar, self._globalstep_last_logged = 0.0, self.state.global_step
model.zero_grad()
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
if not args.ignore_data_skip:
for epoch in range(epochs_trained):
is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(train_dataloader.sampler, RandomSampler)
if is_torch_less_than_1_11 or not is_random_sampler:
for _ in train_dataloader:
break
else:
_ = list(train_dataloader.sampler)
total_batched_samples = 0
if self.prune_metric == 'grad':
unfreeze(model)
sensitivity_dict = init_sensitivity_dict(model)
for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, torch.utils.data.DataLoader) and isinstance(train_dataloader.sampler, torch.utils.data.distributed.DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
train_dataloader.dataset.set_epoch(epoch)
epoch_iterator = train_dataloader
if args.past_index >= 0:
self._past = None
steps_in_epoch = len(epoch_iterator) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync, steps_skipped = False, 0
for step, inputs in enumerate(epoch_iterator):
total_batched_samples += 1
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
tr_loss_step = self.training_step(model, inputs)
if args.logging_nan_inf_filter and not is_torch_xla_available() and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)):
tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
else:
tr_loss += tr_loss_step
self.current_flos += float(self.floating_point_ops(inputs))
if self.deepspeed:
self.deepspeed.step()
if total_batched_samples % args.gradient_accumulation_steps == 0 or (steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch):
if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
if is_sagemaker_mp_enabled() and args.fp16:
grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
elif hasattr(self.optimizer, "clip_grad_norm"):
grad_norm = self.optimizer.clip_grad_norm(args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"):
grad_norm = model.clip_grad_norm_(args.max_grad_norm)
else:
grad_norm = nn.utils.clip_grad_norm_(amp.master_params(self.optimizer) if self.use_apex else model.parameters(), args.max_grad_norm)
if not self.deepspeed:
sensitivity_dict = update_sensitivity_dict(model, sensitivity_dict, self.prune_metric)
ratio = schedule_sparsity_ratio(self.state.global_step, self.state.max_steps, self.warmup_iters, self.cooldown_iters, self.init_ratio, self.ratio)
if (self.state.global_step) % self.prune_freq == 0 and ratio > self.init_ratio and ratio < self.ratio:
local_prune(model, sensitivity_dict, ratio, self.ratio)
optimizer_was_run = True
if not self.deepspeed:
self.optimizer.step()
if optimizer_was_run:
self.lr_scheduler.step()
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, grad_norm if grad_norm is not None else None, model, trial, epoch, ignore_keys_for_eval)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
if step < 0:
logger.warning(f"No samples in epoch_iterator, stopping at step {self.state.global_step}.")
self.control.should_training_stop = True
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, None, model, trial, epoch, ignore_keys_for_eval)
if self.control.should_training_stop:
break
if args.past_index and hasattr(self, "_past"):
delattr(self, "_past")
logger.info("\n\nTraining completed. Share your model on huggingface.co/models =)\n\n")
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
self._load_best_model()
self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
self.store_flos()
metrics["total_flos"], metrics["train_loss"] = self.state.total_flos, train_loss
self.is_in_train = False
self._memory_tracker.stop_and_update_metrics(metrics)
self.log(metrics)
run_dir = self._get_output_dir(trial)
checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
for checkpoint in checkpoints_sorted:
if checkpoint != self.state.best_model_checkpoint:
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint)
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
return TrainOutput(self.state.global_step, train_loss, metrics)