nvan13's picture
Upload folder using huggingface_hub
ecadbd9 verified
from typing import Optional
import torch
import torch.nn as nn
from enum import Enum
from dataclasses import asdict
from tqdm import tqdm
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists, onload_layer
from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, ModulesToSaveWrapper, _get_submodules
from .layer import SamaLayer, Linear
TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
class SamaTuner(BaseTuner):
prefix: str = "_sama"
tuner_layer_class = SamaLayer
target_module_mapping = TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING
@staticmethod
def _check_target_module_exists(rotation_config, key: str) -> bool:
return check_target_module_exists(rotation_config, key)
def _create_and_replace(
self,
sama_config,
adapter_name: str,
target: nn.Module,
target_name: str,
parent: nn.Module,
current_key: str,
**optional_kwargs,
) -> None:
"""
Create and replace a target module with a rotation-augmented version.
This method is called when an existing module is already a RotationLayer
and needs to have a new adapter added to it.
Args:
sama_config: Configuration for the SaMA adapter
adapter_name: Name of the adapter to add
target: The target module to augment
target_name: Name of the target module
parent: Parent module containing the target
current_key: Full key path to the current module
**optional_kwargs: Additional optional arguments
Raises:
ValueError: If current_key is not provided
"""
if current_key is None:
raise ValueError("current_key must be provided to create Rotation layer")
# Check if target is already a SamaLayer
if isinstance(target, SamaLayer):
target.update_layer(
adapter_name= sama_config.adapter_name,
share_factor= sama_config.share_factor,
scaling= sama_config.scaling,
num_unique_blocks= sama_config.num_unique_blocks,
col_L= sama_config.col_L,
row_R= sama_config.row_R,
drop_out= sama_config.drop_out,
)
else:
# Create new rotation layer
new_module = self._create_new_module(
sama_config=sama_config,
adapter_name=adapter_name,
target=target,
**optional_kwargs,
)
if new_module is not None:
self._replace_module(parent, target_name, new_module, target)
def _replace_module(self, parent, child_name, new_module, child):
setattr(parent, child_name, new_module)
# child layer wraps the original module, unpack it
if hasattr(child, "base_layer"):
child = child.base_layer
meta = torch.device("meta")
# dispatch to correct device
for name, module in new_module.named_modules():
if (self.prefix in name) or ("ranknum" in name):
if hasattr(child, "qweight"):
weight = child.qweight
elif hasattr(child, "W_q"):
weight = child.W_q
elif hasattr(child, "weight"):
weight = child.weight
elif getattr(child, "in_proj_weight", None) is not None: # MHA
weight = child.in_proj_weight
else:
weight = next(child.parameters())
if not any(p.device == meta for p in module.parameters()):
module.to(weight.device)
def _mark_only_adapters_as_trainable(self, model):
# First, freeze all parameters
for n, p in model.named_parameters():
# print(f'{n}, np {p.requires_grad}')
if self.prefix not in n:
p.requires_grad = False
else:
p.requires_grad = True
# Handle bias parameters based on config
for active_adapter in self.active_adapters:
bias_config = self.peft_config[active_adapter].bias
if bias_config == "none":
continue
elif bias_config == "all":
# Enable all bias parameters
for n, p in model.named_parameters():
if "bias" in n:
p.requires_grad = True
elif bias_config == "sama_only":
# Enable only bias in rotation layers
for name, m in model.named_modules():
if isinstance(m, SamaLayer):
if hasattr(m, "bias") and m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError(
f"Requested bias configuration '{bias_config}' is not implemented. "
f"Supported values: 'none', 'all', 'sama_only'"
)
@staticmethod
def _create_new_module(
sama_config,
adapter_name: str,
target: nn.Module,
**kwargs,
) -> Optional[nn.Module]:
"""
Create a new rotation-augmented module.
Args:
sama_config: Configuration for the SaMA adapter
adapter_name: Name of the adapter
target: Base module to augment
**kwargs: Additional arguments
Returns:
New RotationLayer module wrapping the target, or None if unsupported
"""
if isinstance(target, nn.Linear):
return Linear(
base_layer=target,
adapter_name=adapter_name,
share_factor_L= sama_config.share_factor_L,
share_factor_R= sama_config.share_factor_R,
scaling= sama_config.scaling,
col_L= sama_config.col_L,
row_R= sama_config.row_R,
drop_out= sama_config.drop_out,
**kwargs,
)
else:
# Unsupported layer type
print(
f"SaMA layer does not support {type(target).__name__} yet. "
f"Skipping this module."
)
return None
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.model, name)
def get_peft_config_as_dict(self, inference: bool = False):
config_dict = {}
for key, value in self.peft_config.items():
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
if inference:
config["inference_mode"] = True
config_dict[key] = config
return config
def _set_adapter_layers(self, enabled=True):
for module in self.model.modules():
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
module.enable_adapters(enabled)
def enable_adapter_layers(self) -> None:
"""Enable all adapters.
Call this if you have previously disabled all adapters and want to re-enable them.
"""
self._set_adapter_layers(enabled=True)
def disable_adapter_layers(self):
for active_adapter in self.active_adapters:
val = self.peft_config[active_adapter].bias
if val != "none":
msg = (
f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same "
"output as the base model would without adaption."
)
print(msg)
self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name, inference_mode):
"""Set the active adapter(s).
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
not desired, use the following code.
```py
>>> for name, param in model_peft.named_parameters():
... if ...: # some check on name (ex. if 'lora' in name)
... param.requires_grad = False
```
Args:
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated.
"""
for module in self.model.modules():
if isinstance(module, SamaLayer):
if module.merged:
print("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge()
module.set_adapter(adapter_name, inference_mode)
self.active_adapter = adapter_name
def merge_adapter(self, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge adapter weights into the base model weights.
This can speed up inference by eliminating the need for runtime
rotation computations.
Args:
adapter_names: List of adapter names to merge. If None, merges all
active adapters.
"""
for module in self.model.modules():
if isinstance(module, SamaLayer):
module.merge(safe_merge=False, adapter_names=adapter_names)
def unmerge_adapter(self) -> None:
"""
Unmerge adapter weights from the base model weights.
This reverses the merge operation, restoring dynamic adapter behavior.
"""
for module in self.model.modules():
if isinstance(module, SamaLayer):
module.unmerge()
@staticmethod
def _prepare_adapter_config(peft_config, model_config):
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = set(
TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING[model_config["model_type"]]
)
return peft_config
def _check_new_adapter_config(self, config) -> None:
"""
Check the validity of a new adapter configuration.
Args:
config: Configuration to validate
Raises:
ValueError: If configuration is invalid
"""
# Validate rank
if config.share_factor_L <= 0:
raise ValueError(f"share_factor_L must be positive, got {config.share_factor_L}")
if config.share_factor_R <= 0:
raise ValueError(f"share_factor_R must be positive, got {config.share_factor_R}")
# Validate num_rotations
if config.col_L <= 0:
raise ValueError(
f"#cols of L must be positive, got {config.col_L}"
)
if config.row_R <= 0:
raise ValueError(
f"#rows of R must be positive, got {config.row_R}"
)
# Validate bias configuration
valid_bias_configs = ["none", "all", "sama_only"]
if hasattr(config, "bias") and config.bias not in valid_bias_configs:
raise ValueError(
f"Invalid bias configuration '{config.bias}'. "
f"Must be one of {valid_bias_configs}"
)
def _unload_and_optionally_merge(
self,
merge=True,
progressbar: bool = False,
safe_merge: bool = False,
adapter_names: Optional[list[str]] = None,
):
if merge:
self._check_merge_allowed()
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
desc = "Unloading " + ("and merging " if merge else "") + "model"
for key in tqdm(key_list, disable=not progressbar, desc=desc):
try:
parent, target, target_name = _get_submodules(self.model, key)
except AttributeError:
continue
with onload_layer(target):
if hasattr(target, "unload_and_optionally_merge_module"):
# if layers have special unloading method, like MultiheadAttention, use that
unloaded_module = target.unload_and_optionally_merge_module(
merge=merge, safe_merge=safe_merge, adapter_names=adapter_names
)
self._replace_module(parent, target_name, unloaded_module, target)
elif hasattr(target, "base_layer"):
if merge:
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
self._replace_module(parent, target_name, target.get_base_layer(), target)
return self.model
def delete_adapter(self, adapter_name: str) -> None:
"""
Deletes an existing adapter.
Args:
adapter_name (str): Name of the adapter to be deleted.
"""
if adapter_name not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter_name} does not exist")
del self.peft_config[adapter_name]
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
new_adapter = None
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, SamaLayer):
target.delete_adapter(adapter_name)
if new_adapter is None:
new_adapter = target.active_adapters[:]
self.active_adapter = new_adapter or []
self._delete_auxiliary_adapter(adapter_name, new_active_adapters=new_adapter)
def merge_and_unload(
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
) -> torch.nn.Module:
r"""
This method merges the OFT layers into the base model. This is needed if someone wants to use the base model as
a standalone model.
Args:
progressbar (`bool`):
whether to show a progressbar indicating the unload and merge process
safe_merge (`bool`):
whether to activate the safe merging check to check if there is any potential Nan in the adapter
weights
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
return self._unload_and_optionally_merge(
progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
)
def unload(self) -> torch.nn.Module:
"""
Gets back the base model by removing all the oft modules without merging. This gives back the original base
model.
"""
return self._unload_and_optionally_merge(merge=False)