File size: 5,747 Bytes
2651102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
语言切换器
判断当前噪声状态更接近哪种语言
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional


class LanguageSwitcher(nn.Module):
    """语言切换分类器
    
    输入: 噪声状态 x_t [batch, seq_len, d_model]
    输出: 语言概率 [batch, 2] -> [中文概率, 英文概率]
    """
    
    def __init__(self, d_model: int = 256, hidden_dim: int = 128, dropout: float = 0.1):
        super().__init__()
        
        # 全局特征提取
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        
        # 分类头
        self.classifier = nn.Sequential(
            nn.Linear(d_model, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 2),  # 2类:中文/英文
        )
        
        # 初始化
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
    
    def forward(self, x_t: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        x_t: [batch, seq_len, d_model]
        mask: [batch, seq_len] 可选的mask
        
        返回: [batch, 2] logits (中文, 英文)
        """
        # 应用mask
        if mask is not None:
            x_t = x_t * mask.unsqueeze(-1)
        
        # 全局池化: [batch, seq_len, d_model] -> [batch, d_model, seq_len] -> [batch, d_model, 1]
        x = x_t.transpose(1, 2)
        x = self.global_pool(x).squeeze(-1)  # [batch, d_model]
        
        # 分类
        logits = self.classifier(x)
        
        return logits
    
    def predict(self, x_t: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[str, float]:
        """预测语言
        
        返回:
            lang: "zh" 或 "en"
            confidence: 置信度 [0, 1]
        """
        self.eval()
        with torch.no_grad():
            logits = self.forward(x_t, mask)
            probs = F.softmax(logits, dim=-1)
            
            # 取第一个样本(假设batch=1)
            zh_prob = probs[0, 0].item()
            en_prob = probs[0, 1].item()
            
            if zh_prob > en_prob:
                return "zh", zh_prob
            else:
                return "en", en_prob
    
    def get_probabilities(self, x_t: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """获取中文和英文的概率
        
        返回:
            zh_probs: [batch] 中文概率
            en_probs: [batch] 英文概率
        """
        logits = self.forward(x_t, mask)
        probs = F.softmax(logits, dim=-1)
        return probs[:, 0], probs[:, 1]


class AdaptiveSwitcher(nn.Module):
    """自适应语言切换器
    
    根据扩散时间步动态调整切换策略
    - 早期(高噪声):更激进的切换
    - 后期(低噪声):更保守的切换
    """
    
    def __init__(
        self,
        d_model: int = 256,
        hidden_dim: int = 128,
        dropout: float = 0.1,
        switch_threshold: float = 0.6,  # 切换阈值
    ):
        super().__init__()
        
        self.switch_threshold = switch_threshold
        
        # 基础切换器
        self.base_switcher = LanguageSwitcher(d_model, hidden_dim, dropout)
        
        # 时间调制
        self.time_modulation = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, 2),
            nn.Sigmoid(),
        )
    
    def forward(
        self,
        x_t: torch.Tensor,
        t: Optional[torch.Tensor] = None,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        x_t: [batch, seq_len, d_model]
        t: [batch] 时间步,用于调制
        """
        # 基础预测
        logits = self.base_switcher(x_t, mask)
        
        # 时间调制(可选)
        if t is not None:
            # 归一化时间
            t_norm = t.float().unsqueeze(-1) / 1000.0  # [batch, 1]
            modulation = self.time_modulation(t_norm)  # [batch, 2]
            logits = logits * modulation
        
        return logits
    
    def should_switch(
        self,
        x_t: torch.Tensor,
        current_lang: str,
        t: Optional[torch.Tensor] = None,
        mask: Optional[torch.Tensor] = None,
    ) -> Tuple[bool, str, float]:
        """判断是否应该切换语言
        
        返回:
            should_switch: 是否切换
            new_lang: 新语言
            confidence: 置信度
        """
        self.eval()
        with torch.no_grad():
            logits = self.forward(x_t, t, mask)
            probs = F.softmax(logits, dim=-1)
            
            zh_prob = probs[0, 0].item()
            en_prob = probs[0, 1].item()
            
            # 判断
            predicted_lang = "zh" if zh_prob > en_prob else "en"
            confidence = max(zh_prob, en_prob)
            
            # 是否切换
            should_switch = (
                predicted_lang != current_lang and
                confidence > self.switch_threshold
            )
            
            return should_switch, predicted_lang, confidence


def create_switcher(config) -> LanguageSwitcher:
    """创建语言切换器"""
    return LanguageSwitcher(
        d_model=config.model.d_model,
        hidden_dim=config.model.d_model // 2,
        dropout=config.model.dropout,
    )