""" Simple training loop; Boilerplate that could apply to any arbitrary neural network, so nothing in this file really has anything to do with GPT specifically. """ from typing import Optional, Tuple, List import time import os from collections import defaultdict from accelerate import Accelerator import torch from torch.nn import functional as F from torch.utils.data.dataloader import DataLoader from mingpt.utils import CfgNode as CN from cube3d.training.utils import save_model_weights, mask_cross_entropy, normalize_bboxs, top_k_prob_mask from cube3d.training.process_single_ldr import logits2ldr, logits2ldrot, logits2ldrp, logits2flatldrp, logits2flatldrpr from cube3d.inference.utils import load_model_weights from tqdm import tqdm def generate_tokens( engine, prompt, inputs_ids, latent, resolution_base=8.0, disable_postprocess=False, top_p=None, bounding_box_xyz=None, strategy=None ): output_ids = engine.t2t( #[prompt], prompt, #use_kv_cache=True, inputs_ids=inputs_ids, latent=latent, use_kv_cache=False, resolution_base=resolution_base, top_p=top_p, bounding_box_xyz=bounding_box_xyz, strategy=strategy ) return output_ids class Trainer: @staticmethod def get_default_config(): C = CN() # device to train on C.device = 'auto' # dataloder parameters C.num_workers = 4 # optimizer parameters C.max_iters = None C.batch_size = 4 C.learning_rate = 3e-4 C.betas = (0.9, 0.95) C.weight_decay = 0.1 # only applied on matmul weights C.grad_norm_clip = 1.0 C.save_interval = None return C def __init__( self, config, engine, train_dataset, accelerator, tb, prompt: str, indices: Optional[List[int]] = None, resolution_base: float = 8.0, disable_postprocessing: bool = False, top_p: float = None, bounding_box_xyz: Optional[Tuple[float]] = None, save_gpt_ckpt_path: str = None, mode: str = 'train' ): self.config = config self.engine = engine self.model = engine.gpt_model self.optimizer = None self.callbacks = defaultdict(list) self.train_dataset = train_dataset self.accelerator = accelerator # Training parameters self.prompt = prompt self.targets = indices self.resolution_base = resolution_base self.disable_postprocessing = disable_postprocessing self.top_p = top_p self.bounding_box_xyz = bounding_box_xyz self.save_gpt_ckpt_path = save_gpt_ckpt_path # determine the device we'll train on if config.device == 'auto': self.device = 'cuda' if torch.cuda.is_available() else 'cpu' else: self.device = config.device self.model = self.model.to(self.device) print("running on device", self.device) # variables that will be assigned to trainer class later for logging and etc self.iter_num = 0 self.iter_time = 0.0 self.iter_dt = 0.0 self.tb_writer = tb self.mode = mode def add_callback(self, onevent: str, callback): self.callbacks[onevent].append(callback) def set_callback(self, onevent: str, callback): self.callbacks[onevent] = [callback] def trigger_callbacks(self, onevent: str): for callback in self.callbacks.get(onevent, []): callback(self) def run(self): model, config = self.model, self.config # setup the optimizer #self.optimizer = self.engine.configure_optimizers(config) self.optimizer, self.scheduler = self.engine.configure_optimizers_scratch_linear(config) #self.engine.configure_optimizers_lora_linear(config) # setup the dataloader train_loader = DataLoader( self.train_dataset, shuffle=False if self.mode!='train' else True, batch_size=config.batch_size, ) model.train() model, self.optimizer, train_loader = self.accelerator.prepare(model, self.optimizer, train_loader) self.iter_num = 0 self.iter_time = time.time() data_iter = iter(train_loader) ema_loss_for_log = 0.0 ema_ploss_for_log = 0.0 ema_rloss_for_log = 0.0 ema_dloss_for_log = 0.0 ema_floss_for_log = 0.0 #loss dat_num = 1217 #286 x_num = 251 y_num = 215 z_num = 525 rot_num = 24 shift = 0 stride = 5 attr_shift = stride-3 #with dat and rot,+1 for bert bert_shift = 1 x = x_num xy = x_num + y_num + rot_num xyz = x_num + y_num + z_num + rot_num progress_bar = tqdm(range(0, config.max_iters), desc="Training progress") #while True: for self.iter_num in range(0, config.max_iters+1): # fetch the next batch (x, y) and re-init iterator if needed try: batch = next(data_iter) except StopIteration: data_iter = iter(train_loader) batch = next(data_iter) #batch = [t['latent'].to(self.device) for t in batch] self.prompt, self.targets, self.box = batch['prompt'], batch['target'].to(self.device), batch['bbox'] #self.targets = batch['latent'].to(self.device) targets = self.targets.clone() logits, inputs_ids, strategy, mask, cut_idx = generate_tokens( self.engine, self.prompt, targets, None, self.resolution_base, self.disable_postprocessing, self.top_p, #self.bounding_box_xyz, normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box) None ) # rotation_loss = F.cross_entropy( # logits[:,:-1,:rot_num].permute(0, 2, 1), # inputs_ids[:,shift:,:rot_num].argmax(-1), # ) # px_loss = mask_cross_entropy(rot_num, x+rot_num, self.box[:, 0], logits, inputs_ids, shift) # py_loss = mask_cross_entropy(x+rot_num, xy, self.box[:, 1], logits, inputs_ids, shift) # pz_loss = mask_cross_entropy(xy, xyz, self.box[:, 2], logits, inputs_ids, shift) px_loss = F.cross_entropy( logits[:,1+attr_shift+bert_shift:-1:stride,rot_num+1:x+rot_num+1+1].permute(0, 2, 1), inputs_ids[:,shift:,-5], ignore_index=-1 #+1 for padding ) py_loss = F.cross_entropy( logits[:,0+attr_shift+bert_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1), inputs_ids[:,shift:,-4], ignore_index=-1 ) pz_loss = F.cross_entropy( logits[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4].permute(0, 2, 1), inputs_ids[:,shift:,-3], ignore_index=-1 ) position_loss = px_loss + py_loss + pz_loss # dat_loss = F.cross_entropy( # logits[:,0:-4:stride,:dat_num+1].permute(0, 2, 1), # inputs_ids[:,shift:,-6], # ignore_index=-1 # ) rotation_loss = F.cross_entropy( logits[:,1+bert_shift:-3:stride,:rot_num+1].permute(0, 2, 1), inputs_ids[:,shift:,-7], ignore_index=-1 ) # flag_loss = F.cross_entropy( # logits[:,:-1,xyz+dat_num:xyz+dat_num+2].permute(0, 2, 1), # inputs_ids[:,shift:,xyz+dat_num:xyz+dat_num+2].argmax(-1), # ) # flag_loss = F.cross_entropy( # logits[:,:-1,-2:].permute(0, 2, 1), # inputs_ids[:,shift:,-2:].argmax(-1), # ) lambda_posiition = 1.0 lambda_rotation = 1.0 lambda_dat = 1.0 lambda_flag = 50.0 self.loss = lambda_posiition * position_loss #+ \ #lambda_rotation * rotation_loss #+ \ #lambda_flag * flag_loss #lambda_dat * dat_loss + \ if strategy==1 or strategy==2: self.loss+=lambda_rotation * rotation_loss # targets = self.targets.clone() # # mask_topk, mask_inv = top_k_prob_mask(F.softmax(logits[:, 1:-3:stride, :rot_num+1], dim=2), cut_idx, top_percent=0.5) # # targets[:,shift:,-7][mask_topk] = logits[:,1:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1)[mask_topk] # # targets[:,shift:,-7][mask_inv] = self.engine.gpt_model.rot_num+1 # targets[:,shift:,-7] = logits[:,1:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1) # #targets[:,shift:,-4] = logits_y[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1).argmax(dim=1) # logits_x, inputs_ids, strategy, mask, cut_idx = generate_tokens( # self.engine, # self.prompt, # targets, # None, # self.resolution_base, # self.disable_postprocessing, # self.top_p, # #self.bounding_box_xyz, # normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box) # 0 # ) # targets = self.targets.clone() # targets[:,shift:,-7] = logits_x[:,1+bert_shift:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1) # mask_x, mask_x_inv = top_k_prob_mask(F.softmax(logits[:,1+attr_shift+bert_shift:-1:stride,rot_num+1:x+rot_num+1+1], dim=2), cut_idx, top_percent=0.5) # mask_y, mask_y_inv = top_k_prob_mask(F.softmax(logits[:,0+attr_shift+bert_shift:-2:stride,x+rot_num+2:xy+3], dim=2), cut_idx, top_percent=0.5) # mask_z, mask_z_inv = top_k_prob_mask(F.softmax(logits[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4], dim=2), cut_idx, top_percent=0.5) # targets[:,shift:,-5][mask_x] = logits_x[:,1+attr_shift+bert_shift:-1:stride,rot_num+1:x+rot_num+1+1].permute(0, 2, 1).argmax(dim=1)[mask_x] # targets[:,shift:,-5][mask_x_inv] = self.engine.gpt_model.x_num+1 # targets[:,shift:,-4][mask_y] = logits_x[:,0+attr_shift+bert_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1).argmax(dim=1)[mask_y] # targets[:,shift:,-4][mask_y_inv] = self.engine.gpt_model.y_num+1 # targets[:,shift:,-3][mask_z] = logits_x[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4].permute(0, 2, 1).argmax(dim=1)[mask_z] # targets[:,shift:,-3][mask_z_inv] = self.engine.gpt_model.z_num+1 # logits_p, inputs_ids, strategy, mask, cut_idx = generate_tokens( # self.engine, # self.prompt, # targets, # None, # self.resolution_base, # self.disable_postprocessing, # self.top_p, # #self.bounding_box_xyz, # normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box) # None # ) # logits_p[:,1+bert_shift:-3:stride,:rot_num+1] = logits[:,1+bert_shift:-3:stride,:rot_num+1] # logits2flatldrpr(logits_p[0].cpu().detach().numpy(), inputs_ids[0].cpu().detach().numpy(), stride, 0, output_file=f"test_rightd2r2p2p_{self.iter_num}_scratch_0p5_bert.ldr") # targets = self.targets.clone() # targets[:,shift:,-7] = logits[:,1:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1) # targets[:,shift:,-4] = logits_y[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1).argmax(dim=1) # targets[:,shift:,-5] = logits_x[:,1+attr_shift:-1:stride,rot_num+1:x+rot_num+1+1].permute(0, 2, 1).argmax(dim=1) # logits_z, inputs_ids, strategy = generate_tokens( # self.engine, # self.prompt, # targets, # None, # self.resolution_base, # self.disable_postprocessing, # self.top_p, # #self.bounding_box_xyz, # normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box) # 3 # ) # backprop and update the parameters model.zero_grad(set_to_none=True) # #if self.mode!='train': # logits_z[:,1:-3:stride,:rot_num+1] = logits[:,1:-3:stride,:rot_num+1] # logits_z[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3] = logits_y[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3] # logits_z[:,1+attr_shift:-1:stride,rot_num+1:x+rot_num+1+1] = logits_x[:,1+attr_shift:-1:stride,rot_num+1:x+rot_num+1+1] # if self.iter_num>4: # break self.accelerator.backward(self.loss) torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) self.optimizer.step() self.scheduler.step() with torch.no_grad(): # Progress bar ema_loss_for_log = 0.4 * self.loss.item() + 0.6 * ema_loss_for_log ema_ploss_for_log = 0.4 * position_loss.item() + 0.6 * ema_ploss_for_log ema_rloss_for_log = 0.4 * rotation_loss.item() + 0.6 * ema_rloss_for_log #ema_dloss_for_log = 0.4 * dat_loss.item() + 0.6 * ema_dloss_for_log #ema_floss_for_log = 0.4 * flag_loss.item() + 0.6 * ema_floss_for_log if self.iter_num % 10 == 0: progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}", "Positon_Loss": f"{ema_ploss_for_log:.{7}f}", "Rotation_Loss": f"{ema_rloss_for_log:.{7}f}", #"Dat_Loss": f"{ema_dloss_for_log:.{7}f}", #"Flag_Loss": f"{ema_floss_for_log:.{7}f}", }) progress_bar.update(10) #logits2ldr(logits[0].cpu().detach().numpy()) if (self.iter_num % config.save_interval == 0 and self.iter_num != 0): if self.accelerator.is_main_process: save_model_weights( self.engine.gpt_model, self.save_gpt_ckpt_path, ) # self.engine.gpt_model.save_pretrained(self.save_gpt_ckpt_path) # torch.save({ # "ldr_proj": self.engine.gpt_model.ldr_proj.state_dict(), # "ldr_head": self.engine.gpt_model.ldr_head.state_dict(), # "rte": self.engine.gpt_model.rte.state_dict(), # "dte": self.engine.gpt_model.dte.state_dict(), # "xte": self.engine.gpt_model.xte.state_dict(), # "yte": self.engine.gpt_model.yte.state_dict(), # "zte": self.engine.gpt_model.zte.state_dict(), # }, f"{self.save_gpt_ckpt_path}/unfrozen_weights.pth") if self.tb_writer: #and self.accelerator.is_main_process: self.tb_writer.add_scalar(f'train_loss/position_loss', position_loss.item(), self.iter_num) self.tb_writer.add_scalar(f'train_loss/rotation_loss', rotation_loss.item(), self.iter_num) #self.tb_writer.add_scalar(f'train_loss/dat_loss', dat_loss.item(), self.iter_num) #self.tb_writer.add_scalar(f'train_loss/flag_loss', flag_loss.item(), self.iter_num) self.tb_writer.add_scalar(f'train_loss/total_loss', self.loss.item(), self.iter_num) if self.iter_num == config.max_iters: progress_bar.close()