""" Simple training loop; Boilerplate that could apply to any arbitrary neural network, so nothing in this file really has anything to do with GPT specifically. """ from typing import Optional, Tuple, List import time import os from collections import defaultdict from accelerate import Accelerator import torch from torch.nn import functional as F from torch.utils.data.dataloader import DataLoader # from mingpt.utils import CfgNode as CN from omegaconf import DictConfig as CN # 使用omegaconf作为配置替代 from cube3d.training.utils import save_model_weights, mask_cross_entropy, normalize_bboxs, top_k_prob_mask, visualize_token_probabilities, visualize_max_prob_distribution from cube3d.training.process_single_ldr import logits2ldr, logits2ldrot, logits2ldrp, logits2flatldrp, logits2flatldrpr, logits2botldrpr from cube3d.inference.utils import load_model_weights from tqdm import tqdm def generate_tokens( engine, prompt, inputs_ids, latent, resolution_base=8.0, disable_postprocess=False, top_p=None, bounding_box_xyz=None, strategy=None, mode=None ): output_ids = engine.t2t( #[prompt], prompt, #use_kv_cache=True, inputs_ids=inputs_ids, latent=latent, use_kv_cache=False, resolution_base=resolution_base, top_p=top_p, bounding_box_xyz=bounding_box_xyz, strategy=strategy, mode=mode ) return output_ids class Infer: @staticmethod def get_default_config(): C = CN() # device to train on C.device = 'auto' # dataloder parameters C.num_workers = 4 # optimizer parameters C.max_iters = None C.batch_size = 4 C.learning_rate = 3e-4 C.betas = (0.9, 0.95) C.weight_decay = 0.1 # only applied on matmul weights C.grad_norm_clip = 1.0 C.save_interval = None return C def __init__( self, config, engine, train_dataset, accelerator, tb, prompt: str, indices: Optional[List[int]] = None, resolution_base: float = 8.0, disable_postprocessing: bool = False, top_p: float = None, bounding_box_xyz: Optional[Tuple[float]] = None, save_gpt_ckpt_path: str = None, mode: str = 'train' ): self.config = config self.engine = engine self.model = engine.gpt_model self.optimizer = None self.callbacks = defaultdict(list) self.train_dataset = train_dataset self.accelerator = accelerator # Training parameters self.prompt = prompt self.targets = indices self.resolution_base = resolution_base self.disable_postprocessing = disable_postprocessing self.top_p = top_p self.bounding_box_xyz = bounding_box_xyz self.save_gpt_ckpt_path = save_gpt_ckpt_path # determine the device we'll train on if config.device == 'auto': self.device = 'cuda' if torch.cuda.is_available() else 'cpu' else: self.device = config.device self.model = self.model.to(self.device) print("running on device", self.device) # variables that will be assigned to trainer class later for logging and etc self.iter_num = 0 self.iter_time = 0.0 self.iter_dt = 0.0 self.tb_writer = tb self.mode = mode self.probs = [] def add_callback(self, onevent: str, callback): self.callbacks[onevent].append(callback) def set_callback(self, onevent: str, callback): self.callbacks[onevent] = [callback] def trigger_callbacks(self, onevent: str): for callback in self.callbacks.get(onevent, []): callback(self) def run(self): model, config = self.model, self.config # setup the optimizer #self.optimizer = self.engine.configure_optimizers(config) self.optimizer, self.scheduler = self.engine.configure_optimizers_scratch_linear(config) #self.engine.configure_optimizers_lora_linear(config) # setup the dataloader train_loader = DataLoader( self.train_dataset, shuffle=False if self.mode!='train' else True, batch_size=config.batch_size, ) model.train() model, self.optimizer, train_loader = self.accelerator.prepare(model, self.optimizer, train_loader) self.iter_num = 0 self.iter_time = time.time() data_iter = iter(train_loader) ema_loss_for_log = 0.0 ema_ploss_for_log = 0.0 ema_rloss_for_log = 0.0 ema_dloss_for_log = 0.0 ema_floss_for_log = 0.0 #loss dat_num = 1217 #301 x_num = 251 #213 y_num = 215 #73 z_num = 525 #411 rot_num = 24 shift = 0 stride = 5 attr_shift = stride-3 #with dat and rot,+1 for bert bert_shift = 1 x = x_num xy = x_num + y_num + rot_num xyz = x_num + y_num + z_num + rot_num progress_bar = tqdm(range(0, config.max_iters), desc="Training progress") #while True: for self.iter_num in range(0, config.max_iters+1): # fetch the next batch (x, y) and re-init iterator if needed try: batch = next(data_iter) except StopIteration: data_iter = iter(train_loader) batch = next(data_iter) #batch = [t['latent'].to(self.device) for t in batch] self.prompt, self.targets, self.box = batch['prompt'], batch['target'].to(self.device), batch['bbox'] #self.targets = batch['latent'].to(self.device) targets = self.targets.clone() # import ipdb; ipdb.set_trace() logits, inputs_ids, strategy, mask, cut_idx = generate_tokens( self.engine, self.prompt, targets, None, self.resolution_base, self.disable_postprocessing, self.top_p, #self.bounding_box_xyz, normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box) 1, self.mode ) # rotation_loss = F.cross_entropy( # logits[:,:-1,:rot_num].permute(0, 2, 1), # inputs_ids[:,shift:,:rot_num].argmax(-1), # ) # px_loss = mask_cross_entropy(rot_num, x+rot_num, self.box[:, 0], logits, inputs_ids, shift) # py_loss = mask_cross_entropy(x+rot_num, xy, self.box[:, 1], logits, inputs_ids, shift) # pz_loss = mask_cross_entropy(xy, xyz, self.box[:, 2], logits, inputs_ids, shift) px_loss = F.cross_entropy( logits[:,1+attr_shift+bert_shift:-1:stride,rot_num+1:x+rot_num+1+1].permute(0, 2, 1), inputs_ids[:,shift:,-5], ignore_index=-1 #+1 for padding ) py_loss = F.cross_entropy( logits[:,0+attr_shift+bert_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1), inputs_ids[:,shift:,-4], ignore_index=-1 ) pz_loss = F.cross_entropy( logits[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4].permute(0, 2, 1), inputs_ids[:,shift:,-3], ignore_index=-1 ) position_loss = px_loss + py_loss + pz_loss # dat_loss = F.cross_entropy( # logits[:,0:-4:stride,:dat_num+1].permute(0, 2, 1), # inputs_ids[:,shift:,-6], # ignore_index=-1 # ) rotation_loss = F.cross_entropy( logits[:,1+bert_shift:-3:stride,:rot_num+1].permute(0, 2, 1), inputs_ids[:,shift:,-7], ignore_index=-1 ) # flag_loss = F.cross_entropy( # logits[:,:-1,xyz+dat_num:xyz+dat_num+2].permute(0, 2, 1), # inputs_ids[:,shift:,xyz+dat_num:xyz+dat_num+2].argmax(-1), # ) # flag_loss = F.cross_entropy( # logits[:,:-1,-2:].permute(0, 2, 1), # inputs_ids[:,shift:,-2:].argmax(-1), # ) lambda_posiition = 1.0 lambda_rotation = 1.0 lambda_dat = 1.0 lambda_flag = 50.0 self.loss = lambda_posiition * position_loss #+ \ #lambda_rotation * rotation_loss #+ \ #lambda_flag * flag_loss #lambda_dat * dat_loss + \ if strategy==1 or strategy==2: self.loss+=lambda_rotation * rotation_loss targets = self.targets.clone() # mask_topk, mask_inv = top_k_prob_mask(F.softmax(logits[:, 1:-3:stride, :rot_num+1], dim=2), cut_idx, top_percent=0.5) # targets[:,shift:,-7][mask_topk] = logits[:,1:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1)[mask_topk] # targets[:,shift:,-7][mask_inv] = self.engine.gpt_model.rot_num+1 targets[:,shift:,-7] = logits[:,1:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1) #targets[:,shift:,-4] = logits_y[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1).argmax(dim=1) # fig = visualize_token_probabilities( # probs=F.softmax(logits[:, 1:-3:stride, :rot_num+1], dim=2), # cut_idx=cut_idx, # sample_idx=0, # tokens_per_page=10, # save_dir='token_probability_pages' # 图片会保存到这个文件夹 # ) logits_x, inputs_ids, strategy, mask, cut_idx = generate_tokens( self.engine, self.prompt, targets, None, self.resolution_base, self.disable_postprocessing, self.top_p, #self.bounding_box_xyz, normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box) 0, self.mode ) logits_x[:,1+bert_shift:-3:stride,:rot_num+1] = logits[:,1+bert_shift:-3:stride,:rot_num+1] output_dir = "train_d2r2p_scratch_whole_perm10re" os.makedirs(output_dir, exist_ok=True) logits2botldrpr(logits_x[0].cpu().detach().numpy(), inputs_ids[0].cpu().detach().numpy(), stride, 0, output_file=os.path.join(output_dir, f"test_d2r2p_{self.iter_num}_scratch_0p85_bert.ldr")) #targets = self.targets.clone() #targets[:,shift:,-7] = logits[:,1+bert_shift:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1) #mask_x, mask_x_inv = top_k_prob_mask(F.softmax(logits_x[:,1+attr_shift+bert_shift:-1:stride,rot_num+1:x+rot_num+1+1], dim=2), cut_idx, top_percent=0.3) #mask_y, mask_y_inv = top_k_prob_mask(F.softmax(logits_x[:,0+attr_shift+bert_shift:-2:stride,x+rot_num+2:xy+3], dim=2), cut_idx, top_percent=0.3) #mask_z, mask_z_inv = top_k_prob_mask(F.softmax(logits_x[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4], dim=2), cut_idx, top_percent=0.3) #targets[:,shift:,-5][mask_x] = logits_x[:,1+attr_shift+bert_shift:-1:stride,rot_num+1:x+rot_num+1+1].permute(0, 2, 1).argmax(dim=1)[mask_x] #targets[:,shift:,-5][mask_x_inv] = self.engine.gpt_model.x_num+1 #targets[:,shift:,-4][mask_y] = logits_x[:,0+attr_shift+bert_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1).argmax(dim=1)[mask_y] #targets[:,shift:,-4][mask_y_inv] = self.engine.gpt_model.y_num+1 #targets[:,shift:,-3][mask_z] = logits_x[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4].permute(0, 2, 1).argmax(dim=1)[mask_z] #targets[:,shift:,-3][mask_z_inv] = self.engine.gpt_model.z_num+1 # fig = visualize_token_probabilities( # probs=F.softmax(logits_x[:,1+attr_shift+bert_shift:-1:stride,rot_num+1:x+rot_num+1+1], dim=2), # cut_idx=cut_idx, # sample_idx=0, # tokens_per_page=10, # save_dir='token_probability_pages_x' # 图片会保存到这个文件夹 # ) # fig = visualize_token_probabilities( # probs=F.softmax(logits_x[:,0+attr_shift+bert_shift:-2:stride,x+rot_num+2:xy+3], dim=2), # cut_idx=cut_idx, # sample_idx=0, # tokens_per_page=10, # save_dir='token_probability_pages_y' # 图片会保存到这个文件夹 # ) # fig = visualize_token_probabilities( # probs=F.softmax(logits_x[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4], dim=2), # cut_idx=cut_idx, # sample_idx=0, # tokens_per_page=10, # save_dir='token_probability_pages_z' # 图片会保存到这个文件夹 # ) # fig = visualize_max_prob_distribution( # probs=F.softmax(logits_x[:,2+attr_shift+bert_shift::stride,xy+3:xyz+4], dim=2), # cut_idx=cut_idx, # sample_idx=0, # bins=20, # 0-1分成20个区间(每个区间0.05) # figsize=(12, 6) # ) # fig.savefig(f'token_max_probability_distribution_iter{self.iter_num}.png') # current_probs = torch.softmax(logits[:,1+bert_shift:-3:stride,:rot_num+1], dim=2) # self.probs.append(current_probs[:, :min(int(cut_idx[0]), current_probs.shape[1]), :]) # logits_p, inputs_ids, strategy, mask, cut_idx = generate_tokens( # self.engine, # self.prompt, # targets, # None, # self.resolution_base, # self.disable_postprocessing, # self.top_p, # #self.bounding_box_xyz, # normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box) # 0 # ) # logits_p[:,1+bert_shift:-3:stride,:rot_num+1] = logits[:,1+bert_shift:-3:stride,:rot_num+1] # logits2botldrpr(logits_p[0].cpu().detach().numpy(), inputs_ids[0].cpu().detach().numpy(), stride, 0, output_file=f"gt_d2r2p2p_scratch_0p85_bot/test_d2r2p2p_{self.iter_num}_scratch_0p85_bert.ldr") # targets = self.targets.clone() # targets[:,shift:,-7] = logits[:,1:-3:stride,:rot_num+1].permute(0, 2, 1).argmax(dim=1) # targets[:,shift:,-4] = logits_y[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3].permute(0, 2, 1).argmax(dim=1) # targets[:,shift:,-5] = logits_x[:,1+attr_shift:-1:stride,rot_num+1:x+rot_num+1+1].permute(0, 2, 1).argmax(dim=1) # logits_z, inputs_ids, strategy = generate_tokens( # self.engine, # self.prompt, # targets, # None, # self.resolution_base, # self.disable_postprocessing, # self.top_p, # #self.bounding_box_xyz, # normalize_bboxs(self.box.float(), [x_num-1, y_num-1, z_num-1]), #batch_normalization(self.box) # 3 # ) # backprop and update the parameters model.zero_grad(set_to_none=True) # #if self.mode!='train': # logits_z[:,1:-3:stride,:rot_num+1] = logits[:,1:-3:stride,:rot_num+1] # logits_z[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3] = logits_y[:,0+attr_shift:-2:stride,x+rot_num+2:xy+3] # logits_z[:,1+attr_shift:-1:stride,rot_num+1:x+rot_num+1+1] = logits_x[:,1+attr_shift:-1:stride,rot_num+1:x+rot_num+1+1] if self.iter_num>4: break # self.accelerator.backward(self.loss) # torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) # self.optimizer.step() # self.scheduler.step() with torch.no_grad(): # Progress bar ema_loss_for_log = 0.4 * self.loss.item() + 0.6 * ema_loss_for_log ema_ploss_for_log = 0.4 * position_loss.item() + 0.6 * ema_ploss_for_log ema_rloss_for_log = 0.4 * rotation_loss.item() + 0.6 * ema_rloss_for_log #ema_dloss_for_log = 0.4 * dat_loss.item() + 0.6 * ema_dloss_for_log #ema_floss_for_log = 0.4 * flag_loss.item() + 0.6 * ema_floss_for_log if self.iter_num % 10 == 0: progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}", "Positon_Loss": f"{ema_ploss_for_log:.{7}f}", "Rotation_Loss": f"{ema_rloss_for_log:.{7}f}", #"Dat_Loss": f"{ema_dloss_for_log:.{7}f}", #"Flag_Loss": f"{ema_floss_for_log:.{7}f}", }) progress_bar.update(10) #logits2ldr(logits[0].cpu().detach().numpy()) if (self.iter_num % config.save_interval == 0 and self.iter_num != 0): if self.accelerator.is_main_process: save_model_weights( self.engine.gpt_model, self.save_gpt_ckpt_path, ) # self.engine.gpt_model.save_pretrained(self.save_gpt_ckpt_path) # torch.save({ # "ldr_proj": self.engine.gpt_model.ldr_proj.state_dict(), # "ldr_head": self.engine.gpt_model.ldr_head.state_dict(), # "rte": self.engine.gpt_model.rte.state_dict(), # "dte": self.engine.gpt_model.dte.state_dict(), # "xte": self.engine.gpt_model.xte.state_dict(), # "yte": self.engine.gpt_model.yte.state_dict(), # "zte": self.engine.gpt_model.zte.state_dict(), # }, f"{self.save_gpt_ckpt_path}/unfrozen_weights.pth") if self.tb_writer: #and self.accelerator.is_main_process: self.tb_writer.add_scalar(f'train_loss/position_loss', position_loss.item(), self.iter_num) self.tb_writer.add_scalar(f'train_loss/rotation_loss', rotation_loss.item(), self.iter_num) #self.tb_writer.add_scalar(f'train_loss/dat_loss', dat_loss.item(), self.iter_num) #self.tb_writer.add_scalar(f'train_loss/flag_loss', flag_loss.item(), self.iter_num) self.tb_writer.add_scalar(f'train_loss/total_loss', self.loss.item(), self.iter_num) if self.iter_num == config.max_iters: progress_bar.close()