| """ |
| 嵌入层 |
| 语言特定的嵌入,包含位置编码和长度编码 |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| from typing import Optional |
|
|
|
|
| class PositionalEncoding(nn.Module): |
| """正弦位置编码""" |
| |
| def __init__(self, d_model: int, max_len: int = 128, dropout: float = 0.1): |
| super().__init__() |
| self.dropout = nn.Dropout(p=dropout) |
| |
| |
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
| |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| |
| pe = pe.unsqueeze(0) |
| self.register_buffer('pe', pe) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| x: [batch, seq_len, d_model] |
| """ |
| x = x + self.pe[:, :x.size(1), :] |
| return self.dropout(x) |
|
|
|
|
| class SinusoidalTimeEmbedding(nn.Module): |
| """时间步的正弦嵌入(用于扩散)""" |
| |
| def __init__(self, d_model: int): |
| super().__init__() |
| self.d_model = d_model |
| |
| def forward(self, t: torch.Tensor) -> torch.Tensor: |
| """ |
| t: [batch] 时间步,范围 [0, T] |
| 返回: [batch, d_model] |
| """ |
| |
| t = t.float().unsqueeze(-1) |
| |
| half_dim = self.d_model // 2 |
| emb = math.log(10000) / (half_dim - 1) |
| emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb) |
| emb = t * emb.unsqueeze(0) |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) |
| |
| return emb |
|
|
|
|
| class LanguageEmbedding(nn.Module): |
| """语言特定的嵌入层""" |
| |
| def __init__( |
| self, |
| vocab_size: int, |
| d_model: int, |
| max_len: int = 128, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| |
| self.d_model = d_model |
| |
| |
| self.token_embedding = nn.Embedding(vocab_size, d_model) |
| |
| |
| self.position_encoding = PositionalEncoding(d_model, max_len, dropout) |
| |
| |
| self.length_embedding = nn.Embedding(max_len + 1, d_model) |
| |
| |
| self.scale = math.sqrt(d_model) |
| |
| |
| nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02) |
| nn.init.normal_(self.length_embedding.weight, mean=0.0, std=0.02) |
| |
| def forward( |
| self, |
| token_ids: torch.Tensor, |
| lengths: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| token_ids: [batch, seq_len] |
| lengths: [batch] 可选,序列实际长度 |
| 返回: [batch, seq_len, d_model] |
| """ |
| |
| x = self.token_embedding(token_ids) * self.scale |
| |
| |
| x = self.position_encoding(x) |
| |
| |
| if lengths is not None: |
| |
| len_emb = self.length_embedding(lengths) |
| x = x + len_emb.unsqueeze(1) |
| |
| return x |
| |
| def embed_noise(self, shape: tuple, device: torch.device) -> torch.Tensor: |
| """生成纯噪声嵌入 |
| |
| shape: (batch, seq_len, d_model) |
| """ |
| return torch.randn(shape, device=device) |
|
|
|
|
| class DualLanguageEmbedding(nn.Module): |
| """双语嵌入层,管理中英文嵌入""" |
| |
| def __init__( |
| self, |
| vocab_size_zh: int, |
| vocab_size_en: int, |
| d_model: int, |
| max_len: int = 128, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| |
| self.d_model = d_model |
| |
| self.zh_embedding = LanguageEmbedding(vocab_size_zh, d_model, max_len, dropout) |
| self.en_embedding = LanguageEmbedding(vocab_size_en, d_model, max_len, dropout) |
| |
| def forward( |
| self, |
| token_ids: torch.Tensor, |
| lang: str, |
| lengths: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| lang: 'zh' 或 'en' |
| """ |
| if lang == 'zh': |
| return self.zh_embedding(token_ids, lengths) |
| else: |
| return self.en_embedding(token_ids, lengths) |
| |
| def embed_tokens( |
| self, |
| zh_ids: Optional[torch.Tensor] = None, |
| en_ids: Optional[torch.Tensor] = None, |
| zh_lens: Optional[torch.Tensor] = None, |
| en_lens: Optional[torch.Tensor] = None, |
| ) -> tuple: |
| """同时嵌入中英文""" |
| zh_emb = None |
| en_emb = None |
| |
| if zh_ids is not None: |
| zh_emb = self.zh_embedding(zh_ids, zh_lens) |
| if en_ids is not None: |
| en_emb = self.en_embedding(en_ids, en_lens) |
| |
| return zh_emb, en_emb |
|
|
|
|
| class OutputProjection(nn.Module): |
| """输出投影层,将隐藏状态投影回词表空间""" |
| |
| def __init__(self, d_model: int, vocab_size: int): |
| super().__init__() |
| self.projection = nn.Linear(d_model, vocab_size, bias=False) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| x: [batch, seq_len, d_model] |
| 返回: [batch, seq_len, vocab_size] logits |
| """ |
| return self.projection(x) |
|
|
|
|
| class DualOutputProjection(nn.Module): |
| """双语输出投影层""" |
| |
| def __init__(self, d_model: int, vocab_size_zh: int, vocab_size_en: int): |
| super().__init__() |
| |
| self.zh_projection = OutputProjection(d_model, vocab_size_zh) |
| self.en_projection = OutputProjection(d_model, vocab_size_en) |
| |
| def forward(self, x: torch.Tensor, lang: str) -> torch.Tensor: |
| if lang == 'zh': |
| return self.zh_projection(x) |
| else: |
| return self.en_projection(x) |
|
|