| import sys
|
| import os
|
| import time
|
|
|
| repo_dir = os.path.dirname(os.path.abspath("./"))
|
| if repo_dir not in sys.path:
|
| print(f'add repository dir: {repo_dir}')
|
| sys.path.append(repo_dir)
|
|
|
| from torch.utils.tensorboard import SummaryWriter
|
| import torch
|
| import sys
|
| from rl.agents.pqn import PQN
|
| import numpy as np
|
| from envs.qa_env import QAEnv
|
| from envs.parallel_env import ParallelTextEnv
|
| from tqdm import tqdm
|
| from omegaconf import OmegaConf, DictConfig
|
| from hydra.utils import instantiate
|
| from hydra import initialize, compose
|
| import random
|
| from datetime import datetime
|
|
|
|
|
| @torch.no_grad()
|
| def evaluate(env_test, agent):
|
| s_t = env_test.reset()
|
| done_t = False
|
| a_embeds_t, a_embeds_target_t = env_test.get_extra_embeds(agent.action_tokenizer, agent.critic.action_embed, agent.action_embed_target)
|
| r_sum_t = 0
|
| while not done_t:
|
| a_embeds_t = env_test.update_embeds(a_embeds_t, agent.critic.action_embed)
|
| a_embeds_target_t = env_test.update_embeds(a_embeds_target_t, agent.action_embed_target)
|
|
|
| action_t, _, _ = agent.select_action(s_t, a_embeds_t["rope"], a_embeds_target_t["rope"], random=False, evaluate=True)
|
| s_t, _, reward_t, done_t = env_test.step(action_t)
|
| r_sum_t += reward_t
|
|
|
| return r_sum_t
|
|
|
|
|
| def load_config(name, overrides=None):
|
| with initialize(version_base="1.3", config_path="./configs"):
|
|
|
| cfg = compose(
|
| config_name=name,
|
| overrides=sys.argv[1:]
|
| )
|
|
|
|
|
| cfg = prepare_config(cfg)
|
| return cfg
|
|
|
|
|
| def prepare_config(cfg):
|
| """
|
| modifies config for parameters that should depend on each other
|
| """
|
| if cfg.logger.log_dir is not None:
|
| dir_name = datetime.now().strftime("%b%d_%H-%M-%S") + cfg.logger.tensorboard.comment
|
| cfg.logger.log_dir = os.path.join(cfg.logger.log_dir, dir_name)
|
| cfg.logger.tensorboard.log_dir = os.path.join(cfg.logger.log_dir, 'tb_logs/')
|
|
|
|
|
|
|
|
|
| return cfg
|
|
|
|
|
| def set_all_seeds(seed):
|
| random.seed(seed)
|
| np.random.seed(seed)
|
| torch.manual_seed(seed)
|
| torch.cuda.manual_seed(seed)
|
| torch.backends.cudnn.deterministic = True
|
|
|
|
|
|
|
| def format_time(seconds):
|
| hours = int(seconds // 3600)
|
| minutes = int((seconds % 3600) // 60)
|
| secs = int(seconds % 60)
|
| return f"{hours:02d}:{minutes:02d}:{secs:02d}"
|
|
|
|
|
| cfg: DictConfig = load_config(name="training.yaml")
|
|
|
|
|
| writer: SummaryWriter = instantiate(cfg.logger.tensorboard)
|
| os.makedirs(cfg.logger.log_dir, exist_ok=True)
|
| config_save_path = os.path.join(cfg.logger.log_dir, "config.yaml")
|
| OmegaConf.save(config=cfg, f=config_save_path, resolve=False)
|
| print(f"[INFO] Training config saved to {config_save_path}")
|
|
|
| agent_config: DictConfig = cfg.algo
|
| env_config: DictConfig = cfg.envs
|
| print("Embedder model:", agent_config.model.model_name)
|
|
|
|
|
| ckpt_last_path = os.path.join(cfg.logger.log_dir, "model_last.pt")
|
| ckpt_best_path = os.path.join(cfg.logger.log_dir, "model_best.pt")
|
| best_eval_reward = -float("inf")
|
|
|
| torch.set_default_device(cfg.device)
|
| torch.set_float32_matmul_precision('high')
|
| set_all_seeds(cfg.seed)
|
|
|
|
|
|
|
|
|
| agent = PQN(agent_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| env: QAEnv = instantiate(env_config.env)
|
| env_test: QAEnv = instantiate(env_config.test_env)
|
| parallel_env = ParallelTextEnv(
|
| [env] + [env.copy() for _ in range(cfg.envs_parallel - 1)],
|
| state_tokenizer=agent.state_tokenizer,
|
| action_tokenizer=agent.action_tokenizer)
|
|
|
| total_steps = cfg.steps_count * cfg.accumulate_grads
|
| eval_interval = cfg.eval_interval * cfg.accumulate_grads
|
| log_interval = cfg.accumulate_grads * 100 // cfg.accumulate_grads
|
| log_interval = 100
|
|
|
|
|
| progress_bar = tqdm(range(total_steps), desc="Training")
|
|
|
|
|
| log_path = os.path.join(cfg.logger.log_dir, "log.txt")
|
| log_file = open(log_path, "w", buffering=1)
|
| print(f"[INFO] Log file saved to {log_path}")
|
|
|
| log_file.write(f"Training started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
| log_file.write(f"Total iterations: {total_steps}\n")
|
| log_file.write("="*100 + "\n")
|
|
|
|
|
|
|
| train_start_time = time.time()
|
|
|
| states_list, _ = parallel_env.reset()
|
| step = 0
|
| train_rewards = []
|
| last_eval_reward = float("nan")
|
|
|
| for it in progress_bar:
|
|
|
| agent.train()
|
| states_list, rewards, train_batch = parallel_env.rollout(cfg.batch_size, states_list, agent, random=(step < 2 * cfg.learning_start))
|
| step += np.prod(train_batch.reward.shape)
|
| train_rewards.extend(rewards)
|
|
|
| qf_loss = agent.update(
|
| train_batch.state,
|
| train_batch.action,
|
| train_batch.next_state,
|
| train_batch.q_values,
|
| train_batch.reward,
|
| train_batch.not_done)
|
|
|
|
|
| if it % log_interval == 0 and it > 0:
|
|
|
| elapsed_time = time.time() - train_start_time
|
| avg_time_per_it = elapsed_time / (it + 1)
|
| remaining_it = total_steps - it - 1
|
| remaining_time = remaining_it * avg_time_per_it
|
| total_estimated_time = elapsed_time + remaining_time
|
|
|
|
|
| elapsed_time_str = format_time(elapsed_time)
|
| remaining_time_str = format_time(remaining_time)
|
| total_time_str = format_time(total_estimated_time)
|
|
|
|
|
| progress_pct = (it + 1) / total_steps * 100
|
|
|
| train_r_mean = np.mean(train_rewards)
|
| writer.add_scalar("train r_sum", train_r_mean, step)
|
| writer.add_scalar("qf_loss", qf_loss, step)
|
| writer.add_scalar("training/elapsed_time_hours", elapsed_time/3600, step)
|
| writer.add_scalar("training/remaining_time_hours", remaining_time/3600, step)
|
|
|
| log_file.write(
|
| f"{it}/{total_steps} ({progress_pct:.1f}%), "
|
| f"elapsed_time={elapsed_time_str}, "
|
| f"remaining_time={remaining_time_str}, "
|
| f"total_estimated={total_time_str}, "
|
| f"reward={train_r_mean:.3f}, "
|
| f"eval_reward={last_eval_reward:.3f}, "
|
| f"qf_loss={float(qf_loss):.3f}, "
|
| f"step={step}\n"
|
| )
|
|
|
|
|
| if it % eval_interval == 0:
|
|
|
| agent.eval()
|
|
|
| r_eval = []
|
|
|
| eval_start_time = time.time()
|
| for j in range(cfg.eval_episodes):
|
| r_eval.append(evaluate(env_test, agent))
|
| print(f"\reval prog: {len(r_eval)}/{cfg.eval_episodes}", end="")
|
| eval_elapsed_time = time.time() - eval_start_time
|
|
|
| last_eval_reward = float(np.mean(r_eval))
|
| writer.add_scalar("eval r_sum", last_eval_reward, step)
|
| writer.add_scalar("eval/elapsed_time_seconds", eval_elapsed_time, step)
|
|
|
|
|
| elapsed_time = time.time() - train_start_time
|
| avg_time_per_it = elapsed_time / (it + 1)
|
| remaining_time = (total_steps - it - 1) * avg_time_per_it
|
|
|
| progress_bar.set_postfix({
|
| 'reward': np.mean(train_rewards),
|
| "eval_reward": last_eval_reward,
|
| 'qf_loss': qf_loss,
|
| 'step': step,
|
| 'elapsed': format_time(elapsed_time),
|
| 'remaining': format_time(remaining_time)
|
| })
|
| agent.save(ckpt_last_path)
|
|
|
|
|
| if last_eval_reward > best_eval_reward:
|
| best_eval_reward = last_eval_reward
|
| agent.save(ckpt_best_path)
|
|
|
|
|
|
|
| train_rewards = []
|
|
|
|
|
| total_train_time = time.time() - train_start_time
|
| log_file.write("\n" + "="*100 + "\n")
|
| log_file.write(f"Training finished at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
| log_file.write(f"Total training time: {format_time(total_train_time)}\n")
|
| log_file.write(f"Best eval reward: {best_eval_reward:.3f}\n")
|
| log_file.close()
|
|
|
|
|
| print(f"\n[INFO] Training completed!")
|
| print(f"Total training time: {format_time(total_train_time)}")
|
| print(f"Best evaluation reward: {best_eval_reward:.3f}")
|
| print(f"Logs saved to: {cfg.logger.log_dir}") |