Diff-Refine / train_ae.py
2ira's picture
Add files using upload-large-folder tool
77d636f verified
import torch
import torch.optim as optim
from transformers import AutoTokenizer
import os
import argparse
from src.config import ModelConfig, TrainConfig
from src.models.autoencoder import ReshapedAutoencoder,ResidualAutoencoder
from src.trainer import Trainer
from src.utils.data_utils import prepare_data
def _pick_stop_id(tokenizer):
return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--save_dir", type=str, default="/mnt/hdfs/user/lixinyu.222/CodeFlow/robust_checkpoints", help="Directory to save checkpoints")
args = parser.parse_args()
os.makedirs(args.save_dir, exist_ok=True)
print(f"Checkpoints will be saved to: {args.save_dir}")
# --- Config ---
m_cfg = ModelConfig(
encoder_name='../jina-embeddings-v2-base-code', # 请根据实际路径修改
latent_dim=512,
max_seq_len=128
)
t_cfg = TrainConfig(
batch_size=16,
num_epochs_ae=20, # 只关注 AE 的 epoch
grad_accum_steps=4,
use_amp=False,
lr_ae=1e-4
)
# --- Data & Tokenizer ---
tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, local_files_only=True, trust_remote_code=False)
train_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train")
# --- Model ---
ae = ReshapedAutoencoder(m_cfg).to(t_cfg.device).float()
# ae = ResidualAutoencoder(m_cfg).to(t_cfg.device).float()
if ae.encoder.config.pad_token_id is None:
ae.encoder.config.pad_token_id = tokenizer.pad_token_id
# --- Trainer ---
# 这里 flow 传 None,因为只训 AE
trainer = Trainer(
ae=ae,
flow=None,
cfg=t_cfg,
loader=train_loader,
pad_id=tokenizer.pad_token_id,
stop_id=_pick_stop_id(tokenizer)
)
# --- Optimizer ---
opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae)
# --- Training Loop ---
best_ae_loss = float('inf')
print("\n>>> Start Training Autoencoder...")
for epoch in range(t_cfg.num_epochs_ae):
# loss = trainer.train_ae(opt_ae)
# loss = trainer.train_robust_ae(opt_ae)
loss = trainer.train_ae_combined(opt_ae, epoch, t_cfg.num_epochs_ae)
print(f"AE Epoch {epoch}: Loss {loss:.4f}")
# Save Best
if loss < best_ae_loss:
best_ae_loss = loss
save_path = os.path.join(args.save_dir, "ae_best.pt")
torch.save(ae.state_dict(), save_path)
print(f" Saved Best AE to {save_path}")
# Save Last
torch.save(ae.state_dict(), os.path.join(args.save_dir, "ae_last.pt"))
print(f"AE Training Done. Best Loss: {best_ae_loss:.4f}")
if __name__ == "__main__":
main()