| bb = breakpoint |
| import torch |
| from torch.utils.data import DataLoader |
| import wandb |
| from argparse import ArgumentParser |
| from datasets.octmae import OctMae |
| from datasets.foundation_pose import FoundationPose |
| from datasets.generic_loader import GenericLoader |
|
|
| from utils.collate import collate |
| from models.rayquery import RayQuery |
| from engine import train_epoch, eval_epoch, eval_model |
| import torch.nn as nn |
| from models.rayquery import RayQuery, PointmapEncoder, RayEncoder |
| from models.losses import * |
| import utils.misc as misc |
| import os |
| from utils.viz import just_load_viz |
| from utils.fusion import fuse_batch |
| import socket |
| import time |
| from utils.augmentations import * |
|
|
| def parse_args(): |
| parser = ArgumentParser() |
| parser.add_argument("--dataset_train", type=str, default="TableOfCubes(size=10,n_views=2,seed=747)") |
| parser.add_argument("--dataset_test", type=str, default="TableOfCubes(size=10,n_views=2,seed=787)") |
| parser.add_argument("--dataset_just_load", type=str, default=None) |
| parser.add_argument("--logdir", type=str, default="logs") |
| parser.add_argument("--batch_size", type=int, default=5) |
| parser.add_argument("--n_epochs", type=int, default=100) |
| parser.add_argument("--n_workers", type=int, default=4) |
| parser.add_argument("--model", type=str, default="RayQuery(ray_enc=RayEncoder(),pointmap_enc=PointmapEncoder(),criterion=RayCompletion(ConfLoss(L21)))") |
| parser.add_argument("--save_every", type=int, default=1) |
| parser.add_argument("--resume", type=str, default=None) |
| parser.add_argument("--eval_every", type=int, default=3) |
| parser.add_argument("--wandb_project", type=str, default=None) |
| parser.add_argument("--wandb_run_name", type=str, default="init") |
| parser.add_argument("--just_load", action="store_true") |
| parser.add_argument("--device", type=str, default="cuda") |
| parser.add_argument("--rr_addr", type=str, default="0.0.0.0:"+os.getenv("RERUN_RECORDING","9876")) |
| parser.add_argument("--mesh", action="store_true") |
| parser.add_argument("--max_norm", type=float, default=-1) |
| parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)') |
| parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR', |
| help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') |
| parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', |
| help='lower lr bound for cyclic schedulers that hit 0') |
| parser.add_argument('--warmup_epochs', type=int, default=10) |
| parser.add_argument('--weight_decay', type=float, default=0.01) |
| parser.add_argument('--normalize_mode',type=str,default='None') |
| parser.add_argument('--start_from',type=str,default=None) |
| parser.add_argument('--augmentor',type=str,default='None') |
| return parser.parse_args() |
|
|
| def main(args): |
| load_dino = False |
| if not args.just_load: |
| dataset_train = eval(args.dataset_train) |
| dataset_test = eval(args.dataset_test) |
| if not dataset_train.prefetch_dino: |
| load_dino = True |
| rank, world_size, local_rank = misc.setup_distributed() |
| sampler_train = torch.utils.data.DistributedSampler( |
| dataset_train, num_replicas=world_size, rank=rank, shuffle=True |
| ) |
|
|
| sampler_test = torch.utils.data.DistributedSampler( |
| dataset_test, num_replicas=world_size, rank=rank, shuffle=False |
| ) |
|
|
| train_loader = DataLoader( |
| dataset_train, sampler=sampler_train, batch_size=args.batch_size, shuffle=False, collate_fn=collate, |
| num_workers=args.n_workers, |
| pin_memory=True, |
| prefetch_factor=2, |
| drop_last=True |
| ) |
| test_loader = DataLoader( |
| dataset_test, sampler=sampler_test, batch_size=args.batch_size, shuffle=False, collate_fn=collate, |
| num_workers=args.n_workers, |
| pin_memory=True, |
| prefetch_factor=2, |
| drop_last=True |
| ) |
|
|
| n_scenes_epoch = len(train_loader) * args.batch_size * world_size |
| print(f"Number of scenes in epoch: {n_scenes_epoch}") |
| else: |
| if args.dataset_just_load is None: |
| dataset = eval(args.dataset_train) |
| else: |
| dataset = eval(args.dataset_just_load) |
| if not dataset.prefetch_dino: |
| load_dino = True |
| rank, world_size, local_rank = misc.setup_distributed() |
| sampler_train = torch.utils.data.DistributedSampler( |
| dataset, num_replicas=world_size, rank=rank, shuffle=False |
| ) |
| just_loader = DataLoader(dataset, sampler=sampler_train, batch_size=args.batch_size, shuffle=False, collate_fn=collate, |
| pin_memory=True, |
| drop_last=True |
| ) |
| |
| model = eval(args.model).to(args.device) |
| if args.augmentor != 'None': |
| augmentor = eval(args.augmentor) |
| else: |
| augmentor = None |
| |
| if load_dino and len(model.dino_layers) > 0: |
| dino_model = torch.hub.load('facebookresearch/dinov2', "dinov2_vitl14_reg") |
| dino_model.eval() |
| dino_model.to("cuda") |
| else: |
| dino_model = None |
| |
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],find_unused_parameters=True) |
| model_without_ddp = model.module if hasattr(model, 'module') else model |
|
|
| eff_batch_size = args.batch_size * misc.get_world_size() |
| if args.lr is None: |
| args.lr = args.blr * eff_batch_size / 256 |
|
|
| param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay) |
| optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) |
| os.makedirs(args.logdir,exist_ok=True) |
| start_epoch = 0 |
| print("Running on host %s" % socket.gethostname()) |
| if args.resume and os.path.exists(os.path.join(args.resume, "checkpoint-latest.pth")): |
| checkpoint = torch.load(os.path.join(args.resume, "checkpoint-latest.pth"), map_location='cpu') |
| model_without_ddp.load_state_dict(checkpoint['model']) |
| model_params = list(model.parameters()) |
| print("Resume checkpoint %s" % args.resume) |
|
|
| if 'optimizer' in checkpoint and 'epoch' in checkpoint: |
| optimizer.load_state_dict(checkpoint['optimizer']) |
| start_epoch = checkpoint['epoch'] + 1 |
| print("With optim & sched!") |
| del checkpoint |
| elif args.start_from is not None: |
| checkpoint = torch.load(args.start_from, map_location='cpu') |
| model_without_ddp.load_state_dict(checkpoint['model']) |
| print("Start from checkpoint %s" % args.start_from) |
| if args.just_load: |
| with torch.no_grad(): |
| while True: |
| |
| for data in just_loader: |
| pred, gt, loss_dict, batch = eval_model(model,data,mode='viz',args=args,dino_model=dino_model,augmentor=augmentor) |
| |
| gt = {k: v.float() for k, v in gt.items()} |
| pred = {k: v.float() for k, v in pred.items()} |
| |
| |
| |
| print(f"{'Key':<10} {'Value':<10}") |
| print("-"*20) |
| for key, value in loss_dict.items(): |
| print(f"{key:<10}: {value:.4f}") |
| print("-"*20) |
| name = args.logdir |
| addr = args.rr_addr |
| if args.mesh: |
| fused_meshes = fuse_batch(pred,gt,data, voxel_size=0.002) |
| else: |
| fused_meshes = None |
| just_load_viz(pred,gt,batch,addr=addr,name=name,fused_meshes=fused_meshes) |
| breakpoint() |
| return |
| else: |
| if args.wandb_project and misc.get_rank() == 0: |
| wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args) |
| log_wandb = args.wandb_project |
| else: |
| log_wandb = None |
| for epoch in range(start_epoch,args.n_epochs): |
| start_time = time.time() |
| log_dict = train_epoch(model,train_loader,optimizer,device=args.device,max_norm=args.max_norm,epoch=epoch, |
| log_wandb=log_wandb,batch_size=eff_batch_size,args=args,dino_model=dino_model,augmentor=augmentor) |
| end_time = time.time() |
| print(f"Epoch {epoch} train loss: {log_dict['loss']:.4f} grad_norm: {log_dict['grad_norm']:.4f} \n") |
| print(f"Time taken for epoch {epoch}: {end_time - start_time:.2f} seconds") |
| |
| if epoch % args.eval_every == 0: |
| test_log_dict = eval_epoch(model,test_loader,device=args.device,dino_model=dino_model,args=args,augmentor=augmentor) |
| print(f"Epoch {epoch} test loss: {test_log_dict['loss']:.4f} \n") |
| if log_wandb: |
| wandb_dict = {f"test_{k}":v for k,v in test_log_dict.items()} |
| wandb.log(wandb_dict, step=(epoch+1)*n_scenes_epoch) |
| if epoch % args.save_every == 0: |
| |
| |
| misc.save_model(args, epoch, model_without_ddp, optimizer, epoch_name=f"latest") |
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| main(args) |