| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from __future__ import annotations |
|
|
| import warnings |
|
|
| import torch |
| from transformers.pytorch_utils import Conv1D |
|
|
| from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists |
| from peft.utils import ( |
| TRANSFORMERS_MODELS_TO_WAVEFT_TARGET_MODULES_MAPPING, |
| ) |
| from peft.utils.other import get_pattern_key |
|
|
| from .layer import WaveFTLayer, WaveFTLinear |
|
|
|
|
| class WaveFTModel(BaseTuner): |
| prefix: str = "waveft_" |
| tuner_layer_cls: type[BaseTunerLayer] = WaveFTLayer |
| target_module_mapping = TRANSFORMERS_MODELS_TO_WAVEFT_TARGET_MODULES_MAPPING |
|
|
| def _calculate_proportional_parameters(self, model: torch.nn.Module, waveft_config): |
| """Calculate proportional parameter allocation for all target modules.""" |
| target_modules_info = [] |
| for name, module in model.named_modules(): |
| if check_target_module_exists(waveft_config, name): |
| |
| if isinstance(module, WaveFTLayer): |
| |
| base_module = module.base_layer |
| if isinstance(base_module, torch.nn.Linear): |
| input_dim, output_dim = base_module.in_features, base_module.out_features |
| elif isinstance(base_module, Conv1D): |
| input_dim, output_dim = base_module.weight.shape[1], base_module.weight.shape[0] |
| else: |
| continue |
| elif isinstance(module, torch.nn.Linear): |
| input_dim, output_dim = module.in_features, module.out_features |
| elif isinstance(module, Conv1D): |
| input_dim, output_dim = module.weight.shape[1], module.weight.shape[0] |
| else: |
| continue |
| target_modules_info.append((name, input_dim, output_dim)) |
|
|
| if not target_modules_info: |
| raise ValueError("No target modules found for proportional parameter allocation.") |
|
|
| total_sum = sum(input_dim * output_dim for (_, input_dim, output_dim) in target_modules_info) |
| num_layers = len(target_modules_info) |
| total_budget = waveft_config.n_frequency * num_layers |
|
|
| n_frequency_dict = {} |
| for name, input_dim, output_dim in target_modules_info: |
| layer_ratio = (input_dim * output_dim) / total_sum |
| n_freq = round(layer_ratio * total_budget) |
| n_frequency_dict[name] = n_freq |
|
|
| return n_frequency_dict |
|
|
| def _create_and_replace( |
| self, |
| waveft_config, |
| adapter_name, |
| target, |
| target_name, |
| parent, |
| current_key, |
| **optional_kwargs, |
| ): |
| if current_key is None: |
| raise ValueError("Current Key shouldn't be `None`") |
|
|
| |
| if waveft_config.proportional_parameters: |
| if not hasattr(self, "_proportional_params_cache"): |
| self._proportional_params_cache = {} |
| if adapter_name not in self._proportional_params_cache: |
| n_frequency_dict = self._calculate_proportional_parameters(self.model, waveft_config) |
| self._proportional_params_cache[adapter_name] = n_frequency_dict |
|
|
| |
| |
| |
| |
| |
| n_frequency = None |
| if ( |
| waveft_config.proportional_parameters |
| and hasattr(self, "_proportional_params_cache") |
| and adapter_name in self._proportional_params_cache |
| ): |
| n_frequency = self._proportional_params_cache[adapter_name].get(current_key) |
|
|
| if n_frequency is None and "n_frequency" in optional_kwargs: |
| n_frequency = optional_kwargs["n_frequency"] |
|
|
| if n_frequency is None: |
| pattern_keys = list(waveft_config.n_frequency_pattern.keys()) |
| target_name_key = get_pattern_key(pattern_keys, current_key) |
| n_frequency = waveft_config.n_frequency_pattern.get(target_name_key, waveft_config.n_frequency) |
|
|
| |
| wavelet_family = None |
| if "wavelet_family" in optional_kwargs: |
| wavelet_family = optional_kwargs["wavelet_family"] |
| if wavelet_family is None: |
| wavelet_family = waveft_config.wavelet_family |
|
|
| scaling = waveft_config.scaling |
| random_loc_seed = waveft_config.random_loc_seed |
| bias = hasattr(target, "bias") and target.bias is not None |
| |
| kwargs = { |
| "n_frequency": n_frequency, |
| "scaling": scaling, |
| "fan_in_fan_out": waveft_config.fan_in_fan_out, |
| "init_weights": waveft_config.init_weights, |
| "random_loc_seed": waveft_config.random_loc_seed, |
| "wavelet_family": wavelet_family, |
| } |
| kwargs["bias"] = bias |
|
|
| if isinstance(target, WaveFTLayer): |
| target.update_layer( |
| adapter_name, |
| n_frequency, |
| scaling, |
| waveft_config.init_weights, |
| random_loc_seed, |
| wavelet_family=wavelet_family, |
| use_idwt=waveft_config.use_idwt, |
| ) |
| else: |
| new_module = self._create_new_module(waveft_config, adapter_name, target, **kwargs) |
| if adapter_name != self.active_adapter: |
| new_module.requires_grad_(False) |
| self._replace_module(parent, target_name, new_module, target) |
|
|
| @staticmethod |
| def _create_new_module(waveft_config, adapter_name, target, **kwargs): |
| if isinstance(target, BaseTunerLayer): |
| target_base_layer = target.get_base_layer() |
| else: |
| target_base_layer = target |
|
|
| if isinstance(target_base_layer, torch.nn.Linear): |
| if kwargs["fan_in_fan_out"]: |
| warnings.warn( |
| "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " |
| "Setting fan_in_fan_out to False." |
| ) |
| kwargs["fan_in_fan_out"] = waveft_config.fan_in_fan_out = False |
| elif isinstance(target_base_layer, Conv1D): |
| kwargs["is_target_conv_1d_layer"] = True |
| if not kwargs["fan_in_fan_out"]: |
| warnings.warn( |
| "fan_in_fan_out is set to False but the target module is `Conv1D`. Setting fan_in_fan_out to True." |
| ) |
| kwargs["fan_in_fan_out"] = waveft_config.fan_in_fan_out = True |
| else: |
| raise ValueError( |
| f"Target module {target} is not supported. Currently, only the following modules are supported: " |
| "`torch.nn.Linear`." |
| ) |
|
|
| kwargs["wavelet_family"] = waveft_config.wavelet_family |
| kwargs["use_idwt"] = waveft_config.use_idwt |
| new_module = WaveFTLinear(target, adapter_name, **kwargs) |
|
|
| return new_module |
|
|
| def delete_adapter(self, adapter_name: str) -> None: |
| """ |
| Deletes an existing adapter. |
| |
| Args: |
| adapter_name (str): Name of the adapter to be deleted. |
| """ |
| super().delete_adapter(adapter_name) |
| |
| if hasattr(self, "_proportional_params_cache") and adapter_name in self._proportional_params_cache: |
| del self._proportional_params_cache[adapter_name] |
|
|