| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import warnings |
| from typing import Any, Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers.pytorch_utils import Conv1D |
|
|
| from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge |
|
|
| from .constants import WAVELET_REDUCTIONS |
| from .waverec2d import waverec2d |
|
|
|
|
| class WaveFTLayer(BaseTunerLayer): |
| |
| adapter_layer_names = ("waveft_spectrum",) |
| |
| other_param_names = ( |
| "waveft_n_frequency", |
| "waveft_scaling", |
| "waveft_random_loc_seed", |
| "waveft_wavelet_family", |
| "waveft_indices", |
| "waveft_use_idwt", |
| ) |
|
|
| def __init__(self, base_layer: nn.Module, **kwargs) -> None: |
| self.base_layer = base_layer |
| self.waveft_n_frequency = {} |
| self.waveft_scaling = {} |
| self.waveft_spectrum = nn.ParameterDict({}) |
| self.waveft_wavelet_family = {} |
| self.waveft_indices = {} |
| self.waveft_random_loc_seed = {} |
| self.waveft_use_idwt = {} |
| |
| self._disable_adapters = False |
| self.merged_adapters = [] |
| self.kwargs = kwargs |
|
|
| base_layer = self.get_base_layer() |
| if isinstance(base_layer, nn.Linear): |
| self.in_features, self.out_features = base_layer.in_features, base_layer.out_features |
| elif isinstance(base_layer, Conv1D): |
| self.in_features, self.out_features = ( |
| base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape |
| ) |
| else: |
| raise ValueError(f"Unsupported layer type {type(base_layer)}") |
|
|
| def update_layer( |
| self, adapter_name, n_frequency, scaling, init_weights, random_loc_seed, wavelet_family="db1", use_idwt=True |
| ): |
| if n_frequency <= 0: |
| raise ValueError(f"`n_frequency` should be a positive integer value but the value passed is {n_frequency}") |
| if n_frequency > self.in_features * self.out_features: |
| raise ValueError( |
| f"`n_frequency` should be less than or equal to the product of the input and output dimensions " |
| f"but the value passed is {n_frequency} and the product is {self.in_features * self.out_features}" |
| ) |
|
|
| self.waveft_n_frequency[adapter_name] = n_frequency |
| self.waveft_random_loc_seed[adapter_name] = random_loc_seed |
| self.waveft_wavelet_family[adapter_name] = wavelet_family |
| self.waveft_use_idwt[adapter_name] = use_idwt |
|
|
| |
| reduction_rows, reduction_cols = WAVELET_REDUCTIONS[wavelet_family] |
|
|
| |
| |
| generator = torch.Generator().manual_seed(self.waveft_random_loc_seed[adapter_name]) |
| indices = torch.randperm(self.out_features * self.in_features, generator=generator)[:n_frequency] |
|
|
| |
| self.waveft_indices[adapter_name] = torch.stack( |
| [indices // self.in_features, indices % self.in_features], dim=0 |
| ) |
|
|
| self.waveft_scaling[adapter_name] = scaling |
|
|
| |
| |
| if init_weights: |
| |
| self.waveft_spectrum[adapter_name] = nn.Parameter(torch.empty(n_frequency), requires_grad=True) |
| self.reset_wave_parameters(adapter_name) |
| else: |
| |
| std_dev = 0.01 |
| self.waveft_spectrum[adapter_name] = nn.Parameter(torch.randn(n_frequency) * std_dev, requires_grad=True) |
|
|
| self._move_adapter_to_device_of_base_layer(adapter_name) |
| self.set_adapter(self.active_adapters) |
|
|
| @torch.no_grad() |
| def reset_wave_parameters(self, adapter_name): |
| if adapter_name in self.waveft_spectrum.keys(): |
| nn.init.zeros_(self.waveft_spectrum[adapter_name]) |
|
|
| def get_delta_weight(self, adapter) -> torch.Tensor: |
| spectrum = self.waveft_spectrum[adapter] |
| indices = self.waveft_indices[adapter].to(spectrum.device) |
| wavelet_family = self.waveft_wavelet_family[adapter] |
|
|
| |
| if self.waveft_use_idwt[adapter]: |
| reduction_rows, reduction_cols = WAVELET_REDUCTIONS[wavelet_family] |
|
|
| |
| |
| padded_out_features = self.out_features + reduction_rows |
| padded_in_features = self.in_features + reduction_cols |
|
|
| |
| if padded_out_features % 2 != 0: |
| padded_out_features += 1 |
| if padded_in_features % 2 != 0: |
| padded_in_features += 1 |
|
|
| |
| dense_spectrum = torch.zeros( |
| padded_out_features, padded_in_features, device=spectrum.device, dtype=spectrum.dtype |
| ) |
|
|
| |
| row_offset = (padded_out_features - self.out_features) // 2 |
| col_offset = (padded_in_features - self.in_features) // 2 |
|
|
| |
| padded_indices = indices.clone() |
| padded_indices[0, :] += row_offset |
| padded_indices[1, :] += col_offset |
|
|
| |
| |
| valid_mask = (padded_indices[0, :] < padded_out_features) & (padded_indices[1, :] < padded_in_features) |
| valid_indices = padded_indices[:, valid_mask] |
| valid_spectrum = spectrum[valid_mask] |
|
|
| |
| dense_spectrum[valid_indices[0, :], valid_indices[1, :]] = valid_spectrum |
|
|
| |
| H, W = dense_spectrum.shape |
| H2, W2 = H // 2, W // 2 |
| cA = dense_spectrum[:H2, :W2] |
| cH = dense_spectrum[:H2, W2:] |
| cV = dense_spectrum[H2:, :W2] |
| cD = dense_spectrum[H2:, W2:] |
|
|
| |
| coeffs = (cA, (cH, cV, cD)) |
|
|
| |
| delta_weight = waverec2d(coeffs, wavelet_family) * self.waveft_scaling[adapter] |
|
|
| |
| if delta_weight.shape[0] != self.out_features or delta_weight.shape[1] != self.in_features: |
| |
| start_row = (delta_weight.shape[0] - self.out_features) // 2 |
| start_col = (delta_weight.shape[1] - self.in_features) // 2 |
|
|
| |
| delta_weight = delta_weight[ |
| start_row : start_row + self.out_features, start_col : start_col + self.in_features |
| ] |
| else: |
| |
| dense_spectrum = torch.zeros( |
| self.out_features, self.in_features, device=spectrum.device, dtype=spectrum.dtype |
| ) |
| dense_spectrum[indices[0, :], indices[1, :]] = spectrum |
| delta_weight = dense_spectrum * self.waveft_scaling[adapter] |
|
|
| return delta_weight |
|
|
|
|
| class WaveFTLinear(nn.Module, WaveFTLayer): |
| |
| def __init__( |
| self, |
| base_layer, |
| adapter_name: str, |
| n_frequency: int = 1000, |
| scaling: float = 150.0, |
| fan_in_fan_out: bool = False, |
| init_weights: Union[bool, str] = False, |
| random_loc_seed: int = 777, |
| wavelet_family: str = "db1", |
| use_idwt: bool = True, |
| **kwargs, |
| ) -> None: |
| super().__init__() |
| WaveFTLayer.__init__(self, base_layer, **kwargs) |
| self.fan_in_fan_out = fan_in_fan_out |
| self._active_adapter = adapter_name |
| self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed, wavelet_family, use_idwt) |
|
|
| def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: |
| """ |
| Merge the active adapter weights into the base weights |
| |
| Args: |
| safe_merge (`bool`, *optional*): |
| If True, the merge operation will be performed in a copy of the original weights and check for NaNs |
| before merging the weights. This is useful if you want to check if the merge operation will produce |
| NaNs. Defaults to `False`. |
| adapter_names (`List[str]`, *optional*): |
| The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults |
| to `None`. |
| """ |
| adapter_names = check_adapters_to_merge(self, adapter_names) |
| if not adapter_names: |
| |
| return |
|
|
| for active_adapter in adapter_names: |
| if active_adapter in self.waveft_spectrum.keys(): |
| base_layer = self.get_base_layer() |
| if safe_merge: |
| |
| |
| orig_weights = base_layer.weight.data.clone() |
| orig_weights += self.get_delta_weight(active_adapter) |
|
|
| if not torch.isfinite(orig_weights).all(): |
| raise ValueError( |
| f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" |
| ) |
|
|
| base_layer.weight.data = orig_weights |
| else: |
| base_layer.weight.data += self.get_delta_weight(active_adapter) |
| self.merged_adapters.append(active_adapter) |
|
|
| def unmerge(self) -> None: |
| """ |
| This method unmerges all merged adapter layers from the base weights. |
| """ |
| if not self.merged: |
| warnings.warn("Already unmerged. Nothing to do.") |
| return |
| while len(self.merged_adapters) > 0: |
| active_adapter = self.merged_adapters.pop() |
| if active_adapter in self.waveft_spectrum.keys(): |
| self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) |
|
|
| def get_delta_weight(self, adapter) -> torch.Tensor: |
| return super().get_delta_weight(adapter) |
|
|
| def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: |
| previous_dtype = x.dtype |
|
|
| if self.disable_adapters: |
| if self.merged: |
| self.unmerge() |
| result = self.base_layer(x, *args, **kwargs) |
| elif self.merged: |
| result = self.base_layer(x, *args, **kwargs) |
| else: |
| result = self.base_layer(x, *args, **kwargs) |
| for active_adapter in self.active_adapters: |
| if active_adapter not in self.waveft_spectrum.keys(): |
| continue |
|
|
| delta_w = self.get_delta_weight(active_adapter) |
| x = self._cast_input_dtype(x, delta_w.dtype) |
| result = result + F.linear(x, delta_w) |
|
|
| result = result.to(previous_dtype) |
| return result |
|
|
| def __repr__(self) -> str: |
| rep = super().__repr__() |
| return "waveft." + rep |
|
|