lglg666's picture
Upload folder using huggingface_hub
6766eda verified
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class LoRALinear(nn.Module):
"""
LoRA 线性层:直接持有 weight/bias,保持与 nn.Linear 相同的 state_dict key 结构。
state_dict 结构:
- weight: 原始权重(与 nn.Linear 一致)
- bias: 原始偏置(与 nn.Linear 一致)
- lora_A: LoRA 低秩矩阵 A
- lora_B: LoRA 低秩矩阵 B
这样设计的好处:加载预训练权重时无需做 key 转换。
"""
def __init__(
self,
base: nn.Linear,
r: int,
alpha: float = 1.0,
dropout: float = 0.0,
):
super().__init__()
assert isinstance(base, nn.Linear), "LoRALinear only supports wrapping nn.Linear."
self.in_features = base.in_features
self.out_features = base.out_features
self.r = r
self.alpha = alpha
self._base_scaling = alpha / r if r > 0 else 0.0
# 使用 buffer 存储 scaling,这样修改值不会触发 torch.compile 重编译
# persistent=False 表示不保存到 state_dict,避免加载时 missing key
self.register_buffer("scaling", torch.tensor(self._base_scaling), persistent=False)
# 直接持有 weight 和 bias(从原始 Linear 转移过来)
self.weight = base.weight
self.bias = base.bias # 可能是 None
# LoRA 参数
if r > 0:
self.lora_A = nn.Parameter(torch.zeros(r, self.in_features))
self.lora_B = nn.Parameter(torch.zeros(self.out_features, r))
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
else:
self.register_parameter("lora_A", None)
self.register_parameter("lora_B", None)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 基础 Linear 计算
result = F.linear(x, self.weight, self.bias)
if self.r <= 0 or self.lora_A is None:
return result
# LoRA: result + dropout(x @ A^T @ B^T) * scaling
lora_out = F.linear(F.linear(x, self.lora_A), self.lora_B)
return result + self.dropout(lora_out) * self.scaling
def reset_lora_parameters(self):
"""重置 LoRA 参数到初始状态"""
if self.r > 0 and self.lora_A is not None:
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def set_enabled(self, enabled: bool):
"""启用/禁用 LoRA(通过 scaling 控制,兼容 torch.compile)"""
# 使用 fill_ 原地修改 buffer 值,不会触发重编译
self.scaling.fill_(self._base_scaling if enabled else 0.0)
@property
def enabled(self) -> bool:
return self.scaling.item() != 0.0
def _get_parent_module(root: nn.Module, name: str) -> Optional[nn.Module]:
"""
根据类似 'layers.0.self_attn.q_proj' 的全名,返回 parent module(即 q_proj 的上一级)。
"""
parts = name.split(".")
if len(parts) == 1:
return root
parent = root
for p in parts[:-1]:
if not hasattr(parent, p):
return None
parent = getattr(parent, p)
return parent
def apply_lora_to_named_linear_modules(
root: nn.Module,
*,
target_submodule_names: list[str],
r: int,
alpha: float,
dropout: float,
) -> None:
"""
在给定模块及其子模块中,对名字以 target_submodule_names 结尾的 Linear 层注入 LoRA。
例如 target_submodule_names=["q_proj", "v_proj"] 时,
会在所有名为 *.q_proj / *.v_proj 的 nn.Linear 上替换为 LoRALinear。
"""
for full_name, module in list(root.named_modules()):
if not isinstance(module, nn.Linear):
continue
short_name = full_name.split(".")[-1]
if short_name not in target_submodule_names:
continue
parent = _get_parent_module(root, full_name)
if parent is None:
continue
# 用 LoRALinear 替换原始 Linear
lora_layer = LoRALinear(
base=module,
r=r,
alpha=alpha,
dropout=dropout,
)
setattr(parent, short_name, lora_layer)