import torch import torch.optim as optim from transformers import AutoTokenizer from tqdm import tqdm import torch.nn.functional as F import os import evaluate import sacrebleu from src.config import ModelConfig, TrainConfig from src.models.autoencoder import ReshapedAutoencoder from src.models.dit import PatchedFlowDiT from src.trainer import Trainer from src.utils.data_utils import prepare_data ### 加上判断eos的函数 def _pick_stop_id(tokenizer): # BERT/Jina 系通常 eos_token_id=None,用 sep_token_id 作为终止符 return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id def _first_pos(x_1d, token_id, default): # x_1d: [L] idx = (x_1d == token_id).nonzero(as_tuple=True)[0] return idx[0].item() if idx.numel() > 0 else default def calculate_metrics(sources, predictions, references): """ 计算 SARI, BLEU, 和 压缩比 """ ## 这里尝试去huggingface hub 去下载 BLEU的评估脚本,但是因为网络问题没有找到 # sari_metric = evaluate.load("sari") # bleu_metric = evaluate.load("bleu") # SARI 需要 sources # sari_score = sari_metric.compute(sources=sources, predictions=predictions, references=[[r] for r in references]) # # BLEU # bleu_score = bleu_metric.compute(predictions=predictions, references=[[r] for r in references]) # 1. BLEU # sacrebleu 期望 references 是 List[List[str]] (多个参考) # 这里的 references 是 List[str] (单个参考) # 所以需要 transpose 一下: [[ref1, ref2, ...]] bleu = sacrebleu.corpus_bleu(predictions, [references]) # 2. SARI try: # corpus_sari 返回值就是一个 SARI 对象,它的 score 属性是 float sari = sacrebleu.corpus_sari(sources, predictions, [references]) sari_score = sari.score except Exception as e: print(f"SARI calculation failed: {e}") sari_score = 0.0 # 3. Compression Ratio ratios = [len(p) / len(s) if len(s) > 0 else 0 for p, s in zip(predictions, sources)] avg_ratio = sum(ratios) / len(ratios) return { "SARI": sari_score, # 直接使用 float "BLEU": bleu.score, "Compression Ratio": avg_ratio } @torch.no_grad() def inference_batch(ae, flow, loader, tokenizer, device, steps=10, save_path="results.txt",use_oneshot=True): ae.eval() flow.eval() stop_id = _pick_stop_id(tokenizer) pad_id = tokenizer.pad_token_id print(f"\n>>> Running Inference on {len(loader.dataset)} examples...") all_sources = [] all_targets = [] all_generated = [] scale = getattr(ae, "latent_scale", 10.0) with open(save_path, "w", encoding="utf-8") as f: f.write("Source\tTarget\tGenerated\n") for batch in tqdm(loader, desc="Inferencing"): src_ids = batch['src_ids'].to(device) src_mask = batch['src_mask'].to(device) tgt_ids = batch['tgt_ids'].to(device) B, L = src_ids.shape z_curr = ae.encode(src_ids, src_mask) z_cond = z_curr.clone() ## 这里分别采用 one-shot 和多布采样 if use_oneshot: # x-pred 最稳:直接 t=0 one-shot t0 = torch.zeros(B, device=device) z_curr = flow(z_curr, t0, condition=z_cond).float() else: dt = 1.0 / steps for i in range(steps): t_val = i / steps # 避免 t=1 时的除零错误 (虽不常见但要防范) if t_val >= 0.999: break t = torch.ones(z_curr.shape[0], device=device) * t_val ## from v to z # v = flow(z_curr, t, condition=z_cond).float() # z_curr = z_curr + v * dt ## from z to v to zcur pred_z1 = flow(z_curr, t, condition=z_cond).float() ## maybe optimize: 1 - t_val -> 1 v = (pred_z1 - z_curr) / (1.0 - t_val + + 1e-4) # add epilson z_curr = z_curr + v * dt z_curr = F.normalize(z_curr, p=2, dim=-1) * scale z_curr = pred_z1 # 最后一次终点预测直接使用 z_curr = torch.nn.functional.normalize(z_curr, p=2, dim=-1) * scale ## scaling 对齐 # ---- 3) two-pass decode to determine length by EOS ---- full_mask = torch.ones(B, L, device=device) # 允许增长:全长都“可生成” # Pass-1: decode with full mask logits1 = ae.decode(z_curr, attention_mask=full_mask) ids1 = logits1.argmax(dim=-1) # [B, L] # find stop positions and build gen_mask stop_pos = [] for i in range(B): # 如果没预测到 stop,就用 L-1 当作“最大长度” pos = _first_pos(ids1[i], stop_id, default=L - 1) stop_pos.append(pos) stop_pos = torch.tensor(stop_pos, device=device) gen_mask = torch.zeros(B, L, device=device) for i in range(B): gen_mask[i, : stop_pos[i].item() + 1] = 1.0 # Pass-2: decode again with gen_mask, reducing tail interference logits2 = ae.decode(z_curr, attention_mask=gen_mask) ids2 = logits2.argmax(dim=-1) # enforce pad after stop for clean decoding ids2 = ids2.masked_fill(gen_mask == 0, pad_id) # ---- 4) decode to text with truncation ---- src_texts = tokenizer.batch_decode(src_ids, skip_special_tokens=True) tgt_texts = tokenizer.batch_decode(tgt_ids, skip_special_tokens=True) gen_texts = [] for i in range(B): end = stop_pos[i].item() + 1 ids_cut = ids2[i, :end] gen_texts.append(tokenizer.decode(ids_cut, skip_special_tokens=True)) # Save & Collect for s, t, g in zip(src_texts, tgt_texts, gen_texts): # 简单的后处理:去掉换行符以便存成 TSV s_clean = s.replace("\n", " ") t_clean = t.replace("\n", " ") g_clean = g.replace("\n", " ") f.write(f"{s_clean}\t{t_clean}\t{g_clean}\n") all_sources.append(s_clean) all_targets.append(t_clean) all_generated.append(g_clean) return all_sources, all_targets, all_generated ### add saving ckpts def main(): ckpt_dir = "checkpoints" os.makedirs(ckpt_dir, exist_ok=True) print(f"Checkpoints will be saved to: {ckpt_dir}") # Config m_cfg = ModelConfig( encoder_name='../jina-embeddings-v2-base-code', latent_dim=512, max_seq_len=128 # Wiki 任务文本短,用 128 足够且快 ) t_cfg = TrainConfig( batch_size=16, # 推理时可以大一点 num_epochs_ae=20, # 增加一点 AE 训练 num_epochs_flow=35, # 增加 Flow 训练 grad_accum_steps=4, use_amp=False ) tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, trust_remote_code=True) # 1. Load Data (Train & Test) train_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train") test_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="test") # Init ae = ReshapedAutoencoder(m_cfg).to(t_cfg.device).float() flow = PatchedFlowDiT(m_cfg).to(t_cfg.device).float() if ae.encoder.config.pad_token_id is None: ae.encoder.config.pad_token_id = tokenizer.pad_token_id # trainer = Trainer(ae, flow, t_cfg, train_loader) ## 加上pad_id 和 stop_id trainer = Trainer(ae, flow, t_cfg, train_loader, pad_id=tokenizer.pad_token_id, stop_id=_pick_stop_id(tokenizer)) # 2. Train AE opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae) best_ae_loss = float('inf') print("\n>>> Start Training Autoencoder...") for epoch in range(t_cfg.num_epochs_ae): loss = trainer.train_ae(opt_ae) print(f"AE Epoch {epoch}: Loss {loss:.4f}") # 保存 Best if loss < best_ae_loss: best_ae_loss = loss torch.save(ae.state_dict(), os.path.join(ckpt_dir, "ae_best.pt")) # print(f" Saved Best AE (Loss {loss:.4f})") # 保存 Last (每个 epoch 覆盖,用于断点续训或检查) torch.save(ae.state_dict(), os.path.join(ckpt_dir, "ae_last.pt")) print(f"AE Training Done. Best Loss: {best_ae_loss:.4f}") # 3. Train Flow opt_flow = optim.AdamW(flow.parameters(), lr=t_cfg.lr_flow) best_flow_loss = float('inf') print("\n>>> Start Training Flow DiT...") for epoch in range(t_cfg.num_epochs_flow): loss = trainer.train_flow(opt_flow) print(f"Flow Epoch {epoch}: Loss {loss:.4f}") # 保存 Best if loss < best_flow_loss: best_flow_loss = loss torch.save(flow.state_dict(), os.path.join(ckpt_dir, "flow_best.pt")) # print(f" Saved Best Flow (Loss {loss:.4f})") # 保存 Last torch.save(flow.state_dict(), os.path.join(ckpt_dir, "flow_last.pt")) print(f"Flow Training Done. Best Loss: {best_flow_loss:.4f}") # 4. Evaluation # 加载最佳权重 ae_path = os.path.join(ckpt_dir, "ae_best.pt") flow_path = os.path.join(ckpt_dir, "flow_best.pt") if os.path.exists(ae_path): ae.load_state_dict(torch.load(ae_path, map_location=t_cfg.device)) print("Loaded AE Best.") else: print("Warning: AE Best ckpt not found, using last state.") if os.path.exists(flow_path): flow.load_state_dict(torch.load(flow_path, map_location=t_cfg.device)) print("Loaded Flow Best.") else: print("Warning: Flow Best ckpt not found, using last state.") print("\n--- Starting Inference ---") sources, targets, gens = inference_batch( ae, flow, test_loader, tokenizer, t_cfg.device, steps=10, save_path="wiki_results.tsv" ) # Calculate Metrics metrics = calculate_metrics(sources, gens, targets) print("\n=== Metrics ===") for k, v in metrics.items(): print(f"{k}: {v:.4f}") print(f"\nResults saved to wiki_results.tsv") if __name__ == "__main__": main()