| """ |
| 扩散模型 |
| 轻量级Transformer用于噪声预测 |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Tuple |
|
|
| from embedding import SinusoidalTimeEmbedding |
|
|
|
|
| class FeedForward(nn.Module): |
| """前馈网络""" |
| |
| def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): |
| super().__init__() |
| self.w1 = nn.Linear(d_model, d_ff) |
| self.w2 = nn.Linear(d_ff, d_model) |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.dropout(self.w2(F.gelu(self.w1(x)))) |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| """多头自注意力""" |
| |
| def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1): |
| super().__init__() |
| assert d_model % n_heads == 0 |
| |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.d_k = d_model // n_heads |
| |
| self.w_q = nn.Linear(d_model, d_model) |
| self.w_k = nn.Linear(d_model, d_model) |
| self.w_v = nn.Linear(d_model, d_model) |
| self.w_o = nn.Linear(d_model, d_model) |
| |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| batch_size = q.size(0) |
| |
| |
| q = self.w_q(q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) |
| k = self.w_k(k).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) |
| v = self.w_v(v).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) |
| |
| |
| scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) |
| |
| if mask is not None: |
| scores = scores.masked_fill(mask == 0, float('-inf')) |
| |
| attn = F.softmax(scores, dim=-1) |
| attn = self.dropout(attn) |
| |
| |
| out = torch.matmul(attn, v) |
| out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) |
| |
| return self.w_o(out) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """Transformer块""" |
| |
| def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1): |
| super().__init__() |
| |
| self.attn = MultiHeadAttention(d_model, n_heads, dropout) |
| self.ff = FeedForward(d_model, d_ff, dropout) |
| |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| |
| x = x + self.dropout(self.attn(self.norm1(x), self.norm1(x), self.norm1(x), mask)) |
| |
| x = x + self.dropout(self.ff(self.norm2(x))) |
| return x |
|
|
|
|
| class NoisePredictor(nn.Module): |
| """噪声预测网络 |
| |
| 输入: 加噪后的嵌入 x_t 和时间步 t |
| 输出: 预测的噪声 |
| """ |
| |
| def __init__( |
| self, |
| d_model: int = 256, |
| n_heads: int = 4, |
| n_layers: int = 4, |
| d_ff: int = 512, |
| max_len: int = 128, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| |
| self.d_model = d_model |
| |
| |
| self.time_embedding = SinusoidalTimeEmbedding(d_model) |
| self.time_mlp = nn.Sequential( |
| nn.Linear(d_model, d_model * 4), |
| nn.GELU(), |
| nn.Linear(d_model * 4, d_model), |
| ) |
| |
| |
| self.layers = nn.ModuleList([ |
| TransformerBlock(d_model, n_heads, d_ff, dropout) |
| for _ in range(n_layers) |
| ]) |
| |
| |
| self.output_norm = nn.LayerNorm(d_model) |
| self.output_proj = nn.Linear(d_model, d_model) |
| |
| |
| self.apply(self._init_weights) |
| |
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
| |
| def forward( |
| self, |
| x_t: torch.Tensor, |
| t: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| x_t: [batch, seq_len, d_model] 加噪后的嵌入 |
| t: [batch] 时间步 |
| mask: [batch, seq_len] 可选的注意力mask |
| |
| 返回: [batch, seq_len, d_model] 预测的噪声 |
| """ |
| batch_size, seq_len, _ = x_t.shape |
| |
| |
| t_emb = self.time_embedding(t) |
| t_emb = self.time_mlp(t_emb) |
| |
| |
| x = x_t + t_emb.unsqueeze(1) |
| |
| |
| for layer in self.layers: |
| x = layer(x, mask) |
| |
| |
| x = self.output_norm(x) |
| noise_pred = self.output_proj(x) |
| |
| return noise_pred |
|
|
|
|
| class DualNoisePredictor(nn.Module): |
| """双语言噪声预测器 |
| |
| 共享核心网络,语言特定的输入/输出投影 |
| """ |
| |
| def __init__( |
| self, |
| d_model: int = 256, |
| n_heads: int = 4, |
| n_layers: int = 4, |
| d_ff: int = 512, |
| max_len: int = 128, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| |
| self.d_model = d_model |
| |
| |
| self.time_embedding = SinusoidalTimeEmbedding(d_model) |
| self.time_mlp = nn.Sequential( |
| nn.Linear(d_model, d_model * 4), |
| nn.GELU(), |
| nn.Linear(d_model * 4, d_model), |
| ) |
| |
| |
| self.zh_input_proj = nn.Linear(d_model, d_model) |
| self.en_input_proj = nn.Linear(d_model, d_model) |
| |
| |
| self.layers = nn.ModuleList([ |
| TransformerBlock(d_model, n_heads, d_ff, dropout) |
| for _ in range(n_layers) |
| ]) |
| |
| |
| self.zh_output_proj = nn.Linear(d_model, d_model) |
| self.en_output_proj = nn.Linear(d_model, d_model) |
| |
| self.output_norm = nn.LayerNorm(d_model) |
| |
| |
| self.apply(self._init_weights) |
| |
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
| |
| def forward( |
| self, |
| x_t: torch.Tensor, |
| t: torch.Tensor, |
| lang: str = "zh", |
| mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| x_t: [batch, seq_len, d_model] |
| t: [batch] |
| lang: "zh" 或 "en" |
| """ |
| |
| t_emb = self.time_embedding(t) |
| t_emb = self.time_mlp(t_emb) |
| |
| |
| if lang == "zh": |
| x = self.zh_input_proj(x_t) |
| else: |
| x = self.en_input_proj(x_t) |
| |
| |
| x = x + t_emb.unsqueeze(1) |
| |
| |
| for layer in self.layers: |
| x = layer(x, mask) |
| |
| |
| x = self.output_norm(x) |
| |
| |
| if lang == "zh": |
| noise_pred = self.zh_output_proj(x) |
| else: |
| noise_pred = self.en_output_proj(x) |
| |
| return noise_pred |
|
|
|
|
| def create_model(config) -> DualNoisePredictor: |
| """创建模型""" |
| model = DualNoisePredictor( |
| d_model=config.model.d_model, |
| n_heads=config.model.n_heads, |
| n_layers=config.model.n_layers, |
| d_ff=config.model.d_ff, |
| max_len=config.model.max_len, |
| dropout=config.model.dropout, |
| ) |
| return model |
|
|