| import torch |
| import torch.optim as optim |
| from transformers import AutoTokenizer |
| from tqdm import tqdm |
| import torch.nn.functional as F |
| import os |
| import evaluate |
| import sacrebleu |
|
|
| from src.config import ModelConfig, TrainConfig |
| from src.models.autoencoder import ReshapedAutoencoder |
| from src.models.dit import PatchedFlowDiT |
| from src.trainer import Trainer |
| from src.utils.data_utils import prepare_data |
|
|
| |
| def _pick_stop_id(tokenizer): |
| |
| return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id |
|
|
| def _first_pos(x_1d, token_id, default): |
| |
| idx = (x_1d == token_id).nonzero(as_tuple=True)[0] |
| return idx[0].item() if idx.numel() > 0 else default |
|
|
|
|
| def calculate_metrics(sources, predictions, references): |
| """ |
| 计算 SARI, BLEU, 和 压缩比 |
| """ |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| bleu = sacrebleu.corpus_bleu(predictions, [references]) |
| |
| |
| try: |
| |
| sari = sacrebleu.corpus_sari(sources, predictions, [references]) |
| sari_score = sari.score |
| except Exception as e: |
| print(f"SARI calculation failed: {e}") |
| sari_score = 0.0 |
| |
| |
| ratios = [len(p) / len(s) if len(s) > 0 else 0 for p, s in zip(predictions, sources)] |
| avg_ratio = sum(ratios) / len(ratios) |
| |
| return { |
| "SARI": sari_score, |
| "BLEU": bleu.score, |
| "Compression Ratio": avg_ratio |
| } |
|
|
| @torch.no_grad() |
| def inference_batch(ae, flow, loader, tokenizer, device, steps=10, save_path="results.txt",use_oneshot=True): |
| ae.eval() |
| flow.eval() |
|
|
| stop_id = _pick_stop_id(tokenizer) |
| pad_id = tokenizer.pad_token_id |
| |
| print(f"\n>>> Running Inference on {len(loader.dataset)} examples...") |
| |
| all_sources = [] |
| all_targets = [] |
| all_generated = [] |
|
|
| scale = getattr(ae, "latent_scale", 10.0) |
| |
| with open(save_path, "w", encoding="utf-8") as f: |
| f.write("Source\tTarget\tGenerated\n") |
| |
| for batch in tqdm(loader, desc="Inferencing"): |
| src_ids = batch['src_ids'].to(device) |
| src_mask = batch['src_mask'].to(device) |
| tgt_ids = batch['tgt_ids'].to(device) |
| B, L = src_ids.shape |
|
|
| z_curr = ae.encode(src_ids, src_mask) |
| z_cond = z_curr.clone() |
|
|
| |
| if use_oneshot: |
| |
| t0 = torch.zeros(B, device=device) |
| z_curr = flow(z_curr, t0, condition=z_cond).float() |
| else: |
| dt = 1.0 / steps |
| for i in range(steps): |
| t_val = i / steps |
| |
| if t_val >= 0.999: break |
| t = torch.ones(z_curr.shape[0], device=device) * t_val |
| |
| |
| |
| |
| |
| |
| pred_z1 = flow(z_curr, t, condition=z_cond).float() |
| |
| v = (pred_z1 - z_curr) / (1.0 - t_val + + 1e-4) |
| z_curr = z_curr + v * dt |
| z_curr = F.normalize(z_curr, p=2, dim=-1) * scale |
| z_curr = pred_z1 |
| |
| z_curr = torch.nn.functional.normalize(z_curr, p=2, dim=-1) * scale |
| |
| |
| full_mask = torch.ones(B, L, device=device) |
| |
| |
| logits1 = ae.decode(z_curr, attention_mask=full_mask) |
| ids1 = logits1.argmax(dim=-1) |
|
|
| |
| stop_pos = [] |
| for i in range(B): |
| |
| pos = _first_pos(ids1[i], stop_id, default=L - 1) |
| stop_pos.append(pos) |
| stop_pos = torch.tensor(stop_pos, device=device) |
|
|
| gen_mask = torch.zeros(B, L, device=device) |
| for i in range(B): |
| gen_mask[i, : stop_pos[i].item() + 1] = 1.0 |
|
|
| |
| logits2 = ae.decode(z_curr, attention_mask=gen_mask) |
| ids2 = logits2.argmax(dim=-1) |
|
|
| |
| ids2 = ids2.masked_fill(gen_mask == 0, pad_id) |
|
|
| |
| src_texts = tokenizer.batch_decode(src_ids, skip_special_tokens=True) |
| tgt_texts = tokenizer.batch_decode(tgt_ids, skip_special_tokens=True) |
|
|
| gen_texts = [] |
| for i in range(B): |
| end = stop_pos[i].item() + 1 |
| ids_cut = ids2[i, :end] |
| gen_texts.append(tokenizer.decode(ids_cut, skip_special_tokens=True)) |
| |
| |
| for s, t, g in zip(src_texts, tgt_texts, gen_texts): |
| |
| s_clean = s.replace("\n", " ") |
| t_clean = t.replace("\n", " ") |
| g_clean = g.replace("\n", " ") |
| |
| f.write(f"{s_clean}\t{t_clean}\t{g_clean}\n") |
| |
| all_sources.append(s_clean) |
| all_targets.append(t_clean) |
| all_generated.append(g_clean) |
| |
| return all_sources, all_targets, all_generated |
|
|
| |
| def main(): |
|
|
| ckpt_dir = "checkpoints" |
| os.makedirs(ckpt_dir, exist_ok=True) |
| print(f"Checkpoints will be saved to: {ckpt_dir}") |
|
|
| |
| m_cfg = ModelConfig( |
| encoder_name='../jina-embeddings-v2-base-code', |
| latent_dim=512, |
| max_seq_len=128 |
| ) |
| |
| t_cfg = TrainConfig( |
| batch_size=16, |
| num_epochs_ae=20, |
| num_epochs_flow=35, |
| grad_accum_steps=4, |
| use_amp=False |
| ) |
| |
| tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, trust_remote_code=True) |
| |
| |
| train_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train") |
| test_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="test") |
| |
| |
| ae = ReshapedAutoencoder(m_cfg).to(t_cfg.device).float() |
| flow = PatchedFlowDiT(m_cfg).to(t_cfg.device).float() |
| |
| if ae.encoder.config.pad_token_id is None: |
| ae.encoder.config.pad_token_id = tokenizer.pad_token_id |
| |
| |
| |
| trainer = Trainer(ae, flow, t_cfg, train_loader, pad_id=tokenizer.pad_token_id, stop_id=_pick_stop_id(tokenizer)) |
|
|
| |
| |
| opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae) |
| best_ae_loss = float('inf') |
| print("\n>>> Start Training Autoencoder...") |
| for epoch in range(t_cfg.num_epochs_ae): |
| loss = trainer.train_ae(opt_ae) |
| print(f"AE Epoch {epoch}: Loss {loss:.4f}") |
| |
| |
| if loss < best_ae_loss: |
| best_ae_loss = loss |
| torch.save(ae.state_dict(), os.path.join(ckpt_dir, "ae_best.pt")) |
| |
| |
| |
| torch.save(ae.state_dict(), os.path.join(ckpt_dir, "ae_last.pt")) |
| |
| print(f"AE Training Done. Best Loss: {best_ae_loss:.4f}") |
|
|
| |
| opt_flow = optim.AdamW(flow.parameters(), lr=t_cfg.lr_flow) |
| best_flow_loss = float('inf') |
| print("\n>>> Start Training Flow DiT...") |
| for epoch in range(t_cfg.num_epochs_flow): |
| loss = trainer.train_flow(opt_flow) |
| print(f"Flow Epoch {epoch}: Loss {loss:.4f}") |
|
|
| |
| if loss < best_flow_loss: |
| best_flow_loss = loss |
| torch.save(flow.state_dict(), os.path.join(ckpt_dir, "flow_best.pt")) |
| |
| |
| |
| torch.save(flow.state_dict(), os.path.join(ckpt_dir, "flow_last.pt")) |
| |
| print(f"Flow Training Done. Best Loss: {best_flow_loss:.4f}") |
| |
| |
| |
| ae_path = os.path.join(ckpt_dir, "ae_best.pt") |
| flow_path = os.path.join(ckpt_dir, "flow_best.pt") |
| |
| if os.path.exists(ae_path): |
| ae.load_state_dict(torch.load(ae_path, map_location=t_cfg.device)) |
| print("Loaded AE Best.") |
| else: |
| print("Warning: AE Best ckpt not found, using last state.") |
|
|
| if os.path.exists(flow_path): |
| flow.load_state_dict(torch.load(flow_path, map_location=t_cfg.device)) |
| print("Loaded Flow Best.") |
| else: |
| print("Warning: Flow Best ckpt not found, using last state.") |
|
|
| print("\n--- Starting Inference ---") |
| sources, targets, gens = inference_batch( |
| ae, flow, test_loader, tokenizer, t_cfg.device, |
| steps=10, |
| save_path="wiki_results.tsv" |
| ) |
| |
| |
| metrics = calculate_metrics(sources, gens, targets) |
| print("\n=== Metrics ===") |
| for k, v in metrics.items(): |
| print(f"{k}: {v:.4f}") |
| |
| print(f"\nResults saved to wiki_results.tsv") |
|
|
| if __name__ == "__main__": |
| main() |