import torch import torch.nn as nn import math class LoRALayer(nn.Module): """低秩适应层 (LoRA)""" def __init__( self, in_features: int, out_features: int, rank: int = 8, alpha: float = 16.0, dropout: float = 0.0 ): super().__init__() self.rank = rank self.alpha = alpha self.scaling = alpha / rank self.lora_A = nn.Parameter(torch.zeros(in_features, rank)) self.lora_B = nn.Parameter(torch.zeros(rank, out_features)) self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) self.merged = False def forward(self, x: torch.Tensor) -> torch.Tensor: """前向传播""" result = x @ self.lora_A @ self.lora_B result = self.dropout(result) return result * self.scaling class LinearWithLoRA(nn.Module): """带LoRA的线性层""" def __init__( self, in_features: int, out_features: int, bias: bool = True, use_lora: bool = False, lora_rank: int = 8, lora_alpha: float = 16.0, lora_dropout: float = 0.0 ): super().__init__() self.in_features = in_features self.out_features = out_features self.use_lora = use_lora self.base_linear = nn.Linear(in_features, out_features, bias=bias) if use_lora: self.lora = LoRALayer( in_features, out_features, lora_rank, lora_alpha, lora_dropout ) self.merged = False else: self.lora = None self.merged = False def merge(self): """将LoRA权重合并到基础权重中""" if self.use_lora and not self.merged: lora_weight = (self.lora.lora_A @ self.lora.lora_B) * self.lora.scaling self.base_linear.weight.data += lora_weight.T self.merged = True def unmerge(self): """取消合并LoRA权重""" if self.use_lora and self.merged: lora_weight = (self.lora.lora_A @ self.lora.lora_B) * self.lora.scaling self.base_linear.weight.data -= lora_weight.T self.merged = False def forward(self, x: torch.Tensor) -> torch.Tensor: """前向传播""" output = self.base_linear(x) if self.use_lora and self.lora is not None and not self.merged: output = output + self.lora(x) return output class AdapterLayer(nn.Module): """Adapter层 - 轻量级微调""" def __init__( self, dim: int, bottleneck_dim: int = 64, dropout: float = 0.1, activation: str = 'gelu', residual_scale: float = 1.0 ): super().__init__() self.residual_scale = residual_scale self.down_proj = nn.Linear(dim, bottleneck_dim) if activation == 'gelu': self.activation = nn.GELU() elif activation == 'relu': self.activation = nn.ReLU() elif activation == 'silu': self.activation = nn.SiLU() else: self.activation = nn.GELU() self.up_proj = nn.Linear(bottleneck_dim, dim) self.dropout = nn.Dropout(dropout) from components import RMSNorm self.layer_norm = RMSNorm(dim) self._init_weights() def _init_weights(self): """初始化权重""" nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) nn.init.zeros_(self.up_proj.weight) if self.down_proj.bias is not None: nn.init.zeros_(self.down_proj.bias) if self.up_proj.bias is not None: nn.init.zeros_(self.up_proj.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: """前向传播""" residual = x x = self.layer_norm(x) x = self.down_proj(x) x = self.activation(x) x = self.dropout(x) x = self.up_proj(x) x = self.dropout(x) return residual + x * self.residual_scale class PrefixTuning(nn.Module): """Prefix Tuning""" def __init__( self, num_layers: int, num_tokens: int, dim: int, num_heads: int ): super().__init__() self.num_layers = num_layers self.num_tokens = num_tokens self.dim = dim self.num_heads = num_heads head_dim = dim // num_heads self.prefix = nn.Parameter( torch.randn(num_layers, 2, num_tokens, num_heads, head_dim) ) nn.init.normal_(self.prefix, std=0.02) def forward(self, layer_idx: int, batch_size: int) -> torch.Tensor: """获取指定层的prefix""" prefix = self.prefix[layer_idx] prefix = prefix.unsqueeze(1).expand( 2, batch_size, self.num_heads, self.num_tokens, -1 ) return prefix class PromptTuning(nn.Module): """Prompt Tuning""" def __init__( self, num_tokens: int, dim: int, init_from_vocab: bool = False, vocab_embeddings: nn.Embedding = None ): super().__init__() self.num_tokens = num_tokens self.dim = dim self.prompt_embeddings = nn.Parameter(torch.randn(num_tokens, dim)) if init_from_vocab and vocab_embeddings is not None: indices = torch.randint(0, vocab_embeddings.num_embeddings, (num_tokens,)) self.prompt_embeddings.data = vocab_embeddings.weight[indices].clone() else: nn.init.normal_(self.prompt_embeddings, std=0.02) def forward(self, batch_size: int) -> torch.Tensor: """获取prompt embeddings""" return self.prompt_embeddings.unsqueeze(0).expand(batch_size, -1, -1) class IALayer(nn.Module): """(IA)³层""" def __init__(self, dim: int): super().__init__() self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: """应用缩放""" return x * self.scale