File size: 2,913 Bytes
77d636f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import torch.optim as optim
from transformers import AutoTokenizer
import os
import argparse

from src.config import ModelConfig, TrainConfig
from src.models.autoencoder import ReshapedAutoencoder,ResidualAutoencoder
from src.trainer import Trainer
from src.utils.data_utils import prepare_data

def _pick_stop_id(tokenizer):
    return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_dir", type=str, default="/mnt/hdfs/user/lixinyu.222/CodeFlow/robust_checkpoints", help="Directory to save checkpoints")
    args = parser.parse_args()

    os.makedirs(args.save_dir, exist_ok=True)
    print(f"Checkpoints will be saved to: {args.save_dir}")

    # --- Config ---
    m_cfg = ModelConfig(
        encoder_name='../jina-embeddings-v2-base-code', # 请根据实际路径修改
        latent_dim=512, 
        max_seq_len=128
    )
    
    t_cfg = TrainConfig(
        batch_size=16,
        num_epochs_ae=20, # 只关注 AE 的 epoch
        grad_accum_steps=4,
        use_amp=False,
        lr_ae=1e-4
    )
    
    # --- Data & Tokenizer ---
    tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, local_files_only=True, trust_remote_code=False)
    train_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train")
    
    # --- Model ---
    ae = ReshapedAutoencoder(m_cfg).to(t_cfg.device).float()
    # ae = ResidualAutoencoder(m_cfg).to(t_cfg.device).float()
    
    if ae.encoder.config.pad_token_id is None:
        ae.encoder.config.pad_token_id = tokenizer.pad_token_id

    # --- Trainer ---
    # 这里 flow 传 None,因为只训 AE
    trainer = Trainer(
        ae=ae, 
        flow=None, 
        cfg=t_cfg, 
        loader=train_loader, 
        pad_id=tokenizer.pad_token_id, 
        stop_id=_pick_stop_id(tokenizer)
    )
    
    # --- Optimizer ---
    opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae)
    
    # --- Training Loop ---
    best_ae_loss = float('inf')
    print("\n>>> Start Training Autoencoder...")
    
    for epoch in range(t_cfg.num_epochs_ae):
        # loss = trainer.train_ae(opt_ae)
        # loss = trainer.train_robust_ae(opt_ae)
        loss = trainer.train_ae_combined(opt_ae, epoch, t_cfg.num_epochs_ae)
        print(f"AE Epoch {epoch}: Loss {loss:.4f}")
        
        # Save Best
        if loss < best_ae_loss:
            best_ae_loss = loss
            save_path = os.path.join(args.save_dir, "ae_best.pt")
            torch.save(ae.state_dict(), save_path)
            print(f"  Saved Best AE to {save_path}")
        
        # Save Last
        torch.save(ae.state_dict(), os.path.join(args.save_dir, "ae_last.pt"))
    
    print(f"AE Training Done. Best Loss: {best_ae_loss:.4f}")

if __name__ == "__main__":
    main()