File size: 10,693 Bytes
77d636f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 | import torch
import torch.optim as optim
from transformers import AutoTokenizer
from tqdm import tqdm
import torch.nn.functional as F
import os
import evaluate
import sacrebleu
from src.config import ModelConfig, TrainConfig
from src.models.autoencoder import ReshapedAutoencoder
from src.models.dit import PatchedFlowDiT
from src.trainer import Trainer
from src.utils.data_utils import prepare_data
### 加上判断eos的函数
def _pick_stop_id(tokenizer):
# BERT/Jina 系通常 eos_token_id=None,用 sep_token_id 作为终止符
return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id
def _first_pos(x_1d, token_id, default):
# x_1d: [L]
idx = (x_1d == token_id).nonzero(as_tuple=True)[0]
return idx[0].item() if idx.numel() > 0 else default
def calculate_metrics(sources, predictions, references):
"""
计算 SARI, BLEU, 和 压缩比
"""
## 这里尝试去huggingface hub 去下载 BLEU的评估脚本,但是因为网络问题没有找到
# sari_metric = evaluate.load("sari")
# bleu_metric = evaluate.load("bleu")
# SARI 需要 sources
# sari_score = sari_metric.compute(sources=sources, predictions=predictions, references=[[r] for r in references])
# # BLEU
# bleu_score = bleu_metric.compute(predictions=predictions, references=[[r] for r in references])
# 1. BLEU
# sacrebleu 期望 references 是 List[List[str]] (多个参考)
# 这里的 references 是 List[str] (单个参考)
# 所以需要 transpose 一下: [[ref1, ref2, ...]]
bleu = sacrebleu.corpus_bleu(predictions, [references])
# 2. SARI
try:
# corpus_sari 返回值就是一个 SARI 对象,它的 score 属性是 float
sari = sacrebleu.corpus_sari(sources, predictions, [references])
sari_score = sari.score
except Exception as e:
print(f"SARI calculation failed: {e}")
sari_score = 0.0
# 3. Compression Ratio
ratios = [len(p) / len(s) if len(s) > 0 else 0 for p, s in zip(predictions, sources)]
avg_ratio = sum(ratios) / len(ratios)
return {
"SARI": sari_score, # 直接使用 float
"BLEU": bleu.score,
"Compression Ratio": avg_ratio
}
@torch.no_grad()
def inference_batch(ae, flow, loader, tokenizer, device, steps=10, save_path="results.txt",use_oneshot=True):
ae.eval()
flow.eval()
stop_id = _pick_stop_id(tokenizer)
pad_id = tokenizer.pad_token_id
print(f"\n>>> Running Inference on {len(loader.dataset)} examples...")
all_sources = []
all_targets = []
all_generated = []
scale = getattr(ae, "latent_scale", 10.0)
with open(save_path, "w", encoding="utf-8") as f:
f.write("Source\tTarget\tGenerated\n")
for batch in tqdm(loader, desc="Inferencing"):
src_ids = batch['src_ids'].to(device)
src_mask = batch['src_mask'].to(device)
tgt_ids = batch['tgt_ids'].to(device)
B, L = src_ids.shape
z_curr = ae.encode(src_ids, src_mask)
z_cond = z_curr.clone()
## 这里分别采用 one-shot 和多布采样
if use_oneshot:
# x-pred 最稳:直接 t=0 one-shot
t0 = torch.zeros(B, device=device)
z_curr = flow(z_curr, t0, condition=z_cond).float()
else:
dt = 1.0 / steps
for i in range(steps):
t_val = i / steps
# 避免 t=1 时的除零错误 (虽不常见但要防范)
if t_val >= 0.999: break
t = torch.ones(z_curr.shape[0], device=device) * t_val
## from v to z
# v = flow(z_curr, t, condition=z_cond).float()
# z_curr = z_curr + v * dt
## from z to v to zcur
pred_z1 = flow(z_curr, t, condition=z_cond).float()
## maybe optimize: 1 - t_val -> 1
v = (pred_z1 - z_curr) / (1.0 - t_val + + 1e-4) # add epilson
z_curr = z_curr + v * dt
z_curr = F.normalize(z_curr, p=2, dim=-1) * scale
z_curr = pred_z1 # 最后一次终点预测直接使用
z_curr = torch.nn.functional.normalize(z_curr, p=2, dim=-1) * scale ## scaling 对齐
# ---- 3) two-pass decode to determine length by EOS ----
full_mask = torch.ones(B, L, device=device) # 允许增长:全长都“可生成”
# Pass-1: decode with full mask
logits1 = ae.decode(z_curr, attention_mask=full_mask)
ids1 = logits1.argmax(dim=-1) # [B, L]
# find stop positions and build gen_mask
stop_pos = []
for i in range(B):
# 如果没预测到 stop,就用 L-1 当作“最大长度”
pos = _first_pos(ids1[i], stop_id, default=L - 1)
stop_pos.append(pos)
stop_pos = torch.tensor(stop_pos, device=device)
gen_mask = torch.zeros(B, L, device=device)
for i in range(B):
gen_mask[i, : stop_pos[i].item() + 1] = 1.0
# Pass-2: decode again with gen_mask, reducing tail interference
logits2 = ae.decode(z_curr, attention_mask=gen_mask)
ids2 = logits2.argmax(dim=-1)
# enforce pad after stop for clean decoding
ids2 = ids2.masked_fill(gen_mask == 0, pad_id)
# ---- 4) decode to text with truncation ----
src_texts = tokenizer.batch_decode(src_ids, skip_special_tokens=True)
tgt_texts = tokenizer.batch_decode(tgt_ids, skip_special_tokens=True)
gen_texts = []
for i in range(B):
end = stop_pos[i].item() + 1
ids_cut = ids2[i, :end]
gen_texts.append(tokenizer.decode(ids_cut, skip_special_tokens=True))
# Save & Collect
for s, t, g in zip(src_texts, tgt_texts, gen_texts):
# 简单的后处理:去掉换行符以便存成 TSV
s_clean = s.replace("\n", " ")
t_clean = t.replace("\n", " ")
g_clean = g.replace("\n", " ")
f.write(f"{s_clean}\t{t_clean}\t{g_clean}\n")
all_sources.append(s_clean)
all_targets.append(t_clean)
all_generated.append(g_clean)
return all_sources, all_targets, all_generated
### add saving ckpts
def main():
ckpt_dir = "checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
print(f"Checkpoints will be saved to: {ckpt_dir}")
# Config
m_cfg = ModelConfig(
encoder_name='../jina-embeddings-v2-base-code',
latent_dim=512,
max_seq_len=128 # Wiki 任务文本短,用 128 足够且快
)
t_cfg = TrainConfig(
batch_size=16, # 推理时可以大一点
num_epochs_ae=20, # 增加一点 AE 训练
num_epochs_flow=35, # 增加 Flow 训练
grad_accum_steps=4,
use_amp=False
)
tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, trust_remote_code=True)
# 1. Load Data (Train & Test)
train_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train")
test_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="test")
# Init
ae = ReshapedAutoencoder(m_cfg).to(t_cfg.device).float()
flow = PatchedFlowDiT(m_cfg).to(t_cfg.device).float()
if ae.encoder.config.pad_token_id is None:
ae.encoder.config.pad_token_id = tokenizer.pad_token_id
# trainer = Trainer(ae, flow, t_cfg, train_loader)
## 加上pad_id 和 stop_id
trainer = Trainer(ae, flow, t_cfg, train_loader, pad_id=tokenizer.pad_token_id, stop_id=_pick_stop_id(tokenizer))
# 2. Train AE
opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae)
best_ae_loss = float('inf')
print("\n>>> Start Training Autoencoder...")
for epoch in range(t_cfg.num_epochs_ae):
loss = trainer.train_ae(opt_ae)
print(f"AE Epoch {epoch}: Loss {loss:.4f}")
# 保存 Best
if loss < best_ae_loss:
best_ae_loss = loss
torch.save(ae.state_dict(), os.path.join(ckpt_dir, "ae_best.pt"))
# print(f" Saved Best AE (Loss {loss:.4f})")
# 保存 Last (每个 epoch 覆盖,用于断点续训或检查)
torch.save(ae.state_dict(), os.path.join(ckpt_dir, "ae_last.pt"))
print(f"AE Training Done. Best Loss: {best_ae_loss:.4f}")
# 3. Train Flow
opt_flow = optim.AdamW(flow.parameters(), lr=t_cfg.lr_flow)
best_flow_loss = float('inf')
print("\n>>> Start Training Flow DiT...")
for epoch in range(t_cfg.num_epochs_flow):
loss = trainer.train_flow(opt_flow)
print(f"Flow Epoch {epoch}: Loss {loss:.4f}")
# 保存 Best
if loss < best_flow_loss:
best_flow_loss = loss
torch.save(flow.state_dict(), os.path.join(ckpt_dir, "flow_best.pt"))
# print(f" Saved Best Flow (Loss {loss:.4f})")
# 保存 Last
torch.save(flow.state_dict(), os.path.join(ckpt_dir, "flow_last.pt"))
print(f"Flow Training Done. Best Loss: {best_flow_loss:.4f}")
# 4. Evaluation
# 加载最佳权重
ae_path = os.path.join(ckpt_dir, "ae_best.pt")
flow_path = os.path.join(ckpt_dir, "flow_best.pt")
if os.path.exists(ae_path):
ae.load_state_dict(torch.load(ae_path, map_location=t_cfg.device))
print("Loaded AE Best.")
else:
print("Warning: AE Best ckpt not found, using last state.")
if os.path.exists(flow_path):
flow.load_state_dict(torch.load(flow_path, map_location=t_cfg.device))
print("Loaded Flow Best.")
else:
print("Warning: Flow Best ckpt not found, using last state.")
print("\n--- Starting Inference ---")
sources, targets, gens = inference_batch(
ae, flow, test_loader, tokenizer, t_cfg.device,
steps=10,
save_path="wiki_results.tsv"
)
# Calculate Metrics
metrics = calculate_metrics(sources, gens, targets)
print("\n=== Metrics ===")
for k, v in metrics.items():
print(f"{k}: {v:.4f}")
print(f"\nResults saved to wiki_results.tsv")
if __name__ == "__main__":
main() |