File size: 9,950 Bytes
77d636f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import torch
import torch.nn.functional as F
from tqdm import tqdm

class Trainer:
    def __init__(self, ae, flow, cfg, loader, pad_id, stop_id):
        self.ae = ae.to(cfg.device)
        self.flow = flow.to(cfg.device) if flow else None
        self.cfg = cfg
        self.loader = loader
        self.device = cfg.device
        self.pad_id = pad_id
        self.stop_id = stop_id

    def train_ae(self, optimizer):
        self.ae.train()
        total_loss = 0
        pbar = tqdm(self.loader, desc="Train AE")
        optimizer.zero_grad()
        
        for step, batch in enumerate(pbar):
            tgt = batch['tgt_ids'].to(self.device)
            mask = batch['tgt_mask'].to(self.device)
            
            # logits, z = self.ae(tgt, mask)

            # ## 不太明白这里的mask 机制
            # labels = tgt.masked_fill(mask == 0, -100)
            # loss = F.cross_entropy(
            #     logits.view(-1, logits.size(-1)),
            #     labels.view(-1),
            #     ignore_index=-100
            # )
            # Reconstruction Loss
            # loss = F.cross_entropy(logits.view(-1, logits.size(-1)), tgt.view(-1), ignore_index=1)
            logits, z = self.ae(tgt, mask)   # decoder_mask 默认 = mask

            V = logits.size(-1)
            B, L = tgt.shape3

            # 1) token loss:只看 mask==1
            labels_tok = tgt.masked_fill(mask == 0, -100)
            loss_tok = F.cross_entropy(
                logits.view(-1, V),
                labels_tok.view(-1),
                ignore_index=-100,
                reduction="mean"
            )

            # 2) pad loss:mask==0 的位置强制预测 PAD(轻权重)
            pad_pos = (mask == 0)
            if pad_pos.any():
                # 每个位置的 CE
                ce_all = F.cross_entropy(
                    logits.view(-1, V),
                    tgt.new_full((B * L,), self.pad_id),
                    reduction="none"
                ).view(B, L)
                loss_pad = (ce_all * pad_pos).sum() / (pad_pos.sum() + 1e-6)
            else:
                loss_pad = logits.new_tensor(0.0)

            # 3) 可选:stop 位置加权(让 SEP 更稳)
            stop_pos = ((tgt == self.stop_id) & (mask == 1))
            if stop_pos.any():
                ce_tok = F.cross_entropy(
                    logits.view(-1, V),
                    tgt.view(-1),
                    reduction="none"
                ).view(B, L)
                loss_stop = (ce_tok * stop_pos).sum() / (stop_pos.sum() + 1e-6)
            else:
                loss_stop = logits.new_tensor(0.0)

            # 合成:pad/stop 的权重别太大
            lambda_pad = 0.1
            lambda_stop = 0.2
            loss = loss_tok + lambda_pad * loss_pad + lambda_stop * loss_stop

            loss = loss / self.cfg.grad_accum_steps
            loss.backward()
            
            if (step + 1) % self.cfg.grad_accum_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                
            total_loss += loss.item() * self.cfg.grad_accum_steps
            pbar.set_postfix(loss=loss.item() * self.cfg.grad_accum_steps)
            
        return total_loss / len(self.loader)

    def train_robust_ae(self, optimizer):
        
        self.ae.train()
        total_loss = 0
        noise_std = 0.05 

        for batch in self.loader:
            tgt_ids = batch['tgt_ids'].to(self.device)
            tgt_mask = batch['tgt_mask'].to(self.device)
            
            # 1. get normal z
            with torch.no_grad():
                z_clean = self.ae.encode(tgt_ids, tgt_mask)
            
            # 2. add noise (Denoising Training)
            #  Decoder -> like z 
            noise = torch.randn_like(z_clean) * noise_std
            z_noisy = z_clean + noise
            
            # 3. Decode
            logits = self.ae.decode(z_noisy, attention_mask=tgt_mask)

            # 4. Loss
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), 
                                   tgt_ids.view(-1), 
                                   reduction='none')
            loss = (loss * tgt_mask.view(-1)).sum() / tgt_mask.sum()

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
        return total_loss / len(self.loader)


    def train_ae_combined(self, optimizer, epoch, max_epochs):
        """
        结合了 基础重建 + 课程去噪 + Pad/Stop 优化
        """
        self.ae.train()
        total_loss = 0
        
        # --- 课程噪声调度 (Curriculum Noise) ---
        # 前 20% 的 Epoch 不加噪声,先学好重建
        # 后面线性增加到 0.1
        if epoch < max_epochs * 0.2:
            current_noise = 0.0
        else:
            progress = (epoch - max_epochs * 0.2) / (max_epochs * 0.8)
            current_noise = 0.1 * progress # 最大噪声 0.1

        pbar = tqdm(self.loader, desc=f"Train AE (Noise={current_noise:.4f})")
        
        for step, batch in enumerate(pbar):
            tgt = batch['tgt_ids'].to(self.device)
            mask = batch['tgt_mask'].to(self.device)
            
            # 1. Encode Clean
            with torch.no_grad():
                z_clean = self.ae.encode(tgt, mask)
            
            # 2. Add Noise (如果 noise > 0)
            if current_noise > 0:
                noise = torch.randn_like(z_clean) * current_noise
                z_input = z_clean + noise
            else:
                z_input = z_clean
            
            # 3. Decode
            logits = self.ae.decode(z_input, attention_mask=mask)
            
            # 4. Calculate Advanced Loss (Copy from your original code)
            V = logits.size(-1)
            B, L = tgt.shape

            # Token Loss (只看 mask==1)
            labels_tok = tgt.masked_fill(mask == 0, -100)
            loss_tok = F.cross_entropy(logits.view(-1, V), labels_tok.view(-1), ignore_index=-100)

            # Pad Loss (mask==0)
            pad_pos = (mask == 0)
            if pad_pos.any():
                ce_pad = F.cross_entropy(logits.view(-1, V), tgt.new_full((B*L,), self.pad_id), reduction='none').view(B,L)
                loss_pad = (ce_pad * pad_pos).sum() / (pad_pos.sum() + 1e-6)
            else:
                loss_pad = torch.tensor(0.0, device=self.device)
                
            # Stop Loss
            stop_pos = ((tgt == self.stop_id) & (mask == 1))
            if stop_pos.any():
                ce_stop = F.cross_entropy(logits.view(-1, V), tgt.view(-1), reduction='none').view(B,L)
                loss_stop = (ce_stop * stop_pos).sum() / (stop_pos.sum() + 1e-6)
            else:
                loss_stop = torch.tensor(0.0, device=self.device)

            # 合并 Loss
            loss = loss_tok + 0.1 * loss_pad + 0.5 * loss_stop # 提高一点 stop 的权重
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pbar.set_postfix(loss=loss.item())
            
        return total_loss / len(self.loader)

    def train_flow(self, optimizer):
        self.flow.train()
        self.ae.eval()
        total_loss = 0
        pbar = tqdm(self.loader, desc="Train Flow")
        optimizer.zero_grad()

        scale = getattr(self.ae, "latent_scale", 10.0)
        
        for step, batch in enumerate(pbar):
            src = batch['src_ids'].to(self.device)
            src_mask = batch['src_mask'].to(self.device)
            tgt = batch['tgt_ids'].to(self.device)
            tgt_mask = batch['tgt_mask'].to(self.device)
            
            with torch.no_grad():
                z_bad = self.ae.encode(src, src_mask) # norm ~ scale
                z_good = self.ae.encode(tgt, tgt_mask) # norm ~ scale
                
            # Rectified Flow
            bs = z_bad.shape[0]
            t = torch.rand(bs, device=self.device).view(bs, 1, 1)
            
            
            # Interpolation: Bad -> Good, modify-> push back to sphere
            z_t_linear = (1 - t) * z_bad + t * z_good
            ## test before or after
            # z_t = F.normalize(z_t_linear, p=2, dim=-1) * scale
            z_t = z_t_linear

            # Modify: pred_v to pred_x
            # target_v = z_good - z_bad
            # pred_v = self.flow(z_t, t.squeeze(), condition=z_bad)
            # loss = F.mse_loss(pred_v, target_v)
            
            # to predict z_good (Target)
            pred_z1 = self.flow(z_t, t, condition=z_bad)
            # 3) (强烈建议) 把输出也投影回同一球面,避免 off-manifold -> 都不要normalize
            pred_z1 = pred_z1
            # pred_z1 = F.normalize(pred_z1, p=2, dim=-1) * scale
            # Loss 直接算与 z_good 的距离
            ## 修改:loss必须按照mask 算有效token
            mse = (pred_z1 - z_good).pow(2).mean(dim=-1)  # [B,L]
            w = tgt_mask.float()

            # stop 位置加权
            stop_pos = ((tgt == self.stop_id) & (tgt_mask == 1))
            w = w + stop_pos.float() * 2.0   # 让 SEP 位置权重更大(比如 +2)

            loss = (mse * w).sum() / (w.sum() + 1e-6)

            # loss = (mse * tgt_mask).sum() / (tgt_mask.sum() + 1e-6)
            # loss = F.mse_loss(pred_z1, z_good)

            loss = loss / self.cfg.grad_accum_steps
            loss.backward()
            
            if (step + 1) % self.cfg.grad_accum_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                
            total_loss += loss.item() * self.cfg.grad_accum_steps
            pbar.set_postfix(loss=loss.item() * self.cfg.grad_accum_steps)
            
        return total_loss / len(self.loader)