0xZohar commited on
Commit
d6a9e66
·
verified ·
1 Parent(s): 19c667c

Add code/cube3d/training/trainer.py

Browse files
Files changed (1) hide show
  1. code/cube3d/training/trainer.py +386 -0
code/cube3d/training/trainer.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple training loop; Boilerplate that could apply to any arbitrary neural network,
3
+ so nothing in this file really has anything to do with GPT specifically.
4
+ """
5
+ from typing import Optional, Tuple, List
6
+ import time
7
+ import os
8
+ from collections import defaultdict
9
+
10
+ from accelerate import Accelerator
11
+ import torch
12
+ from torch.nn import functional as F
13
+ from torch.utils.data.dataloader import DataLoader
14
+ from mingpt.utils import CfgNode as CN
15
+ from cube3d.training.utils import save_model_weights, mask_cross_entropy, normalize_bboxs, top_k_prob_mask
16
+ from cube3d.training.process_single_ldr import logits2ldr, logits2ldrot, logits2ldrp, logits2flatldrp, logits2flatldrpr
17
+ from cube3d.inference.utils import load_model_weights
18
+ from tqdm import tqdm
19
+
20
+
21
+ def generate_tokens(
22
+ engine,
23
+ prompt,
24
+ inputs_ids,
25
+ latent,
26
+ resolution_base=8.0,
27
+ disable_postprocess=False,
28
+ top_p=None,
29
+ bounding_box_xyz=None,
30
+ strategy=None
31
+ ):
32
+ output_ids = engine.t2t(
33
+ #[prompt],
34
+ prompt,
35
+ #use_kv_cache=True,
36
+ inputs_ids=inputs_ids,
37
+ latent=latent,
38
+ use_kv_cache=False,
39
+ resolution_base=resolution_base,
40
+ top_p=top_p,
41
+ bounding_box_xyz=bounding_box_xyz,
42
+ strategy=strategy
43
+ )
44
+
45
+ return output_ids
46
+
47
+ class Trainer:
48
+
49
+ @staticmethod
50
+ def get_default_config():
51
+ C = CN()
52
+ # device to train on
53
+ C.device = 'auto'
54
+ # dataloder parameters
55
+ C.num_workers = 4
56
+ # optimizer parameters
57
+ C.max_iters = None
58
+ C.batch_size = 4
59
+ C.learning_rate = 3e-4
60
+ C.betas = (0.9, 0.95)
61
+ C.weight_decay = 0.1 # only applied on matmul weights
62
+ C.grad_norm_clip = 1.0
63
+ C.save_interval = None
64
+ return C
65
+
66
+ def __init__(
67
+ self,
68
+ config,
69
+ engine,
70
+ train_dataset,
71
+ accelerator,
72
+ tb,
73
+ prompt: str,
74
+ indices: Optional[List[int]] = None,
75
+ resolution_base: float = 8.0,
76
+ disable_postprocessing: bool = False,
77
+ top_p: float = None,
78
+ bounding_box_xyz: Optional[Tuple[float]] = None,
79
+ save_gpt_ckpt_path: str = None,
80
+ mode: str = 'train'
81
+ ):
82
+ self.config = config
83
+ self.engine = engine
84
+ self.model = engine.gpt_model
85
+ self.optimizer = None
86
+ self.callbacks = defaultdict(list)
87
+ self.train_dataset = train_dataset
88
+ self.accelerator = accelerator
89
+
90
+ # Training parameters
91
+ self.prompt = prompt
92
+ self.targets = indices
93
+ self.resolution_base = resolution_base
94
+ self.disable_postprocessing = disable_postprocessing
95
+ self.top_p = top_p
96
+ self.bounding_box_xyz = bounding_box_xyz
97
+ self.save_gpt_ckpt_path = save_gpt_ckpt_path
98
+
99
+ # determine the device we'll train on
100
+ if config.device == 'auto':
101
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
102
+ else:
103
+ self.device = config.device
104
+
105
+ self.model = self.model.to(self.device)
106
+ print("running on device", self.device)
107
+
108
+ # variables that will be assigned to trainer class later for logging and etc
109
+ self.iter_num = 0
110
+ self.iter_time = 0.0
111
+ self.iter_dt = 0.0
112
+
113
+ self.tb_writer = tb
114
+ self.mode = mode
115
+
116
+
117
+ def add_callback(self, onevent: str, callback):
118
+ self.callbacks[onevent].append(callback)
119
+
120
+ def set_callback(self, onevent: str, callback):
121
+ self.callbacks[onevent] = [callback]
122
+
123
+ def trigger_callbacks(self, onevent: str):
124
+ for callback in self.callbacks.get(onevent, []):
125
+ callback(self)
126
+
127
+ def run(self):
128
+ model, config = self.model, self.config
129
+ # setup the optimizer
130
+ #self.optimizer = self.engine.configure_optimizers(config)
131
+ self.optimizer, self.scheduler = self.engine.configure_optimizers_scratch_linear(config) #self.engine.configure_optimizers_lora_linear(config)
132
+
133
+ # setup the dataloader
134
+ train_loader = DataLoader(
135
+ self.train_dataset,
136
+ shuffle=False if self.mode!='train' else True,
137
+ batch_size=config.batch_size,
138
+ )
139
+
140
+ model.train()
141
+
142
+ model, self.optimizer, train_loader = self.accelerator.prepare(model, self.optimizer, train_loader)
143
+
144
+ self.iter_num = 0
145
+ self.iter_time = time.time()
146
+ data_iter = iter(train_loader)
147
+ ema_loss_for_log = 0.0
148
+ ema_ploss_for_log = 0.0
149
+ ema_rloss_for_log = 0.0
150
+ ema_dloss_for_log = 0.0
151
+ ema_floss_for_log = 0.0
152
+
153
+ #loss
154
+ dat_num = 1217 #286
155
+ x_num = 251
156
+ y_num = 215
157
+ z_num = 525
158
+ rot_num = 24
159
+ shift = 0
160
+ stride = 5
161
+ attr_shift = stride-3 #with dat and rot,+1 for bert
162
+ bert_shift = 1
163
+
164
+ x = x_num
165
+ xy = x_num + y_num + rot_num
166
+ xyz = x_num + y_num + z_num + rot_num
167
+
168
+ progress_bar = tqdm(range(0, config.max_iters), desc="Training progress")
169
+ #while True:
170
+ for self.iter_num in range(0, config.max_iters+1):
171
+ # fetch the next batch (x, y) and re-init iterator if needed
172
+ try:
173
+ batch = next(data_iter)
174
+ except StopIteration:
175
+ data_iter = iter(train_loader)
176
+ batch = next(data_iter)
177
+
178
+ #batch = [t['latent'].to(self.device) for t in batch]
179
+ self.prompt, self.targets, self.box = batch['prompt'], batch['target'].to(self.device), batch['bbox']
180
+ #self.targets = batch['latent'].to(self.device)
181
+ targets = self.targets.clone()
182
+ logits, inputs_ids, strategy, mask, cut_idx = generate_tokens(
183
+ self.engine,
184
+ self.prompt,
185
+ targets,
186
+ None,
187
+ self.resolution_base,
188
+ self.disable_postprocessing,
189
+ self.top_p,
190
+ #self.bounding_box_xyz,
191
+ normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box)
192
+ None
193
+ )
194
+
195
+
196
+ # rotation_loss = F.cross_entropy(
197
+ # logits[:,:-1,:rot_num].permute(0, 2, 1),
198
+ # inputs_ids[:,shift:,:rot_num].argmax(-1),
199
+ # )
200
+
201
+
202
+ # px_loss = mask_cross_entropy(rot_num, x+rot_num, self.box[:, 0], logits, inputs_ids, shift)
203
+ # py_loss = mask_cross_entropy(x+rot_num, xy, self.box[:, 1], logits, inputs_ids, shift)
204
+ # pz_loss = mask_cross_entropy(xy, xyz, self.box[:, 2], logits, inputs_ids, shift)
205
+
206
+ px_loss = F.cross_entropy(
207
+ logits[:,1+attr_shift+bert_shift:-1:stride,rot_num+1:x+rot_num+1+1].permute(0, 2, 1),
208
+ inputs_ids[:,shift:,-5],
209
+ ignore_index=-1 #+1 for padding
210
+ )
211
+ py_loss = F.cross_entropy(
212
+ logits[:,0+attr_shift+bert_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1),
213
+ inputs_ids[:,shift:,-4],
214
+ ignore_index=-1
215
+ )
216
+ pz_loss = F.cross_entropy(
217
+ logits[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4].permute(0, 2, 1),
218
+ inputs_ids[:,shift:,-3],
219
+ ignore_index=-1
220
+ )
221
+
222
+ position_loss = px_loss + py_loss + pz_loss
223
+
224
+ # dat_loss = F.cross_entropy(
225
+ # logits[:,0:-4:stride,:dat_num+1].permute(0, 2, 1),
226
+ # inputs_ids[:,shift:,-6],
227
+ # ignore_index=-1
228
+ # )
229
+
230
+ rotation_loss = F.cross_entropy(
231
+ logits[:,1+bert_shift:-3:stride,:rot_num+1].permute(0, 2, 1),
232
+ inputs_ids[:,shift:,-7],
233
+ ignore_index=-1
234
+ )
235
+
236
+
237
+ # flag_loss = F.cross_entropy(
238
+ # logits[:,:-1,xyz+dat_num:xyz+dat_num+2].permute(0, 2, 1),
239
+ # inputs_ids[:,shift:,xyz+dat_num:xyz+dat_num+2].argmax(-1),
240
+ # )
241
+
242
+ # flag_loss = F.cross_entropy(
243
+ # logits[:,:-1,-2:].permute(0, 2, 1),
244
+ # inputs_ids[:,shift:,-2:].argmax(-1),
245
+ # )
246
+
247
+ lambda_posiition = 1.0
248
+ lambda_rotation = 1.0
249
+ lambda_dat = 1.0
250
+ lambda_flag = 50.0
251
+
252
+ self.loss = lambda_posiition * position_loss #+ \
253
+ #lambda_rotation * rotation_loss #+ \
254
+ #lambda_flag * flag_loss
255
+ #lambda_dat * dat_loss + \
256
+
257
+ if strategy==1 or strategy==2:
258
+ self.loss+=lambda_rotation * rotation_loss
259
+
260
+
261
+ # targets = self.targets.clone()
262
+ # # mask_topk, mask_inv = top_k_prob_mask(F.softmax(logits[:, 1:-3:stride, :rot_num+1], dim=2), cut_idx, top_percent=0.5)
263
+ # # targets[:,shift:,-7][mask_topk] = logits[:,1:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1)[mask_topk]
264
+ # # targets[:,shift:,-7][mask_inv] = self.engine.gpt_model.rot_num+1
265
+
266
+ # targets[:,shift:,-7] = logits[:,1:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1)
267
+ # #targets[:,shift:,-4] = logits_y[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1).argmax(dim=1)
268
+ # logits_x, inputs_ids, strategy, mask, cut_idx = generate_tokens(
269
+ # self.engine,
270
+ # self.prompt,
271
+ # targets,
272
+ # None,
273
+ # self.resolution_base,
274
+ # self.disable_postprocessing,
275
+ # self.top_p,
276
+ # #self.bounding_box_xyz,
277
+ # normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box)
278
+ # 0
279
+ # )
280
+
281
+ # targets = self.targets.clone()
282
+ # targets[:,shift:,-7] = logits_x[:,1+bert_shift:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1)
283
+
284
+ # mask_x, mask_x_inv = top_k_prob_mask(F.softmax(logits[:,1+attr_shift+bert_shift:-1:stride,rot_num+1:x+rot_num+1+1], dim=2), cut_idx, top_percent=0.5)
285
+ # mask_y, mask_y_inv = top_k_prob_mask(F.softmax(logits[:,0+attr_shift+bert_shift:-2:stride,x+rot_num+2:xy+3], dim=2), cut_idx, top_percent=0.5)
286
+ # mask_z, mask_z_inv = top_k_prob_mask(F.softmax(logits[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4], dim=2), cut_idx, top_percent=0.5)
287
+
288
+ # targets[:,shift:,-5][mask_x] = logits_x[:,1+attr_shift+bert_shift:-1:stride,rot_num+1:x+rot_num+1+1].permute(0, 2, 1).argmax(dim=1)[mask_x]
289
+ # targets[:,shift:,-5][mask_x_inv] = self.engine.gpt_model.x_num+1
290
+ # targets[:,shift:,-4][mask_y] = logits_x[:,0+attr_shift+bert_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1).argmax(dim=1)[mask_y]
291
+ # targets[:,shift:,-4][mask_y_inv] = self.engine.gpt_model.y_num+1
292
+ # targets[:,shift:,-3][mask_z] = logits_x[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4].permute(0, 2, 1).argmax(dim=1)[mask_z]
293
+ # targets[:,shift:,-3][mask_z_inv] = self.engine.gpt_model.z_num+1
294
+ # logits_p, inputs_ids, strategy, mask, cut_idx = generate_tokens(
295
+ # self.engine,
296
+ # self.prompt,
297
+ # targets,
298
+ # None,
299
+ # self.resolution_base,
300
+ # self.disable_postprocessing,
301
+ # self.top_p,
302
+ # #self.bounding_box_xyz,
303
+ # normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box)
304
+ # None
305
+ # )
306
+
307
+ # logits_p[:,1+bert_shift:-3:stride,:rot_num+1] = logits[:,1+bert_shift:-3:stride,:rot_num+1]
308
+ # logits2flatldrpr(logits_p[0].cpu().detach().numpy(), inputs_ids[0].cpu().detach().numpy(), stride, 0, output_file=f"test_rightd2r2p2p_{self.iter_num}_scratch_0p5_bert.ldr")
309
+
310
+ # targets = self.targets.clone()
311
+ # targets[:,shift:,-7] = logits[:,1:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1)
312
+ # targets[:,shift:,-4] = logits_y[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1).argmax(dim=1)
313
+ # targets[:,shift:,-5] = logits_x[:,1+attr_shift:-1:stride,rot_num+1:x+rot_num+1+1].permute(0, 2, 1).argmax(dim=1)
314
+ # logits_z, inputs_ids, strategy = generate_tokens(
315
+ # self.engine,
316
+ # self.prompt,
317
+ # targets,
318
+ # None,
319
+ # self.resolution_base,
320
+ # self.disable_postprocessing,
321
+ # self.top_p,
322
+ # #self.bounding_box_xyz,
323
+ # normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box)
324
+ # 3
325
+ # )
326
+
327
+ # backprop and update the parameters
328
+ model.zero_grad(set_to_none=True)
329
+ # #if self.mode!='train':
330
+ # logits_z[:,1:-3:stride,:rot_num+1] = logits[:,1:-3:stride,:rot_num+1]
331
+ # logits_z[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3] = logits_y[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3]
332
+ # logits_z[:,1+attr_shift:-1:stride,rot_num+1:x+rot_num+1+1] = logits_x[:,1+attr_shift:-1:stride,rot_num+1:x+rot_num+1+1]
333
+
334
+ # if self.iter_num>4:
335
+ # break
336
+
337
+ self.accelerator.backward(self.loss)
338
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
339
+ self.optimizer.step()
340
+ self.scheduler.step()
341
+
342
+ with torch.no_grad():
343
+ # Progress bar
344
+ ema_loss_for_log = 0.4 * self.loss.item() + 0.6 * ema_loss_for_log
345
+ ema_ploss_for_log = 0.4 * position_loss.item() + 0.6 * ema_ploss_for_log
346
+ ema_rloss_for_log = 0.4 * rotation_loss.item() + 0.6 * ema_rloss_for_log
347
+ #ema_dloss_for_log = 0.4 * dat_loss.item() + 0.6 * ema_dloss_for_log
348
+ #ema_floss_for_log = 0.4 * flag_loss.item() + 0.6 * ema_floss_for_log
349
+ if self.iter_num % 10 == 0:
350
+ progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}",
351
+ "Positon_Loss": f"{ema_ploss_for_log:.{7}f}",
352
+ "Rotation_Loss": f"{ema_rloss_for_log:.{7}f}",
353
+ #"Dat_Loss": f"{ema_dloss_for_log:.{7}f}",
354
+ #"Flag_Loss": f"{ema_floss_for_log:.{7}f}",
355
+ })
356
+ progress_bar.update(10)
357
+
358
+ #logits2ldr(logits[0].cpu().detach().numpy())
359
+
360
+ if (self.iter_num % config.save_interval == 0 and self.iter_num != 0):
361
+ if self.accelerator.is_main_process:
362
+ save_model_weights(
363
+ self.engine.gpt_model,
364
+ self.save_gpt_ckpt_path,
365
+ )
366
+ # self.engine.gpt_model.save_pretrained(self.save_gpt_ckpt_path)
367
+ # torch.save({
368
+ # "ldr_proj": self.engine.gpt_model.ldr_proj.state_dict(),
369
+ # "ldr_head": self.engine.gpt_model.ldr_head.state_dict(),
370
+ # "rte": self.engine.gpt_model.rte.state_dict(),
371
+ # "dte": self.engine.gpt_model.dte.state_dict(),
372
+ # "xte": self.engine.gpt_model.xte.state_dict(),
373
+ # "yte": self.engine.gpt_model.yte.state_dict(),
374
+ # "zte": self.engine.gpt_model.zte.state_dict(),
375
+ # }, f"{self.save_gpt_ckpt_path}/unfrozen_weights.pth")
376
+
377
+
378
+ if self.tb_writer: #and self.accelerator.is_main_process:
379
+ self.tb_writer.add_scalar(f'train_loss/position_loss', position_loss.item(), self.iter_num)
380
+ self.tb_writer.add_scalar(f'train_loss/rotation_loss', rotation_loss.item(), self.iter_num)
381
+ #self.tb_writer.add_scalar(f'train_loss/dat_loss', dat_loss.item(), self.iter_num)
382
+ #self.tb_writer.add_scalar(f'train_loss/flag_loss', flag_loss.item(), self.iter_num)
383
+ self.tb_writer.add_scalar(f'train_loss/total_loss', self.loss.item(), self.iter_num)
384
+
385
+ if self.iter_num == config.max_iters:
386
+ progress_bar.close()