| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| from typing import Any |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from peft.tuners.tuners_utils import BaseTunerLayer |
|
|
| from .config import PolyConfig |
| from .router import get_router |
|
|
|
|
| class PolyLayer(BaseTunerLayer): |
| |
| adapter_layer_names = ("poly_lora_A", "poly_lora_B", "poly_router") |
| |
| other_param_names = ("r", "n_tasks", "n_skills", "n_splits") |
|
|
| def __init__(self, base_layer: nn.Module, **kwargs): |
| self.base_layer = base_layer |
| self.r = {} |
| self.n_tasks = {} |
| self.n_skills = {} |
| self.n_splits = {} |
| self.poly_type = {} |
| self.poly_router = nn.ModuleDict() |
| self.poly_lora_A = nn.ParameterDict() |
| self.poly_lora_B = nn.ParameterDict() |
| self.kwargs = kwargs |
|
|
| base_layer = self.get_base_layer() |
| if isinstance(base_layer, nn.Linear): |
| in_features, out_features = base_layer.in_features, base_layer.out_features |
| else: |
| raise ValueError(f"Unsupported layer type {type(base_layer)}") |
|
|
| self.in_features = in_features |
| self.out_features = out_features |
|
|
| def update_layer(self, adapter_name, poly_config, inference_mode: bool = False, **kwargs): |
| if poly_config.r <= 0: |
| raise ValueError(f"`r` should be a positive integer value but the value passed is {poly_config.r}") |
|
|
| self.r[adapter_name] = poly_config.r |
| self.n_tasks[adapter_name] = poly_config.n_tasks |
| self.n_skills[adapter_name] = poly_config.n_skills |
| self.n_splits[adapter_name] = poly_config.n_splits |
| self.poly_type[adapter_name] = poly_config.poly_type |
|
|
| self.poly_lora_A[adapter_name] = nn.Parameter( |
| torch.empty( |
| poly_config.n_splits, |
| poly_config.n_skills, |
| self.in_features // poly_config.n_splits, |
| poly_config.r, |
| ) |
| ) |
| self.poly_lora_B[adapter_name] = nn.Parameter( |
| torch.empty( |
| poly_config.n_splits, |
| poly_config.n_skills, |
| poly_config.r, |
| self.out_features // poly_config.n_splits, |
| ) |
| ) |
| self.poly_router[adapter_name] = get_router(poly_config) |
|
|
| self.reset_poly_parameters(adapter_name, init_weights=poly_config.init_weights) |
|
|
| self._move_adapter_to_device_of_base_layer(adapter_name) |
| self.set_adapter(self.active_adapters, inference_mode=inference_mode) |
|
|
| def reset_poly_parameters(self, adapter_name, init_weights): |
| if adapter_name in self.poly_lora_A.keys(): |
| |
| |
| n_splits, n_skills, d, r = self.poly_lora_A[adapter_name].shape |
| for skill in range(n_skills): |
| for split in range(n_splits): |
| param = torch.empty((r, d)) |
| torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5)) |
| self.poly_lora_A[adapter_name].data[split, skill, :, :] = param.T |
|
|
| if init_weights: |
| |
| torch.nn.init.zeros_(self.poly_lora_B[adapter_name]) |
| else: |
| |
| n_splits, n_skills, r, d = self.poly_lora_B[adapter_name].shape |
| for skill in range(n_skills): |
| for split in range(n_splits): |
| param = torch.empty((d, r)) |
| torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5)) |
| self.poly_lora_B[adapter_name].data[split, skill, :, :] = param.T |
|
|
| |
| self.poly_router[adapter_name].reset() |
|
|
|
|
| class Linear(nn.Module, PolyLayer): |
| |
| def __init__( |
| self, |
| base_layer, |
| adapter_name: str, |
| poly_config: PolyConfig, |
| **kwargs, |
| ) -> None: |
| super().__init__() |
| PolyLayer.__init__(self, base_layer, **kwargs) |
|
|
| self._active_adapter = adapter_name |
| self.update_layer(adapter_name, poly_config) |
|
|
| def forward(self, x: torch.Tensor, *args: Any, task_ids: torch.Tensor = None, **kwargs: Any) -> torch.Tensor: |
| previous_dtype = x.dtype |
| if self.disable_adapters: |
| result = self.base_layer(x, *args, **kwargs) |
| else: |
| result = self.base_layer(x, *args, **kwargs) |
| for active_adapter in self.active_adapters: |
| if active_adapter not in self.poly_lora_A.keys(): |
| continue |
|
|
| r = self.r[active_adapter] |
| poly_router = self.poly_router[active_adapter] |
| poly_lora_A = self.poly_lora_A[active_adapter] |
| poly_lora_B = self.poly_lora_B[active_adapter] |
|
|
| |
| |
| mixing_weights = poly_router(task_ids=task_ids, input_ids=x) |
| bs, n_splits, n_skills = mixing_weights.size() |
|
|
| |
| |
| A = torch.einsum("bqs,qsdr->bqdr", (mixing_weights, poly_lora_A)) |
| B = torch.einsum("bqs,qsrd->bqrd", (mixing_weights, poly_lora_B)) |
|
|
| A = A.reshape(bs, self.in_features, r) |
| B = B.transpose(1, 2).reshape(bs, r, self.out_features) |
|
|
| x = x.to(A.dtype) |
| result += x.bmm(A).bmm(B) / r |
|
|
| result = result.to(previous_dtype) |
| return result |
|
|
| def __repr__(self) -> str: |
| rep = super().__repr__() |
| return "poly." + rep |
|
|