Spaces:
Running
on
Zero
Running
on
Zero
| import yaml | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import re | |
| class LoRAConfig: | |
| def __init__(self, config_file): | |
| # Load the YAML configuration file | |
| with open(config_file, 'r') as file: | |
| config = yaml.safe_load(file) | |
| # self.config = config | |
| # Set class attributes based on the loaded YAML config | |
| for key, value in config.items(): | |
| setattr(self, key, value) | |
| class LoRALinear(nn.Module): | |
| def __init__(self, linear_layer, rank, scaling_rank, init_scale): | |
| super().__init__() | |
| self.in_features = linear_layer.in_features | |
| self.out_features = linear_layer.out_features | |
| self.rank = rank | |
| self.scaling_rank = scaling_rank | |
| self.weight = linear_layer.weight | |
| self.bias = linear_layer.bias | |
| if self.rank > 0: | |
| self.lora_a = nn.Parameter(torch.randn(rank, linear_layer.in_features) * init_scale) | |
| if init_scale < 0: | |
| self.lora_b = nn.Parameter(torch.randn(linear_layer.out_features, rank) * init_scale) | |
| else: | |
| self.lora_b = nn.Parameter(torch.zeros(linear_layer.out_features, rank)) | |
| if self.scaling_rank: | |
| self.multi_lora_a = nn.Parameter( | |
| torch.ones(self.scaling_rank, linear_layer.in_features) | |
| + torch.randn(self.scaling_rank, linear_layer.in_features) * init_scale | |
| ) | |
| if init_scale < 0: | |
| self.multi_lora_b = nn.Parameter( | |
| torch.ones(linear_layer.out_features, self.scaling_rank) | |
| + torch.randn(linear_layer.out_features, self.scaling_rank) * init_scale | |
| ) | |
| else: | |
| self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, self.scaling_rank)) | |
| def forward(self, input): | |
| if self.scaling_rank == 1 and self.rank == 0: | |
| # parsimonious implementation for ia3 and lora scaling | |
| if self.multi_lora_a.requires_grad: | |
| hidden = F.linear((input * self.multi_lora_a.flatten()), self.weight, self.bias) | |
| else: | |
| hidden = F.linear(input, self.weight, self.bias) | |
| if self.multi_lora_b.requires_grad: | |
| hidden = hidden * self.multi_lora_b.flatten() | |
| return hidden | |
| else: | |
| # general implementation for lora (adding and scaling) | |
| weight = self.weight | |
| if self.scaling_rank: | |
| weight = weight * torch.matmul(self.multi_lora_b, self.multi_lora_a) / self.scaling_rank | |
| if self.rank: | |
| weight = weight + torch.matmul(self.lora_b, self.lora_a) / self.rank | |
| return F.linear(input, weight, self.bias) | |
| def extra_repr(self): | |
| return "in_features={}, out_features={}, bias={}, rank={}, scaling_rank={}".format( | |
| self.in_features, self.out_features, self.bias is not None, self.rank, self.scaling_rank | |
| ) | |
| def modify_with_lora(transformer, config): | |
| for m_name, module in dict(transformer.named_modules()).items(): | |
| if re.fullmatch(config.lora_modules, m_name): | |
| for c_name, layer in dict(module.named_children()).items(): | |
| if re.fullmatch(config.lora_layers, c_name): | |
| assert isinstance( | |
| layer, nn.Linear | |
| ), f"LoRA can only be applied to torch.nn.Linear, but {layer} is {type(layer)}." | |
| setattr( | |
| module, | |
| c_name, | |
| LoRALinear(layer, config.lora_rank, config.lora_scaling_rank, config.lora_init_scale), | |
| ) | |
| return transformer |