File size: 2,913 Bytes
77d636f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | import torch
import torch.optim as optim
from transformers import AutoTokenizer
import os
import argparse
from src.config import ModelConfig, TrainConfig
from src.models.autoencoder import ReshapedAutoencoder,ResidualAutoencoder
from src.trainer import Trainer
from src.utils.data_utils import prepare_data
def _pick_stop_id(tokenizer):
return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--save_dir", type=str, default="/mnt/hdfs/user/lixinyu.222/CodeFlow/robust_checkpoints", help="Directory to save checkpoints")
args = parser.parse_args()
os.makedirs(args.save_dir, exist_ok=True)
print(f"Checkpoints will be saved to: {args.save_dir}")
# --- Config ---
m_cfg = ModelConfig(
encoder_name='../jina-embeddings-v2-base-code', # 请根据实际路径修改
latent_dim=512,
max_seq_len=128
)
t_cfg = TrainConfig(
batch_size=16,
num_epochs_ae=20, # 只关注 AE 的 epoch
grad_accum_steps=4,
use_amp=False,
lr_ae=1e-4
)
# --- Data & Tokenizer ---
tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, local_files_only=True, trust_remote_code=False)
train_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train")
# --- Model ---
ae = ReshapedAutoencoder(m_cfg).to(t_cfg.device).float()
# ae = ResidualAutoencoder(m_cfg).to(t_cfg.device).float()
if ae.encoder.config.pad_token_id is None:
ae.encoder.config.pad_token_id = tokenizer.pad_token_id
# --- Trainer ---
# 这里 flow 传 None,因为只训 AE
trainer = Trainer(
ae=ae,
flow=None,
cfg=t_cfg,
loader=train_loader,
pad_id=tokenizer.pad_token_id,
stop_id=_pick_stop_id(tokenizer)
)
# --- Optimizer ---
opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae)
# --- Training Loop ---
best_ae_loss = float('inf')
print("\n>>> Start Training Autoencoder...")
for epoch in range(t_cfg.num_epochs_ae):
# loss = trainer.train_ae(opt_ae)
# loss = trainer.train_robust_ae(opt_ae)
loss = trainer.train_ae_combined(opt_ae, epoch, t_cfg.num_epochs_ae)
print(f"AE Epoch {epoch}: Loss {loss:.4f}")
# Save Best
if loss < best_ae_loss:
best_ae_loss = loss
save_path = os.path.join(args.save_dir, "ae_best.pt")
torch.save(ae.state_dict(), save_path)
print(f" Saved Best AE to {save_path}")
# Save Last
torch.save(ae.state_dict(), os.path.join(args.save_dir, "ae_last.pt"))
print(f"AE Training Done. Best Loss: {best_ae_loss:.4f}")
if __name__ == "__main__":
main() |