Diff-Refine / run_wiki_flow.py
2ira's picture
Add files using upload-large-folder tool
77d636f verified
import torch
import torch.optim as optim
from transformers import AutoTokenizer
from tqdm import tqdm
import torch.nn.functional as F
import os
import evaluate
import sacrebleu
from src.config import ModelConfig, TrainConfig
from src.models.autoencoder import ReshapedAutoencoder
from src.models.dit import PatchedFlowDiT
from src.trainer import Trainer
from src.utils.data_utils import prepare_data
### 加上判断eos的函数
def _pick_stop_id(tokenizer):
# BERT/Jina 系通常 eos_token_id=None,用 sep_token_id 作为终止符
return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id
def _first_pos(x_1d, token_id, default):
# x_1d: [L]
idx = (x_1d == token_id).nonzero(as_tuple=True)[0]
return idx[0].item() if idx.numel() > 0 else default
def calculate_metrics(sources, predictions, references):
"""
计算 SARI, BLEU, 和 压缩比
"""
## 这里尝试去huggingface hub 去下载 BLEU的评估脚本,但是因为网络问题没有找到
# sari_metric = evaluate.load("sari")
# bleu_metric = evaluate.load("bleu")
# SARI 需要 sources
# sari_score = sari_metric.compute(sources=sources, predictions=predictions, references=[[r] for r in references])
# # BLEU
# bleu_score = bleu_metric.compute(predictions=predictions, references=[[r] for r in references])
# 1. BLEU
# sacrebleu 期望 references 是 List[List[str]] (多个参考)
# 这里的 references 是 List[str] (单个参考)
# 所以需要 transpose 一下: [[ref1, ref2, ...]]
bleu = sacrebleu.corpus_bleu(predictions, [references])
# 2. SARI
try:
# corpus_sari 返回值就是一个 SARI 对象,它的 score 属性是 float
sari = sacrebleu.corpus_sari(sources, predictions, [references])
sari_score = sari.score
except Exception as e:
print(f"SARI calculation failed: {e}")
sari_score = 0.0
# 3. Compression Ratio
ratios = [len(p) / len(s) if len(s) > 0 else 0 for p, s in zip(predictions, sources)]
avg_ratio = sum(ratios) / len(ratios)
return {
"SARI": sari_score, # 直接使用 float
"BLEU": bleu.score,
"Compression Ratio": avg_ratio
}
@torch.no_grad()
def inference_batch(ae, flow, loader, tokenizer, device, steps=10, save_path="results.txt",use_oneshot=True):
ae.eval()
flow.eval()
stop_id = _pick_stop_id(tokenizer)
pad_id = tokenizer.pad_token_id
print(f"\n>>> Running Inference on {len(loader.dataset)} examples...")
all_sources = []
all_targets = []
all_generated = []
scale = getattr(ae, "latent_scale", 10.0)
with open(save_path, "w", encoding="utf-8") as f:
f.write("Source\tTarget\tGenerated\n")
for batch in tqdm(loader, desc="Inferencing"):
src_ids = batch['src_ids'].to(device)
src_mask = batch['src_mask'].to(device)
tgt_ids = batch['tgt_ids'].to(device)
B, L = src_ids.shape
z_curr = ae.encode(src_ids, src_mask)
z_cond = z_curr.clone()
## 这里分别采用 one-shot 和多布采样
if use_oneshot:
# x-pred 最稳:直接 t=0 one-shot
t0 = torch.zeros(B, device=device)
z_curr = flow(z_curr, t0, condition=z_cond).float()
else:
dt = 1.0 / steps
for i in range(steps):
t_val = i / steps
# 避免 t=1 时的除零错误 (虽不常见但要防范)
if t_val >= 0.999: break
t = torch.ones(z_curr.shape[0], device=device) * t_val
## from v to z
# v = flow(z_curr, t, condition=z_cond).float()
# z_curr = z_curr + v * dt
## from z to v to zcur
pred_z1 = flow(z_curr, t, condition=z_cond).float()
## maybe optimize: 1 - t_val -> 1
v = (pred_z1 - z_curr) / (1.0 - t_val + + 1e-4) # add epilson
z_curr = z_curr + v * dt
z_curr = F.normalize(z_curr, p=2, dim=-1) * scale
z_curr = pred_z1 # 最后一次终点预测直接使用
z_curr = torch.nn.functional.normalize(z_curr, p=2, dim=-1) * scale ## scaling 对齐
# ---- 3) two-pass decode to determine length by EOS ----
full_mask = torch.ones(B, L, device=device) # 允许增长:全长都“可生成”
# Pass-1: decode with full mask
logits1 = ae.decode(z_curr, attention_mask=full_mask)
ids1 = logits1.argmax(dim=-1) # [B, L]
# find stop positions and build gen_mask
stop_pos = []
for i in range(B):
# 如果没预测到 stop,就用 L-1 当作“最大长度”
pos = _first_pos(ids1[i], stop_id, default=L - 1)
stop_pos.append(pos)
stop_pos = torch.tensor(stop_pos, device=device)
gen_mask = torch.zeros(B, L, device=device)
for i in range(B):
gen_mask[i, : stop_pos[i].item() + 1] = 1.0
# Pass-2: decode again with gen_mask, reducing tail interference
logits2 = ae.decode(z_curr, attention_mask=gen_mask)
ids2 = logits2.argmax(dim=-1)
# enforce pad after stop for clean decoding
ids2 = ids2.masked_fill(gen_mask == 0, pad_id)
# ---- 4) decode to text with truncation ----
src_texts = tokenizer.batch_decode(src_ids, skip_special_tokens=True)
tgt_texts = tokenizer.batch_decode(tgt_ids, skip_special_tokens=True)
gen_texts = []
for i in range(B):
end = stop_pos[i].item() + 1
ids_cut = ids2[i, :end]
gen_texts.append(tokenizer.decode(ids_cut, skip_special_tokens=True))
# Save & Collect
for s, t, g in zip(src_texts, tgt_texts, gen_texts):
# 简单的后处理:去掉换行符以便存成 TSV
s_clean = s.replace("\n", " ")
t_clean = t.replace("\n", " ")
g_clean = g.replace("\n", " ")
f.write(f"{s_clean}\t{t_clean}\t{g_clean}\n")
all_sources.append(s_clean)
all_targets.append(t_clean)
all_generated.append(g_clean)
return all_sources, all_targets, all_generated
### add saving ckpts
def main():
ckpt_dir = "checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
print(f"Checkpoints will be saved to: {ckpt_dir}")
# Config
m_cfg = ModelConfig(
encoder_name='../jina-embeddings-v2-base-code',
latent_dim=512,
max_seq_len=128 # Wiki 任务文本短,用 128 足够且快
)
t_cfg = TrainConfig(
batch_size=16, # 推理时可以大一点
num_epochs_ae=20, # 增加一点 AE 训练
num_epochs_flow=35, # 增加 Flow 训练
grad_accum_steps=4,
use_amp=False
)
tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, trust_remote_code=True)
# 1. Load Data (Train & Test)
train_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train")
test_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="test")
# Init
ae = ReshapedAutoencoder(m_cfg).to(t_cfg.device).float()
flow = PatchedFlowDiT(m_cfg).to(t_cfg.device).float()
if ae.encoder.config.pad_token_id is None:
ae.encoder.config.pad_token_id = tokenizer.pad_token_id
# trainer = Trainer(ae, flow, t_cfg, train_loader)
## 加上pad_id 和 stop_id
trainer = Trainer(ae, flow, t_cfg, train_loader, pad_id=tokenizer.pad_token_id, stop_id=_pick_stop_id(tokenizer))
# 2. Train AE
opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae)
best_ae_loss = float('inf')
print("\n>>> Start Training Autoencoder...")
for epoch in range(t_cfg.num_epochs_ae):
loss = trainer.train_ae(opt_ae)
print(f"AE Epoch {epoch}: Loss {loss:.4f}")
# 保存 Best
if loss < best_ae_loss:
best_ae_loss = loss
torch.save(ae.state_dict(), os.path.join(ckpt_dir, "ae_best.pt"))
# print(f" Saved Best AE (Loss {loss:.4f})")
# 保存 Last (每个 epoch 覆盖,用于断点续训或检查)
torch.save(ae.state_dict(), os.path.join(ckpt_dir, "ae_last.pt"))
print(f"AE Training Done. Best Loss: {best_ae_loss:.4f}")
# 3. Train Flow
opt_flow = optim.AdamW(flow.parameters(), lr=t_cfg.lr_flow)
best_flow_loss = float('inf')
print("\n>>> Start Training Flow DiT...")
for epoch in range(t_cfg.num_epochs_flow):
loss = trainer.train_flow(opt_flow)
print(f"Flow Epoch {epoch}: Loss {loss:.4f}")
# 保存 Best
if loss < best_flow_loss:
best_flow_loss = loss
torch.save(flow.state_dict(), os.path.join(ckpt_dir, "flow_best.pt"))
# print(f" Saved Best Flow (Loss {loss:.4f})")
# 保存 Last
torch.save(flow.state_dict(), os.path.join(ckpt_dir, "flow_last.pt"))
print(f"Flow Training Done. Best Loss: {best_flow_loss:.4f}")
# 4. Evaluation
# 加载最佳权重
ae_path = os.path.join(ckpt_dir, "ae_best.pt")
flow_path = os.path.join(ckpt_dir, "flow_best.pt")
if os.path.exists(ae_path):
ae.load_state_dict(torch.load(ae_path, map_location=t_cfg.device))
print("Loaded AE Best.")
else:
print("Warning: AE Best ckpt not found, using last state.")
if os.path.exists(flow_path):
flow.load_state_dict(torch.load(flow_path, map_location=t_cfg.device))
print("Loaded Flow Best.")
else:
print("Warning: Flow Best ckpt not found, using last state.")
print("\n--- Starting Inference ---")
sources, targets, gens = inference_batch(
ae, flow, test_loader, tokenizer, t_cfg.device,
steps=10,
save_path="wiki_results.tsv"
)
# Calculate Metrics
metrics = calculate_metrics(sources, gens, targets)
print("\n=== Metrics ===")
for k, v in metrics.items():
print(f"{k}: {v:.4f}")
print(f"\nResults saved to wiki_results.tsv")
if __name__ == "__main__":
main()