Spaces:
Running
Running
| """ | |
| Diffutslator Hugging Face Space 应用 | |
| 基于扩散模型的机器翻译演示 | |
| """ | |
| import os | |
| import sys | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| import gradio as gr | |
| from typing import Optional, Tuple, List | |
| from dataclasses import dataclass, field | |
| import json | |
| # ==================== 配置(与config.py保持一致,用于加载检查点)==================== | |
| class ModelConfig: | |
| d_model: int = 256 | |
| n_heads: int = 4 | |
| n_layers: int = 4 | |
| d_ff: int = 512 | |
| max_len: int = 128 | |
| dropout: float = 0.1 | |
| vocab_size_zh: int = 8000 | |
| vocab_size_en: int = 8000 | |
| pad_token: str = "<pad>" | |
| sos_token: str = "<sos>" | |
| eos_token: str = "<eos>" | |
| unk_token: str = "<unk>" | |
| mask_token: str = "<mask>" | |
| class DiffusionConfig: | |
| timesteps: int = 1000 | |
| ddim_steps: int = 50 | |
| beta_start: float = 0.0001 | |
| beta_end: float = 0.02 | |
| length_noise_scale: float = 0.3 | |
| interpolation_strength: float = 0.8 # 语言插值强度 | |
| cross_lingual_mode: bool = True # 跨语言扩散模式 | |
| class TrainingConfig: | |
| batch_size: int = 64 | |
| gradient_accumulation: int = 1 | |
| learning_rate: float = 1e-4 | |
| weight_decay: float = 0.01 | |
| warmup_steps: int = 500 | |
| epochs: int = 10 | |
| save_every: int = 1 | |
| eval_every: int = 100 | |
| quick_mode: bool = False | |
| quick_samples: int = 1000 | |
| checkpoint_dir: str = "checkpoints" | |
| resume: Optional[str] = None | |
| class DataConfig: | |
| tatoeba_path: str = "" | |
| cveto_zh_path: str = "" | |
| cveto_en_path: str = "" | |
| max_samples: Optional[int] = None | |
| min_len: int = 2 | |
| max_len: int = 128 | |
| use_cache: bool = True | |
| cache_dir: str = ".cache" | |
| class Config: | |
| model: ModelConfig = field(default_factory=ModelConfig) | |
| diffusion: DiffusionConfig = field(default_factory=DiffusionConfig) | |
| training: TrainingConfig = field(default_factory=TrainingConfig) | |
| data: DataConfig = field(default_factory=DataConfig) | |
| project_dir: str = "" | |
| # 创建一个假的config模块,用于加载检查点时反序列化 | |
| class _FakeConfigModule: | |
| Config = Config | |
| ModelConfig = ModelConfig | |
| DiffusionConfig = DiffusionConfig | |
| TrainingConfig = TrainingConfig | |
| DataConfig = DataConfig | |
| # 将假模块注入sys.modules | |
| sys.modules['config'] = _FakeConfigModule() | |
| # ==================== 分词器 ==================== | |
| import re | |
| class Tokenizer: | |
| """BPE分词器(与tokenizer.py兼容)""" | |
| def __init__(self, vocab_size: int = 8000, lang: str = "zh"): | |
| self.vocab_size = vocab_size | |
| self.lang = lang | |
| # 特殊token | |
| self.pad_token = "<pad>" | |
| self.sos_token = "<sos>" | |
| self.eos_token = "<eos>" | |
| self.unk_token = "<unk>" | |
| self.mask_token = "<mask>" | |
| self.special_tokens = [self.pad_token, self.sos_token, self.eos_token, self.unk_token, self.mask_token] | |
| # 词表 | |
| self.token_to_id: dict = {} | |
| self.id_to_token: dict = {} | |
| # BPE合并规则 | |
| self.merges: list = [] | |
| self.bpe_ranks: dict = {} | |
| def vocab_size_actual(self) -> int: | |
| return len(self.token_to_id) | |
| def pad_id(self) -> int: | |
| return self.token_to_id[self.pad_token] | |
| def sos_id(self) -> int: | |
| return self.token_to_id[self.sos_token] | |
| def eos_id(self) -> int: | |
| return self.token_to_id[self.eos_token] | |
| def unk_id(self) -> int: | |
| return self.token_to_id[self.unk_token] | |
| def _is_chinese(self, char: str) -> bool: | |
| return '\u4e00' <= char <= '\u9fff' | |
| def _pre_tokenize(self, text: str) -> List[str]: | |
| """预分词""" | |
| if self.lang == "zh": | |
| tokens = [] | |
| current = "" | |
| for char in text: | |
| if self._is_chinese(char): | |
| if current: | |
| tokens.append(current) | |
| current = "" | |
| tokens.append(char) | |
| elif char.isalnum(): | |
| current += char.lower() | |
| else: | |
| if current: | |
| tokens.append(current) | |
| current = "" | |
| if char.strip(): | |
| tokens.append(char) | |
| if current: | |
| tokens.append(current) | |
| return tokens | |
| else: | |
| text = text.lower() | |
| tokens = re.findall(r"\w+|[^\w\s]", text) | |
| return tokens | |
| def _get_pairs(self, word: tuple) -> set: | |
| """获取词中的所有相邻字符对""" | |
| pairs = set() | |
| prev = word[0] | |
| for char in word[1:]: | |
| pairs.add((prev, char)) | |
| prev = char | |
| return pairs | |
| def _apply_bpe(self, token: str) -> List[str]: | |
| """对单个token应用BPE""" | |
| if not token: | |
| return [] | |
| word = tuple(token) + ('</w>',) | |
| while True: | |
| pairs = self._get_pairs(word) | |
| if not pairs: | |
| break | |
| # 找到rank最高的pair | |
| min_pair = None | |
| min_rank = float('inf') | |
| for pair in pairs: | |
| rank = self.bpe_ranks.get(pair, float('inf')) | |
| if rank < min_rank: | |
| min_rank = rank | |
| min_pair = pair | |
| if min_pair is None or min_rank == float('inf'): | |
| break | |
| # 合并 | |
| new_word = [] | |
| i = 0 | |
| while i < len(word): | |
| if i < len(word) - 1 and word[i] == min_pair[0] and word[i + 1] == min_pair[1]: | |
| new_word.append(min_pair[0] + min_pair[1]) | |
| i += 2 | |
| else: | |
| new_word.append(word[i]) | |
| i += 1 | |
| word = tuple(new_word) | |
| return [t for t in word if t != '</w>'] | |
| def encode(self, text: str, add_sos: bool = True, add_eos: bool = True) -> List[int]: | |
| """编码文本为token id序列""" | |
| tokens = self._pre_tokenize(text) | |
| ids = [] | |
| if add_sos: | |
| ids.append(self.sos_id) | |
| for token in tokens: | |
| bpe_tokens = self._apply_bpe(token) | |
| for t in bpe_tokens: | |
| ids.append(self.token_to_id.get(t, self.unk_id)) | |
| if add_eos: | |
| ids.append(self.eos_id) | |
| return ids | |
| def decode(self, ids: List[int], skip_special: bool = True) -> str: | |
| """解码token id序列为文本""" | |
| tokens = [] | |
| for id in ids: | |
| token = self.id_to_token.get(id, self.unk_token) | |
| if skip_special and token in self.special_tokens: | |
| continue | |
| token = token.replace('</w>', '') | |
| if token: | |
| tokens.append(token) | |
| if self.lang == "en": | |
| text = ' '.join(tokens) | |
| text = re.sub(r'\s+([.,!?;:\'\"])', r'\1', text) | |
| text = re.sub(r'([.,!?;:])([a-zA-Z])', r'\1 \2', text) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| else: | |
| text = ''.join(tokens) | |
| return text | |
| def load(cls, path: str) -> "Tokenizer": | |
| """加载分词器""" | |
| with open(path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| tokenizer = cls(vocab_size=data["vocab_size"], lang=data["lang"]) | |
| tokenizer.token_to_id = data["token_to_id"] | |
| tokenizer.id_to_token = {int(k): v for k, v in data["id_to_token"].items()} | |
| tokenizer.merges = [tuple(m) for m in data["merges"]] | |
| tokenizer.bpe_ranks = {pair: i for i, pair in enumerate(tokenizer.merges)} | |
| tokenizer.special_tokens = data["special_tokens"] | |
| return tokenizer | |
| # ==================== 模型组件 ==================== | |
| class PositionalEncoding(nn.Module): | |
| """正弦位置编码""" | |
| def __init__(self, d_model: int, max_len: int = 128, dropout: float = 0.1): | |
| super().__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0) | |
| self.register_buffer("pe", pe) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.pe[:, :x.size(1), :] | |
| return self.dropout(x) | |
| class SinusoidalTimeEmbedding(nn.Module): | |
| """时间步的正弦嵌入(用于扩散)""" | |
| def __init__(self, d_model: int): | |
| super().__init__() | |
| self.d_model = d_model | |
| def forward(self, t: torch.Tensor) -> torch.Tensor: | |
| t = t.float().unsqueeze(-1) | |
| half_dim = self.d_model // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb) | |
| emb = t * emb.unsqueeze(0) | |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) | |
| return emb | |
| class LanguageEmbedding(nn.Module): | |
| """语言特定的嵌入层""" | |
| def __init__(self, vocab_size: int, d_model: int, max_len: int = 128, dropout: float = 0.1): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.token_embedding = nn.Embedding(vocab_size, d_model) | |
| self.position_encoding = PositionalEncoding(d_model, max_len, dropout) | |
| self.length_embedding = nn.Embedding(max_len + 1, d_model) | |
| self.scale = math.sqrt(d_model) | |
| def forward(self, token_ids: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| x = self.token_embedding(token_ids) * self.scale | |
| x = self.position_encoding(x) | |
| if lengths is not None: | |
| len_emb = self.length_embedding(lengths) | |
| x = x + len_emb.unsqueeze(1) | |
| return x | |
| class DualLanguageEmbedding(nn.Module): | |
| """双语嵌入层""" | |
| def __init__(self, vocab_size_zh: int, vocab_size_en: int, d_model: int, max_len: int = 128, dropout: float = 0.1): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.zh_embedding = LanguageEmbedding(vocab_size_zh, d_model, max_len, dropout) | |
| self.en_embedding = LanguageEmbedding(vocab_size_en, d_model, max_len, dropout) | |
| def forward(self, token_ids: torch.Tensor, lang: str, lengths: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| if lang == 'zh': | |
| return self.zh_embedding(token_ids, lengths) | |
| else: | |
| return self.en_embedding(token_ids, lengths) | |
| class OutputProjection(nn.Module): | |
| """输出投影层""" | |
| def __init__(self, d_model: int, vocab_size: int): | |
| super().__init__() | |
| self.projection = nn.Linear(d_model, vocab_size, bias=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.projection(x) | |
| class DualOutputProjection(nn.Module): | |
| """双语输出投影层""" | |
| def __init__(self, d_model: int, vocab_size_zh: int, vocab_size_en: int): | |
| super().__init__() | |
| self.zh_projection = OutputProjection(d_model, vocab_size_zh) | |
| self.en_projection = OutputProjection(d_model, vocab_size_en) | |
| def forward(self, x: torch.Tensor, lang: str) -> torch.Tensor: | |
| if lang == 'zh': | |
| return self.zh_projection(x) | |
| else: | |
| return self.en_projection(x) | |
| class MultiHeadAttention(nn.Module): | |
| """改进的多头自注意力 - 合并 QKV 投影""" | |
| def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1): | |
| super().__init__() | |
| assert d_model % n_heads == 0 | |
| self.d_model = d_model | |
| self.n_heads = n_heads | |
| self.d_k = d_model // n_heads | |
| self.scale = math.sqrt(self.d_k) | |
| # 合并 QKV 投影(更高效) | |
| self.qkv = nn.Linear(d_model, d_model * 3, bias=False) | |
| self.w_o = nn.Linear(d_model, d_model, bias=False) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| batch_size = q.size(0) | |
| seq_len = q.size(1) | |
| # 合并计算 QKV | |
| qkv = self.qkv(q) | |
| q, k, v = qkv.chunk(3, dim=-1) | |
| # 分头 | |
| q = q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) | |
| k = k.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) | |
| v = v.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) | |
| # 注意力计算 | |
| scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale | |
| if mask is not None: | |
| scores = scores.masked_fill(mask == 0, float('-inf')) | |
| attn = F.softmax(scores, dim=-1) | |
| attn = self.dropout(attn) | |
| # 合并头 | |
| out = torch.matmul(attn, v) | |
| out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) | |
| return self.w_o(out) | |
| class FeedForward(nn.Module): | |
| """前馈网络 - 使用 GLU 结构""" | |
| def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, use_glu: bool = True): | |
| super().__init__() | |
| self.use_glu = use_glu | |
| if use_glu: | |
| # GLU 结构 - 更好的表达能力 | |
| self.w1 = nn.Linear(d_model, d_ff * 2) | |
| self.w2 = nn.Linear(d_ff, d_model) | |
| else: | |
| self.w1 = nn.Linear(d_model, d_ff) | |
| self.w2 = nn.Linear(d_ff, d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.use_glu: | |
| x, gate = self.w1(x).chunk(2, dim=-1) | |
| x = F.gelu(x) * F.gelu(gate) | |
| else: | |
| x = F.gelu(self.w1(x)) | |
| return self.dropout(self.w2(x)) | |
| class TransformerBlock(nn.Module): | |
| """Transformer块 - Pre-LayerNorm 结构""" | |
| def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1): | |
| super().__init__() | |
| # Pre-LayerNorm 结构 | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.attn = MultiHeadAttention(d_model, n_heads, dropout) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.ff = FeedForward(d_model, d_ff, dropout, use_glu=True) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| # 自注意力 + 残差 (Pre-LN) | |
| x = x + self.dropout(self.attn(self.norm1(x), self.norm1(x), self.norm1(x), mask)) | |
| # 前馈 + 残差 (Pre-LN) | |
| x = x + self.dropout(self.ff(self.norm2(x))) | |
| return x | |
| class DualNoisePredictor(nn.Module): | |
| """双语言噪声预测器 - 使用先进架构""" | |
| def __init__(self, d_model: int = 256, n_heads: int = 4, n_layers: int = 4, d_ff: int = 512, max_len: int = 128, dropout: float = 0.1): | |
| super().__init__() | |
| self.d_model = d_model | |
| # 时间步嵌入(共享)- 使用 TimeEmbedding 结构 | |
| self.time_embedding = nn.Module() | |
| self.time_embedding.sinusoidal = SinusoidalTimeEmbedding(d_model) | |
| self.time_embedding.mlp = nn.Sequential( | |
| nn.Linear(d_model, d_model * 4), | |
| nn.SiLU(), | |
| nn.Linear(d_model * 4, d_model * 4), | |
| nn.SiLU(), | |
| nn.Linear(d_model * 4, d_model), | |
| ) | |
| # 语言特定的输入投影(多层) | |
| self.zh_input_proj = nn.Sequential( | |
| nn.Linear(d_model, d_model), | |
| nn.LayerNorm(d_model), | |
| nn.GELU(), | |
| nn.Linear(d_model, d_model), | |
| ) | |
| self.en_input_proj = nn.Sequential( | |
| nn.Linear(d_model, d_model), | |
| nn.LayerNorm(d_model), | |
| nn.GELU(), | |
| nn.Linear(d_model, d_model), | |
| ) | |
| # 共享Transformer层 | |
| self.layers = nn.ModuleList([ | |
| TransformerBlock(d_model, n_heads, d_ff, dropout) | |
| for _ in range(n_layers) | |
| ]) | |
| # 语言特定的输出投影(多层) | |
| self.zh_output_proj = nn.Sequential( | |
| nn.Linear(d_model, d_model), | |
| nn.LayerNorm(d_model), | |
| nn.GELU(), | |
| nn.Linear(d_model, d_model), | |
| ) | |
| self.en_output_proj = nn.Sequential( | |
| nn.Linear(d_model, d_model), | |
| nn.LayerNorm(d_model), | |
| nn.GELU(), | |
| nn.Linear(d_model, d_model), | |
| ) | |
| self.output_norm = nn.LayerNorm(d_model) | |
| def forward(self, x_t: torch.Tensor, t: torch.Tensor, lang: str = "zh", mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| # 时间步嵌入 | |
| t_emb = self.time_embedding.sinusoidal(t) | |
| t_emb = self.time_embedding.mlp(t_emb) | |
| # 语言特定输入投影 | |
| if lang == "zh": | |
| x = self.zh_input_proj(x_t) | |
| else: | |
| x = self.en_input_proj(x_t) | |
| # 添加时间信息 | |
| x = x + t_emb.unsqueeze(1) | |
| # 共享Transformer | |
| for layer in self.layers: | |
| x = layer(x, mask) | |
| # 输出归一化 | |
| x = self.output_norm(x) | |
| # 语言特定输出投影 | |
| if lang == "zh": | |
| noise_pred = self.zh_output_proj(x) | |
| else: | |
| noise_pred = self.en_output_proj(x) | |
| return noise_pred | |
| class LanguageSwitcher(nn.Module): | |
| """语言切换分类器""" | |
| def __init__(self, d_model: int = 256, hidden_dim: int = 128, dropout: float = 0.1): | |
| super().__init__() | |
| self.global_pool = nn.AdaptiveAvgPool1d(1) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(d_model, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, 2), | |
| ) | |
| def forward(self, x_t: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| if mask is not None: | |
| x_t = x_t * mask.unsqueeze(-1) | |
| x = x_t.transpose(1, 2) | |
| x = self.global_pool(x).squeeze(-1) | |
| logits = self.classifier(x) | |
| return logits | |
| def predict(self, x_t: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[str, float]: | |
| self.eval() | |
| with torch.no_grad(): | |
| logits = self.forward(x_t, mask) | |
| probs = F.softmax(logits, dim=-1) | |
| zh_prob = probs[0, 0].item() | |
| en_prob = probs[0, 1].item() | |
| if zh_prob > en_prob: | |
| return "zh", zh_prob | |
| else: | |
| return "en", en_prob | |
| # ==================== 扩散过程 ==================== | |
| class CrossLingualDiffusion: | |
| """跨语言扩散模型:支持源语言和目标语言之间的插值""" | |
| def __init__(self, config: DiffusionConfig): | |
| self.config = config | |
| self.timesteps = config.timesteps | |
| self.interpolation_strength = config.interpolation_strength | |
| # Beta schedule (linear) | |
| betas = torch.linspace(config.beta_start, config.beta_end, self.timesteps) | |
| alphas = 1.0 - betas | |
| alphas_cumprod = torch.cumprod(alphas, dim=0) | |
| self.register_buffer("betas", betas) | |
| self.register_buffer("alphas", alphas) | |
| self.register_buffer("alphas_cumprod", alphas_cumprod) | |
| self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) | |
| self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1 - alphas_cumprod)) | |
| def register_buffer(self, name: str, tensor: torch.Tensor): | |
| setattr(self, name, tensor) | |
| def get_interpolation_factor(self, t: torch.Tensor) -> torch.Tensor: | |
| """计算插值因子(smoothstep平滑过渡)""" | |
| normalized_t = t.float() / self.timesteps | |
| # smoothstep: 3t^2 - 2t^3 | |
| factor = normalized_t * normalized_t * (3 - 2 * normalized_t) | |
| return factor * self.interpolation_strength | |
| def _align_sequences(self, x_source: torch.Tensor, x_target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: | |
| """对齐两个序列到相同长度""" | |
| source_len = x_source.size(1) | |
| target_len = x_target.size(1) | |
| target_seq_len = max(source_len, target_len) | |
| if source_len < target_seq_len: | |
| # 填充源序列 | |
| pad_len = target_seq_len - source_len | |
| x_source_aligned = F.pad(x_source, (0, 0, 0, pad_len)) | |
| else: | |
| x_source_aligned = x_source | |
| if target_len < target_seq_len: | |
| # 填充目标序列 | |
| pad_len = target_seq_len - target_len | |
| x_target_aligned = F.pad(x_target, (0, 0, 0, pad_len)) | |
| else: | |
| x_target_aligned = x_target | |
| return x_source_aligned, x_target_aligned, target_seq_len | |
| def q_sample( | |
| self, | |
| x_source: torch.Tensor, | |
| x_target: torch.Tensor, | |
| t: torch.Tensor, | |
| noise: Optional[torch.Tensor] = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """跨语言前向扩散:源语言和目标语言之间的插值 + 加噪""" | |
| # 对齐序列 | |
| x_source_aligned, x_target_aligned, seq_len = self._align_sequences(x_source, x_target) | |
| if noise is None: | |
| noise = torch.randn_like(x_source_aligned) | |
| # 计算插值因子 | |
| interp_factor = self.get_interpolation_factor(t).view(-1, 1, 1) | |
| # 插值:从源语言逐渐过渡到目标语言 | |
| x_interp = (1 - interp_factor) * x_source_aligned + interp_factor * x_target_aligned | |
| # 加噪 | |
| sqrt_alpha = self.sqrt_alphas_cumprod[t] | |
| sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t] | |
| x_t = sqrt_alpha.view(-1, 1, 1) * x_interp + sqrt_one_minus_alpha.view(-1, 1, 1) * noise | |
| return x_t, noise | |
| def p_sample(self, x_t: torch.Tensor, t: torch.Tensor, predicted_noise: torch.Tensor) -> torch.Tensor: | |
| """反向扩散单步""" | |
| beta = self.betas[t] | |
| sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t] | |
| sqrt_recip_alpha = 1.0 / torch.sqrt(self.alphas[t]) | |
| # 去噪 | |
| x_0_pred = sqrt_recip_alpha.view(-1, 1, 1) * (x_t - sqrt_one_minus_alpha.view(-1, 1, 1) * predicted_noise) | |
| # 添加噪声(除了最后一步) | |
| if t[0] > 0: | |
| noise = torch.randn_like(x_t) | |
| x_prev = x_0_pred + torch.sqrt(beta).view(-1, 1, 1) * noise | |
| else: | |
| x_prev = x_0_pred | |
| return x_prev | |
| class DDIMSampler: | |
| def __init__(self, diffusion: CrossLingualDiffusion, ddim_steps: int = 50): | |
| self.diffusion = diffusion | |
| self.ddim_steps = ddim_steps | |
| # 选择均匀分布的时间步,从高到低(从噪声到干净) | |
| c = self.diffusion.timesteps // ddim_steps | |
| ddim_timesteps = [i * c for i in range(ddim_steps)] | |
| self.ddim_timesteps = torch.tensor(list(reversed(ddim_timesteps))) | |
| def ddim_step(self, x_t: torch.Tensor, t: int, t_prev: int, | |
| predicted_noise: torch.Tensor, eta: float = 0.0) -> torch.Tensor: | |
| """DDIM单步""" | |
| alpha_t = self.diffusion.alphas_cumprod[t] | |
| alpha_prev = self.diffusion.alphas_cumprod[t_prev] if t_prev >= 0 else torch.tensor(1.0) | |
| # 预测 x_0 | |
| x_0_pred = (x_t - torch.sqrt(1 - alpha_t) * predicted_noise) / torch.sqrt(alpha_t) | |
| # 方差 | |
| sigma = eta * torch.sqrt((1 - alpha_prev) / (1 - alpha_t)) * torch.sqrt(1 - alpha_t / alpha_prev) | |
| # DDIM更新 | |
| dir_xt = torch.sqrt(1 - alpha_prev - sigma ** 2) * predicted_noise | |
| if t_prev >= 0: | |
| noise = torch.randn_like(x_t) | |
| x_prev = torch.sqrt(alpha_prev) * x_0_pred + dir_xt + sigma * noise | |
| else: | |
| x_prev = x_0_pred | |
| return x_prev | |
| # ==================== 翻译器 ==================== | |
| class Translator: | |
| def __init__(self, model_dir: str = "."): | |
| self.device = torch.device("cpu") | |
| # 配置 | |
| self.model_config = ModelConfig() | |
| self.diffusion_config = DiffusionConfig() | |
| # 加载分词器 | |
| self.zh_tokenizer = Tokenizer.load(os.path.join(model_dir, "tokenizer_zh.json")) | |
| self.en_tokenizer = Tokenizer.load(os.path.join(model_dir, "tokenizer_en.json")) | |
| # 初始化模型 | |
| self.embedding = DualLanguageEmbedding( | |
| vocab_size_zh=self.zh_tokenizer.vocab_size_actual, | |
| vocab_size_en=self.en_tokenizer.vocab_size_actual, | |
| d_model=self.model_config.d_model, | |
| max_len=self.model_config.max_len, | |
| dropout=0.0, | |
| ) | |
| self.output_proj = DualOutputProjection( | |
| d_model=self.model_config.d_model, | |
| vocab_size_zh=self.zh_tokenizer.vocab_size_actual, | |
| vocab_size_en=self.en_tokenizer.vocab_size_actual, | |
| ) | |
| self.model = DualNoisePredictor( | |
| d_model=self.model_config.d_model, | |
| n_heads=self.model_config.n_heads, | |
| n_layers=self.model_config.n_layers, | |
| d_ff=self.model_config.d_ff, | |
| max_len=self.model_config.max_len, | |
| dropout=0.0, | |
| ) | |
| self.switcher = LanguageSwitcher( | |
| d_model=self.model_config.d_model, | |
| hidden_dim=self.model_config.d_model // 2, | |
| dropout=0.0, | |
| ) | |
| self.diffusion = CrossLingualDiffusion(self.diffusion_config) | |
| # 加载权重 | |
| self._load_checkpoint(os.path.join(model_dir, "best.pt")) | |
| def _load_checkpoint(self, path: str): | |
| state = torch.load(path, map_location=self.device, weights_only=False) | |
| self.embedding.load_state_dict(state['embedding']) | |
| self.output_proj.load_state_dict(state['output_proj']) | |
| self.model.load_state_dict(state['model']) | |
| self.switcher.load_state_dict(state['switcher']) | |
| print(f"已加载模型: {path}") | |
| def _encode(self, text: str, lang: str) -> torch.Tensor: | |
| if lang == "zh": | |
| ids = self.zh_tokenizer.encode(text, add_sos=True, add_eos=True) | |
| else: | |
| ids = self.en_tokenizer.encode(text, add_sos=True, add_eos=True) | |
| return torch.tensor(ids, dtype=torch.long).unsqueeze(0) | |
| def _decode(self, ids: torch.Tensor, lang: str) -> str: | |
| ids = ids[0].tolist() | |
| if lang == "zh": | |
| return self.zh_tokenizer.decode(ids, skip_special=True) | |
| else: | |
| return self.en_tokenizer.decode(ids, skip_special=True) | |
| def _embed_to_tokens(self, x: torch.Tensor, lang: str) -> torch.Tensor: | |
| logits = self.output_proj(x, lang) | |
| return logits.argmax(dim=-1) | |
| def translate( | |
| self, | |
| text: str, | |
| source_lang: str, | |
| ddim_steps: int = 50, | |
| show_process: bool = False, | |
| ) -> Tuple[str, List[str]]: | |
| """翻译文本,返回结果和中间过程(跨语言扩散)""" | |
| self.model.eval() | |
| self.embedding.eval() | |
| self.output_proj.eval() | |
| self.switcher.eval() | |
| target_lang = "en" if source_lang == "zh" else "zh" | |
| # 更新DDIM步数 | |
| self.diffusion_config.ddim_steps = ddim_steps | |
| ddim_sampler = DDIMSampler(self.diffusion, ddim_steps) | |
| # 编码源语言 | |
| source_ids = self._encode(text, source_lang) | |
| source_len = torch.tensor([source_ids.size(1)]) | |
| # 嵌入源语言 | |
| source_emb = self.embedding(source_ids, source_lang, source_len) | |
| # 初始状态:使用跨语言扩散的前向过程 | |
| batch_size = source_emb.size(0) | |
| t_start = torch.full((batch_size,), self.diffusion_config.timesteps - 1, dtype=torch.long) | |
| noise_start = torch.randn_like(source_emb) | |
| # 模拟目标语言嵌入(随机噪声 + 源语言信息) | |
| target_emb_fake = torch.randn_like(source_emb) * 0.3 + source_emb * 0.7 | |
| x_t, _ = self.diffusion.q_sample(source_emb, target_emb_fake, t_start, noise_start) | |
| # DDIM反向扩散 | |
| timesteps = ddim_sampler.ddim_timesteps | |
| total_steps = len(timesteps) | |
| process_steps = [] | |
| for i, t in enumerate(timesteps[:-1]): | |
| t_prev = timesteps[i + 1] | |
| # 计算进度,决定当前语言 | |
| progress = i / total_steps | |
| if progress < 0.3: | |
| current_lang = source_lang | |
| else: | |
| current_lang = target_lang | |
| # 预测噪声 | |
| t_tensor = torch.full((x_t.size(0),), t.item(), dtype=torch.long) | |
| predicted_noise = self.model(x_t, t_tensor, lang=current_lang) | |
| # 记录过程 | |
| if show_process and i % max(1, total_steps // 10) == 0: | |
| current_ids = self._embed_to_tokens(x_t, current_lang) | |
| current_text = self._decode(current_ids, current_lang) | |
| process_steps.append(f"Step {t.item()} [{current_lang}]: {current_text[:50]}") | |
| # DDIM步骤 | |
| x_t = ddim_sampler.ddim_step(x_t, t.item(), t_prev.item(), predicted_noise, eta=0.0) | |
| # 最终解码 | |
| final_ids = self._embed_to_tokens(x_t, target_lang) | |
| result = self._decode(final_ids, target_lang) | |
| return result, process_steps | |
| # ==================== Gradio 应用 ==================== | |
| def create_app(): | |
| # 加载模型 | |
| print("正在加载模型...") | |
| # 使用脚本所在目录作为模型目录 | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| translator = Translator(model_dir=script_dir) | |
| print("模型加载完成!") | |
| def translate_text(text: str, language: str, ddim_steps: int, show_process: bool): | |
| if not text.strip(): | |
| return "", [] | |
| # 自动检测或手动选择 | |
| if language == "自动检测": | |
| if any('\u4e00' <= c <= '\u9fff' for c in text): | |
| source_lang = "zh" | |
| else: | |
| source_lang = "en" | |
| else: | |
| source_lang = "zh" if language == "中文 → 英文" else "en" | |
| try: | |
| result, process = translator.translate( | |
| text, source_lang, ddim_steps, show_process | |
| ) | |
| process_text = "\n".join(process) if process else "(过程未显示)" | |
| return result, process_text | |
| except Exception as e: | |
| return f"翻译出错: {str(e)}", "" | |
| # 创建界面 | |
| with gr.Blocks( | |
| title="Diffutslator", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .output-box { min-height: 100px; } | |
| .process-box { font-family: monospace; font-size: 12px; } | |
| """ | |
| ) as app: | |
| gr.Markdown( | |
| """ | |
| # Diffutslator 扩散翻译器 | |
| 基于扩散模型的机器翻译系统,可视化翻译过程中的语言渐变。 | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| input_text = gr.Textbox( | |
| label="输入文本", | |
| placeholder="输入要翻译的中文或英文...", | |
| lines=5, | |
| ) | |
| with gr.Row(): | |
| language = gr.Dropdown( | |
| choices=["自动检测", "中文 → 英文", "英文 → 中文"], | |
| value="自动检测", | |
| label="翻译方向", | |
| ) | |
| ddim_steps = gr.Slider( | |
| minimum=10, | |
| maximum=100, | |
| value=50, | |
| step=5, | |
| label="DDIM步数", | |
| info="步数越多质量越高,速度越慢", | |
| ) | |
| show_process = gr.Checkbox( | |
| value=False, | |
| label="显示扩散过程", | |
| info="显示翻译中间步骤(会增加推理时间)", | |
| ) | |
| translate_btn = gr.Button("翻译", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| output_text = gr.Textbox( | |
| label="翻译结果", | |
| lines=5, | |
| interactive=False, | |
| elem_classes=["output-box"], | |
| ) | |
| process_text = gr.Textbox( | |
| label="扩散过程", | |
| lines=5, | |
| interactive=False, | |
| visible=False, | |
| elem_classes=["process-box"], | |
| ) | |
| # 示例 | |
| gr.Examples( | |
| examples=[ | |
| ["你好,世界!", "自动检测"], | |
| ["Hello, how are you today?", "自动检测"], | |
| ["机器学习正在改变世界。", "中文 → 英文"], | |
| ["The quick brown fox jumps over the lazy dog.", "英文 → 中文"], | |
| ], | |
| inputs=[input_text, language], | |
| ) | |
| # 事件处理 | |
| def toggle_process(show): | |
| return gr.Textbox(visible=show) | |
| show_process.change( | |
| fn=toggle_process, | |
| inputs=[show_process], | |
| outputs=[process_text], | |
| ) | |
| translate_btn.click( | |
| fn=translate_text, | |
| inputs=[input_text, language, ddim_steps, show_process], | |
| outputs=[output_text, process_text], | |
| ) | |
| # 回车提交 | |
| input_text.submit( | |
| fn=translate_text, | |
| inputs=[input_text, language, ddim_steps, show_process], | |
| outputs=[output_text, process_text], | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| app = create_app() | |
| app.launch() | |