import torch import torch.nn.functional as F from tqdm import tqdm class Trainer: def __init__(self, ae, flow, cfg, loader, pad_id, stop_id): self.ae = ae.to(cfg.device) self.flow = flow.to(cfg.device) if flow else None self.cfg = cfg self.loader = loader self.device = cfg.device self.pad_id = pad_id self.stop_id = stop_id def train_ae(self, optimizer): self.ae.train() total_loss = 0 pbar = tqdm(self.loader, desc="Train AE") optimizer.zero_grad() for step, batch in enumerate(pbar): tgt = batch['tgt_ids'].to(self.device) mask = batch['tgt_mask'].to(self.device) # logits, z = self.ae(tgt, mask) # ## 不太明白这里的mask 机制 # labels = tgt.masked_fill(mask == 0, -100) # loss = F.cross_entropy( # logits.view(-1, logits.size(-1)), # labels.view(-1), # ignore_index=-100 # ) # Reconstruction Loss # loss = F.cross_entropy(logits.view(-1, logits.size(-1)), tgt.view(-1), ignore_index=1) logits, z = self.ae(tgt, mask) # decoder_mask 默认 = mask V = logits.size(-1) B, L = tgt.shape3 # 1) token loss:只看 mask==1 labels_tok = tgt.masked_fill(mask == 0, -100) loss_tok = F.cross_entropy( logits.view(-1, V), labels_tok.view(-1), ignore_index=-100, reduction="mean" ) # 2) pad loss:mask==0 的位置强制预测 PAD(轻权重) pad_pos = (mask == 0) if pad_pos.any(): # 每个位置的 CE ce_all = F.cross_entropy( logits.view(-1, V), tgt.new_full((B * L,), self.pad_id), reduction="none" ).view(B, L) loss_pad = (ce_all * pad_pos).sum() / (pad_pos.sum() + 1e-6) else: loss_pad = logits.new_tensor(0.0) # 3) 可选:stop 位置加权(让 SEP 更稳) stop_pos = ((tgt == self.stop_id) & (mask == 1)) if stop_pos.any(): ce_tok = F.cross_entropy( logits.view(-1, V), tgt.view(-1), reduction="none" ).view(B, L) loss_stop = (ce_tok * stop_pos).sum() / (stop_pos.sum() + 1e-6) else: loss_stop = logits.new_tensor(0.0) # 合成:pad/stop 的权重别太大 lambda_pad = 0.1 lambda_stop = 0.2 loss = loss_tok + lambda_pad * loss_pad + lambda_stop * loss_stop loss = loss / self.cfg.grad_accum_steps loss.backward() if (step + 1) % self.cfg.grad_accum_steps == 0: optimizer.step() optimizer.zero_grad() total_loss += loss.item() * self.cfg.grad_accum_steps pbar.set_postfix(loss=loss.item() * self.cfg.grad_accum_steps) return total_loss / len(self.loader) def train_robust_ae(self, optimizer): self.ae.train() total_loss = 0 noise_std = 0.05 for batch in self.loader: tgt_ids = batch['tgt_ids'].to(self.device) tgt_mask = batch['tgt_mask'].to(self.device) # 1. get normal z with torch.no_grad(): z_clean = self.ae.encode(tgt_ids, tgt_mask) # 2. add noise (Denoising Training) # Decoder -> like z noise = torch.randn_like(z_clean) * noise_std z_noisy = z_clean + noise # 3. Decode logits = self.ae.decode(z_noisy, attention_mask=tgt_mask) # 4. Loss loss = F.cross_entropy(logits.view(-1, logits.size(-1)), tgt_ids.view(-1), reduction='none') loss = (loss * tgt_mask.view(-1)).sum() / tgt_mask.sum() # Backward optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() return total_loss / len(self.loader) def train_ae_combined(self, optimizer, epoch, max_epochs): """ 结合了 基础重建 + 课程去噪 + Pad/Stop 优化 """ self.ae.train() total_loss = 0 # --- 课程噪声调度 (Curriculum Noise) --- # 前 20% 的 Epoch 不加噪声,先学好重建 # 后面线性增加到 0.1 if epoch < max_epochs * 0.2: current_noise = 0.0 else: progress = (epoch - max_epochs * 0.2) / (max_epochs * 0.8) current_noise = 0.1 * progress # 最大噪声 0.1 pbar = tqdm(self.loader, desc=f"Train AE (Noise={current_noise:.4f})") for step, batch in enumerate(pbar): tgt = batch['tgt_ids'].to(self.device) mask = batch['tgt_mask'].to(self.device) # 1. Encode Clean with torch.no_grad(): z_clean = self.ae.encode(tgt, mask) # 2. Add Noise (如果 noise > 0) if current_noise > 0: noise = torch.randn_like(z_clean) * current_noise z_input = z_clean + noise else: z_input = z_clean # 3. Decode logits = self.ae.decode(z_input, attention_mask=mask) # 4. Calculate Advanced Loss (Copy from your original code) V = logits.size(-1) B, L = tgt.shape # Token Loss (只看 mask==1) labels_tok = tgt.masked_fill(mask == 0, -100) loss_tok = F.cross_entropy(logits.view(-1, V), labels_tok.view(-1), ignore_index=-100) # Pad Loss (mask==0) pad_pos = (mask == 0) if pad_pos.any(): ce_pad = F.cross_entropy(logits.view(-1, V), tgt.new_full((B*L,), self.pad_id), reduction='none').view(B,L) loss_pad = (ce_pad * pad_pos).sum() / (pad_pos.sum() + 1e-6) else: loss_pad = torch.tensor(0.0, device=self.device) # Stop Loss stop_pos = ((tgt == self.stop_id) & (mask == 1)) if stop_pos.any(): ce_stop = F.cross_entropy(logits.view(-1, V), tgt.view(-1), reduction='none').view(B,L) loss_stop = (ce_stop * stop_pos).sum() / (stop_pos.sum() + 1e-6) else: loss_stop = torch.tensor(0.0, device=self.device) # 合并 Loss loss = loss_tok + 0.1 * loss_pad + 0.5 * loss_stop # 提高一点 stop 的权重 # Backward optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() pbar.set_postfix(loss=loss.item()) return total_loss / len(self.loader) def train_flow(self, optimizer): self.flow.train() self.ae.eval() total_loss = 0 pbar = tqdm(self.loader, desc="Train Flow") optimizer.zero_grad() scale = getattr(self.ae, "latent_scale", 10.0) for step, batch in enumerate(pbar): src = batch['src_ids'].to(self.device) src_mask = batch['src_mask'].to(self.device) tgt = batch['tgt_ids'].to(self.device) tgt_mask = batch['tgt_mask'].to(self.device) with torch.no_grad(): z_bad = self.ae.encode(src, src_mask) # norm ~ scale z_good = self.ae.encode(tgt, tgt_mask) # norm ~ scale # Rectified Flow bs = z_bad.shape[0] t = torch.rand(bs, device=self.device).view(bs, 1, 1) # Interpolation: Bad -> Good, modify-> push back to sphere z_t_linear = (1 - t) * z_bad + t * z_good ## test before or after # z_t = F.normalize(z_t_linear, p=2, dim=-1) * scale z_t = z_t_linear # Modify: pred_v to pred_x # target_v = z_good - z_bad # pred_v = self.flow(z_t, t.squeeze(), condition=z_bad) # loss = F.mse_loss(pred_v, target_v) # to predict z_good (Target) pred_z1 = self.flow(z_t, t, condition=z_bad) # 3) (强烈建议) 把输出也投影回同一球面,避免 off-manifold -> 都不要normalize pred_z1 = pred_z1 # pred_z1 = F.normalize(pred_z1, p=2, dim=-1) * scale # Loss 直接算与 z_good 的距离 ## 修改:loss必须按照mask 算有效token mse = (pred_z1 - z_good).pow(2).mean(dim=-1) # [B,L] w = tgt_mask.float() # stop 位置加权 stop_pos = ((tgt == self.stop_id) & (tgt_mask == 1)) w = w + stop_pos.float() * 2.0 # 让 SEP 位置权重更大(比如 +2) loss = (mse * w).sum() / (w.sum() + 1e-6) # loss = (mse * tgt_mask).sum() / (tgt_mask.sum() + 1e-6) # loss = F.mse_loss(pred_z1, z_good) loss = loss / self.cfg.grad_accum_steps loss.backward() if (step + 1) % self.cfg.grad_accum_steps == 0: optimizer.step() optimizer.zero_grad() total_loss += loss.item() * self.cfg.grad_accum_steps pbar.set_postfix(loss=loss.item() * self.cfg.grad_accum_steps) return total_loss / len(self.loader)