| import torch |
| from torch import nn |
|
|
|
|
| def replace_linear_with_lora( |
| module: nn.Module, |
| max_rank: int, |
| scale: float = 1.0, |
| ) -> None: |
| for name, child in module.named_children(): |
| if isinstance(child, nn.Linear): |
| new_lora = LinearLora( |
| in_features=child.in_features, |
| out_features=child.out_features, |
| bias=child.bias, |
| rank=max_rank, |
| scale=scale, |
| dtype=child.weight.dtype, |
| device=child.weight.device, |
| ) |
|
|
| new_lora.weight = child.weight |
| new_lora.bias = child.bias if child.bias is not None else None |
|
|
| setattr(module, name, new_lora) |
| else: |
| replace_linear_with_lora( |
| module=child, |
| max_rank=max_rank, |
| scale=scale, |
| ) |
|
|
|
|
| class LinearLora(nn.Linear): |
| def __init__( |
| self, |
| in_features: int, |
| out_features: int, |
| bias: bool, |
| rank: int, |
| dtype: torch.dtype, |
| device: torch.device, |
| lora_bias: bool = True, |
| scale: float = 1.0, |
| *args, |
| **kwargs, |
| ) -> None: |
| super().__init__( |
| in_features=in_features, |
| out_features=out_features, |
| bias=bias is not None, |
| device=device, |
| dtype=dtype, |
| *args, |
| **kwargs, |
| ) |
|
|
| assert isinstance(scale, float), "scale must be a float" |
|
|
| self.scale = scale |
| self.rank = rank |
| self.lora_bias = lora_bias |
| self.dtype = dtype |
| self.device = device |
|
|
| if rank > (new_rank := min(self.out_features, self.in_features)): |
| self.rank = new_rank |
|
|
| self.lora_A = nn.Linear( |
| in_features=in_features, |
| out_features=self.rank, |
| bias=False, |
| dtype=dtype, |
| device=device, |
| ) |
| self.lora_B = nn.Linear( |
| in_features=self.rank, |
| out_features=out_features, |
| bias=self.lora_bias, |
| dtype=dtype, |
| device=device, |
| ) |
|
|
| def set_scale(self, scale: float) -> None: |
| assert isinstance(scale, float), "scalar value must be a float" |
| self.scale = scale |
|
|
| def forward(self, input: torch.Tensor) -> torch.Tensor: |
| base_out = super().forward(input) |
|
|
| _lora_out_B = self.lora_B(self.lora_A(input)) |
| lora_update = _lora_out_B * self.scale |
|
|
| return base_out + lora_update |
|
|