0xZohar commited on
Commit
904d07e
·
verified ·
1 Parent(s): aaa20ea

Add code/cube3d/training/bert_infer.py

Browse files
Files changed (1) hide show
  1. code/cube3d/training/bert_infer.py +442 -0
code/cube3d/training/bert_infer.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple training loop; Boilerplate that could apply to any arbitrary neural network,
3
+ so nothing in this file really has anything to do with GPT specifically.
4
+ """
5
+ from typing import Optional, Tuple, List
6
+ import time
7
+ import os
8
+ from collections import defaultdict
9
+
10
+ from accelerate import Accelerator
11
+ import torch
12
+ from torch.nn import functional as F
13
+ from torch.utils.data.dataloader import DataLoader
14
+ # from mingpt.utils import CfgNode as CN
15
+ from omegaconf import DictConfig as CN # 使用omegaconf作为配置替代
16
+ from cube3d.training.utils import save_model_weights, mask_cross_entropy, normalize_bboxs, top_k_prob_mask, visualize_token_probabilities, visualize_max_prob_distribution
17
+ from cube3d.training.process_single_ldr import logits2ldr, logits2ldrot, logits2ldrp, logits2flatldrp, logits2flatldrpr, logits2botldrpr
18
+ from cube3d.inference.utils import load_model_weights
19
+ from tqdm import tqdm
20
+
21
+
22
+ def generate_tokens(
23
+ engine,
24
+ prompt,
25
+ inputs_ids,
26
+ latent,
27
+ resolution_base=8.0,
28
+ disable_postprocess=False,
29
+ top_p=None,
30
+ bounding_box_xyz=None,
31
+ strategy=None,
32
+ mode=None
33
+ ):
34
+ output_ids = engine.t2t(
35
+ #[prompt],
36
+ prompt,
37
+ #use_kv_cache=True,
38
+ inputs_ids=inputs_ids,
39
+ latent=latent,
40
+ use_kv_cache=False,
41
+ resolution_base=resolution_base,
42
+ top_p=top_p,
43
+ bounding_box_xyz=bounding_box_xyz,
44
+ strategy=strategy,
45
+ mode=mode
46
+ )
47
+
48
+ return output_ids
49
+
50
+ class Infer:
51
+
52
+ @staticmethod
53
+ def get_default_config():
54
+ C = CN()
55
+ # device to train on
56
+ C.device = 'auto'
57
+ # dataloder parameters
58
+ C.num_workers = 4
59
+ # optimizer parameters
60
+ C.max_iters = None
61
+ C.batch_size = 4
62
+ C.learning_rate = 3e-4
63
+ C.betas = (0.9, 0.95)
64
+ C.weight_decay = 0.1 # only applied on matmul weights
65
+ C.grad_norm_clip = 1.0
66
+ C.save_interval = None
67
+ return C
68
+
69
+ def __init__(
70
+ self,
71
+ config,
72
+ engine,
73
+ train_dataset,
74
+ accelerator,
75
+ tb,
76
+ prompt: str,
77
+ indices: Optional[List[int]] = None,
78
+ resolution_base: float = 8.0,
79
+ disable_postprocessing: bool = False,
80
+ top_p: float = None,
81
+ bounding_box_xyz: Optional[Tuple[float]] = None,
82
+ save_gpt_ckpt_path: str = None,
83
+ mode: str = 'train'
84
+ ):
85
+ self.config = config
86
+ self.engine = engine
87
+ self.model = engine.gpt_model
88
+ self.optimizer = None
89
+ self.callbacks = defaultdict(list)
90
+ self.train_dataset = train_dataset
91
+ self.accelerator = accelerator
92
+
93
+ # Training parameters
94
+ self.prompt = prompt
95
+ self.targets = indices
96
+ self.resolution_base = resolution_base
97
+ self.disable_postprocessing = disable_postprocessing
98
+ self.top_p = top_p
99
+ self.bounding_box_xyz = bounding_box_xyz
100
+ self.save_gpt_ckpt_path = save_gpt_ckpt_path
101
+
102
+ # determine the device we'll train on
103
+ if config.device == 'auto':
104
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
105
+ else:
106
+ self.device = config.device
107
+
108
+ self.model = self.model.to(self.device)
109
+ print("running on device", self.device)
110
+
111
+ # variables that will be assigned to trainer class later for logging and etc
112
+ self.iter_num = 0
113
+ self.iter_time = 0.0
114
+ self.iter_dt = 0.0
115
+
116
+ self.tb_writer = tb
117
+ self.mode = mode
118
+
119
+ self.probs = []
120
+
121
+
122
+ def add_callback(self, onevent: str, callback):
123
+ self.callbacks[onevent].append(callback)
124
+
125
+ def set_callback(self, onevent: str, callback):
126
+ self.callbacks[onevent] = [callback]
127
+
128
+ def trigger_callbacks(self, onevent: str):
129
+ for callback in self.callbacks.get(onevent, []):
130
+ callback(self)
131
+
132
+ def run(self):
133
+ model, config = self.model, self.config
134
+ # setup the optimizer
135
+ #self.optimizer = self.engine.configure_optimizers(config)
136
+ self.optimizer, self.scheduler = self.engine.configure_optimizers_scratch_linear(config) #self.engine.configure_optimizers_lora_linear(config)
137
+
138
+ # setup the dataloader
139
+ train_loader = DataLoader(
140
+ self.train_dataset,
141
+ shuffle=False if self.mode!='train' else True,
142
+ batch_size=config.batch_size,
143
+ )
144
+
145
+ model.train()
146
+
147
+ model, self.optimizer, train_loader = self.accelerator.prepare(model, self.optimizer, train_loader)
148
+
149
+ self.iter_num = 0
150
+ self.iter_time = time.time()
151
+ data_iter = iter(train_loader)
152
+ ema_loss_for_log = 0.0
153
+ ema_ploss_for_log = 0.0
154
+ ema_rloss_for_log = 0.0
155
+ ema_dloss_for_log = 0.0
156
+ ema_floss_for_log = 0.0
157
+
158
+ #loss
159
+ dat_num = 1217 #301
160
+ x_num = 251 #213
161
+ y_num = 215 #73
162
+ z_num = 525 #411
163
+ rot_num = 24
164
+ shift = 0
165
+ stride = 5
166
+ attr_shift = stride-3 #with dat and rot,+1 for bert
167
+ bert_shift = 1
168
+
169
+ x = x_num
170
+ xy = x_num + y_num + rot_num
171
+ xyz = x_num + y_num + z_num + rot_num
172
+
173
+ progress_bar = tqdm(range(0, config.max_iters), desc="Training progress")
174
+ #while True:
175
+ for self.iter_num in range(0, config.max_iters+1):
176
+ # fetch the next batch (x, y) and re-init iterator if needed
177
+ try:
178
+ batch = next(data_iter)
179
+ except StopIteration:
180
+ data_iter = iter(train_loader)
181
+ batch = next(data_iter)
182
+
183
+ #batch = [t['latent'].to(self.device) for t in batch]
184
+ self.prompt, self.targets, self.box = batch['prompt'], batch['target'].to(self.device), batch['bbox']
185
+ #self.targets = batch['latent'].to(self.device)
186
+ targets = self.targets.clone()
187
+ # import ipdb; ipdb.set_trace()
188
+ logits, inputs_ids, strategy, mask, cut_idx = generate_tokens(
189
+ self.engine,
190
+ self.prompt,
191
+ targets,
192
+ None,
193
+ self.resolution_base,
194
+ self.disable_postprocessing,
195
+ self.top_p,
196
+ #self.bounding_box_xyz,
197
+ normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box)
198
+ 1,
199
+ self.mode
200
+ )
201
+
202
+
203
+ # rotation_loss = F.cross_entropy(
204
+ # logits[:,:-1,:rot_num].permute(0, 2, 1),
205
+ # inputs_ids[:,shift:,:rot_num].argmax(-1),
206
+ # )
207
+
208
+
209
+ # px_loss = mask_cross_entropy(rot_num, x+rot_num, self.box[:, 0], logits, inputs_ids, shift)
210
+ # py_loss = mask_cross_entropy(x+rot_num, xy, self.box[:, 1], logits, inputs_ids, shift)
211
+ # pz_loss = mask_cross_entropy(xy, xyz, self.box[:, 2], logits, inputs_ids, shift)
212
+
213
+ px_loss = F.cross_entropy(
214
+ logits[:,1+attr_shift+bert_shift:-1:stride,rot_num+1:x+rot_num+1+1].permute(0, 2, 1),
215
+ inputs_ids[:,shift:,-5],
216
+ ignore_index=-1 #+1 for padding
217
+ )
218
+ py_loss = F.cross_entropy(
219
+ logits[:,0+attr_shift+bert_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1),
220
+ inputs_ids[:,shift:,-4],
221
+ ignore_index=-1
222
+ )
223
+ pz_loss = F.cross_entropy(
224
+ logits[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4].permute(0, 2, 1),
225
+ inputs_ids[:,shift:,-3],
226
+ ignore_index=-1
227
+ )
228
+
229
+ position_loss = px_loss + py_loss + pz_loss
230
+
231
+ # dat_loss = F.cross_entropy(
232
+ # logits[:,0:-4:stride,:dat_num+1].permute(0, 2, 1),
233
+ # inputs_ids[:,shift:,-6],
234
+ # ignore_index=-1
235
+ # )
236
+
237
+ rotation_loss = F.cross_entropy(
238
+ logits[:,1+bert_shift:-3:stride,:rot_num+1].permute(0, 2, 1),
239
+ inputs_ids[:,shift:,-7],
240
+ ignore_index=-1
241
+ )
242
+
243
+
244
+ # flag_loss = F.cross_entropy(
245
+ # logits[:,:-1,xyz+dat_num:xyz+dat_num+2].permute(0, 2, 1),
246
+ # inputs_ids[:,shift:,xyz+dat_num:xyz+dat_num+2].argmax(-1),
247
+ # )
248
+
249
+ # flag_loss = F.cross_entropy(
250
+ # logits[:,:-1,-2:].permute(0, 2, 1),
251
+ # inputs_ids[:,shift:,-2:].argmax(-1),
252
+ # )
253
+
254
+ lambda_posiition = 1.0
255
+ lambda_rotation = 1.0
256
+ lambda_dat = 1.0
257
+ lambda_flag = 50.0
258
+
259
+ self.loss = lambda_posiition * position_loss #+ \
260
+ #lambda_rotation * rotation_loss #+ \
261
+ #lambda_flag * flag_loss
262
+ #lambda_dat * dat_loss + \
263
+
264
+ if strategy==1 or strategy==2:
265
+ self.loss+=lambda_rotation * rotation_loss
266
+
267
+
268
+ targets = self.targets.clone()
269
+ # mask_topk, mask_inv = top_k_prob_mask(F.softmax(logits[:, 1:-3:stride, :rot_num+1], dim=2), cut_idx, top_percent=0.5)
270
+ # targets[:,shift:,-7][mask_topk] = logits[:,1:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1)[mask_topk]
271
+ # targets[:,shift:,-7][mask_inv] = self.engine.gpt_model.rot_num+1
272
+
273
+ targets[:,shift:,-7] = logits[:,1:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1)
274
+ #targets[:,shift:,-4] = logits_y[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1).argmax(dim=1)
275
+ # fig = visualize_token_probabilities(
276
+ # probs=F.softmax(logits[:, 1:-3:stride, :rot_num+1], dim=2),
277
+ # cut_idx=cut_idx,
278
+ # sample_idx=0,
279
+ # tokens_per_page=10,
280
+ # save_dir='token_probability_pages' # 图片会保存到这个文件夹
281
+ # )
282
+
283
+ logits_x, inputs_ids, strategy, mask, cut_idx = generate_tokens(
284
+ self.engine,
285
+ self.prompt,
286
+ targets,
287
+ None,
288
+ self.resolution_base,
289
+ self.disable_postprocessing,
290
+ self.top_p,
291
+ #self.bounding_box_xyz,
292
+ normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box)
293
+ 0,
294
+ self.mode
295
+ )
296
+
297
+ logits_x[:,1+bert_shift:-3:stride,:rot_num+1] = logits[:,1+bert_shift:-3:stride,:rot_num+1]
298
+ output_dir = "train_d2r2p_scratch_whole_perm10re"
299
+ os.makedirs(output_dir, exist_ok=True)
300
+ logits2botldrpr(logits_x[0].cpu().detach().numpy(), inputs_ids[0].cpu().detach().numpy(), stride, 0, output_file=os.path.join(output_dir, f"test_d2r2p_{self.iter_num}_scratch_0p85_bert.ldr"))
301
+
302
+
303
+ #targets = self.targets.clone()
304
+ #targets[:,shift:,-7] = logits[:,1+bert_shift:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1)
305
+
306
+ #mask_x, mask_x_inv = top_k_prob_mask(F.softmax(logits_x[:,1+attr_shift+bert_shift:-1:stride,rot_num+1:x+rot_num+1+1], dim=2), cut_idx, top_percent=0.3)
307
+ #mask_y, mask_y_inv = top_k_prob_mask(F.softmax(logits_x[:,0+attr_shift+bert_shift:-2:stride,x+rot_num+2:xy+3], dim=2), cut_idx, top_percent=0.3)
308
+ #mask_z, mask_z_inv = top_k_prob_mask(F.softmax(logits_x[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4], dim=2), cut_idx, top_percent=0.3)
309
+
310
+ #targets[:,shift:,-5][mask_x] = logits_x[:,1+attr_shift+bert_shift:-1:stride,rot_num+1:x+rot_num+1+1].permute(0, 2, 1).argmax(dim=1)[mask_x]
311
+ #targets[:,shift:,-5][mask_x_inv] = self.engine.gpt_model.x_num+1
312
+ #targets[:,shift:,-4][mask_y] = logits_x[:,0+attr_shift+bert_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1).argmax(dim=1)[mask_y]
313
+ #targets[:,shift:,-4][mask_y_inv] = self.engine.gpt_model.y_num+1
314
+ #targets[:,shift:,-3][mask_z] = logits_x[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4].permute(0, 2, 1).argmax(dim=1)[mask_z]
315
+ #targets[:,shift:,-3][mask_z_inv] = self.engine.gpt_model.z_num+1
316
+
317
+ # fig = visualize_token_probabilities(
318
+ # probs=F.softmax(logits_x[:,1+attr_shift+bert_shift:-1:stride,rot_num+1:x+rot_num+1+1], dim=2),
319
+ # cut_idx=cut_idx,
320
+ # sample_idx=0,
321
+ # tokens_per_page=10,
322
+ # save_dir='token_probability_pages_x' # 图片会保存到这个文件夹
323
+ # )
324
+
325
+ # fig = visualize_token_probabilities(
326
+ # probs=F.softmax(logits_x[:,0+attr_shift+bert_shift:-2:stride,x+rot_num+2:xy+3], dim=2),
327
+ # cut_idx=cut_idx,
328
+ # sample_idx=0,
329
+ # tokens_per_page=10,
330
+ # save_dir='token_probability_pages_y' # 图片会保存到这个文件夹
331
+ # )
332
+
333
+ # fig = visualize_token_probabilities(
334
+ # probs=F.softmax(logits_x[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4], dim=2),
335
+ # cut_idx=cut_idx,
336
+ # sample_idx=0,
337
+ # tokens_per_page=10,
338
+ # save_dir='token_probability_pages_z' # 图片会保存到这个文件夹
339
+ # )
340
+ # fig = visualize_max_prob_distribution(
341
+ # probs=F.softmax(logits_x[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4], dim=2),
342
+ # cut_idx=cut_idx,
343
+ # sample_idx=0,
344
+ # bins=20, # 0-1分成20个区间(每个区间0.05)
345
+ # figsize=(12, 6)
346
+ # )
347
+ # fig.savefig(f'token_max_probability_distribution_iter{self.iter_num}.png')
348
+ # current_probs = torch.softmax(logits[:,1+bert_shift:-3:stride,:rot_num+1], dim=2)
349
+ # self.probs.append(current_probs[:, :min(int(cut_idx[0]), current_probs.shape[1]), :])
350
+ # logits_p, inputs_ids, strategy, mask, cut_idx = generate_tokens(
351
+ # self.engine,
352
+ # self.prompt,
353
+ # targets,
354
+ # None,
355
+ # self.resolution_base,
356
+ # self.disable_postprocessing,
357
+ # self.top_p,
358
+ # #self.bounding_box_xyz,
359
+ # normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box)
360
+ # 0
361
+ # )
362
+
363
+ # logits_p[:,1+bert_shift:-3:stride,:rot_num+1] = logits[:,1+bert_shift:-3:stride,:rot_num+1]
364
+ # logits2botldrpr(logits_p[0].cpu().detach().numpy(), inputs_ids[0].cpu().detach().numpy(), stride, 0, output_file=f"gt_d2r2p2p_scratch_0p85_bot/test_d2r2p2p_{self.iter_num}_scratch_0p85_bert.ldr")
365
+
366
+ # targets = self.targets.clone()
367
+ # targets[:,shift:,-7] = logits[:,1:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1)
368
+ # targets[:,shift:,-4] = logits_y[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1).argmax(dim=1)
369
+ # targets[:,shift:,-5] = logits_x[:,1+attr_shift:-1:stride,rot_num+1:x+rot_num+1+1].permute(0, 2, 1).argmax(dim=1)
370
+ # logits_z, inputs_ids, strategy = generate_tokens(
371
+ # self.engine,
372
+ # self.prompt,
373
+ # targets,
374
+ # None,
375
+ # self.resolution_base,
376
+ # self.disable_postprocessing,
377
+ # self.top_p,
378
+ # #self.bounding_box_xyz,
379
+ # normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box)
380
+ # 3
381
+ # )
382
+
383
+ # backprop and update the parameters
384
+ model.zero_grad(set_to_none=True)
385
+ # #if self.mode!='train':
386
+ # logits_z[:,1:-3:stride,:rot_num+1] = logits[:,1:-3:stride,:rot_num+1]
387
+ # logits_z[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3] = logits_y[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3]
388
+ # logits_z[:,1+attr_shift:-1:stride,rot_num+1:x+rot_num+1+1] = logits_x[:,1+attr_shift:-1:stride,rot_num+1:x+rot_num+1+1]
389
+
390
+ if self.iter_num>4:
391
+ break
392
+
393
+ # self.accelerator.backward(self.loss)
394
+ # torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
395
+ # self.optimizer.step()
396
+ # self.scheduler.step()
397
+
398
+ with torch.no_grad():
399
+ # Progress bar
400
+ ema_loss_for_log = 0.4 * self.loss.item() + 0.6 * ema_loss_for_log
401
+ ema_ploss_for_log = 0.4 * position_loss.item() + 0.6 * ema_ploss_for_log
402
+ ema_rloss_for_log = 0.4 * rotation_loss.item() + 0.6 * ema_rloss_for_log
403
+ #ema_dloss_for_log = 0.4 * dat_loss.item() + 0.6 * ema_dloss_for_log
404
+ #ema_floss_for_log = 0.4 * flag_loss.item() + 0.6 * ema_floss_for_log
405
+ if self.iter_num % 10 == 0:
406
+ progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}",
407
+ "Positon_Loss": f"{ema_ploss_for_log:.{7}f}",
408
+ "Rotation_Loss": f"{ema_rloss_for_log:.{7}f}",
409
+ #"Dat_Loss": f"{ema_dloss_for_log:.{7}f}",
410
+ #"Flag_Loss": f"{ema_floss_for_log:.{7}f}",
411
+ })
412
+ progress_bar.update(10)
413
+
414
+ #logits2ldr(logits[0].cpu().detach().numpy())
415
+
416
+ if (self.iter_num % config.save_interval == 0 and self.iter_num != 0):
417
+ if self.accelerator.is_main_process:
418
+ save_model_weights(
419
+ self.engine.gpt_model,
420
+ self.save_gpt_ckpt_path,
421
+ )
422
+ # self.engine.gpt_model.save_pretrained(self.save_gpt_ckpt_path)
423
+ # torch.save({
424
+ # "ldr_proj": self.engine.gpt_model.ldr_proj.state_dict(),
425
+ # "ldr_head": self.engine.gpt_model.ldr_head.state_dict(),
426
+ # "rte": self.engine.gpt_model.rte.state_dict(),
427
+ # "dte": self.engine.gpt_model.dte.state_dict(),
428
+ # "xte": self.engine.gpt_model.xte.state_dict(),
429
+ # "yte": self.engine.gpt_model.yte.state_dict(),
430
+ # "zte": self.engine.gpt_model.zte.state_dict(),
431
+ # }, f"{self.save_gpt_ckpt_path}/unfrozen_weights.pth")
432
+
433
+
434
+ if self.tb_writer: #and self.accelerator.is_main_process:
435
+ self.tb_writer.add_scalar(f'train_loss/position_loss', position_loss.item(), self.iter_num)
436
+ self.tb_writer.add_scalar(f'train_loss/rotation_loss', rotation_loss.item(), self.iter_num)
437
+ #self.tb_writer.add_scalar(f'train_loss/dat_loss', dat_loss.item(), self.iter_num)
438
+ #self.tb_writer.add_scalar(f'train_loss/flag_loss', flag_loss.item(), self.iter_num)
439
+ self.tb_writer.add_scalar(f'train_loss/total_loss', self.loss.item(), self.iter_num)
440
+
441
+ if self.iter_num == config.max_iters:
442
+ progress_bar.close()