vAIbe_diffutslator / embedding.py
forthezero's picture
Upload 28 files
2651102 verified
"""
嵌入层
语言特定的嵌入,包含位置编码和长度编码
"""
import math
import torch
import torch.nn as nn
from typing import Optional
class PositionalEncoding(nn.Module):
"""正弦位置编码"""
def __init__(self, d_model: int, max_len: int = 128, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 计算位置编码
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # [1, max_len, d_model]
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: [batch, seq_len, d_model]
"""
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
class SinusoidalTimeEmbedding(nn.Module):
"""时间步的正弦嵌入(用于扩散)"""
def __init__(self, d_model: int):
super().__init__()
self.d_model = d_model
def forward(self, t: torch.Tensor) -> torch.Tensor:
"""
t: [batch] 时间步,范围 [0, T]
返回: [batch, d_model]
"""
# 归一化到 [0, 1]
t = t.float().unsqueeze(-1) # [batch, 1]
half_dim = self.d_model // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t * emb.unsqueeze(0) # [batch, half_dim]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
return emb
class LanguageEmbedding(nn.Module):
"""语言特定的嵌入层"""
def __init__(
self,
vocab_size: int,
d_model: int,
max_len: int = 128,
dropout: float = 0.1,
):
super().__init__()
self.d_model = d_model
# Token嵌入
self.token_embedding = nn.Embedding(vocab_size, d_model)
# 位置编码
self.position_encoding = PositionalEncoding(d_model, max_len, dropout)
# 长度嵌入(用于变长序列)
self.length_embedding = nn.Embedding(max_len + 1, d_model)
# 缩放
self.scale = math.sqrt(d_model)
# 初始化
nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
nn.init.normal_(self.length_embedding.weight, mean=0.0, std=0.02)
def forward(
self,
token_ids: torch.Tensor,
lengths: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
token_ids: [batch, seq_len]
lengths: [batch] 可选,序列实际长度
返回: [batch, seq_len, d_model]
"""
# Token嵌入
x = self.token_embedding(token_ids) * self.scale
# 位置编码
x = self.position_encoding(x)
# 长度嵌入
if lengths is not None:
# 将长度信息广播到每个位置
len_emb = self.length_embedding(lengths) # [batch, d_model]
x = x + len_emb.unsqueeze(1) # 广播到序列长度
return x
def embed_noise(self, shape: tuple, device: torch.device) -> torch.Tensor:
"""生成纯噪声嵌入
shape: (batch, seq_len, d_model)
"""
return torch.randn(shape, device=device)
class DualLanguageEmbedding(nn.Module):
"""双语嵌入层,管理中英文嵌入"""
def __init__(
self,
vocab_size_zh: int,
vocab_size_en: int,
d_model: int,
max_len: int = 128,
dropout: float = 0.1,
):
super().__init__()
self.d_model = d_model
self.zh_embedding = LanguageEmbedding(vocab_size_zh, d_model, max_len, dropout)
self.en_embedding = LanguageEmbedding(vocab_size_en, d_model, max_len, dropout)
def forward(
self,
token_ids: torch.Tensor,
lang: str,
lengths: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
lang: 'zh' 或 'en'
"""
if lang == 'zh':
return self.zh_embedding(token_ids, lengths)
else:
return self.en_embedding(token_ids, lengths)
def embed_tokens(
self,
zh_ids: Optional[torch.Tensor] = None,
en_ids: Optional[torch.Tensor] = None,
zh_lens: Optional[torch.Tensor] = None,
en_lens: Optional[torch.Tensor] = None,
) -> tuple:
"""同时嵌入中英文"""
zh_emb = None
en_emb = None
if zh_ids is not None:
zh_emb = self.zh_embedding(zh_ids, zh_lens)
if en_ids is not None:
en_emb = self.en_embedding(en_ids, en_lens)
return zh_emb, en_emb
class OutputProjection(nn.Module):
"""输出投影层,将隐藏状态投影回词表空间"""
def __init__(self, d_model: int, vocab_size: int):
super().__init__()
self.projection = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: [batch, seq_len, d_model]
返回: [batch, seq_len, vocab_size] logits
"""
return self.projection(x)
class DualOutputProjection(nn.Module):
"""双语输出投影层"""
def __init__(self, d_model: int, vocab_size_zh: int, vocab_size_en: int):
super().__init__()
self.zh_projection = OutputProjection(d_model, vocab_size_zh)
self.en_projection = OutputProjection(d_model, vocab_size_en)
def forward(self, x: torch.Tensor, lang: str) -> torch.Tensor:
if lang == 'zh':
return self.zh_projection(x)
else:
return self.en_projection(x)