| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| This module contains the implementation of the LoRA-FA optimizer. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from collections.abc import Iterable |
| from typing import Callable |
|
|
| import torch |
| import torch.nn as nn |
| from accelerate.utils.imports import is_bf16_available |
| from torch import autocast |
| from torch.optim import Optimizer |
|
|
| from ..peft_model import PeftModel |
| from ..utils.other import infer_device |
|
|
|
|
| class LoraFAOptimizer(Optimizer): |
| """ |
| Implements the LoRA-FA optimizer designed specifically for training Low-Rank Adaptation (LoRA) parameters |
| efficiently. Note that LoraFAOptimizer is based on adamw-hf in transformers, with only LoRA part modified. Without |
| LoRA it will fall back to adamw-hf. |
| |
| Args: |
| params (Iterable[nn.parameter.Parameter]): Parameters to optimize. |
| lr (float, optional): Learning rate (default: 1e-3). |
| betas (Tuple[float, float], optional): |
| Coefficients for computing running averages of gradient and squared gradient (default: (0.9, 0.999)). |
| eps (float, optional): Term added to denominator to improve numerical stability (default: 1e-6). |
| weight_decay (float, optional): Weight decay (L2 penalty) (default: 0.0). |
| correct_bias (bool, optional): Whether to apply bias correction as in original Adam (default: True). |
| |
| Args in sub-function step: |
| closure (Callable, optional): A closure that reevaluates the model and returns the loss. |
| |
| Reference: |
| - LoRA-FA: https://huggingface.co/papers/2308.03303 |
| """ |
|
|
| def __init__( |
| self, |
| params: Iterable[nn.parameter.Parameter], |
| lr: float = 1e-3, |
| betas: tuple[float, float] = (0.9, 0.999), |
| eps: float = 1e-6, |
| weight_decay: float = 0.0, |
| correct_bias: bool = True, |
| ): |
| if lr < 0.0: |
| raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") |
| if not 0.0 <= betas[0] < 1.0: |
| raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") |
| if not 0.0 <= betas[1] < 1.0: |
| raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") |
| if not 0.0 <= eps: |
| raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") |
| defaults = { |
| "lr": lr, |
| "betas": betas, |
| "eps": eps, |
| "weight_decay": weight_decay, |
| "correct_bias": correct_bias, |
| } |
| super().__init__(params, defaults) |
|
|
| @torch.no_grad() |
| def step(self, closure: Callable = None): |
| """ |
| Performs a single optimization step. |
| |
| Arguments: |
| closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. |
| """ |
| loss = None |
| if closure is not None: |
| loss = closure() |
|
|
| for group in self.param_groups: |
| scaling_factor = group["scaling_factor"] |
| param_list = [] |
| name_list = [] |
| for p, n in zip(group["params"], group["names"]): |
| |
| if "lora" not in n and p.grad is None: |
| continue |
| grad = p.grad |
|
|
| if "lora" in n: |
| param_list.append(p) |
| name_list.append(n) |
| if len(param_list) == 2: |
| name = n[: n.find("lora")] + "lora" |
| elif len(param_list) == 1: |
| continue |
| else: |
| name = n |
| |
| |
|
|
| state = self.state[name] |
| |
| if len(state) == 0: |
| if len(param_list) == 2: |
| state["step"] = 0 |
| |
| state["exp_avg_B"] = torch.zeros_like(param_list[1]) |
| |
| state["exp_avg_sq_B"] = torch.zeros_like(param_list[1]) |
| else: |
| state["step"] = 0 |
| |
| state["exp_avg"] = torch.zeros_like(p) |
| |
| state["exp_avg_sq"] = torch.zeros_like(p) |
|
|
| |
| |
| |
| |
| |
| |
| if len(param_list) == 2: |
| A = param_list[0] |
| B = param_list[1] |
| grad_B_orin = B.grad |
|
|
| |
| delta = 1e-8 |
|
|
| |
| AA_T = A @ A.T |
| AA_T_inv = torch.linalg.pinv(AA_T + delta * torch.eye(A.shape[0]).to(A.device)) |
|
|
| device_type = infer_device() |
|
|
| if is_bf16_available(): |
| with autocast(device_type=device_type, dtype=torch.bfloat16): |
| grad_B = (1 / scaling_factor**2) * (grad_B_orin @ AA_T_inv) |
| else: |
| grad_B = (1 / scaling_factor**2) * (grad_B_orin @ AA_T_inv) |
|
|
| if grad_B.dtype != B.grad.dtype: |
| grad_B = grad_B.to(B.grad.dtype) |
|
|
| exp_avg_B, exp_avg_sq_B = state["exp_avg_B"], state["exp_avg_sq_B"] |
| beta1, beta2 = group["betas"] |
| state["step"] += 1 |
| exp_avg_B.mul_(beta1).add_(grad_B, alpha=(1.0 - beta1)) |
| exp_avg_sq_B.mul_(beta2).addcmul_(grad_B, grad_B, value=1.0 - beta2) |
|
|
| denom_B = exp_avg_sq_B.sqrt().add_(group["eps"]) |
| step_size = group["lr"] |
| if group["correct_bias"]: |
| bias_correction1 = 1.0 - beta1 ** state["step"] |
| bias_correction2 = 1.0 - beta2 ** state["step"] |
| step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 |
| B.addcdiv_(exp_avg_B, denom_B, value=-step_size) |
| if group["weight_decay"] > 0.0: |
| B.add_(B, alpha=(-group["lr"] * group["weight_decay"])) |
| param_list = [] |
| name_list = [] |
|
|
| |
| else: |
| exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] |
| beta1, beta2 = group["betas"] |
|
|
| state["step"] += 1 |
|
|
| |
| |
| exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) |
| exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) |
| denom = exp_avg_sq.sqrt().add_(group["eps"]) |
|
|
| step_size = group["lr"] |
| if group["correct_bias"]: |
| bias_correction1 = 1.0 - beta1 ** state["step"] |
| bias_correction2 = 1.0 - beta2 ** state["step"] |
| step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 |
|
|
| p.addcdiv_(exp_avg, denom, value=-step_size) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| if group["weight_decay"] > 0.0: |
| p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) |
|
|
| return loss |
|
|
|
|
| def create_lorafa_optimizer( |
| model: PeftModel, r: int, lora_alpha: int, lr: float, weight_decay: float = 0.0, use_rslora: bool = False |
| ) -> Optimizer: |
| """ |
| Helper function to instantiate a lorafa optimizer specifically configured for a given model using the LoRA method. |
| |
| This function will: |
| - Disable gradient updates for the "lora_A" parameters (these are typically frozen during LoRA training). |
| - Compute the scaling factor based on provided `lora_alpha` and rank `r` for proper gradient projection. |
| - Create and configure parameter groups for the optimizer including specified learning rate, weight decay, and |
| additional optimizer options. |
| |
| For hyper-params, LoRA-FA uses the same hyper-params as AdamW, except for the LoRA hyper-params (r, lora_alpha, |
| use_rslora). One can always use the same hyper-params such as lr and weight_decay, as AdamW in LoRA tuning. |
| |
| Args: |
| model (PeftModel): The model containing LoRA-adapted parameters. |
| r (int): Rank of the LoRA decomposition. |
| lora_alpha (int): Scaling factor for LoRA parameterization. |
| lr (float): Learning rate for optimizer updates. |
| weight_decay (float): Weight decay for AdamW. |
| use_rslora (bool): |
| whether to use rslora. In rslora, the lora scaling factor becomes to lora_alpha / math.sqrt(r) instead of |
| lora_alpha / r. |
| |
| Returns: |
| Optimizer: Configured lorafa optimizer instance ready for training. |
| """ |
| for name, param in model.named_parameters(): |
| if "lora_A" in name: |
| param.requires_grad_(False) |
| lora_scaling = lora_alpha / math.sqrt(r) if use_rslora else lora_alpha / r |
| param_groups = [ |
| { |
| "params": model.parameters(), |
| "lr": lr, |
| "names": [name for name, _ in model.named_parameters()], |
| "scaling_factor": lora_scaling, |
| "betas": (0.9, 0.999), |
| "weight_decay": weight_decay, |
| } |
| ] |
| return LoraFAOptimizer(param_groups) |
|
|