vAIbe_diffutslator / model.py
forthezero's picture
Upload 28 files
2651102 verified
"""
扩散模型
轻量级Transformer用于噪声预测
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from embedding import SinusoidalTimeEmbedding
class FeedForward(nn.Module):
"""前馈网络"""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff)
self.w2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.w2(F.gelu(self.w1(x))))
class MultiHeadAttention(nn.Module):
"""多头自注意力"""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size = q.size(0)
# 线性变换并分头
q = self.w_q(q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
k = self.w_k(k).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
v = self.w_v(v).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# 注意力计算
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
# 合并头
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
return self.w_o(out)
class TransformerBlock(nn.Module):
"""Transformer块"""
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.attn = MultiHeadAttention(d_model, n_heads, dropout)
self.ff = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
# 自注意力 + 残差
x = x + self.dropout(self.attn(self.norm1(x), self.norm1(x), self.norm1(x), mask))
# 前馈 + 残差
x = x + self.dropout(self.ff(self.norm2(x)))
return x
class NoisePredictor(nn.Module):
"""噪声预测网络
输入: 加噪后的嵌入 x_t 和时间步 t
输出: 预测的噪声
"""
def __init__(
self,
d_model: int = 256,
n_heads: int = 4,
n_layers: int = 4,
d_ff: int = 512,
max_len: int = 128,
dropout: float = 0.1,
):
super().__init__()
self.d_model = d_model
# 时间步嵌入
self.time_embedding = SinusoidalTimeEmbedding(d_model)
self.time_mlp = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.GELU(),
nn.Linear(d_model * 4, d_model),
)
# Transformer层
self.layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
# 输出层
self.output_norm = nn.LayerNorm(d_model)
self.output_proj = nn.Linear(d_model, d_model)
# 初始化
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def forward(
self,
x_t: torch.Tensor,
t: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
x_t: [batch, seq_len, d_model] 加噪后的嵌入
t: [batch] 时间步
mask: [batch, seq_len] 可选的注意力mask
返回: [batch, seq_len, d_model] 预测的噪声
"""
batch_size, seq_len, _ = x_t.shape
# 时间步嵌入
t_emb = self.time_embedding(t) # [batch, d_model]
t_emb = self.time_mlp(t_emb) # [batch, d_model]
# 添加时间信息到每个位置
x = x_t + t_emb.unsqueeze(1)
# Transformer处理
for layer in self.layers:
x = layer(x, mask)
# 输出
x = self.output_norm(x)
noise_pred = self.output_proj(x)
return noise_pred
class DualNoisePredictor(nn.Module):
"""双语言噪声预测器
共享核心网络,语言特定的输入/输出投影
"""
def __init__(
self,
d_model: int = 256,
n_heads: int = 4,
n_layers: int = 4,
d_ff: int = 512,
max_len: int = 128,
dropout: float = 0.1,
):
super().__init__()
self.d_model = d_model
# 时间步嵌入(共享)
self.time_embedding = SinusoidalTimeEmbedding(d_model)
self.time_mlp = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.GELU(),
nn.Linear(d_model * 4, d_model),
)
# 语言特定的输入投影
self.zh_input_proj = nn.Linear(d_model, d_model)
self.en_input_proj = nn.Linear(d_model, d_model)
# 共享Transformer层
self.layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
# 语言特定的输出投影
self.zh_output_proj = nn.Linear(d_model, d_model)
self.en_output_proj = nn.Linear(d_model, d_model)
self.output_norm = nn.LayerNorm(d_model)
# 初始化
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def forward(
self,
x_t: torch.Tensor,
t: torch.Tensor,
lang: str = "zh",
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
x_t: [batch, seq_len, d_model]
t: [batch]
lang: "zh" 或 "en"
"""
# 时间步嵌入
t_emb = self.time_embedding(t)
t_emb = self.time_mlp(t_emb)
# 语言特定输入投影
if lang == "zh":
x = self.zh_input_proj(x_t)
else:
x = self.en_input_proj(x_t)
# 添加时间信息
x = x + t_emb.unsqueeze(1)
# 共享Transformer
for layer in self.layers:
x = layer(x, mask)
# 输出归一化
x = self.output_norm(x)
# 语言特定输出投影
if lang == "zh":
noise_pred = self.zh_output_proj(x)
else:
noise_pred = self.en_output_proj(x)
return noise_pred
def create_model(config) -> DualNoisePredictor:
"""创建模型"""
model = DualNoisePredictor(
d_model=config.model.d_model,
n_heads=config.model.n_heads,
n_layers=config.model.n_layers,
d_ff=config.model.d_ff,
max_len=config.model.max_len,
dropout=config.model.dropout,
)
return model