TroyHow commited on
Commit
d427d3b
·
verified ·
1 Parent(s): 4b0a7d0
Files changed (1) hide show
  1. train_q_rag_log.py +205 -0
train_q_rag_log.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ repo_dir = os.path.dirname(os.path.abspath("./"))
5
+ if repo_dir not in sys.path:
6
+ print(f'add repository dir: {repo_dir}')
7
+ sys.path.append(repo_dir)
8
+
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ import torch
11
+ import sys
12
+ from rl.agents.pqn import PQN
13
+ import numpy as np
14
+ from envs.qa_env import QAEnv
15
+ from envs.parallel_env import ParallelTextEnv
16
+ from tqdm import tqdm
17
+ from omegaconf import OmegaConf, DictConfig
18
+ from hydra.utils import instantiate
19
+ from hydra import initialize, compose
20
+ import random
21
+ from datetime import datetime
22
+
23
+
24
+ @torch.no_grad()
25
+ def evaluate(env_test, agent):
26
+ s_t = env_test.reset()
27
+ done_t = False
28
+ a_embeds_t, a_embeds_target_t = env_test.get_extra_embeds(agent.action_tokenizer, agent.critic.action_embed, agent.action_embed_target)
29
+ r_sum_t = 0
30
+ while not done_t:
31
+ a_embeds_t = env_test.update_embeds(a_embeds_t, agent.critic.action_embed)
32
+ a_embeds_target_t = env_test.update_embeds(a_embeds_target_t, agent.action_embed_target)
33
+
34
+ action_t, _, _ = agent.select_action(s_t, a_embeds_t["rope"], a_embeds_target_t["rope"], random=False, evaluate=True)
35
+ s_t, _, reward_t, done_t = env_test.step(action_t)
36
+ r_sum_t += reward_t
37
+
38
+ return r_sum_t
39
+
40
+
41
+ def load_config(name, overrides=None):
42
+ with initialize(version_base="1.3", config_path="./configs"):
43
+
44
+ cfg = compose(
45
+ config_name=name,
46
+ overrides=sys.argv[1:] #overrides if overrides else []
47
+ )
48
+ #cli_cfg = OmegaConf.from_cli()
49
+ #cfg = OmegaConf.merge(cfg, cli_cfg)
50
+ cfg = prepare_config(cfg)
51
+ return cfg
52
+
53
+
54
+ def prepare_config(cfg):
55
+ """
56
+ modifies config for parameters that should depend on each other
57
+ """
58
+ if cfg.logger.log_dir is not None:
59
+ dir_name = datetime.now().strftime("%b%d_%H-%M-%S") + cfg.logger.tensorboard.comment
60
+ cfg.logger.log_dir = os.path.join(cfg.logger.log_dir, dir_name)
61
+ cfg.logger.tensorboard.log_dir = os.path.join(cfg.logger.log_dir, 'tb_logs/')
62
+
63
+ # enumerate_facts = (cfg.positional_coding == 'enum') #TODO: add version that enumerate all chunks
64
+ # cfg.envs.env.dataset.task_dataset.add_sentence_idx = enumerate_facts
65
+ # cfg.envs.test_env.dataset.task_dataset.add_sentence_idx = enumerate_facts
66
+ return cfg
67
+
68
+
69
+ def set_all_seeds(seed):
70
+ random.seed(seed)
71
+ np.random.seed(seed)
72
+ torch.manual_seed(seed)
73
+ torch.cuda.manual_seed(seed)
74
+ torch.backends.cudnn.deterministic = True
75
+
76
+
77
+ cfg: DictConfig = load_config(name="training.yaml")
78
+ #cfg: DictConfig = load_config(name="training_gte_combined.yaml")
79
+
80
+ writer: SummaryWriter = instantiate(cfg.logger.tensorboard)
81
+ os.makedirs(cfg.logger.log_dir, exist_ok=True)
82
+ config_save_path = os.path.join(cfg.logger.log_dir, "config.yaml")
83
+ OmegaConf.save(config=cfg, f=config_save_path, resolve=False)
84
+ print(f"[INFO] Training config saved to {config_save_path}")
85
+
86
+ agent_config: DictConfig = cfg.algo
87
+ env_config: DictConfig = cfg.envs
88
+ print("Embedder model:", agent_config.model.model_name)
89
+
90
+ # path to checkpoints and metric to determine the best model
91
+ ckpt_last_path = os.path.join(cfg.logger.log_dir, "model_last.pt")
92
+ ckpt_best_path = os.path.join(cfg.logger.log_dir, "model_best.pt")
93
+ best_eval_reward = -float("inf")
94
+
95
+ torch.set_default_device(cfg.device)
96
+ torch.set_float32_matmul_precision('high')
97
+ set_all_seeds(cfg.seed)
98
+
99
+ # MAX_TOKEN_LENGTH["state"] = cfg.max_state_length
100
+ # MAX_TOKEN_LENGTH["action"] = cfg.max_action_length
101
+
102
+ agent = PQN(agent_config)
103
+
104
+ # if bf16:
105
+ # for m in [agent.critic, agent.policy, agent.random_policy,
106
+ # agent.v_net_target, agent.action_embed_target]:
107
+ # m.to(dtype=torch.bfloat16)
108
+ #
109
+ # if args.fp16:
110
+ # # import apex
111
+ # # apex.amp.register_half_function(torch, 'einsum')
112
+ # from torch.cuda.amp import autocast, GradScaler
113
+ #
114
+ # scaler = GradScaler()
115
+ #
116
+ # device_type = torch.device(cfg.device).type
117
+ # amp_dtype = torch.bfloat16 if bf16 else torch.float16
118
+ # amp_enabled = bf16 or mixed_precision
119
+ # autocast = torch.cuda.amp.autocast if device_type == 'cuda' else torch.autocast
120
+
121
+ env: QAEnv = instantiate(env_config.env)
122
+ env_test: QAEnv = instantiate(env_config.test_env)
123
+ parallel_env = ParallelTextEnv(
124
+ [env] + [env.copy() for _ in range(cfg.envs_parallel - 1)],
125
+ state_tokenizer=agent.state_tokenizer,
126
+ action_tokenizer=agent.action_tokenizer)
127
+
128
+ total_steps = cfg.steps_count * cfg.accumulate_grads # 80000
129
+ eval_interval = cfg.eval_interval * cfg.accumulate_grads # 50 * 8 = 400 it
130
+ log_interval = cfg.accumulate_grads * 100 // cfg.accumulate_grads # 固定每 100 it 记录一次
131
+ log_interval = 100 # 每 100 it 记录 reward / qf_loss
132
+
133
+ #assuming we don't need to scale cfg.learning_start with grad_accumulation
134
+ progress_bar = tqdm(range(total_steps), desc="Training")
135
+
136
+ # ---------- log.txt 初始化 ----------
137
+ log_path = os.path.join(cfg.logger.log_dir, "log.txt")
138
+ log_file = open(log_path, "w", buffering=1) # buffering=1: 行缓冲,实时落盘
139
+ print(f"[INFO] Log file saved to {log_path}")
140
+ # ------------------------------------
141
+
142
+ states_list, _ = parallel_env.reset()
143
+ step = 0
144
+ train_rewards = []
145
+ last_eval_reward = float("nan") # 还没跑过 eval 时显示 nan
146
+
147
+ for it in progress_bar:
148
+
149
+ agent.train()
150
+ states_list, rewards, train_batch = parallel_env.rollout(cfg.batch_size, states_list, agent, random=(step < 2 * cfg.learning_start))
151
+ step += np.prod(train_batch.reward.shape)
152
+ train_rewards.extend(rewards)
153
+
154
+ qf_loss = agent.update(
155
+ train_batch.state,
156
+ train_batch.action,
157
+ train_batch.next_state,
158
+ train_batch.q_values,
159
+ train_batch.reward,
160
+ train_batch.not_done)
161
+
162
+ # ---- 每 100 it:记录 reward / qf_loss 到 tensorboard + log.txt ----
163
+ if it % log_interval == 0 and it > 0:
164
+ train_r_mean = np.mean(train_rewards)
165
+ writer.add_scalar("train r_sum", train_r_mean, step)
166
+ writer.add_scalar("qf_loss", qf_loss, step)
167
+ log_file.write(
168
+ f"{it}/{total_steps}, "
169
+ f"reward={train_r_mean:.3f}, "
170
+ f"eval_reward={last_eval_reward:.3f}, "
171
+ f"qf_loss={float(qf_loss):.3f}, "
172
+ f"step={step}\n"
173
+ )
174
+
175
+ # ---- 每 400 it(eval_interval=50×8):跑 eval,更新 eval_reward ----
176
+ if it % eval_interval == 0:
177
+
178
+ agent.eval()
179
+
180
+ r_eval = []
181
+ for j in range(cfg.eval_episodes):
182
+ r_eval.append(evaluate(env_test, agent))
183
+ print(f"\reval prog: {len(r_eval)}/{cfg.eval_episodes}", end="")
184
+
185
+ last_eval_reward = float(np.mean(r_eval))
186
+ writer.add_scalar("eval r_sum", last_eval_reward, step)
187
+
188
+ progress_bar.set_postfix({
189
+ 'reward': np.mean(train_rewards),
190
+ "eval_reward": last_eval_reward,
191
+ 'qf_loss': qf_loss,
192
+ 'step': step,
193
+ })
194
+ agent.save(ckpt_last_path)
195
+ #torch.save(agent.state_dict(), ckpt_last_path)
196
+
197
+ if last_eval_reward > best_eval_reward:
198
+ best_eval_reward = last_eval_reward
199
+ agent.save(ckpt_best_path)
200
+ #torch.save(agent.state_dict(), ckpt_best_path)
201
+ #print(f"[INFO] New best model saved with reward {best_eval_reward:.3f}")
202
+
203
+ train_rewards = []
204
+
205
+ log_file.close()