| """ |
| 语言切换器 |
| 判断当前噪声状态更接近哪种语言 |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Tuple, Optional |
|
|
|
|
| class LanguageSwitcher(nn.Module): |
| """语言切换分类器 |
| |
| 输入: 噪声状态 x_t [batch, seq_len, d_model] |
| 输出: 语言概率 [batch, 2] -> [中文概率, 英文概率] |
| """ |
| |
| def __init__(self, d_model: int = 256, hidden_dim: int = 128, dropout: float = 0.1): |
| super().__init__() |
| |
| |
| self.global_pool = nn.AdaptiveAvgPool1d(1) |
| |
| |
| self.classifier = nn.Sequential( |
| nn.Linear(d_model, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, 2), |
| ) |
| |
| |
| self.apply(self._init_weights) |
| |
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| |
| def forward(self, x_t: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| """ |
| x_t: [batch, seq_len, d_model] |
| mask: [batch, seq_len] 可选的mask |
| |
| 返回: [batch, 2] logits (中文, 英文) |
| """ |
| |
| if mask is not None: |
| x_t = x_t * mask.unsqueeze(-1) |
| |
| |
| x = x_t.transpose(1, 2) |
| x = self.global_pool(x).squeeze(-1) |
| |
| |
| logits = self.classifier(x) |
| |
| return logits |
| |
| def predict(self, x_t: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[str, float]: |
| """预测语言 |
| |
| 返回: |
| lang: "zh" 或 "en" |
| confidence: 置信度 [0, 1] |
| """ |
| self.eval() |
| with torch.no_grad(): |
| logits = self.forward(x_t, mask) |
| probs = F.softmax(logits, dim=-1) |
| |
| |
| zh_prob = probs[0, 0].item() |
| en_prob = probs[0, 1].item() |
| |
| if zh_prob > en_prob: |
| return "zh", zh_prob |
| else: |
| return "en", en_prob |
| |
| def get_probabilities(self, x_t: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
| """获取中文和英文的概率 |
| |
| 返回: |
| zh_probs: [batch] 中文概率 |
| en_probs: [batch] 英文概率 |
| """ |
| logits = self.forward(x_t, mask) |
| probs = F.softmax(logits, dim=-1) |
| return probs[:, 0], probs[:, 1] |
|
|
|
|
| class AdaptiveSwitcher(nn.Module): |
| """自适应语言切换器 |
| |
| 根据扩散时间步动态调整切换策略 |
| - 早期(高噪声):更激进的切换 |
| - 后期(低噪声):更保守的切换 |
| """ |
| |
| def __init__( |
| self, |
| d_model: int = 256, |
| hidden_dim: int = 128, |
| dropout: float = 0.1, |
| switch_threshold: float = 0.6, |
| ): |
| super().__init__() |
| |
| self.switch_threshold = switch_threshold |
| |
| |
| self.base_switcher = LanguageSwitcher(d_model, hidden_dim, dropout) |
| |
| |
| self.time_modulation = nn.Sequential( |
| nn.Linear(1, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, 2), |
| nn.Sigmoid(), |
| ) |
| |
| def forward( |
| self, |
| x_t: torch.Tensor, |
| t: Optional[torch.Tensor] = None, |
| mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| x_t: [batch, seq_len, d_model] |
| t: [batch] 时间步,用于调制 |
| """ |
| |
| logits = self.base_switcher(x_t, mask) |
| |
| |
| if t is not None: |
| |
| t_norm = t.float().unsqueeze(-1) / 1000.0 |
| modulation = self.time_modulation(t_norm) |
| logits = logits * modulation |
| |
| return logits |
| |
| def should_switch( |
| self, |
| x_t: torch.Tensor, |
| current_lang: str, |
| t: Optional[torch.Tensor] = None, |
| mask: Optional[torch.Tensor] = None, |
| ) -> Tuple[bool, str, float]: |
| """判断是否应该切换语言 |
| |
| 返回: |
| should_switch: 是否切换 |
| new_lang: 新语言 |
| confidence: 置信度 |
| """ |
| self.eval() |
| with torch.no_grad(): |
| logits = self.forward(x_t, t, mask) |
| probs = F.softmax(logits, dim=-1) |
| |
| zh_prob = probs[0, 0].item() |
| en_prob = probs[0, 1].item() |
| |
| |
| predicted_lang = "zh" if zh_prob > en_prob else "en" |
| confidence = max(zh_prob, en_prob) |
| |
| |
| should_switch = ( |
| predicted_lang != current_lang and |
| confidence > self.switch_threshold |
| ) |
| |
| return should_switch, predicted_lang, confidence |
|
|
|
|
| def create_switcher(config) -> LanguageSwitcher: |
| """创建语言切换器""" |
| return LanguageSwitcher( |
| d_model=config.model.d_model, |
| hidden_dim=config.model.d_model // 2, |
| dropout=config.model.dropout, |
| ) |
|
|