| bb=breakpoint |
| import torch |
| from utils.batch_prep import prepare_fast_batch, normalize_batch, denormalize_batch |
| from utils.utils import scenes_to_batch, batch_to_scenes |
| from utils.geometry import center_pointmaps, uncenter_pointmaps |
|
|
|
|
| |
| |
| |
| |
| |
| from utils.eval import eval_pred |
| from utils.geometry import depth2pts |
|
|
| def batch_to_device(batch,device='cuda'): |
| for key in batch: |
| if isinstance(batch[key],torch.Tensor): |
| batch[key] = batch[key].to(device) |
| elif isinstance(batch[key],dict): |
| batch[key] = batch_to_device(batch[key],device) |
| return batch |
|
|
| def eval_model(model,batch,mode='loss',device='cuda',dino_model=None,args=None,augmentor=None,return_scale=False): |
| batch = batch_to_device(batch,device) |
| |
| if isinstance(model,torch.nn.parallel.DistributedDataParallel): |
| dino_layers = model.module.dino_layers |
| else: |
| dino_layers = model.dino_layers |
| if 'pointmaps' not in list(batch['input_cams'].keys()): |
| batch = prepare_fast_batch(batch,dino_model,dino_layers) |
|
|
| normalize_mode = args.normalize_mode if args is not None else 'median' |
| batch, scale_factors = normalize_batch(batch,normalize_mode) |
| if augmentor is not None: |
| batch = augmentor(batch) |
|
|
| batch, n_cams = scenes_to_batch(batch) |
| batch = center_pointmaps(batch) |
|
|
| device = args.device if args is not None else 'cuda' |
| with torch.amp.autocast(device_type=device, dtype=torch.bfloat16): |
| pred, gt, loss_dict = model(batch,mode='viz') |
| |
| if 'pointmaps' not in list(pred.keys()): |
| pred['pointmaps'] = depth2pts(pred['depths'].squeeze(-1),batch['new_cams']['Ks']) |
| elif 'depths' not in list(pred.keys()): |
| pred['depths'] = pred['pointmaps'][...,-1] |
| loss_dict = {**loss_dict,**eval_pred(pred, gt)} |
| if mode == 'loss': |
| return loss_dict |
| elif mode == 'viz': |
| pred, gt, batch = uncenter_pointmaps(pred, gt, batch) |
| pred, gt, batch = batch_to_scenes(pred, gt,batch, n_cams) |
| if return_scale: |
| return pred, gt, loss_dict, scale_factors[0].item() |
| else: |
| return pred, gt, loss_dict |
| else: |
| raise ValueError(f"Invalid mode: {mode}") |
|
|
| def update_loss_dict(loss_dict,loss_dict_new,sample_count): |
| for key in loss_dict_new: |
| if key not in loss_dict: |
| loss_dict[key] = loss_dict_new[key] |
| else: |
| |
| |
| loss_dict[key] = (loss_dict[key] * sample_count + loss_dict_new[key]) / (sample_count + 1) |
| return loss_dict |
|
|
| def train_epoch(model, train_loader, optimizer, device='cuda', max_norm=1.0,log_wandb=False,epoch=0,batch_size=None,args=None,dino_model=None,augmentor=None): |
| model.train() |
| all_losses_dict = {} |
| |
| sample_idx = epoch * batch_size * len(train_loader) |
| scaler = GradScaler() |
| for i, batch in tqdm(enumerate(train_loader),total=len(train_loader)): |
| optimizer.zero_grad() |
| new_loss_dict = eval_model(model, batch, mode='loss', device=device,dino_model=dino_model,args=args,augmentor=augmentor) |
| loss = new_loss_dict['loss'] |
| if loss is None: |
| continue |
| |
| scaler.scale(loss).backward() |
| |
| scaler.unscale_(optimizer) |
| |
| grad_norm = torch.norm(torch.stack([torch.norm(p.grad) for p in model.parameters() if p.grad is not None])) |
| if grad_norm.isnan(): |
| breakpoint() |
| |
| |
| if max_norm > 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) |
| |
| |
| |
| scaler.step(optimizer) |
|
|
| |
| scaler.update() |
|
|
| new_loss_dict['grad_norm'] = grad_norm.detach().cpu().item() |
|
|
| misc.adjust_learning_rate(optimizer, epoch + i/len(train_loader), args) |
| optimizer.step() |
| |
| new_loss_dict = {k: (v.detach().cpu().item() if isinstance(v, torch.Tensor) else v) for k, v in new_loss_dict.items()} |
| if log_wandb: |
| wandb_dict = {f"train_{k}":v for k,v in new_loss_dict.items()} |
| wandb.log(wandb_dict, step=sample_idx + (i+1)*batch_size) |
| |
| wandb.log({"train_lr": optimizer.param_groups[0]['lr']}, step=sample_idx + (i+1)*batch_size) |
| |
| all_losses_dict = update_loss_dict(all_losses_dict, new_loss_dict,sample_count=i) |
| |
| torch.cuda.empty_cache() |
| del loss |
| del new_loss_dict |
| del grad_norm |
| del batch |
| |
| return all_losses_dict |
|
|
| def eval_epoch(model,test_loader,device='cuda',dino_model=None,args=None,augmentor=None): |
| model.eval() |
| all_losses_dict = {} |
| with torch.no_grad(): |
| for i, batch in tqdm(enumerate(test_loader),total=len(test_loader)): |
| new_loss_dict = eval_model(model,batch,mode='loss',device=device,dino_model=dino_model,args=args,augmentor=augmentor) |
| if new_loss_dict is None: |
| continue |
| all_losses_dict = update_loss_dict(all_losses_dict,new_loss_dict,sample_count=i) |
|
|
| torch.cuda.empty_cache() |
| del new_loss_dict |
| del batch |
| |
| return all_losses_dict |