File size: 8,464 Bytes
77d636f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import torch
import torch.optim as optim
from transformers import AutoTokenizer
from tqdm import tqdm
import torch.nn.functional as F
import os
import argparse
import sacrebleu

from src.config import ModelConfig, TrainConfig
from src.models.autoencoder import ReshapedAutoencoder
from src.models.dit import PatchedFlowDiT
from src.trainer import Trainer
from src.utils.data_utils import prepare_data

# --- Helper Functions for Inference (复制过来以便独立运行) ---
def _pick_stop_id(tokenizer):
    return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id

def _first_pos(x_1d, token_id, default):
    idx = (x_1d == token_id).nonzero(as_tuple=True)[0]
    return idx[0].item() if idx.numel() > 0 else default

def calculate_metrics(sources, predictions, references):
    bleu = sacrebleu.corpus_bleu(predictions, [references])
    try:
        sari = sacrebleu.corpus_sari(sources, predictions, [references])
        sari_score = sari.score
    except Exception:
        sari_score = 0.0
    
    ratios = [len(p) / len(s) if len(s) > 0 else 0 for p, s in zip(predictions, sources)]
    avg_ratio = sum(ratios) / len(ratios) if ratios else 0
    
    return {"SARI": sari_score, "BLEU": bleu.score, "Compression Ratio": avg_ratio}

@torch.no_grad()
def inference_batch(ae, flow, loader, tokenizer, device, steps=10, save_path="results.txt", use_oneshot=True):
    ae.eval()
    flow.eval()
    stop_id = _pick_stop_id(tokenizer)
    pad_id = tokenizer.pad_token_id
    
    print(f"\n>>> Running Inference on {len(loader.dataset)} examples...")
    
    all_sources, all_targets, all_generated = [], [], []
    scale = getattr(ae, "latent_scale", 10.0) # 兼容逻辑
    
    with open(save_path, "w", encoding="utf-8") as f:
        f.write("Source\tTarget\tGenerated\n")
        
        for batch in tqdm(loader, desc="Inferencing"):
            src_ids = batch['src_ids'].to(device)
            src_mask = batch['src_mask'].to(device)
            tgt_ids = batch['tgt_ids'].to(device)
            B, L = src_ids.shape

            # Encode
            z_curr = ae.encode(src_ids, src_mask)
            z_cond = z_curr.clone()

            # Flow Sampling
            if use_oneshot:
                t0 = torch.zeros(B, device=device)
                z_curr = flow(z_curr, t0, condition=z_cond).float()
            else:
                dt = 1.0 / steps
                for i in range(steps):
                    t_val = i / steps
                    if t_val >= 0.999: break
                    t = torch.ones(B, device=device) * t_val
                    pred_z1 = flow(z_curr, t, condition=z_cond).float()
                    v = (pred_z1 - z_curr) / (1.0 - t_val + 1e-4) 
                    z_curr = z_curr + v * dt
                z_curr = pred_z1 
            
            # Decode (Pass 1: Detect Length)
            full_mask = torch.ones(B, L, device=device)
            logits1 = ae.decode(z_curr, attention_mask=full_mask)
            ids1 = logits1.argmax(dim=-1)

            stop_pos = []
            for i in range(B):
                pos = _first_pos(ids1[i], stop_id, default=L - 1)
                stop_pos.append(pos)
            
            # Decode (Pass 2: Clean Decode)
            gen_mask = torch.zeros(B, L, device=device)
            for i in range(B):
                gen_mask[i, : stop_pos[i] + 1] = 1.0
            
            logits2 = ae.decode(z_curr, attention_mask=gen_mask)
            ids2 = logits2.argmax(dim=-1)
            ids2 = ids2.masked_fill(gen_mask == 0, pad_id)

            # Convert to Text
            src_texts = tokenizer.batch_decode(src_ids, skip_special_tokens=True)
            tgt_texts = tokenizer.batch_decode(tgt_ids, skip_special_tokens=True)
            
            gen_texts = []
            for i in range(B):
                end = stop_pos[i] + 1
                ids_cut = ids2[i, :end]
                gen_texts.append(tokenizer.decode(ids_cut, skip_special_tokens=True))
    
            for s, t, g in zip(src_texts, tgt_texts, gen_texts):
                s_c = s.replace("\n", " ")
                t_c = t.replace("\n", " ")
                g_c = g.replace("\n", " ")
                f.write(f"{s_c}\t{t_c}\t{g_c}\n")
                all_sources.append(s_c)
                all_targets.append(t_c)
                all_generated.append(g_c)
                
    return all_sources, all_targets, all_generated


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--ae_ckpt", type=str, default="/mnt/hdfs/user/lixinyu.222/CodeFlow/residual_robust_checkpoints/ae_best.pt", help="Path to pre-trained AE checkpoint")
    parser.add_argument("--save_dir", type=str, default="residual_robust_checkpoints", help="Directory to save flow checkpoints")
    parser.add_argument("--use_oneshot", action="store_true", default=True, help="Use one-shot sampling for inference")
    args = parser.parse_args()

    os.makedirs(args.save_dir, exist_ok=True)

    # --- Config ---
    m_cfg = ModelConfig(
        encoder_name='../jina-embeddings-v2-base-code',
        latent_dim=512, 
        max_seq_len=128
    )
    
    t_cfg = TrainConfig(
        batch_size=16,
        num_epochs_flow=35, # 只关注 Flow 的 epoch
        grad_accum_steps=4,
        use_amp=False,
        lr_flow=2e-4
    )
    
    # --- Tokenizer & Data ---
    tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name,local_files_only=True, trust_remote_code=False)
    train_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train")
    test_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="test")

    # --- Load AE (Pre-trained) ---
    print(f"\n>>> Loading Pre-trained Autoencoder from {args.ae_ckpt} ...")
    ae = ReshapedAutoencoder(m_cfg).to(t_cfg.device).float()
    
    if not os.path.exists(args.ae_ckpt):
        raise FileNotFoundError(f"AE checkpoint not found at {args.ae_ckpt}. Please run train_ae.py first.")
        
    ae.load_state_dict(torch.load(args.ae_ckpt, map_location=t_cfg.device))
    
    # 冻结 AE 的所有参数,Flow 训练时不更新 AE
    ae.eval()
    for param in ae.parameters():
        param.requires_grad = False
    print(">>> Autoencoder loaded and frozen.")

    if ae.encoder.config.pad_token_id is None:
        ae.encoder.config.pad_token_id = tokenizer.pad_token_id

    # --- Initialize Flow ---
    flow = PatchedFlowDiT(m_cfg).to(t_cfg.device).float()
    
    # --- Trainer ---
    trainer = Trainer(
        ae=ae, 
        flow=flow, 
        cfg=t_cfg, 
        loader=train_loader,
        pad_id=tokenizer.pad_token_id, 
        stop_id=_pick_stop_id(tokenizer)
    )

    # --- Optimizer ---
    opt_flow = optim.AdamW(flow.parameters(), lr=t_cfg.lr_flow)

    # --- Training Loop ---
    best_flow_loss = float('inf')
    print("\n>>> Start Training Flow DiT...")
    
    for epoch in range(t_cfg.num_epochs_flow):
        # 传入 opt_flow 训练 Flow
        loss = trainer.train_flow(opt_flow)
        print(f"Flow Epoch {epoch}: Loss {loss:.4f}")

        # Save Best
        if loss < best_flow_loss:
            best_flow_loss = loss
            save_path = os.path.join(args.save_dir, "flow_best.pt")
            torch.save(flow.state_dict(), save_path)
            # print(f"  Saved Best Flow to {save_path}")
            
        # Save Last
        torch.save(flow.state_dict(), os.path.join(args.save_dir, "flow_last.pt"))
    
    print(f"Flow Training Done. Best Loss: {best_flow_loss:.4f}")
    
    # --- Inference / Evaluation ---
    print("\n>>> Loading Best Flow Checkpoint for Evaluation...")
    best_flow_path = os.path.join(args.save_dir, "flow_best.pt")
    if os.path.exists(best_flow_path):
        flow.load_state_dict(torch.load(best_flow_path, map_location=t_cfg.device))
    else:
        print("Warning: Best checkpoint not found, utilizing last epoch weights.")

    print("\n--- Starting Inference ---")
    sources, targets, gens = inference_batch(
        ae, flow, test_loader, tokenizer, t_cfg.device, 
        steps=10, 
        save_path="wiki_results.tsv",
        use_oneshot=args.use_oneshot
    )
    
    # Metrics
    metrics = calculate_metrics(sources, gens, targets)
    print("\n=== Metrics ===")
    for k, v in metrics.items():
        print(f"{k}: {v:.4f}")
    
    print(f"\nResults saved to wiki_results.tsv")

if __name__ == "__main__":
    main()