Q-RAG_Test / train_q_rag_logt.py
TroyHow's picture
log with time record
c29f6a5 verified
import sys
import os
import time # 新增:导入时间模块
repo_dir = os.path.dirname(os.path.abspath("./"))
if repo_dir not in sys.path:
print(f'add repository dir: {repo_dir}')
sys.path.append(repo_dir)
from torch.utils.tensorboard import SummaryWriter
import torch
import sys
from rl.agents.pqn import PQN
import numpy as np
from envs.qa_env import QAEnv
from envs.parallel_env import ParallelTextEnv
from tqdm import tqdm
from omegaconf import OmegaConf, DictConfig
from hydra.utils import instantiate
from hydra import initialize, compose
import random
from datetime import datetime
@torch.no_grad()
def evaluate(env_test, agent):
s_t = env_test.reset()
done_t = False
a_embeds_t, a_embeds_target_t = env_test.get_extra_embeds(agent.action_tokenizer, agent.critic.action_embed, agent.action_embed_target)
r_sum_t = 0
while not done_t:
a_embeds_t = env_test.update_embeds(a_embeds_t, agent.critic.action_embed)
a_embeds_target_t = env_test.update_embeds(a_embeds_target_t, agent.action_embed_target)
action_t, _, _ = agent.select_action(s_t, a_embeds_t["rope"], a_embeds_target_t["rope"], random=False, evaluate=True)
s_t, _, reward_t, done_t = env_test.step(action_t)
r_sum_t += reward_t
return r_sum_t
def load_config(name, overrides=None):
with initialize(version_base="1.3", config_path="./configs"):
cfg = compose(
config_name=name,
overrides=sys.argv[1:] #overrides if overrides else []
)
#cli_cfg = OmegaConf.from_cli()
#cfg = OmegaConf.merge(cfg, cli_cfg)
cfg = prepare_config(cfg)
return cfg
def prepare_config(cfg):
"""
modifies config for parameters that should depend on each other
"""
if cfg.logger.log_dir is not None:
dir_name = datetime.now().strftime("%b%d_%H-%M-%S") + cfg.logger.tensorboard.comment
cfg.logger.log_dir = os.path.join(cfg.logger.log_dir, dir_name)
cfg.logger.tensorboard.log_dir = os.path.join(cfg.logger.log_dir, 'tb_logs/')
# enumerate_facts = (cfg.positional_coding == 'enum') #TODO: add version that enumerate all chunks
# cfg.envs.env.dataset.task_dataset.add_sentence_idx = enumerate_facts
# cfg.envs.test_env.dataset.task_dataset.add_sentence_idx = enumerate_facts
return cfg
def set_all_seeds(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
# 格式化时间(秒转时分秒)
def format_time(seconds):
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
cfg: DictConfig = load_config(name="training.yaml")
#cfg: DictConfig = load_config(name="training_gte_combined.yaml")
writer: SummaryWriter = instantiate(cfg.logger.tensorboard)
os.makedirs(cfg.logger.log_dir, exist_ok=True)
config_save_path = os.path.join(cfg.logger.log_dir, "config.yaml")
OmegaConf.save(config=cfg, f=config_save_path, resolve=False)
print(f"[INFO] Training config saved to {config_save_path}")
agent_config: DictConfig = cfg.algo
env_config: DictConfig = cfg.envs
print("Embedder model:", agent_config.model.model_name)
# path to checkpoints and metric to determine the best model
ckpt_last_path = os.path.join(cfg.logger.log_dir, "model_last.pt")
ckpt_best_path = os.path.join(cfg.logger.log_dir, "model_best.pt")
best_eval_reward = -float("inf")
torch.set_default_device(cfg.device)
torch.set_float32_matmul_precision('high')
set_all_seeds(cfg.seed)
# MAX_TOKEN_LENGTH["state"] = cfg.max_state_length
# MAX_TOKEN_LENGTH["action"] = cfg.max_action_length
agent = PQN(agent_config)
# if bf16:
# for m in [agent.critic, agent.policy, agent.random_policy,
# agent.v_net_target, agent.action_embed_target]:
# m.to(dtype=torch.bfloat16)
#
# if args.fp16:
# # import apex
# # apex.amp.register_half_function(torch, 'einsum')
# from torch.cuda.amp import autocast, GradScaler
#
# scaler = GradScaler()
#
# device_type = torch.device(cfg.device).type
# amp_dtype = torch.bfloat16 if bf16 else torch.float16
# amp_enabled = bf16 or mixed_precision
# autocast = torch.cuda.amp.autocast if device_type == 'cuda' else torch.autocast
env: QAEnv = instantiate(env_config.env)
env_test: QAEnv = instantiate(env_config.test_env)
parallel_env = ParallelTextEnv(
[env] + [env.copy() for _ in range(cfg.envs_parallel - 1)],
state_tokenizer=agent.state_tokenizer,
action_tokenizer=agent.action_tokenizer)
total_steps = cfg.steps_count * cfg.accumulate_grads # 80000
eval_interval = cfg.eval_interval * cfg.accumulate_grads # 50 * 8 = 400 it
log_interval = cfg.accumulate_grads * 100 // cfg.accumulate_grads # 固定每 100 it 记录一次
log_interval = 100 # 每 100 it 记录 reward / qf_loss
#assuming we don't need to scale cfg.learning_start with grad_accumulation
progress_bar = tqdm(range(total_steps), desc="Training")
# ---------- log.txt 初始化 ----------
log_path = os.path.join(cfg.logger.log_dir, "log.txt")
log_file = open(log_path, "w", buffering=1) # buffering=1: 行缓冲,实时落盘
print(f"[INFO] Log file saved to {log_path}")
# 写入日志头部信息
log_file.write(f"Training started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
log_file.write(f"Total iterations: {total_steps}\n")
log_file.write("="*100 + "\n")
# ------------------------------------
# 记录训练开始时间
train_start_time = time.time()
states_list, _ = parallel_env.reset()
step = 0
train_rewards = []
last_eval_reward = float("nan") # 还没跑过 eval 时显示 nan
for it in progress_bar:
agent.train()
states_list, rewards, train_batch = parallel_env.rollout(cfg.batch_size, states_list, agent, random=(step < 2 * cfg.learning_start))
step += np.prod(train_batch.reward.shape)
train_rewards.extend(rewards)
qf_loss = agent.update(
train_batch.state,
train_batch.action,
train_batch.next_state,
train_batch.q_values,
train_batch.reward,
train_batch.not_done)
# ---- 每 100 it:记录 reward / qf_loss 到 tensorboard + log.txt ----
if it % log_interval == 0 and it > 0:
# 计算时间相关信息
elapsed_time = time.time() - train_start_time
avg_time_per_it = elapsed_time / (it + 1)
remaining_it = total_steps - it - 1
remaining_time = remaining_it * avg_time_per_it
total_estimated_time = elapsed_time + remaining_time
# 格式化时间
elapsed_time_str = format_time(elapsed_time)
remaining_time_str = format_time(remaining_time)
total_time_str = format_time(total_estimated_time)
# 计算进度百分比
progress_pct = (it + 1) / total_steps * 100
train_r_mean = np.mean(train_rewards)
writer.add_scalar("train r_sum", train_r_mean, step)
writer.add_scalar("qf_loss", qf_loss, step)
writer.add_scalar("training/elapsed_time_hours", elapsed_time/3600, step)
writer.add_scalar("training/remaining_time_hours", remaining_time/3600, step)
log_file.write(
f"{it}/{total_steps} ({progress_pct:.1f}%), "
f"elapsed_time={elapsed_time_str}, "
f"remaining_time={remaining_time_str}, "
f"total_estimated={total_time_str}, "
f"reward={train_r_mean:.3f}, "
f"eval_reward={last_eval_reward:.3f}, "
f"qf_loss={float(qf_loss):.3f}, "
f"step={step}\n"
)
# ---- 每 400 it(eval_interval=50×8):跑 eval,更新 eval_reward ----
if it % eval_interval == 0:
agent.eval()
r_eval = []
# 记录eval开始时间
eval_start_time = time.time()
for j in range(cfg.eval_episodes):
r_eval.append(evaluate(env_test, agent))
print(f"\reval prog: {len(r_eval)}/{cfg.eval_episodes}", end="")
eval_elapsed_time = time.time() - eval_start_time
last_eval_reward = float(np.mean(r_eval))
writer.add_scalar("eval r_sum", last_eval_reward, step)
writer.add_scalar("eval/elapsed_time_seconds", eval_elapsed_time, step)
# 更新进度条,添加时间信息
elapsed_time = time.time() - train_start_time
avg_time_per_it = elapsed_time / (it + 1)
remaining_time = (total_steps - it - 1) * avg_time_per_it
progress_bar.set_postfix({
'reward': np.mean(train_rewards),
"eval_reward": last_eval_reward,
'qf_loss': qf_loss,
'step': step,
'elapsed': format_time(elapsed_time),
'remaining': format_time(remaining_time)
})
agent.save(ckpt_last_path)
#torch.save(agent.state_dict(), ckpt_last_path)
if last_eval_reward > best_eval_reward:
best_eval_reward = last_eval_reward
agent.save(ckpt_best_path)
#torch.save(agent.state_dict(), ckpt_best_path)
#print(f"[INFO] New best model saved with reward {best_eval_reward:.3f}")
train_rewards = []
# 训练结束,记录最终时间信息
total_train_time = time.time() - train_start_time
log_file.write("\n" + "="*100 + "\n")
log_file.write(f"Training finished at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
log_file.write(f"Total training time: {format_time(total_train_time)}\n")
log_file.write(f"Best eval reward: {best_eval_reward:.3f}\n")
log_file.close()
# 打印最终统计信息
print(f"\n[INFO] Training completed!")
print(f"Total training time: {format_time(total_train_time)}")
print(f"Best evaluation reward: {best_eval_reward:.3f}")
print(f"Logs saved to: {cfg.logger.log_dir}")