import torch import torch.optim as optim from transformers import AutoTokenizer import os import argparse from src.config import ModelConfig, TrainConfig from src.models.autoencoder import ReshapedAutoencoder,ResidualAutoencoder from src.trainer import Trainer from src.utils.data_utils import prepare_data def _pick_stop_id(tokenizer): return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id def main(): parser = argparse.ArgumentParser() parser.add_argument("--save_dir", type=str, default="/mnt/hdfs/user/lixinyu.222/CodeFlow/robust_checkpoints", help="Directory to save checkpoints") args = parser.parse_args() os.makedirs(args.save_dir, exist_ok=True) print(f"Checkpoints will be saved to: {args.save_dir}") # --- Config --- m_cfg = ModelConfig( encoder_name='../jina-embeddings-v2-base-code', # 请根据实际路径修改 latent_dim=512, max_seq_len=128 ) t_cfg = TrainConfig( batch_size=16, num_epochs_ae=20, # 只关注 AE 的 epoch grad_accum_steps=4, use_amp=False, lr_ae=1e-4 ) # --- Data & Tokenizer --- tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, local_files_only=True, trust_remote_code=False) train_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train") # --- Model --- ae = ReshapedAutoencoder(m_cfg).to(t_cfg.device).float() # ae = ResidualAutoencoder(m_cfg).to(t_cfg.device).float() if ae.encoder.config.pad_token_id is None: ae.encoder.config.pad_token_id = tokenizer.pad_token_id # --- Trainer --- # 这里 flow 传 None,因为只训 AE trainer = Trainer( ae=ae, flow=None, cfg=t_cfg, loader=train_loader, pad_id=tokenizer.pad_token_id, stop_id=_pick_stop_id(tokenizer) ) # --- Optimizer --- opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae) # --- Training Loop --- best_ae_loss = float('inf') print("\n>>> Start Training Autoencoder...") for epoch in range(t_cfg.num_epochs_ae): # loss = trainer.train_ae(opt_ae) # loss = trainer.train_robust_ae(opt_ae) loss = trainer.train_ae_combined(opt_ae, epoch, t_cfg.num_epochs_ae) print(f"AE Epoch {epoch}: Loss {loss:.4f}") # Save Best if loss < best_ae_loss: best_ae_loss = loss save_path = os.path.join(args.save_dir, "ae_best.pt") torch.save(ae.state_dict(), save_path) print(f" Saved Best AE to {save_path}") # Save Last torch.save(ae.state_dict(), os.path.join(args.save_dir, "ae_last.pt")) print(f"AE Training Done. Best Loss: {best_ae_loss:.4f}") if __name__ == "__main__": main()