TroyHow commited on
Commit
c29f6a5
·
verified ·
1 Parent(s): d427d3b

log with time record

Browse files
Files changed (1) hide show
  1. train_q_rag_logt.py +265 -0
train_q_rag_logt.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import time # 新增:导入时间模块
4
+
5
+ repo_dir = os.path.dirname(os.path.abspath("./"))
6
+ if repo_dir not in sys.path:
7
+ print(f'add repository dir: {repo_dir}')
8
+ sys.path.append(repo_dir)
9
+
10
+ from torch.utils.tensorboard import SummaryWriter
11
+ import torch
12
+ import sys
13
+ from rl.agents.pqn import PQN
14
+ import numpy as np
15
+ from envs.qa_env import QAEnv
16
+ from envs.parallel_env import ParallelTextEnv
17
+ from tqdm import tqdm
18
+ from omegaconf import OmegaConf, DictConfig
19
+ from hydra.utils import instantiate
20
+ from hydra import initialize, compose
21
+ import random
22
+ from datetime import datetime
23
+
24
+
25
+ @torch.no_grad()
26
+ def evaluate(env_test, agent):
27
+ s_t = env_test.reset()
28
+ done_t = False
29
+ a_embeds_t, a_embeds_target_t = env_test.get_extra_embeds(agent.action_tokenizer, agent.critic.action_embed, agent.action_embed_target)
30
+ r_sum_t = 0
31
+ while not done_t:
32
+ a_embeds_t = env_test.update_embeds(a_embeds_t, agent.critic.action_embed)
33
+ a_embeds_target_t = env_test.update_embeds(a_embeds_target_t, agent.action_embed_target)
34
+
35
+ action_t, _, _ = agent.select_action(s_t, a_embeds_t["rope"], a_embeds_target_t["rope"], random=False, evaluate=True)
36
+ s_t, _, reward_t, done_t = env_test.step(action_t)
37
+ r_sum_t += reward_t
38
+
39
+ return r_sum_t
40
+
41
+
42
+ def load_config(name, overrides=None):
43
+ with initialize(version_base="1.3", config_path="./configs"):
44
+
45
+ cfg = compose(
46
+ config_name=name,
47
+ overrides=sys.argv[1:] #overrides if overrides else []
48
+ )
49
+ #cli_cfg = OmegaConf.from_cli()
50
+ #cfg = OmegaConf.merge(cfg, cli_cfg)
51
+ cfg = prepare_config(cfg)
52
+ return cfg
53
+
54
+
55
+ def prepare_config(cfg):
56
+ """
57
+ modifies config for parameters that should depend on each other
58
+ """
59
+ if cfg.logger.log_dir is not None:
60
+ dir_name = datetime.now().strftime("%b%d_%H-%M-%S") + cfg.logger.tensorboard.comment
61
+ cfg.logger.log_dir = os.path.join(cfg.logger.log_dir, dir_name)
62
+ cfg.logger.tensorboard.log_dir = os.path.join(cfg.logger.log_dir, 'tb_logs/')
63
+
64
+ # enumerate_facts = (cfg.positional_coding == 'enum') #TODO: add version that enumerate all chunks
65
+ # cfg.envs.env.dataset.task_dataset.add_sentence_idx = enumerate_facts
66
+ # cfg.envs.test_env.dataset.task_dataset.add_sentence_idx = enumerate_facts
67
+ return cfg
68
+
69
+
70
+ def set_all_seeds(seed):
71
+ random.seed(seed)
72
+ np.random.seed(seed)
73
+ torch.manual_seed(seed)
74
+ torch.cuda.manual_seed(seed)
75
+ torch.backends.cudnn.deterministic = True
76
+
77
+
78
+ # 格式化时间(秒转时分秒)
79
+ def format_time(seconds):
80
+ hours = int(seconds // 3600)
81
+ minutes = int((seconds % 3600) // 60)
82
+ secs = int(seconds % 60)
83
+ return f"{hours:02d}:{minutes:02d}:{secs:02d}"
84
+
85
+
86
+ cfg: DictConfig = load_config(name="training.yaml")
87
+ #cfg: DictConfig = load_config(name="training_gte_combined.yaml")
88
+
89
+ writer: SummaryWriter = instantiate(cfg.logger.tensorboard)
90
+ os.makedirs(cfg.logger.log_dir, exist_ok=True)
91
+ config_save_path = os.path.join(cfg.logger.log_dir, "config.yaml")
92
+ OmegaConf.save(config=cfg, f=config_save_path, resolve=False)
93
+ print(f"[INFO] Training config saved to {config_save_path}")
94
+
95
+ agent_config: DictConfig = cfg.algo
96
+ env_config: DictConfig = cfg.envs
97
+ print("Embedder model:", agent_config.model.model_name)
98
+
99
+ # path to checkpoints and metric to determine the best model
100
+ ckpt_last_path = os.path.join(cfg.logger.log_dir, "model_last.pt")
101
+ ckpt_best_path = os.path.join(cfg.logger.log_dir, "model_best.pt")
102
+ best_eval_reward = -float("inf")
103
+
104
+ torch.set_default_device(cfg.device)
105
+ torch.set_float32_matmul_precision('high')
106
+ set_all_seeds(cfg.seed)
107
+
108
+ # MAX_TOKEN_LENGTH["state"] = cfg.max_state_length
109
+ # MAX_TOKEN_LENGTH["action"] = cfg.max_action_length
110
+
111
+ agent = PQN(agent_config)
112
+
113
+ # if bf16:
114
+ # for m in [agent.critic, agent.policy, agent.random_policy,
115
+ # agent.v_net_target, agent.action_embed_target]:
116
+ # m.to(dtype=torch.bfloat16)
117
+ #
118
+ # if args.fp16:
119
+ # # import apex
120
+ # # apex.amp.register_half_function(torch, 'einsum')
121
+ # from torch.cuda.amp import autocast, GradScaler
122
+ #
123
+ # scaler = GradScaler()
124
+ #
125
+ # device_type = torch.device(cfg.device).type
126
+ # amp_dtype = torch.bfloat16 if bf16 else torch.float16
127
+ # amp_enabled = bf16 or mixed_precision
128
+ # autocast = torch.cuda.amp.autocast if device_type == 'cuda' else torch.autocast
129
+
130
+ env: QAEnv = instantiate(env_config.env)
131
+ env_test: QAEnv = instantiate(env_config.test_env)
132
+ parallel_env = ParallelTextEnv(
133
+ [env] + [env.copy() for _ in range(cfg.envs_parallel - 1)],
134
+ state_tokenizer=agent.state_tokenizer,
135
+ action_tokenizer=agent.action_tokenizer)
136
+
137
+ total_steps = cfg.steps_count * cfg.accumulate_grads # 80000
138
+ eval_interval = cfg.eval_interval * cfg.accumulate_grads # 50 * 8 = 400 it
139
+ log_interval = cfg.accumulate_grads * 100 // cfg.accumulate_grads # 固定每 100 it 记录一次
140
+ log_interval = 100 # 每 100 it 记录 reward / qf_loss
141
+
142
+ #assuming we don't need to scale cfg.learning_start with grad_accumulation
143
+ progress_bar = tqdm(range(total_steps), desc="Training")
144
+
145
+ # ---------- log.txt 初始化 ----------
146
+ log_path = os.path.join(cfg.logger.log_dir, "log.txt")
147
+ log_file = open(log_path, "w", buffering=1) # buffering=1: 行缓冲,实时落盘
148
+ print(f"[INFO] Log file saved to {log_path}")
149
+ # 写入日志头部信息
150
+ log_file.write(f"Training started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
151
+ log_file.write(f"Total iterations: {total_steps}\n")
152
+ log_file.write("="*100 + "\n")
153
+ # ------------------------------------
154
+
155
+ # 记录训练开始时间
156
+ train_start_time = time.time()
157
+
158
+ states_list, _ = parallel_env.reset()
159
+ step = 0
160
+ train_rewards = []
161
+ last_eval_reward = float("nan") # 还没跑过 eval 时显示 nan
162
+
163
+ for it in progress_bar:
164
+
165
+ agent.train()
166
+ states_list, rewards, train_batch = parallel_env.rollout(cfg.batch_size, states_list, agent, random=(step < 2 * cfg.learning_start))
167
+ step += np.prod(train_batch.reward.shape)
168
+ train_rewards.extend(rewards)
169
+
170
+ qf_loss = agent.update(
171
+ train_batch.state,
172
+ train_batch.action,
173
+ train_batch.next_state,
174
+ train_batch.q_values,
175
+ train_batch.reward,
176
+ train_batch.not_done)
177
+
178
+ # ---- 每 100 it:记录 reward / qf_loss 到 tensorboard + log.txt ----
179
+ if it % log_interval == 0 and it > 0:
180
+ # 计算时间相关信息
181
+ elapsed_time = time.time() - train_start_time
182
+ avg_time_per_it = elapsed_time / (it + 1)
183
+ remaining_it = total_steps - it - 1
184
+ remaining_time = remaining_it * avg_time_per_it
185
+ total_estimated_time = elapsed_time + remaining_time
186
+
187
+ # 格式化时间
188
+ elapsed_time_str = format_time(elapsed_time)
189
+ remaining_time_str = format_time(remaining_time)
190
+ total_time_str = format_time(total_estimated_time)
191
+
192
+ # 计算进度百分比
193
+ progress_pct = (it + 1) / total_steps * 100
194
+
195
+ train_r_mean = np.mean(train_rewards)
196
+ writer.add_scalar("train r_sum", train_r_mean, step)
197
+ writer.add_scalar("qf_loss", qf_loss, step)
198
+ writer.add_scalar("training/elapsed_time_hours", elapsed_time/3600, step)
199
+ writer.add_scalar("training/remaining_time_hours", remaining_time/3600, step)
200
+
201
+ log_file.write(
202
+ f"{it}/{total_steps} ({progress_pct:.1f}%), "
203
+ f"elapsed_time={elapsed_time_str}, "
204
+ f"remaining_time={remaining_time_str}, "
205
+ f"total_estimated={total_time_str}, "
206
+ f"reward={train_r_mean:.3f}, "
207
+ f"eval_reward={last_eval_reward:.3f}, "
208
+ f"qf_loss={float(qf_loss):.3f}, "
209
+ f"step={step}\n"
210
+ )
211
+
212
+ # ---- 每 400 it(eval_interval=50×8):跑 eval,更新 eval_reward ----
213
+ if it % eval_interval == 0:
214
+
215
+ agent.eval()
216
+
217
+ r_eval = []
218
+ # 记录eval开始时间
219
+ eval_start_time = time.time()
220
+ for j in range(cfg.eval_episodes):
221
+ r_eval.append(evaluate(env_test, agent))
222
+ print(f"\reval prog: {len(r_eval)}/{cfg.eval_episodes}", end="")
223
+ eval_elapsed_time = time.time() - eval_start_time
224
+
225
+ last_eval_reward = float(np.mean(r_eval))
226
+ writer.add_scalar("eval r_sum", last_eval_reward, step)
227
+ writer.add_scalar("eval/elapsed_time_seconds", eval_elapsed_time, step)
228
+
229
+ # 更新进度条,添加时间信息
230
+ elapsed_time = time.time() - train_start_time
231
+ avg_time_per_it = elapsed_time / (it + 1)
232
+ remaining_time = (total_steps - it - 1) * avg_time_per_it
233
+
234
+ progress_bar.set_postfix({
235
+ 'reward': np.mean(train_rewards),
236
+ "eval_reward": last_eval_reward,
237
+ 'qf_loss': qf_loss,
238
+ 'step': step,
239
+ 'elapsed': format_time(elapsed_time),
240
+ 'remaining': format_time(remaining_time)
241
+ })
242
+ agent.save(ckpt_last_path)
243
+ #torch.save(agent.state_dict(), ckpt_last_path)
244
+
245
+ if last_eval_reward > best_eval_reward:
246
+ best_eval_reward = last_eval_reward
247
+ agent.save(ckpt_best_path)
248
+ #torch.save(agent.state_dict(), ckpt_best_path)
249
+ #print(f"[INFO] New best model saved with reward {best_eval_reward:.3f}")
250
+
251
+ train_rewards = []
252
+
253
+ # 训练结束,记录最终时间信息
254
+ total_train_time = time.time() - train_start_time
255
+ log_file.write("\n" + "="*100 + "\n")
256
+ log_file.write(f"Training finished at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
257
+ log_file.write(f"Total training time: {format_time(total_train_time)}\n")
258
+ log_file.write(f"Best eval reward: {best_eval_reward:.3f}\n")
259
+ log_file.close()
260
+
261
+ # 打印最终统计信息
262
+ print(f"\n[INFO] Training completed!")
263
+ print(f"Total training time: {format_time(total_train_time)}")
264
+ print(f"Best evaluation reward: {best_eval_reward:.3f}")
265
+ print(f"Logs saved to: {cfg.logger.log_dir}")