MultiModal / peft_.py
szxllm's picture
Update peft_.py
3d1c312 verified
import torch
import torch.nn as nn
import math
class LoRALayer(nn.Module):
"""低秩适应层 (LoRA)"""
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 8,
alpha: float = 16.0,
dropout: float = 0.0
):
super().__init__()
self.rank = rank
self.alpha = alpha
self.scaling = alpha / rank
self.lora_A = nn.Parameter(torch.zeros(in_features, rank))
self.lora_B = nn.Parameter(torch.zeros(rank, out_features))
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
self.merged = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""前向传播"""
result = x @ self.lora_A @ self.lora_B
result = self.dropout(result)
return result * self.scaling
class LinearWithLoRA(nn.Module):
"""带LoRA的线性层"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
use_lora: bool = False,
lora_rank: int = 8,
lora_alpha: float = 16.0,
lora_dropout: float = 0.0
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.use_lora = use_lora
self.base_linear = nn.Linear(in_features, out_features, bias=bias)
if use_lora:
self.lora = LoRALayer(
in_features,
out_features,
lora_rank,
lora_alpha,
lora_dropout
)
self.merged = False
else:
self.lora = None
self.merged = False
def merge(self):
"""将LoRA权重合并到基础权重中"""
if self.use_lora and not self.merged:
lora_weight = (self.lora.lora_A @ self.lora.lora_B) * self.lora.scaling
self.base_linear.weight.data += lora_weight.T
self.merged = True
def unmerge(self):
"""取消合并LoRA权重"""
if self.use_lora and self.merged:
lora_weight = (self.lora.lora_A @ self.lora.lora_B) * self.lora.scaling
self.base_linear.weight.data -= lora_weight.T
self.merged = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""前向传播"""
output = self.base_linear(x)
if self.use_lora and self.lora is not None and not self.merged:
output = output + self.lora(x)
return output
class AdapterLayer(nn.Module):
"""Adapter层 - 轻量级微调"""
def __init__(
self,
dim: int,
bottleneck_dim: int = 64,
dropout: float = 0.1,
activation: str = 'gelu',
residual_scale: float = 1.0
):
super().__init__()
self.residual_scale = residual_scale
self.down_proj = nn.Linear(dim, bottleneck_dim)
if activation == 'gelu':
self.activation = nn.GELU()
elif activation == 'relu':
self.activation = nn.ReLU()
elif activation == 'silu':
self.activation = nn.SiLU()
else:
self.activation = nn.GELU()
self.up_proj = nn.Linear(bottleneck_dim, dim)
self.dropout = nn.Dropout(dropout)
from components import RMSNorm
self.layer_norm = RMSNorm(dim)
self._init_weights()
def _init_weights(self):
"""初始化权重"""
nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
nn.init.zeros_(self.up_proj.weight)
if self.down_proj.bias is not None:
nn.init.zeros_(self.down_proj.bias)
if self.up_proj.bias is not None:
nn.init.zeros_(self.up_proj.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""前向传播"""
residual = x
x = self.layer_norm(x)
x = self.down_proj(x)
x = self.activation(x)
x = self.dropout(x)
x = self.up_proj(x)
x = self.dropout(x)
return residual + x * self.residual_scale
class PrefixTuning(nn.Module):
"""Prefix Tuning"""
def __init__(
self,
num_layers: int,
num_tokens: int,
dim: int,
num_heads: int
):
super().__init__()
self.num_layers = num_layers
self.num_tokens = num_tokens
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.prefix = nn.Parameter(
torch.randn(num_layers, 2, num_tokens, num_heads, head_dim)
)
nn.init.normal_(self.prefix, std=0.02)
def forward(self, layer_idx: int, batch_size: int) -> torch.Tensor:
"""获取指定层的prefix"""
prefix = self.prefix[layer_idx]
prefix = prefix.unsqueeze(1).expand(
2, batch_size, self.num_heads, self.num_tokens, -1
)
return prefix
class PromptTuning(nn.Module):
"""Prompt Tuning"""
def __init__(
self,
num_tokens: int,
dim: int,
init_from_vocab: bool = False,
vocab_embeddings: nn.Embedding = None
):
super().__init__()
self.num_tokens = num_tokens
self.dim = dim
self.prompt_embeddings = nn.Parameter(torch.randn(num_tokens, dim))
if init_from_vocab and vocab_embeddings is not None:
indices = torch.randint(0, vocab_embeddings.num_embeddings, (num_tokens,))
self.prompt_embeddings.data = vocab_embeddings.weight[indices].clone()
else:
nn.init.normal_(self.prompt_embeddings, std=0.02)
def forward(self, batch_size: int) -> torch.Tensor:
"""获取prompt embeddings"""
return self.prompt_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
class IALayer(nn.Module):
"""(IA)³层"""
def __init__(self, dim: int):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""应用缩放"""
return x * self.scale