diff --git a/outdoor_v48_16gpu_v2/.hydra/config.yaml b/outdoor_v48_16gpu_v2/.hydra/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a5d5c7cd09bfe9b78c47ee89f010cf4491f3944 --- /dev/null +++ b/outdoor_v48_16gpu_v2/.hydra/config.yaml @@ -0,0 +1,68 @@ +teacher: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model +pretrained: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model +load_only_encoder: false +long_context: false +fixed_length: true +resume: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu_v2/checkpoint-last.pth +benchmark: false +num_views: 64 +num_test_views: 4 +n_corres_train: 0 +n_corres_test: 0 +train_criterion: DistillLoss() +test_criterion: DistillLoss() +allow_repeat: false +root_vkitti2: /scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti +root_kitti: /scratch-shared/wwei2/eval/kitti_odometry/dataset +root_kitti_velo: /gpfs/work2/0/prjs0824/semantickitti/dataset +root_kitti360: /scratch-shared/wwei2/downloads/kitti360/KITTI-360 +root_kitti360_velo: /scratch-shared/wwei2/downloads/kitti360/KITTI-360 +root_waymo: /scratch-shared/wwei2/waymo_v2 +root_waymo_lidar: /scratch-shared/wwei2/waymo_v2 +dataset_vkitti2: VirtualKITTI2_Multi(allow_repeat=${allow_repeat}, split='train', + ROOT="${root_vkitti2}", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), + (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=${num_views}, + n_corres=${n_corres_train}) +dataset_kitti360: KITTI360_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_kitti360}", + velodyne_root="${root_kitti360_velo}", aug_crop=16, resolution=[(518, 392), (518, + 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, + num_views=${num_views}, n_corres=${n_corres_train}) +dataset_waymo: Waymo_v2_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_waymo}", + lidar_root="${root_waymo_lidar}", aug_crop=16, resolution=[(518, 392), (518, 336), + (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=${num_views}, + n_corres=${n_corres_train}) +train_dataset: 6000 @ ${dataset_vkitti2} + 6000 @ ${dataset_kitti360} + 5400 @ ${dataset_waymo} +test_dataset: 200 @ VirtualKITTI2_Multi(split='train', ROOT="${root_vkitti2}", resolution=(518, + 154), num_views=${num_test_views}, seed=42, n_corres=${n_corres_test}) +seed: 0 +batch_size: 1 +accum_iter: 1 +gradient_checkpointing: false +epochs: 10 +start_epoch: 0 +start_step: 0 +weight_decay: 0.05 +lr: 1.0e-05 +min_lr: 1.0e-08 +warmup_epochs: 0.5 +amp: 1 +num_workers: 4 +world_size: 1 +local-rank: -1 +dist_url: env:// +rank: 0 +gpu: 0 +distributed: false +dist_backend: nccl +eval_freq: 1 +save_freq: 0.1 +max_checkpoints: 10 +keep_freq: 1 +print_freq: 10 +print_img_freq: 50000000 +num_imgs_vis: 4 +save_dir: /scratch-shared/wwei2/training_upstream/checkpoints +exp_name: outdoor_v48_16gpu_v2 +task: StreamVGGT +logdir: ${save_dir}/${exp_name}/logs +output_dir: ${save_dir}/${exp_name}/ diff --git a/outdoor_v48_16gpu_v2/.hydra/hydra.yaml b/outdoor_v48_16gpu_v2/.hydra/hydra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2f4df9d53ebffb8656b106e56d5023f6c275b44b --- /dev/null +++ b/outdoor_v48_16gpu_v2/.hydra/hydra.yaml @@ -0,0 +1,156 @@ +hydra: + run: + dir: ${save_dir}/${exp_name} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.num} + launcher: + _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher + sweeper: + _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper + max_batch_size: null + params: null + help: + app_name: ${hydra.job.name} + header: '${hydra.help.app_name} is powered by Hydra. + + ' + footer: 'Powered by Hydra (https://hydra.cc) + + Use --hydra-help to view Hydra specific help + + ' + template: '${hydra.help.header} + + == Configuration groups == + + Compose your configuration from those groups (group=option) + + + $APP_CONFIG_GROUPS + + + == Config == + + Override anything in the config (foo.bar=value) + + + $CONFIG + + + ${hydra.help.footer} + + ' + hydra_help: + template: 'Hydra (${hydra.runtime.version}) + + See https://hydra.cc for more info. + + + == Flags == + + $FLAGS_HELP + + + == Configuration groups == + + Compose your configuration from those groups (For example, append hydra/job_logging=disabled + to command line) + + + $HYDRA_CONFIG_GROUPS + + + Use ''--cfg hydra'' to Show the Hydra config. + + ' + hydra_help: ??? + hydra_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][HYDRA] %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + root: + level: INFO + handlers: + - console + loggers: + logging_example: + level: DEBUG + disable_existing_loggers: false + job_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + file: + class: logging.FileHandler + formatter: simple + filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log + root: + level: INFO + handlers: + - console + - file + disable_existing_loggers: false + env: {} + mode: RUN + searchpath: [] + callbacks: {} + output_subdir: .hydra + overrides: + hydra: + - hydra.mode=RUN + task: + - exp_name=outdoor_v48_16gpu_v2 + - resume=/scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu_v2/checkpoint-last.pth + job: + name: mytrain + chdir: null + override_dirname: exp_name=outdoor_v48_16gpu_v2,resume=/scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu_v2/checkpoint-last.pth + id: ??? + num: ??? + config_name: outdoor_v48 + env_set: {} + env_copy: [] + config: + override_dirname: + kv_sep: '=' + item_sep: ',' + exclude_keys: [] + runtime: + version: 1.3.2 + version_base: '1.3' + cwd: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/src + config_sources: + - path: hydra.conf + schema: pkg + provider: hydra + - path: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/config + schema: file + provider: main + - path: '' + schema: structured + provider: schema + output_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_16gpu_v2 + choices: + hydra/env: default + hydra/callbacks: null + hydra/job_logging: default + hydra/hydra_logging: default + hydra/hydra_help: default + hydra/help: default + hydra/sweeper: basic + hydra/launcher: basic + hydra/output: default + verbose: true diff --git a/outdoor_v48_16gpu_v2/.hydra/overrides.yaml b/outdoor_v48_16gpu_v2/.hydra/overrides.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8d42a02c6004db9472a7282118ca4b65c1d90082 --- /dev/null +++ b/outdoor_v48_16gpu_v2/.hydra/overrides.yaml @@ -0,0 +1,2 @@ +- exp_name=outdoor_v48_16gpu_v2 +- resume=/scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu_v2/checkpoint-last.pth diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/__init__.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/base_multiview_dataset.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/base_multiview_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..43571a69609444fc7d11dbdf6643c130dab6f127 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/base_multiview_dataset.py @@ -0,0 +1,576 @@ +import PIL +import numpy as np +import torch +import random +import itertools +from dust3r.datasets.base.easy_dataset import EasyDataset +from dust3r.datasets.utils.transforms import ImgNorm, SeqColorJitter +from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates +import dust3r.datasets.utils.cropping as cropping +from dust3r.datasets.utils.corr import extract_correspondences_from_pts3d + +from vggt.train_utils.augmentation import get_image_augmentation + + + +def get_ray_map(c2w1, c2w2, intrinsics, h, w): + c2w = np.linalg.inv(c2w1) @ c2w2 + i, j = np.meshgrid(np.arange(w), np.arange(h), indexing="xy") + grid = np.stack([i, j, np.ones_like(i)], axis=-1) + ro = c2w[:3, 3] + rd = np.linalg.inv(intrinsics) @ grid.reshape(-1, 3).T + rd = (c2w @ np.vstack([rd, np.ones_like(rd[0])])).T[:, :3].reshape(h, w, 3) + rd = rd / np.linalg.norm(rd, axis=-1, keepdims=True) + ro = np.broadcast_to(ro, (h, w, 3)) + ray_map = np.concatenate([ro, rd], axis=-1) + return ray_map + + +class BaseMultiViewDataset(EasyDataset): + """Define all basic options. + + Usage: + class MyDataset (BaseMultiViewDataset): + def _get_views(self, idx, rng): + # overload here + views = [] + views.append(dict(img=, ...)) + return views + """ + + def __init__( + self, + *, # only keyword arguments + num_views=None, + split=None, + resolution=None, # square_size or (width, height) or list of [(width,height), ...] + transform=ImgNorm, + aug_crop=False, + n_corres=0, + nneg=0, + seed=None, + allow_repeat=False, + seq_aug_crop=False, + ): + assert num_views is not None, "undefined num_views" + self.num_views = num_views + self.split = split + self._set_resolutions(resolution) + + self.n_corres = n_corres + self.nneg = nneg + assert ( + self.n_corres == "all" + or isinstance(self.n_corres, int) + or ( + isinstance(self.n_corres, list) and len(self.n_corres) == self.num_views + ) + ), f"Error, n_corres should either be 'all', a single integer or a list of length {self.num_views}" + assert ( + self.nneg == 0 or self.n_corres != "all" + ), "nneg should be 0 if n_corres is all" + + self.is_seq_color_jitter = False + if isinstance(transform, str): + transform = eval(transform) + if transform == SeqColorJitter: + transform = SeqColorJitter() + self.is_seq_color_jitter = True + self.transform = transform + + self.image_aug = get_image_augmentation( + color_jitter={ 'brightness': 0.5, + 'contrast': 0.5, + 'saturation': 0.5, + 'hue': 0.1, + 'p': 0.9}, +#common_config.augs.color_jitter, + gray_scale=True,#common_config.augs.gray_scale, + gau_blur=False, #common_config.augs.gau_blur, + ) + + + self.aug_crop = aug_crop + self.seed = seed + self.allow_repeat = allow_repeat + self.seq_aug_crop = seq_aug_crop + + def __len__(self): + return len(self.scenes) + + @staticmethod + def efficient_random_intervals( + start, + num_elements, + interval_range, + fixed_interval_prob=0.8, + weights=None, + seed=42, + ): + if random.random() < fixed_interval_prob: + intervals = random.choices(interval_range, weights=weights) * ( + num_elements - 1 + ) + else: + intervals = [ + random.choices(interval_range, weights=weights)[0] + for _ in range(num_elements - 1) + ] + return list(itertools.accumulate([start] + intervals)) + + def sample_based_on_timestamps(self, i, timestamps, num_views, interval=1): + time_diffs = np.abs(timestamps - timestamps[i]) + ids_candidate = np.where(time_diffs < interval)[0] + ids_candidate = np.sort(ids_candidate) + if (self.allow_repeat and len(ids_candidate) < num_views // 3) or ( + len(ids_candidate) < num_views + ): + return [] + ids_sel_list = [] + ids_candidate_left = ids_candidate.copy() + while len(ids_candidate_left) >= num_views: + ids_sel = np.random.choice(ids_candidate_left, num_views, replace=False) + ids_sel_list.append(sorted(ids_sel)) + ids_candidate_left = np.setdiff1d(ids_candidate_left, ids_sel) + + if len(ids_candidate_left) > 0 and len(ids_candidate) >= num_views: + ids_sel = np.concatenate( + [ + ids_candidate_left, + np.random.choice( + np.setdiff1d(ids_candidate, ids_candidate_left), + num_views - len(ids_candidate_left), + replace=False, + ), + ] + ) + ids_sel_list.append(sorted(ids_sel)) + + if self.allow_repeat: + ids_sel_list.append( + sorted(np.random.choice(ids_candidate, num_views, replace=True)) + ) + + # add sequences with fixed intervals (all possible intervals) + pos_i = np.where(ids_candidate == i)[0][0] + curr_interval = 1 + stop = len(ids_candidate) < num_views + while not stop: + pos_sel = [pos_i] + count = 0 + while len(pos_sel) < num_views: + if count % 2 == 0: + curr_pos_i = pos_sel[-1] + curr_interval + if curr_pos_i >= len(ids_candidate): + stop = True + break + pos_sel.append(curr_pos_i) + else: + curr_pos_i = pos_sel[0] - curr_interval + if curr_pos_i < 0: + stop = True + break + pos_sel.insert(0, curr_pos_i) + count += 1 + if not stop and len(pos_sel) == num_views: + ids_sel = sorted([ids_candidate[pos] for pos in pos_sel]) + if ids_sel not in ids_sel_list: + ids_sel_list.append(ids_sel) + curr_interval += 1 + return ids_sel_list + + @staticmethod + def blockwise_shuffle(x, rng, block_shuffle): + if block_shuffle is None: + return rng.permutation(x).tolist() + else: + assert block_shuffle > 0 + blocks = [x[i : i + block_shuffle] for i in range(0, len(x), block_shuffle)] + shuffled_blocks = [rng.permutation(block).tolist() for block in blocks] + shuffled_list = [item for block in shuffled_blocks for item in block] + return shuffled_list + + def get_seq_from_start_id( + self, + num_views, + id_ref, + ids_all, + rng, + min_interval=1, + max_interval=25, + video_prob=0.5, + fix_interval_prob=0.5, + block_shuffle=None, + ): + """ + args: + num_views: number of views to return + id_ref: the reference id (first id) + ids_all: all the ids + rng: random number generator + max_interval: maximum interval between two views + returns: + pos: list of positions of the views in ids_all, i.e., index for ids_all + is_video: True if the views are consecutive + """ + assert min_interval > 0, f"min_interval should be > 0, got {min_interval}" + assert ( + min_interval <= max_interval + ), f"min_interval should be <= max_interval, got {min_interval} and {max_interval}" + assert id_ref in ids_all + pos_ref = ids_all.index(id_ref) + all_possible_pos = np.arange(pos_ref, len(ids_all)) + + remaining_sum = len(ids_all) - 1 - pos_ref + + if remaining_sum >= num_views - 1: + if remaining_sum == num_views - 1: + assert ids_all[-num_views] == id_ref + return [pos_ref + i for i in range(num_views)], True + max_interval = min(max_interval, 2 * remaining_sum // (num_views - 1)) + intervals = [ + rng.choice(range(min_interval, max_interval + 1)) + for _ in range(num_views - 1) + ] + + # if video or collection + if rng.random() < video_prob: + # if fixed interval or random + if rng.random() < fix_interval_prob: + # regular interval + fixed_interval = rng.choice( + range( + 1, + min(remaining_sum // (num_views - 1) + 1, max_interval + 1), + ) + ) + intervals = [fixed_interval for _ in range(num_views - 1)] + is_video = True + else: + is_video = False + + pos = list(itertools.accumulate([pos_ref] + intervals)) + pos = [p for p in pos if p < len(ids_all)] + pos_candidates = [p for p in all_possible_pos if p not in pos] + pos = ( + pos + + rng.choice( + pos_candidates, num_views - len(pos), replace=False + ).tolist() + ) + + pos = ( + sorted(pos) + if is_video + else self.blockwise_shuffle(pos, rng, block_shuffle) + ) + #elif remaining_sum>1: + else: + # assert self.allow_repeat + uniq_num = remaining_sum + new_pos_ref = rng.choice(np.arange(pos_ref + 1)) + new_remaining_sum = len(ids_all) - 1 - new_pos_ref + new_max_interval = min(max_interval, new_remaining_sum // (uniq_num - 1)) + new_intervals = [ + rng.choice(range(1, new_max_interval + 1)) for _ in range(uniq_num - 1) + ] + + revisit_random = rng.random() + video_random = rng.random() + + if rng.random() < fix_interval_prob and video_random < video_prob: + # regular interval + fixed_interval = rng.choice(range(1, new_max_interval + 1)) + new_intervals = [fixed_interval for _ in range(uniq_num - 1)] + pos = list(itertools.accumulate([new_pos_ref] + new_intervals)) + + is_video = False + if revisit_random < 0.5 or video_prob == 1.0: # revisit, video / collection + is_video = video_random < video_prob + pos = ( + self.blockwise_shuffle(pos, rng, block_shuffle) + if not is_video + else pos + ) + num_full_repeat = num_views // uniq_num + pos = ( + pos * num_full_repeat + + pos[: num_views - len(pos) * num_full_repeat] + ) + elif revisit_random < 0.9: # random + pos = rng.choice(pos, num_views, replace=True) + else: # ordered + pos = sorted(rng.choice(pos, num_views, replace=True)) + assert len(pos) == num_views + return pos, is_video + + def get_img_and_ray_masks(self, is_metric, v, rng, p=[0.8, 0.15, 0.05]): + # generate img mask and raymap mask + if v == 0 or (not is_metric): + img_mask = True + raymap_mask = False + else: + rand_val = rng.random() + if rand_val < p[0]: + img_mask = True + raymap_mask = False + elif rand_val < p[0] + p[1]: + img_mask = False + raymap_mask = True + else: + img_mask = True + raymap_mask = True + return img_mask, raymap_mask + + def get_stats(self): + return f"{len(self)} groups of views" + + def __repr__(self): + resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]" + return ( + f"""{type(self).__name__}({self.get_stats()}, + {self.num_views=}, + {self.split=}, + {self.seed=}, + resolutions={resolutions_str}, + {self.transform=})""".replace( + "self.", "" + ) + .replace("\n", "") + .replace(" ", "") + ) + + def _get_views(self, idx, resolution, rng, num_views): + raise NotImplementedError() + + def __getitem__(self, idx): + # print("Receiving:" , idx) + if isinstance(idx, (tuple, list, np.ndarray)): + # the idx is specifying the aspect-ratio + idx, ar_idx, nview = idx + else: + assert len(self._resolutions) == 1 + ar_idx = 0 + nview = self.num_views + + assert nview >= 1 and nview <= self.num_views + # set-up the rng + if self.seed: # reseed for each __getitem__ + self._rng = np.random.default_rng(seed=self.seed + idx) + elif not hasattr(self, "_rng"): + seed = torch.randint(0, 2**32, (1,)).item() + self._rng = np.random.default_rng(seed=seed) + + if self.aug_crop > 1 and self.seq_aug_crop: + self.delta_target_resolution = self._rng.integers(0, self.aug_crop) + + # over-loaded code + resolution = self._resolutions[ + ar_idx + ] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler) + views = self._get_views(idx, resolution, self._rng, nview) + assert len(views) == nview + + if "camera_pose" not in views[0]: + views[0]["camera_pose"] = np.ones((4, 4), dtype=np.float32) + first_view_camera_pose = views[0]["camera_pose"] + transform = SeqColorJitter() if self.is_seq_color_jitter else self.transform + + for v, view in enumerate(views): + assert ( + "pts3d" not in view + ), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}" + view["idx"] = (idx, ar_idx, v) + + # encode the image + width, height = view["img"].size + + view["true_shape"] = np.int32((height, width)) + view["img"] = transform(view["img"]) + view["sky_mask"] = view["depthmap"] < 0 + + assert "camera_intrinsics" in view + if "camera_pose" not in view: + view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32) + else: + assert np.isfinite( + view["camera_pose"] + ).all(), f"NaN in camera pose for view {view_name(view)}" + + ray_map = get_ray_map( + first_view_camera_pose, + view["camera_pose"], + view["camera_intrinsics"], + height, + width, + ) + view["ray_map"] = ray_map.astype(np.float32) + + assert "pts3d" not in view + assert "valid_mask" not in view + assert np.isfinite( + view["depthmap"] + ).all(), f"NaN in depthmap for view {view_name(view)}" + pts3d, pts3d_local, valid_mask = depthmap_to_absolute_camera_coordinates(**view) + + + + + view["pts3d"] = pts3d + view["pts3d_local"] = pts3d_local + view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1) + + # check all datatypes + for key, val in view.items(): + res, err_msg = is_good_type(key, val) + assert res, f"{err_msg} with {key}={val} for view {view_name(view)}" + K = view["camera_intrinsics"] + if False: + if random.random() > 0.3:#self.cojitter_ratio: + images = torch.stack([view['img'] for view in views],axis=0) + images = self.image_aug(images) + for v, view in enumerate(views): + view['img'] = images[v] + + else: + for view in views: + view['img'] = self.image_aug(view['img'][None])[0] + + if self.n_corres > 0: + ref_view = views[0] + for view in views: + corres1, corres2, valid = extract_correspondences_from_pts3d( + ref_view, view, self.n_corres, self._rng, nneg=self.nneg + ) + view["corres"] = (corres1, corres2) + view["valid_corres"] = valid + + # last thing done! + for view in views: + view["rng"] = int.from_bytes(self._rng.bytes(4), "big") + return views + + def _set_resolutions(self, resolutions): + assert resolutions is not None, "undefined resolution" + + if not isinstance(resolutions, list): + resolutions = [resolutions] + + self._resolutions = [] + for resolution in resolutions: + if isinstance(resolution, int): + width = height = resolution + else: + width, height = resolution + assert isinstance( + width, int + ), f"Bad type for {width=} {type(width)=}, should be int" + assert isinstance( + height, int + ), f"Bad type for {height=} {type(height)=}, should be int" + self._resolutions.append((width, height)) + + def _crop_resize_if_necessary( + self, image, depthmap, intrinsics, resolution, rng=None, info=None + ): + """This function: + - first downsizes the image with LANCZOS inteprolation, + which is better than bilinear interpolation in + """ + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + + # downscale with lanczos interpolation so that image.size == resolution + # cropping centered on the principal point + W, H = image.size + cx, cy = intrinsics[:2, 2].round().astype(int) + min_margin_x = min(cx, W - cx) + min_margin_y = min(cy, H - cy) + assert min_margin_x > W / 5, f"Bad principal point in view={info}" + assert min_margin_y > H / 5, f"Bad principal point in view={info}" + # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) + l, t = cx - min_margin_x, cy - min_margin_y + r, b = cx + min_margin_x, cy + min_margin_y + crop_bbox = (l, t, r, b) + image, depthmap, intrinsics = cropping.crop_image_depthmap( + image, depthmap, intrinsics, crop_bbox + ) + + # transpose the resolution if necessary + W, H = image.size # new size + + # high-quality Lanczos down-scaling + target_resolution = np.array(resolution) + if self.aug_crop > 1: + target_resolution += ( + rng.integers(0, self.aug_crop) + if not self.seq_aug_crop + else self.delta_target_resolution + ) + image, depthmap, intrinsics = cropping.rescale_image_depthmap( + image, depthmap, intrinsics, target_resolution + ) + + # actual cropping (if necessary) with bilinear interpolation + intrinsics2 = cropping.camera_matrix_of_crop( + intrinsics, image.size, resolution, offset_factor=0.5 + ) + crop_bbox = cropping.bbox_from_intrinsics_in_out( + intrinsics, intrinsics2, resolution + ) + image, depthmap, intrinsics2 = cropping.crop_image_depthmap( + image, depthmap, intrinsics, crop_bbox + ) + + return image, depthmap, intrinsics2 + + +def is_good_type(key, v): + """returns (is_good, err_msg)""" + if isinstance(v, (str, int, tuple)): + return True, None + if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8): + return False, f"bad {v.dtype=}" + return True, None + + +def view_name(view, batch_index=None): + def sel(x): + return x[batch_index] if batch_index not in (None, slice(None)) else x + + db = sel(view["dataset"]) + label = sel(view["label"]) + instance = sel(view["instance"]) + return f"{db}/{label}/{instance}" + + +def transpose_to_landscape(view): + height, width = view["true_shape"] + + if width < height: + # rectify portrait to landscape + assert view["img"].shape == (3, height, width) + view["img"] = view["img"].swapaxes(1, 2) + + assert view["valid_mask"].shape == (height, width) + view["valid_mask"] = view["valid_mask"].swapaxes(0, 1) + + assert view["depthmap"].shape == (height, width) + view["depthmap"] = view["depthmap"].swapaxes(0, 1) + + assert view["pts3d"].shape == (height, width, 3) + view["pts3d"] = view["pts3d"].swapaxes(0, 1) + + # transpose x and y pixels + view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]] + + assert view["ray_map"].shape == (height, width, 6) + view["ray_map"] = view["ray_map"].swapaxes(0, 1) + + assert view["sky_mask"].shape == (height, width) + view["sky_mask"] = view["sky_mask"].swapaxes(0, 1) + + if "corres" in view: + # transpose correspondences x and y + view["corres"][0] = view["corres"][0][:, [1, 0]] + view["corres"][1] = view["corres"][1][:, [1, 0]] diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/batched_sampler.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/batched_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..b556e913c55791eea3323057402e9637abc9888a --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/batched_sampler.py @@ -0,0 +1,93 @@ +import numpy as np +import torch +from accelerate import Accelerator +import torch.utils +from torch.utils.data import BatchSampler, Sampler +import torch.utils.data + + +class CustomRandomSampler(Sampler): + """Random sampling under a constraint: each sample in the batch has the same feature, + which is chosen randomly from a known pool of 'features' for each batch. + + For instance, the 'feature' could be the image aspect-ratio. + + The index returned is a tuple (sample_idx, feat_idx). + This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. + """ + + def __init__( + self, + dataset, + batch_size, + pool_size, + min_view_size, + max_view_size, + world_size, + warmup=1, + drop_last=True, + ): + self.batch_size = batch_size + self.pool_size = pool_size + self.min_view_size = min_view_size + self.max_view_size = max_view_size + self.drop_last = drop_last + self.len_dataset = N = len(dataset) + self.total_size = N + self.epoch = None + self.epochf = 0.0 + + def __len__(self): + return self.total_size + + def set_epoch(self, epoch): + self.epoch = epoch + + def __iter__(self): + if self.epoch is None: + raise ValueError( + "Epoch number not set. Please call 'set_epoch(epoch)' before iterating." + ) + + seed = self.epoch + 788 + rng = np.random.default_rng(seed=seed) + # random indices (will restart from 0 if not drop_last) + sample_idxs = np.arange(self.total_size) + rng.shuffle(sample_idxs) + # random feat_idxs (same across each batch) + n_batches = (self.total_size + self.batch_size - 1) // self.batch_size + if self.pool_size > 1: + p = np.ones(self.pool_size) + p[: self.pool_size // 2] *= 2 + p = p / p.sum() + _feat_idxs = rng.choice(self.pool_size, size=n_batches, p=p) + else: + _feat_idxs = rng.integers(self.pool_size, size=n_batches) + _feat_idxs = np.broadcast_to(_feat_idxs[:, None], (n_batches, self.batch_size)) + _feat_idxs = _feat_idxs.ravel()[: self.total_size] + _view_idxs = rng.integers( + self.min_view_size, self.max_view_size + 1, size=n_batches + ) + _view_idxs = np.broadcast_to(_view_idxs[:, None], (n_batches, self.batch_size)) + _view_idxs = _view_idxs.ravel()[: self.total_size] + + idxs = np.c_[sample_idxs, _feat_idxs, _view_idxs] + yield from (tuple(idx) for idx in idxs) + + +class BatchedRandomSampler(BatchSampler): + """Batch sampler that groups indices from RandomSampler into batches.""" + + def __init__(self, sampler: CustomRandomSampler, batch_size, drop_last=True): + self.sampler = sampler # An instance of RandomSampler + self.batch_size = batch_size + self.drop_last = drop_last + + def set_epoch(self, epoch): + self.sampler.set_epoch(epoch) + + +def round_by(total, multiple, up=False): + if up: + total = total + multiple - 1 + return (total // multiple) * multiple diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/easy_dataset.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/easy_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..604048ba0d055d9e59713b87dbab0c2fb7db6d3c --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/easy_dataset.py @@ -0,0 +1,212 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +import numpy as np +from dust3r.datasets.base.batched_sampler import ( + BatchedRandomSampler, + CustomRandomSampler, +) +import torch + + +class EasyDataset: + """a dataset that you can easily resize and combine. + Examples: + --------- + 2 * dataset ==> duplicate each element 2x + + 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary) + + dataset1 + dataset2 ==> concatenate datasets + """ + + def __add__(self, other): + return CatDataset([self, other]) + + def __rmul__(self, factor): + return MulDataset(factor, self) + + def __rmatmul__(self, factor): + return ResizedDataset(factor, self) + + def set_epoch(self, epoch): + pass # nothing to do by default + + def make_sampler( + self, batch_size, shuffle=True, drop_last=True, world_size=1, rank=0, fixed_length=False + ): + if not (shuffle): + raise NotImplementedError() # cannot deal yet + num_of_aspect_ratios = len(self._resolutions) + num_of_views = self.num_views + sampler = CustomRandomSampler( + self, + batch_size, + num_of_aspect_ratios, + 4 if not fixed_length else num_of_views, + num_of_views, + world_size, + warmup=1, + drop_last=drop_last, + ) + return BatchedRandomSampler(sampler, batch_size, drop_last) + + +class MulDataset(EasyDataset): + """Artifically augmenting the size of a dataset.""" + + multiplicator: int + + def __init__(self, multiplicator, dataset): + assert isinstance(multiplicator, int) and multiplicator > 0 + self.multiplicator = multiplicator + self.dataset = dataset + + def __len__(self): + return self.multiplicator * len(self.dataset) + + def __repr__(self): + return f"{self.multiplicator}*{repr(self.dataset)}" + + def __getitem__(self, idx): + if isinstance(idx, tuple): + idx, other, another = idx + return self.dataset[idx // self.multiplicator, other, another] + else: + return self.dataset[idx // self.multiplicator] + + @property + def _resolutions(self): + return self.dataset._resolutions + + @property + def num_views(self): + return self.dataset.num_views + + +class ResizedDataset(EasyDataset): + """Artifically changing the size of a dataset.""" + + new_size: int + + def __init__(self, new_size, dataset): + assert isinstance(new_size, int) and new_size > 0 + self.new_size = new_size + self.dataset = dataset + + def __len__(self): + return self.new_size + + def __repr__(self): + size_str = str(self.new_size) + for i in range((len(size_str) - 1) // 3): + sep = -4 * i - 3 + size_str = size_str[:sep] + "_" + size_str[sep:] + return f"{size_str} @ {repr(self.dataset)}" + + def set_epoch(self, epoch): + # this random shuffle only depends on the epoch + rng = np.random.default_rng(seed=epoch + 777) + + # shuffle all indices + perm = rng.permutation(len(self.dataset)) + + # rotary extension until target size is met + shuffled_idxs = np.concatenate( + [perm] * (1 + (len(self) - 1) // len(self.dataset)) + ) + self._idxs_mapping = shuffled_idxs[: self.new_size] + + assert len(self._idxs_mapping) == self.new_size + + def __getitem__(self, idx): + assert hasattr( + self, "_idxs_mapping" + ), "You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()" + if isinstance(idx, tuple): + idx, other, another = idx + return self.dataset[self._idxs_mapping[idx], other, another] + else: + return self.dataset[self._idxs_mapping[idx]] + + @property + def _resolutions(self): + return self.dataset._resolutions + + @property + def num_views(self): + return self.dataset.num_views + + +class CatDataset(EasyDataset): + """Concatenation of several datasets""" + + def __init__(self, datasets): + for dataset in datasets: + assert isinstance(dataset, EasyDataset) + self.datasets = datasets + self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets]) + + def __len__(self): + return self._cum_sizes[-1] + + def __repr__(self): + # remove uselessly long transform + return " + ".join( + repr(dataset).replace( + ",transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))", + "", + ) + for dataset in self.datasets + ) + + def set_epoch(self, epoch): + for dataset in self.datasets: + dataset.set_epoch(epoch) + + def __getitem__(self, idx): + other = None + if isinstance(idx, tuple): + idx, other, another = idx + + cause_error = False + while True: + + if not (0 <= idx < len(self)): + raise IndexError() + + db_idx = np.searchsorted(self._cum_sizes, idx, "right") + dataset = self.datasets[db_idx] + new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0) + + if other is not None and another is not None: + new_idx = (new_idx, other, another) + + try: + res_data = dataset[new_idx] + except Exception as e: + print(e) + print("DATA ERROR", new_idx) + idx += 1 + idx = idx % len(self) + continue + + break + return res_data + + @property + def _resolutions(self): + resolutions = self.datasets[0]._resolutions + for dataset in self.datasets[1:]: + assert tuple(dataset._resolutions) == tuple(resolutions) + return resolutions + + @property + def num_views(self): + num_views = self.datasets[0].num_views + for dataset in self.datasets[1:]: + assert dataset.num_views == num_views + return num_views diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/dynamic_replica.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/dynamic_replica.py new file mode 100644 index 0000000000000000000000000000000000000000..1d816e58be6518e1274fa84fa8c6a7cae73741ca --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/dynamic_replica.py @@ -0,0 +1,137 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class DynamicReplica(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 16 + super().__init__(*args, **kwargs) + + self.loaded_data = self._load_data(self.split) + + def _load_data(self, split): + self.scenes = os.listdir(os.path.join(self.ROOT, split)) + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, self.split, scene, "left") + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")], + key=lambda x: float(x), + ) + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=1.0, + fix_interval_prob=1.0, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id], "left") + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png")) + # Load depthmap + depthmap = np.load(osp.join(depth_dir, basename + ".npy")) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + cam = np.load(osp.join(cam_dir, basename + ".npz")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.85, 0.10, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="dynamic_replica", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/habitat_hm3d.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/habitat_hm3d.py new file mode 100644 index 0000000000000000000000000000000000000000..aa3d3422ccc4b19753630d09a39beee191bae8fe --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/habitat_hm3d.py @@ -0,0 +1,174 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class HabitatHM3D_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = False + self.max_interval = 8 + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + self.scenes = os.listdir(self.ROOT) + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + basenames = sorted( + [f[:-4] for f in os.listdir(scene_dir) if f.endswith(".npz")], + key=lambda x: int(x), + ) + + num_imgs = len(basenames) + # TODO: because current minghui's training data is backward moving, now use seq from -1 to 0 + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + start_img_ids.extend([(scene, id) for id in start_img_ids_]) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + self.invalid_scenes = {scene: False for scene in self.scenes} + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + invalid_seq = True + scene, start_id = self.start_img_ids[idx] # 获取指定索引idx对应的场景名scene和起始图像id + + # 添加最大重试次数,防止无限循环导致分布式训练卡住 + max_retries = 100 + retry_count = 0 + + while invalid_seq: + retry_count += 1 + + # 超过重试次数限制,抛出异常 + if retry_count > max_retries: + raise RuntimeError( + f"[HabitatHM3D] Failed to get valid views after {max_retries} retries. " + f"idx={idx}, scene={scene}, num_views={num_views}. " + f"This may indicate insufficient valid frames in the dataset." + ) + + # 超过50次时打印警告 + if retry_count == 50: + print(f"[HabitatHM3D WARNING] Already retried {retry_count} times for idx={idx}, scene={scene}") + + # 如果当前场景被标记为invalid则随机选择一个新的场景和起始图像id + scene_retry = 0 + while self.invalid_scenes[scene]: + scene_retry += 1 + if scene_retry > len(self.start_img_ids): + raise RuntimeError( + f"[HabitatHM3D] All scenes are invalid! Cannot find valid scene after {scene_retry} attempts." + ) + idx = rng.integers(low=0, high=len(self.start_img_ids)) + scene, start_id = self.start_img_ids[idx] + + all_image_ids = self.scene_img_list[self.sceneids[start_id]] # 获取当前场景的所有图像id列表 + pos, ordered_video = self.get_seq_from_start_id( + num_views, start_id, all_image_ids, rng, max_interval=self.max_interval + ) # 根据起始图像id和其他参数生成图像序列的索引pos 并返回有序视频 + image_idxs = np.array(all_image_ids)[pos] # 从all_image_ids提取图像序列 + + views = [] + load_failed = False + for view_idx in image_idxs: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + + basename = self.images[view_idx] + + try: + # Load RGB image + rgb_image = imread_cv2(osp.join(scene_dir, "image_" + basename + ".png")) + # Load depthmap + depthmap = imread_cv2( + osp.join(scene_dir, "depth_" + basename + ".png"), cv2.IMREAD_UNCHANGED + ) + depthmap = depthmap.astype(np.float32) / 1000 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + camera_params = np.load(osp.join(scene_dir, basename + ".npz")) + intrinsics = np.float32(camera_params["intrinsics"]) + camera_pose = np.eye(4, dtype=np.float32) + camera_pose[:3, :3] = camera_params["R_cam2world"] + camera_pose[:3, 3] = camera_params["t_cam2world"] + except Exception as e: + print(f"[HabitatHM3D] Error loading {scene} {basename}: {e}, skipping scene") + self.invalid_scenes[scene] = True + load_failed = True + break + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="habitatHM3D", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.98, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=True, + depth_only=False, + single_view=False, + reset=False, + ) + ) + + # 只有成功加载所有视图才退出循环 + if not load_failed and len(views) == num_views: + invalid_seq = False + + return views diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/hoi4d.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/hoi4d.py new file mode 100644 index 0000000000000000000000000000000000000000..b602df5d4dd1493d02377039379fd2ffb3b08ba2 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/hoi4d.py @@ -0,0 +1,84 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys +sys.path.append(osp.join(osp.dirname(__file__), '..','..')) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class HOI4D_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + scenes = os.listdir(self.ROOT) + img_names = [] + for scene in scenes: + scene_dir = osp.join(self.ROOT, scene) + rgb_dir = osp.join(scene_dir, 'rgb') + basenames = sorted([f[:-4] for f in os.listdir(rgb_dir) if f.endswith('.png')]) + img_names.extend([(scene, basename) for basename in basenames]) + + self.img_names = img_names + + def __len__(self): + return len(self.img_names) + + def get_image_num(self): + return len(self.img_names) + + def _get_views(self, idx, resolution, rng, num_views): + new_seed = rng.integers(0, 2**32) + idx + new_rng = np.random.default_rng(new_seed) + invalid_seq = True + while invalid_seq: + img_names = new_rng.choice(self.img_names, num_views, replace=False) + + views = [] + for v, img_name in enumerate(img_names): + # Load RGB image + scene, img_name = img_name + try: + rgb_image = imread_cv2(osp.join(self.ROOT, scene, "rgb", f"{img_name}.png")) + depthmap = np.load(osp.join(self.ROOT, scene, "depth", f"{img_name}.npy")) + depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) + + intrinsics = np.load(osp.join(self.ROOT, scene, "cam", f"{img_name}.npz"))["intrinsics"] + except: + print(f"Error loading {scene} {img_name}, skipping") + break + # camera pose is not provided, placeholder + camera_pose = np.eye(4) + + rgb_image, depthmap, intrinsics= self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=img_name) + + views.append(dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset='HOI4D', + label=img_name, + instance=osp.join(self.ROOT, scene, "rgb", f"{img_name}.png"), + is_metric=self.is_metric, + is_video=False, + quantile=np.array(0.99, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=True, + reset=True, + )) + if len(views) == num_views: + invalid_seq = False + return views diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/mapfree.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/mapfree.py new file mode 100644 index 0000000000000000000000000000000000000000..58eef2f61642deeca4e7accb84429f3d471a5bd9 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/mapfree.py @@ -0,0 +1,282 @@ +import os.path as osp +import numpy as np +import cv2 +import numpy as np +import itertools +import os +import sys +import pickle +import h5py +from tqdm import tqdm + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) + +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class MapFree_Multi(BaseMultiViewDataset): + + def __init__(self, ROOT, *args, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 30 + super().__init__(*args, **kwargs) + + self._load_data() + + def imgid2path(self, img_id, scene): + first_seq_id, first_frame_id = img_id + return os.path.join( + self.ROOT, + scene, + f"dense{first_seq_id}", + "rgb", + f"frame_{first_frame_id:05d}.jpg", + ) + + def path2imgid(self, subscene, filename): + first_seq_id = int(subscene[5:]) + first_frame_id = int(filename[6:-4]) + return [first_seq_id, first_frame_id] + + def _load_data(self): + cache_file = f"{self.ROOT}/cached_metadata_50_col_only.h5" + if os.path.exists(cache_file): + print(f"Loading cached metadata from {cache_file}") + with h5py.File(cache_file, "r") as hf: + self.scenes = list(map(lambda x: x.decode("utf-8"), hf["scenes"][:])) + self.sceneids = hf["sceneids"][:] + self.scope = hf["scope"][:] + self.video_flags = hf["video_flags"][:] + self.groups = hf["groups"][:] + self.id_ranges = hf["id_ranges"][:] + self.images = hf["images"][:] + else: + scene_dirs = sorted( + [ + d + for d in os.listdir(self.ROOT) + if os.path.isdir(os.path.join(self.ROOT, d)) + ] + ) + scenes = [] + sceneids = [] + groups = [] + scope = [] + images = [] + id_ranges = [] + is_video = [] + start = 0 + j = 0 + offset = 0 + + for scene in tqdm(scene_dirs): + scenes.append(scene) + # video sequences + subscenes = sorted( + [ + d + for d in os.listdir(os.path.join(self.ROOT, scene)) + if d.startswith("dense") + ] + ) + id_range_subscenes = [] + for subscene in subscenes: + rgb_paths = sorted( + [ + d + for d in os.listdir( + os.path.join(self.ROOT, scene, subscene, "rgb") + ) + if d.endswith(".jpg") + ] + ) + assert ( + len(rgb_paths) > 0 + ), f"{os.path.join(self.ROOT, scene, subscene)} is empty." + num_imgs = len(rgb_paths) + images.extend( + [self.path2imgid(subscene, rgb_path) for rgb_path in rgb_paths] + ) + id_range_subscenes.append((offset, offset + num_imgs)) + offset += num_imgs + + # image collections + metadata = pickle.load( + open(os.path.join(self.ROOT, scene, "metadata.pkl"), "rb") + ) + ref_imgs = list(metadata.keys()) + img_groups = [] + for ref_img in ref_imgs: + other_imgs = metadata[ref_img] + if len(other_imgs) + 1 < self.num_views: + continue + group = [(*other_img[0], other_img[1]) for other_img in other_imgs] + group.insert(0, (*ref_img, 1)) + img_groups.append(np.array(group)) + id_ranges.append(id_range_subscenes[ref_img[0]]) + scope.append(start) + start = start + len(group) + + num_groups = len(img_groups) + sceneids.extend([j] * num_groups) + groups.extend(img_groups) + is_video.extend([False] * num_groups) + j += 1 + + self.scenes = np.array(scenes) + self.sceneids = np.array(sceneids) + self.scope = np.array(scope) + self.video_flags = np.array(is_video) + self.groups = np.concatenate(groups, 0) + self.id_ranges = np.array(id_ranges) + self.images = np.array(images) + + data = dict( + scenes=self.scenes, + sceneids=self.sceneids, + scope=self.scope, + video_flags=self.video_flags, + groups=self.groups, + id_ranges=self.id_ranges, + images=self.images, + ) + + with h5py.File(cache_file, "w") as h5f: + h5f.create_dataset( + "scenes", + data=data["scenes"].astype(object), + dtype=h5py.string_dtype(encoding="utf-8"), + compression="lzf", + chunks=True, + ) + h5f.create_dataset( + "sceneids", data=data["sceneids"], compression="lzf", chunks=True + ) + h5f.create_dataset( + "scope", data=data["scope"], compression="lzf", chunks=True + ) + h5f.create_dataset( + "video_flags", + data=data["video_flags"], + compression="lzf", + chunks=True, + ) + h5f.create_dataset( + "groups", data=data["groups"], compression="lzf", chunks=True + ) + h5f.create_dataset( + "id_ranges", data=data["id_ranges"], compression="lzf", chunks=True + ) + h5f.create_dataset( + "images", data=data["images"], compression="lzf", chunks=True + ) + + def __len__(self): + return len(self.scope) + + def get_image_num(self): + return len(self.images) + + def get_stats(self): + return f"{len(self)} groups of views" + + def _get_views(self, idx, resolution, rng, num_views): + scene = self.scenes[self.sceneids[idx]] + if rng.random() < 0.6: + ids = np.arange(self.id_ranges[idx][0], self.id_ranges[idx][1]) + cut_off = num_views if not self.allow_repeat else max(num_views // 3, 3) + start_ids = ids[: len(ids) - cut_off + 1] + start_id = rng.choice(start_ids) + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + ids.tolist(), + rng, + max_interval=self.max_interval, + video_prob=0.8, + fix_interval_prob=0.5, + block_shuffle=16, + ) + ids = np.array(ids)[pos] + image_idxs = self.images[ids] + else: + ordered_video = False + seq_start_index = self.scope[idx] + seq_end_index = self.scope[idx + 1] if idx < len(self.scope) - 1 else None + image_idxs = ( + self.groups[seq_start_index:seq_end_index] + if seq_end_index is not None + else self.groups[seq_start_index:] + ) + image_idxs, overlap_scores = image_idxs[:, :2], image_idxs[:, 2] + replace = ( + True + if self.allow_repeat + or len(overlap_scores[overlap_scores > 0]) < num_views + else False + ) + image_idxs = rng.choice( + image_idxs, + num_views, + replace=replace, + p=overlap_scores / np.sum(overlap_scores), + ) + image_idxs = image_idxs.astype(np.int64) + + views = [] + for v, view_idx in enumerate(image_idxs): + img_path = self.imgid2path(view_idx, scene) + depth_path = img_path.replace("rgb", "depth").replace(".jpg", ".npy") + cam_path = img_path.replace("rgb", "cam").replace(".jpg", ".npz") + sky_mask_path = img_path.replace("rgb", "sky_mask") + image = imread_cv2(img_path) + depthmap = np.load(depth_path) + camera_params = np.load(cam_path) + sky_mask = cv2.imread(sky_mask_path, cv2.IMREAD_UNCHANGED) >= 127 + + intrinsics = camera_params["intrinsic"].astype(np.float32) + camera_pose = camera_params["pose"].astype(np.float32) + + depthmap[sky_mask] = -1.0 + depthmap[depthmap > 400.0] = 0.0 + depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) + threshold = ( + np.percentile(depthmap[depthmap > 0], 98) + if depthmap[depthmap > 0].size > 0 + else 0 + ) + depthmap[depthmap > threshold] = 0.0 + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(img_path) + ) + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="MapFree", + label=img_path, + is_metric=self.is_metric, + instance=img_path, + is_video=ordered_video, + quantile=np.array(0.96, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/mvs_synth.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/mvs_synth.py new file mode 100644 index 0000000000000000000000000000000000000000..09f1b1a85364a8de08813396d76762a2f8f2c966 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/mvs_synth.py @@ -0,0 +1,144 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_pil + + +class MVS_Synth_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = False + self.max_interval = 4 + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + print('DATA: mvs_synth', len(self)) + + def _load_data(self): + self.scenes = os.listdir(self.ROOT) + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".jpg")] + ) + num_imgs = len(basenames) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + img_ids = list(np.arange(num_imgs) + offset) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=1.0, + fix_interval_prob=1.0, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_pil(osp.join(rgb_dir, basename + ".jpg")) + # Load depthmap + depthmap = np.load(osp.join(depth_dir, basename + ".npy")) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + threshold = ( + np.percentile(depthmap[depthmap > 0], 98) + if depthmap[depthmap > 0].size > 0 + else 0 + ) + depthmap[depthmap > threshold] = 0.0 + depthmap[depthmap > 1000] = 0.0 + + cam = np.load(osp.join(cam_dir, basename + ".npz")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.8, 0.15, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="MVS_Synth", + label=self.scenes[scene_id] + "_" + basename, + instance=osp.join(rgb_dir, basename + ".jpg"), + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/omniobject3d.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/omniobject3d.py new file mode 100644 index 0000000000000000000000000000000000000000..1d8e1019c94e30c70dd1d9dd2d50ff9dee46b924 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/omniobject3d.py @@ -0,0 +1,146 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys +import json + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 +import re + + +def extract_number(filename): + match = re.search(r"\d+", filename) + if match: + return int(match.group()) + return 0 + + +class OmniObject3D_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = False + self.is_metric = False # True + super().__init__(*args, **kwargs) + + self.loaded_data = self._load_data() + + def _load_data(self): + self.scenes = [ + d + for d in os.listdir(self.ROOT) + if os.path.isdir(os.path.join(self.ROOT, d)) and not d.startswith('.') + ] + with open(os.path.join(self.ROOT, "scale.json"), "r") as f: + self.scales = json.load(f) + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")], + key=extract_number, + ) + + num_imgs = len(basenames) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + img_ids = list(np.arange(num_imgs) + offset) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + start_img_ids.extend([(scene, id) for id in start_img_ids_]) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + scene, start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, start_id, all_image_ids, rng, max_interval=100, video_prob=0.0 + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png")) + depthmap = np.load(osp.join(depth_dir, basename + ".npy")) + cam = np.load(osp.join(cam_dir, basename + ".npz")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + scale = self.scales[self.scenes[scene_id]] + depthmap = depthmap / scale / 1000.0 + camera_pose[:3, 3] = camera_pose[:3, 3] / scale / 1000.0 + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.8, 0.15, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="OmniObject3D", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/pointodyssey.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/pointodyssey.py new file mode 100644 index 0000000000000000000000000000000000000000..9ced302f1bdaed09fc2294fd6c3a7dd8e248f964 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/pointodyssey.py @@ -0,0 +1,178 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class PointOdyssey_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 4 + super().__init__(*args, **kwargs) + assert self.split in ["train", "test", "val"] + self.scenes_to_use = [ + # 'cab_h_bench_3rd', 'cab_h_bench_ego1', 'cab_h_bench_ego2', + "cnb_dlab_0215_3rd", + "cnb_dlab_0215_ego1", + "cnb_dlab_0225_3rd", + "cnb_dlab_0225_ego1", + "dancing", + "dancingroom0_3rd", + "footlab_3rd", + "footlab_ego1", + "footlab_ego2", + "girl", + "girl_egocentric", + "human_egocentric", + "human_in_scene", + "human_in_scene1", + "kg", + "kg_ego1", + "kg_ego2", + "kitchen_gfloor", + "kitchen_gfloor_ego1", + "kitchen_gfloor_ego2", + "scene_carb_h_tables", + "scene_carb_h_tables_ego1", + "scene_carb_h_tables_ego2", + "scene_j716_3rd", + "scene_j716_ego1", + "scene_j716_ego2", + "scene_recording_20210910_S05_S06_0_3rd", + "scene_recording_20210910_S05_S06_0_ego2", + "scene1_0129", + "scene1_0129_ego", + "seminar_h52_3rd", + "seminar_h52_ego1", + "seminar_h52_ego2", + ] + self.loaded_data = self._load_data(self.split) + + def _load_data(self, split): + root = os.path.join(self.ROOT, split) + self.scenes = [] + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(os.listdir(root)): + if scene not in self.scenes_to_use: + continue + scene_dir = osp.join(root, scene) + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".jpg")] + ) + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + # start_img_ids_ = img_ids[:-self.num_views+1] + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=1.0, + fix_interval_prob=1.0, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".jpg")) + # Load depthmap + depthmap = np.load(osp.join(depth_dir, basename + ".npy")) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + depthmap[depthmap > 1000] = 0.0 + + cam = np.load(osp.join(cam_dir, basename + ".npz")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.9, 0.05, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="PointOdyssey", + label=self.scenes[scene_id] + "_" + basename, + instance=osp.join(rgb_dir, basename + ".jpg"), + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/realestate10k.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/realestate10k.py new file mode 100644 index 0000000000000000000000000000000000000000..34526946529905640be4ee49d0530b950bafdb04 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/realestate10k.py @@ -0,0 +1,139 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class RE10K_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = False + self.max_interval = 128 + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + self.scenes = os.listdir(self.ROOT) + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")], + key=lambda x: int(x), + ) + + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + start_img_ids.extend([(scene, id) for id in start_img_ids_]) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + self.invalid_scenes = {scene: False for scene in self.scenes} + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + invalid_seq = True + scene, start_id = self.start_img_ids[idx] + + while invalid_seq: + while self.invalid_scenes[scene]: + idx = rng.integers(low=0, high=len(self.start_img_ids)) + scene, start_id = self.start_img_ids[idx] + + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, start_id, all_image_ids, rng, max_interval=self.max_interval + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for view_idx in image_idxs: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + try: + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png")) + # Load depthmap, no depth, set to all ones + depthmap = np.ones_like(rgb_image[..., 0], dtype=np.float32) + cam = np.load(osp.join(cam_dir, basename + ".npz")) + intrinsics = cam["intrinsics"] + camera_pose = cam["pose"] + except: + print(f"Error loading {scene} {basename}, skipping") + self.invalid_scenes[scene] = True + break + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="realestate10k", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.98, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=True, + depth_only=False, + single_view=False, + reset=False, + ) + ) + if len(views) == num_views: + invalid_seq = False + return views diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/scannet.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/scannet.py new file mode 100644 index 0000000000000000000000000000000000000000..0b6644615d2e9761a2a3cec8178a22be5f316afa --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/scannet.py @@ -0,0 +1,149 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2, imread_pil + + +class ScanNet_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 30 + super().__init__(*args, **kwargs) + + self.loaded_data = self._load_data(self.split) + print('DATA: scannet', len(self)) + + def _load_data(self, split): + self.scene_root = osp.join( + self.ROOT, "scans_train" if split == "train" else "scans_test" + ) + self.scenes = [ + scene for scene in os.listdir(self.scene_root) if scene.startswith("scene") + ] + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.scene_root, scene) + with np.load( + osp.join(scene_dir, "new_scene_metadata.npz"), allow_pickle=True + ) as data: + basenames = data["images"] + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=0.6, + fix_interval_prob=0.6, + block_shuffle=16, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.scene_root, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "color") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_pil(osp.join(rgb_dir, basename + ".jpg")) + # Load depthmap + depthmap = imread_cv2( + osp.join(depth_dir, basename + ".png"), cv2.IMREAD_UNCHANGED + ) + depthmap = depthmap.astype(np.float32) / 1000 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + cam = np.load(osp.join(cam_dir, basename + ".npz")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="ScanNet", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.98, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/scannetpp.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/scannetpp.py new file mode 100644 index 0000000000000000000000000000000000000000..4cca5c0ade3ccf79f97f31e5f30a823e032152c6 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/scannetpp.py @@ -0,0 +1,211 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) + +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2, imread_pil + + +class ScanNetpp_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 3 + super().__init__(*args, **kwargs) + assert self.split == "train" + self.loaded_data = self._load_data() + + def _load_data(self): + with np.load(osp.join(self.ROOT, "all_metadata.npz")) as data: + self.scenes = data["scenes"] + offset = 0 + scenes = [] + sceneids = [] + images = [] + intrinsics = [] + trajectories = [] + groups = [] + id_ranges = [] + j = 0 + self.image_num = 0 + for scene in self.scenes: + scene_dir = osp.join(self.ROOT, scene) + with np.load( + osp.join(scene_dir, "new_scene_metadata.npz"), allow_pickle=True + ) as data: + imgs = data["images"] + self.image_num += len(imgs) + img_ids = np.arange(len(imgs)).tolist() + intrins = data["intrinsics"] + traj = data["trajectories"] + imgs_on_disk = sorted(os.listdir(osp.join(scene_dir, "images"))) + imgs_on_disk = list(map(lambda x: x[:-4], imgs_on_disk)) + + dslr_ids = [ + i + offset + for i in img_ids + if imgs[i].startswith("DSC") and imgs[i] in imgs_on_disk + ] + iphone_ids = [ + i + offset + for i in img_ids + if imgs[i].startswith("frame") and imgs[i] in imgs_on_disk + ] + + num_imgs = len(imgs) + assert max(dslr_ids) < min(iphone_ids) + assert "image_collection" in data + + img_groups = [] + img_id_ranges = [] + + # 使用与其他数据集一致的 cut_off 逻辑 + min_group_len = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + + for ref_id, group in data["image_collection"].item().items(): + if len(group) + 1 < min_group_len: + continue + group.insert(0, (ref_id, 1.0)) + sorted_group = sorted(group, key=lambda x: x[1], reverse=True) + group = [int(x[0] + offset) for x in sorted_group] + + # 确定对应的视频帧列表 + if imgs[ref_id].startswith("frame"): + video_ids = dslr_ids + else: + video_ids = iphone_ids + + # 只有当视频帧列表足够长时才添加 + if len(video_ids) >= min_group_len: + img_groups.append(sorted(group)) + img_id_ranges.append(video_ids) + + if len(img_groups) == 0: + print(f"Skipping {scene}") + continue + scenes.append(scene) + sceneids.extend([j] * num_imgs) + images.extend(imgs) + intrinsics.append(intrins) + trajectories.append(traj) + + # offset groups + groups.extend(img_groups) + id_ranges.extend(img_id_ranges) + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.intrinsics = np.concatenate(intrinsics, axis=0) + self.trajectories = np.concatenate(trajectories, axis=0) + self.id_ranges = id_ranges + self.groups = groups + + def __len__(self): + return len(self.groups) * 10 + + def get_image_num(self): + return self.image_num + + def _get_views(self, idx, resolution, rng, num_views): + idx = idx // 10 + image_idxs = self.groups[idx] + rand_val = rng.random() + + image_idxs_video = self.id_ranges[idx] + cut_off = num_views if not self.allow_repeat else max(num_views // 3, 3) + start_image_idxs = image_idxs_video[: len(image_idxs_video) - cut_off + 1] + + if rand_val < 0.7 and len(start_image_idxs) > 0: + start_id = rng.choice(start_image_idxs) + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + image_idxs_video, + rng, + max_interval=self.max_interval, + video_prob=0.8, + fix_interval_prob=0.5, + block_shuffle=16, + ) + image_idxs = np.array(image_idxs_video)[pos] + + else: + ordered_video = True + # ordered video with varying intervals + num_candidates = len(image_idxs) + max_id = min(num_candidates, int(num_views * (2 + 2 * rng.random()))) + + # 确保有足够的候选帧 + if num_candidates < num_views: + # 如果候选帧不足,使用重复采样 + image_idxs = sorted(rng.choice(image_idxs, size=num_views, replace=True)) + else: + image_idxs = sorted(rng.permutation(image_idxs[:max_id])[:num_views]) + + if rand_val > 0.75: + ordered_video = False + image_idxs = rng.permutation(image_idxs) + + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + + intrinsics = self.intrinsics[view_idx] + camera_pose = self.trajectories[view_idx] + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_pil(osp.join(scene_dir, "images", basename + ".jpg")) + # Load depthmap + depthmap = imread_cv2( + osp.join(scene_dir, "depth", basename + ".png"), cv2.IMREAD_UNCHANGED + ) + depthmap = depthmap.astype(np.float32) / 1000 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="ScanNet++", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.99, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/smartportraits.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/smartportraits.py new file mode 100644 index 0000000000000000000000000000000000000000..a5955aecd651f2bf1f6a666b0869b5d97816cf5f --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/smartportraits.py @@ -0,0 +1,85 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class SmartPortraits_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + scenes = os.listdir(self.ROOT) + img_names = [] + for scene in scenes: + scene_dir = osp.join(self.ROOT, scene) + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")] + ) + img_names.extend([(scene, basename) for basename in basenames]) + + self.img_names = img_names + + def __len__(self): + return len(self.img_names) + + def get_image_num(self): + return len(self.img_names) + + def _get_views(self, idx, resolution, rng, num_views): + new_seed = rng.integers(0, 2**32) + idx + new_rng = np.random.default_rng(new_seed) + img_names = new_rng.choice(self.img_names, num_views, replace=False) + + views = [] + for v, img_name in enumerate(img_names): + # Load RGB image + scene, img_name = img_name + rgb_image = imread_cv2(osp.join(self.ROOT, scene, "rgb", f"{img_name}.png")) + depthmap = np.load(osp.join(self.ROOT, scene, "depth", f"{img_name}.npy")) + depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) + + intrinsics = np.load(osp.join(self.ROOT, scene, "cam", f"{img_name}.npz"))[ + "intrinsics" + ] + # camera pose is not provided, placeholder + camera_pose = np.eye(4) + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=img_name + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="SmartPortraits", + label=img_name, + instance=osp.join(self.ROOT, scene, "rgb", f"{img_name}.png"), + is_metric=self.is_metric, + is_video=False, + quantile=np.array(0.98, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=True, + reset=True, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/threedkb.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/threedkb.py new file mode 100644 index 0000000000000000000000000000000000000000..face09abd00f76cd62e7654b1b673e9d1d3394b7 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/threedkb.py @@ -0,0 +1,111 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class ThreeDKenBurns(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = False + self.is_metric = False + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + self.scenes = os.listdir(self.ROOT) + + offset = 0 + scenes = [] + sceneids = [] + images = [] + img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")] + ) + + num_imgs = len(basenames) + img_ids_ = list(np.arange(num_imgs) + offset) + + img_ids.extend(img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.img_ids = img_ids + + def __len__(self): + return len(self.img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + new_seed = rng.integers(0, 2**32) + idx + new_rng = np.random.default_rng(new_seed) + image_idxs = new_rng.choice(self.img_ids, num_views, replace=False) + + views = [] + for view_idx in image_idxs: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png")) + depthmap = imread_cv2(osp.join(depth_dir, basename + ".exr")) + depthmap[depthmap > 20000] = 0.0 + depthmap = depthmap / 1000.0 + cam = np.load(osp.join(cam_dir, basename + ".npz")) + intrinsics = cam["intrinsics"] + camera_pose = np.eye(4) + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="3DKenBurns", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=False, + quantile=np.array(1.0, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=True, + reset=True, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/unreal4k.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/unreal4k.py new file mode 100644 index 0000000000000000000000000000000000000000..4d9092928daacf527c99e1958bbee85ef9110035 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/unreal4k.py @@ -0,0 +1,159 @@ +import os.path as osp +import numpy as np +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) + +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + +R_conv = np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]).astype( + np.float32 +) + + +class UnReal4K_Multi(BaseMultiViewDataset): + + def __init__(self, ROOT, *args, **kwargs): + self.ROOT = ROOT + self.max_interval = 2 + self.is_metric = True + super().__init__(*args, **kwargs) + # loading all + assert self.split is None + self._load_data() + + def _load_data(self): + scene_dirs = sorted( + [ + d + for d in os.listdir(self.ROOT) + if os.path.isdir(os.path.join(self.ROOT, d)) + ] + ) + + offset = 0 + scenes = [] + sceneids = [] + images = [] + start_img_ids = [] + scene_img_list = [] + j = 0 + + seq_dirs = sorted( + [ + os.path.join(self.ROOT, scene, mode) + for scene in scene_dirs + for mode in ["0", "1"] + ] + ) + for seq_dir in seq_dirs: + basenames = sorted( + [f[:-8] for f in os.listdir(seq_dir) if f.endswith(".png")] + ) + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + # start_img_ids_ = img_ids[:-self.num_views+1] + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + if num_imgs < cut_off: + print(f"Skipping {seq_dir}") + continue + + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(seq_dir) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + def __len__(self): + return len(self.start_img_ids) * 10 + + def get_image_num(self): + return len(self.images) + + def get_stats(self): + return f"{len(self)//10} groups of views" + + def _get_views(self, idx, resolution, rng, num_views): + idx = idx // 10 + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, start_id, all_image_ids, rng, max_interval=self.max_interval + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = self.scenes[scene_id] + basename = self.images[view_idx] + + img = basename + "_rgb.png" + image = imread_cv2(osp.join(scene_dir, img)) + depthmap = np.load(osp.join(scene_dir, basename + "_depth.npy")) + camera_params = np.load(osp.join(scene_dir, basename + ".npz")) + + intrinsics = camera_params["intrinsics"].astype(np.float32) + camera_pose = camera_params["cam2world"].astype(np.float32) + + camera_pose = R_conv @ camera_pose + + sky_mask = depthmap >= 1000 + depthmap[sky_mask] = -1.0 # sky + threshold = ( + np.percentile(depthmap[depthmap > 0], 98) + if depthmap[depthmap > 0].size > 0 + else 0 + ) + depthmap[depthmap > threshold] = 0.0 + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(scene_dir, img) + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="UnReal4K", + label=scene_dir, + is_metric=self.is_metric, + instance=scene_dir + "_" + img, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/__init__.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/corr.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/corr.py new file mode 100644 index 0000000000000000000000000000000000000000..a0413d4cc035f21acd9b02fb2bccebe36ab57736 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/corr.py @@ -0,0 +1,129 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +import numpy as np +from dust3r.utils.device import to_numpy +from dust3r.utils.geometry import inv, geotrf + + +def reproject_view(pts3d, view2): + shape = view2["pts3d"].shape[:2] + return reproject( + pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape + ) + + +def reproject(pts3d, K, world2cam, shape): + H, W, THREE = pts3d.shape + assert THREE == 3 + + # reproject in camera2 space + with np.errstate(divide="ignore", invalid="ignore"): + pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2) + + # quantize to pixel positions + return (H, W), ravel_xy(pos, shape) + + +def ravel_xy(pos, shape): + H, W = shape + with np.errstate(invalid="ignore"): + qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T + quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip( + min=0, max=H - 1, out=qy + ) + return quantized_pos + + +def unravel_xy(pos, shape): + # convert (x+W*y) back to 2d (x,y) coordinates + return np.unravel_index(pos, shape)[0].base[:, ::-1].copy() + + +def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False): + is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2)) + pos1 = is_reciprocal1.nonzero()[0] + pos2 = corres_1_to_2[pos1] + if ret_recip: + return is_reciprocal1, pos1, pos2 + return pos1, pos2 + + +def extract_correspondences_from_pts3d( + view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0 +): + view1, view2 = to_numpy((view1, view2)) + # project pixels from image1 --> 3d points --> image2 pixels + shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2) + shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1) + + # compute reciprocal correspondences: + # pos1 == valid pixels (correspondences) in image1 + is_reciprocal1, pos1, pos2 = reciprocal_1d( + corres1_to_2, corres2_to_1, ret_recip=True + ) + is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1)) + + if target_n_corres is None: + if ret_xy: + pos1 = unravel_xy(pos1, shape1) + pos2 = unravel_xy(pos2, shape2) + return pos1, pos2 + + available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum()) + target_n_positives = int(target_n_corres * (1 - nneg)) + n_positives = min(len(pos1), target_n_positives) + n_negatives = min(target_n_corres - n_positives, available_negatives) + + if n_negatives + n_positives != target_n_corres: + # should be really rare => when there are not enough negatives + # in that case, break nneg and add a few more positives ? + n_positives = target_n_corres - n_negatives + assert n_positives <= len(pos1) + + assert n_positives <= len(pos1) + assert n_positives <= len(pos2) + assert n_negatives <= (~is_reciprocal1).sum() + assert n_negatives <= (~is_reciprocal2).sum() + assert n_positives + n_negatives == target_n_corres + + valid = np.ones(n_positives, dtype=bool) + if n_positives < len(pos1): + # random sub-sampling of valid correspondences + perm = rng.permutation(len(pos1))[:n_positives] + pos1 = pos1[perm] + pos2 = pos2[perm] + + if n_negatives > 0: + # add false correspondences if not enough + def norm(p): + return p / p.sum() + + pos1 = np.r_[ + pos1, + rng.choice( + shape1[0] * shape1[1], + size=n_negatives, + replace=False, + p=norm(~is_reciprocal1), + ), + ] + pos2 = np.r_[ + pos2, + rng.choice( + shape2[0] * shape2[1], + size=n_negatives, + replace=False, + p=norm(~is_reciprocal2), + ), + ] + valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)] + + # convert (x+W*y) back to 2d (x,y) coordinates + if ret_xy: + pos1 = unravel_xy(pos1, shape1) + pos2 = unravel_xy(pos2, shape2) + return pos1, pos2, valid diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/cropping.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/cropping.py new file mode 100644 index 0000000000000000000000000000000000000000..6074f0d93b54ef5af36189276e0f179825a525fe --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/cropping.py @@ -0,0 +1,147 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# croppping utilities +# -------------------------------------------------------- +import PIL.Image +import os + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 # noqa +import numpy as np # noqa +from dust3r.utils.geometry import ( + colmap_to_opencv_intrinsics, + opencv_to_colmap_intrinsics, +) # noqa + +try: + lanczos = PIL.Image.Resampling.LANCZOS + bicubic = PIL.Image.Resampling.BICUBIC +except AttributeError: + lanczos = PIL.Image.LANCZOS + bicubic = PIL.Image.BICUBIC + + +class ImageList: + """Convenience class to aply the same operation to a whole set of images.""" + + def __init__(self, images): + if not isinstance(images, (tuple, list, set)): + images = [images] + self.images = [] + for image in images: + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + self.images.append(image) + + def __len__(self): + return len(self.images) + + def to_pil(self): + return tuple(self.images) if len(self.images) > 1 else self.images[0] + + @property + def size(self): + sizes = [im.size for im in self.images] + assert all(sizes[0] == s for s in sizes) + return sizes[0] + + def resize(self, *args, **kwargs): + return ImageList(self._dispatch("resize", *args, **kwargs)) + + def crop(self, *args, **kwargs): + return ImageList(self._dispatch("crop", *args, **kwargs)) + + def _dispatch(self, func, *args, **kwargs): + return [getattr(im, func)(*args, **kwargs) for im in self.images] + + +def rescale_image_depthmap( + image, depthmap, camera_intrinsics, output_resolution, force=True +): + """Jointly rescale a (image, depthmap) + so that (out_width, out_height) >= output_res + """ + image = ImageList(image) + input_resolution = np.array(image.size) # (W,H) + output_resolution = np.array(output_resolution) + if depthmap is not None: + # can also use this with masks instead of depthmaps + assert tuple(depthmap.shape[:2]) == image.size[::-1] + + # define output resolution + assert output_resolution.shape == (2,) + scale_final = max(output_resolution / image.size) + 1e-8 + if scale_final >= 1 and not force: # image is already smaller than what is asked + return (image.to_pil(), depthmap, camera_intrinsics) + output_resolution = np.floor(input_resolution * scale_final).astype(int) + + # first rescale the image so that it contains the crop + image = image.resize( + output_resolution, resample=lanczos if scale_final < 1 else bicubic + ) + if depthmap is not None: + depthmap = cv2.resize( + depthmap, + output_resolution, + fx=scale_final, + fy=scale_final, + interpolation=cv2.INTER_NEAREST, + ) + + # no offset here; simple rescaling + camera_intrinsics = camera_matrix_of_crop( + camera_intrinsics, input_resolution, output_resolution, scaling=scale_final + ) + + return image.to_pil(), depthmap, camera_intrinsics + + +def camera_matrix_of_crop( + input_camera_matrix, + input_resolution, + output_resolution, + scaling=1, + offset_factor=0.5, + offset=None, +): + # Margins to offset the origin + margins = np.asarray(input_resolution) * scaling - output_resolution + assert np.all(margins >= 0.0) + if offset is None: + offset = offset_factor * margins + + # Generate new camera parameters + output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix) + output_camera_matrix_colmap[:2, :] *= scaling + output_camera_matrix_colmap[:2, 2] -= offset + output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap) + + return output_camera_matrix + + +def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox): + """ + Return a crop of the input view. + """ + image = ImageList(image) + l, t, r, b = crop_bbox + + image = image.crop((l, t, r, b)) + depthmap = depthmap[t:b, l:r] + + camera_intrinsics = camera_intrinsics.copy() + camera_intrinsics[0, 2] -= l + camera_intrinsics[1, 2] -= t + + return image.to_pil(), depthmap, camera_intrinsics + + +def bbox_from_intrinsics_in_out( + input_camera_matrix, output_camera_matrix, output_resolution +): + out_width, out_height = output_resolution + l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])) + crop_bbox = (l, t, l + out_width, t + out_height) + return crop_bbox diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/transforms.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..39a4450e57e3482315e307e72c0f3b19e77dea3b --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/transforms.py @@ -0,0 +1,80 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# DUST3R default transforms +# -------------------------------------------------------- +import torchvision.transforms as tvf +from dust3r.utils.image import ImgNorm + +# define the standard image transforms +ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) + + +def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True): + if isinstance(value, (int, float)): + if value < 0: + raise ValueError(f"If is a single number, it must be non negative.") + value = [center - float(value), center + float(value)] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + value = [float(value[0]), float(value[1])] + else: + raise TypeError(f"should be a single number or a list/tuple with length 2.") + + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f"values should be between {bound}, but got {value}.") + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + return None + else: + return tuple(value) + + +import torch +import torchvision.transforms.functional as F + + +def SeqColorJitter(): + """ + Return a color jitter transform with same random parameters + """ + brightness = _check_input(0.5) + contrast = _check_input(0.5) + saturation = _check_input(0.5) + hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + + fn_idx = torch.randperm(4) + brightness_factor = ( + None + if brightness is None + else float(torch.empty(1).uniform_(brightness[0], brightness[1])) + ) + contrast_factor = ( + None + if contrast is None + else float(torch.empty(1).uniform_(contrast[0], contrast[1])) + ) + saturation_factor = ( + None + if saturation is None + else float(torch.empty(1).uniform_(saturation[0], saturation[1])) + ) + hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) + + def _color_jitter(img): + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = F.adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = F.adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = F.adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = F.adjust_hue(img, hue_factor) + return ImgNorm(img) + + return _color_jitter diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/waymo.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/waymo.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f811f144c638b931cb99fd246702a0fa2d18e7 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/waymo.py @@ -0,0 +1,178 @@ +import os.path as osp +import os +import numpy as np +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import h5py +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class Waymo_Multi(BaseMultiViewDataset): + """Dataset of outdoor street scenes, 5 images each time""" + + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.max_interval = 8 + self.video = True + self.is_metric = True + super().__init__(*args, **kwargs) + assert self.split is None + self._load_data() + + def load_invalid_dict(self, h5_file_path): + invalid_dict = {} + with h5py.File(h5_file_path, "r") as h5f: + for scene in h5f: + data = h5f[scene]["invalid_pairs"][:] + invalid_pairs = set( + tuple(pair.decode("utf-8").split("_")) for pair in data + ) + invalid_dict[scene] = invalid_pairs + return invalid_dict + + def _load_data(self): + invalid_dict = self.load_invalid_dict( + os.path.join(self.ROOT, "invalid_files.h5") + ) + scene_dirs = sorted( + [ + d + for d in os.listdir(self.ROOT) + if os.path.isdir(os.path.join(self.ROOT, d)) + ] + ) + offset = 0 + scenes = [] + sceneids = [] + images = [] + start_img_ids = [] + scene_img_list = [] + is_video = [] + j = 0 + + for scene in scene_dirs: + scene_dir = osp.join(self.ROOT, scene) + invalid_pairs = invalid_dict.get(scene, set()) + seq2frames = {} + for f in os.listdir(scene_dir): + if not f.endswith(".jpg"): + continue + basename = f[:-4] + frame_id = basename.split("_")[0] + seq_id = basename.split("_")[1] + if seq_id == "5": + continue + if (seq_id, frame_id) in invalid_pairs: + continue # Skip invalid files + if seq_id not in seq2frames: + seq2frames[seq_id] = [] + seq2frames[seq_id].append(frame_id) + + for seq_id, frame_ids in seq2frames.items(): + frame_ids = sorted(frame_ids) + num_imgs = len(frame_ids) + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + if num_imgs < cut_off: + print(f"Skipping {scene}_{seq_id}") + continue + + scenes.append((scene, seq_id)) + sceneids.extend([j] * num_imgs) + images.extend(frame_ids) + start_img_ids.extend(start_img_ids_) + scene_img_list.append(img_ids) + + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + self.is_video = is_video + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def get_stats(self): + return f"{len(self)} groups of views" + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + _, seq_id = self.scenes[self.sceneids[start_id]] + max_interval = self.max_interval // 2 if seq_id == "4" else self.max_interval + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=max_interval, + video_prob=0.9, + fix_interval_prob=0.9, + block_shuffle=16, + ) + image_idxs = np.array(all_image_ids)[pos] + views = [] + ordered_video = True + + views = [] + + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir, seq_id = self.scenes[scene_id] + scene_dir = osp.join(self.ROOT, scene_dir) + frame_id = self.images[view_idx] + + impath = f"{frame_id}_{seq_id}" + image = imread_cv2(osp.join(scene_dir, impath + ".jpg")) + depthmap = imread_cv2(osp.join(scene_dir, impath + ".exr")) + camera_params = np.load(osp.join(scene_dir, impath + ".npz")) + + intrinsics = np.float32(camera_params["intrinsics"]) + camera_pose = np.float32(camera_params["cam2world"]) + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(scene_dir, impath) + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.85, 0.10, 0.05] + ) + + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="Waymo", + label=osp.relpath(scene_dir, self.ROOT), + is_metric=self.is_metric, + instance=osp.join(scene_dir, impath + ".jpg"), + is_video=ordered_video, + quantile=np.array(0.98, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + + return views diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/wildrgbd.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/wildrgbd.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba152e19b9dae9e3ddd254d632f19d779ccffbe --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/wildrgbd.py @@ -0,0 +1,56 @@ +import os.path as osp +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import cv2 +import numpy as np + +from dust3r.datasets.co3d import Co3d_Multi +from dust3r.utils.image import imread_cv2 + + +class WildRGBD_Multi(Co3d_Multi): + def __init__(self, mask_bg="rand", *args, ROOT, **kwargs): + super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs) + self.dataset_label = "WildRGBD" + self.is_metric = True + # load all scenes + self.scenes.pop(("box", "scenes/scene_257"), None) + self.scene_list = list(self.scenes.keys()) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + self.cut_off = cut_off + self.all_ref_imgs = [ + (key, value) + for key, values in self.scenes.items() + for value in values[: len(values) - cut_off + 1] + ] + self.invalidate = {scene: {} for scene in self.scene_list} + self.invalid_scenes = {scene: False for scene in self.scene_list} + + def _get_metadatapath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "metadata", f"{view_idx:0>5d}.npz") + + def _get_impath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "rgb", f"{view_idx:0>5d}.jpg") + + def _get_depthpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "depth", f"{view_idx:0>5d}.png") + + def _get_maskpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "masks", f"{view_idx:0>5d}.png") + + def _read_depthmap(self, depthpath, input_metadata): + # We store depths in the depth scale of 1000. + # That is, when we load depth image and divide by 1000, we could get depth in meters. + depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) + depthmap = depthmap.astype(np.float32) / 1000.0 + return depthmap + + def _get_views(self, idx, resolution, rng, num_views): + views = super()._get_views(idx, resolution, rng, num_views) + for view in views: + assert view["is_metric"] + view["quantile"] = np.array(0.96, dtype=np.float32) + return views diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/__init__.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/__init__.py @@ -0,0 +1 @@ + diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/camera.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..a76b52fcae78a004f74ae4fc1a4c187b743c5e57 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/camera.py @@ -0,0 +1,463 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from croco.models.blocks import Mlp + + +inf = float("inf") + + +class PoseDecoder(nn.Module): + def __init__( + self, + hidden_size=768, + mlp_ratio=4, + pose_encoding_type="absT_quaR", + ): + super().__init__() + + self.pose_encoding_type = pose_encoding_type + if self.pose_encoding_type == "absT_quaR": + self.target_dim = 7 + + self.mlp = Mlp( + in_features=hidden_size, + hidden_features=int(hidden_size * mlp_ratio), + out_features=self.target_dim, + drop=0, + ) + + def forward( + self, + pose_feat, + ): + """ + pose_feat: BxC + preliminary_cameras: cameras in opencv coordinate. + """ + + pred_cameras = self.mlp(pose_feat) # Bx7, 3 for absT, 4 for quaR + return pred_cameras + + +class PoseEncoder(nn.Module): + def __init__( + self, + hidden_size=768, + mlp_ratio=4, + pose_mode=("exp", -inf, inf), + pose_encoding_type="absT_quaR", + ): + super().__init__() + self.pose_encoding_type = pose_encoding_type + self.pose_mode = pose_mode + + if self.pose_encoding_type == "absT_quaR": + self.target_dim = 7 + + self.embed_pose = PoseEmbedding( + target_dim=self.target_dim, + out_dim=hidden_size, + n_harmonic_functions=10, + append_input=True, + ) + self.pose_encoder = Mlp( + in_features=self.embed_pose.out_dim, + hidden_features=int(hidden_size * mlp_ratio), + out_features=hidden_size, + drop=0, + ) + + def forward(self, camera): + from dust3r.heads.postprocess import postprocess_pose + pose_enc = camera_to_pose_encoding( + camera, + pose_encoding_type=self.pose_encoding_type, + ).to(camera.dtype) + pose_enc = postprocess_pose(pose_enc, self.pose_mode, inverse=True) + pose_feat = self.embed_pose(pose_enc) + pose_feat = self.pose_encoder(pose_feat) + return pose_feat + + +class HarmonicEmbedding(torch.nn.Module): + def __init__( + self, + n_harmonic_functions: int = 6, + omega_0: float = 1.0, + logspace: bool = True, + append_input: bool = True, + ) -> None: + """ + The harmonic embedding layer supports the classical + Nerf positional encoding described in + `NeRF `_ + and the integrated position encoding in + `MIP-NeRF `_. + + During the inference you can provide the extra argument `diag_cov`. + + If `diag_cov is None`, it converts + rays parametrized with a `ray_bundle` to 3D points by + extending each ray according to the corresponding length. + Then it converts each feature + (i.e. vector along the last dimension) in `x` + into a series of harmonic features `embedding`, + where for each i in range(dim) the following are present + in embedding[...]:: + + [ + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i], # only present if append_input is True. + ] + + where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar + denoting the i-th frequency of the harmonic embedding. + + + If `diag_cov is not None`, it approximates + conical frustums following a ray bundle as gaussians, + defined by x, the means of the gaussians and diag_cov, + the diagonal covariances. + Then it converts each gaussian + into a series of harmonic features `embedding`, + where for each i in range(dim) the following are present + in embedding[...]:: + + [ + sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), + sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]), + ... + sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), + cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), + cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),, + ... + cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), + x[..., i], # only present if append_input is True. + ] + + where N equals `n_harmonic_functions-1`, and f_i is a scalar + denoting the i-th frequency of the harmonic embedding. + + If `logspace==True`, the frequencies `[f_1, ..., f_N]` are + powers of 2: + `f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)` + + If `logspace==False`, frequencies are linearly spaced between + `1.0` and `2**(n_harmonic_functions-1)`: + `f_1, ..., f_N = torch.linspace( + 1.0, 2**(n_harmonic_functions-1), n_harmonic_functions + )` + + Note that `x` is also premultiplied by the base frequency `omega_0` + before evaluating the harmonic functions. + + Args: + n_harmonic_functions: int, number of harmonic + features + omega_0: float, base frequency + logspace: bool, Whether to space the frequencies in + logspace or linear space + append_input: bool, whether to concat the original + input to the harmonic embedding. If true the + output is of the form (embed.sin(), embed.cos(), x) + """ + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange(n_harmonic_functions, dtype=torch.float32) + else: + frequencies = torch.linspace( + 1.0, + 2.0 ** (n_harmonic_functions - 1), + n_harmonic_functions, + dtype=torch.float32, + ) + + self.register_buffer("_frequencies", frequencies * omega_0, persistent=False) + self.register_buffer( + "_zero_half_pi", + torch.tensor([0.0, 0.5 * torch.pi]), + persistent=False, + ) + self.append_input = append_input + + def forward( + self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs + ) -> torch.Tensor: + """ + Args: + x: tensor of shape [..., dim] + diag_cov: An optional tensor of shape `(..., dim)` + representing the diagonal covariance matrices of our Gaussians, joined with x + as means of the Gaussians. + + Returns: + embedding: a harmonic embedding of `x` of shape + [..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray] + """ + + embed = x[..., None] * self._frequencies + + embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None] + + embed = embed.sin() + if diag_cov is not None: + x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2) + exp_var = torch.exp(-0.5 * x_var) + + embed = embed * exp_var[..., None, :, :] + + embed = embed.reshape(*x.shape[:-1], -1) + + if self.append_input: + return torch.cat([embed, x], dim=-1) + return embed + + @staticmethod + def get_output_dim_static( + input_dims: int, n_harmonic_functions: int, append_input: bool + ) -> int: + """ + Utility to help predict the shape of the output of `forward`. + + Args: + input_dims: length of the last dimension of the input tensor + n_harmonic_functions: number of embedding frequencies + append_input: whether or not to concat the original + input to the harmonic embedding + Returns: + int: the length of the last dimension of the output tensor + """ + return input_dims * (2 * n_harmonic_functions + int(append_input)) + + def get_output_dim(self, input_dims: int = 3) -> int: + """ + Same as above. The default for input_dims is 3 for 3D applications + which use harmonic embedding for positional encoding, + so the input might be xyz. + """ + return self.get_output_dim_static( + input_dims, len(self._frequencies), self.append_input + ) + + +class PoseEmbedding(nn.Module): + def __init__(self, target_dim, out_dim, n_harmonic_functions=10, append_input=True): + super().__init__() + + self._emb_pose = HarmonicEmbedding( + n_harmonic_functions=n_harmonic_functions, append_input=append_input + ) + + self.out_dim = self._emb_pose.get_output_dim(target_dim) + + def forward(self, pose_encoding): + e_pose_encoding = self._emb_pose(pose_encoding) + return e_pose_encoding + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + out = quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + return standardize_quaternion(out) + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + quaternions = F.normalize(quaternions, p=2, dim=-1) + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def camera_to_pose_encoding( + camera, + pose_encoding_type="absT_quaR", +): + """ + Inverse to pose_encoding_to_camera + camera: opencv, cam2world + """ + if pose_encoding_type == "absT_quaR": + + quaternion_R = matrix_to_quaternion(camera[:, :3, :3]) + + pose_encoding = torch.cat([camera[:, :3, 3], quaternion_R], dim=-1) + else: + raise ValueError(f"Unknown pose encoding {pose_encoding_type}") + + return pose_encoding + + +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def pose_encoding_to_camera( + pose_encoding, + pose_encoding_type="absT_quaR", +): + """ + Args: + pose_encoding: A tensor of shape `BxC`, containing a batch of + `B` `C`-dimensional pose encodings. + pose_encoding_type: The type of pose encoding, + """ + + if pose_encoding_type == "absT_quaR": + + abs_T = pose_encoding[:, :3] + quaternion_R = pose_encoding[:, 3:7] + R = quaternion_to_matrix(quaternion_R) + else: + raise ValueError(f"Unknown pose encoding {pose_encoding_type}") + + c2w_mats = torch.eye(4, 4).to(R.dtype).to(R.device) + c2w_mats = c2w_mats[None].repeat(len(R), 1, 1) + c2w_mats[:, :3, :3] = R + c2w_mats[:, :3, 3] = abs_T + + return c2w_mats + + +def quaternion_conjugate(q): + """Compute the conjugate of quaternion q (w, x, y, z).""" + + q_conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1) + return q_conj + + +def quaternion_multiply(q1, q2): + """Multiply two quaternions q1 and q2.""" + w1, x1, y1, z1 = q1.unbind(dim=-1) + w2, x2, y2, z2 = q2.unbind(dim=-1) + + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 + z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + + return torch.stack((w, x, y, z), dim=-1) + + +def rotate_vector(q, v): + """Rotate vector v by quaternion q.""" + q_vec = q[..., 1:] + q_w = q[..., :1] + + t = 2.0 * torch.cross(q_vec, v, dim=-1) + v_rot = v + q_w * t + torch.cross(q_vec, t, dim=-1) + return v_rot + + +def relative_pose_absT_quatR(t1, q1, t2, q2): + """Compute the relative translation and quaternion between two poses.""" + + q1_inv = quaternion_conjugate(q1) + + q_rel = quaternion_multiply(q1_inv, q2) + + delta_t = t2 - t1 + t_rel = rotate_vector(q1_inv, delta_t) + return t_rel, q_rel diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/device.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..ad5e8a44a0e634b4590695063f028847818bf12f --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/device.py @@ -0,0 +1,88 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +import numpy as np +import torch + + +def todevice(batch, device, callback=None, non_blocking=False): + """Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). + + batch: list, tuple, dict of tensors or other things + device: pytorch device or 'numpy' + callback: function that would be called on every sub-elements. + """ + if callback: + batch = callback(batch) + + if isinstance(batch, dict): + return {k: todevice(v, device) for k, v in batch.items()} + + if isinstance(batch, (tuple, list)): + return type(batch)(todevice(x, device) for x in batch) + + x = batch + if device == "numpy": + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + elif x is not None: + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + if torch.is_tensor(x): + x = x.to(device, non_blocking=non_blocking) + return x + + +to_device = todevice # alias + + +def to_numpy(x): + return todevice(x, "numpy") + + +def to_cpu(x): + return todevice(x, "cpu") + + +def to_cuda(x): + return todevice(x, "cuda") + + +def collate_with_cat(whatever, lists=False): + if isinstance(whatever, dict): + return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()} + + elif isinstance(whatever, (tuple, list)): + if len(whatever) == 0: + return whatever + elem = whatever[0] + T = type(whatever) + + if elem is None: + return None + if isinstance(elem, (bool, float, int, str)): + return whatever + if isinstance(elem, tuple): + return T(collate_with_cat(x, lists=lists) for x in zip(*whatever)) + if isinstance(elem, dict): + return { + k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem + } + + if isinstance(elem, torch.Tensor): + return listify(whatever) if lists else torch.cat(whatever) + if isinstance(elem, np.ndarray): + return ( + listify(whatever) + if lists + else torch.cat([torch.from_numpy(x) for x in whatever]) + ) + + return sum(whatever, T()) + + +def listify(elems): + return [x for e in elems for x in e] diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/geometry.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..4d4ab6a9338d112a5e0e27a1d249bd9be0f0c282 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/geometry.py @@ -0,0 +1,554 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +import torch +import numpy as np +from scipy.spatial import cKDTree as KDTree + +from dust3r.utils.misc import invalid_to_zeros, invalid_to_nans +from dust3r.utils.device import to_numpy + + +def xy_grid( + W, + H, + device=None, + origin=(0, 0), + unsqueeze=None, + cat_dim=-1, + homogeneous=False, + **arange_kw, +): + """Output a (H,W,2) array of int32 + with output[j,i,0] = i + origin[0] + output[j,i,1] = j + origin[1] + """ + if device is None: + + arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones + else: + + arange = lambda *a, **kw: torch.arange(*a, device=device, **kw) + meshgrid, stack = torch.meshgrid, torch.stack + ones = lambda *a: torch.ones(*a, device=device) + + tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)] + grid = meshgrid(tw, th, indexing="xy") + if homogeneous: + grid = grid + (ones((H, W)),) + if unsqueeze is not None: + grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) + if cat_dim is not None: + grid = stack(grid, cat_dim) + return grid + + +def geotrf(Trf, pts, ncol=None, norm=False): + """Apply a geometric transformation to a list of 3-D points. + + H: 3x3 or 4x4 projection matrix (typically a Homography) + p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) + + ncol: int. number of columns of the result (2 or 3) + norm: float. if != 0, the resut is projected on the z=norm plane. + + Returns an array of projected 2d points. + """ + assert Trf.ndim >= 2 + if isinstance(Trf, np.ndarray): + pts = np.asarray(pts) + elif isinstance(Trf, torch.Tensor): + pts = torch.as_tensor(pts, dtype=Trf.dtype) + + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + if ( + isinstance(Trf, torch.Tensor) + and isinstance(pts, torch.Tensor) + and Trf.ndim == 3 + and pts.ndim == 4 + ): + d = pts.shape[3] + if Trf.shape[-1] == d: + pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) + elif Trf.shape[-1] == d + 1: + pts = ( + torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + + Trf[:, None, None, :d, d] + ) + else: + raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}") + else: + if Trf.ndim >= 3: + n = Trf.ndim - 2 + assert Trf.shape[:n] == pts.shape[:n], "batch size does not match" + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + + pts = pts[:, None, :] + + if pts.shape[-1] + 1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + return res + + +def inv(mat): + """Invert a torch or numpy matrix""" + if isinstance(mat, torch.Tensor): + return torch.linalg.inv(mat) + if isinstance(mat, np.ndarray): + return np.linalg.inv(mat) + raise ValueError(f"bad matrix type = {type(mat)}") + + +def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_): + """ + Args: + - depthmap (BxHxW array): + - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W] + Returns: + pointmap of absolute coordinates (BxHxWx3 array) + """ + + if len(depth.shape) == 4: + B, H, W, n = depth.shape + else: + B, H, W = depth.shape + n = None + + if len(pseudo_focal.shape) == 3: # [B,H,W] + pseudo_focalx = pseudo_focaly = pseudo_focal + elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W] + pseudo_focalx = pseudo_focal[:, 0] + if pseudo_focal.shape[1] == 2: + pseudo_focaly = pseudo_focal[:, 1] + else: + pseudo_focaly = pseudo_focalx + else: + raise NotImplementedError("Error, unknown input focal shape format.") + + assert pseudo_focalx.shape == depth.shape[:3] + assert pseudo_focaly.shape == depth.shape[:3] + grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None] + + if pp is None: + grid_x = grid_x - (W - 1) / 2 + grid_y = grid_y - (H - 1) / 2 + else: + grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None] + grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None] + + if n is None: + pts3d = torch.empty((B, H, W, 3), device=depth.device) + pts3d[..., 0] = depth * grid_x / pseudo_focalx + pts3d[..., 1] = depth * grid_y / pseudo_focaly + pts3d[..., 2] = depth + else: + pts3d = torch.empty((B, H, W, 3, n), device=depth.device) + pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None] + pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None] + pts3d[..., 2, :] = depth + return pts3d + + +def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = depthmap.shape + + assert camera_intrinsics[0, 1] == 0.0 + assert camera_intrinsics[1, 0] == 0.0 + if pseudo_focal is None: + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + else: + assert pseudo_focal.shape == (H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + u, v = np.meshgrid(np.arange(W), np.arange(H)) + z_cam = depthmap + x_cam = (u - cu) * z_cam / fu + y_cam = (v - cv) * z_cam / fv + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + valid_mask = depthmap > 0.0 + return X_cam, valid_mask + + +def depthmap_to_absolute_camera_coordinates( + depthmap, camera_intrinsics, camera_pose, **kw +): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + - camera_pose: a 4x3 or 4x4 cam2world matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) + + X_world = X_cam # default + if camera_pose is not None: + + R_cam2world = camera_pose[:3, :3] + t_cam2world = camera_pose[:3, 3] + + X_world = ( + np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] + ) + + return X_world, X_cam, valid_mask + + +def colmap_to_opencv_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] -= 0.5 + K[1, 2] -= 0.5 + return K + + +def opencv_to_colmap_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] += 0.5 + K[1, 2] += 0.5 + return K + + +def normalize_pointcloud( + pts1, pts2, norm_mode="avg_dis", valid1=None, valid2=None, ret_factor=False +): + """renorm pointmaps pts1, pts2 with norm_mode""" + assert pts1.ndim >= 3 and pts1.shape[-1] == 3 + assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3) + norm_mode, dis_mode = norm_mode.split("_") + + if norm_mode == "avg": + + nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3) + nan_pts2, nnz2 = ( + invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0) + ) + all_pts = ( + torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 + ) + + all_dis = all_pts.norm(dim=-1) + if dis_mode == "dis": + pass # do nothing + elif dis_mode == "log1p": + all_dis = torch.log1p(all_dis) + elif dis_mode == "warp-log1p": + + log_dis = torch.log1p(all_dis) + warp_factor = log_dis / all_dis.clip(min=1e-8) + H1, W1 = pts1.shape[1:-1] + pts1 = pts1 * warp_factor[:, : W1 * H1].view(-1, H1, W1, 1) + if pts2 is not None: + H2, W2 = pts2.shape[1:-1] + pts2 = pts2 * warp_factor[:, W1 * H1 :].view(-1, H2, W2, 1) + all_dis = log_dis # this is their true distance afterwards + else: + raise ValueError(f"bad {dis_mode=}") + + norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8) + else: + + nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3) + nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None + all_pts = ( + torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 + ) + + all_dis = all_pts.norm(dim=-1) + + if norm_mode == "avg": + norm_factor = all_dis.nanmean(dim=1) + elif norm_mode == "median": + norm_factor = all_dis.nanmedian(dim=1).values.detach() + elif norm_mode == "sqrt": + norm_factor = all_dis.sqrt().nanmean(dim=1) ** 2 + else: + raise ValueError(f"bad {norm_mode=}") + + norm_factor = norm_factor.clip(min=1e-8) + while norm_factor.ndim < pts1.ndim: + norm_factor.unsqueeze_(-1) + + res = pts1 / norm_factor + if pts2 is not None: + res = (res, pts2 / norm_factor) + if ret_factor: + res = res + (norm_factor,) + return res + + +def normalize_pointcloud_group( + pts_list, + norm_mode="avg_dis", + valid_list=None, + conf_list=None, + ret_factor=False, + ret_factor_only=False, +): + """renorm pointmaps pts1, pts2 with norm_mode""" + for pts in pts_list: + assert pts.ndim >= 3 and pts.shape[-1] == 3 + + norm_mode, dis_mode = norm_mode.split("_") + + if norm_mode == "avg": + + nan_pts_list, nnz_list = zip( + *[ + invalid_to_zeros(pts1, valid1, ndim=3) + for pts1, valid1 in zip(pts_list, valid_list) + ] + ) + all_pts = torch.cat(nan_pts_list, dim=1) + if conf_list is not None: + nan_conf_list = [ + invalid_to_zeros(conf1[..., None], valid1, ndim=3)[0] + for conf1, valid1 in zip(conf_list, valid_list) + ] + all_conf = torch.cat(nan_conf_list, dim=1)[..., 0] + else: + all_conf = torch.ones_like(all_pts[..., 0]) + + all_dis = all_pts.norm(dim=-1) + if dis_mode == "dis": + pass # do nothing + elif dis_mode == "log1p": + all_dis = torch.log1p(all_dis) + elif dis_mode == "warp-log1p": + + log_dis = torch.log1p(all_dis) + warp_factor = log_dis / all_dis.clip(min=1e-8) + H_W_list = [pts.shape[1:-1] for pts in pts_list] + pts_list = [ + pts + * warp_factor[:, sum(H_W_list[:i]) : sum(H_W_list[: i + 1])].view( + -1, H, W, 1 + ) + for i, (pts, (H, W)) in enumerate(zip(pts_list, H_W_list)) + ] + all_dis = log_dis # this is their true distance afterwards + else: + raise ValueError(f"bad {dis_mode=}") + + norm_factor = (all_conf * all_dis).sum(dim=1) / (all_conf.sum(dim=1) + 1e-8) + else: + + nan_pts_list = [ + invalid_to_nans(pts1, valid1, ndim=3) + for pts1, valid1 in zip(pts_list, valid_list) + ] + + all_pts = torch.cat(nan_pts_list, dim=1) + + all_dis = all_pts.norm(dim=-1) + + if norm_mode == "avg": + norm_factor = all_dis.nanmean(dim=1) + elif norm_mode == "median": + norm_factor = all_dis.nanmedian(dim=1).values.detach() + elif norm_mode == "sqrt": + norm_factor = all_dis.sqrt().nanmean(dim=1) ** 2 + else: + raise ValueError(f"bad {norm_mode=}") + + norm_factor = norm_factor.clip(min=1e-8) + while norm_factor.ndim < pts_list[0].ndim: + norm_factor.unsqueeze_(-1) + + if ret_factor_only: + + return norm_factor + + res = [pts / norm_factor for pts in pts_list] + if ret_factor: + return res, norm_factor + return res + + +@torch.no_grad() +def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5): + + _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1) + _z2 = ( + invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) + if z2 is not None + else None + ) + _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1 + + if quantile == 0.5: + shift_z = torch.nanmedian(_z, dim=-1).values + else: + shift_z = torch.nanquantile(_z, quantile, dim=-1) + return shift_z # (B,) + + +@torch.no_grad() +def get_group_pointcloud_depth(zs, valid_masks, quantile=0.5): + + _zs = [ + invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1) + for z1, valid_mask1 in zip(zs, valid_masks) + ] + _z = torch.cat(_zs, dim=-1) + + if quantile == 0.5: + shift_z = torch.nanmedian(_z, dim=-1).values + else: + shift_z = torch.nanquantile(_z, quantile, dim=-1) + return shift_z # (B,) + + +@torch.no_grad() +def get_joint_pointcloud_center_scale( + pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True +): + + _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3) + _pts2 = ( + invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) + if pts2 is not None + else None + ) + _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1 + + _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3) + if z_only: + _center[..., :2] = 0 # do not center X and Y + + _norm = ((_pts - _center) if center else _pts).norm(dim=-1) + scale = torch.nanmedian(_norm, dim=1).values + return _center[:, None, :, :], scale[:, None, None, None] + + +@torch.no_grad() +def get_group_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True): + + _pts = [ + invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3) + for pts1, valid_mask1 in zip(pts, valid_masks) + ] + _pts = torch.cat(_pts, dim=1) + + _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3) + if z_only: + _center[..., :2] = 0 # do not center X and Y + + _norm = ((_pts - _center) if center else _pts).norm(dim=-1) + scale = torch.nanmedian(_norm, dim=1).values + return _center[:, None, :, :], scale[:, None, None, None] + + +def find_reciprocal_matches(P1, P2): + """ + returns 3 values: + 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match + 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1 + 3 - reciprocal_in_P2.sum(): the number of matches + """ + tree1 = KDTree(P1) + tree2 = KDTree(P2) + + _, nn1_in_P2 = tree2.query(P1, workers=8) + _, nn2_in_P1 = tree1.query(P2, workers=8) + + reciprocal_in_P1 = nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2)) + reciprocal_in_P2 = nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1)) + assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum() + return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum() + + +def get_med_dist_between_poses(poses): + from scipy.spatial.distance import pdist + + return np.median(pdist([to_numpy(p[:3, 3]) for p in poses])) + + +def weighted_procrustes(A, B, w, use_weights=True, eps=1e-16, return_T=False): + """ + X: torch tensor B x N x 3 + Y: torch tensor B x N x 3 + w: torch tensor B x N + """ + assert len(A) == len(B) + if use_weights: + W1 = torch.abs(w).sum(1, keepdim=True) + w_norm = (w / (W1 + eps)).unsqueeze(-1) + a_mean = (w_norm * A).sum(dim=1, keepdim=True) + b_mean = (w_norm * B).sum(dim=1, keepdim=True) + + A_c = A - a_mean + B_c = B - b_mean + + H = torch.einsum("bni,bnj->bij", A_c, w_norm * B_c) + + else: + a_mean = A.mean(axis=1, keepdim=True) + b_mean = B.mean(axis=1, keepdim=True) + + A_c = A - a_mean + B_c = B - b_mean + + H = torch.einsum("bij,bik->bjk", A_c, B_c) + + U, S, V = torch.svd(H) # U: B x 3 x 3, S: B x 3, V: B x 3 x 3 + Z = torch.eye(3).unsqueeze(0).repeat(A.shape[0], 1, 1).to(A.device) + Z[:, -1, -1] = torch.sign(torch.linalg.det(U @ V.transpose(1, 2))) # B x 3 x 3 + R = V @ Z @ U.transpose(1, 2) # B x 3 x 3 + t = b_mean - torch.einsum("bij,bjk->bik", R, a_mean.transpose(-2, -1)).transpose( + -2, -1 + ) + if return_T: + T = torch.eye(4).unsqueeze(0).repeat(A.shape[0], 1, 1).to(A.device) + T[:, :3, :3] = R + T[:, :3, 3] = t.squeeze() + return T + return R, t.squeeze() diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/image.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..60feb3048d342bf1e82483cbd37c57d6efd7fff3 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/image.py @@ -0,0 +1,271 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +import os +import torch +import numpy as np +import PIL.Image +from PIL.ImageOps import exif_transpose +import torchvision.transforms as tvf + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 # noqa + +try: + from pillow_heif import register_heif_opener # noqa + + register_heif_opener() + heif_support_enabled = True +except ImportError: + heif_support_enabled = False + +ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + +def img_to_arr(img): + if isinstance(img, str): + img = imread_cv2(img) + return img + + +def imread_cv2(path, options=cv2.IMREAD_COLOR): + """Open an image or a depthmap with opencv-python.""" + if path.endswith((".exr", "EXR")): + options = cv2.IMREAD_ANYDEPTH + img = cv2.imread(path, options) + if img is None: + raise IOError(f"Could not load image={path} with {options=}") + if img.ndim == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def imread_pil(path): + """Open an RGB image using PIL and return as numpy array.""" + img = PIL.Image.open(path) + img = exif_transpose(img) + img = img.convert("RGB") + return np.array(img) + + +def rgb(ftensor, true_shape=None): + if isinstance(ftensor, list): + return [rgb(x, true_shape=true_shape) for x in ftensor] + if isinstance(ftensor, torch.Tensor): + ftensor = ftensor.detach().cpu().numpy() # H,W,3 + if ftensor.ndim == 3 and ftensor.shape[0] == 3: + ftensor = ftensor.transpose(1, 2, 0) + elif ftensor.ndim == 4 and ftensor.shape[1] == 3: + ftensor = ftensor.transpose(0, 2, 3, 1) + if true_shape is not None: + H, W = true_shape + ftensor = ftensor[:H, :W] + if ftensor.dtype == np.uint8: + img = np.float32(ftensor) / 255 + else: + img = (ftensor * 0.5) + 0.5 + return img.clip(min=0, max=1) + + +def _resize_pil_image(img, long_edge_size): + S = max(img.size) + if S > long_edge_size: + interp = PIL.Image.LANCZOS + elif S <= long_edge_size: + interp = PIL.Image.BICUBIC + new_size = tuple(int(round(x * long_edge_size / S)) for x in img.size) + return img.resize(new_size, interp) + + +def load_images(folder_or_list, size, square_ok=False, verbose=True): + """open and convert all images in a list or folder to proper input format for DUSt3R""" + if isinstance(folder_or_list, str): + if verbose: + print(f">> Loading images from {folder_or_list}") + root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) + + elif isinstance(folder_or_list, list): + if verbose: + print(f">> Loading a list of {len(folder_or_list)} images") + root, folder_content = "", folder_or_list + + else: + raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})") + + supported_images_extensions = [".jpg", ".jpeg", ".png", ".bmp"] + if heif_support_enabled: + supported_images_extensions += [".heic", ".heif"] + supported_images_extensions = tuple(supported_images_extensions) + + imgs = [] + for path in folder_content: + if not path.lower().endswith(supported_images_extensions): + continue + img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB") + W1, H1 = img.size + if size == 224: + + img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1))) + else: + + img = _resize_pil_image(img, size) + W, H = img.size + cx, cy = W // 2, H // 2 + if size == 224: + half = min(cx, cy) + img = img.crop((cx - half, cy - half, cx + half, cy + half)) + else: + halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8 + if not (square_ok) and W == H: + halfh = 3 * halfw / 4 + img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh)) + + W2, H2 = img.size + if verbose: + print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}") + imgs.append( + dict( + img=ImgNorm(img)[None], + true_shape=np.int32([img.size[::-1]]), + idx=len(imgs), + instance=str(len(imgs)), + ) + ) + + assert imgs, "no images foud at " + root + if verbose: + print(f" (Found {len(imgs)} images)") + return imgs + + +def load_images_for_eval( + folder_or_list, size, square_ok=False, verbose=True, crop=True +): + """open and convert all images in a list or folder to proper input format for DUSt3R""" + if isinstance(folder_or_list, str): + if verbose: + print(f">> Loading images from {folder_or_list}") + root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) + + elif isinstance(folder_or_list, list): + if verbose: + print(f">> Loading a list of {len(folder_or_list)} images") + root, folder_content = "", folder_or_list + + else: + raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})") + + supported_images_extensions = [".jpg", ".jpeg", ".png"] + if heif_support_enabled: + supported_images_extensions += [".heic", ".heif"] + supported_images_extensions = tuple(supported_images_extensions) + + imgs = [] + for path in folder_content: + if not path.lower().endswith(supported_images_extensions): + continue + img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB") + W1, H1 = img.size + if size == 224: + # resize short side to 224 (then crop) + img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1))) + else: + # resize long side to 512 + img = _resize_pil_image(img, size) + W, H = img.size + cx, cy = W // 2, H // 2 + if size == 224: + half = min(cx, cy) + if crop: + img = img.crop((cx - half, cy - half, cx + half, cy + half)) + else: # resize + img = img.resize((2 * half, 2 * half), PIL.Image.LANCZOS) + else: + halfw, halfh = ((2 * cx) // 14) * 7, ((2 * cy) // 14) * 7 + if not (square_ok) and W == H: + halfh = 3 * halfw / 4 + if crop: + img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh)) + else: # resize + img = img.resize((2 * halfw, 2 * halfh), PIL.Image.LANCZOS) + W2, H2 = img.size + if verbose: + print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}") + imgs.append( + dict( + img=ImgNorm(img)[None], + true_shape=np.int32([img.size[::-1]]), + idx=len(imgs), + instance=str(len(imgs)), + ) + ) + + assert imgs, "no images foud at " + root + if verbose: + print(f" (Found {len(imgs)} images)") + return imgs + + +def load_images_512(folder_or_list, size, square_ok=False, verbose=True): + """open and convert all images in a list or folder to proper input format for DUSt3R""" + if isinstance(folder_or_list, str): + if verbose: + print(f">> Loading images from {folder_or_list}") + root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) + + elif isinstance(folder_or_list, list): + if verbose: + print(f">> Loading a list of {len(folder_or_list)} images") + root, folder_content = "", folder_or_list + + else: + raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})") + + supported_images_extensions = [".jpg", ".jpeg", ".png", ".bmp"] + if heif_support_enabled: + supported_images_extensions += [".heic", ".heif"] + supported_images_extensions = tuple(supported_images_extensions) + + imgs = [] + for path in folder_content: + if not path.lower().endswith(supported_images_extensions): + continue + img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB") + img = img.resize((512, 384)) + W1, H1 = img.size + if size == 224: + + img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1))) + else: + + img = _resize_pil_image(img, size) + W, H = img.size + cx, cy = W // 2, H // 2 + if size == 224: + half = min(cx, cy) + img = img.crop((cx - half, cy - half, cx + half, cy + half)) + else: + halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8 + if not (square_ok) and W == H: + halfh = 3 * halfw / 4 + img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh)) + + W2, H2 = img.size + if verbose: + print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}") + imgs.append( + dict( + img=ImgNorm(img)[None], + true_shape=np.int32([img.size[::-1]]), + idx=len(imgs), + instance=str(len(imgs)), + ) + ) + + assert imgs, "no images foud at " + root + if verbose: + print(f" (Found {len(imgs)} images)") + return imgs diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/misc.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..fbb3f225ba3b0a007541eb81362cd58e1c54d916 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/misc.py @@ -0,0 +1,127 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +import torch + + +def fill_default_args(kwargs, func): + import inspect # a bit hacky but it works reliably + + signature = inspect.signature(func) + + for k, v in signature.parameters.items(): + if v.default is inspect.Parameter.empty: + continue + kwargs.setdefault(k, v.default) + + return kwargs + + +def freeze_all_params(modules): + for module in modules: + try: + for n, param in module.named_parameters(): + param.requires_grad = False + except AttributeError: + + module.requires_grad = False + + +def is_symmetrized(gt1, gt2): + x = gt1["instance"] + y = gt2["instance"] + if len(x) == len(y) and len(x) == 1: + return False # special case of batchsize 1 + ok = True + for i in range(0, len(x), 2): + ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i]) + return ok + + +def flip(tensor): + """flip so that tensor[0::2] <=> tensor[1::2]""" + return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1) + + +def interleave(tensor1, tensor2): + res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) + res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) + return res1, res2 + + +def transpose_to_landscape(head, activate=True): + """Predict in the correct aspect-ratio, + then transpose the result in landscape + and stack everything back together. + """ + + def wrapper_no(decout, true_shape, **kwargs): + B = len(true_shape) + assert true_shape[0:1].allclose(true_shape), "true_shape must be all identical" + H, W = true_shape[0].cpu().tolist() + res = head(decout, (H, W), **kwargs) + return res + + def wrapper_yes(decout, true_shape, **kwargs): + B = len(true_shape) + + H, W = int(true_shape.min()), int(true_shape.max()) + + height, width = true_shape.T + is_landscape = width >= height + is_portrait = ~is_landscape + + if is_landscape.all(): + return head(decout, (H, W), **kwargs) + if is_portrait.all(): + return transposed(head(decout, (W, H), **kwargs)) + + def selout(ar): + return [d[ar] for d in decout] + + if "pos" in kwargs: + kwargs_landscape = kwargs.copy() + kwargs_landscape["pos"] = kwargs["pos"][is_landscape] + kwargs_portrait = kwargs.copy() + kwargs_portrait["pos"] = kwargs["pos"][is_portrait] + l_result = head(selout(is_landscape), (H, W), **kwargs_landscape) + p_result = transposed(head(selout(is_portrait), (W, H), **kwargs_portrait)) + + result = {} + for k in l_result | p_result: + x = l_result[k].new(B, *l_result[k].shape[1:]) + x[is_landscape] = l_result[k] + x[is_portrait] = p_result[k] + result[k] = x + + return result + + return wrapper_yes if activate else wrapper_no + + +def transposed(dic): + return {k: v.swapaxes(1, 2) if v.ndim > 2 else v for k, v in dic.items()} + + +def invalid_to_nans(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = float("nan") + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr + + +def invalid_to_zeros(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = 0 + nnz = valid_mask.view(len(valid_mask), -1).sum(1) + else: + nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr, nnz diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/parallel.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..5082a85b8c66cdcddc7402c401c0c983c5f1078b --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/parallel.py @@ -0,0 +1,87 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +from tqdm import tqdm +from multiprocessing.dummy import Pool as ThreadPool +from multiprocessing import cpu_count + + +def parallel_threads( + function, + args, + workers=0, + star_args=False, + kw_args=False, + front_num=1, + Pool=ThreadPool, + **tqdm_kw +): + """tqdm but with parallel execution. + + Will essentially return + res = [ function(arg) # default + function(*arg) # if star_args is True + function(**arg) # if kw_args is True + for arg in args] + + Note: + the first elements of args will not be parallelized. + This can be useful for debugging. + """ + while workers <= 0: + workers += cpu_count() + if workers == 1: + front_num = float("inf") + + try: + n_args_parallel = len(args) - front_num + except TypeError: + n_args_parallel = None + args = iter(args) + + front = [] + while len(front) < front_num: + try: + a = next(args) + except StopIteration: + return front # end of the iterable + front.append( + function(*a) if star_args else function(**a) if kw_args else function(a) + ) + + out = [] + with Pool(workers) as pool: + + if star_args: + futures = pool.imap(starcall, [(function, a) for a in args]) + elif kw_args: + futures = pool.imap(starstarcall, [(function, a) for a in args]) + else: + futures = pool.imap(function, args) + + for f in tqdm(futures, total=n_args_parallel, **tqdm_kw): + out.append(f) + return front + out + + +def parallel_processes(*args, **kwargs): + """Same as parallel_threads, with processes""" + import multiprocessing as mp + + kwargs["Pool"] = mp.Pool + return parallel_threads(*args, **kwargs) + + +def starcall(args): + """convenient wrapper for Process.Pool""" + function, args = args + return function(*args) + + +def starstarcall(args): + """convenient wrapper for Process.Pool""" + function, args = args + return function(**args) diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/path_to_croco.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/path_to_croco.py new file mode 100644 index 0000000000000000000000000000000000000000..108b532b440b49dd5c9f77eac86ec3562cb5c1e8 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/path_to_croco.py @@ -0,0 +1,47 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +import sys +import os.path as path +import importlib + +HERE_PATH = path.normpath(path.dirname(__file__)) +CROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, "../../croco")) +CROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, "models") +# IMPORTANT: +# Do NOT add `.../src/croco` directly to sys.path, otherwise subfolders like +# `croco/datasets` become a top-level module named `datasets`, which will shadow +# HuggingFace `datasets` and break `accelerate` (and others). +# Instead, add `.../src` so we import as `croco.*`. +SRC_PATH = path.normpath(path.join(HERE_PATH, "../../..")) + +if path.isdir(CROCO_MODELS_PATH): + + # Prefer adding the `src` directory; this enables `import croco...` without + # polluting top-level module names. + if SRC_PATH not in sys.path: + sys.path.insert(0, SRC_PATH) + + # In case an old run already inserted CROCO_REPO_PATH, remove it to avoid + # shadowing top-level modules (e.g., `datasets`). + while CROCO_REPO_PATH in sys.path: + sys.path.remove(CROCO_REPO_PATH) + + # Backward-compat: DUSt3R code expects `models.*` to exist as a top-level package + # (historically achieved by adding CROCO_REPO_PATH to sys.path). We keep that + # import path working by aliasing `croco.models` to `models` without exposing + # other top-level names like `datasets`. + try: + _croco_models = importlib.import_module("croco.models") + sys.modules.setdefault("models", _croco_models) + except Exception: + # If croco isn't importable yet, downstream import will raise a clearer error. + pass +else: + raise ImportError( + f"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\n " + "Did you forget to run 'git submodule update --init --recursive' ?" + ) diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/render.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/render.py new file mode 100644 index 0000000000000000000000000000000000000000..bc61fa8993396c9cd850177c288eb2a798561333 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/render.py @@ -0,0 +1,75 @@ +import torch +from gsplat import rasterization +from dust3r.utils.geometry import inv, geotrf + + +def render( + intrinsics: torch.Tensor, + pts3d: torch.Tensor, + rgbs: torch.Tensor | None = None, + scale: float = 0.002, + opacity: float = 0.95, +): + + device = pts3d.device + batch_size = len(intrinsics) + img_size = pts3d.shape[1:3] + pts3d = pts3d.reshape(batch_size, -1, 3) + num_pts = pts3d.shape[1] + quats = torch.randn((num_pts, 4), device=device) + quats = quats / quats.norm(dim=-1, keepdim=True) + scales = scale * torch.ones((num_pts, 3), device=device) + opacities = opacity * torch.ones((num_pts), device=device) + if rgbs is not None: + assert rgbs.shape[1] == 3 + rgbs = rgbs.reshape(batch_size, 3, -1).transpose(1, 2) + else: + rgbs = torch.ones_like(pts3d[:, :, :3]) + + rendered_rgbs = [] + rendered_depths = [] + accs = [] + for i in range(batch_size): + rgbd, acc, _ = rasterization( + pts3d[i], + quats, + scales, + opacities, + rgbs[i], + torch.eye(4, device=device)[None], + intrinsics[[i]], + width=img_size[1], + height=img_size[0], + packed=False, + render_mode="RGB+D", + ) + + rendered_depths.append(rgbd[..., 3]) + + rendered_depths = torch.cat(rendered_depths, dim=0) + + return rendered_rgbs, rendered_depths, accs + + +def get_render_results(gts, preds, self_view=False): + device = preds[0]["pts3d_in_other_view"].device + with torch.no_grad(): + depths = [] + gt_depths = [] + for i, (gt, pred) in enumerate(zip(gts, preds)): + if self_view: + camera = inv(gt["camera_pose"]).to(device) + intrinsics = gt["camera_intrinsics"].to(device) + pred = pred["pts3d_in_other_view"] + else: + camera = inv(gts[0]["camera_pose"]).to(device) + intrinsics = gts[0]["camera_intrinsics"].to(device) + pred = pred["pts3d_in_other_view"] + gt_img = gt["img"].to(device) + gt_pts3d = gt["pts3d"].to(device) + + _, depth, _ = render(intrinsics, pred, gt_img) + _, gt_depth, _ = render(intrinsics, geotrf(camera, gt_pts3d), gt_img) + depths.append(depth) + gt_depths.append(gt_depth) + return depths, gt_depths diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/__init__.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae847e46898077fe3d8701b8a181d7b4e3d41cd9 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +__version__ = "0.0.1" diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/__init__.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/backbones.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..53fe83719d5107eb77a8f25ef1814c3d73446002 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/backbones.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/utils.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) + output = F.pad(x, pads) + return output diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/__init__.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05a0b61868e43abb821ca05a813bab2b8b43629e --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/attention.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3fed573116d5c837be46a7525d8acf77422c2400 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/attention.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/block.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5b8a7bb8527b74186af7c1e060e37bdb52c73d --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/block.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Block)") + else: + # warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/dino_head.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/drop_path.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/layer_scale.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/mlp.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/patch_embed.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/swiglu_ffn.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..5ce211515774d42e04c8b51003bae53b88f14b35 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (SwiGLU)") + else: + # warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + # warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/models/__init__.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fdff20badbd5244bf79f16bf18dd2cb73982265 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/models/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from . import vision_transformer as vits + + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/models/vision_transformer.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..73f15cfb082d0fe629f8aa312c9d9b27a64ad4e7 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/models/vision_transformer.py @@ -0,0 +1,404 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from torch.nn.init import trunc_normal_ + +from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block +from ...layers.attention import FlashAttention + + +# logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + # logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + # logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + # logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + attn_class=FlashAttention + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/__init__.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/cluster.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..3df87dc3e1eb4f0f8a280dc3137cfef031886314 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/cluster.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +import os +from pathlib import Path +from typing import Any, Dict, Optional + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + if uname.sysname == "Linux": + if uname.release.endswith("-aws"): + # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" + return ClusterType.AWS + elif uname.nodename.startswith("rsc"): + # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" + return ClusterType.RSC + + return ClusterType.FAIR + + +def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + CHECKPOINT_DIRNAMES = { + ClusterType.AWS: "checkpoints", + ClusterType.FAIR: "checkpoint", + ClusterType.RSC: "checkpoint/dino", + } + return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] + + +def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + checkpoint_path = get_checkpoint_path(cluster_type) + if checkpoint_path is None: + return None + + username = os.environ.get("USER") + assert username is not None + return checkpoint_path / username + + +def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + SLURM_PARTITIONS = { + ClusterType.AWS: "learnlab", + ClusterType.FAIR: "learnlab", + ClusterType.RSC: "learn", + } + return SLURM_PARTITIONS[cluster_type] + + +def get_slurm_executor_parameters( + nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs +) -> Dict[str, Any]: + # create default parameters + params = { + "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "slurm_partition": get_slurm_partition(cluster_type), + } + # apply cluster-specific adjustments + cluster_type = get_cluster_type(cluster_type) + if cluster_type == ClusterType.AWS: + params["cpus_per_task"] = 12 + del params["mem_gb"] + elif cluster_type == ClusterType.RSC: + params["cpus_per_task"] = 12 + # set additional parameters / apply overrides + params.update(kwargs) + return params diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/config.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c9de578787bbcb376f8bd5a782206d0eb7ec1f52 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/config.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import logging +import os + +from omegaconf import OmegaConf + +import dinov2.distributed as distributed +from dinov2.logging import setup_logging +from dinov2.utils import utils +from dinov2.configs import dinov2_default_config + + +logger = logging.getLogger("dinov2") + + +def apply_scaling_rules_to_cfg(cfg): # to fix + if cfg.optim.scaling_rule == "sqrt_wrt_1024": + base_lr = cfg.optim.base_lr + cfg.optim.lr = base_lr + cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) + logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") + else: + raise NotImplementedError + return cfg + + +def write_config(cfg, output_dir, name="config.yaml"): + logger.info(OmegaConf.to_yaml(cfg)) + saved_cfg_path = os.path.join(output_dir, name) + with open(saved_cfg_path, "w") as f: + OmegaConf.save(config=cfg, f=f) + return saved_cfg_path + + +def get_cfg_from_args(args): + args.output_dir = os.path.abspath(args.output_dir) + args.opts += [f"train.output_dir={args.output_dir}"] + default_cfg = OmegaConf.create(dinov2_default_config) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) + return cfg + + +def default_setup(args): + distributed.enable(overwrite=True) + seed = getattr(args, "seed", 0) + rank = distributed.get_global_rank() + + global logger + setup_logging(output=args.output_dir, level=logging.INFO) + logger = logging.getLogger("dinov2") + + utils.fix_random_seeds(seed + rank) + logger.info("git:\n {}\n".format(utils.get_sha())) + logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg_from_args(args) + os.makedirs(args.output_dir, exist_ok=True) + default_setup(args) + apply_scaling_rules_to_cfg(cfg) + write_config(cfg, args.output_dir) + return cfg diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/dtype.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..80f4cd74d99faa2731dbe9f8d3a13d71b3f8e3a8 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/dtype.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +from typing import Dict, Union + +import numpy as np +import torch + + +TypeSpec = Union[str, np.dtype, torch.dtype] + + +_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} + + +def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + dtype = np.dtype(dtype) + assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" + return _NUMPY_TO_TORCH_DTYPE[dtype] diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/param_groups.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/param_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5d2ff627cddadc222e5f836864ee39c865208f --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/param_groups.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import defaultdict +import logging + + +logger = logging.getLogger("dinov2") + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone") or force_is_backbone: + if ( + ".pos_embed" in name + or ".patch_embed" in name + or ".mask_token" in name + or ".cls_token" in name + or ".register_tokens" in name + ): + layer_id = 0 + elif force_is_backbone and ( + "pos_embed" in name + or "patch_embed" in name + or "mask_token" in name + or "cls_token" in name + or "register_tokens" in name + ): + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + elif chunked_blocks and "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 + elif "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): + chunked_blocks = False + if hasattr(model, "n_blocks"): + logger.info("chunked fsdp") + n_blocks = model.n_blocks + chunked_blocks = model.chunked_blocks + elif hasattr(model, "blocks"): + logger.info("first code branch") + n_blocks = len(model.blocks) + elif hasattr(model, "backbone"): + logger.info("second code branch") + n_blocks = len(model.backbone.blocks) + else: + logger.info("else code branch") + n_blocks = 0 + all_param_groups = [] + + for name, param in model.named_parameters(): + name = name.replace("_fsdp_wrapped_module.", "") + if not param.requires_grad: + continue + decay_rate = get_vit_lr_decay_rate( + name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks + ) + d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} + + if "last_layer" in name: + d.update({"is_last_layer": True}) + + if name.endswith(".bias") or "norm" in name or "gamma" in name: + d.update({"wd_multiplier": 0.0}) + + if "patch_embed" in name: + d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) + + all_param_groups.append(d) + logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") + + return all_param_groups + + +def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): + fused_params_groups = defaultdict(lambda: {"params": []}) + for d in all_params_groups: + identifier = "" + for k in keys: + identifier += k + str(d[k]) + "_" + + for k in keys: + fused_params_groups[identifier][k] = d[k] + fused_params_groups[identifier]["params"].append(d["params"]) + + return fused_params_groups.values() diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/utils.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e8842e4145414f6f040c4ae83bf38552de8f65b2 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/utils/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import random +import subprocess +from urllib.parse import urlparse + +import numpy as np +import torch +from torch import nn + + +# logger = logging.getLogger("dinov2") + + +def load_pretrained_weights(model, pretrained_weights, checkpoint_key): + if urlparse(pretrained_weights).scheme: # If it looks like an URL + state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") + else: + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + # logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + # logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) + + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommitted changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +class CosineScheduler(object): + def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): + super().__init__() + self.final_value = final_value + self.total_iters = total_iters + + freeze_schedule = np.zeros((freeze_iters)) + + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(total_iters - warmup_iters - freeze_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) + + assert len(self.schedule) == self.total_iters + + def __getitem__(self, it): + if it >= self.total_iters: + return self.final_value + else: + return self.schedule[it] + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/attention.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..538702c02ca4eb6e0768f2fb261e7bf6256d0adf --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/attention.py @@ -0,0 +1,369 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn +import torch + +from torch.nn.functional import scaled_dot_product_attention +from torch.nn.attention import SDPBackend + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:,:,i] for i in range(3)] + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlashAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3) + + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:,:,i] for i in range(3)] + + if q.dtype == torch.bfloat16: + with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION): + x = scaled_dot_product_attention(q, k, v) + else: + with nn.attention.sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): + x = scaled_dot_product_attention(q, k, v) + + x = x.transpose(1, 2).reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +""" +Following is written by GPT-4o +""" +class CrossAttentionRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + qk_norm: bool = False, + norm_layer: nn.Module = nn.LayerNorm, + rope=None, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + # Separate projection layers for query, key, and value + self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) + + self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.rope = rope + + def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_bias=None, qpos=None, kpos=None) -> Tensor: + """ + Args: + query: Tensor of shape (B, N, C), input query + key: Tensor of shape (B, M, C), input key + value: Tensor of shape (B, M, C), input value + attn_bias: Optional tensor for attention bias + Returns: + Tensor of shape (B, N, C), output of cross-attention + """ + B, N, C = query.shape + _, M, _ = key.shape + + # Project query, key, and value + q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, qpos) + k = self.rope(k, kpos) + + # Scale query + q = q * self.scale + + # Compute attention scores + attn = q @ k.transpose(-2, -1) # (B, num_heads, N, M) + if attn_bias is not None: + attn = attn + attn_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + # Compute attention output + x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B, N, C) + + # Final projection + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffCrossAttentionRope(CrossAttentionRope): + def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_bias=None, qpos=None, kpos=None) -> Tensor: + """ + Args: + query: Tensor of shape (B, N, C), input query + key: Tensor of shape (B, M, C), input key + value: Tensor of shape (B, M, C), input value + attn_bias: Optional tensor for attention bias + Returns: + Tensor of shape (B, N, C), output of cross-attention + """ + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(query, key, value, attn_bias) + + B, N, C = query.shape + _, M, _ = key.shape + + # Project query, key, and value + q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads) + k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads) + v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, qpos) + k = self.rope(k, kpos) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + # Compute memory-efficient attention + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape(B, N, C) + + # Final projection + x = self.proj(x) + x = self.proj_drop(x) + return x + +class AttentionRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + qk_norm: bool = False, + norm_layer: nn.Module = nn.LayerNorm, + rope=None + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + + self.rope = rope + + def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttentionRope(AttentionRope): + def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + qkv = qkv.transpose(1, 3) + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:,:,i] for i in range(3)] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(frame_num, 261, frame_num, 261).mean(dim=[1, 3]).sum(1) # for frame attention matrix + # global_valid_id = torch.where(score_matrix > 0) + # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class FlashAttentionRope(AttentionRope): + def forward(self, x: Tensor, attn_bias=None, xpos=None, attn_mask=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3) + + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:,:,i] for i in range(3)] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + if q.dtype == torch.bfloat16: + #with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION): + x = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + else: + with nn.attention.sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): + x = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + + x = x.transpose(1, 2).reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + +def get_attn_score(blk_class, x, frame_num, token_length, xpos=None): + x = blk_class.norm1(x) + + B, N, C = x.shape + qkv = blk_class.attn.qkv(x).reshape(B, N, 3, blk_class.attn.num_heads, C // blk_class.attn.num_heads) + + qkv = qkv.transpose(1, 3) + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:,:,i] for i in range(3)] + q, k = blk_class.attn.q_norm(q).to(v.dtype), blk_class.attn.k_norm(k).to(v.dtype) + + if blk_class.attn.rope is not None: + q = blk_class.attn.rope(q, xpos) + k = blk_class.attn.rope(k, xpos) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + score = (q.permute(0, 2, 1, 3) * blk_class.attn.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(B, frame_num, token_length, frame_num, token_length).mean(dim=[2, 4]).sum(-1) + + return score diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/block.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..4135a9c8622b8f8468e8324ef9e223964ede913f --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/block.py @@ -0,0 +1,434 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention, CrossAttentionRope, MemEffCrossAttentionRope, FlashAttentionRope +from ..dinov2.layers.drop_path import DropPath +from ..dinov2.layers.layer_scale import LayerScale +from ..dinov2.layers.mlp import Mlp + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Block)") + else: + # warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError + +class BlockRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool=False, + rope=None + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + rope=rope + ) + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, xpos=None, N=None, branch=1, global_=False, attn_mask=None, **kwargs) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + if attn_mask is not None: + # Use externally provided mask directly + _attn_mask = attn_mask + elif global_: + if branch == 1: # frontend + B,NP,C = x.shape + S = N + L = NP #S * P + P = int(NP//N) + frame_ids = torch.arange(L, device=x.device) // P # [0,0,...,1,1,...,S-1] + future_frame = frame_ids.unsqueeze(1) < frame_ids.unsqueeze(0) + future_frame *= (frame_ids.unsqueeze(0)>1) + _attn_mask = future_frame.to(x.dtype) * torch.finfo(x.dtype).min + + elif branch == 2: # backend + _attn_mask = None + elif branch == 3: # mix-mode + B,NP,C = x.shape + S = N + L = NP #S * P + P = int(NP//N) + S_half= int(S//2) + frame_ids = torch.arange(L, device=x.device) // P # [0,0,...,1,1,...,S-1] + future_frame = frame_ids.unsqueeze(1) < frame_ids.unsqueeze(0) + future_frame *= (frame_ids.unsqueeze(0)>=S_half) + _attn_mask = future_frame.to(x.dtype) * torch.finfo(x.dtype).min + else: + _attn_mask = None + return self.ls1(self.attn(self.norm1(x), xpos=xpos, attn_mask=_attn_mask)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +class CrossBlockRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + cross_attn_class: Callable[..., nn.Module] = CrossAttentionRope, + ffn_layer: Callable[..., nn.Module] = Mlp, + init_values=None, + qk_norm: bool=False, + rope=None + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + rope=rope, + qk_norm=qk_norm + ) + + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.ls_y = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.norm2 = norm_layer(dim) + self.norm_y = norm_layer(dim) + self.cross_attn = cross_attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + rope=rope, + qk_norm=qk_norm + ) + + self.norm3 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + bias=ffn_bias, + ) + + def forward(self, x: Tensor, y: Tensor, xpos=None, ypos=None) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x), xpos=xpos)) + + def cross_attn_residual_func(x: Tensor, y: Tensor) -> Tensor: + return self.ls_y(self.cross_attn(self.norm2(x), y, y, qpos=xpos, kpos=ypos)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm3(x))) + + x = x + attn_residual_func(x) + y_ = self.norm_y(y) + x = x + cross_attn_residual_func(x, y_) + x = x + ffn_residual_func(x) + + return x diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/camera_head.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/camera_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7d844f7b76851c3e523e419e18358838e9d23410 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/camera_head.py @@ -0,0 +1,93 @@ +import torch +import torch.nn as nn +from copy import deepcopy +import torch.nn.functional as F + +# code adapted from 'https://github.com/nianticlabs/marepo/blob/9a45e2bb07e5bb8cb997620088d352b439b13e0e/transformer/transformer.py#L172' +class ResConvBlock(nn.Module): + """ + 1x1 convolution residual block + """ + def __init__(self, in_channels, out_channels): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.head_skip = nn.Identity() if self.in_channels == self.out_channels else nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0) + # self.res_conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0) + # self.res_conv2 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0) + # self.res_conv3 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0) + + # change 1x1 convolution to linear + self.res_conv1 = nn.Linear(self.in_channels, self.out_channels) + self.res_conv2 = nn.Linear(self.out_channels, self.out_channels) + self.res_conv3 = nn.Linear(self.out_channels, self.out_channels) + + def forward(self, res): + x = F.relu(self.res_conv1(res)) + x = F.relu(self.res_conv2(x)) + x = F.relu(self.res_conv3(x)) + res = self.head_skip(res) + x + return res + +class CameraHead(nn.Module): + def __init__(self, dim=512): + super().__init__() + output_dim = dim + self.res_conv = nn.ModuleList([deepcopy(ResConvBlock(output_dim, output_dim)) + for _ in range(2)]) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.more_mlps = nn.Sequential( + nn.Linear(output_dim,output_dim), + nn.ReLU(), + nn.Linear(output_dim,output_dim), + nn.ReLU() + ) + self.fc_t = nn.Linear(output_dim, 3) + self.fc_rot = nn.Linear(output_dim, 9) + + def forward(self, feat, patch_h, patch_w): + BN, hw, c = feat.shape + + for i in range(2): + feat = self.res_conv[i](feat) + + # feat = self.avgpool(feat) + feat = self.avgpool(feat.permute(0, 2, 1).reshape(BN, -1, patch_h, patch_w).contiguous()) ########## + feat = feat.view(feat.size(0), -1) + + feat = self.more_mlps(feat) # [B, D_] + with torch.amp.autocast(device_type='cuda', enabled=False): + out_t = self.fc_t(feat.float()) # [B,3] + out_r = self.fc_rot(feat.float()) # [B,9] + pose = self.convert_pose_to_4x4(BN, out_r, out_t, feat.device) + + return pose + + def convert_pose_to_4x4(self, B, out_r, out_t, device): + out_r = self.svd_orthogonalize(out_r) # [N,3,3] + pose = torch.zeros((B, 4, 4), device=device) + pose[:, :3, :3] = out_r + pose[:, :3, 3] = out_t + pose[:, 3, 3] = 1. + return pose + + def svd_orthogonalize(self, m): + """Convert 9D representation to SO(3) using SVD orthogonalization. + + Args: + m: [BATCH, 3, 3] 3x3 matrices. + + Returns: + [BATCH, 3, 3] SO(3) rotation matrices. + """ + if m.dim() < 3: + m = m.reshape((-1, 3, 3)) + m_transpose = torch.transpose(torch.nn.functional.normalize(m, p=2, dim=-1), dim0=-1, dim1=-2) + u, s, v = torch.svd(m_transpose) + det = torch.det(torch.matmul(v, u.transpose(-2, -1))) + # Check orientation reflection. + r = torch.matmul( + torch.cat([v[:, :, :-1], v[:, :, -1:] * det.view(-1, 1, 1)], dim=2), + u.transpose(-2, -1) + ) + return r \ No newline at end of file diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/dpt_head.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e82f9e20e7fd8824888cfb3f1da81977706d40aa --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/dpt_head.py @@ -0,0 +1,415 @@ +import os +from typing import List, Dict, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from streamvggt.heads.head_act import activate_head +from streamvggt.heads.utils import create_uv_grid, position_grid_to_embed +import pdb + +class DPTHead(nn.Module): + """ + Args: + dim_in (int): Input dimension (channels). + patch_size (int, optional): Patch size. Default is 14. + output_dim (int, optional): Number of output channels. Default is 4. + activation (str, optional): Activation type. Default is "inv_log". + conf_activation (str, optional): Confidence activation type. Default is "expp1". + features (int, optional): Feature channels for intermediate representations. Default is 256. + out_channels (List[int], optional): Output channels for each intermediate layer. + intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. + pos_embed (bool, optional): Whether to use positional embedding. Default is True. + feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. + down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. + """ + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 4, + activation: str = "inv_log", + conf_activation: str = "expp1", + features: int = 256, + out_channels: List[int] = [256, 512, 1024, 1024], + intermediate_layer_idx: List[int] = [0], #[4, 11, 17, 23], + pos_embed: bool = True, + feature_only: bool = False, + down_ratio: int = 1, + ) -> None: + super(DPTHead, self).__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.feature_only = feature_only + self.down_ratio = down_ratio + self.intermediate_layer_idx = intermediate_layer_idx + + self.norm = nn.LayerNorm(dim_in) + + # Projection layers for each output channel from tokens. + self.projects = nn.ModuleList( + [ + nn.Conv2d( + in_channels=dim_in, + out_channels=oc, + kernel_size=1, + stride=1, + padding=0, + ) + for oc in out_channels + ] + ) + + # Resize layers for upsampling feature maps. + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + + self.scratch = _make_scratch( + out_channels, + features, + expand=False, + ) + + # Attach additional modules to scratch. + self.scratch.stem_transpose = None + #self.scratch.refinenet1 = _make_fusion_block(features) + #self.scratch.refinenet2 = _make_fusion_block(features) + #self.scratch.refinenet3 = _make_fusion_block(features) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) + + head_features_1 = features + head_features_2 = 32 + + if feature_only: + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1) + else: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 + ) + conv2_in_channels = head_features_1 // 2 + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), + ) + + def forward( + self, + aggregated_tokens_list: List[torch.Tensor], + patch_start_idx: int = 0, + frames_chunk_size: int = 8, + shape_: tuple = (1,16,3, 518, 392) + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass through the DPT head, supports processing by chunking frames. + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + patch_start_idx (int): Starting index for patch tokens in the token sequence. + Used to separate patch tokens from other tokens (e.g., camera or register tokens). + frames_chunk_size (int, optional): Number of frames to process in each chunk. + If None or larger than S, all frames are processed at once. Default: 8. + + Returns: + Tensor or Tuple[Tensor, Tensor]: + - If feature_only=True: Feature maps with shape [B, S, C, H, W] + - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] + """ + B, S, _, H, W = shape_ + + # If frames_chunk_size is not specified or greater than S, process all frames at once + + return self._forward_impl(aggregated_tokens_list, patch_start_idx, shape_ = shape_) + + + def _forward_impl( + self, + aggregated_tokens_list: List[torch.Tensor], + patch_start_idx: int, + frames_start_idx: int = None, + frames_end_idx: int = None, + shape_: tuple = (1,16,3, 518, 392) + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W]. + patch_start_idx (int): Starting index for patch tokens. + frames_start_idx (int, optional): Starting index for frames to process. + frames_end_idx (int, optional): Ending index for frames to process. + + Returns: + Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). + """ + + B, S, _, H, W = shape_ + + patch_h, patch_w = H // self.patch_size, W // self.patch_size + + out = [] + dpt_idx = 0 + + for layer_idx in self.intermediate_layer_idx: + x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] + + # Select frames if processing a chunk + + + x = x.reshape(B * S, -1, x.shape[-1]) + x = self.norm(x) + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[dpt_idx](x) + if self.pos_embed: + x = self._apply_pos_embed(x, W, H) + x = self.resize_layers[dpt_idx](x) + + out.append(x) + dpt_idx += 1 + + # Fuse features from multiple layers. + out = self.scratch_forward(out) + # Interpolate fused output to match target image resolution. + out = custom_interpolate( + out, + (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)), + mode="bilinear", + align_corners=True, + ) + + if self.pos_embed: + out = self._apply_pos_embed(out, W, H) + + if self.feature_only: + return out.reshape(B, S, *out.shape[1:]) + + out = self.scratch.output_conv2(out) + preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation) + + preds = preds.reshape(B, S, *preds.shape[1:]) + conf = conf.reshape(B, S, *conf.shape[1:]) + return preds, conf + + def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """ + Apply positional embedding to tensor x. + """ + patch_w = x.shape[-1] + patch_h = x.shape[-2] + pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) + pos_embed = pos_embed * ratio + pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pos_embed + + def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: + """ + Forward pass through the fusion blocks. + + Args: + features (List[Tensor]): List of feature maps from different layers. + + Returns: + Tensor: Fused feature map. + """ + layer_4 = features[0] + + layer_4_rn = self.scratch.layer4_rn(layer_4) + + out = self.scratch.refinenet4(layer_4_rn, size=layer_4_rn.shape[2:]) + del layer_4_rn, layer_4 + out = self.scratch.output_conv1(out) + return out + + +def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module: + return FeatureFusionBlock( + features, + nn.ReLU(inplace=True), + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=size, + has_residual=has_residual, + groups=groups, + ) + + +def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module: + scratch = nn.Module() + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer4_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn, groups=1): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + self.groups = groups + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.norm1 = None + self.norm2 = None + + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.norm1 is not None: + out = self.norm1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.norm2 is not None: + out = self.norm2(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + has_residual=True, + groups=1, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + self.groups = groups + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups + ) + + if has_residual: + self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.has_residual = has_residual + self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.skip_add = nn.quantized.FloatFunctional() + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if self.has_residual: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + output = self.out_conv(output) + + return output + + +def custom_interpolate( + x: torch.Tensor, + size: Tuple[int, int] = None, + scale_factor: float = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + """ + Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. + """ + if size is None: + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + + INT_MAX = 1610612736 + + input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] + + if input_elements > INT_MAX: + chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) + interpolated_chunks = [ + nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks + ] + x = torch.cat(interpolated_chunks, dim=0) + return x.contiguous() + else: + return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/pos_embed.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/pos_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..9e3ceb86b2d3992636a28b1a4abb3f5722dd959e --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/pos_embed.py @@ -0,0 +1,606 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Modified: Added RoPE3D_ChunkAware and PositionGetter3D_ChunkAware for chunk-aware position encoding +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + + + +import numpy as np + +import torch + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if n_cls_token>0: + pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +#---------------------------------------------------------- +# RoPE2D: RoPE implementation in 2D +#---------------------------------------------------------- + +try: + from models.curope import cuRoPE2D + RoPE2D = cuRoPE2D +except ImportError: + print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead') + + class RoPE2D(torch.nn.Module): + + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D,seq_len,device,dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D,seq_len,device,dtype] = (cos,sin) + return self.cache[D,seq_len,device,dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim==2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 2 (y and x position of each token) + output: + * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) + """ + assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two" + D = tokens.size(3) // 2 + assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2 + cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype) + # split features into two along the feature dimension, and apply rope1d on each half + y, x = tokens.chunk(2, dim=-1) + y = self.apply_rope1d(y, positions[:,:,0], cos, sin) + x = self.apply_rope1d(x, positions[:,:,1], cos, sin) + tokens = torch.cat((y, x), dim=-1) + return tokens + +# patch embedding +class PositionGetter(object): + """ return positions of patches """ + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, h, w, device): + if not (h,w) in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2) + pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone() + return pos + + +#---------------------------------------------------------- +# RoPE3D: RoPE implementation in 3D (ref_id + space) +# Fixed dimension split: ref_id=22, y=21, x=21 (total 64) +#---------------------------------------------------------- + +class RoPE3D(torch.nn.Module): + """ + RoPE implementation in 3D (ref_id, height, width) + + Fixed dimension split for head_dim=64: + - ref_id: 22 dimensions + - y: 21 dimensions + - x: 21 dimensions + + This ensures no dim divisibility issues (64 = 22 + 21 + 21) + """ + + def __init__(self, freq=100.0, F0=1.0, ref_freq=10.0): + super().__init__() + self.base = freq # for spatial dimensions (y, x) + self.ref_base = ref_freq # for ref_id dimension + self.F0 = F0 + self.cache = {} + + # Fixed dimension split + self.D_ref = 22 # ref_id gets 22 dims + self.D_y = 21 # y gets 21 dims + self.D_x = 21 # x gets 21 dims + + def get_cos_sin(self, D, seq_len, device, dtype, base): + key = (D, seq_len, device, dtype, base) + if key not in self.cache: + # For D dimensions, we need D/2 frequencies + # When D is odd, we compute ceil(D/2) frequencies and truncate + half_D = (D + 1) // 2 # ceil(D/2) + inv_freq = 1.0 / (base ** (torch.arange(0, half_D).float().to(device) * 2 / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + # Duplicate and truncate to exactly D dimensions + freqs = torch.cat((freqs, freqs), dim=-1)[:, :D] + cos = freqs.cos() # (Seq, D) + sin = freqs.sin() + self.cache[key] = (cos, sin) + return self.cache[key] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 3 (ref_id, y, x) + output: + * tokens after applying RoPE3D (batch_size x nheads x ntokens x dim) + """ + dim = tokens.size(3) + assert dim == self.D_ref + self.D_y + self.D_x, \ + f"dim {dim} != {self.D_ref} + {self.D_y} + {self.D_x}" + assert positions.ndim == 3 and positions.shape[-1] == 3 # Batch, Seq, 3 + + # Get cos/sin for ref_id dimension + cos_r, sin_r = self.get_cos_sin(self.D_ref, int(positions[:,:,0].max())+1, + tokens.device, tokens.dtype, self.ref_base) + # Get cos/sin for spatial dimensions (y and x have different D) + cos_y, sin_y = self.get_cos_sin(self.D_y, int(positions[:,:,1].max())+1, + tokens.device, tokens.dtype, self.base) + cos_x, sin_x = self.get_cos_sin(self.D_x, int(positions[:,:,2].max())+1, + tokens.device, tokens.dtype, self.base) + + # Split features using fixed dimensions + r = tokens[..., :self.D_ref] # [0:22] + y = tokens[..., self.D_ref:self.D_ref+self.D_y] # [22:43] + x = tokens[..., self.D_ref+self.D_y:] # [43:64] + + # Apply rope1d on each part with corresponding position + r = self.apply_rope1d(r, positions[:,:,0], cos_r, sin_r) # ref_id + y = self.apply_rope1d(y, positions[:,:,1], cos_y, sin_y) # height + x = self.apply_rope1d(x, positions[:,:,2], cos_x, sin_x) # width + + tokens = torch.cat((r, y, x), dim=-1) + return tokens + + +class PositionGetter3D(object): + """ return 3D positions of patches (ref_id, y, x) """ + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, n, h, w, device, chunk_size=None): + """ + Args: + b: batch size + n: number of frames + h: height in patches + w: width in patches + device: torch device + chunk_size: size of each chunk for computing ref_id + If None, uses frame_idx directly (backward compatible) + Returns: + pos: (b, n*h*w, 3) with (ref_id, y, x) + + ref_id formula: + - frame 0: ref_id = 0 + - other frames: ref_id = ((frame_idx - 1) // chunk_size) * chunk_size + """ + # Default chunk_size for backward compatibility + if chunk_size is None: + chunk_size = n // 2 + + cache_key = (n, h, w, chunk_size) + if cache_key not in self.cache_positions: + positions = [] + for frame_idx in range(n): + # Compute ref_id based on chunk_size + if frame_idx == 0: + ref_id = 0 + else: + ref_id = ((frame_idx - 1) // chunk_size) * chunk_size + + for py in range(h): + for px in range(w): + positions.append([ref_id, py, px]) + pos = torch.tensor(positions, dtype=torch.long, device=device) + self.cache_positions[cache_key] = pos + + pos = self.cache_positions[cache_key].to(device).view(1, n*h*w, 3).expand(b, -1, 3).clone() + return pos + + +#---------------------------------------------------------- +# RoPE4D_RefAware: RoPE with reference awareness (ref_id, t, y, x) +# 4D encoding solves dim divisibility issues (64 / 4 = 16) +#---------------------------------------------------------- + +class RoPE4D_ChunkAware(torch.nn.Module): + """ + Chunk-Aware 4D RoPE implementation: (chunk_idx, t, y, x) + + Key design improvements over ref_id-based encoding: + - chunk_idx: which chunk this frame belongs to (0, 1, 2, ...) - CONTINUOUS values + - t: temporal position within the chunk (0 for anchor, 1-chunk_size for others) + - y, x: spatial position in the patch grid + + Dimension split for head_dim=64 (optimized for value ranges): + - chunk_idx: 8 dims (small value range: 0, 1, 2, 3, ...) + - t: 8 dims (small value range: 0 to chunk_size) + - y: 24 dims (larger range: 0 to ~37 for 518x518 images) + - x: 24 dims (larger range: 0 to ~37) + + This gives more capacity to spatial dimensions while keeping temporal compact. + """ + + def __init__(self, freq=100.0, chunk_freq=10000, temporal_freq=1000.0): + super().__init__() + self.base = freq # for spatial dimensions (y, x) + self.chunk_base = chunk_freq # for chunk_idx dimension + self.temporal_base = temporal_freq # for temporal dimension t + self.cache = {} + + # Fixed dimension split (8 + 8 + 24 + 24 = 64) + self.D_chunk = 8 # chunk_idx gets 8 dims + self.D_t = 8 # t gets 8 dims + self.D_y = 24 # y gets 24 dims + self.D_x = 24 # x gets 24 dims + + def get_cos_sin(self, D, seq_len, device, dtype, base): + key = (D, seq_len, device, dtype, base) + if key not in self.cache: + # For D dimensions, we need D/2 frequencies + half_D = (D + 1) // 2 + inv_freq = 1.0 / (base ** (torch.arange(0, half_D).float().to(device) * 2 / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1)[:, :D] # Truncate to D + cos = freqs.cos() # (Seq, D) + sin = freqs.sin() + self.cache[key] = (cos, sin) + return self.cache[key] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 4 (chunk_idx, t, y, x) + output: + * tokens after applying RoPE (batch_size x nheads x ntokens x dim) + + Dimension layout: [chunk(8) | t(8) | y(24) | x(24)] + """ + dim = tokens.size(3) + expected_dim = self.D_chunk + self.D_t + self.D_y + self.D_x + assert dim == expected_dim, f"dim {dim} != expected {expected_dim}" + assert positions.ndim == 3 and positions.shape[-1] == 4 + + # Get cos/sin for each dimension with appropriate base frequency + cos_c, sin_c = self.get_cos_sin(self.D_chunk, int(positions[:,:,0].max())+1, + tokens.device, tokens.dtype, self.chunk_base) + cos_t, sin_t = self.get_cos_sin(self.D_t, int(positions[:,:,1].max())+1, + tokens.device, tokens.dtype, self.temporal_base) + cos_y, sin_y = self.get_cos_sin(self.D_y, int(positions[:,:,2].max())+1, + tokens.device, tokens.dtype, self.base) + cos_x, sin_x = self.get_cos_sin(self.D_x, int(positions[:,:,3].max())+1, + tokens.device, tokens.dtype, self.base) + + # Split features using fixed dimensions + c = tokens[..., :self.D_chunk] # [0:8] + t = tokens[..., self.D_chunk:self.D_chunk+self.D_t] # [8:16] + y = tokens[..., self.D_chunk+self.D_t:self.D_chunk+self.D_t+self.D_y] # [16:40] + x = tokens[..., self.D_chunk+self.D_t+self.D_y:] # [40:64] + + # Apply rope1d on each part + c = self.apply_rope1d(c, positions[:,:,0], cos_c, sin_c) # chunk_idx + t = self.apply_rope1d(t, positions[:,:,1], cos_t, sin_t) # temporal + y = self.apply_rope1d(y, positions[:,:,2], cos_y, sin_y) # height + x = self.apply_rope1d(x, positions[:,:,3], cos_x, sin_x) # width + + return torch.cat((c, t, y, x), dim=-1) + + +# Legacy alias - redirect to new class +class RoPE4D_RefAware(RoPE4D_ChunkAware): + """Backward compatibility alias for RoPE4D_ChunkAware""" + def __init__(self, freq=100.0, ref_freq=10.0, temporal_freq=10.0): + super().__init__(freq=freq, chunk_freq=ref_freq, temporal_freq=temporal_freq) + +# Aliases for backward compatibility +RoPE3D_RefAware = RoPE4D_RefAware +RoPE3D_ChunkAware = RoPE4D_RefAware + + +class PositionGetter4D_ChunkAware(object): + """ + Return 4D positions: (chunk_idx, t, y, x) + + Key improvement: uses chunk_idx (continuous: 0, 1, 2, ...) instead of ref_id (sparse: 0, 4, 8, ...) + + chunk_idx: which chunk this frame belongs to + t: temporal position within the chunk + y, x: spatial position in patch grid + + Position encoding distribution (N=12, chunk_size=4): + Frame: 0 1 2 3 4 5 6 7 8 9 10 11 + chunk_idx: 0 0 0 0 1 1 1 1 2 2 2 2 ← CONTINUOUS! + t: 0 1 2 3 0 1 2 3 0 1 2 3 ← Repeating pattern + + This gives RoPE a much better distribution to work with compared to: + ref_id: 0 0 0 0 0 4 4 4 4 8 8 8 ← SPARSE (problematic!) + """ + + def __init__(self): + self.cache = {} + + def __call__(self, b, n, h, w, device, chunk_size=None, ref_ids=None): + """ + Args: + b: batch size + n: total number of frames + h, w: patch grid size (height, width in patches) + chunk_size: size of each chunk (default: n//2 for backward compatibility) + ref_ids: Deprecated, ignored. Use chunk_size instead. + Returns: + pos: (b, n*h*w, 4) with (chunk_idx, t, y, x) + """ + if chunk_size is None: + chunk_size = n // 2 # backward compatible + + key = (n, h, w, chunk_size) + if key not in self.cache: + positions = [] + for frame_idx in range(n): + # Compute chunk_idx (continuous: 0, 1, 2, ...) + if frame_idx == 0: + chunk_idx = 0 + t = 0 # First frame (anchor of first chunk) + else: + # chunk_idx = floor((frame_idx - 1) / chunk_size) + chunk_idx = (frame_idx - 1) // chunk_size + # t = position within chunk (1 to chunk_size for non-first-frame of chunk) + t = (frame_idx - 1) % chunk_size + 1 + + for py in range(h): + for px in range(w): + positions.append([chunk_idx, t, py, px]) + pos = torch.tensor(positions, dtype=torch.long, device=device) + self.cache[key] = pos + + # Clone and expand for batch + pos = self.cache[key].to(device).view(1, n*h*w, 4).expand(b, -1, 4).clone() + return pos + + def get_position_for_single_frame(self, frame_idx, h, w, device, chunk_size): + """ + Get position encoding for a single frame (useful for inference). + + Args: + frame_idx: index of the frame + h, w: patch grid size + chunk_size: size of each chunk + Returns: + pos: (1, h*w, 4) with (chunk_idx, t, y, x) + """ + if frame_idx == 0: + chunk_idx = 0 + t = 0 + else: + chunk_idx = (frame_idx - 1) // chunk_size + t = (frame_idx - 1) % chunk_size + 1 + + positions = [] + for py in range(h): + for px in range(w): + positions.append([chunk_idx, t, py, px]) + + pos = torch.tensor(positions, dtype=torch.long, device=device) + return pos.view(1, h*w, 4) + + +# Legacy aliases for backward compatibility +class PositionGetter3D_RefAware(PositionGetter4D_ChunkAware): + """Backward compatibility alias""" + pass + +PositionGetter3D_ChunkAware = PositionGetter4D_ChunkAware + + +#---------------------------------------------------------- +# ALiBi2D_Temporal: ALiBi for temporal dimensions only +# RoPE2D handles spatial (y, x), ALiBi handles (chunk_id, t) +#---------------------------------------------------------- + +class ALiBi2D_Temporal(torch.nn.Module): + """ + ALiBi (Attention with Linear Biases) for temporal dimensions (chunk_id, t). + + This is used in HYBRID mode with RoPE2D: + - RoPE2D encodes spatial positions (y, x) via rotary embeddings + - ALiBi2D_Temporal encodes temporal positions (chunk_id, t) via attention bias + + Key advantage: Perfect extrapolation for inference + - Training: chunk_id ∈ [0, 3] + - Inference: chunk_id ∈ [0, 200+] + - ALiBi uses linear distance, so no distribution shift! + + Formula: + bias(i, j) = -m * (chunk_weight * |chunk_i - chunk_j| + temporal_weight * |t_i - t_j|) + + where m is the slope for each attention head (different per head). + """ + + def __init__(self, num_heads, chunk_weight=1.0, temporal_weight=0.5): + """ + Args: + num_heads: number of attention heads + chunk_weight: weight for chunk_id distance (default 1.0, most important) + temporal_weight: weight for temporal t distance (default 0.5) + """ + super().__init__() + self.num_heads = num_heads + # Store as Python floats to avoid tensor conversion issues + self.chunk_weight = float(chunk_weight) + self.temporal_weight = float(temporal_weight) + + # Generate slopes for each head using geometric sequence + # Paper recommendation: slopes = 2^(-8/n * i) for i in [1, 2, ..., n] + slopes = torch.pow(2.0, -torch.linspace(0, 8, num_heads)) + self.register_buffer('slopes', slopes) + + def forward(self, positions_4d): + """ + Compute ALiBi temporal bias from 4D positions. + + Args: + positions_4d: (B, L, 4) with (chunk_id, t, y, x) + We only use chunk_id and t here. + + Returns: + bias: (B, num_heads, L, L) attention bias matrix + Negative values penalize distant frames. + """ + B, L, _ = positions_4d.shape + device = positions_4d.device + + # Extract temporal dimensions only (chunk_id, t) + chunk_ids = positions_4d[:, :, 0:1].float() # (B, L, 1) + t_ids = positions_4d[:, :, 1:2].float() # (B, L, 1) + + # Compute pairwise distance matrices + # chunk_dist[b, i, j] = |chunk_i - chunk_j| + chunk_dist = torch.abs( + chunk_ids - chunk_ids.transpose(1, 2) + ) # (B, L, L) + + t_dist = torch.abs( + t_ids - t_ids.transpose(1, 2) + ) # (B, L, L) + + # Weighted combination of distances + total_dist = ( + self.chunk_weight * chunk_dist + + self.temporal_weight * t_dist + ) # (B, L, L) + + # Apply per-head slopes: bias = -slope * distance + # slopes: (num_heads,) -> (1, num_heads, 1, 1) + # total_dist: (B, L, L) -> (B, 1, L, L) + bias = -self.slopes.view(1, -1, 1, 1) * total_dist.unsqueeze(1) + + return bias # (B, num_heads, L, L) \ No newline at end of file diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/transformer_head.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/transformer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8b03892d1629f995151fc06e1c5299f9f6b4a6f2 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/layers/transformer_head.py @@ -0,0 +1,81 @@ +from .attention import FlashAttentionRope +from .block import BlockRope +from ..dinov2.layers import Mlp +import torch.nn as nn +from functools import partial +from torch.utils.checkpoint import checkpoint +import torch.nn.functional as F + +class TransformerDecoder(nn.Module): + def __init__( + self, + in_dim, + out_dim, + dec_embed_dim=512, + depth=5, + dec_num_heads=8, + mlp_ratio=4, + rope=None, + need_project=True, + use_checkpoint=False, + ): + super().__init__() + + self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity() + self.use_checkpoint = use_checkpoint + + self.blocks = nn.ModuleList([ + BlockRope( + dim=dec_embed_dim, + num_heads=dec_num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + ffn_layer=Mlp, + init_values=None, + qk_norm=False, + # attn_class=MemEffAttentionRope, + attn_class=FlashAttentionRope, + rope=rope + ) for _ in range(depth)]) + + self.linear_out = nn.Linear(dec_embed_dim, out_dim) + + def forward(self, hidden, xpos=None): + hidden = self.projects(hidden) + for i, blk in enumerate(self.blocks): + if self.use_checkpoint and self.training: + hidden = checkpoint(blk, hidden, xpos=xpos, use_reentrant=False) + else: + hidden = blk(hidden, xpos=xpos) + out = self.linear_out(hidden) + return out + +class LinearPts3d (nn.Module): + """ + Linear head for dust3r + Each token outputs: - 16x16 3D points (+ confidence) + """ + + def __init__(self, patch_size, dec_embed_dim, output_dim=3,): + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Linear(dec_embed_dim, (output_dim)*self.patch_size**2) + + def forward(self, decout, img_shape): + H, W = img_shape + tokens = decout[-1] + B, S, D = tokens.shape + + # extract 3D points + feat = self.proj(tokens) # B,S,D + feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) + feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W + + # permute + norm depth + return feat.permute(0, 2, 3, 1) \ No newline at end of file diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/slamformer.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/slamformer.py new file mode 100644 index 0000000000000000000000000000000000000000..be601fe006038e86ca755c1493070dcfc737f709 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/slamformer.py @@ -0,0 +1,304 @@ +import torch +import torch.nn as nn +from functools import partial +from copy import deepcopy + +from .dinov2.layers import Mlp +from ..utils.geometry import homogenize_points +from .layers.pos_embed import RoPE2D, PositionGetter +from .layers.block import BlockRope +from .layers.attention import FlashAttentionRope +from .layers.transformer_head import TransformerDecoder, LinearPts3d +from .layers.camera_head import CameraHead +from .dinov2.hub.backbones import dinov2_vitl14, dinov2_vitl14_reg +from huggingface_hub import PyTorchModelHubMixin +from torch.utils.checkpoint import checkpoint +from typing import Optional, Tuple, List, Any +from dataclasses import dataclass +from transformers.file_utils import ModelOutput + + +from .layers.dpt_head import DPTHead + +@dataclass +class StreamVGGTOutput(ModelOutput): + ress: Optional[List[dict]] = None + views: Optional[torch.Tensor] = None + + + + +class SLAMFormer(nn.Module, PyTorchModelHubMixin): + def __init__( + self, + pos_type='rope100', + decoder_size='large', + ): + super().__init__() + + # ---------------------- + # Encoder + # ---------------------- + self.encoder = dinov2_vitl14_reg(pretrained=False) + self.patch_size = 14 + del self.encoder.mask_token + + # ---------------------- + # Positonal Encoding + # ---------------------- + self.pos_type = pos_type if pos_type is not None else 'none' + self.rope=None + if self.pos_type.startswith('rope'): # eg rope100 + if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions") + freq = float(self.pos_type[len('rope'):]) + self.rope = RoPE2D(freq=freq) + self.position_getter = PositionGetter() + else: + raise NotImplementedError + + + # ---------------------- + # Decoder + # ---------------------- + enc_embed_dim = self.encoder.blocks[0].attn.qkv.in_features # 1024 + if decoder_size == 'small': + dec_embed_dim = 384 + dec_num_heads = 6 + mlp_ratio = 4 + dec_depth = 24 + elif decoder_size == 'base': + dec_embed_dim = 768 + dec_num_heads = 12 + mlp_ratio = 4 + dec_depth = 24 + elif decoder_size == 'large': + dec_embed_dim = 1024 + dec_num_heads = 16 + mlp_ratio = 4 + dec_depth = 36 + else: + raise NotImplementedError + self.decoder = nn.ModuleList([ + BlockRope( + dim=dec_embed_dim, + num_heads=dec_num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + ffn_layer=Mlp, + init_values=0.01, + qk_norm=True, + attn_class=FlashAttentionRope, + rope=self.rope + ) for _ in range(dec_depth)]) + self.dec_embed_dim = dec_embed_dim + + # ---------------------- + # Register_token + # ---------------------- + num_register_tokens = 5 + self.patch_start_idx = num_register_tokens + self.register_token = nn.Parameter(torch.randn(1, 1, num_register_tokens, self.dec_embed_dim)) + nn.init.normal_(self.register_token, std=1e-6) + + # ---------------------- + # Local Points Decoder + # ---------------------- + self.point_decoder = TransformerDecoder( + in_dim=2*self.dec_embed_dim, + dec_embed_dim=1024, + dec_num_heads=16, + out_dim=1024, + rope=self.rope, + use_checkpoint=True + ) + self.point_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=3) + + # ---------------------- + # Conf Decoder + # ---------------------- + self.conf_decoder = deepcopy(self.point_decoder) + self.conf_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=1) + ''' + + self.point_head = DPTHead(dim_in=1024, output_dim=4, activation="inv_log", conf_activation="expp1", intermediate_layer_idx=[0]) + ''' + + + # ---------------------- + # Camera Pose Decoder + # ---------------------- + self.camera_decoder = TransformerDecoder( + in_dim=2*self.dec_embed_dim, + dec_embed_dim=1024, + dec_num_heads=16, # 8 + out_dim=512, + rope=self.rope, + use_checkpoint=True + ) + self.camera_head = CameraHead(dim=512) + #self.intrin_head = CameraHead(dim=512) + + # For ImageNet Normalize + image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + + self.register_buffer("image_mean", image_mean) + self.register_buffer("image_std", image_std) + + + def decode(self, N, H, W, hidden_I=None, hidden_F=None): + # branch 1: frontend + # branch 2: backend + # branch 3: mix-mode + + branch = 1 + if hidden_I is not None: + branch = 1 + hidden = hidden_I + + BN, hw, C = hidden.shape + B = BN // N + + hidden = hidden.reshape(B*N, hw, -1) + + register_token = self.register_token.repeat(B, N, 1, 1).reshape(B*N, *self.register_token.shape[-2:]) + + # Concatenate special tokens with patch tokens + hidden = torch.cat([register_token, hidden], dim=1) + hw = hidden.shape[1] + + if hidden_F is None: + branch = 1 + else: + branch = 3 + hidden = hidden.view(B,N,-1,C) + hidden = torch.cat([hidden_F.view(B,N,-1,C)[:,:int(N//2)], hidden[:,int(N//2):]],axis=1).reshape(B*N,-1,C) + + + elif hidden_I is None and hidden_F is not None: + branch = 2 + hidden = hidden_F + BN, hw, C = hidden.shape + B = BN // N + + + if self.pos_type.startswith('rope'): + pos = self.position_getter(B * N, H//self.patch_size, W//self.patch_size, hidden.device) + + if self.patch_start_idx > 0: + # do not use position embedding for special tokens (camera and register tokens) + # so set pos to 0 for the special tokens + pos = pos + 1 + pos_special = torch.zeros(B * N, self.patch_start_idx, 2).to(hidden.device).to(pos.dtype) + pos = torch.cat([pos_special, pos], dim=1) + + final_output = [] + for i in range(len(self.decoder)): + blk = self.decoder[i] + + if i % 2 == 0: + pos = pos.reshape(B*N, hw, -1) + hidden = hidden.reshape(B*N, hw, -1) + global_ = False + else: + pos = pos.reshape(B, N*hw, -1) + hidden = hidden.reshape(B, N*hw, -1) + global_ = True + + + hidden = checkpoint(blk, hidden, xpos=pos, N=N, branch=branch, global_=global_, use_reentrant=False) + + if i+1 in [len(self.decoder)-1, len(self.decoder)]: + final_output.append(hidden.reshape(B*N, hw, -1)) + + return torch.cat([final_output[0], final_output[1]], dim=-1), pos.reshape(B*N, hw, -1) + + def forward(self, views, query_points): + imgs = torch.stack( + [view["img"] for view in views], dim=0 + ).permute(1, 0, 2, 3, 4) # B S C H W + + imgs = (imgs - self.image_mean) / self.image_std + + B, N, _, H, W = imgs.shape + patch_h, patch_w = H // 14, W // 14 + + shape_ = (B,N,H,W,patch_h, patch_w) + # encode by dinov2 + imgs = imgs.reshape(B*N, _, H, W) + hidden_I = self.encoder(imgs, is_training=True) + + if isinstance(hidden_I, dict): + hidden_I = hidden_I["x_norm_patchtokens"] + + hidden_F, pos = self.decode(N, H, W, hidden_I, hidden_F=None) + res_F = self.extract(hidden_F,pos,shape_) + + + C2 = hidden_F.shape[-1] + C = int(C2//2) + ''' + hidden_B, pos = self.decode(N, H, W, hidden_I=None, hidden_F=hidden_F[:,:,:int(C2//2)]) + res_B = self.extract(hidden_B,pos,shape_) + ''' + + hidden_M, pos = self.decode(N, H, W, hidden_I,hidden_F[:,:,:int(C2//2)]) + res_M = self.extract(hidden_M,pos,shape_) + + + N_half = int(N//2) + hidden_input_B = torch.cat([hidden_F.view(B,N,-1,C2)[:,:N_half,:,:int(C2//2)], hidden_M.view(B,N,-1,C2)[:,N_half:,:,:int(C2//2)]],axis=1).view(B*N,-1,C) + hidden_B, pos = self.decode(N, H, W, hidden_I=None,hidden_F=hidden_input_B) + res_B = self.extract(hidden_B,pos,shape_) + + + + output_F=StreamVGGTOutput(ress=[res_F],views=views) + output_M=StreamVGGTOutput(ress=[res_M],views=views) + output_B=StreamVGGTOutput(ress=[res_B],views=views) + + + return output_F, output_M, output_B + + + def extract(self, hidden, pos, shape_): + B,N,H,W,patch_h, patch_w = shape_ + + point_hidden = self.point_decoder(hidden, xpos=pos) # BN, P, 1024 + conf_hidden = self.conf_decoder(hidden, xpos=pos) + camera_hidden = self.camera_decoder(hidden, xpos=pos) + + + + with torch.amp.autocast(device_type='cuda', enabled=False): + # local points + point_hidden = point_hidden.float() + ret = self.point_head([point_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1) + xy, z = ret.split([2, 1], dim=-1) + z = torch.exp(z) + local_points = torch.cat([xy * z, z], dim=-1) + + # confidence + conf_hidden = conf_hidden.float() + conf = self.conf_head([conf_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1) + + # camera + camera_hidden = camera_hidden.float() + camera_poses = self.camera_head(camera_hidden[:, self.patch_start_idx:], patch_h, patch_w).reshape(B, N, 4, 4) + + + # unproject local points using camera poses + points = torch.einsum('bnij, bnhwj -> bnhwi', camera_poses, homogenize_points(local_points))[..., :3] + + output = dict(points=points, + local_points=local_points, + conf=conf, + camera_poses=camera_poses, + ) + + return output diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/utils/basic.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/utils/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac73492409b1f2441a84e2f9de9681b3cf3ca9f --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/utils/basic.py @@ -0,0 +1,223 @@ +import os +import os.path as osp +import math +import cv2 +from PIL import Image +import torch +from torchvision import transforms +from plyfile import PlyData, PlyElement +import numpy as np + +def load_images_as_tensor(path='data/truck', interval=1, PIXEL_LIMIT=255000): + """ + Loads images from a directory or video, resizes them to a uniform size, + then converts and stacks them into a single [N, 3, H, W] PyTorch tensor. + """ + sources = [] + + # --- 1. Load image paths or video frames --- + if osp.isdir(path): + print(f"Loading images from directory: {path}") + filenames = sorted([x for x in os.listdir(path) if x.lower().endswith(('.png', '.jpg', '.jpeg'))]) + for i in range(0, len(filenames), interval): + img_path = osp.join(path, filenames[i]) + try: + sources.append(Image.open(img_path).convert('RGB')) + except Exception as e: + print(f"Could not load image {filenames[i]}: {e}") + elif path.lower().endswith('.mp4'): + print(f"Loading frames from video: {path}") + cap = cv2.VideoCapture(path) + if not cap.isOpened(): raise IOError(f"Cannot open video file: {path}") + frame_idx = 0 + while True: + ret, frame = cap.read() + if not ret: break + if frame_idx % interval == 0: + rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + sources.append(Image.fromarray(rgb_frame)) + frame_idx += 1 + cap.release() + else: + raise ValueError(f"Unsupported path. Must be a directory or a .mp4 file: {path}") + + if not sources: + print("No images found or loaded.") + return torch.empty(0) + + print(f"Found {len(sources)} images/frames. Processing...") + + # --- 2. Determine a uniform target size for all images based on the first image --- + # This is necessary to ensure all tensors have the same dimensions for stacking. + first_img = sources[0] + W_orig, H_orig = first_img.size + scale = math.sqrt(PIXEL_LIMIT / (W_orig * H_orig)) if W_orig * H_orig > 0 else 1 + W_target, H_target = W_orig * scale, H_orig * scale + k, m = round(W_target / 14), round(H_target / 14) + while (k * 14) * (m * 14) > PIXEL_LIMIT: + if k / m > W_target / H_target: k -= 1 + else: m -= 1 + TARGET_W, TARGET_H = max(1, k) * 14, max(1, m) * 14 + print(f"All images will be resized to a uniform size: ({TARGET_W}, {TARGET_H})") + + # --- 3. Resize images and convert them to tensors in the [0, 1] range --- + tensor_list = [] + # Define a transform to convert a PIL Image to a CxHxW tensor and normalize to [0,1] + to_tensor_transform = transforms.ToTensor() + + for img_pil in sources: + try: + # Resize to the uniform target size + resized_img = img_pil.resize((TARGET_W, TARGET_H), Image.Resampling.LANCZOS) + # Convert to tensor + img_tensor = to_tensor_transform(resized_img) + tensor_list.append(img_tensor) + except Exception as e: + print(f"Error processing an image: {e}") + + if not tensor_list: + print("No images were successfully processed.") + return torch.empty(0) + + # --- 4. Stack the list of tensors into a single [N, C, H, W] batch tensor --- + return torch.stack(tensor_list, dim=0) + + +def tensor_to_pil(tensor): + """ + Converts a PyTorch tensor to a PIL image. Automatically moves the channel dimension + (if it has size 3) to the last axis before converting. + + Args: + tensor (torch.Tensor): Input tensor. Expected shape can be [C, H, W], [H, W, C], or [H, W]. + + Returns: + PIL.Image: The converted PIL image. + """ + if torch.is_tensor(tensor): + array = tensor.detach().cpu().numpy() + else: + array = tensor + + return array_to_pil(array) + + +def array_to_pil(array): + """ + Converts a NumPy array to a PIL image. Automatically: + - Squeezes dimensions of size 1. + - Moves the channel dimension (if it has size 3) to the last axis. + + Args: + array (np.ndarray): Input array. Expected shape can be [C, H, W], [H, W, C], or [H, W]. + + Returns: + PIL.Image: The converted PIL image. + """ + # Remove singleton dimensions + array = np.squeeze(array) + + # Ensure the array has the channel dimension as the last axis + if array.ndim == 3 and array.shape[0] == 3: # If the channel is the first axis + array = np.transpose(array, (1, 2, 0)) # Move channel to the last axis + + # Handle single-channel grayscale images + if array.ndim == 2: # [H, W] + return Image.fromarray((array * 255).astype(np.uint8), mode="L") + elif array.ndim == 3 and array.shape[2] == 3: # [H, W, C] with 3 channels + return Image.fromarray((array * 255).astype(np.uint8), mode="RGB") + else: + raise ValueError(f"Unsupported array shape for PIL conversion: {array.shape}") + + +def rotate_target_dim_to_last_axis(x, target_dim=3): + shape = x.shape + axis_to_move = -1 + # Iterate backwards to find the first occurrence from the end + # (which corresponds to the last dimension of size 3 in the original order). + for i in range(len(shape) - 1, -1, -1): + if shape[i] == target_dim: + axis_to_move = i + break + + # 2. If the axis is found and it's not already in the last position, move it. + if axis_to_move != -1 and axis_to_move != len(shape) - 1: + # Create the new dimension order. + dims_order = list(range(len(shape))) + dims_order.pop(axis_to_move) + dims_order.append(axis_to_move) + + # Use permute to reorder the dimensions. + ret = x.transpose(*dims_order) + else: + ret = x + + return ret + + +def write_ply( + xyz, + rgb=None, + path='output.ply', +) -> None: + if torch.is_tensor(xyz): + xyz = xyz.detach().cpu().numpy() + + if torch.is_tensor(rgb): + rgb = rgb.detach().cpu().numpy() + + if rgb is not None and rgb.max() > 1: + rgb = rgb / 255. + + xyz = rotate_target_dim_to_last_axis(xyz, 3) + xyz = xyz.reshape(-1, 3) + + if rgb is not None: + rgb = rotate_target_dim_to_last_axis(rgb, 3) + rgb = rgb.reshape(-1, 3) + + if rgb is None: + min_coord = np.min(xyz, axis=0) + max_coord = np.max(xyz, axis=0) + normalized_coord = (xyz - min_coord) / (max_coord - min_coord + 1e-8) + + hue = 0.7 * normalized_coord[:,0] + 0.2 * normalized_coord[:,1] + 0.1 * normalized_coord[:,2] + hsv = np.stack([hue, 0.9*np.ones_like(hue), 0.8*np.ones_like(hue)], axis=1) + + c = hsv[:,2:] * hsv[:,1:2] + x = c * (1 - np.abs( (hsv[:,0:1]*6) % 2 - 1 )) + m = hsv[:,2:] - c + + rgb = np.zeros_like(hsv) + cond = (0 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 1) + rgb[cond] = np.hstack([c[cond], x[cond], np.zeros_like(x[cond])]) + cond = (1 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 2) + rgb[cond] = np.hstack([x[cond], c[cond], np.zeros_like(x[cond])]) + cond = (2 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 3) + rgb[cond] = np.hstack([np.zeros_like(x[cond]), c[cond], x[cond]]) + cond = (3 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 4) + rgb[cond] = np.hstack([np.zeros_like(x[cond]), x[cond], c[cond]]) + cond = (4 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 5) + rgb[cond] = np.hstack([x[cond], np.zeros_like(x[cond]), c[cond]]) + cond = (5 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 6) + rgb[cond] = np.hstack([c[cond], np.zeros_like(x[cond]), x[cond]]) + rgb = (rgb + m) + + dtype = [ + ("x", "f4"), + ("y", "f4"), + ("z", "f4"), + ("nx", "f4"), + ("ny", "f4"), + ("nz", "f4"), + ("red", "u1"), + ("green", "u1"), + ("blue", "u1"), + ] + normals = np.zeros_like(xyz) + elements = np.empty(xyz.shape[0], dtype=dtype) + attributes = np.concatenate((xyz, normals, rgb * 255), axis=1) + elements[:] = list(map(tuple, attributes)) + vertex_element = PlyElement.describe(elements, "vertex") + ply_data = PlyData([vertex_element]) + ply_data.write(path) \ No newline at end of file diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/utils/debug.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/utils/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..f3da8f3d6caece7c828cab2574bf1cbdd207779e --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/utils/debug.py @@ -0,0 +1,63 @@ +import os +import json +import debugpy +import socket +import random + +def update_vscode_launch_file(host: str, port: int): + """Update the .vscode/launch.json file with the new host and port.""" + launch_file_path = ".vscode/launch.json" + # Desired configuration + new_config = { + "version": "0.2.0", + "configurations": [ + { + "name": "bash_debug", + "type": "debugpy", + "request": "attach", + "connect": { + "host": host, + "port": port + }, + "justMyCode": False + }, + ] + } + + # Ensure the .vscode directory exists + if not os.path.exists(".vscode"): + os.makedirs(".vscode") + + # Write the updated configuration to launch.json + with open(launch_file_path, "w") as f: + json.dump(new_config, f, indent=4) + print(f"Updated {launch_file_path} with host: {host} and port: {port}") + +def is_port_in_use(host, port): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex((host, port)) == 0 + +def setup_debug(is_main_process=True, max_retries=10, port_range=(10000, 20000)): + if is_main_process: + host = os.environ['SLURM_NODELIST'].split(',')[0] + + for _ in range(max_retries): + port = random.randint(*port_range) + try: + if is_port_in_use(host, port): + print(f"Port {port} is already in use, trying another...") + continue + + # 更新 launch.json + update_vscode_launch_file(host, port) + + print("master_addr = ", host) + debugpy.listen((host, port)) + print(f"Waiting for debugger attach at port {port}...", flush=True) + debugpy.wait_for_client() + print("Debugger attached", flush=True) + return + except Exception as e: + print(f"Failed to bind to port {port}: {e}") + + raise RuntimeError("Could not find a free port for debugpy after several attempts.") \ No newline at end of file diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/utils/geometry.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..903b31d62f3fc9036169e2eba60cf21491a3791c --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/utils/geometry.py @@ -0,0 +1,414 @@ +import numpy as np +import torch +import torch.nn.functional as F + +def se3_inverse(T): + """ + Computes the inverse of a batch of SE(3) matrices. + T: Tensor of shape (B, 4, 4) + """ + if len(T.shape) == 2: + T = T[None] + unseq_flag = True + else: + unseq_flag = False + + if torch.is_tensor(T): + R = T[:, :3, :3] + t = T[:, :3, 3].unsqueeze(-1) + R_inv = R.transpose(-2, -1) + t_inv = -torch.matmul(R_inv, t) + T_inv = torch.cat([ + torch.cat([R_inv, t_inv], dim=-1), + torch.tensor([0, 0, 0, 1], device=T.device, dtype=T.dtype).repeat(T.shape[0], 1, 1) + ], dim=1) + else: + R = T[:, :3, :3] + t = T[:, :3, 3, np.newaxis] + + R_inv = np.swapaxes(R, -2, -1) + t_inv = -R_inv @ t + + bottom_row = np.zeros((T.shape[0], 1, 4), dtype=T.dtype) + bottom_row[:, :, 3] = 1 + + top_part = np.concatenate([R_inv, t_inv], axis=-1) + T_inv = np.concatenate([top_part, bottom_row], axis=1) + + if unseq_flag: + T_inv = T_inv[0] + return T_inv + +def get_pixel(H, W): + # get 2D pixels (u, v) for image_a in cam_a pixel space + u_a, v_a = np.meshgrid(np.arange(W), np.arange(H)) + # u_a = np.flip(u_a, axis=1) + # v_a = np.flip(v_a, axis=0) + pixels_a = np.stack([ + u_a.flatten() + 0.5, + v_a.flatten() + 0.5, + np.ones_like(u_a.flatten()) + ], axis=0) + + return pixels_a + +def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, z_far=0, **kw): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + - camera_pose: a 4x3 or 4x4 cam2world matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.""" + X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) + if z_far > 0: + valid_mask = valid_mask & (depthmap < z_far) + + X_world = X_cam # default + if camera_pose is not None: + # R_cam2world = np.float32(camera_params["R_cam2world"]) + # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze() + R_cam2world = camera_pose[:3, :3] + t_cam2world = camera_pose[:3, 3] + + # Express in absolute coordinates (invalid depth values) + X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] + + return X_world, valid_mask + + +def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = depthmap.shape + + # Compute 3D ray associated with each pixel + # Strong assumption: there are no skew terms + # assert camera_intrinsics[0, 1] == 0.0 + # assert camera_intrinsics[1, 0] == 0.0 + if pseudo_focal is None: + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + else: + assert pseudo_focal.shape == (H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + u, v = np.meshgrid(np.arange(W), np.arange(H)) + z_cam = depthmap + x_cam = (u - cu) * z_cam / fu + y_cam = (v - cv) * z_cam / fv + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + # Mask for valid coordinates + valid_mask = (depthmap > 0.0) + # Invalid any depth > 80m + valid_mask = valid_mask + return X_cam, valid_mask + +def homogenize_points( + points, +): + """Convert batched points (xyz) to (xyz1).""" + return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + + +def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None): + + if H is None: + B,H,W = depth1.shape + else: + B = depth1.shape[0] + with torch.no_grad(): + x1_n = torch.meshgrid( + *[ + torch.linspace( + -1 + 1 / n, 1 - 1 / n, n, device=depth1.device + ) + for n in (B, H, W) + ], + indexing = 'ij' + ) + x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) + mask, x2 = warp_kpts( + x1_n.double(), + depth1.double(), + depth2.double(), + T_1to2.double(), + K1.double(), + K2.double(), + depth_interpolation_mode = depth_interpolation_mode, + relative_depth_error_threshold = relative_depth_error_threshold, + ) + prob = mask.float().reshape(B, H, W) + x2 = x2.reshape(B, H, W, 2) + return x2, prob + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05): + """Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here + Args: + kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1) + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + ( + n, + h, + w, + ) = depth0.shape + if depth_interpolation_mode == "combined": + # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation + if smooth_mask: + raise NotImplementedError("Combined bilinear and NN warp not implemented") + valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, + smooth_mask = smooth_mask, + return_relative_depth_error = return_relative_depth_error, + depth_interpolation_mode = "bilinear", + relative_depth_error_threshold = relative_depth_error_threshold) + valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, + smooth_mask = smooth_mask, + return_relative_depth_error = return_relative_depth_error, + depth_interpolation_mode = "nearest-exact", + relative_depth_error_threshold = relative_depth_error_threshold) + nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) + warp = warp_bilinear.clone() + warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid] + valid = valid_bilinear | valid_nearest + return valid, warp + + + kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[ + :, 0, :, 0 + ] + kpts0 = torch.stack( + (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1 + ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] + # Sample depth, get calculable_mask on depth != 0 + # nonzero_mask = kpts0_depth != 0 + # Sample depth, get calculable_mask on depth > 0 + nonzero_mask = kpts0_depth > 0 + + # Unproject + kpts0_h = ( + torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) + * kpts0_depth[..., None] + ) # (N, L, 3) + kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + kpts0_cam = kpts0_n + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / ( + w_kpts0_h[:, :, [2]] + 1e-4 + ) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = ( + (w_kpts0[:, :, 0] > 0) + * (w_kpts0[:, :, 0] < w - 1) + * (w_kpts0[:, :, 1] > 0) + * (w_kpts0[:, :, 1] < h - 1) + ) + w_kpts0 = torch.stack( + (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1 + ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h] + # w_kpts0[~covisible_mask, :] = -5 # xd + + w_kpts0_depth = F.grid_sample( + depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False + )[:, 0, :, 0] + + relative_depth_error = ( + (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth + ).abs() + if not smooth_mask: + consistent_mask = relative_depth_error < relative_depth_error_threshold + else: + consistent_mask = (-relative_depth_error/smooth_mask).exp() + valid_mask = nonzero_mask * covisible_mask * consistent_mask + if return_relative_depth_error: + return relative_depth_error, w_kpts0 + else: + return valid_mask, w_kpts0 + + +def geotrf(Trf, pts, ncol=None, norm=False): + """ Apply a geometric transformation to a list of 3-D points. + + H: 3x3 or 4x4 projection matrix (typically a Homography) + p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) + + ncol: int. number of columns of the result (2 or 3) + norm: float. if != 0, the resut is projected on the z=norm plane. + + Returns an array of projected 2d points. + """ + assert Trf.ndim >= 2 + if isinstance(Trf, np.ndarray): + pts = np.asarray(pts) + elif isinstance(Trf, torch.Tensor): + pts = torch.as_tensor(pts, dtype=Trf.dtype) + + # adapt shape if necessary + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + # optimized code + if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and + Trf.ndim == 3 and pts.ndim == 4): + d = pts.shape[3] + if Trf.shape[-1] == d: + pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) + elif Trf.shape[-1] == d + 1: + pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] + else: + raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}') + else: + if Trf.ndim >= 3: + n = Trf.ndim - 2 + assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) + pts = pts[:, None, :] + + if pts.shape[-1] + 1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + return res + + +def inv(mat): + """ Invert a torch or numpy matrix + """ + if isinstance(mat, torch.Tensor): + return torch.linalg.inv(mat) + if isinstance(mat, np.ndarray): + return np.linalg.inv(mat) + raise ValueError(f'bad matrix type = {type(mat)}') + +def opencv_camera_to_plucker(poses, K, H, W): + device = poses.device + B = poses.shape[0] + + pixel = torch.from_numpy(get_pixel(H, W).astype(np.float32)).to(device).T.reshape(H, W, 3)[None].repeat(B, 1, 1, 1) # (3, H, W) + pixel = torch.einsum('bij, bhwj -> bhwi', torch.inverse(K), pixel) + ray_directions = torch.einsum('bij, bhwj -> bhwi', poses[..., :3, :3], pixel) + + ray_origins = poses[..., :3, 3][:, None, None].repeat(1, H, W, 1) + + ray_directions = ray_directions / ray_directions.norm(dim=-1, keepdim=True) + plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1) + plucker_ray = torch.cat([ray_directions, plucker_normal], dim=-1) + + return plucker_ray + + +def depth_edge(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor: + """ + Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth. + + Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (torch.Tensor): shape (..., height, width) of dtype torch.bool + """ + shape = depth.shape + depth = depth.reshape(-1, 1, *shape[-2:]) + if mask is not None: + mask = mask.reshape(-1, 1, *shape[-2:]) + + if mask is None: + diff = (F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)) + else: + diff = (F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2)) + + edge = torch.zeros_like(depth, dtype=torch.bool) + if atol is not None: + edge |= diff > atol + if rtol is not None: + edge |= (diff / depth).nan_to_num_() > rtol + edge = edge.reshape(*shape) + return edge + +def weighted_procrustes(A, B, w, use_weights=True, eps=1e-16, return_T=False): + """ + X: torch tensor B x N x 3 + Y: torch tensor B x N x 3 + w: torch tensor B x N + """ + assert len(A) == len(B) + if use_weights: + W1 = torch.abs(w).sum(1, keepdim=True) + w_norm = (w / (W1 + eps)).unsqueeze(-1) + a_mean = (w_norm * A).sum(dim=1, keepdim=True) + b_mean = (w_norm * B).sum(dim=1, keepdim=True) + + A_c = A - a_mean + B_c = B - b_mean + + H = torch.einsum("bni,bnj->bij", A_c, w_norm * B_c) + + else: + a_mean = A.mean(axis=1, keepdim=True) + b_mean = B.mean(axis=1, keepdim=True) + + A_c = A - a_mean + B_c = B - b_mean + + H = torch.einsum("bij,bik->bjk", A_c, B_c) + + U, S, V = torch.svd(H) # U: B x 3 x 3, S: B x 3, V: B x 3 x 3 + Z = torch.eye(3).unsqueeze(0).repeat(A.shape[0], 1, 1).to(A.device) + Z[:, -1, -1] = torch.sign(torch.linalg.det(U @ V.transpose(1, 2))) # B x 3 x 3 + R = V @ Z @ U.transpose(1, 2) # B x 3 x 3 + t = b_mean - torch.einsum("bij,bjk->bik", R, a_mean.transpose(-2, -1)).transpose(-2, -1) + if return_T: + T = torch.eye(4).unsqueeze(0).repeat(A.shape[0], 1, 1).to(A.device) + T[:, :3, :3] = R + T[:, :3, 3] = t.squeeze() + return T + return R, t.squeeze() \ No newline at end of file diff --git a/outdoor_v48_8gpu/.hydra/config.yaml b/outdoor_v48_8gpu/.hydra/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7bb975029f17b7935e6fee0b19714ecb576b6b85 --- /dev/null +++ b/outdoor_v48_8gpu/.hydra/config.yaml @@ -0,0 +1,68 @@ +teacher: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model +pretrained: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model +load_only_encoder: false +long_context: false +fixed_length: true +resume: null +benchmark: false +num_views: 64 +num_test_views: 4 +n_corres_train: 0 +n_corres_test: 0 +train_criterion: DistillLoss() +test_criterion: DistillLoss() +allow_repeat: false +root_vkitti2: /scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti +root_kitti: /scratch-shared/wwei2/eval/kitti_odometry/dataset +root_kitti_velo: /gpfs/work2/0/prjs0824/semantickitti/dataset +root_kitti360: /scratch-shared/wwei2/downloads/kitti360/KITTI-360 +root_kitti360_velo: /scratch-shared/wwei2/downloads/kitti360/KITTI-360 +root_waymo: /scratch-shared/wwei2/waymo_v2 +root_waymo_lidar: /scratch-shared/wwei2/waymo_v2 +dataset_vkitti2: VirtualKITTI2_Multi(allow_repeat=${allow_repeat}, split='train', + ROOT="${root_vkitti2}", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), + (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=${num_views}, + n_corres=${n_corres_train}) +dataset_kitti360: KITTI360_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_kitti360}", + velodyne_root="${root_kitti360_velo}", aug_crop=16, resolution=[(518, 392), (518, + 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, + num_views=${num_views}, n_corres=${n_corres_train}) +dataset_waymo: Waymo_v2_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_waymo}", + lidar_root="${root_waymo_lidar}", aug_crop=16, resolution=[(518, 392), (518, 336), + (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=${num_views}, + n_corres=${n_corres_train}) +train_dataset: 6000 @ ${dataset_vkitti2} + 6000 @ ${dataset_kitti360} + 5400 @ ${dataset_waymo} +test_dataset: 200 @ VirtualKITTI2_Multi(split='train', ROOT="${root_vkitti2}", resolution=(518, + 154), num_views=${num_test_views}, seed=42, n_corres=${n_corres_test}) +seed: 0 +batch_size: 1 +accum_iter: 1 +gradient_checkpointing: false +epochs: 10 +start_epoch: 0 +start_step: 0 +weight_decay: 0.05 +lr: 1.0e-05 +min_lr: 1.0e-08 +warmup_epochs: 0.5 +amp: 1 +num_workers: 4 +world_size: 1 +local-rank: -1 +dist_url: env:// +rank: 0 +gpu: 0 +distributed: false +dist_backend: nccl +eval_freq: 1 +save_freq: 0.1 +max_checkpoints: 10 +keep_freq: 1 +print_freq: 10 +print_img_freq: 50000000 +num_imgs_vis: 4 +save_dir: /scratch-shared/wwei2/training_upstream/checkpoints +exp_name: outdoor_v48_8gpu +task: StreamVGGT +logdir: ${save_dir}/${exp_name}/logs +output_dir: ${save_dir}/${exp_name}/ diff --git a/outdoor_v48_8gpu/.hydra/hydra.yaml b/outdoor_v48_8gpu/.hydra/hydra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9ea326dfea326af404151bcf3a59f4f178654c4b --- /dev/null +++ b/outdoor_v48_8gpu/.hydra/hydra.yaml @@ -0,0 +1,155 @@ +hydra: + run: + dir: ${save_dir}/${exp_name} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.num} + launcher: + _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher + sweeper: + _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper + max_batch_size: null + params: null + help: + app_name: ${hydra.job.name} + header: '${hydra.help.app_name} is powered by Hydra. + + ' + footer: 'Powered by Hydra (https://hydra.cc) + + Use --hydra-help to view Hydra specific help + + ' + template: '${hydra.help.header} + + == Configuration groups == + + Compose your configuration from those groups (group=option) + + + $APP_CONFIG_GROUPS + + + == Config == + + Override anything in the config (foo.bar=value) + + + $CONFIG + + + ${hydra.help.footer} + + ' + hydra_help: + template: 'Hydra (${hydra.runtime.version}) + + See https://hydra.cc for more info. + + + == Flags == + + $FLAGS_HELP + + + == Configuration groups == + + Compose your configuration from those groups (For example, append hydra/job_logging=disabled + to command line) + + + $HYDRA_CONFIG_GROUPS + + + Use ''--cfg hydra'' to Show the Hydra config. + + ' + hydra_help: ??? + hydra_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][HYDRA] %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + root: + level: INFO + handlers: + - console + loggers: + logging_example: + level: DEBUG + disable_existing_loggers: false + job_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + file: + class: logging.FileHandler + formatter: simple + filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log + root: + level: INFO + handlers: + - console + - file + disable_existing_loggers: false + env: {} + mode: RUN + searchpath: [] + callbacks: {} + output_subdir: .hydra + overrides: + hydra: + - hydra.mode=RUN + task: + - exp_name=outdoor_v48_8gpu + job: + name: mytrain + chdir: null + override_dirname: exp_name=outdoor_v48_8gpu + id: ??? + num: ??? + config_name: outdoor_v48 + env_set: {} + env_copy: [] + config: + override_dirname: + kv_sep: '=' + item_sep: ',' + exclude_keys: [] + runtime: + version: 1.3.2 + version_base: '1.3' + cwd: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/src + config_sources: + - path: hydra.conf + schema: pkg + provider: hydra + - path: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/config + schema: file + provider: main + - path: '' + schema: structured + provider: schema + output_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_8gpu + choices: + hydra/env: default + hydra/callbacks: null + hydra/job_logging: default + hydra/hydra_logging: default + hydra/hydra_help: default + hydra/help: default + hydra/sweeper: basic + hydra/launcher: basic + hydra/output: default + verbose: true diff --git a/outdoor_v48_8gpu/.hydra/overrides.yaml b/outdoor_v48_8gpu/.hydra/overrides.yaml new file mode 100644 index 0000000000000000000000000000000000000000..354a2ce1e9ef767e61de03fac9676b60c7068093 --- /dev/null +++ b/outdoor_v48_8gpu/.hydra/overrides.yaml @@ -0,0 +1 @@ +- exp_name=outdoor_v48_8gpu diff --git a/outdoor_v48_8gpu/mytrain.log b/outdoor_v48_8gpu/mytrain.log new file mode 100644 index 0000000000000000000000000000000000000000..e9ba41b99092df6def9b412ef1ecbe5fbd3ac50a --- /dev/null +++ b/outdoor_v48_8gpu/mytrain.log @@ -0,0 +1,1857 @@ +[2026-05-01 23:30:49,780][__main__][INFO] - [RANK 0] output_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_8gpu/ +[2026-05-01 23:30:51,331][__main__][INFO] - [RANK 0] Saving current code to /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_8gpu/code/05_01-23:30:49 +[2026-05-01 23:30:51,332][__main__][INFO] - [RANK 0] job dir: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/src +[2026-05-01 23:30:51,332][__main__][INFO] - [RANK 0] Setting seed to 0 for process 0 +[2026-05-01 23:30:51,333][__main__][INFO] - [RANK 0] Building train dataset 6000 @ VirtualKITTI2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 6000 @ KITTI360_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", velodyne_root="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 5400 @ Waymo_v2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/waymo_v2", lidar_root="/scratch-shared/wwei2/waymo_v2", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) +[2026-05-01 23:30:51,334][__main__][INFO] - [RANK 0] Building Train Data loader for dataset: 6000 @ VirtualKITTI2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 6000 @ KITTI360_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", velodyne_root="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 5400 @ Waymo_v2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/waymo_v2", lidar_root="/scratch-shared/wwei2/waymo_v2", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) +[2026-05-01 23:35:09,093][__main__][INFO] - [RANK 0] Building test dataset 200 @ VirtualKITTI2_Multi(split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", resolution=(518, 154), num_views=4, seed=42, n_corres=0) +[2026-05-01 23:35:09,094][__main__][INFO] - [RANK 0] Building Test Data loader for dataset: 200 @ VirtualKITTI2_Multi(split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", resolution=(518, 154), num_views=4, seed=42, n_corres=0) +[2026-05-01 23:35:09,196][__main__][INFO] - [RANK 0] Loading model +[2026-05-01 23:35:15,120][__main__][INFO] - [RANK 0] All model parameters: 958696732 +[2026-05-01 23:35:15,121][__main__][INFO] - [RANK 0] >> Creating train criterion = DistillLoss() +[2026-05-01 23:35:15,121][__main__][INFO] - [RANK 0] >> Creating test criterion = DistillLoss() +[2026-05-01 23:35:15,562][__main__][INFO] - [RANK 0] Loading pretrained: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model +[2026-05-01 23:35:26,983][__main__][INFO] - [RANK 0] _IncompatibleKeys(missing_keys=['register_token', 'image_mean', 'image_std', 'encoder.cls_token', 'encoder.pos_embed', 'encoder.register_tokens', 'encoder.patch_embed.proj.weight', 'encoder.patch_embed.proj.bias', 'encoder.blocks.0.norm1.weight', 'encoder.blocks.0.norm1.bias', 'encoder.blocks.0.attn.qkv.weight', 'encoder.blocks.0.attn.qkv.bias', 'encoder.blocks.0.attn.proj.weight', 'encoder.blocks.0.attn.proj.bias', 'encoder.blocks.0.ls1.gamma', 'encoder.blocks.0.norm2.weight', 'encoder.blocks.0.norm2.bias', 'encoder.blocks.0.mlp.fc1.weight', 'encoder.blocks.0.mlp.fc1.bias', 'encoder.blocks.0.mlp.fc2.weight', 'encoder.blocks.0.mlp.fc2.bias', 'encoder.blocks.0.ls2.gamma', 'encoder.blocks.1.norm1.weight', 'encoder.blocks.1.norm1.bias', 'encoder.blocks.1.attn.qkv.weight', 'encoder.blocks.1.attn.qkv.bias', 'encoder.blocks.1.attn.proj.weight', 'encoder.blocks.1.attn.proj.bias', 'encoder.blocks.1.ls1.gamma', 'encoder.blocks.1.norm2.weight', 'encoder.blocks.1.norm2.bias', 'encoder.blocks.1.mlp.fc1.weight', 'encoder.blocks.1.mlp.fc1.bias', 'encoder.blocks.1.mlp.fc2.weight', 'encoder.blocks.1.mlp.fc2.bias', 'encoder.blocks.1.ls2.gamma', 'encoder.blocks.2.norm1.weight', 'encoder.blocks.2.norm1.bias', 'encoder.blocks.2.attn.qkv.weight', 'encoder.blocks.2.attn.qkv.bias', 'encoder.blocks.2.attn.proj.weight', 'encoder.blocks.2.attn.proj.bias', 'encoder.blocks.2.ls1.gamma', 'encoder.blocks.2.norm2.weight', 'encoder.blocks.2.norm2.bias', 'encoder.blocks.2.mlp.fc1.weight', 'encoder.blocks.2.mlp.fc1.bias', 'encoder.blocks.2.mlp.fc2.weight', 'encoder.blocks.2.mlp.fc2.bias', 'encoder.blocks.2.ls2.gamma', 'encoder.blocks.3.norm1.weight', 'encoder.blocks.3.norm1.bias', 'encoder.blocks.3.attn.qkv.weight', 'encoder.blocks.3.attn.qkv.bias', 'encoder.blocks.3.attn.proj.weight', 'encoder.blocks.3.attn.proj.bias', 'encoder.blocks.3.ls1.gamma', 'encoder.blocks.3.norm2.weight', 'encoder.blocks.3.norm2.bias', 'encoder.blocks.3.mlp.fc1.weight', 'encoder.blocks.3.mlp.fc1.bias', 'encoder.blocks.3.mlp.fc2.weight', 'encoder.blocks.3.mlp.fc2.bias', 'encoder.blocks.3.ls2.gamma', 'encoder.blocks.4.norm1.weight', 'encoder.blocks.4.norm1.bias', 'encoder.blocks.4.attn.qkv.weight', 'encoder.blocks.4.attn.qkv.bias', 'encoder.blocks.4.attn.proj.weight', 'encoder.blocks.4.attn.proj.bias', 'encoder.blocks.4.ls1.gamma', 'encoder.blocks.4.norm2.weight', 'encoder.blocks.4.norm2.bias', 'encoder.blocks.4.mlp.fc1.weight', 'encoder.blocks.4.mlp.fc1.bias', 'encoder.blocks.4.mlp.fc2.weight', 'encoder.blocks.4.mlp.fc2.bias', 'encoder.blocks.4.ls2.gamma', 'encoder.blocks.5.norm1.weight', 'encoder.blocks.5.norm1.bias', 'encoder.blocks.5.attn.qkv.weight', 'encoder.blocks.5.attn.qkv.bias', 'encoder.blocks.5.attn.proj.weight', 'encoder.blocks.5.attn.proj.bias', 'encoder.blocks.5.ls1.gamma', 'encoder.blocks.5.norm2.weight', 'encoder.blocks.5.norm2.bias', 'encoder.blocks.5.mlp.fc1.weight', 'encoder.blocks.5.mlp.fc1.bias', 'encoder.blocks.5.mlp.fc2.weight', 'encoder.blocks.5.mlp.fc2.bias', 'encoder.blocks.5.ls2.gamma', 'encoder.blocks.6.norm1.weight', 'encoder.blocks.6.norm1.bias', 'encoder.blocks.6.attn.qkv.weight', 'encoder.blocks.6.attn.qkv.bias', 'encoder.blocks.6.attn.proj.weight', 'encoder.blocks.6.attn.proj.bias', 'encoder.blocks.6.ls1.gamma', 'encoder.blocks.6.norm2.weight', 'encoder.blocks.6.norm2.bias', 'encoder.blocks.6.mlp.fc1.weight', 'encoder.blocks.6.mlp.fc1.bias', 'encoder.blocks.6.mlp.fc2.weight', 'encoder.blocks.6.mlp.fc2.bias', 'encoder.blocks.6.ls2.gamma', 'encoder.blocks.7.norm1.weight', 'encoder.blocks.7.norm1.bias', 'encoder.blocks.7.attn.qkv.weight', 'encoder.blocks.7.attn.qkv.bias', 'encoder.blocks.7.attn.proj.weight', 'encoder.blocks.7.attn.proj.bias', 'encoder.blocks.7.ls1.gamma', 'encoder.blocks.7.norm2.weight', 'encoder.blocks.7.norm2.bias', 'encoder.blocks.7.mlp.fc1.weight', 'encoder.blocks.7.mlp.fc1.bias', 'encoder.blocks.7.mlp.fc2.weight', 'encoder.blocks.7.mlp.fc2.bias', 'encoder.blocks.7.ls2.gamma', 'encoder.blocks.8.norm1.weight', 'encoder.blocks.8.norm1.bias', 'encoder.blocks.8.attn.qkv.weight', 'encoder.blocks.8.attn.qkv.bias', 'encoder.blocks.8.attn.proj.weight', 'encoder.blocks.8.attn.proj.bias', 'encoder.blocks.8.ls1.gamma', 'encoder.blocks.8.norm2.weight', 'encoder.blocks.8.norm2.bias', 'encoder.blocks.8.mlp.fc1.weight', 'encoder.blocks.8.mlp.fc1.bias', 'encoder.blocks.8.mlp.fc2.weight', 'encoder.blocks.8.mlp.fc2.bias', 'encoder.blocks.8.ls2.gamma', 'encoder.blocks.9.norm1.weight', 'encoder.blocks.9.norm1.bias', 'encoder.blocks.9.attn.qkv.weight', 'encoder.blocks.9.attn.qkv.bias', 'encoder.blocks.9.attn.proj.weight', 'encoder.blocks.9.attn.proj.bias', 'encoder.blocks.9.ls1.gamma', 'encoder.blocks.9.norm2.weight', 'encoder.blocks.9.norm2.bias', 'encoder.blocks.9.mlp.fc1.weight', 'encoder.blocks.9.mlp.fc1.bias', 'encoder.blocks.9.mlp.fc2.weight', 'encoder.blocks.9.mlp.fc2.bias', 'encoder.blocks.9.ls2.gamma', 'encoder.blocks.10.norm1.weight', 'encoder.blocks.10.norm1.bias', 'encoder.blocks.10.attn.qkv.weight', 'encoder.blocks.10.attn.qkv.bias', 'encoder.blocks.10.attn.proj.weight', 'encoder.blocks.10.attn.proj.bias', 'encoder.blocks.10.ls1.gamma', 'encoder.blocks.10.norm2.weight', 'encoder.blocks.10.norm2.bias', 'encoder.blocks.10.mlp.fc1.weight', 'encoder.blocks.10.mlp.fc1.bias', 'encoder.blocks.10.mlp.fc2.weight', 'encoder.blocks.10.mlp.fc2.bias', 'encoder.blocks.10.ls2.gamma', 'encoder.blocks.11.norm1.weight', 'encoder.blocks.11.norm1.bias', 'encoder.blocks.11.attn.qkv.weight', 'encoder.blocks.11.attn.qkv.bias', 'encoder.blocks.11.attn.proj.weight', 'encoder.blocks.11.attn.proj.bias', 'encoder.blocks.11.ls1.gamma', 'encoder.blocks.11.norm2.weight', 'encoder.blocks.11.norm2.bias', 'encoder.blocks.11.mlp.fc1.weight', 'encoder.blocks.11.mlp.fc1.bias', 'encoder.blocks.11.mlp.fc2.weight', 'encoder.blocks.11.mlp.fc2.bias', 'encoder.blocks.11.ls2.gamma', 'encoder.blocks.12.norm1.weight', 'encoder.blocks.12.norm1.bias', 'encoder.blocks.12.attn.qkv.weight', 'encoder.blocks.12.attn.qkv.bias', 'encoder.blocks.12.attn.proj.weight', 'encoder.blocks.12.attn.proj.bias', 'encoder.blocks.12.ls1.gamma', 'encoder.blocks.12.norm2.weight', 'encoder.blocks.12.norm2.bias', 'encoder.blocks.12.mlp.fc1.weight', 'encoder.blocks.12.mlp.fc1.bias', 'encoder.blocks.12.mlp.fc2.weight', 'encoder.blocks.12.mlp.fc2.bias', 'encoder.blocks.12.ls2.gamma', 'encoder.blocks.13.norm1.weight', 'encoder.blocks.13.norm1.bias', 'encoder.blocks.13.attn.qkv.weight', 'encoder.blocks.13.attn.qkv.bias', 'encoder.blocks.13.attn.proj.weight', 'encoder.blocks.13.attn.proj.bias', 'encoder.blocks.13.ls1.gamma', 'encoder.blocks.13.norm2.weight', 'encoder.blocks.13.norm2.bias', 'encoder.blocks.13.mlp.fc1.weight', 'encoder.blocks.13.mlp.fc1.bias', 'encoder.blocks.13.mlp.fc2.weight', 'encoder.blocks.13.mlp.fc2.bias', 'encoder.blocks.13.ls2.gamma', 'encoder.blocks.14.norm1.weight', 'encoder.blocks.14.norm1.bias', 'encoder.blocks.14.attn.qkv.weight', 'encoder.blocks.14.attn.qkv.bias', 'encoder.blocks.14.attn.proj.weight', 'encoder.blocks.14.attn.proj.bias', 'encoder.blocks.14.ls1.gamma', 'encoder.blocks.14.norm2.weight', 'encoder.blocks.14.norm2.bias', 'encoder.blocks.14.mlp.fc1.weight', 'encoder.blocks.14.mlp.fc1.bias', 'encoder.blocks.14.mlp.fc2.weight', 'encoder.blocks.14.mlp.fc2.bias', 'encoder.blocks.14.ls2.gamma', 'encoder.blocks.15.norm1.weight', 'encoder.blocks.15.norm1.bias', 'encoder.blocks.15.attn.qkv.weight', 'encoder.blocks.15.attn.qkv.bias', 'encoder.blocks.15.attn.proj.weight', 'encoder.blocks.15.attn.proj.bias', 'encoder.blocks.15.ls1.gamma', 'encoder.blocks.15.norm2.weight', 'encoder.blocks.15.norm2.bias', 'encoder.blocks.15.mlp.fc1.weight', 'encoder.blocks.15.mlp.fc1.bias', 'encoder.blocks.15.mlp.fc2.weight', 'encoder.blocks.15.mlp.fc2.bias', 'encoder.blocks.15.ls2.gamma', 'encoder.blocks.16.norm1.weight', 'encoder.blocks.16.norm1.bias', 'encoder.blocks.16.attn.qkv.weight', 'encoder.blocks.16.attn.qkv.bias', 'encoder.blocks.16.attn.proj.weight', 'encoder.blocks.16.attn.proj.bias', 'encoder.blocks.16.ls1.gamma', 'encoder.blocks.16.norm2.weight', 'encoder.blocks.16.norm2.bias', 'encoder.blocks.16.mlp.fc1.weight', 'encoder.blocks.16.mlp.fc1.bias', 'encoder.blocks.16.mlp.fc2.weight', 'encoder.blocks.16.mlp.fc2.bias', 'encoder.blocks.16.ls2.gamma', 'encoder.blocks.17.norm1.weight', 'encoder.blocks.17.norm1.bias', 'encoder.blocks.17.attn.qkv.weight', 'encoder.blocks.17.attn.qkv.bias', 'encoder.blocks.17.attn.proj.weight', 'encoder.blocks.17.attn.proj.bias', 'encoder.blocks.17.ls1.gamma', 'encoder.blocks.17.norm2.weight', 'encoder.blocks.17.norm2.bias', 'encoder.blocks.17.mlp.fc1.weight', 'encoder.blocks.17.mlp.fc1.bias', 'encoder.blocks.17.mlp.fc2.weight', 'encoder.blocks.17.mlp.fc2.bias', 'encoder.blocks.17.ls2.gamma', 'encoder.blocks.18.norm1.weight', 'encoder.blocks.18.norm1.bias', 'encoder.blocks.18.attn.qkv.weight', 'encoder.blocks.18.attn.qkv.bias', 'encoder.blocks.18.attn.proj.weight', 'encoder.blocks.18.attn.proj.bias', 'encoder.blocks.18.ls1.gamma', 'encoder.blocks.18.norm2.weight', 'encoder.blocks.18.norm2.bias', 'encoder.blocks.18.mlp.fc1.weight', 'encoder.blocks.18.mlp.fc1.bias', 'encoder.blocks.18.mlp.fc2.weight', 'encoder.blocks.18.mlp.fc2.bias', 'encoder.blocks.18.ls2.gamma', 'encoder.blocks.19.norm1.weight', 'encoder.blocks.19.norm1.bias', 'encoder.blocks.19.attn.qkv.weight', 'encoder.blocks.19.attn.qkv.bias', 'encoder.blocks.19.attn.proj.weight', 'encoder.blocks.19.attn.proj.bias', 'encoder.blocks.19.ls1.gamma', 'encoder.blocks.19.norm2.weight', 'encoder.blocks.19.norm2.bias', 'encoder.blocks.19.mlp.fc1.weight', 'encoder.blocks.19.mlp.fc1.bias', 'encoder.blocks.19.mlp.fc2.weight', 'encoder.blocks.19.mlp.fc2.bias', 'encoder.blocks.19.ls2.gamma', 'encoder.blocks.20.norm1.weight', 'encoder.blocks.20.norm1.bias', 'encoder.blocks.20.attn.qkv.weight', 'encoder.blocks.20.attn.qkv.bias', 'encoder.blocks.20.attn.proj.weight', 'encoder.blocks.20.attn.proj.bias', 'encoder.blocks.20.ls1.gamma', 'encoder.blocks.20.norm2.weight', 'encoder.blocks.20.norm2.bias', 'encoder.blocks.20.mlp.fc1.weight', 'encoder.blocks.20.mlp.fc1.bias', 'encoder.blocks.20.mlp.fc2.weight', 'encoder.blocks.20.mlp.fc2.bias', 'encoder.blocks.20.ls2.gamma', 'encoder.blocks.21.norm1.weight', 'encoder.blocks.21.norm1.bias', 'encoder.blocks.21.attn.qkv.weight', 'encoder.blocks.21.attn.qkv.bias', 'encoder.blocks.21.attn.proj.weight', 'encoder.blocks.21.attn.proj.bias', 'encoder.blocks.21.ls1.gamma', 'encoder.blocks.21.norm2.weight', 'encoder.blocks.21.norm2.bias', 'encoder.blocks.21.mlp.fc1.weight', 'encoder.blocks.21.mlp.fc1.bias', 'encoder.blocks.21.mlp.fc2.weight', 'encoder.blocks.21.mlp.fc2.bias', 'encoder.blocks.21.ls2.gamma', 'encoder.blocks.22.norm1.weight', 'encoder.blocks.22.norm1.bias', 'encoder.blocks.22.attn.qkv.weight', 'encoder.blocks.22.attn.qkv.bias', 'encoder.blocks.22.attn.proj.weight', 'encoder.blocks.22.attn.proj.bias', 'encoder.blocks.22.ls1.gamma', 'encoder.blocks.22.norm2.weight', 'encoder.blocks.22.norm2.bias', 'encoder.blocks.22.mlp.fc1.weight', 'encoder.blocks.22.mlp.fc1.bias', 'encoder.blocks.22.mlp.fc2.weight', 'encoder.blocks.22.mlp.fc2.bias', 'encoder.blocks.22.ls2.gamma', 'encoder.blocks.23.norm1.weight', 'encoder.blocks.23.norm1.bias', 'encoder.blocks.23.attn.qkv.weight', 'encoder.blocks.23.attn.qkv.bias', 'encoder.blocks.23.attn.proj.weight', 'encoder.blocks.23.attn.proj.bias', 'encoder.blocks.23.ls1.gamma', 'encoder.blocks.23.norm2.weight', 'encoder.blocks.23.norm2.bias', 'encoder.blocks.23.mlp.fc1.weight', 'encoder.blocks.23.mlp.fc1.bias', 'encoder.blocks.23.mlp.fc2.weight', 'encoder.blocks.23.mlp.fc2.bias', 'encoder.blocks.23.ls2.gamma', 'encoder.norm.weight', 'encoder.norm.bias', 'decoder.0.norm1.weight', 'decoder.0.norm1.bias', 'decoder.0.attn.qkv.weight', 'decoder.0.attn.qkv.bias', 'decoder.0.attn.proj.weight', 'decoder.0.attn.proj.bias', 'decoder.0.attn.q_norm.weight', 'decoder.0.attn.q_norm.bias', 'decoder.0.attn.k_norm.weight', 'decoder.0.attn.k_norm.bias', 'decoder.0.ls1.gamma', 'decoder.0.norm2.weight', 'decoder.0.norm2.bias', 'decoder.0.mlp.fc1.weight', 'decoder.0.mlp.fc1.bias', 'decoder.0.mlp.fc2.weight', 'decoder.0.mlp.fc2.bias', 'decoder.0.ls2.gamma', 'decoder.1.norm1.weight', 'decoder.1.norm1.bias', 'decoder.1.attn.qkv.weight', 'decoder.1.attn.qkv.bias', 'decoder.1.attn.proj.weight', 'decoder.1.attn.proj.bias', 'decoder.1.attn.q_norm.weight', 'decoder.1.attn.q_norm.bias', 'decoder.1.attn.k_norm.weight', 'decoder.1.attn.k_norm.bias', 'decoder.1.ls1.gamma', 'decoder.1.norm2.weight', 'decoder.1.norm2.bias', 'decoder.1.mlp.fc1.weight', 'decoder.1.mlp.fc1.bias', 'decoder.1.mlp.fc2.weight', 'decoder.1.mlp.fc2.bias', 'decoder.1.ls2.gamma', 'decoder.2.norm1.weight', 'decoder.2.norm1.bias', 'decoder.2.attn.qkv.weight', 'decoder.2.attn.qkv.bias', 'decoder.2.attn.proj.weight', 'decoder.2.attn.proj.bias', 'decoder.2.attn.q_norm.weight', 'decoder.2.attn.q_norm.bias', 'decoder.2.attn.k_norm.weight', 'decoder.2.attn.k_norm.bias', 'decoder.2.ls1.gamma', 'decoder.2.norm2.weight', 'decoder.2.norm2.bias', 'decoder.2.mlp.fc1.weight', 'decoder.2.mlp.fc1.bias', 'decoder.2.mlp.fc2.weight', 'decoder.2.mlp.fc2.bias', 'decoder.2.ls2.gamma', 'decoder.3.norm1.weight', 'decoder.3.norm1.bias', 'decoder.3.attn.qkv.weight', 'decoder.3.attn.qkv.bias', 'decoder.3.attn.proj.weight', 'decoder.3.attn.proj.bias', 'decoder.3.attn.q_norm.weight', 'decoder.3.attn.q_norm.bias', 'decoder.3.attn.k_norm.weight', 'decoder.3.attn.k_norm.bias', 'decoder.3.ls1.gamma', 'decoder.3.norm2.weight', 'decoder.3.norm2.bias', 'decoder.3.mlp.fc1.weight', 'decoder.3.mlp.fc1.bias', 'decoder.3.mlp.fc2.weight', 'decoder.3.mlp.fc2.bias', 'decoder.3.ls2.gamma', 'decoder.4.norm1.weight', 'decoder.4.norm1.bias', 'decoder.4.attn.qkv.weight', 'decoder.4.attn.qkv.bias', 'decoder.4.attn.proj.weight', 'decoder.4.attn.proj.bias', 'decoder.4.attn.q_norm.weight', 'decoder.4.attn.q_norm.bias', 'decoder.4.attn.k_norm.weight', 'decoder.4.attn.k_norm.bias', 'decoder.4.ls1.gamma', 'decoder.4.norm2.weight', 'decoder.4.norm2.bias', 'decoder.4.mlp.fc1.weight', 'decoder.4.mlp.fc1.bias', 'decoder.4.mlp.fc2.weight', 'decoder.4.mlp.fc2.bias', 'decoder.4.ls2.gamma', 'decoder.5.norm1.weight', 'decoder.5.norm1.bias', 'decoder.5.attn.qkv.weight', 'decoder.5.attn.qkv.bias', 'decoder.5.attn.proj.weight', 'decoder.5.attn.proj.bias', 'decoder.5.attn.q_norm.weight', 'decoder.5.attn.q_norm.bias', 'decoder.5.attn.k_norm.weight', 'decoder.5.attn.k_norm.bias', 'decoder.5.ls1.gamma', 'decoder.5.norm2.weight', 'decoder.5.norm2.bias', 'decoder.5.mlp.fc1.weight', 'decoder.5.mlp.fc1.bias', 'decoder.5.mlp.fc2.weight', 'decoder.5.mlp.fc2.bias', 'decoder.5.ls2.gamma', 'decoder.6.norm1.weight', 'decoder.6.norm1.bias', 'decoder.6.attn.qkv.weight', 'decoder.6.attn.qkv.bias', 'decoder.6.attn.proj.weight', 'decoder.6.attn.proj.bias', 'decoder.6.attn.q_norm.weight', 'decoder.6.attn.q_norm.bias', 'decoder.6.attn.k_norm.weight', 'decoder.6.attn.k_norm.bias', 'decoder.6.ls1.gamma', 'decoder.6.norm2.weight', 'decoder.6.norm2.bias', 'decoder.6.mlp.fc1.weight', 'decoder.6.mlp.fc1.bias', 'decoder.6.mlp.fc2.weight', 'decoder.6.mlp.fc2.bias', 'decoder.6.ls2.gamma', 'decoder.7.norm1.weight', 'decoder.7.norm1.bias', 'decoder.7.attn.qkv.weight', 'decoder.7.attn.qkv.bias', 'decoder.7.attn.proj.weight', 'decoder.7.attn.proj.bias', 'decoder.7.attn.q_norm.weight', 'decoder.7.attn.q_norm.bias', 'decoder.7.attn.k_norm.weight', 'decoder.7.attn.k_norm.bias', 'decoder.7.ls1.gamma', 'decoder.7.norm2.weight', 'decoder.7.norm2.bias', 'decoder.7.mlp.fc1.weight', 'decoder.7.mlp.fc1.bias', 'decoder.7.mlp.fc2.weight', 'decoder.7.mlp.fc2.bias', 'decoder.7.ls2.gamma', 'decoder.8.norm1.weight', 'decoder.8.norm1.bias', 'decoder.8.attn.qkv.weight', 'decoder.8.attn.qkv.bias', 'decoder.8.attn.proj.weight', 'decoder.8.attn.proj.bias', 'decoder.8.attn.q_norm.weight', 'decoder.8.attn.q_norm.bias', 'decoder.8.attn.k_norm.weight', 'decoder.8.attn.k_norm.bias', 'decoder.8.ls1.gamma', 'decoder.8.norm2.weight', 'decoder.8.norm2.bias', 'decoder.8.mlp.fc1.weight', 'decoder.8.mlp.fc1.bias', 'decoder.8.mlp.fc2.weight', 'decoder.8.mlp.fc2.bias', 'decoder.8.ls2.gamma', 'decoder.9.norm1.weight', 'decoder.9.norm1.bias', 'decoder.9.attn.qkv.weight', 'decoder.9.attn.qkv.bias', 'decoder.9.attn.proj.weight', 'decoder.9.attn.proj.bias', 'decoder.9.attn.q_norm.weight', 'decoder.9.attn.q_norm.bias', 'decoder.9.attn.k_norm.weight', 'decoder.9.attn.k_norm.bias', 'decoder.9.ls1.gamma', 'decoder.9.norm2.weight', 'decoder.9.norm2.bias', 'decoder.9.mlp.fc1.weight', 'decoder.9.mlp.fc1.bias', 'decoder.9.mlp.fc2.weight', 'decoder.9.mlp.fc2.bias', 'decoder.9.ls2.gamma', 'decoder.10.norm1.weight', 'decoder.10.norm1.bias', 'decoder.10.attn.qkv.weight', 'decoder.10.attn.qkv.bias', 'decoder.10.attn.proj.weight', 'decoder.10.attn.proj.bias', 'decoder.10.attn.q_norm.weight', 'decoder.10.attn.q_norm.bias', 'decoder.10.attn.k_norm.weight', 'decoder.10.attn.k_norm.bias', 'decoder.10.ls1.gamma', 'decoder.10.norm2.weight', 'decoder.10.norm2.bias', 'decoder.10.mlp.fc1.weight', 'decoder.10.mlp.fc1.bias', 'decoder.10.mlp.fc2.weight', 'decoder.10.mlp.fc2.bias', 'decoder.10.ls2.gamma', 'decoder.11.norm1.weight', 'decoder.11.norm1.bias', 'decoder.11.attn.qkv.weight', 'decoder.11.attn.qkv.bias', 'decoder.11.attn.proj.weight', 'decoder.11.attn.proj.bias', 'decoder.11.attn.q_norm.weight', 'decoder.11.attn.q_norm.bias', 'decoder.11.attn.k_norm.weight', 'decoder.11.attn.k_norm.bias', 'decoder.11.ls1.gamma', 'decoder.11.norm2.weight', 'decoder.11.norm2.bias', 'decoder.11.mlp.fc1.weight', 'decoder.11.mlp.fc1.bias', 'decoder.11.mlp.fc2.weight', 'decoder.11.mlp.fc2.bias', 'decoder.11.ls2.gamma', 'decoder.12.norm1.weight', 'decoder.12.norm1.bias', 'decoder.12.attn.qkv.weight', 'decoder.12.attn.qkv.bias', 'decoder.12.attn.proj.weight', 'decoder.12.attn.proj.bias', 'decoder.12.attn.q_norm.weight', 'decoder.12.attn.q_norm.bias', 'decoder.12.attn.k_norm.weight', 'decoder.12.attn.k_norm.bias', 'decoder.12.ls1.gamma', 'decoder.12.norm2.weight', 'decoder.12.norm2.bias', 'decoder.12.mlp.fc1.weight', 'decoder.12.mlp.fc1.bias', 'decoder.12.mlp.fc2.weight', 'decoder.12.mlp.fc2.bias', 'decoder.12.ls2.gamma', 'decoder.13.norm1.weight', 'decoder.13.norm1.bias', 'decoder.13.attn.qkv.weight', 'decoder.13.attn.qkv.bias', 'decoder.13.attn.proj.weight', 'decoder.13.attn.proj.bias', 'decoder.13.attn.q_norm.weight', 'decoder.13.attn.q_norm.bias', 'decoder.13.attn.k_norm.weight', 'decoder.13.attn.k_norm.bias', 'decoder.13.ls1.gamma', 'decoder.13.norm2.weight', 'decoder.13.norm2.bias', 'decoder.13.mlp.fc1.weight', 'decoder.13.mlp.fc1.bias', 'decoder.13.mlp.fc2.weight', 'decoder.13.mlp.fc2.bias', 'decoder.13.ls2.gamma', 'decoder.14.norm1.weight', 'decoder.14.norm1.bias', 'decoder.14.attn.qkv.weight', 'decoder.14.attn.qkv.bias', 'decoder.14.attn.proj.weight', 'decoder.14.attn.proj.bias', 'decoder.14.attn.q_norm.weight', 'decoder.14.attn.q_norm.bias', 'decoder.14.attn.k_norm.weight', 'decoder.14.attn.k_norm.bias', 'decoder.14.ls1.gamma', 'decoder.14.norm2.weight', 'decoder.14.norm2.bias', 'decoder.14.mlp.fc1.weight', 'decoder.14.mlp.fc1.bias', 'decoder.14.mlp.fc2.weight', 'decoder.14.mlp.fc2.bias', 'decoder.14.ls2.gamma', 'decoder.15.norm1.weight', 'decoder.15.norm1.bias', 'decoder.15.attn.qkv.weight', 'decoder.15.attn.qkv.bias', 'decoder.15.attn.proj.weight', 'decoder.15.attn.proj.bias', 'decoder.15.attn.q_norm.weight', 'decoder.15.attn.q_norm.bias', 'decoder.15.attn.k_norm.weight', 'decoder.15.attn.k_norm.bias', 'decoder.15.ls1.gamma', 'decoder.15.norm2.weight', 'decoder.15.norm2.bias', 'decoder.15.mlp.fc1.weight', 'decoder.15.mlp.fc1.bias', 'decoder.15.mlp.fc2.weight', 'decoder.15.mlp.fc2.bias', 'decoder.15.ls2.gamma', 'decoder.16.norm1.weight', 'decoder.16.norm1.bias', 'decoder.16.attn.qkv.weight', 'decoder.16.attn.qkv.bias', 'decoder.16.attn.proj.weight', 'decoder.16.attn.proj.bias', 'decoder.16.attn.q_norm.weight', 'decoder.16.attn.q_norm.bias', 'decoder.16.attn.k_norm.weight', 'decoder.16.attn.k_norm.bias', 'decoder.16.ls1.gamma', 'decoder.16.norm2.weight', 'decoder.16.norm2.bias', 'decoder.16.mlp.fc1.weight', 'decoder.16.mlp.fc1.bias', 'decoder.16.mlp.fc2.weight', 'decoder.16.mlp.fc2.bias', 'decoder.16.ls2.gamma', 'decoder.17.norm1.weight', 'decoder.17.norm1.bias', 'decoder.17.attn.qkv.weight', 'decoder.17.attn.qkv.bias', 'decoder.17.attn.proj.weight', 'decoder.17.attn.proj.bias', 'decoder.17.attn.q_norm.weight', 'decoder.17.attn.q_norm.bias', 'decoder.17.attn.k_norm.weight', 'decoder.17.attn.k_norm.bias', 'decoder.17.ls1.gamma', 'decoder.17.norm2.weight', 'decoder.17.norm2.bias', 'decoder.17.mlp.fc1.weight', 'decoder.17.mlp.fc1.bias', 'decoder.17.mlp.fc2.weight', 'decoder.17.mlp.fc2.bias', 'decoder.17.ls2.gamma', 'decoder.18.norm1.weight', 'decoder.18.norm1.bias', 'decoder.18.attn.qkv.weight', 'decoder.18.attn.qkv.bias', 'decoder.18.attn.proj.weight', 'decoder.18.attn.proj.bias', 'decoder.18.attn.q_norm.weight', 'decoder.18.attn.q_norm.bias', 'decoder.18.attn.k_norm.weight', 'decoder.18.attn.k_norm.bias', 'decoder.18.ls1.gamma', 'decoder.18.norm2.weight', 'decoder.18.norm2.bias', 'decoder.18.mlp.fc1.weight', 'decoder.18.mlp.fc1.bias', 'decoder.18.mlp.fc2.weight', 'decoder.18.mlp.fc2.bias', 'decoder.18.ls2.gamma', 'decoder.19.norm1.weight', 'decoder.19.norm1.bias', 'decoder.19.attn.qkv.weight', 'decoder.19.attn.qkv.bias', 'decoder.19.attn.proj.weight', 'decoder.19.attn.proj.bias', 'decoder.19.attn.q_norm.weight', 'decoder.19.attn.q_norm.bias', 'decoder.19.attn.k_norm.weight', 'decoder.19.attn.k_norm.bias', 'decoder.19.ls1.gamma', 'decoder.19.norm2.weight', 'decoder.19.norm2.bias', 'decoder.19.mlp.fc1.weight', 'decoder.19.mlp.fc1.bias', 'decoder.19.mlp.fc2.weight', 'decoder.19.mlp.fc2.bias', 'decoder.19.ls2.gamma', 'decoder.20.norm1.weight', 'decoder.20.norm1.bias', 'decoder.20.attn.qkv.weight', 'decoder.20.attn.qkv.bias', 'decoder.20.attn.proj.weight', 'decoder.20.attn.proj.bias', 'decoder.20.attn.q_norm.weight', 'decoder.20.attn.q_norm.bias', 'decoder.20.attn.k_norm.weight', 'decoder.20.attn.k_norm.bias', 'decoder.20.ls1.gamma', 'decoder.20.norm2.weight', 'decoder.20.norm2.bias', 'decoder.20.mlp.fc1.weight', 'decoder.20.mlp.fc1.bias', 'decoder.20.mlp.fc2.weight', 'decoder.20.mlp.fc2.bias', 'decoder.20.ls2.gamma', 'decoder.21.norm1.weight', 'decoder.21.norm1.bias', 'decoder.21.attn.qkv.weight', 'decoder.21.attn.qkv.bias', 'decoder.21.attn.proj.weight', 'decoder.21.attn.proj.bias', 'decoder.21.attn.q_norm.weight', 'decoder.21.attn.q_norm.bias', 'decoder.21.attn.k_norm.weight', 'decoder.21.attn.k_norm.bias', 'decoder.21.ls1.gamma', 'decoder.21.norm2.weight', 'decoder.21.norm2.bias', 'decoder.21.mlp.fc1.weight', 'decoder.21.mlp.fc1.bias', 'decoder.21.mlp.fc2.weight', 'decoder.21.mlp.fc2.bias', 'decoder.21.ls2.gamma', 'decoder.22.norm1.weight', 'decoder.22.norm1.bias', 'decoder.22.attn.qkv.weight', 'decoder.22.attn.qkv.bias', 'decoder.22.attn.proj.weight', 'decoder.22.attn.proj.bias', 'decoder.22.attn.q_norm.weight', 'decoder.22.attn.q_norm.bias', 'decoder.22.attn.k_norm.weight', 'decoder.22.attn.k_norm.bias', 'decoder.22.ls1.gamma', 'decoder.22.norm2.weight', 'decoder.22.norm2.bias', 'decoder.22.mlp.fc1.weight', 'decoder.22.mlp.fc1.bias', 'decoder.22.mlp.fc2.weight', 'decoder.22.mlp.fc2.bias', 'decoder.22.ls2.gamma', 'decoder.23.norm1.weight', 'decoder.23.norm1.bias', 'decoder.23.attn.qkv.weight', 'decoder.23.attn.qkv.bias', 'decoder.23.attn.proj.weight', 'decoder.23.attn.proj.bias', 'decoder.23.attn.q_norm.weight', 'decoder.23.attn.q_norm.bias', 'decoder.23.attn.k_norm.weight', 'decoder.23.attn.k_norm.bias', 'decoder.23.ls1.gamma', 'decoder.23.norm2.weight', 'decoder.23.norm2.bias', 'decoder.23.mlp.fc1.weight', 'decoder.23.mlp.fc1.bias', 'decoder.23.mlp.fc2.weight', 'decoder.23.mlp.fc2.bias', 'decoder.23.ls2.gamma', 'decoder.24.norm1.weight', 'decoder.24.norm1.bias', 'decoder.24.attn.qkv.weight', 'decoder.24.attn.qkv.bias', 'decoder.24.attn.proj.weight', 'decoder.24.attn.proj.bias', 'decoder.24.attn.q_norm.weight', 'decoder.24.attn.q_norm.bias', 'decoder.24.attn.k_norm.weight', 'decoder.24.attn.k_norm.bias', 'decoder.24.ls1.gamma', 'decoder.24.norm2.weight', 'decoder.24.norm2.bias', 'decoder.24.mlp.fc1.weight', 'decoder.24.mlp.fc1.bias', 'decoder.24.mlp.fc2.weight', 'decoder.24.mlp.fc2.bias', 'decoder.24.ls2.gamma', 'decoder.25.norm1.weight', 'decoder.25.norm1.bias', 'decoder.25.attn.qkv.weight', 'decoder.25.attn.qkv.bias', 'decoder.25.attn.proj.weight', 'decoder.25.attn.proj.bias', 'decoder.25.attn.q_norm.weight', 'decoder.25.attn.q_norm.bias', 'decoder.25.attn.k_norm.weight', 'decoder.25.attn.k_norm.bias', 'decoder.25.ls1.gamma', 'decoder.25.norm2.weight', 'decoder.25.norm2.bias', 'decoder.25.mlp.fc1.weight', 'decoder.25.mlp.fc1.bias', 'decoder.25.mlp.fc2.weight', 'decoder.25.mlp.fc2.bias', 'decoder.25.ls2.gamma', 'decoder.26.norm1.weight', 'decoder.26.norm1.bias', 'decoder.26.attn.qkv.weight', 'decoder.26.attn.qkv.bias', 'decoder.26.attn.proj.weight', 'decoder.26.attn.proj.bias', 'decoder.26.attn.q_norm.weight', 'decoder.26.attn.q_norm.bias', 'decoder.26.attn.k_norm.weight', 'decoder.26.attn.k_norm.bias', 'decoder.26.ls1.gamma', 'decoder.26.norm2.weight', 'decoder.26.norm2.bias', 'decoder.26.mlp.fc1.weight', 'decoder.26.mlp.fc1.bias', 'decoder.26.mlp.fc2.weight', 'decoder.26.mlp.fc2.bias', 'decoder.26.ls2.gamma', 'decoder.27.norm1.weight', 'decoder.27.norm1.bias', 'decoder.27.attn.qkv.weight', 'decoder.27.attn.qkv.bias', 'decoder.27.attn.proj.weight', 'decoder.27.attn.proj.bias', 'decoder.27.attn.q_norm.weight', 'decoder.27.attn.q_norm.bias', 'decoder.27.attn.k_norm.weight', 'decoder.27.attn.k_norm.bias', 'decoder.27.ls1.gamma', 'decoder.27.norm2.weight', 'decoder.27.norm2.bias', 'decoder.27.mlp.fc1.weight', 'decoder.27.mlp.fc1.bias', 'decoder.27.mlp.fc2.weight', 'decoder.27.mlp.fc2.bias', 'decoder.27.ls2.gamma', 'decoder.28.norm1.weight', 'decoder.28.norm1.bias', 'decoder.28.attn.qkv.weight', 'decoder.28.attn.qkv.bias', 'decoder.28.attn.proj.weight', 'decoder.28.attn.proj.bias', 'decoder.28.attn.q_norm.weight', 'decoder.28.attn.q_norm.bias', 'decoder.28.attn.k_norm.weight', 'decoder.28.attn.k_norm.bias', 'decoder.28.ls1.gamma', 'decoder.28.norm2.weight', 'decoder.28.norm2.bias', 'decoder.28.mlp.fc1.weight', 'decoder.28.mlp.fc1.bias', 'decoder.28.mlp.fc2.weight', 'decoder.28.mlp.fc2.bias', 'decoder.28.ls2.gamma', 'decoder.29.norm1.weight', 'decoder.29.norm1.bias', 'decoder.29.attn.qkv.weight', 'decoder.29.attn.qkv.bias', 'decoder.29.attn.proj.weight', 'decoder.29.attn.proj.bias', 'decoder.29.attn.q_norm.weight', 'decoder.29.attn.q_norm.bias', 'decoder.29.attn.k_norm.weight', 'decoder.29.attn.k_norm.bias', 'decoder.29.ls1.gamma', 'decoder.29.norm2.weight', 'decoder.29.norm2.bias', 'decoder.29.mlp.fc1.weight', 'decoder.29.mlp.fc1.bias', 'decoder.29.mlp.fc2.weight', 'decoder.29.mlp.fc2.bias', 'decoder.29.ls2.gamma', 'decoder.30.norm1.weight', 'decoder.30.norm1.bias', 'decoder.30.attn.qkv.weight', 'decoder.30.attn.qkv.bias', 'decoder.30.attn.proj.weight', 'decoder.30.attn.proj.bias', 'decoder.30.attn.q_norm.weight', 'decoder.30.attn.q_norm.bias', 'decoder.30.attn.k_norm.weight', 'decoder.30.attn.k_norm.bias', 'decoder.30.ls1.gamma', 'decoder.30.norm2.weight', 'decoder.30.norm2.bias', 'decoder.30.mlp.fc1.weight', 'decoder.30.mlp.fc1.bias', 'decoder.30.mlp.fc2.weight', 'decoder.30.mlp.fc2.bias', 'decoder.30.ls2.gamma', 'decoder.31.norm1.weight', 'decoder.31.norm1.bias', 'decoder.31.attn.qkv.weight', 'decoder.31.attn.qkv.bias', 'decoder.31.attn.proj.weight', 'decoder.31.attn.proj.bias', 'decoder.31.attn.q_norm.weight', 'decoder.31.attn.q_norm.bias', 'decoder.31.attn.k_norm.weight', 'decoder.31.attn.k_norm.bias', 'decoder.31.ls1.gamma', 'decoder.31.norm2.weight', 'decoder.31.norm2.bias', 'decoder.31.mlp.fc1.weight', 'decoder.31.mlp.fc1.bias', 'decoder.31.mlp.fc2.weight', 'decoder.31.mlp.fc2.bias', 'decoder.31.ls2.gamma', 'decoder.32.norm1.weight', 'decoder.32.norm1.bias', 'decoder.32.attn.qkv.weight', 'decoder.32.attn.qkv.bias', 'decoder.32.attn.proj.weight', 'decoder.32.attn.proj.bias', 'decoder.32.attn.q_norm.weight', 'decoder.32.attn.q_norm.bias', 'decoder.32.attn.k_norm.weight', 'decoder.32.attn.k_norm.bias', 'decoder.32.ls1.gamma', 'decoder.32.norm2.weight', 'decoder.32.norm2.bias', 'decoder.32.mlp.fc1.weight', 'decoder.32.mlp.fc1.bias', 'decoder.32.mlp.fc2.weight', 'decoder.32.mlp.fc2.bias', 'decoder.32.ls2.gamma', 'decoder.33.norm1.weight', 'decoder.33.norm1.bias', 'decoder.33.attn.qkv.weight', 'decoder.33.attn.qkv.bias', 'decoder.33.attn.proj.weight', 'decoder.33.attn.proj.bias', 'decoder.33.attn.q_norm.weight', 'decoder.33.attn.q_norm.bias', 'decoder.33.attn.k_norm.weight', 'decoder.33.attn.k_norm.bias', 'decoder.33.ls1.gamma', 'decoder.33.norm2.weight', 'decoder.33.norm2.bias', 'decoder.33.mlp.fc1.weight', 'decoder.33.mlp.fc1.bias', 'decoder.33.mlp.fc2.weight', 'decoder.33.mlp.fc2.bias', 'decoder.33.ls2.gamma', 'decoder.34.norm1.weight', 'decoder.34.norm1.bias', 'decoder.34.attn.qkv.weight', 'decoder.34.attn.qkv.bias', 'decoder.34.attn.proj.weight', 'decoder.34.attn.proj.bias', 'decoder.34.attn.q_norm.weight', 'decoder.34.attn.q_norm.bias', 'decoder.34.attn.k_norm.weight', 'decoder.34.attn.k_norm.bias', 'decoder.34.ls1.gamma', 'decoder.34.norm2.weight', 'decoder.34.norm2.bias', 'decoder.34.mlp.fc1.weight', 'decoder.34.mlp.fc1.bias', 'decoder.34.mlp.fc2.weight', 'decoder.34.mlp.fc2.bias', 'decoder.34.ls2.gamma', 'decoder.35.norm1.weight', 'decoder.35.norm1.bias', 'decoder.35.attn.qkv.weight', 'decoder.35.attn.qkv.bias', 'decoder.35.attn.proj.weight', 'decoder.35.attn.proj.bias', 'decoder.35.attn.q_norm.weight', 'decoder.35.attn.q_norm.bias', 'decoder.35.attn.k_norm.weight', 'decoder.35.attn.k_norm.bias', 'decoder.35.ls1.gamma', 'decoder.35.norm2.weight', 'decoder.35.norm2.bias', 'decoder.35.mlp.fc1.weight', 'decoder.35.mlp.fc1.bias', 'decoder.35.mlp.fc2.weight', 'decoder.35.mlp.fc2.bias', 'decoder.35.ls2.gamma', 'point_decoder.projects.weight', 'point_decoder.projects.bias', 'point_decoder.blocks.0.norm1.weight', 'point_decoder.blocks.0.norm1.bias', 'point_decoder.blocks.0.attn.qkv.weight', 'point_decoder.blocks.0.attn.qkv.bias', 'point_decoder.blocks.0.attn.proj.weight', 'point_decoder.blocks.0.attn.proj.bias', 'point_decoder.blocks.0.norm2.weight', 'point_decoder.blocks.0.norm2.bias', 'point_decoder.blocks.0.mlp.fc1.weight', 'point_decoder.blocks.0.mlp.fc1.bias', 'point_decoder.blocks.0.mlp.fc2.weight', 'point_decoder.blocks.0.mlp.fc2.bias', 'point_decoder.blocks.1.norm1.weight', 'point_decoder.blocks.1.norm1.bias', 'point_decoder.blocks.1.attn.qkv.weight', 'point_decoder.blocks.1.attn.qkv.bias', 'point_decoder.blocks.1.attn.proj.weight', 'point_decoder.blocks.1.attn.proj.bias', 'point_decoder.blocks.1.norm2.weight', 'point_decoder.blocks.1.norm2.bias', 'point_decoder.blocks.1.mlp.fc1.weight', 'point_decoder.blocks.1.mlp.fc1.bias', 'point_decoder.blocks.1.mlp.fc2.weight', 'point_decoder.blocks.1.mlp.fc2.bias', 'point_decoder.blocks.2.norm1.weight', 'point_decoder.blocks.2.norm1.bias', 'point_decoder.blocks.2.attn.qkv.weight', 'point_decoder.blocks.2.attn.qkv.bias', 'point_decoder.blocks.2.attn.proj.weight', 'point_decoder.blocks.2.attn.proj.bias', 'point_decoder.blocks.2.norm2.weight', 'point_decoder.blocks.2.norm2.bias', 'point_decoder.blocks.2.mlp.fc1.weight', 'point_decoder.blocks.2.mlp.fc1.bias', 'point_decoder.blocks.2.mlp.fc2.weight', 'point_decoder.blocks.2.mlp.fc2.bias', 'point_decoder.blocks.3.norm1.weight', 'point_decoder.blocks.3.norm1.bias', 'point_decoder.blocks.3.attn.qkv.weight', 'point_decoder.blocks.3.attn.qkv.bias', 'point_decoder.blocks.3.attn.proj.weight', 'point_decoder.blocks.3.attn.proj.bias', 'point_decoder.blocks.3.norm2.weight', 'point_decoder.blocks.3.norm2.bias', 'point_decoder.blocks.3.mlp.fc1.weight', 'point_decoder.blocks.3.mlp.fc1.bias', 'point_decoder.blocks.3.mlp.fc2.weight', 'point_decoder.blocks.3.mlp.fc2.bias', 'point_decoder.blocks.4.norm1.weight', 'point_decoder.blocks.4.norm1.bias', 'point_decoder.blocks.4.attn.qkv.weight', 'point_decoder.blocks.4.attn.qkv.bias', 'point_decoder.blocks.4.attn.proj.weight', 'point_decoder.blocks.4.attn.proj.bias', 'point_decoder.blocks.4.norm2.weight', 'point_decoder.blocks.4.norm2.bias', 'point_decoder.blocks.4.mlp.fc1.weight', 'point_decoder.blocks.4.mlp.fc1.bias', 'point_decoder.blocks.4.mlp.fc2.weight', 'point_decoder.blocks.4.mlp.fc2.bias', 'point_decoder.linear_out.weight', 'point_decoder.linear_out.bias', 'point_head.proj.weight', 'point_head.proj.bias', 'conf_decoder.projects.weight', 'conf_decoder.projects.bias', 'conf_decoder.blocks.0.norm1.weight', 'conf_decoder.blocks.0.norm1.bias', 'conf_decoder.blocks.0.attn.qkv.weight', 'conf_decoder.blocks.0.attn.qkv.bias', 'conf_decoder.blocks.0.attn.proj.weight', 'conf_decoder.blocks.0.attn.proj.bias', 'conf_decoder.blocks.0.norm2.weight', 'conf_decoder.blocks.0.norm2.bias', 'conf_decoder.blocks.0.mlp.fc1.weight', 'conf_decoder.blocks.0.mlp.fc1.bias', 'conf_decoder.blocks.0.mlp.fc2.weight', 'conf_decoder.blocks.0.mlp.fc2.bias', 'conf_decoder.blocks.1.norm1.weight', 'conf_decoder.blocks.1.norm1.bias', 'conf_decoder.blocks.1.attn.qkv.weight', 'conf_decoder.blocks.1.attn.qkv.bias', 'conf_decoder.blocks.1.attn.proj.weight', 'conf_decoder.blocks.1.attn.proj.bias', 'conf_decoder.blocks.1.norm2.weight', 'conf_decoder.blocks.1.norm2.bias', 'conf_decoder.blocks.1.mlp.fc1.weight', 'conf_decoder.blocks.1.mlp.fc1.bias', 'conf_decoder.blocks.1.mlp.fc2.weight', 'conf_decoder.blocks.1.mlp.fc2.bias', 'conf_decoder.blocks.2.norm1.weight', 'conf_decoder.blocks.2.norm1.bias', 'conf_decoder.blocks.2.attn.qkv.weight', 'conf_decoder.blocks.2.attn.qkv.bias', 'conf_decoder.blocks.2.attn.proj.weight', 'conf_decoder.blocks.2.attn.proj.bias', 'conf_decoder.blocks.2.norm2.weight', 'conf_decoder.blocks.2.norm2.bias', 'conf_decoder.blocks.2.mlp.fc1.weight', 'conf_decoder.blocks.2.mlp.fc1.bias', 'conf_decoder.blocks.2.mlp.fc2.weight', 'conf_decoder.blocks.2.mlp.fc2.bias', 'conf_decoder.blocks.3.norm1.weight', 'conf_decoder.blocks.3.norm1.bias', 'conf_decoder.blocks.3.attn.qkv.weight', 'conf_decoder.blocks.3.attn.qkv.bias', 'conf_decoder.blocks.3.attn.proj.weight', 'conf_decoder.blocks.3.attn.proj.bias', 'conf_decoder.blocks.3.norm2.weight', 'conf_decoder.blocks.3.norm2.bias', 'conf_decoder.blocks.3.mlp.fc1.weight', 'conf_decoder.blocks.3.mlp.fc1.bias', 'conf_decoder.blocks.3.mlp.fc2.weight', 'conf_decoder.blocks.3.mlp.fc2.bias', 'conf_decoder.blocks.4.norm1.weight', 'conf_decoder.blocks.4.norm1.bias', 'conf_decoder.blocks.4.attn.qkv.weight', 'conf_decoder.blocks.4.attn.qkv.bias', 'conf_decoder.blocks.4.attn.proj.weight', 'conf_decoder.blocks.4.attn.proj.bias', 'conf_decoder.blocks.4.norm2.weight', 'conf_decoder.blocks.4.norm2.bias', 'conf_decoder.blocks.4.mlp.fc1.weight', 'conf_decoder.blocks.4.mlp.fc1.bias', 'conf_decoder.blocks.4.mlp.fc2.weight', 'conf_decoder.blocks.4.mlp.fc2.bias', 'conf_decoder.linear_out.weight', 'conf_decoder.linear_out.bias', 'conf_head.proj.weight', 'conf_head.proj.bias', 'camera_decoder.projects.weight', 'camera_decoder.projects.bias', 'camera_decoder.blocks.0.norm1.weight', 'camera_decoder.blocks.0.norm1.bias', 'camera_decoder.blocks.0.attn.qkv.weight', 'camera_decoder.blocks.0.attn.qkv.bias', 'camera_decoder.blocks.0.attn.proj.weight', 'camera_decoder.blocks.0.attn.proj.bias', 'camera_decoder.blocks.0.norm2.weight', 'camera_decoder.blocks.0.norm2.bias', 'camera_decoder.blocks.0.mlp.fc1.weight', 'camera_decoder.blocks.0.mlp.fc1.bias', 'camera_decoder.blocks.0.mlp.fc2.weight', 'camera_decoder.blocks.0.mlp.fc2.bias', 'camera_decoder.blocks.1.norm1.weight', 'camera_decoder.blocks.1.norm1.bias', 'camera_decoder.blocks.1.attn.qkv.weight', 'camera_decoder.blocks.1.attn.qkv.bias', 'camera_decoder.blocks.1.attn.proj.weight', 'camera_decoder.blocks.1.attn.proj.bias', 'camera_decoder.blocks.1.norm2.weight', 'camera_decoder.blocks.1.norm2.bias', 'camera_decoder.blocks.1.mlp.fc1.weight', 'camera_decoder.blocks.1.mlp.fc1.bias', 'camera_decoder.blocks.1.mlp.fc2.weight', 'camera_decoder.blocks.1.mlp.fc2.bias', 'camera_decoder.blocks.2.norm1.weight', 'camera_decoder.blocks.2.norm1.bias', 'camera_decoder.blocks.2.attn.qkv.weight', 'camera_decoder.blocks.2.attn.qkv.bias', 'camera_decoder.blocks.2.attn.proj.weight', 'camera_decoder.blocks.2.attn.proj.bias', 'camera_decoder.blocks.2.norm2.weight', 'camera_decoder.blocks.2.norm2.bias', 'camera_decoder.blocks.2.mlp.fc1.weight', 'camera_decoder.blocks.2.mlp.fc1.bias', 'camera_decoder.blocks.2.mlp.fc2.weight', 'camera_decoder.blocks.2.mlp.fc2.bias', 'camera_decoder.blocks.3.norm1.weight', 'camera_decoder.blocks.3.norm1.bias', 'camera_decoder.blocks.3.attn.qkv.weight', 'camera_decoder.blocks.3.attn.qkv.bias', 'camera_decoder.blocks.3.attn.proj.weight', 'camera_decoder.blocks.3.attn.proj.bias', 'camera_decoder.blocks.3.norm2.weight', 'camera_decoder.blocks.3.norm2.bias', 'camera_decoder.blocks.3.mlp.fc1.weight', 'camera_decoder.blocks.3.mlp.fc1.bias', 'camera_decoder.blocks.3.mlp.fc2.weight', 'camera_decoder.blocks.3.mlp.fc2.bias', 'camera_decoder.blocks.4.norm1.weight', 'camera_decoder.blocks.4.norm1.bias', 'camera_decoder.blocks.4.attn.qkv.weight', 'camera_decoder.blocks.4.attn.qkv.bias', 'camera_decoder.blocks.4.attn.proj.weight', 'camera_decoder.blocks.4.attn.proj.bias', 'camera_decoder.blocks.4.norm2.weight', 'camera_decoder.blocks.4.norm2.bias', 'camera_decoder.blocks.4.mlp.fc1.weight', 'camera_decoder.blocks.4.mlp.fc1.bias', 'camera_decoder.blocks.4.mlp.fc2.weight', 'camera_decoder.blocks.4.mlp.fc2.bias', 'camera_decoder.linear_out.weight', 'camera_decoder.linear_out.bias', 'camera_head.res_conv.0.res_conv1.weight', 'camera_head.res_conv.0.res_conv1.bias', 'camera_head.res_conv.0.res_conv2.weight', 'camera_head.res_conv.0.res_conv2.bias', 'camera_head.res_conv.0.res_conv3.weight', 'camera_head.res_conv.0.res_conv3.bias', 'camera_head.res_conv.1.res_conv1.weight', 'camera_head.res_conv.1.res_conv1.bias', 'camera_head.res_conv.1.res_conv2.weight', 'camera_head.res_conv.1.res_conv2.bias', 'camera_head.res_conv.1.res_conv3.weight', 'camera_head.res_conv.1.res_conv3.bias', 'camera_head.more_mlps.0.weight', 'camera_head.more_mlps.0.bias', 'camera_head.more_mlps.2.weight', 'camera_head.more_mlps.2.bias', 'camera_head.fc_t.weight', 'camera_head.fc_t.bias', 'camera_head.fc_rot.weight', 'camera_head.fc_rot.bias'], unexpected_keys=['module.register_token', 'module.image_mean', 'module.image_std', 'module.encoder.cls_token', 'module.encoder.pos_embed', 'module.encoder.register_tokens', 'module.encoder.patch_embed.proj.weight', 'module.encoder.patch_embed.proj.bias', 'module.encoder.blocks.0.norm1.weight', 'module.encoder.blocks.0.norm1.bias', 'module.encoder.blocks.0.attn.qkv.weight', 'module.encoder.blocks.0.attn.qkv.bias', 'module.encoder.blocks.0.attn.proj.weight', 'module.encoder.blocks.0.attn.proj.bias', 'module.encoder.blocks.0.ls1.gamma', 'module.encoder.blocks.0.norm2.weight', 'module.encoder.blocks.0.norm2.bias', 'module.encoder.blocks.0.mlp.fc1.weight', 'module.encoder.blocks.0.mlp.fc1.bias', 'module.encoder.blocks.0.mlp.fc2.weight', 'module.encoder.blocks.0.mlp.fc2.bias', 'module.encoder.blocks.0.ls2.gamma', 'module.encoder.blocks.1.norm1.weight', 'module.encoder.blocks.1.norm1.bias', 'module.encoder.blocks.1.attn.qkv.weight', 'module.encoder.blocks.1.attn.qkv.bias', 'module.encoder.blocks.1.attn.proj.weight', 'module.encoder.blocks.1.attn.proj.bias', 'module.encoder.blocks.1.ls1.gamma', 'module.encoder.blocks.1.norm2.weight', 'module.encoder.blocks.1.norm2.bias', 'module.encoder.blocks.1.mlp.fc1.weight', 'module.encoder.blocks.1.mlp.fc1.bias', 'module.encoder.blocks.1.mlp.fc2.weight', 'module.encoder.blocks.1.mlp.fc2.bias', 'module.encoder.blocks.1.ls2.gamma', 'module.encoder.blocks.2.norm1.weight', 'module.encoder.blocks.2.norm1.bias', 'module.encoder.blocks.2.attn.qkv.weight', 'module.encoder.blocks.2.attn.qkv.bias', 'module.encoder.blocks.2.attn.proj.weight', 'module.encoder.blocks.2.attn.proj.bias', 'module.encoder.blocks.2.ls1.gamma', 'module.encoder.blocks.2.norm2.weight', 'module.encoder.blocks.2.norm2.bias', 'module.encoder.blocks.2.mlp.fc1.weight', 'module.encoder.blocks.2.mlp.fc1.bias', 'module.encoder.blocks.2.mlp.fc2.weight', 'module.encoder.blocks.2.mlp.fc2.bias', 'module.encoder.blocks.2.ls2.gamma', 'module.encoder.blocks.3.norm1.weight', 'module.encoder.blocks.3.norm1.bias', 'module.encoder.blocks.3.attn.qkv.weight', 'module.encoder.blocks.3.attn.qkv.bias', 'module.encoder.blocks.3.attn.proj.weight', 'module.encoder.blocks.3.attn.proj.bias', 'module.encoder.blocks.3.ls1.gamma', 'module.encoder.blocks.3.norm2.weight', 'module.encoder.blocks.3.norm2.bias', 'module.encoder.blocks.3.mlp.fc1.weight', 'module.encoder.blocks.3.mlp.fc1.bias', 'module.encoder.blocks.3.mlp.fc2.weight', 'module.encoder.blocks.3.mlp.fc2.bias', 'module.encoder.blocks.3.ls2.gamma', 'module.encoder.blocks.4.norm1.weight', 'module.encoder.blocks.4.norm1.bias', 'module.encoder.blocks.4.attn.qkv.weight', 'module.encoder.blocks.4.attn.qkv.bias', 'module.encoder.blocks.4.attn.proj.weight', 'module.encoder.blocks.4.attn.proj.bias', 'module.encoder.blocks.4.ls1.gamma', 'module.encoder.blocks.4.norm2.weight', 'module.encoder.blocks.4.norm2.bias', 'module.encoder.blocks.4.mlp.fc1.weight', 'module.encoder.blocks.4.mlp.fc1.bias', 'module.encoder.blocks.4.mlp.fc2.weight', 'module.encoder.blocks.4.mlp.fc2.bias', 'module.encoder.blocks.4.ls2.gamma', 'module.encoder.blocks.5.norm1.weight', 'module.encoder.blocks.5.norm1.bias', 'module.encoder.blocks.5.attn.qkv.weight', 'module.encoder.blocks.5.attn.qkv.bias', 'module.encoder.blocks.5.attn.proj.weight', 'module.encoder.blocks.5.attn.proj.bias', 'module.encoder.blocks.5.ls1.gamma', 'module.encoder.blocks.5.norm2.weight', 'module.encoder.blocks.5.norm2.bias', 'module.encoder.blocks.5.mlp.fc1.weight', 'module.encoder.blocks.5.mlp.fc1.bias', 'module.encoder.blocks.5.mlp.fc2.weight', 'module.encoder.blocks.5.mlp.fc2.bias', 'module.encoder.blocks.5.ls2.gamma', 'module.encoder.blocks.6.norm1.weight', 'module.encoder.blocks.6.norm1.bias', 'module.encoder.blocks.6.attn.qkv.weight', 'module.encoder.blocks.6.attn.qkv.bias', 'module.encoder.blocks.6.attn.proj.weight', 'module.encoder.blocks.6.attn.proj.bias', 'module.encoder.blocks.6.ls1.gamma', 'module.encoder.blocks.6.norm2.weight', 'module.encoder.blocks.6.norm2.bias', 'module.encoder.blocks.6.mlp.fc1.weight', 'module.encoder.blocks.6.mlp.fc1.bias', 'module.encoder.blocks.6.mlp.fc2.weight', 'module.encoder.blocks.6.mlp.fc2.bias', 'module.encoder.blocks.6.ls2.gamma', 'module.encoder.blocks.7.norm1.weight', 'module.encoder.blocks.7.norm1.bias', 'module.encoder.blocks.7.attn.qkv.weight', 'module.encoder.blocks.7.attn.qkv.bias', 'module.encoder.blocks.7.attn.proj.weight', 'module.encoder.blocks.7.attn.proj.bias', 'module.encoder.blocks.7.ls1.gamma', 'module.encoder.blocks.7.norm2.weight', 'module.encoder.blocks.7.norm2.bias', 'module.encoder.blocks.7.mlp.fc1.weight', 'module.encoder.blocks.7.mlp.fc1.bias', 'module.encoder.blocks.7.mlp.fc2.weight', 'module.encoder.blocks.7.mlp.fc2.bias', 'module.encoder.blocks.7.ls2.gamma', 'module.encoder.blocks.8.norm1.weight', 'module.encoder.blocks.8.norm1.bias', 'module.encoder.blocks.8.attn.qkv.weight', 'module.encoder.blocks.8.attn.qkv.bias', 'module.encoder.blocks.8.attn.proj.weight', 'module.encoder.blocks.8.attn.proj.bias', 'module.encoder.blocks.8.ls1.gamma', 'module.encoder.blocks.8.norm2.weight', 'module.encoder.blocks.8.norm2.bias', 'module.encoder.blocks.8.mlp.fc1.weight', 'module.encoder.blocks.8.mlp.fc1.bias', 'module.encoder.blocks.8.mlp.fc2.weight', 'module.encoder.blocks.8.mlp.fc2.bias', 'module.encoder.blocks.8.ls2.gamma', 'module.encoder.blocks.9.norm1.weight', 'module.encoder.blocks.9.norm1.bias', 'module.encoder.blocks.9.attn.qkv.weight', 'module.encoder.blocks.9.attn.qkv.bias', 'module.encoder.blocks.9.attn.proj.weight', 'module.encoder.blocks.9.attn.proj.bias', 'module.encoder.blocks.9.ls1.gamma', 'module.encoder.blocks.9.norm2.weight', 'module.encoder.blocks.9.norm2.bias', 'module.encoder.blocks.9.mlp.fc1.weight', 'module.encoder.blocks.9.mlp.fc1.bias', 'module.encoder.blocks.9.mlp.fc2.weight', 'module.encoder.blocks.9.mlp.fc2.bias', 'module.encoder.blocks.9.ls2.gamma', 'module.encoder.blocks.10.norm1.weight', 'module.encoder.blocks.10.norm1.bias', 'module.encoder.blocks.10.attn.qkv.weight', 'module.encoder.blocks.10.attn.qkv.bias', 'module.encoder.blocks.10.attn.proj.weight', 'module.encoder.blocks.10.attn.proj.bias', 'module.encoder.blocks.10.ls1.gamma', 'module.encoder.blocks.10.norm2.weight', 'module.encoder.blocks.10.norm2.bias', 'module.encoder.blocks.10.mlp.fc1.weight', 'module.encoder.blocks.10.mlp.fc1.bias', 'module.encoder.blocks.10.mlp.fc2.weight', 'module.encoder.blocks.10.mlp.fc2.bias', 'module.encoder.blocks.10.ls2.gamma', 'module.encoder.blocks.11.norm1.weight', 'module.encoder.blocks.11.norm1.bias', 'module.encoder.blocks.11.attn.qkv.weight', 'module.encoder.blocks.11.attn.qkv.bias', 'module.encoder.blocks.11.attn.proj.weight', 'module.encoder.blocks.11.attn.proj.bias', 'module.encoder.blocks.11.ls1.gamma', 'module.encoder.blocks.11.norm2.weight', 'module.encoder.blocks.11.norm2.bias', 'module.encoder.blocks.11.mlp.fc1.weight', 'module.encoder.blocks.11.mlp.fc1.bias', 'module.encoder.blocks.11.mlp.fc2.weight', 'module.encoder.blocks.11.mlp.fc2.bias', 'module.encoder.blocks.11.ls2.gamma', 'module.encoder.blocks.12.norm1.weight', 'module.encoder.blocks.12.norm1.bias', 'module.encoder.blocks.12.attn.qkv.weight', 'module.encoder.blocks.12.attn.qkv.bias', 'module.encoder.blocks.12.attn.proj.weight', 'module.encoder.blocks.12.attn.proj.bias', 'module.encoder.blocks.12.ls1.gamma', 'module.encoder.blocks.12.norm2.weight', 'module.encoder.blocks.12.norm2.bias', 'module.encoder.blocks.12.mlp.fc1.weight', 'module.encoder.blocks.12.mlp.fc1.bias', 'module.encoder.blocks.12.mlp.fc2.weight', 'module.encoder.blocks.12.mlp.fc2.bias', 'module.encoder.blocks.12.ls2.gamma', 'module.encoder.blocks.13.norm1.weight', 'module.encoder.blocks.13.norm1.bias', 'module.encoder.blocks.13.attn.qkv.weight', 'module.encoder.blocks.13.attn.qkv.bias', 'module.encoder.blocks.13.attn.proj.weight', 'module.encoder.blocks.13.attn.proj.bias', 'module.encoder.blocks.13.ls1.gamma', 'module.encoder.blocks.13.norm2.weight', 'module.encoder.blocks.13.norm2.bias', 'module.encoder.blocks.13.mlp.fc1.weight', 'module.encoder.blocks.13.mlp.fc1.bias', 'module.encoder.blocks.13.mlp.fc2.weight', 'module.encoder.blocks.13.mlp.fc2.bias', 'module.encoder.blocks.13.ls2.gamma', 'module.encoder.blocks.14.norm1.weight', 'module.encoder.blocks.14.norm1.bias', 'module.encoder.blocks.14.attn.qkv.weight', 'module.encoder.blocks.14.attn.qkv.bias', 'module.encoder.blocks.14.attn.proj.weight', 'module.encoder.blocks.14.attn.proj.bias', 'module.encoder.blocks.14.ls1.gamma', 'module.encoder.blocks.14.norm2.weight', 'module.encoder.blocks.14.norm2.bias', 'module.encoder.blocks.14.mlp.fc1.weight', 'module.encoder.blocks.14.mlp.fc1.bias', 'module.encoder.blocks.14.mlp.fc2.weight', 'module.encoder.blocks.14.mlp.fc2.bias', 'module.encoder.blocks.14.ls2.gamma', 'module.encoder.blocks.15.norm1.weight', 'module.encoder.blocks.15.norm1.bias', 'module.encoder.blocks.15.attn.qkv.weight', 'module.encoder.blocks.15.attn.qkv.bias', 'module.encoder.blocks.15.attn.proj.weight', 'module.encoder.blocks.15.attn.proj.bias', 'module.encoder.blocks.15.ls1.gamma', 'module.encoder.blocks.15.norm2.weight', 'module.encoder.blocks.15.norm2.bias', 'module.encoder.blocks.15.mlp.fc1.weight', 'module.encoder.blocks.15.mlp.fc1.bias', 'module.encoder.blocks.15.mlp.fc2.weight', 'module.encoder.blocks.15.mlp.fc2.bias', 'module.encoder.blocks.15.ls2.gamma', 'module.encoder.blocks.16.norm1.weight', 'module.encoder.blocks.16.norm1.bias', 'module.encoder.blocks.16.attn.qkv.weight', 'module.encoder.blocks.16.attn.qkv.bias', 'module.encoder.blocks.16.attn.proj.weight', 'module.encoder.blocks.16.attn.proj.bias', 'module.encoder.blocks.16.ls1.gamma', 'module.encoder.blocks.16.norm2.weight', 'module.encoder.blocks.16.norm2.bias', 'module.encoder.blocks.16.mlp.fc1.weight', 'module.encoder.blocks.16.mlp.fc1.bias', 'module.encoder.blocks.16.mlp.fc2.weight', 'module.encoder.blocks.16.mlp.fc2.bias', 'module.encoder.blocks.16.ls2.gamma', 'module.encoder.blocks.17.norm1.weight', 'module.encoder.blocks.17.norm1.bias', 'module.encoder.blocks.17.attn.qkv.weight', 'module.encoder.blocks.17.attn.qkv.bias', 'module.encoder.blocks.17.attn.proj.weight', 'module.encoder.blocks.17.attn.proj.bias', 'module.encoder.blocks.17.ls1.gamma', 'module.encoder.blocks.17.norm2.weight', 'module.encoder.blocks.17.norm2.bias', 'module.encoder.blocks.17.mlp.fc1.weight', 'module.encoder.blocks.17.mlp.fc1.bias', 'module.encoder.blocks.17.mlp.fc2.weight', 'module.encoder.blocks.17.mlp.fc2.bias', 'module.encoder.blocks.17.ls2.gamma', 'module.encoder.blocks.18.norm1.weight', 'module.encoder.blocks.18.norm1.bias', 'module.encoder.blocks.18.attn.qkv.weight', 'module.encoder.blocks.18.attn.qkv.bias', 'module.encoder.blocks.18.attn.proj.weight', 'module.encoder.blocks.18.attn.proj.bias', 'module.encoder.blocks.18.ls1.gamma', 'module.encoder.blocks.18.norm2.weight', 'module.encoder.blocks.18.norm2.bias', 'module.encoder.blocks.18.mlp.fc1.weight', 'module.encoder.blocks.18.mlp.fc1.bias', 'module.encoder.blocks.18.mlp.fc2.weight', 'module.encoder.blocks.18.mlp.fc2.bias', 'module.encoder.blocks.18.ls2.gamma', 'module.encoder.blocks.19.norm1.weight', 'module.encoder.blocks.19.norm1.bias', 'module.encoder.blocks.19.attn.qkv.weight', 'module.encoder.blocks.19.attn.qkv.bias', 'module.encoder.blocks.19.attn.proj.weight', 'module.encoder.blocks.19.attn.proj.bias', 'module.encoder.blocks.19.ls1.gamma', 'module.encoder.blocks.19.norm2.weight', 'module.encoder.blocks.19.norm2.bias', 'module.encoder.blocks.19.mlp.fc1.weight', 'module.encoder.blocks.19.mlp.fc1.bias', 'module.encoder.blocks.19.mlp.fc2.weight', 'module.encoder.blocks.19.mlp.fc2.bias', 'module.encoder.blocks.19.ls2.gamma', 'module.encoder.blocks.20.norm1.weight', 'module.encoder.blocks.20.norm1.bias', 'module.encoder.blocks.20.attn.qkv.weight', 'module.encoder.blocks.20.attn.qkv.bias', 'module.encoder.blocks.20.attn.proj.weight', 'module.encoder.blocks.20.attn.proj.bias', 'module.encoder.blocks.20.ls1.gamma', 'module.encoder.blocks.20.norm2.weight', 'module.encoder.blocks.20.norm2.bias', 'module.encoder.blocks.20.mlp.fc1.weight', 'module.encoder.blocks.20.mlp.fc1.bias', 'module.encoder.blocks.20.mlp.fc2.weight', 'module.encoder.blocks.20.mlp.fc2.bias', 'module.encoder.blocks.20.ls2.gamma', 'module.encoder.blocks.21.norm1.weight', 'module.encoder.blocks.21.norm1.bias', 'module.encoder.blocks.21.attn.qkv.weight', 'module.encoder.blocks.21.attn.qkv.bias', 'module.encoder.blocks.21.attn.proj.weight', 'module.encoder.blocks.21.attn.proj.bias', 'module.encoder.blocks.21.ls1.gamma', 'module.encoder.blocks.21.norm2.weight', 'module.encoder.blocks.21.norm2.bias', 'module.encoder.blocks.21.mlp.fc1.weight', 'module.encoder.blocks.21.mlp.fc1.bias', 'module.encoder.blocks.21.mlp.fc2.weight', 'module.encoder.blocks.21.mlp.fc2.bias', 'module.encoder.blocks.21.ls2.gamma', 'module.encoder.blocks.22.norm1.weight', 'module.encoder.blocks.22.norm1.bias', 'module.encoder.blocks.22.attn.qkv.weight', 'module.encoder.blocks.22.attn.qkv.bias', 'module.encoder.blocks.22.attn.proj.weight', 'module.encoder.blocks.22.attn.proj.bias', 'module.encoder.blocks.22.ls1.gamma', 'module.encoder.blocks.22.norm2.weight', 'module.encoder.blocks.22.norm2.bias', 'module.encoder.blocks.22.mlp.fc1.weight', 'module.encoder.blocks.22.mlp.fc1.bias', 'module.encoder.blocks.22.mlp.fc2.weight', 'module.encoder.blocks.22.mlp.fc2.bias', 'module.encoder.blocks.22.ls2.gamma', 'module.encoder.blocks.23.norm1.weight', 'module.encoder.blocks.23.norm1.bias', 'module.encoder.blocks.23.attn.qkv.weight', 'module.encoder.blocks.23.attn.qkv.bias', 'module.encoder.blocks.23.attn.proj.weight', 'module.encoder.blocks.23.attn.proj.bias', 'module.encoder.blocks.23.ls1.gamma', 'module.encoder.blocks.23.norm2.weight', 'module.encoder.blocks.23.norm2.bias', 'module.encoder.blocks.23.mlp.fc1.weight', 'module.encoder.blocks.23.mlp.fc1.bias', 'module.encoder.blocks.23.mlp.fc2.weight', 'module.encoder.blocks.23.mlp.fc2.bias', 'module.encoder.blocks.23.ls2.gamma', 'module.encoder.norm.weight', 'module.encoder.norm.bias', 'module.decoder.0.norm1.weight', 'module.decoder.0.norm1.bias', 'module.decoder.0.attn.qkv.weight', 'module.decoder.0.attn.qkv.bias', 'module.decoder.0.attn.proj.weight', 'module.decoder.0.attn.proj.bias', 'module.decoder.0.attn.q_norm.weight', 'module.decoder.0.attn.q_norm.bias', 'module.decoder.0.attn.k_norm.weight', 'module.decoder.0.attn.k_norm.bias', 'module.decoder.0.ls1.gamma', 'module.decoder.0.norm2.weight', 'module.decoder.0.norm2.bias', 'module.decoder.0.mlp.fc1.weight', 'module.decoder.0.mlp.fc1.bias', 'module.decoder.0.mlp.fc2.weight', 'module.decoder.0.mlp.fc2.bias', 'module.decoder.0.ls2.gamma', 'module.decoder.1.norm1.weight', 'module.decoder.1.norm1.bias', 'module.decoder.1.attn.qkv.weight', 'module.decoder.1.attn.qkv.bias', 'module.decoder.1.attn.proj.weight', 'module.decoder.1.attn.proj.bias', 'module.decoder.1.attn.q_norm.weight', 'module.decoder.1.attn.q_norm.bias', 'module.decoder.1.attn.k_norm.weight', 'module.decoder.1.attn.k_norm.bias', 'module.decoder.1.ls1.gamma', 'module.decoder.1.norm2.weight', 'module.decoder.1.norm2.bias', 'module.decoder.1.mlp.fc1.weight', 'module.decoder.1.mlp.fc1.bias', 'module.decoder.1.mlp.fc2.weight', 'module.decoder.1.mlp.fc2.bias', 'module.decoder.1.ls2.gamma', 'module.decoder.2.norm1.weight', 'module.decoder.2.norm1.bias', 'module.decoder.2.attn.qkv.weight', 'module.decoder.2.attn.qkv.bias', 'module.decoder.2.attn.proj.weight', 'module.decoder.2.attn.proj.bias', 'module.decoder.2.attn.q_norm.weight', 'module.decoder.2.attn.q_norm.bias', 'module.decoder.2.attn.k_norm.weight', 'module.decoder.2.attn.k_norm.bias', 'module.decoder.2.ls1.gamma', 'module.decoder.2.norm2.weight', 'module.decoder.2.norm2.bias', 'module.decoder.2.mlp.fc1.weight', 'module.decoder.2.mlp.fc1.bias', 'module.decoder.2.mlp.fc2.weight', 'module.decoder.2.mlp.fc2.bias', 'module.decoder.2.ls2.gamma', 'module.decoder.3.norm1.weight', 'module.decoder.3.norm1.bias', 'module.decoder.3.attn.qkv.weight', 'module.decoder.3.attn.qkv.bias', 'module.decoder.3.attn.proj.weight', 'module.decoder.3.attn.proj.bias', 'module.decoder.3.attn.q_norm.weight', 'module.decoder.3.attn.q_norm.bias', 'module.decoder.3.attn.k_norm.weight', 'module.decoder.3.attn.k_norm.bias', 'module.decoder.3.ls1.gamma', 'module.decoder.3.norm2.weight', 'module.decoder.3.norm2.bias', 'module.decoder.3.mlp.fc1.weight', 'module.decoder.3.mlp.fc1.bias', 'module.decoder.3.mlp.fc2.weight', 'module.decoder.3.mlp.fc2.bias', 'module.decoder.3.ls2.gamma', 'module.decoder.4.norm1.weight', 'module.decoder.4.norm1.bias', 'module.decoder.4.attn.qkv.weight', 'module.decoder.4.attn.qkv.bias', 'module.decoder.4.attn.proj.weight', 'module.decoder.4.attn.proj.bias', 'module.decoder.4.attn.q_norm.weight', 'module.decoder.4.attn.q_norm.bias', 'module.decoder.4.attn.k_norm.weight', 'module.decoder.4.attn.k_norm.bias', 'module.decoder.4.ls1.gamma', 'module.decoder.4.norm2.weight', 'module.decoder.4.norm2.bias', 'module.decoder.4.mlp.fc1.weight', 'module.decoder.4.mlp.fc1.bias', 'module.decoder.4.mlp.fc2.weight', 'module.decoder.4.mlp.fc2.bias', 'module.decoder.4.ls2.gamma', 'module.decoder.5.norm1.weight', 'module.decoder.5.norm1.bias', 'module.decoder.5.attn.qkv.weight', 'module.decoder.5.attn.qkv.bias', 'module.decoder.5.attn.proj.weight', 'module.decoder.5.attn.proj.bias', 'module.decoder.5.attn.q_norm.weight', 'module.decoder.5.attn.q_norm.bias', 'module.decoder.5.attn.k_norm.weight', 'module.decoder.5.attn.k_norm.bias', 'module.decoder.5.ls1.gamma', 'module.decoder.5.norm2.weight', 'module.decoder.5.norm2.bias', 'module.decoder.5.mlp.fc1.weight', 'module.decoder.5.mlp.fc1.bias', 'module.decoder.5.mlp.fc2.weight', 'module.decoder.5.mlp.fc2.bias', 'module.decoder.5.ls2.gamma', 'module.decoder.6.norm1.weight', 'module.decoder.6.norm1.bias', 'module.decoder.6.attn.qkv.weight', 'module.decoder.6.attn.qkv.bias', 'module.decoder.6.attn.proj.weight', 'module.decoder.6.attn.proj.bias', 'module.decoder.6.attn.q_norm.weight', 'module.decoder.6.attn.q_norm.bias', 'module.decoder.6.attn.k_norm.weight', 'module.decoder.6.attn.k_norm.bias', 'module.decoder.6.ls1.gamma', 'module.decoder.6.norm2.weight', 'module.decoder.6.norm2.bias', 'module.decoder.6.mlp.fc1.weight', 'module.decoder.6.mlp.fc1.bias', 'module.decoder.6.mlp.fc2.weight', 'module.decoder.6.mlp.fc2.bias', 'module.decoder.6.ls2.gamma', 'module.decoder.7.norm1.weight', 'module.decoder.7.norm1.bias', 'module.decoder.7.attn.qkv.weight', 'module.decoder.7.attn.qkv.bias', 'module.decoder.7.attn.proj.weight', 'module.decoder.7.attn.proj.bias', 'module.decoder.7.attn.q_norm.weight', 'module.decoder.7.attn.q_norm.bias', 'module.decoder.7.attn.k_norm.weight', 'module.decoder.7.attn.k_norm.bias', 'module.decoder.7.ls1.gamma', 'module.decoder.7.norm2.weight', 'module.decoder.7.norm2.bias', 'module.decoder.7.mlp.fc1.weight', 'module.decoder.7.mlp.fc1.bias', 'module.decoder.7.mlp.fc2.weight', 'module.decoder.7.mlp.fc2.bias', 'module.decoder.7.ls2.gamma', 'module.decoder.8.norm1.weight', 'module.decoder.8.norm1.bias', 'module.decoder.8.attn.qkv.weight', 'module.decoder.8.attn.qkv.bias', 'module.decoder.8.attn.proj.weight', 'module.decoder.8.attn.proj.bias', 'module.decoder.8.attn.q_norm.weight', 'module.decoder.8.attn.q_norm.bias', 'module.decoder.8.attn.k_norm.weight', 'module.decoder.8.attn.k_norm.bias', 'module.decoder.8.ls1.gamma', 'module.decoder.8.norm2.weight', 'module.decoder.8.norm2.bias', 'module.decoder.8.mlp.fc1.weight', 'module.decoder.8.mlp.fc1.bias', 'module.decoder.8.mlp.fc2.weight', 'module.decoder.8.mlp.fc2.bias', 'module.decoder.8.ls2.gamma', 'module.decoder.9.norm1.weight', 'module.decoder.9.norm1.bias', 'module.decoder.9.attn.qkv.weight', 'module.decoder.9.attn.qkv.bias', 'module.decoder.9.attn.proj.weight', 'module.decoder.9.attn.proj.bias', 'module.decoder.9.attn.q_norm.weight', 'module.decoder.9.attn.q_norm.bias', 'module.decoder.9.attn.k_norm.weight', 'module.decoder.9.attn.k_norm.bias', 'module.decoder.9.ls1.gamma', 'module.decoder.9.norm2.weight', 'module.decoder.9.norm2.bias', 'module.decoder.9.mlp.fc1.weight', 'module.decoder.9.mlp.fc1.bias', 'module.decoder.9.mlp.fc2.weight', 'module.decoder.9.mlp.fc2.bias', 'module.decoder.9.ls2.gamma', 'module.decoder.10.norm1.weight', 'module.decoder.10.norm1.bias', 'module.decoder.10.attn.qkv.weight', 'module.decoder.10.attn.qkv.bias', 'module.decoder.10.attn.proj.weight', 'module.decoder.10.attn.proj.bias', 'module.decoder.10.attn.q_norm.weight', 'module.decoder.10.attn.q_norm.bias', 'module.decoder.10.attn.k_norm.weight', 'module.decoder.10.attn.k_norm.bias', 'module.decoder.10.ls1.gamma', 'module.decoder.10.norm2.weight', 'module.decoder.10.norm2.bias', 'module.decoder.10.mlp.fc1.weight', 'module.decoder.10.mlp.fc1.bias', 'module.decoder.10.mlp.fc2.weight', 'module.decoder.10.mlp.fc2.bias', 'module.decoder.10.ls2.gamma', 'module.decoder.11.norm1.weight', 'module.decoder.11.norm1.bias', 'module.decoder.11.attn.qkv.weight', 'module.decoder.11.attn.qkv.bias', 'module.decoder.11.attn.proj.weight', 'module.decoder.11.attn.proj.bias', 'module.decoder.11.attn.q_norm.weight', 'module.decoder.11.attn.q_norm.bias', 'module.decoder.11.attn.k_norm.weight', 'module.decoder.11.attn.k_norm.bias', 'module.decoder.11.ls1.gamma', 'module.decoder.11.norm2.weight', 'module.decoder.11.norm2.bias', 'module.decoder.11.mlp.fc1.weight', 'module.decoder.11.mlp.fc1.bias', 'module.decoder.11.mlp.fc2.weight', 'module.decoder.11.mlp.fc2.bias', 'module.decoder.11.ls2.gamma', 'module.decoder.12.norm1.weight', 'module.decoder.12.norm1.bias', 'module.decoder.12.attn.qkv.weight', 'module.decoder.12.attn.qkv.bias', 'module.decoder.12.attn.proj.weight', 'module.decoder.12.attn.proj.bias', 'module.decoder.12.attn.q_norm.weight', 'module.decoder.12.attn.q_norm.bias', 'module.decoder.12.attn.k_norm.weight', 'module.decoder.12.attn.k_norm.bias', 'module.decoder.12.ls1.gamma', 'module.decoder.12.norm2.weight', 'module.decoder.12.norm2.bias', 'module.decoder.12.mlp.fc1.weight', 'module.decoder.12.mlp.fc1.bias', 'module.decoder.12.mlp.fc2.weight', 'module.decoder.12.mlp.fc2.bias', 'module.decoder.12.ls2.gamma', 'module.decoder.13.norm1.weight', 'module.decoder.13.norm1.bias', 'module.decoder.13.attn.qkv.weight', 'module.decoder.13.attn.qkv.bias', 'module.decoder.13.attn.proj.weight', 'module.decoder.13.attn.proj.bias', 'module.decoder.13.attn.q_norm.weight', 'module.decoder.13.attn.q_norm.bias', 'module.decoder.13.attn.k_norm.weight', 'module.decoder.13.attn.k_norm.bias', 'module.decoder.13.ls1.gamma', 'module.decoder.13.norm2.weight', 'module.decoder.13.norm2.bias', 'module.decoder.13.mlp.fc1.weight', 'module.decoder.13.mlp.fc1.bias', 'module.decoder.13.mlp.fc2.weight', 'module.decoder.13.mlp.fc2.bias', 'module.decoder.13.ls2.gamma', 'module.decoder.14.norm1.weight', 'module.decoder.14.norm1.bias', 'module.decoder.14.attn.qkv.weight', 'module.decoder.14.attn.qkv.bias', 'module.decoder.14.attn.proj.weight', 'module.decoder.14.attn.proj.bias', 'module.decoder.14.attn.q_norm.weight', 'module.decoder.14.attn.q_norm.bias', 'module.decoder.14.attn.k_norm.weight', 'module.decoder.14.attn.k_norm.bias', 'module.decoder.14.ls1.gamma', 'module.decoder.14.norm2.weight', 'module.decoder.14.norm2.bias', 'module.decoder.14.mlp.fc1.weight', 'module.decoder.14.mlp.fc1.bias', 'module.decoder.14.mlp.fc2.weight', 'module.decoder.14.mlp.fc2.bias', 'module.decoder.14.ls2.gamma', 'module.decoder.15.norm1.weight', 'module.decoder.15.norm1.bias', 'module.decoder.15.attn.qkv.weight', 'module.decoder.15.attn.qkv.bias', 'module.decoder.15.attn.proj.weight', 'module.decoder.15.attn.proj.bias', 'module.decoder.15.attn.q_norm.weight', 'module.decoder.15.attn.q_norm.bias', 'module.decoder.15.attn.k_norm.weight', 'module.decoder.15.attn.k_norm.bias', 'module.decoder.15.ls1.gamma', 'module.decoder.15.norm2.weight', 'module.decoder.15.norm2.bias', 'module.decoder.15.mlp.fc1.weight', 'module.decoder.15.mlp.fc1.bias', 'module.decoder.15.mlp.fc2.weight', 'module.decoder.15.mlp.fc2.bias', 'module.decoder.15.ls2.gamma', 'module.decoder.16.norm1.weight', 'module.decoder.16.norm1.bias', 'module.decoder.16.attn.qkv.weight', 'module.decoder.16.attn.qkv.bias', 'module.decoder.16.attn.proj.weight', 'module.decoder.16.attn.proj.bias', 'module.decoder.16.attn.q_norm.weight', 'module.decoder.16.attn.q_norm.bias', 'module.decoder.16.attn.k_norm.weight', 'module.decoder.16.attn.k_norm.bias', 'module.decoder.16.ls1.gamma', 'module.decoder.16.norm2.weight', 'module.decoder.16.norm2.bias', 'module.decoder.16.mlp.fc1.weight', 'module.decoder.16.mlp.fc1.bias', 'module.decoder.16.mlp.fc2.weight', 'module.decoder.16.mlp.fc2.bias', 'module.decoder.16.ls2.gamma', 'module.decoder.17.norm1.weight', 'module.decoder.17.norm1.bias', 'module.decoder.17.attn.qkv.weight', 'module.decoder.17.attn.qkv.bias', 'module.decoder.17.attn.proj.weight', 'module.decoder.17.attn.proj.bias', 'module.decoder.17.attn.q_norm.weight', 'module.decoder.17.attn.q_norm.bias', 'module.decoder.17.attn.k_norm.weight', 'module.decoder.17.attn.k_norm.bias', 'module.decoder.17.ls1.gamma', 'module.decoder.17.norm2.weight', 'module.decoder.17.norm2.bias', 'module.decoder.17.mlp.fc1.weight', 'module.decoder.17.mlp.fc1.bias', 'module.decoder.17.mlp.fc2.weight', 'module.decoder.17.mlp.fc2.bias', 'module.decoder.17.ls2.gamma', 'module.decoder.18.norm1.weight', 'module.decoder.18.norm1.bias', 'module.decoder.18.attn.qkv.weight', 'module.decoder.18.attn.qkv.bias', 'module.decoder.18.attn.proj.weight', 'module.decoder.18.attn.proj.bias', 'module.decoder.18.attn.q_norm.weight', 'module.decoder.18.attn.q_norm.bias', 'module.decoder.18.attn.k_norm.weight', 'module.decoder.18.attn.k_norm.bias', 'module.decoder.18.ls1.gamma', 'module.decoder.18.norm2.weight', 'module.decoder.18.norm2.bias', 'module.decoder.18.mlp.fc1.weight', 'module.decoder.18.mlp.fc1.bias', 'module.decoder.18.mlp.fc2.weight', 'module.decoder.18.mlp.fc2.bias', 'module.decoder.18.ls2.gamma', 'module.decoder.19.norm1.weight', 'module.decoder.19.norm1.bias', 'module.decoder.19.attn.qkv.weight', 'module.decoder.19.attn.qkv.bias', 'module.decoder.19.attn.proj.weight', 'module.decoder.19.attn.proj.bias', 'module.decoder.19.attn.q_norm.weight', 'module.decoder.19.attn.q_norm.bias', 'module.decoder.19.attn.k_norm.weight', 'module.decoder.19.attn.k_norm.bias', 'module.decoder.19.ls1.gamma', 'module.decoder.19.norm2.weight', 'module.decoder.19.norm2.bias', 'module.decoder.19.mlp.fc1.weight', 'module.decoder.19.mlp.fc1.bias', 'module.decoder.19.mlp.fc2.weight', 'module.decoder.19.mlp.fc2.bias', 'module.decoder.19.ls2.gamma', 'module.decoder.20.norm1.weight', 'module.decoder.20.norm1.bias', 'module.decoder.20.attn.qkv.weight', 'module.decoder.20.attn.qkv.bias', 'module.decoder.20.attn.proj.weight', 'module.decoder.20.attn.proj.bias', 'module.decoder.20.attn.q_norm.weight', 'module.decoder.20.attn.q_norm.bias', 'module.decoder.20.attn.k_norm.weight', 'module.decoder.20.attn.k_norm.bias', 'module.decoder.20.ls1.gamma', 'module.decoder.20.norm2.weight', 'module.decoder.20.norm2.bias', 'module.decoder.20.mlp.fc1.weight', 'module.decoder.20.mlp.fc1.bias', 'module.decoder.20.mlp.fc2.weight', 'module.decoder.20.mlp.fc2.bias', 'module.decoder.20.ls2.gamma', 'module.decoder.21.norm1.weight', 'module.decoder.21.norm1.bias', 'module.decoder.21.attn.qkv.weight', 'module.decoder.21.attn.qkv.bias', 'module.decoder.21.attn.proj.weight', 'module.decoder.21.attn.proj.bias', 'module.decoder.21.attn.q_norm.weight', 'module.decoder.21.attn.q_norm.bias', 'module.decoder.21.attn.k_norm.weight', 'module.decoder.21.attn.k_norm.bias', 'module.decoder.21.ls1.gamma', 'module.decoder.21.norm2.weight', 'module.decoder.21.norm2.bias', 'module.decoder.21.mlp.fc1.weight', 'module.decoder.21.mlp.fc1.bias', 'module.decoder.21.mlp.fc2.weight', 'module.decoder.21.mlp.fc2.bias', 'module.decoder.21.ls2.gamma', 'module.decoder.22.norm1.weight', 'module.decoder.22.norm1.bias', 'module.decoder.22.attn.qkv.weight', 'module.decoder.22.attn.qkv.bias', 'module.decoder.22.attn.proj.weight', 'module.decoder.22.attn.proj.bias', 'module.decoder.22.attn.q_norm.weight', 'module.decoder.22.attn.q_norm.bias', 'module.decoder.22.attn.k_norm.weight', 'module.decoder.22.attn.k_norm.bias', 'module.decoder.22.ls1.gamma', 'module.decoder.22.norm2.weight', 'module.decoder.22.norm2.bias', 'module.decoder.22.mlp.fc1.weight', 'module.decoder.22.mlp.fc1.bias', 'module.decoder.22.mlp.fc2.weight', 'module.decoder.22.mlp.fc2.bias', 'module.decoder.22.ls2.gamma', 'module.decoder.23.norm1.weight', 'module.decoder.23.norm1.bias', 'module.decoder.23.attn.qkv.weight', 'module.decoder.23.attn.qkv.bias', 'module.decoder.23.attn.proj.weight', 'module.decoder.23.attn.proj.bias', 'module.decoder.23.attn.q_norm.weight', 'module.decoder.23.attn.q_norm.bias', 'module.decoder.23.attn.k_norm.weight', 'module.decoder.23.attn.k_norm.bias', 'module.decoder.23.ls1.gamma', 'module.decoder.23.norm2.weight', 'module.decoder.23.norm2.bias', 'module.decoder.23.mlp.fc1.weight', 'module.decoder.23.mlp.fc1.bias', 'module.decoder.23.mlp.fc2.weight', 'module.decoder.23.mlp.fc2.bias', 'module.decoder.23.ls2.gamma', 'module.decoder.24.norm1.weight', 'module.decoder.24.norm1.bias', 'module.decoder.24.attn.qkv.weight', 'module.decoder.24.attn.qkv.bias', 'module.decoder.24.attn.proj.weight', 'module.decoder.24.attn.proj.bias', 'module.decoder.24.attn.q_norm.weight', 'module.decoder.24.attn.q_norm.bias', 'module.decoder.24.attn.k_norm.weight', 'module.decoder.24.attn.k_norm.bias', 'module.decoder.24.ls1.gamma', 'module.decoder.24.norm2.weight', 'module.decoder.24.norm2.bias', 'module.decoder.24.mlp.fc1.weight', 'module.decoder.24.mlp.fc1.bias', 'module.decoder.24.mlp.fc2.weight', 'module.decoder.24.mlp.fc2.bias', 'module.decoder.24.ls2.gamma', 'module.decoder.25.norm1.weight', 'module.decoder.25.norm1.bias', 'module.decoder.25.attn.qkv.weight', 'module.decoder.25.attn.qkv.bias', 'module.decoder.25.attn.proj.weight', 'module.decoder.25.attn.proj.bias', 'module.decoder.25.attn.q_norm.weight', 'module.decoder.25.attn.q_norm.bias', 'module.decoder.25.attn.k_norm.weight', 'module.decoder.25.attn.k_norm.bias', 'module.decoder.25.ls1.gamma', 'module.decoder.25.norm2.weight', 'module.decoder.25.norm2.bias', 'module.decoder.25.mlp.fc1.weight', 'module.decoder.25.mlp.fc1.bias', 'module.decoder.25.mlp.fc2.weight', 'module.decoder.25.mlp.fc2.bias', 'module.decoder.25.ls2.gamma', 'module.decoder.26.norm1.weight', 'module.decoder.26.norm1.bias', 'module.decoder.26.attn.qkv.weight', 'module.decoder.26.attn.qkv.bias', 'module.decoder.26.attn.proj.weight', 'module.decoder.26.attn.proj.bias', 'module.decoder.26.attn.q_norm.weight', 'module.decoder.26.attn.q_norm.bias', 'module.decoder.26.attn.k_norm.weight', 'module.decoder.26.attn.k_norm.bias', 'module.decoder.26.ls1.gamma', 'module.decoder.26.norm2.weight', 'module.decoder.26.norm2.bias', 'module.decoder.26.mlp.fc1.weight', 'module.decoder.26.mlp.fc1.bias', 'module.decoder.26.mlp.fc2.weight', 'module.decoder.26.mlp.fc2.bias', 'module.decoder.26.ls2.gamma', 'module.decoder.27.norm1.weight', 'module.decoder.27.norm1.bias', 'module.decoder.27.attn.qkv.weight', 'module.decoder.27.attn.qkv.bias', 'module.decoder.27.attn.proj.weight', 'module.decoder.27.attn.proj.bias', 'module.decoder.27.attn.q_norm.weight', 'module.decoder.27.attn.q_norm.bias', 'module.decoder.27.attn.k_norm.weight', 'module.decoder.27.attn.k_norm.bias', 'module.decoder.27.ls1.gamma', 'module.decoder.27.norm2.weight', 'module.decoder.27.norm2.bias', 'module.decoder.27.mlp.fc1.weight', 'module.decoder.27.mlp.fc1.bias', 'module.decoder.27.mlp.fc2.weight', 'module.decoder.27.mlp.fc2.bias', 'module.decoder.27.ls2.gamma', 'module.decoder.28.norm1.weight', 'module.decoder.28.norm1.bias', 'module.decoder.28.attn.qkv.weight', 'module.decoder.28.attn.qkv.bias', 'module.decoder.28.attn.proj.weight', 'module.decoder.28.attn.proj.bias', 'module.decoder.28.attn.q_norm.weight', 'module.decoder.28.attn.q_norm.bias', 'module.decoder.28.attn.k_norm.weight', 'module.decoder.28.attn.k_norm.bias', 'module.decoder.28.ls1.gamma', 'module.decoder.28.norm2.weight', 'module.decoder.28.norm2.bias', 'module.decoder.28.mlp.fc1.weight', 'module.decoder.28.mlp.fc1.bias', 'module.decoder.28.mlp.fc2.weight', 'module.decoder.28.mlp.fc2.bias', 'module.decoder.28.ls2.gamma', 'module.decoder.29.norm1.weight', 'module.decoder.29.norm1.bias', 'module.decoder.29.attn.qkv.weight', 'module.decoder.29.attn.qkv.bias', 'module.decoder.29.attn.proj.weight', 'module.decoder.29.attn.proj.bias', 'module.decoder.29.attn.q_norm.weight', 'module.decoder.29.attn.q_norm.bias', 'module.decoder.29.attn.k_norm.weight', 'module.decoder.29.attn.k_norm.bias', 'module.decoder.29.ls1.gamma', 'module.decoder.29.norm2.weight', 'module.decoder.29.norm2.bias', 'module.decoder.29.mlp.fc1.weight', 'module.decoder.29.mlp.fc1.bias', 'module.decoder.29.mlp.fc2.weight', 'module.decoder.29.mlp.fc2.bias', 'module.decoder.29.ls2.gamma', 'module.decoder.30.norm1.weight', 'module.decoder.30.norm1.bias', 'module.decoder.30.attn.qkv.weight', 'module.decoder.30.attn.qkv.bias', 'module.decoder.30.attn.proj.weight', 'module.decoder.30.attn.proj.bias', 'module.decoder.30.attn.q_norm.weight', 'module.decoder.30.attn.q_norm.bias', 'module.decoder.30.attn.k_norm.weight', 'module.decoder.30.attn.k_norm.bias', 'module.decoder.30.ls1.gamma', 'module.decoder.30.norm2.weight', 'module.decoder.30.norm2.bias', 'module.decoder.30.mlp.fc1.weight', 'module.decoder.30.mlp.fc1.bias', 'module.decoder.30.mlp.fc2.weight', 'module.decoder.30.mlp.fc2.bias', 'module.decoder.30.ls2.gamma', 'module.decoder.31.norm1.weight', 'module.decoder.31.norm1.bias', 'module.decoder.31.attn.qkv.weight', 'module.decoder.31.attn.qkv.bias', 'module.decoder.31.attn.proj.weight', 'module.decoder.31.attn.proj.bias', 'module.decoder.31.attn.q_norm.weight', 'module.decoder.31.attn.q_norm.bias', 'module.decoder.31.attn.k_norm.weight', 'module.decoder.31.attn.k_norm.bias', 'module.decoder.31.ls1.gamma', 'module.decoder.31.norm2.weight', 'module.decoder.31.norm2.bias', 'module.decoder.31.mlp.fc1.weight', 'module.decoder.31.mlp.fc1.bias', 'module.decoder.31.mlp.fc2.weight', 'module.decoder.31.mlp.fc2.bias', 'module.decoder.31.ls2.gamma', 'module.decoder.32.norm1.weight', 'module.decoder.32.norm1.bias', 'module.decoder.32.attn.qkv.weight', 'module.decoder.32.attn.qkv.bias', 'module.decoder.32.attn.proj.weight', 'module.decoder.32.attn.proj.bias', 'module.decoder.32.attn.q_norm.weight', 'module.decoder.32.attn.q_norm.bias', 'module.decoder.32.attn.k_norm.weight', 'module.decoder.32.attn.k_norm.bias', 'module.decoder.32.ls1.gamma', 'module.decoder.32.norm2.weight', 'module.decoder.32.norm2.bias', 'module.decoder.32.mlp.fc1.weight', 'module.decoder.32.mlp.fc1.bias', 'module.decoder.32.mlp.fc2.weight', 'module.decoder.32.mlp.fc2.bias', 'module.decoder.32.ls2.gamma', 'module.decoder.33.norm1.weight', 'module.decoder.33.norm1.bias', 'module.decoder.33.attn.qkv.weight', 'module.decoder.33.attn.qkv.bias', 'module.decoder.33.attn.proj.weight', 'module.decoder.33.attn.proj.bias', 'module.decoder.33.attn.q_norm.weight', 'module.decoder.33.attn.q_norm.bias', 'module.decoder.33.attn.k_norm.weight', 'module.decoder.33.attn.k_norm.bias', 'module.decoder.33.ls1.gamma', 'module.decoder.33.norm2.weight', 'module.decoder.33.norm2.bias', 'module.decoder.33.mlp.fc1.weight', 'module.decoder.33.mlp.fc1.bias', 'module.decoder.33.mlp.fc2.weight', 'module.decoder.33.mlp.fc2.bias', 'module.decoder.33.ls2.gamma', 'module.decoder.34.norm1.weight', 'module.decoder.34.norm1.bias', 'module.decoder.34.attn.qkv.weight', 'module.decoder.34.attn.qkv.bias', 'module.decoder.34.attn.proj.weight', 'module.decoder.34.attn.proj.bias', 'module.decoder.34.attn.q_norm.weight', 'module.decoder.34.attn.q_norm.bias', 'module.decoder.34.attn.k_norm.weight', 'module.decoder.34.attn.k_norm.bias', 'module.decoder.34.ls1.gamma', 'module.decoder.34.norm2.weight', 'module.decoder.34.norm2.bias', 'module.decoder.34.mlp.fc1.weight', 'module.decoder.34.mlp.fc1.bias', 'module.decoder.34.mlp.fc2.weight', 'module.decoder.34.mlp.fc2.bias', 'module.decoder.34.ls2.gamma', 'module.decoder.35.norm1.weight', 'module.decoder.35.norm1.bias', 'module.decoder.35.attn.qkv.weight', 'module.decoder.35.attn.qkv.bias', 'module.decoder.35.attn.proj.weight', 'module.decoder.35.attn.proj.bias', 'module.decoder.35.attn.q_norm.weight', 'module.decoder.35.attn.q_norm.bias', 'module.decoder.35.attn.k_norm.weight', 'module.decoder.35.attn.k_norm.bias', 'module.decoder.35.ls1.gamma', 'module.decoder.35.norm2.weight', 'module.decoder.35.norm2.bias', 'module.decoder.35.mlp.fc1.weight', 'module.decoder.35.mlp.fc1.bias', 'module.decoder.35.mlp.fc2.weight', 'module.decoder.35.mlp.fc2.bias', 'module.decoder.35.ls2.gamma', 'module.point_decoder.projects.weight', 'module.point_decoder.projects.bias', 'module.point_decoder.blocks.0.norm1.weight', 'module.point_decoder.blocks.0.norm1.bias', 'module.point_decoder.blocks.0.attn.qkv.weight', 'module.point_decoder.blocks.0.attn.qkv.bias', 'module.point_decoder.blocks.0.attn.proj.weight', 'module.point_decoder.blocks.0.attn.proj.bias', 'module.point_decoder.blocks.0.norm2.weight', 'module.point_decoder.blocks.0.norm2.bias', 'module.point_decoder.blocks.0.mlp.fc1.weight', 'module.point_decoder.blocks.0.mlp.fc1.bias', 'module.point_decoder.blocks.0.mlp.fc2.weight', 'module.point_decoder.blocks.0.mlp.fc2.bias', 'module.point_decoder.blocks.1.norm1.weight', 'module.point_decoder.blocks.1.norm1.bias', 'module.point_decoder.blocks.1.attn.qkv.weight', 'module.point_decoder.blocks.1.attn.qkv.bias', 'module.point_decoder.blocks.1.attn.proj.weight', 'module.point_decoder.blocks.1.attn.proj.bias', 'module.point_decoder.blocks.1.norm2.weight', 'module.point_decoder.blocks.1.norm2.bias', 'module.point_decoder.blocks.1.mlp.fc1.weight', 'module.point_decoder.blocks.1.mlp.fc1.bias', 'module.point_decoder.blocks.1.mlp.fc2.weight', 'module.point_decoder.blocks.1.mlp.fc2.bias', 'module.point_decoder.blocks.2.norm1.weight', 'module.point_decoder.blocks.2.norm1.bias', 'module.point_decoder.blocks.2.attn.qkv.weight', 'module.point_decoder.blocks.2.attn.qkv.bias', 'module.point_decoder.blocks.2.attn.proj.weight', 'module.point_decoder.blocks.2.attn.proj.bias', 'module.point_decoder.blocks.2.norm2.weight', 'module.point_decoder.blocks.2.norm2.bias', 'module.point_decoder.blocks.2.mlp.fc1.weight', 'module.point_decoder.blocks.2.mlp.fc1.bias', 'module.point_decoder.blocks.2.mlp.fc2.weight', 'module.point_decoder.blocks.2.mlp.fc2.bias', 'module.point_decoder.blocks.3.norm1.weight', 'module.point_decoder.blocks.3.norm1.bias', 'module.point_decoder.blocks.3.attn.qkv.weight', 'module.point_decoder.blocks.3.attn.qkv.bias', 'module.point_decoder.blocks.3.attn.proj.weight', 'module.point_decoder.blocks.3.attn.proj.bias', 'module.point_decoder.blocks.3.norm2.weight', 'module.point_decoder.blocks.3.norm2.bias', 'module.point_decoder.blocks.3.mlp.fc1.weight', 'module.point_decoder.blocks.3.mlp.fc1.bias', 'module.point_decoder.blocks.3.mlp.fc2.weight', 'module.point_decoder.blocks.3.mlp.fc2.bias', 'module.point_decoder.blocks.4.norm1.weight', 'module.point_decoder.blocks.4.norm1.bias', 'module.point_decoder.blocks.4.attn.qkv.weight', 'module.point_decoder.blocks.4.attn.qkv.bias', 'module.point_decoder.blocks.4.attn.proj.weight', 'module.point_decoder.blocks.4.attn.proj.bias', 'module.point_decoder.blocks.4.norm2.weight', 'module.point_decoder.blocks.4.norm2.bias', 'module.point_decoder.blocks.4.mlp.fc1.weight', 'module.point_decoder.blocks.4.mlp.fc1.bias', 'module.point_decoder.blocks.4.mlp.fc2.weight', 'module.point_decoder.blocks.4.mlp.fc2.bias', 'module.point_decoder.linear_out.weight', 'module.point_decoder.linear_out.bias', 'module.point_head.proj.weight', 'module.point_head.proj.bias', 'module.conf_decoder.projects.weight', 'module.conf_decoder.projects.bias', 'module.conf_decoder.blocks.0.norm1.weight', 'module.conf_decoder.blocks.0.norm1.bias', 'module.conf_decoder.blocks.0.attn.qkv.weight', 'module.conf_decoder.blocks.0.attn.qkv.bias', 'module.conf_decoder.blocks.0.attn.proj.weight', 'module.conf_decoder.blocks.0.attn.proj.bias', 'module.conf_decoder.blocks.0.norm2.weight', 'module.conf_decoder.blocks.0.norm2.bias', 'module.conf_decoder.blocks.0.mlp.fc1.weight', 'module.conf_decoder.blocks.0.mlp.fc1.bias', 'module.conf_decoder.blocks.0.mlp.fc2.weight', 'module.conf_decoder.blocks.0.mlp.fc2.bias', 'module.conf_decoder.blocks.1.norm1.weight', 'module.conf_decoder.blocks.1.norm1.bias', 'module.conf_decoder.blocks.1.attn.qkv.weight', 'module.conf_decoder.blocks.1.attn.qkv.bias', 'module.conf_decoder.blocks.1.attn.proj.weight', 'module.conf_decoder.blocks.1.attn.proj.bias', 'module.conf_decoder.blocks.1.norm2.weight', 'module.conf_decoder.blocks.1.norm2.bias', 'module.conf_decoder.blocks.1.mlp.fc1.weight', 'module.conf_decoder.blocks.1.mlp.fc1.bias', 'module.conf_decoder.blocks.1.mlp.fc2.weight', 'module.conf_decoder.blocks.1.mlp.fc2.bias', 'module.conf_decoder.blocks.2.norm1.weight', 'module.conf_decoder.blocks.2.norm1.bias', 'module.conf_decoder.blocks.2.attn.qkv.weight', 'module.conf_decoder.blocks.2.attn.qkv.bias', 'module.conf_decoder.blocks.2.attn.proj.weight', 'module.conf_decoder.blocks.2.attn.proj.bias', 'module.conf_decoder.blocks.2.norm2.weight', 'module.conf_decoder.blocks.2.norm2.bias', 'module.conf_decoder.blocks.2.mlp.fc1.weight', 'module.conf_decoder.blocks.2.mlp.fc1.bias', 'module.conf_decoder.blocks.2.mlp.fc2.weight', 'module.conf_decoder.blocks.2.mlp.fc2.bias', 'module.conf_decoder.blocks.3.norm1.weight', 'module.conf_decoder.blocks.3.norm1.bias', 'module.conf_decoder.blocks.3.attn.qkv.weight', 'module.conf_decoder.blocks.3.attn.qkv.bias', 'module.conf_decoder.blocks.3.attn.proj.weight', 'module.conf_decoder.blocks.3.attn.proj.bias', 'module.conf_decoder.blocks.3.norm2.weight', 'module.conf_decoder.blocks.3.norm2.bias', 'module.conf_decoder.blocks.3.mlp.fc1.weight', 'module.conf_decoder.blocks.3.mlp.fc1.bias', 'module.conf_decoder.blocks.3.mlp.fc2.weight', 'module.conf_decoder.blocks.3.mlp.fc2.bias', 'module.conf_decoder.blocks.4.norm1.weight', 'module.conf_decoder.blocks.4.norm1.bias', 'module.conf_decoder.blocks.4.attn.qkv.weight', 'module.conf_decoder.blocks.4.attn.qkv.bias', 'module.conf_decoder.blocks.4.attn.proj.weight', 'module.conf_decoder.blocks.4.attn.proj.bias', 'module.conf_decoder.blocks.4.norm2.weight', 'module.conf_decoder.blocks.4.norm2.bias', 'module.conf_decoder.blocks.4.mlp.fc1.weight', 'module.conf_decoder.blocks.4.mlp.fc1.bias', 'module.conf_decoder.blocks.4.mlp.fc2.weight', 'module.conf_decoder.blocks.4.mlp.fc2.bias', 'module.conf_decoder.linear_out.weight', 'module.conf_decoder.linear_out.bias', 'module.conf_head.proj.weight', 'module.conf_head.proj.bias', 'module.camera_decoder.projects.weight', 'module.camera_decoder.projects.bias', 'module.camera_decoder.blocks.0.norm1.weight', 'module.camera_decoder.blocks.0.norm1.bias', 'module.camera_decoder.blocks.0.attn.qkv.weight', 'module.camera_decoder.blocks.0.attn.qkv.bias', 'module.camera_decoder.blocks.0.attn.proj.weight', 'module.camera_decoder.blocks.0.attn.proj.bias', 'module.camera_decoder.blocks.0.norm2.weight', 'module.camera_decoder.blocks.0.norm2.bias', 'module.camera_decoder.blocks.0.mlp.fc1.weight', 'module.camera_decoder.blocks.0.mlp.fc1.bias', 'module.camera_decoder.blocks.0.mlp.fc2.weight', 'module.camera_decoder.blocks.0.mlp.fc2.bias', 'module.camera_decoder.blocks.1.norm1.weight', 'module.camera_decoder.blocks.1.norm1.bias', 'module.camera_decoder.blocks.1.attn.qkv.weight', 'module.camera_decoder.blocks.1.attn.qkv.bias', 'module.camera_decoder.blocks.1.attn.proj.weight', 'module.camera_decoder.blocks.1.attn.proj.bias', 'module.camera_decoder.blocks.1.norm2.weight', 'module.camera_decoder.blocks.1.norm2.bias', 'module.camera_decoder.blocks.1.mlp.fc1.weight', 'module.camera_decoder.blocks.1.mlp.fc1.bias', 'module.camera_decoder.blocks.1.mlp.fc2.weight', 'module.camera_decoder.blocks.1.mlp.fc2.bias', 'module.camera_decoder.blocks.2.norm1.weight', 'module.camera_decoder.blocks.2.norm1.bias', 'module.camera_decoder.blocks.2.attn.qkv.weight', 'module.camera_decoder.blocks.2.attn.qkv.bias', 'module.camera_decoder.blocks.2.attn.proj.weight', 'module.camera_decoder.blocks.2.attn.proj.bias', 'module.camera_decoder.blocks.2.norm2.weight', 'module.camera_decoder.blocks.2.norm2.bias', 'module.camera_decoder.blocks.2.mlp.fc1.weight', 'module.camera_decoder.blocks.2.mlp.fc1.bias', 'module.camera_decoder.blocks.2.mlp.fc2.weight', 'module.camera_decoder.blocks.2.mlp.fc2.bias', 'module.camera_decoder.blocks.3.norm1.weight', 'module.camera_decoder.blocks.3.norm1.bias', 'module.camera_decoder.blocks.3.attn.qkv.weight', 'module.camera_decoder.blocks.3.attn.qkv.bias', 'module.camera_decoder.blocks.3.attn.proj.weight', 'module.camera_decoder.blocks.3.attn.proj.bias', 'module.camera_decoder.blocks.3.norm2.weight', 'module.camera_decoder.blocks.3.norm2.bias', 'module.camera_decoder.blocks.3.mlp.fc1.weight', 'module.camera_decoder.blocks.3.mlp.fc1.bias', 'module.camera_decoder.blocks.3.mlp.fc2.weight', 'module.camera_decoder.blocks.3.mlp.fc2.bias', 'module.camera_decoder.blocks.4.norm1.weight', 'module.camera_decoder.blocks.4.norm1.bias', 'module.camera_decoder.blocks.4.attn.qkv.weight', 'module.camera_decoder.blocks.4.attn.qkv.bias', 'module.camera_decoder.blocks.4.attn.proj.weight', 'module.camera_decoder.blocks.4.attn.proj.bias', 'module.camera_decoder.blocks.4.norm2.weight', 'module.camera_decoder.blocks.4.norm2.bias', 'module.camera_decoder.blocks.4.mlp.fc1.weight', 'module.camera_decoder.blocks.4.mlp.fc1.bias', 'module.camera_decoder.blocks.4.mlp.fc2.weight', 'module.camera_decoder.blocks.4.mlp.fc2.bias', 'module.camera_decoder.linear_out.weight', 'module.camera_decoder.linear_out.bias', 'module.camera_head.res_conv.0.res_conv1.weight', 'module.camera_head.res_conv.0.res_conv1.bias', 'module.camera_head.res_conv.0.res_conv2.weight', 'module.camera_head.res_conv.0.res_conv2.bias', 'module.camera_head.res_conv.0.res_conv3.weight', 'module.camera_head.res_conv.0.res_conv3.bias', 'module.camera_head.res_conv.1.res_conv1.weight', 'module.camera_head.res_conv.1.res_conv1.bias', 'module.camera_head.res_conv.1.res_conv2.weight', 'module.camera_head.res_conv.1.res_conv2.bias', 'module.camera_head.res_conv.1.res_conv3.weight', 'module.camera_head.res_conv.1.res_conv3.bias', 'module.camera_head.more_mlps.0.weight', 'module.camera_head.more_mlps.0.bias', 'module.camera_head.more_mlps.2.weight', 'module.camera_head.more_mlps.2.bias', 'module.camera_head.fc_t.weight', 'module.camera_head.fc_t.bias', 'module.camera_head.fc_rot.weight', 'module.camera_head.fc_rot.bias']) +[2026-05-01 23:35:26,986][__main__][INFO] - [RANK 0] Freezing patch embedding and positional encoding parameters... +[2026-05-01 23:35:26,991][__main__][INFO] - [RANK 0] Frozen 304,376,832 parameters out of 958,696,732 total parameters. (31.75%) +[2026-05-01 23:35:26,991][__main__][INFO] - [RANK 0] Trainable parameters: 654,319,900 (68.25%) +[2026-05-01 23:35:26,991][__main__][INFO] - [RANK 0] Example frozen parameters: register_token, encoder.cls_token, encoder.pos_embed, encoder.register_tokens, encoder.patch_embed.proj.weight... +[2026-05-01 23:35:26,994][croco.utils.misc][INFO] - [RANK 0] Param groups = { + "no_decay": { + "weight_decay": 0.0, + "params": [ + "decoder.0.norm1.weight", + "decoder.0.norm1.bias", + "decoder.0.attn.qkv.bias", + "decoder.0.attn.proj.bias", + "decoder.0.attn.q_norm.weight", + "decoder.0.attn.q_norm.bias", + "decoder.0.attn.k_norm.weight", + "decoder.0.attn.k_norm.bias", + "decoder.0.ls1.gamma", + "decoder.0.norm2.weight", + "decoder.0.norm2.bias", + "decoder.0.mlp.fc1.bias", + "decoder.0.mlp.fc2.bias", + "decoder.0.ls2.gamma", + "decoder.1.norm1.weight", + "decoder.1.norm1.bias", + "decoder.1.attn.qkv.bias", + "decoder.1.attn.proj.bias", + "decoder.1.attn.q_norm.weight", + "decoder.1.attn.q_norm.bias", + "decoder.1.attn.k_norm.weight", + "decoder.1.attn.k_norm.bias", + "decoder.1.ls1.gamma", + "decoder.1.norm2.weight", + "decoder.1.norm2.bias", + "decoder.1.mlp.fc1.bias", + "decoder.1.mlp.fc2.bias", + "decoder.1.ls2.gamma", + "decoder.2.norm1.weight", + "decoder.2.norm1.bias", + "decoder.2.attn.qkv.bias", + "decoder.2.attn.proj.bias", + "decoder.2.attn.q_norm.weight", + "decoder.2.attn.q_norm.bias", + "decoder.2.attn.k_norm.weight", + "decoder.2.attn.k_norm.bias", + "decoder.2.ls1.gamma", + "decoder.2.norm2.weight", + "decoder.2.norm2.bias", + "decoder.2.mlp.fc1.bias", + "decoder.2.mlp.fc2.bias", + "decoder.2.ls2.gamma", + "decoder.3.norm1.weight", + "decoder.3.norm1.bias", + "decoder.3.attn.qkv.bias", + "decoder.3.attn.proj.bias", + "decoder.3.attn.q_norm.weight", + "decoder.3.attn.q_norm.bias", + "decoder.3.attn.k_norm.weight", + "decoder.3.attn.k_norm.bias", + "decoder.3.ls1.gamma", + "decoder.3.norm2.weight", + "decoder.3.norm2.bias", + "decoder.3.mlp.fc1.bias", + "decoder.3.mlp.fc2.bias", + "decoder.3.ls2.gamma", + "decoder.4.norm1.weight", + "decoder.4.norm1.bias", + "decoder.4.attn.qkv.bias", + "decoder.4.attn.proj.bias", + "decoder.4.attn.q_norm.weight", + "decoder.4.attn.q_norm.bias", + "decoder.4.attn.k_norm.weight", + "decoder.4.attn.k_norm.bias", + "decoder.4.ls1.gamma", + "decoder.4.norm2.weight", + "decoder.4.norm2.bias", + "decoder.4.mlp.fc1.bias", + "decoder.4.mlp.fc2.bias", + "decoder.4.ls2.gamma", + "decoder.5.norm1.weight", + "decoder.5.norm1.bias", + "decoder.5.attn.qkv.bias", + "decoder.5.attn.proj.bias", + "decoder.5.attn.q_norm.weight", + "decoder.5.attn.q_norm.bias", + "decoder.5.attn.k_norm.weight", + "decoder.5.attn.k_norm.bias", + "decoder.5.ls1.gamma", + "decoder.5.norm2.weight", + "decoder.5.norm2.bias", + "decoder.5.mlp.fc1.bias", + "decoder.5.mlp.fc2.bias", + "decoder.5.ls2.gamma", + "decoder.6.norm1.weight", + "decoder.6.norm1.bias", + "decoder.6.attn.qkv.bias", + "decoder.6.attn.proj.bias", + "decoder.6.attn.q_norm.weight", + "decoder.6.attn.q_norm.bias", + "decoder.6.attn.k_norm.weight", + "decoder.6.attn.k_norm.bias", + "decoder.6.ls1.gamma", + "decoder.6.norm2.weight", + "decoder.6.norm2.bias", + "decoder.6.mlp.fc1.bias", + "decoder.6.mlp.fc2.bias", + "decoder.6.ls2.gamma", + "decoder.7.norm1.weight", + "decoder.7.norm1.bias", + "decoder.7.attn.qkv.bias", + "decoder.7.attn.proj.bias", + "decoder.7.attn.q_norm.weight", + "decoder.7.attn.q_norm.bias", + "decoder.7.attn.k_norm.weight", + "decoder.7.attn.k_norm.bias", + "decoder.7.ls1.gamma", + "decoder.7.norm2.weight", + "decoder.7.norm2.bias", + "decoder.7.mlp.fc1.bias", + "decoder.7.mlp.fc2.bias", + "decoder.7.ls2.gamma", + "decoder.8.norm1.weight", + "decoder.8.norm1.bias", + "decoder.8.attn.qkv.bias", + "decoder.8.attn.proj.bias", + "decoder.8.attn.q_norm.weight", + "decoder.8.attn.q_norm.bias", + "decoder.8.attn.k_norm.weight", + "decoder.8.attn.k_norm.bias", + "decoder.8.ls1.gamma", + "decoder.8.norm2.weight", + "decoder.8.norm2.bias", + "decoder.8.mlp.fc1.bias", + "decoder.8.mlp.fc2.bias", + "decoder.8.ls2.gamma", + "decoder.9.norm1.weight", + "decoder.9.norm1.bias", + "decoder.9.attn.qkv.bias", + "decoder.9.attn.proj.bias", + "decoder.9.attn.q_norm.weight", + "decoder.9.attn.q_norm.bias", + "decoder.9.attn.k_norm.weight", + "decoder.9.attn.k_norm.bias", + "decoder.9.ls1.gamma", + "decoder.9.norm2.weight", + "decoder.9.norm2.bias", + "decoder.9.mlp.fc1.bias", + "decoder.9.mlp.fc2.bias", + "decoder.9.ls2.gamma", + "decoder.10.norm1.weight", + "decoder.10.norm1.bias", + "decoder.10.attn.qkv.bias", + "decoder.10.attn.proj.bias", + "decoder.10.attn.q_norm.weight", + "decoder.10.attn.q_norm.bias", + "decoder.10.attn.k_norm.weight", + "decoder.10.attn.k_norm.bias", + "decoder.10.ls1.gamma", + "decoder.10.norm2.weight", + "decoder.10.norm2.bias", + "decoder.10.mlp.fc1.bias", + "decoder.10.mlp.fc2.bias", + "decoder.10.ls2.gamma", + "decoder.11.norm1.weight", + "decoder.11.norm1.bias", + "decoder.11.attn.qkv.bias", + "decoder.11.attn.proj.bias", + "decoder.11.attn.q_norm.weight", + "decoder.11.attn.q_norm.bias", + "decoder.11.attn.k_norm.weight", + "decoder.11.attn.k_norm.bias", + "decoder.11.ls1.gamma", + "decoder.11.norm2.weight", + "decoder.11.norm2.bias", + "decoder.11.mlp.fc1.bias", + "decoder.11.mlp.fc2.bias", + "decoder.11.ls2.gamma", + "decoder.12.norm1.weight", + "decoder.12.norm1.bias", + "decoder.12.attn.qkv.bias", + "decoder.12.attn.proj.bias", + "decoder.12.attn.q_norm.weight", + "decoder.12.attn.q_norm.bias", + "decoder.12.attn.k_norm.weight", + "decoder.12.attn.k_norm.bias", + "decoder.12.ls1.gamma", + "decoder.12.norm2.weight", + "decoder.12.norm2.bias", + "decoder.12.mlp.fc1.bias", + "decoder.12.mlp.fc2.bias", + "decoder.12.ls2.gamma", + "decoder.13.norm1.weight", + "decoder.13.norm1.bias", + "decoder.13.attn.qkv.bias", + "decoder.13.attn.proj.bias", + "decoder.13.attn.q_norm.weight", + "decoder.13.attn.q_norm.bias", + "decoder.13.attn.k_norm.weight", + "decoder.13.attn.k_norm.bias", + "decoder.13.ls1.gamma", + "decoder.13.norm2.weight", + "decoder.13.norm2.bias", + "decoder.13.mlp.fc1.bias", + "decoder.13.mlp.fc2.bias", + "decoder.13.ls2.gamma", + "decoder.14.norm1.weight", + "decoder.14.norm1.bias", + "decoder.14.attn.qkv.bias", + "decoder.14.attn.proj.bias", + "decoder.14.attn.q_norm.weight", + "decoder.14.attn.q_norm.bias", + "decoder.14.attn.k_norm.weight", + "decoder.14.attn.k_norm.bias", + "decoder.14.ls1.gamma", + "decoder.14.norm2.weight", + "decoder.14.norm2.bias", + "decoder.14.mlp.fc1.bias", + "decoder.14.mlp.fc2.bias", + "decoder.14.ls2.gamma", + "decoder.15.norm1.weight", + "decoder.15.norm1.bias", + "decoder.15.attn.qkv.bias", + "decoder.15.attn.proj.bias", + "decoder.15.attn.q_norm.weight", + "decoder.15.attn.q_norm.bias", + "decoder.15.attn.k_norm.weight", + "decoder.15.attn.k_norm.bias", + "decoder.15.ls1.gamma", + "decoder.15.norm2.weight", + "decoder.15.norm2.bias", + "decoder.15.mlp.fc1.bias", + "decoder.15.mlp.fc2.bias", + "decoder.15.ls2.gamma", + "decoder.16.norm1.weight", + "decoder.16.norm1.bias", + "decoder.16.attn.qkv.bias", + "decoder.16.attn.proj.bias", + "decoder.16.attn.q_norm.weight", + "decoder.16.attn.q_norm.bias", + "decoder.16.attn.k_norm.weight", + "decoder.16.attn.k_norm.bias", + "decoder.16.ls1.gamma", + "decoder.16.norm2.weight", + "decoder.16.norm2.bias", + "decoder.16.mlp.fc1.bias", + "decoder.16.mlp.fc2.bias", + "decoder.16.ls2.gamma", + "decoder.17.norm1.weight", + "decoder.17.norm1.bias", + "decoder.17.attn.qkv.bias", + "decoder.17.attn.proj.bias", + "decoder.17.attn.q_norm.weight", + "decoder.17.attn.q_norm.bias", + "decoder.17.attn.k_norm.weight", + "decoder.17.attn.k_norm.bias", + "decoder.17.ls1.gamma", + "decoder.17.norm2.weight", + "decoder.17.norm2.bias", + "decoder.17.mlp.fc1.bias", + "decoder.17.mlp.fc2.bias", + "decoder.17.ls2.gamma", + "decoder.18.norm1.weight", + "decoder.18.norm1.bias", + "decoder.18.attn.qkv.bias", + "decoder.18.attn.proj.bias", + "decoder.18.attn.q_norm.weight", + "decoder.18.attn.q_norm.bias", + "decoder.18.attn.k_norm.weight", + "decoder.18.attn.k_norm.bias", + "decoder.18.ls1.gamma", + "decoder.18.norm2.weight", + "decoder.18.norm2.bias", + "decoder.18.mlp.fc1.bias", + "decoder.18.mlp.fc2.bias", + "decoder.18.ls2.gamma", + "decoder.19.norm1.weight", + "decoder.19.norm1.bias", + "decoder.19.attn.qkv.bias", + "decoder.19.attn.proj.bias", + "decoder.19.attn.q_norm.weight", + "decoder.19.attn.q_norm.bias", + "decoder.19.attn.k_norm.weight", + "decoder.19.attn.k_norm.bias", + "decoder.19.ls1.gamma", + "decoder.19.norm2.weight", + "decoder.19.norm2.bias", + "decoder.19.mlp.fc1.bias", + "decoder.19.mlp.fc2.bias", + "decoder.19.ls2.gamma", + "decoder.20.norm1.weight", + "decoder.20.norm1.bias", + "decoder.20.attn.qkv.bias", + "decoder.20.attn.proj.bias", + "decoder.20.attn.q_norm.weight", + "decoder.20.attn.q_norm.bias", + "decoder.20.attn.k_norm.weight", + "decoder.20.attn.k_norm.bias", + "decoder.20.ls1.gamma", + "decoder.20.norm2.weight", + "decoder.20.norm2.bias", + "decoder.20.mlp.fc1.bias", + "decoder.20.mlp.fc2.bias", + "decoder.20.ls2.gamma", + "decoder.21.norm1.weight", + "decoder.21.norm1.bias", + "decoder.21.attn.qkv.bias", + "decoder.21.attn.proj.bias", + "decoder.21.attn.q_norm.weight", + "decoder.21.attn.q_norm.bias", + "decoder.21.attn.k_norm.weight", + "decoder.21.attn.k_norm.bias", + "decoder.21.ls1.gamma", + "decoder.21.norm2.weight", + "decoder.21.norm2.bias", + "decoder.21.mlp.fc1.bias", + "decoder.21.mlp.fc2.bias", + "decoder.21.ls2.gamma", + "decoder.22.norm1.weight", + "decoder.22.norm1.bias", + "decoder.22.attn.qkv.bias", + "decoder.22.attn.proj.bias", + "decoder.22.attn.q_norm.weight", + "decoder.22.attn.q_norm.bias", + "decoder.22.attn.k_norm.weight", + "decoder.22.attn.k_norm.bias", + "decoder.22.ls1.gamma", + "decoder.22.norm2.weight", + "decoder.22.norm2.bias", + "decoder.22.mlp.fc1.bias", + "decoder.22.mlp.fc2.bias", + "decoder.22.ls2.gamma", + "decoder.23.norm1.weight", + "decoder.23.norm1.bias", + "decoder.23.attn.qkv.bias", + "decoder.23.attn.proj.bias", + "decoder.23.attn.q_norm.weight", + "decoder.23.attn.q_norm.bias", + "decoder.23.attn.k_norm.weight", + "decoder.23.attn.k_norm.bias", + "decoder.23.ls1.gamma", + "decoder.23.norm2.weight", + "decoder.23.norm2.bias", + "decoder.23.mlp.fc1.bias", + "decoder.23.mlp.fc2.bias", + "decoder.23.ls2.gamma", + "decoder.24.norm1.weight", + "decoder.24.norm1.bias", + "decoder.24.attn.qkv.bias", + "decoder.24.attn.proj.bias", + "decoder.24.attn.q_norm.weight", + "decoder.24.attn.q_norm.bias", + "decoder.24.attn.k_norm.weight", + "decoder.24.attn.k_norm.bias", + "decoder.24.ls1.gamma", + "decoder.24.norm2.weight", + "decoder.24.norm2.bias", + "decoder.24.mlp.fc1.bias", + "decoder.24.mlp.fc2.bias", + "decoder.24.ls2.gamma", + "decoder.25.norm1.weight", + "decoder.25.norm1.bias", + "decoder.25.attn.qkv.bias", + "decoder.25.attn.proj.bias", + "decoder.25.attn.q_norm.weight", + "decoder.25.attn.q_norm.bias", + "decoder.25.attn.k_norm.weight", + "decoder.25.attn.k_norm.bias", + "decoder.25.ls1.gamma", + "decoder.25.norm2.weight", + "decoder.25.norm2.bias", + "decoder.25.mlp.fc1.bias", + "decoder.25.mlp.fc2.bias", + "decoder.25.ls2.gamma", + "decoder.26.norm1.weight", + "decoder.26.norm1.bias", + "decoder.26.attn.qkv.bias", + "decoder.26.attn.proj.bias", + "decoder.26.attn.q_norm.weight", + "decoder.26.attn.q_norm.bias", + "decoder.26.attn.k_norm.weight", + "decoder.26.attn.k_norm.bias", + "decoder.26.ls1.gamma", + "decoder.26.norm2.weight", + "decoder.26.norm2.bias", + "decoder.26.mlp.fc1.bias", + "decoder.26.mlp.fc2.bias", + "decoder.26.ls2.gamma", + "decoder.27.norm1.weight", + "decoder.27.norm1.bias", + "decoder.27.attn.qkv.bias", + "decoder.27.attn.proj.bias", + "decoder.27.attn.q_norm.weight", + "decoder.27.attn.q_norm.bias", + "decoder.27.attn.k_norm.weight", + "decoder.27.attn.k_norm.bias", + "decoder.27.ls1.gamma", + "decoder.27.norm2.weight", + "decoder.27.norm2.bias", + "decoder.27.mlp.fc1.bias", + "decoder.27.mlp.fc2.bias", + "decoder.27.ls2.gamma", + "decoder.28.norm1.weight", + "decoder.28.norm1.bias", + "decoder.28.attn.qkv.bias", + "decoder.28.attn.proj.bias", + "decoder.28.attn.q_norm.weight", + "decoder.28.attn.q_norm.bias", + "decoder.28.attn.k_norm.weight", + "decoder.28.attn.k_norm.bias", + "decoder.28.ls1.gamma", + "decoder.28.norm2.weight", + "decoder.28.norm2.bias", + "decoder.28.mlp.fc1.bias", + "decoder.28.mlp.fc2.bias", + "decoder.28.ls2.gamma", + "decoder.29.norm1.weight", + "decoder.29.norm1.bias", + "decoder.29.attn.qkv.bias", + "decoder.29.attn.proj.bias", + "decoder.29.attn.q_norm.weight", + "decoder.29.attn.q_norm.bias", + "decoder.29.attn.k_norm.weight", + "decoder.29.attn.k_norm.bias", + "decoder.29.ls1.gamma", + "decoder.29.norm2.weight", + "decoder.29.norm2.bias", + "decoder.29.mlp.fc1.bias", + "decoder.29.mlp.fc2.bias", + "decoder.29.ls2.gamma", + "decoder.30.norm1.weight", + "decoder.30.norm1.bias", + "decoder.30.attn.qkv.bias", + "decoder.30.attn.proj.bias", + "decoder.30.attn.q_norm.weight", + "decoder.30.attn.q_norm.bias", + "decoder.30.attn.k_norm.weight", + "decoder.30.attn.k_norm.bias", + "decoder.30.ls1.gamma", + "decoder.30.norm2.weight", + "decoder.30.norm2.bias", + "decoder.30.mlp.fc1.bias", + "decoder.30.mlp.fc2.bias", + "decoder.30.ls2.gamma", + "decoder.31.norm1.weight", + "decoder.31.norm1.bias", + "decoder.31.attn.qkv.bias", + "decoder.31.attn.proj.bias", + "decoder.31.attn.q_norm.weight", + "decoder.31.attn.q_norm.bias", + "decoder.31.attn.k_norm.weight", + "decoder.31.attn.k_norm.bias", + "decoder.31.ls1.gamma", + "decoder.31.norm2.weight", + "decoder.31.norm2.bias", + "decoder.31.mlp.fc1.bias", + "decoder.31.mlp.fc2.bias", + "decoder.31.ls2.gamma", + "decoder.32.norm1.weight", + "decoder.32.norm1.bias", + "decoder.32.attn.qkv.bias", + "decoder.32.attn.proj.bias", + "decoder.32.attn.q_norm.weight", + "decoder.32.attn.q_norm.bias", + "decoder.32.attn.k_norm.weight", + "decoder.32.attn.k_norm.bias", + "decoder.32.ls1.gamma", + "decoder.32.norm2.weight", + "decoder.32.norm2.bias", + "decoder.32.mlp.fc1.bias", + "decoder.32.mlp.fc2.bias", + "decoder.32.ls2.gamma", + "decoder.33.norm1.weight", + "decoder.33.norm1.bias", + "decoder.33.attn.qkv.bias", + "decoder.33.attn.proj.bias", + "decoder.33.attn.q_norm.weight", + "decoder.33.attn.q_norm.bias", + "decoder.33.attn.k_norm.weight", + "decoder.33.attn.k_norm.bias", + "decoder.33.ls1.gamma", + "decoder.33.norm2.weight", + "decoder.33.norm2.bias", + "decoder.33.mlp.fc1.bias", + "decoder.33.mlp.fc2.bias", + "decoder.33.ls2.gamma", + "decoder.34.norm1.weight", + "decoder.34.norm1.bias", + "decoder.34.attn.qkv.bias", + "decoder.34.attn.proj.bias", + "decoder.34.attn.q_norm.weight", + "decoder.34.attn.q_norm.bias", + "decoder.34.attn.k_norm.weight", + "decoder.34.attn.k_norm.bias", + "decoder.34.ls1.gamma", + "decoder.34.norm2.weight", + "decoder.34.norm2.bias", + "decoder.34.mlp.fc1.bias", + "decoder.34.mlp.fc2.bias", + "decoder.34.ls2.gamma", + "decoder.35.norm1.weight", + "decoder.35.norm1.bias", + "decoder.35.attn.qkv.bias", + "decoder.35.attn.proj.bias", + "decoder.35.attn.q_norm.weight", + "decoder.35.attn.q_norm.bias", + "decoder.35.attn.k_norm.weight", + "decoder.35.attn.k_norm.bias", + "decoder.35.ls1.gamma", + "decoder.35.norm2.weight", + "decoder.35.norm2.bias", + "decoder.35.mlp.fc1.bias", + "decoder.35.mlp.fc2.bias", + "decoder.35.ls2.gamma", + "point_decoder.projects.bias", + "point_decoder.blocks.0.norm1.weight", + "point_decoder.blocks.0.norm1.bias", + "point_decoder.blocks.0.attn.qkv.bias", + "point_decoder.blocks.0.attn.proj.bias", + "point_decoder.blocks.0.norm2.weight", + "point_decoder.blocks.0.norm2.bias", + "point_decoder.blocks.0.mlp.fc1.bias", + "point_decoder.blocks.0.mlp.fc2.bias", + "point_decoder.blocks.1.norm1.weight", + "point_decoder.blocks.1.norm1.bias", + "point_decoder.blocks.1.attn.qkv.bias", + "point_decoder.blocks.1.attn.proj.bias", + "point_decoder.blocks.1.norm2.weight", + "point_decoder.blocks.1.norm2.bias", + "point_decoder.blocks.1.mlp.fc1.bias", + "point_decoder.blocks.1.mlp.fc2.bias", + "point_decoder.blocks.2.norm1.weight", + "point_decoder.blocks.2.norm1.bias", + "point_decoder.blocks.2.attn.qkv.bias", + "point_decoder.blocks.2.attn.proj.bias", + "point_decoder.blocks.2.norm2.weight", + "point_decoder.blocks.2.norm2.bias", + "point_decoder.blocks.2.mlp.fc1.bias", + "point_decoder.blocks.2.mlp.fc2.bias", + "point_decoder.blocks.3.norm1.weight", + "point_decoder.blocks.3.norm1.bias", + "point_decoder.blocks.3.attn.qkv.bias", + "point_decoder.blocks.3.attn.proj.bias", + "point_decoder.blocks.3.norm2.weight", + "point_decoder.blocks.3.norm2.bias", + "point_decoder.blocks.3.mlp.fc1.bias", + "point_decoder.blocks.3.mlp.fc2.bias", + "point_decoder.blocks.4.norm1.weight", + "point_decoder.blocks.4.norm1.bias", + "point_decoder.blocks.4.attn.qkv.bias", + "point_decoder.blocks.4.attn.proj.bias", + "point_decoder.blocks.4.norm2.weight", + "point_decoder.blocks.4.norm2.bias", + "point_decoder.blocks.4.mlp.fc1.bias", + "point_decoder.blocks.4.mlp.fc2.bias", + "point_decoder.linear_out.bias", + "point_head.proj.bias", + "conf_decoder.projects.bias", + "conf_decoder.blocks.0.norm1.weight", + "conf_decoder.blocks.0.norm1.bias", + "conf_decoder.blocks.0.attn.qkv.bias", + "conf_decoder.blocks.0.attn.proj.bias", + "conf_decoder.blocks.0.norm2.weight", + "conf_decoder.blocks.0.norm2.bias", + "conf_decoder.blocks.0.mlp.fc1.bias", + "conf_decoder.blocks.0.mlp.fc2.bias", + "conf_decoder.blocks.1.norm1.weight", + "conf_decoder.blocks.1.norm1.bias", + "conf_decoder.blocks.1.attn.qkv.bias", + "conf_decoder.blocks.1.attn.proj.bias", + "conf_decoder.blocks.1.norm2.weight", + "conf_decoder.blocks.1.norm2.bias", + "conf_decoder.blocks.1.mlp.fc1.bias", + "conf_decoder.blocks.1.mlp.fc2.bias", + "conf_decoder.blocks.2.norm1.weight", + "conf_decoder.blocks.2.norm1.bias", + "conf_decoder.blocks.2.attn.qkv.bias", + "conf_decoder.blocks.2.attn.proj.bias", + "conf_decoder.blocks.2.norm2.weight", + "conf_decoder.blocks.2.norm2.bias", + "conf_decoder.blocks.2.mlp.fc1.bias", + "conf_decoder.blocks.2.mlp.fc2.bias", + "conf_decoder.blocks.3.norm1.weight", + "conf_decoder.blocks.3.norm1.bias", + "conf_decoder.blocks.3.attn.qkv.bias", + "conf_decoder.blocks.3.attn.proj.bias", + "conf_decoder.blocks.3.norm2.weight", + "conf_decoder.blocks.3.norm2.bias", + "conf_decoder.blocks.3.mlp.fc1.bias", + "conf_decoder.blocks.3.mlp.fc2.bias", + "conf_decoder.blocks.4.norm1.weight", + "conf_decoder.blocks.4.norm1.bias", + "conf_decoder.blocks.4.attn.qkv.bias", + "conf_decoder.blocks.4.attn.proj.bias", + "conf_decoder.blocks.4.norm2.weight", + "conf_decoder.blocks.4.norm2.bias", + "conf_decoder.blocks.4.mlp.fc1.bias", + "conf_decoder.blocks.4.mlp.fc2.bias", + "conf_decoder.linear_out.bias", + "conf_head.proj.bias", + "camera_decoder.projects.bias", + "camera_decoder.blocks.0.norm1.weight", + "camera_decoder.blocks.0.norm1.bias", + "camera_decoder.blocks.0.attn.qkv.bias", + "camera_decoder.blocks.0.attn.proj.bias", + "camera_decoder.blocks.0.norm2.weight", + "camera_decoder.blocks.0.norm2.bias", + "camera_decoder.blocks.0.mlp.fc1.bias", + "camera_decoder.blocks.0.mlp.fc2.bias", + "camera_decoder.blocks.1.norm1.weight", + "camera_decoder.blocks.1.norm1.bias", + "camera_decoder.blocks.1.attn.qkv.bias", + "camera_decoder.blocks.1.attn.proj.bias", + "camera_decoder.blocks.1.norm2.weight", + "camera_decoder.blocks.1.norm2.bias", + "camera_decoder.blocks.1.mlp.fc1.bias", + "camera_decoder.blocks.1.mlp.fc2.bias", + "camera_decoder.blocks.2.norm1.weight", + "camera_decoder.blocks.2.norm1.bias", + "camera_decoder.blocks.2.attn.qkv.bias", + "camera_decoder.blocks.2.attn.proj.bias", + "camera_decoder.blocks.2.norm2.weight", + "camera_decoder.blocks.2.norm2.bias", + "camera_decoder.blocks.2.mlp.fc1.bias", + "camera_decoder.blocks.2.mlp.fc2.bias", + "camera_decoder.blocks.3.norm1.weight", + "camera_decoder.blocks.3.norm1.bias", + "camera_decoder.blocks.3.attn.qkv.bias", + "camera_decoder.blocks.3.attn.proj.bias", + "camera_decoder.blocks.3.norm2.weight", + "camera_decoder.blocks.3.norm2.bias", + "camera_decoder.blocks.3.mlp.fc1.bias", + "camera_decoder.blocks.3.mlp.fc2.bias", + "camera_decoder.blocks.4.norm1.weight", + "camera_decoder.blocks.4.norm1.bias", + "camera_decoder.blocks.4.attn.qkv.bias", + "camera_decoder.blocks.4.attn.proj.bias", + "camera_decoder.blocks.4.norm2.weight", + "camera_decoder.blocks.4.norm2.bias", + "camera_decoder.blocks.4.mlp.fc1.bias", + "camera_decoder.blocks.4.mlp.fc2.bias", + "camera_decoder.linear_out.bias", + "camera_head.res_conv.0.res_conv1.bias", + "camera_head.res_conv.0.res_conv2.bias", + "camera_head.res_conv.0.res_conv3.bias", + "camera_head.res_conv.1.res_conv1.bias", + "camera_head.res_conv.1.res_conv2.bias", + "camera_head.res_conv.1.res_conv3.bias", + "camera_head.more_mlps.0.bias", + "camera_head.more_mlps.2.bias", + "camera_head.fc_t.bias", + "camera_head.fc_rot.bias" + ], + "lr_scale": 1.0 + }, + "decay": { + "weight_decay": 0.05, + "params": [ + "decoder.0.attn.qkv.weight", + "decoder.0.attn.proj.weight", + "decoder.0.mlp.fc1.weight", + "decoder.0.mlp.fc2.weight", + "decoder.1.attn.qkv.weight", + "decoder.1.attn.proj.weight", + "decoder.1.mlp.fc1.weight", + "decoder.1.mlp.fc2.weight", + "decoder.2.attn.qkv.weight", + "decoder.2.attn.proj.weight", + "decoder.2.mlp.fc1.weight", + "decoder.2.mlp.fc2.weight", + "decoder.3.attn.qkv.weight", + "decoder.3.attn.proj.weight", + "decoder.3.mlp.fc1.weight", + "decoder.3.mlp.fc2.weight", + "decoder.4.attn.qkv.weight", + "decoder.4.attn.proj.weight", + "decoder.4.mlp.fc1.weight", + "decoder.4.mlp.fc2.weight", + "decoder.5.attn.qkv.weight", + "decoder.5.attn.proj.weight", + "decoder.5.mlp.fc1.weight", + "decoder.5.mlp.fc2.weight", + "decoder.6.attn.qkv.weight", + "decoder.6.attn.proj.weight", + "decoder.6.mlp.fc1.weight", + "decoder.6.mlp.fc2.weight", + "decoder.7.attn.qkv.weight", + "decoder.7.attn.proj.weight", + "decoder.7.mlp.fc1.weight", + "decoder.7.mlp.fc2.weight", + "decoder.8.attn.qkv.weight", + "decoder.8.attn.proj.weight", + "decoder.8.mlp.fc1.weight", + "decoder.8.mlp.fc2.weight", + "decoder.9.attn.qkv.weight", + "decoder.9.attn.proj.weight", + "decoder.9.mlp.fc1.weight", + "decoder.9.mlp.fc2.weight", + "decoder.10.attn.qkv.weight", + "decoder.10.attn.proj.weight", + "decoder.10.mlp.fc1.weight", + "decoder.10.mlp.fc2.weight", + "decoder.11.attn.qkv.weight", + "decoder.11.attn.proj.weight", + "decoder.11.mlp.fc1.weight", + "decoder.11.mlp.fc2.weight", + "decoder.12.attn.qkv.weight", + "decoder.12.attn.proj.weight", + "decoder.12.mlp.fc1.weight", + "decoder.12.mlp.fc2.weight", + "decoder.13.attn.qkv.weight", + "decoder.13.attn.proj.weight", + "decoder.13.mlp.fc1.weight", + "decoder.13.mlp.fc2.weight", + "decoder.14.attn.qkv.weight", + "decoder.14.attn.proj.weight", + "decoder.14.mlp.fc1.weight", + "decoder.14.mlp.fc2.weight", + "decoder.15.attn.qkv.weight", + "decoder.15.attn.proj.weight", + "decoder.15.mlp.fc1.weight", + "decoder.15.mlp.fc2.weight", + "decoder.16.attn.qkv.weight", + "decoder.16.attn.proj.weight", + "decoder.16.mlp.fc1.weight", + "decoder.16.mlp.fc2.weight", + "decoder.17.attn.qkv.weight", + "decoder.17.attn.proj.weight", + "decoder.17.mlp.fc1.weight", + "decoder.17.mlp.fc2.weight", + "decoder.18.attn.qkv.weight", + "decoder.18.attn.proj.weight", + "decoder.18.mlp.fc1.weight", + "decoder.18.mlp.fc2.weight", + "decoder.19.attn.qkv.weight", + "decoder.19.attn.proj.weight", + "decoder.19.mlp.fc1.weight", + "decoder.19.mlp.fc2.weight", + "decoder.20.attn.qkv.weight", + "decoder.20.attn.proj.weight", + "decoder.20.mlp.fc1.weight", + "decoder.20.mlp.fc2.weight", + "decoder.21.attn.qkv.weight", + "decoder.21.attn.proj.weight", + "decoder.21.mlp.fc1.weight", + "decoder.21.mlp.fc2.weight", + "decoder.22.attn.qkv.weight", + "decoder.22.attn.proj.weight", + "decoder.22.mlp.fc1.weight", + "decoder.22.mlp.fc2.weight", + "decoder.23.attn.qkv.weight", + "decoder.23.attn.proj.weight", + "decoder.23.mlp.fc1.weight", + "decoder.23.mlp.fc2.weight", + "decoder.24.attn.qkv.weight", + "decoder.24.attn.proj.weight", + "decoder.24.mlp.fc1.weight", + "decoder.24.mlp.fc2.weight", + "decoder.25.attn.qkv.weight", + "decoder.25.attn.proj.weight", + "decoder.25.mlp.fc1.weight", + "decoder.25.mlp.fc2.weight", + "decoder.26.attn.qkv.weight", + "decoder.26.attn.proj.weight", + "decoder.26.mlp.fc1.weight", + "decoder.26.mlp.fc2.weight", + "decoder.27.attn.qkv.weight", + "decoder.27.attn.proj.weight", + "decoder.27.mlp.fc1.weight", + "decoder.27.mlp.fc2.weight", + "decoder.28.attn.qkv.weight", + "decoder.28.attn.proj.weight", + "decoder.28.mlp.fc1.weight", + "decoder.28.mlp.fc2.weight", + "decoder.29.attn.qkv.weight", + "decoder.29.attn.proj.weight", + "decoder.29.mlp.fc1.weight", + "decoder.29.mlp.fc2.weight", + "decoder.30.attn.qkv.weight", + "decoder.30.attn.proj.weight", + "decoder.30.mlp.fc1.weight", + "decoder.30.mlp.fc2.weight", + "decoder.31.attn.qkv.weight", + "decoder.31.attn.proj.weight", + "decoder.31.mlp.fc1.weight", + "decoder.31.mlp.fc2.weight", + "decoder.32.attn.qkv.weight", + "decoder.32.attn.proj.weight", + "decoder.32.mlp.fc1.weight", + "decoder.32.mlp.fc2.weight", + "decoder.33.attn.qkv.weight", + "decoder.33.attn.proj.weight", + "decoder.33.mlp.fc1.weight", + "decoder.33.mlp.fc2.weight", + "decoder.34.attn.qkv.weight", + "decoder.34.attn.proj.weight", + "decoder.34.mlp.fc1.weight", + "decoder.34.mlp.fc2.weight", + "decoder.35.attn.qkv.weight", + "decoder.35.attn.proj.weight", + "decoder.35.mlp.fc1.weight", + "decoder.35.mlp.fc2.weight", + "point_decoder.projects.weight", + "point_decoder.blocks.0.attn.qkv.weight", + "point_decoder.blocks.0.attn.proj.weight", + "point_decoder.blocks.0.mlp.fc1.weight", + "point_decoder.blocks.0.mlp.fc2.weight", + "point_decoder.blocks.1.attn.qkv.weight", + "point_decoder.blocks.1.attn.proj.weight", + "point_decoder.blocks.1.mlp.fc1.weight", + "point_decoder.blocks.1.mlp.fc2.weight", + "point_decoder.blocks.2.attn.qkv.weight", + "point_decoder.blocks.2.attn.proj.weight", + "point_decoder.blocks.2.mlp.fc1.weight", + "point_decoder.blocks.2.mlp.fc2.weight", + "point_decoder.blocks.3.attn.qkv.weight", + "point_decoder.blocks.3.attn.proj.weight", + "point_decoder.blocks.3.mlp.fc1.weight", + "point_decoder.blocks.3.mlp.fc2.weight", + "point_decoder.blocks.4.attn.qkv.weight", + "point_decoder.blocks.4.attn.proj.weight", + "point_decoder.blocks.4.mlp.fc1.weight", + "point_decoder.blocks.4.mlp.fc2.weight", + "point_decoder.linear_out.weight", + "point_head.proj.weight", + "conf_decoder.projects.weight", + "conf_decoder.blocks.0.attn.qkv.weight", + "conf_decoder.blocks.0.attn.proj.weight", + "conf_decoder.blocks.0.mlp.fc1.weight", + "conf_decoder.blocks.0.mlp.fc2.weight", + "conf_decoder.blocks.1.attn.qkv.weight", + "conf_decoder.blocks.1.attn.proj.weight", + "conf_decoder.blocks.1.mlp.fc1.weight", + "conf_decoder.blocks.1.mlp.fc2.weight", + "conf_decoder.blocks.2.attn.qkv.weight", + "conf_decoder.blocks.2.attn.proj.weight", + "conf_decoder.blocks.2.mlp.fc1.weight", + "conf_decoder.blocks.2.mlp.fc2.weight", + "conf_decoder.blocks.3.attn.qkv.weight", + "conf_decoder.blocks.3.attn.proj.weight", + "conf_decoder.blocks.3.mlp.fc1.weight", + "conf_decoder.blocks.3.mlp.fc2.weight", + "conf_decoder.blocks.4.attn.qkv.weight", + "conf_decoder.blocks.4.attn.proj.weight", + "conf_decoder.blocks.4.mlp.fc1.weight", + "conf_decoder.blocks.4.mlp.fc2.weight", + "conf_decoder.linear_out.weight", + "conf_head.proj.weight", + "camera_decoder.projects.weight", + "camera_decoder.blocks.0.attn.qkv.weight", + "camera_decoder.blocks.0.attn.proj.weight", + "camera_decoder.blocks.0.mlp.fc1.weight", + "camera_decoder.blocks.0.mlp.fc2.weight", + "camera_decoder.blocks.1.attn.qkv.weight", + "camera_decoder.blocks.1.attn.proj.weight", + "camera_decoder.blocks.1.mlp.fc1.weight", + "camera_decoder.blocks.1.mlp.fc2.weight", + "camera_decoder.blocks.2.attn.qkv.weight", + "camera_decoder.blocks.2.attn.proj.weight", + "camera_decoder.blocks.2.mlp.fc1.weight", + "camera_decoder.blocks.2.mlp.fc2.weight", + "camera_decoder.blocks.3.attn.qkv.weight", + "camera_decoder.blocks.3.attn.proj.weight", + "camera_decoder.blocks.3.mlp.fc1.weight", + "camera_decoder.blocks.3.mlp.fc2.weight", + "camera_decoder.blocks.4.attn.qkv.weight", + "camera_decoder.blocks.4.attn.proj.weight", + "camera_decoder.blocks.4.mlp.fc1.weight", + "camera_decoder.blocks.4.mlp.fc2.weight", + "camera_decoder.linear_out.weight", + "camera_head.res_conv.0.res_conv1.weight", + "camera_head.res_conv.0.res_conv2.weight", + "camera_head.res_conv.0.res_conv3.weight", + "camera_head.res_conv.1.res_conv1.weight", + "camera_head.res_conv.1.res_conv2.weight", + "camera_head.res_conv.1.res_conv3.weight", + "camera_head.more_mlps.0.weight", + "camera_head.more_mlps.2.weight", + "camera_head.fc_t.weight", + "camera_head.fc_rot.weight" + ], + "lr_scale": 1.0 + } +} +[2026-05-01 23:35:31,129][__main__][INFO] - [RANK 0] Start training for 10 epochs +[2026-05-01 23:35:31,133][__main__][INFO] - [RANK 0] log_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_8gpu/ +[2026-05-01 23:36:55,584][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 0/2175] eta: 2 days, 3:01:10 lr: 0.000000 epoch: 0.0000 (0.0000) step: 0.0000 (0.0000) loss: 5439.7471 (5439.7471) Lcamera_frontend: 4.1901 (4.1901) Ldepth_frontend: 16.2068 (16.2068) Lpmap_frontend: 18.4157 (18.4157) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.1823 (4.1823) Ldepth_mix: 16.2088 (16.2088) Lpmap_mix: 18.4142 (18.4142) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.1872 (4.1872) Ldepth_backend: 16.2021 (16.2021) Lpmap_backend: 18.4084 (18.4084) Ltrack_backend: 0.0000 (0.0000) total: 5439.7471 (5439.7471) time: 84.4462 data: 25.0957 max mem: 32998 +[2026-05-01 23:45:06,532][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 10/2175] eta: 1 day, 7:27:17 lr: 0.000000 epoch: 0.0023 (0.0023) step: 5.0000 (4.8182) loss: 5439.7471 (5074.5800) Lcamera_frontend: 4.1901 (3.9012) Ldepth_frontend: 16.5756 (17.1204) Lpmap_frontend: 18.4116 (18.3184) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.1546 (3.8658) Ldepth_mix: 16.5724 (17.1330) Lpmap_mix: 18.4112 (18.3244) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.1872 (3.8723) Ldepth_backend: 16.5614 (17.1411) Lpmap_backend: 18.4082 (18.3288) Ltrack_backend: 0.0000 (0.0000) total: 5439.7471 (5074.5800) time: 52.3039 data: 2.3197 max mem: 78608 +[2026-05-01 23:53:39,326][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 20/2175] eta: 1 day, 7:01:03 lr: 0.000000 epoch: 0.0046 (0.0046) step: 10.0000 (9.8571) loss: 4455.6162 (4960.8985) Lcamera_frontend: 3.3706 (3.7847) Ldepth_frontend: 16.6039 (17.0273) Lpmap_frontend: 18.2098 (18.2580) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3677 (3.7682) Ldepth_mix: 16.6194 (17.0357) Lpmap_mix: 18.2047 (18.2626) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3689 (3.7819) Ldepth_backend: 16.6355 (17.0431) Lpmap_backend: 18.2009 (18.2675) Ltrack_backend: 0.0000 (0.0000) total: 4455.6162 (4960.8985) time: 50.1844 data: 0.0399 max mem: 78608 +[2026-05-02 00:02:34,752][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 30/2175] eta: 1 day, 7:12:18 lr: 0.000000 epoch: 0.0092 (0.0069) step: 20.0000 (14.8710) loss: 4440.2881 (5149.9153) Lcamera_frontend: 3.3706 (3.9496) Ldepth_frontend: 16.4113 (16.8969) Lpmap_frontend: 18.1052 (18.1768) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3677 (3.9320) Ldepth_mix: 16.3962 (16.9022) Lpmap_mix: 18.1084 (18.1795) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3689 (3.9407) Ldepth_backend: 16.4113 (16.9092) Lpmap_backend: 18.1076 (18.1839) Ltrack_backend: 0.0000 (0.0000) total: 4440.2881 (5149.9153) time: 52.4102 data: 0.0363 max mem: 78608 +[2026-05-02 01:06:58,085][__main__][INFO] - [RANK 0] output_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_8gpu/ +[2026-05-02 01:06:58,456][__main__][INFO] - [RANK 0] Saving current code to /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_8gpu/code/05_02-01:06:58 +[2026-05-02 01:06:58,456][__main__][INFO] - [RANK 0] job dir: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/src +[2026-05-02 01:06:58,457][__main__][INFO] - [RANK 0] Setting seed to 0 for process 0 +[2026-05-02 01:06:58,458][__main__][INFO] - [RANK 0] Building train dataset 6000 @ VirtualKITTI2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 6000 @ KITTI360_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", velodyne_root="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 5400 @ Waymo_v2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/waymo_v2", lidar_root="/scratch-shared/wwei2/waymo_v2", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) +[2026-05-02 01:06:58,459][__main__][INFO] - [RANK 0] Building Train Data loader for dataset: 6000 @ VirtualKITTI2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 6000 @ KITTI360_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", velodyne_root="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 5400 @ Waymo_v2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/waymo_v2", lidar_root="/scratch-shared/wwei2/waymo_v2", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) +[2026-05-02 01:11:09,465][__main__][INFO] - [RANK 0] Building test dataset 200 @ VirtualKITTI2_Multi(split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", resolution=(518, 154), num_views=4, seed=42, n_corres=0) +[2026-05-02 01:11:09,466][__main__][INFO] - [RANK 0] Building Test Data loader for dataset: 200 @ VirtualKITTI2_Multi(split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", resolution=(518, 154), num_views=4, seed=42, n_corres=0) +[2026-05-02 01:11:09,577][__main__][INFO] - [RANK 0] Loading model +[2026-05-02 01:11:15,004][__main__][INFO] - [RANK 0] All model parameters: 958696732 +[2026-05-02 01:11:15,004][__main__][INFO] - [RANK 0] >> Creating train criterion = DistillLoss() +[2026-05-02 01:11:15,004][__main__][INFO] - [RANK 0] >> Creating test criterion = DistillLoss() +[2026-05-02 01:11:15,273][__main__][INFO] - [RANK 0] Loading pretrained: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model +[2026-05-02 01:11:26,676][__main__][INFO] - [RANK 0] _IncompatibleKeys(missing_keys=['register_token', 'image_mean', 'image_std', 'encoder.cls_token', 'encoder.pos_embed', 'encoder.register_tokens', 'encoder.patch_embed.proj.weight', 'encoder.patch_embed.proj.bias', 'encoder.blocks.0.norm1.weight', 'encoder.blocks.0.norm1.bias', 'encoder.blocks.0.attn.qkv.weight', 'encoder.blocks.0.attn.qkv.bias', 'encoder.blocks.0.attn.proj.weight', 'encoder.blocks.0.attn.proj.bias', 'encoder.blocks.0.ls1.gamma', 'encoder.blocks.0.norm2.weight', 'encoder.blocks.0.norm2.bias', 'encoder.blocks.0.mlp.fc1.weight', 'encoder.blocks.0.mlp.fc1.bias', 'encoder.blocks.0.mlp.fc2.weight', 'encoder.blocks.0.mlp.fc2.bias', 'encoder.blocks.0.ls2.gamma', 'encoder.blocks.1.norm1.weight', 'encoder.blocks.1.norm1.bias', 'encoder.blocks.1.attn.qkv.weight', 'encoder.blocks.1.attn.qkv.bias', 'encoder.blocks.1.attn.proj.weight', 'encoder.blocks.1.attn.proj.bias', 'encoder.blocks.1.ls1.gamma', 'encoder.blocks.1.norm2.weight', 'encoder.blocks.1.norm2.bias', 'encoder.blocks.1.mlp.fc1.weight', 'encoder.blocks.1.mlp.fc1.bias', 'encoder.blocks.1.mlp.fc2.weight', 'encoder.blocks.1.mlp.fc2.bias', 'encoder.blocks.1.ls2.gamma', 'encoder.blocks.2.norm1.weight', 'encoder.blocks.2.norm1.bias', 'encoder.blocks.2.attn.qkv.weight', 'encoder.blocks.2.attn.qkv.bias', 'encoder.blocks.2.attn.proj.weight', 'encoder.blocks.2.attn.proj.bias', 'encoder.blocks.2.ls1.gamma', 'encoder.blocks.2.norm2.weight', 'encoder.blocks.2.norm2.bias', 'encoder.blocks.2.mlp.fc1.weight', 'encoder.blocks.2.mlp.fc1.bias', 'encoder.blocks.2.mlp.fc2.weight', 'encoder.blocks.2.mlp.fc2.bias', 'encoder.blocks.2.ls2.gamma', 'encoder.blocks.3.norm1.weight', 'encoder.blocks.3.norm1.bias', 'encoder.blocks.3.attn.qkv.weight', 'encoder.blocks.3.attn.qkv.bias', 'encoder.blocks.3.attn.proj.weight', 'encoder.blocks.3.attn.proj.bias', 'encoder.blocks.3.ls1.gamma', 'encoder.blocks.3.norm2.weight', 'encoder.blocks.3.norm2.bias', 'encoder.blocks.3.mlp.fc1.weight', 'encoder.blocks.3.mlp.fc1.bias', 'encoder.blocks.3.mlp.fc2.weight', 'encoder.blocks.3.mlp.fc2.bias', 'encoder.blocks.3.ls2.gamma', 'encoder.blocks.4.norm1.weight', 'encoder.blocks.4.norm1.bias', 'encoder.blocks.4.attn.qkv.weight', 'encoder.blocks.4.attn.qkv.bias', 'encoder.blocks.4.attn.proj.weight', 'encoder.blocks.4.attn.proj.bias', 'encoder.blocks.4.ls1.gamma', 'encoder.blocks.4.norm2.weight', 'encoder.blocks.4.norm2.bias', 'encoder.blocks.4.mlp.fc1.weight', 'encoder.blocks.4.mlp.fc1.bias', 'encoder.blocks.4.mlp.fc2.weight', 'encoder.blocks.4.mlp.fc2.bias', 'encoder.blocks.4.ls2.gamma', 'encoder.blocks.5.norm1.weight', 'encoder.blocks.5.norm1.bias', 'encoder.blocks.5.attn.qkv.weight', 'encoder.blocks.5.attn.qkv.bias', 'encoder.blocks.5.attn.proj.weight', 'encoder.blocks.5.attn.proj.bias', 'encoder.blocks.5.ls1.gamma', 'encoder.blocks.5.norm2.weight', 'encoder.blocks.5.norm2.bias', 'encoder.blocks.5.mlp.fc1.weight', 'encoder.blocks.5.mlp.fc1.bias', 'encoder.blocks.5.mlp.fc2.weight', 'encoder.blocks.5.mlp.fc2.bias', 'encoder.blocks.5.ls2.gamma', 'encoder.blocks.6.norm1.weight', 'encoder.blocks.6.norm1.bias', 'encoder.blocks.6.attn.qkv.weight', 'encoder.blocks.6.attn.qkv.bias', 'encoder.blocks.6.attn.proj.weight', 'encoder.blocks.6.attn.proj.bias', 'encoder.blocks.6.ls1.gamma', 'encoder.blocks.6.norm2.weight', 'encoder.blocks.6.norm2.bias', 'encoder.blocks.6.mlp.fc1.weight', 'encoder.blocks.6.mlp.fc1.bias', 'encoder.blocks.6.mlp.fc2.weight', 'encoder.blocks.6.mlp.fc2.bias', 'encoder.blocks.6.ls2.gamma', 'encoder.blocks.7.norm1.weight', 'encoder.blocks.7.norm1.bias', 'encoder.blocks.7.attn.qkv.weight', 'encoder.blocks.7.attn.qkv.bias', 'encoder.blocks.7.attn.proj.weight', 'encoder.blocks.7.attn.proj.bias', 'encoder.blocks.7.ls1.gamma', 'encoder.blocks.7.norm2.weight', 'encoder.blocks.7.norm2.bias', 'encoder.blocks.7.mlp.fc1.weight', 'encoder.blocks.7.mlp.fc1.bias', 'encoder.blocks.7.mlp.fc2.weight', 'encoder.blocks.7.mlp.fc2.bias', 'encoder.blocks.7.ls2.gamma', 'encoder.blocks.8.norm1.weight', 'encoder.blocks.8.norm1.bias', 'encoder.blocks.8.attn.qkv.weight', 'encoder.blocks.8.attn.qkv.bias', 'encoder.blocks.8.attn.proj.weight', 'encoder.blocks.8.attn.proj.bias', 'encoder.blocks.8.ls1.gamma', 'encoder.blocks.8.norm2.weight', 'encoder.blocks.8.norm2.bias', 'encoder.blocks.8.mlp.fc1.weight', 'encoder.blocks.8.mlp.fc1.bias', 'encoder.blocks.8.mlp.fc2.weight', 'encoder.blocks.8.mlp.fc2.bias', 'encoder.blocks.8.ls2.gamma', 'encoder.blocks.9.norm1.weight', 'encoder.blocks.9.norm1.bias', 'encoder.blocks.9.attn.qkv.weight', 'encoder.blocks.9.attn.qkv.bias', 'encoder.blocks.9.attn.proj.weight', 'encoder.blocks.9.attn.proj.bias', 'encoder.blocks.9.ls1.gamma', 'encoder.blocks.9.norm2.weight', 'encoder.blocks.9.norm2.bias', 'encoder.blocks.9.mlp.fc1.weight', 'encoder.blocks.9.mlp.fc1.bias', 'encoder.blocks.9.mlp.fc2.weight', 'encoder.blocks.9.mlp.fc2.bias', 'encoder.blocks.9.ls2.gamma', 'encoder.blocks.10.norm1.weight', 'encoder.blocks.10.norm1.bias', 'encoder.blocks.10.attn.qkv.weight', 'encoder.blocks.10.attn.qkv.bias', 'encoder.blocks.10.attn.proj.weight', 'encoder.blocks.10.attn.proj.bias', 'encoder.blocks.10.ls1.gamma', 'encoder.blocks.10.norm2.weight', 'encoder.blocks.10.norm2.bias', 'encoder.blocks.10.mlp.fc1.weight', 'encoder.blocks.10.mlp.fc1.bias', 'encoder.blocks.10.mlp.fc2.weight', 'encoder.blocks.10.mlp.fc2.bias', 'encoder.blocks.10.ls2.gamma', 'encoder.blocks.11.norm1.weight', 'encoder.blocks.11.norm1.bias', 'encoder.blocks.11.attn.qkv.weight', 'encoder.blocks.11.attn.qkv.bias', 'encoder.blocks.11.attn.proj.weight', 'encoder.blocks.11.attn.proj.bias', 'encoder.blocks.11.ls1.gamma', 'encoder.blocks.11.norm2.weight', 'encoder.blocks.11.norm2.bias', 'encoder.blocks.11.mlp.fc1.weight', 'encoder.blocks.11.mlp.fc1.bias', 'encoder.blocks.11.mlp.fc2.weight', 'encoder.blocks.11.mlp.fc2.bias', 'encoder.blocks.11.ls2.gamma', 'encoder.blocks.12.norm1.weight', 'encoder.blocks.12.norm1.bias', 'encoder.blocks.12.attn.qkv.weight', 'encoder.blocks.12.attn.qkv.bias', 'encoder.blocks.12.attn.proj.weight', 'encoder.blocks.12.attn.proj.bias', 'encoder.blocks.12.ls1.gamma', 'encoder.blocks.12.norm2.weight', 'encoder.blocks.12.norm2.bias', 'encoder.blocks.12.mlp.fc1.weight', 'encoder.blocks.12.mlp.fc1.bias', 'encoder.blocks.12.mlp.fc2.weight', 'encoder.blocks.12.mlp.fc2.bias', 'encoder.blocks.12.ls2.gamma', 'encoder.blocks.13.norm1.weight', 'encoder.blocks.13.norm1.bias', 'encoder.blocks.13.attn.qkv.weight', 'encoder.blocks.13.attn.qkv.bias', 'encoder.blocks.13.attn.proj.weight', 'encoder.blocks.13.attn.proj.bias', 'encoder.blocks.13.ls1.gamma', 'encoder.blocks.13.norm2.weight', 'encoder.blocks.13.norm2.bias', 'encoder.blocks.13.mlp.fc1.weight', 'encoder.blocks.13.mlp.fc1.bias', 'encoder.blocks.13.mlp.fc2.weight', 'encoder.blocks.13.mlp.fc2.bias', 'encoder.blocks.13.ls2.gamma', 'encoder.blocks.14.norm1.weight', 'encoder.blocks.14.norm1.bias', 'encoder.blocks.14.attn.qkv.weight', 'encoder.blocks.14.attn.qkv.bias', 'encoder.blocks.14.attn.proj.weight', 'encoder.blocks.14.attn.proj.bias', 'encoder.blocks.14.ls1.gamma', 'encoder.blocks.14.norm2.weight', 'encoder.blocks.14.norm2.bias', 'encoder.blocks.14.mlp.fc1.weight', 'encoder.blocks.14.mlp.fc1.bias', 'encoder.blocks.14.mlp.fc2.weight', 'encoder.blocks.14.mlp.fc2.bias', 'encoder.blocks.14.ls2.gamma', 'encoder.blocks.15.norm1.weight', 'encoder.blocks.15.norm1.bias', 'encoder.blocks.15.attn.qkv.weight', 'encoder.blocks.15.attn.qkv.bias', 'encoder.blocks.15.attn.proj.weight', 'encoder.blocks.15.attn.proj.bias', 'encoder.blocks.15.ls1.gamma', 'encoder.blocks.15.norm2.weight', 'encoder.blocks.15.norm2.bias', 'encoder.blocks.15.mlp.fc1.weight', 'encoder.blocks.15.mlp.fc1.bias', 'encoder.blocks.15.mlp.fc2.weight', 'encoder.blocks.15.mlp.fc2.bias', 'encoder.blocks.15.ls2.gamma', 'encoder.blocks.16.norm1.weight', 'encoder.blocks.16.norm1.bias', 'encoder.blocks.16.attn.qkv.weight', 'encoder.blocks.16.attn.qkv.bias', 'encoder.blocks.16.attn.proj.weight', 'encoder.blocks.16.attn.proj.bias', 'encoder.blocks.16.ls1.gamma', 'encoder.blocks.16.norm2.weight', 'encoder.blocks.16.norm2.bias', 'encoder.blocks.16.mlp.fc1.weight', 'encoder.blocks.16.mlp.fc1.bias', 'encoder.blocks.16.mlp.fc2.weight', 'encoder.blocks.16.mlp.fc2.bias', 'encoder.blocks.16.ls2.gamma', 'encoder.blocks.17.norm1.weight', 'encoder.blocks.17.norm1.bias', 'encoder.blocks.17.attn.qkv.weight', 'encoder.blocks.17.attn.qkv.bias', 'encoder.blocks.17.attn.proj.weight', 'encoder.blocks.17.attn.proj.bias', 'encoder.blocks.17.ls1.gamma', 'encoder.blocks.17.norm2.weight', 'encoder.blocks.17.norm2.bias', 'encoder.blocks.17.mlp.fc1.weight', 'encoder.blocks.17.mlp.fc1.bias', 'encoder.blocks.17.mlp.fc2.weight', 'encoder.blocks.17.mlp.fc2.bias', 'encoder.blocks.17.ls2.gamma', 'encoder.blocks.18.norm1.weight', 'encoder.blocks.18.norm1.bias', 'encoder.blocks.18.attn.qkv.weight', 'encoder.blocks.18.attn.qkv.bias', 'encoder.blocks.18.attn.proj.weight', 'encoder.blocks.18.attn.proj.bias', 'encoder.blocks.18.ls1.gamma', 'encoder.blocks.18.norm2.weight', 'encoder.blocks.18.norm2.bias', 'encoder.blocks.18.mlp.fc1.weight', 'encoder.blocks.18.mlp.fc1.bias', 'encoder.blocks.18.mlp.fc2.weight', 'encoder.blocks.18.mlp.fc2.bias', 'encoder.blocks.18.ls2.gamma', 'encoder.blocks.19.norm1.weight', 'encoder.blocks.19.norm1.bias', 'encoder.blocks.19.attn.qkv.weight', 'encoder.blocks.19.attn.qkv.bias', 'encoder.blocks.19.attn.proj.weight', 'encoder.blocks.19.attn.proj.bias', 'encoder.blocks.19.ls1.gamma', 'encoder.blocks.19.norm2.weight', 'encoder.blocks.19.norm2.bias', 'encoder.blocks.19.mlp.fc1.weight', 'encoder.blocks.19.mlp.fc1.bias', 'encoder.blocks.19.mlp.fc2.weight', 'encoder.blocks.19.mlp.fc2.bias', 'encoder.blocks.19.ls2.gamma', 'encoder.blocks.20.norm1.weight', 'encoder.blocks.20.norm1.bias', 'encoder.blocks.20.attn.qkv.weight', 'encoder.blocks.20.attn.qkv.bias', 'encoder.blocks.20.attn.proj.weight', 'encoder.blocks.20.attn.proj.bias', 'encoder.blocks.20.ls1.gamma', 'encoder.blocks.20.norm2.weight', 'encoder.blocks.20.norm2.bias', 'encoder.blocks.20.mlp.fc1.weight', 'encoder.blocks.20.mlp.fc1.bias', 'encoder.blocks.20.mlp.fc2.weight', 'encoder.blocks.20.mlp.fc2.bias', 'encoder.blocks.20.ls2.gamma', 'encoder.blocks.21.norm1.weight', 'encoder.blocks.21.norm1.bias', 'encoder.blocks.21.attn.qkv.weight', 'encoder.blocks.21.attn.qkv.bias', 'encoder.blocks.21.attn.proj.weight', 'encoder.blocks.21.attn.proj.bias', 'encoder.blocks.21.ls1.gamma', 'encoder.blocks.21.norm2.weight', 'encoder.blocks.21.norm2.bias', 'encoder.blocks.21.mlp.fc1.weight', 'encoder.blocks.21.mlp.fc1.bias', 'encoder.blocks.21.mlp.fc2.weight', 'encoder.blocks.21.mlp.fc2.bias', 'encoder.blocks.21.ls2.gamma', 'encoder.blocks.22.norm1.weight', 'encoder.blocks.22.norm1.bias', 'encoder.blocks.22.attn.qkv.weight', 'encoder.blocks.22.attn.qkv.bias', 'encoder.blocks.22.attn.proj.weight', 'encoder.blocks.22.attn.proj.bias', 'encoder.blocks.22.ls1.gamma', 'encoder.blocks.22.norm2.weight', 'encoder.blocks.22.norm2.bias', 'encoder.blocks.22.mlp.fc1.weight', 'encoder.blocks.22.mlp.fc1.bias', 'encoder.blocks.22.mlp.fc2.weight', 'encoder.blocks.22.mlp.fc2.bias', 'encoder.blocks.22.ls2.gamma', 'encoder.blocks.23.norm1.weight', 'encoder.blocks.23.norm1.bias', 'encoder.blocks.23.attn.qkv.weight', 'encoder.blocks.23.attn.qkv.bias', 'encoder.blocks.23.attn.proj.weight', 'encoder.blocks.23.attn.proj.bias', 'encoder.blocks.23.ls1.gamma', 'encoder.blocks.23.norm2.weight', 'encoder.blocks.23.norm2.bias', 'encoder.blocks.23.mlp.fc1.weight', 'encoder.blocks.23.mlp.fc1.bias', 'encoder.blocks.23.mlp.fc2.weight', 'encoder.blocks.23.mlp.fc2.bias', 'encoder.blocks.23.ls2.gamma', 'encoder.norm.weight', 'encoder.norm.bias', 'decoder.0.norm1.weight', 'decoder.0.norm1.bias', 'decoder.0.attn.qkv.weight', 'decoder.0.attn.qkv.bias', 'decoder.0.attn.proj.weight', 'decoder.0.attn.proj.bias', 'decoder.0.attn.q_norm.weight', 'decoder.0.attn.q_norm.bias', 'decoder.0.attn.k_norm.weight', 'decoder.0.attn.k_norm.bias', 'decoder.0.ls1.gamma', 'decoder.0.norm2.weight', 'decoder.0.norm2.bias', 'decoder.0.mlp.fc1.weight', 'decoder.0.mlp.fc1.bias', 'decoder.0.mlp.fc2.weight', 'decoder.0.mlp.fc2.bias', 'decoder.0.ls2.gamma', 'decoder.1.norm1.weight', 'decoder.1.norm1.bias', 'decoder.1.attn.qkv.weight', 'decoder.1.attn.qkv.bias', 'decoder.1.attn.proj.weight', 'decoder.1.attn.proj.bias', 'decoder.1.attn.q_norm.weight', 'decoder.1.attn.q_norm.bias', 'decoder.1.attn.k_norm.weight', 'decoder.1.attn.k_norm.bias', 'decoder.1.ls1.gamma', 'decoder.1.norm2.weight', 'decoder.1.norm2.bias', 'decoder.1.mlp.fc1.weight', 'decoder.1.mlp.fc1.bias', 'decoder.1.mlp.fc2.weight', 'decoder.1.mlp.fc2.bias', 'decoder.1.ls2.gamma', 'decoder.2.norm1.weight', 'decoder.2.norm1.bias', 'decoder.2.attn.qkv.weight', 'decoder.2.attn.qkv.bias', 'decoder.2.attn.proj.weight', 'decoder.2.attn.proj.bias', 'decoder.2.attn.q_norm.weight', 'decoder.2.attn.q_norm.bias', 'decoder.2.attn.k_norm.weight', 'decoder.2.attn.k_norm.bias', 'decoder.2.ls1.gamma', 'decoder.2.norm2.weight', 'decoder.2.norm2.bias', 'decoder.2.mlp.fc1.weight', 'decoder.2.mlp.fc1.bias', 'decoder.2.mlp.fc2.weight', 'decoder.2.mlp.fc2.bias', 'decoder.2.ls2.gamma', 'decoder.3.norm1.weight', 'decoder.3.norm1.bias', 'decoder.3.attn.qkv.weight', 'decoder.3.attn.qkv.bias', 'decoder.3.attn.proj.weight', 'decoder.3.attn.proj.bias', 'decoder.3.attn.q_norm.weight', 'decoder.3.attn.q_norm.bias', 'decoder.3.attn.k_norm.weight', 'decoder.3.attn.k_norm.bias', 'decoder.3.ls1.gamma', 'decoder.3.norm2.weight', 'decoder.3.norm2.bias', 'decoder.3.mlp.fc1.weight', 'decoder.3.mlp.fc1.bias', 'decoder.3.mlp.fc2.weight', 'decoder.3.mlp.fc2.bias', 'decoder.3.ls2.gamma', 'decoder.4.norm1.weight', 'decoder.4.norm1.bias', 'decoder.4.attn.qkv.weight', 'decoder.4.attn.qkv.bias', 'decoder.4.attn.proj.weight', 'decoder.4.attn.proj.bias', 'decoder.4.attn.q_norm.weight', 'decoder.4.attn.q_norm.bias', 'decoder.4.attn.k_norm.weight', 'decoder.4.attn.k_norm.bias', 'decoder.4.ls1.gamma', 'decoder.4.norm2.weight', 'decoder.4.norm2.bias', 'decoder.4.mlp.fc1.weight', 'decoder.4.mlp.fc1.bias', 'decoder.4.mlp.fc2.weight', 'decoder.4.mlp.fc2.bias', 'decoder.4.ls2.gamma', 'decoder.5.norm1.weight', 'decoder.5.norm1.bias', 'decoder.5.attn.qkv.weight', 'decoder.5.attn.qkv.bias', 'decoder.5.attn.proj.weight', 'decoder.5.attn.proj.bias', 'decoder.5.attn.q_norm.weight', 'decoder.5.attn.q_norm.bias', 'decoder.5.attn.k_norm.weight', 'decoder.5.attn.k_norm.bias', 'decoder.5.ls1.gamma', 'decoder.5.norm2.weight', 'decoder.5.norm2.bias', 'decoder.5.mlp.fc1.weight', 'decoder.5.mlp.fc1.bias', 'decoder.5.mlp.fc2.weight', 'decoder.5.mlp.fc2.bias', 'decoder.5.ls2.gamma', 'decoder.6.norm1.weight', 'decoder.6.norm1.bias', 'decoder.6.attn.qkv.weight', 'decoder.6.attn.qkv.bias', 'decoder.6.attn.proj.weight', 'decoder.6.attn.proj.bias', 'decoder.6.attn.q_norm.weight', 'decoder.6.attn.q_norm.bias', 'decoder.6.attn.k_norm.weight', 'decoder.6.attn.k_norm.bias', 'decoder.6.ls1.gamma', 'decoder.6.norm2.weight', 'decoder.6.norm2.bias', 'decoder.6.mlp.fc1.weight', 'decoder.6.mlp.fc1.bias', 'decoder.6.mlp.fc2.weight', 'decoder.6.mlp.fc2.bias', 'decoder.6.ls2.gamma', 'decoder.7.norm1.weight', 'decoder.7.norm1.bias', 'decoder.7.attn.qkv.weight', 'decoder.7.attn.qkv.bias', 'decoder.7.attn.proj.weight', 'decoder.7.attn.proj.bias', 'decoder.7.attn.q_norm.weight', 'decoder.7.attn.q_norm.bias', 'decoder.7.attn.k_norm.weight', 'decoder.7.attn.k_norm.bias', 'decoder.7.ls1.gamma', 'decoder.7.norm2.weight', 'decoder.7.norm2.bias', 'decoder.7.mlp.fc1.weight', 'decoder.7.mlp.fc1.bias', 'decoder.7.mlp.fc2.weight', 'decoder.7.mlp.fc2.bias', 'decoder.7.ls2.gamma', 'decoder.8.norm1.weight', 'decoder.8.norm1.bias', 'decoder.8.attn.qkv.weight', 'decoder.8.attn.qkv.bias', 'decoder.8.attn.proj.weight', 'decoder.8.attn.proj.bias', 'decoder.8.attn.q_norm.weight', 'decoder.8.attn.q_norm.bias', 'decoder.8.attn.k_norm.weight', 'decoder.8.attn.k_norm.bias', 'decoder.8.ls1.gamma', 'decoder.8.norm2.weight', 'decoder.8.norm2.bias', 'decoder.8.mlp.fc1.weight', 'decoder.8.mlp.fc1.bias', 'decoder.8.mlp.fc2.weight', 'decoder.8.mlp.fc2.bias', 'decoder.8.ls2.gamma', 'decoder.9.norm1.weight', 'decoder.9.norm1.bias', 'decoder.9.attn.qkv.weight', 'decoder.9.attn.qkv.bias', 'decoder.9.attn.proj.weight', 'decoder.9.attn.proj.bias', 'decoder.9.attn.q_norm.weight', 'decoder.9.attn.q_norm.bias', 'decoder.9.attn.k_norm.weight', 'decoder.9.attn.k_norm.bias', 'decoder.9.ls1.gamma', 'decoder.9.norm2.weight', 'decoder.9.norm2.bias', 'decoder.9.mlp.fc1.weight', 'decoder.9.mlp.fc1.bias', 'decoder.9.mlp.fc2.weight', 'decoder.9.mlp.fc2.bias', 'decoder.9.ls2.gamma', 'decoder.10.norm1.weight', 'decoder.10.norm1.bias', 'decoder.10.attn.qkv.weight', 'decoder.10.attn.qkv.bias', 'decoder.10.attn.proj.weight', 'decoder.10.attn.proj.bias', 'decoder.10.attn.q_norm.weight', 'decoder.10.attn.q_norm.bias', 'decoder.10.attn.k_norm.weight', 'decoder.10.attn.k_norm.bias', 'decoder.10.ls1.gamma', 'decoder.10.norm2.weight', 'decoder.10.norm2.bias', 'decoder.10.mlp.fc1.weight', 'decoder.10.mlp.fc1.bias', 'decoder.10.mlp.fc2.weight', 'decoder.10.mlp.fc2.bias', 'decoder.10.ls2.gamma', 'decoder.11.norm1.weight', 'decoder.11.norm1.bias', 'decoder.11.attn.qkv.weight', 'decoder.11.attn.qkv.bias', 'decoder.11.attn.proj.weight', 'decoder.11.attn.proj.bias', 'decoder.11.attn.q_norm.weight', 'decoder.11.attn.q_norm.bias', 'decoder.11.attn.k_norm.weight', 'decoder.11.attn.k_norm.bias', 'decoder.11.ls1.gamma', 'decoder.11.norm2.weight', 'decoder.11.norm2.bias', 'decoder.11.mlp.fc1.weight', 'decoder.11.mlp.fc1.bias', 'decoder.11.mlp.fc2.weight', 'decoder.11.mlp.fc2.bias', 'decoder.11.ls2.gamma', 'decoder.12.norm1.weight', 'decoder.12.norm1.bias', 'decoder.12.attn.qkv.weight', 'decoder.12.attn.qkv.bias', 'decoder.12.attn.proj.weight', 'decoder.12.attn.proj.bias', 'decoder.12.attn.q_norm.weight', 'decoder.12.attn.q_norm.bias', 'decoder.12.attn.k_norm.weight', 'decoder.12.attn.k_norm.bias', 'decoder.12.ls1.gamma', 'decoder.12.norm2.weight', 'decoder.12.norm2.bias', 'decoder.12.mlp.fc1.weight', 'decoder.12.mlp.fc1.bias', 'decoder.12.mlp.fc2.weight', 'decoder.12.mlp.fc2.bias', 'decoder.12.ls2.gamma', 'decoder.13.norm1.weight', 'decoder.13.norm1.bias', 'decoder.13.attn.qkv.weight', 'decoder.13.attn.qkv.bias', 'decoder.13.attn.proj.weight', 'decoder.13.attn.proj.bias', 'decoder.13.attn.q_norm.weight', 'decoder.13.attn.q_norm.bias', 'decoder.13.attn.k_norm.weight', 'decoder.13.attn.k_norm.bias', 'decoder.13.ls1.gamma', 'decoder.13.norm2.weight', 'decoder.13.norm2.bias', 'decoder.13.mlp.fc1.weight', 'decoder.13.mlp.fc1.bias', 'decoder.13.mlp.fc2.weight', 'decoder.13.mlp.fc2.bias', 'decoder.13.ls2.gamma', 'decoder.14.norm1.weight', 'decoder.14.norm1.bias', 'decoder.14.attn.qkv.weight', 'decoder.14.attn.qkv.bias', 'decoder.14.attn.proj.weight', 'decoder.14.attn.proj.bias', 'decoder.14.attn.q_norm.weight', 'decoder.14.attn.q_norm.bias', 'decoder.14.attn.k_norm.weight', 'decoder.14.attn.k_norm.bias', 'decoder.14.ls1.gamma', 'decoder.14.norm2.weight', 'decoder.14.norm2.bias', 'decoder.14.mlp.fc1.weight', 'decoder.14.mlp.fc1.bias', 'decoder.14.mlp.fc2.weight', 'decoder.14.mlp.fc2.bias', 'decoder.14.ls2.gamma', 'decoder.15.norm1.weight', 'decoder.15.norm1.bias', 'decoder.15.attn.qkv.weight', 'decoder.15.attn.qkv.bias', 'decoder.15.attn.proj.weight', 'decoder.15.attn.proj.bias', 'decoder.15.attn.q_norm.weight', 'decoder.15.attn.q_norm.bias', 'decoder.15.attn.k_norm.weight', 'decoder.15.attn.k_norm.bias', 'decoder.15.ls1.gamma', 'decoder.15.norm2.weight', 'decoder.15.norm2.bias', 'decoder.15.mlp.fc1.weight', 'decoder.15.mlp.fc1.bias', 'decoder.15.mlp.fc2.weight', 'decoder.15.mlp.fc2.bias', 'decoder.15.ls2.gamma', 'decoder.16.norm1.weight', 'decoder.16.norm1.bias', 'decoder.16.attn.qkv.weight', 'decoder.16.attn.qkv.bias', 'decoder.16.attn.proj.weight', 'decoder.16.attn.proj.bias', 'decoder.16.attn.q_norm.weight', 'decoder.16.attn.q_norm.bias', 'decoder.16.attn.k_norm.weight', 'decoder.16.attn.k_norm.bias', 'decoder.16.ls1.gamma', 'decoder.16.norm2.weight', 'decoder.16.norm2.bias', 'decoder.16.mlp.fc1.weight', 'decoder.16.mlp.fc1.bias', 'decoder.16.mlp.fc2.weight', 'decoder.16.mlp.fc2.bias', 'decoder.16.ls2.gamma', 'decoder.17.norm1.weight', 'decoder.17.norm1.bias', 'decoder.17.attn.qkv.weight', 'decoder.17.attn.qkv.bias', 'decoder.17.attn.proj.weight', 'decoder.17.attn.proj.bias', 'decoder.17.attn.q_norm.weight', 'decoder.17.attn.q_norm.bias', 'decoder.17.attn.k_norm.weight', 'decoder.17.attn.k_norm.bias', 'decoder.17.ls1.gamma', 'decoder.17.norm2.weight', 'decoder.17.norm2.bias', 'decoder.17.mlp.fc1.weight', 'decoder.17.mlp.fc1.bias', 'decoder.17.mlp.fc2.weight', 'decoder.17.mlp.fc2.bias', 'decoder.17.ls2.gamma', 'decoder.18.norm1.weight', 'decoder.18.norm1.bias', 'decoder.18.attn.qkv.weight', 'decoder.18.attn.qkv.bias', 'decoder.18.attn.proj.weight', 'decoder.18.attn.proj.bias', 'decoder.18.attn.q_norm.weight', 'decoder.18.attn.q_norm.bias', 'decoder.18.attn.k_norm.weight', 'decoder.18.attn.k_norm.bias', 'decoder.18.ls1.gamma', 'decoder.18.norm2.weight', 'decoder.18.norm2.bias', 'decoder.18.mlp.fc1.weight', 'decoder.18.mlp.fc1.bias', 'decoder.18.mlp.fc2.weight', 'decoder.18.mlp.fc2.bias', 'decoder.18.ls2.gamma', 'decoder.19.norm1.weight', 'decoder.19.norm1.bias', 'decoder.19.attn.qkv.weight', 'decoder.19.attn.qkv.bias', 'decoder.19.attn.proj.weight', 'decoder.19.attn.proj.bias', 'decoder.19.attn.q_norm.weight', 'decoder.19.attn.q_norm.bias', 'decoder.19.attn.k_norm.weight', 'decoder.19.attn.k_norm.bias', 'decoder.19.ls1.gamma', 'decoder.19.norm2.weight', 'decoder.19.norm2.bias', 'decoder.19.mlp.fc1.weight', 'decoder.19.mlp.fc1.bias', 'decoder.19.mlp.fc2.weight', 'decoder.19.mlp.fc2.bias', 'decoder.19.ls2.gamma', 'decoder.20.norm1.weight', 'decoder.20.norm1.bias', 'decoder.20.attn.qkv.weight', 'decoder.20.attn.qkv.bias', 'decoder.20.attn.proj.weight', 'decoder.20.attn.proj.bias', 'decoder.20.attn.q_norm.weight', 'decoder.20.attn.q_norm.bias', 'decoder.20.attn.k_norm.weight', 'decoder.20.attn.k_norm.bias', 'decoder.20.ls1.gamma', 'decoder.20.norm2.weight', 'decoder.20.norm2.bias', 'decoder.20.mlp.fc1.weight', 'decoder.20.mlp.fc1.bias', 'decoder.20.mlp.fc2.weight', 'decoder.20.mlp.fc2.bias', 'decoder.20.ls2.gamma', 'decoder.21.norm1.weight', 'decoder.21.norm1.bias', 'decoder.21.attn.qkv.weight', 'decoder.21.attn.qkv.bias', 'decoder.21.attn.proj.weight', 'decoder.21.attn.proj.bias', 'decoder.21.attn.q_norm.weight', 'decoder.21.attn.q_norm.bias', 'decoder.21.attn.k_norm.weight', 'decoder.21.attn.k_norm.bias', 'decoder.21.ls1.gamma', 'decoder.21.norm2.weight', 'decoder.21.norm2.bias', 'decoder.21.mlp.fc1.weight', 'decoder.21.mlp.fc1.bias', 'decoder.21.mlp.fc2.weight', 'decoder.21.mlp.fc2.bias', 'decoder.21.ls2.gamma', 'decoder.22.norm1.weight', 'decoder.22.norm1.bias', 'decoder.22.attn.qkv.weight', 'decoder.22.attn.qkv.bias', 'decoder.22.attn.proj.weight', 'decoder.22.attn.proj.bias', 'decoder.22.attn.q_norm.weight', 'decoder.22.attn.q_norm.bias', 'decoder.22.attn.k_norm.weight', 'decoder.22.attn.k_norm.bias', 'decoder.22.ls1.gamma', 'decoder.22.norm2.weight', 'decoder.22.norm2.bias', 'decoder.22.mlp.fc1.weight', 'decoder.22.mlp.fc1.bias', 'decoder.22.mlp.fc2.weight', 'decoder.22.mlp.fc2.bias', 'decoder.22.ls2.gamma', 'decoder.23.norm1.weight', 'decoder.23.norm1.bias', 'decoder.23.attn.qkv.weight', 'decoder.23.attn.qkv.bias', 'decoder.23.attn.proj.weight', 'decoder.23.attn.proj.bias', 'decoder.23.attn.q_norm.weight', 'decoder.23.attn.q_norm.bias', 'decoder.23.attn.k_norm.weight', 'decoder.23.attn.k_norm.bias', 'decoder.23.ls1.gamma', 'decoder.23.norm2.weight', 'decoder.23.norm2.bias', 'decoder.23.mlp.fc1.weight', 'decoder.23.mlp.fc1.bias', 'decoder.23.mlp.fc2.weight', 'decoder.23.mlp.fc2.bias', 'decoder.23.ls2.gamma', 'decoder.24.norm1.weight', 'decoder.24.norm1.bias', 'decoder.24.attn.qkv.weight', 'decoder.24.attn.qkv.bias', 'decoder.24.attn.proj.weight', 'decoder.24.attn.proj.bias', 'decoder.24.attn.q_norm.weight', 'decoder.24.attn.q_norm.bias', 'decoder.24.attn.k_norm.weight', 'decoder.24.attn.k_norm.bias', 'decoder.24.ls1.gamma', 'decoder.24.norm2.weight', 'decoder.24.norm2.bias', 'decoder.24.mlp.fc1.weight', 'decoder.24.mlp.fc1.bias', 'decoder.24.mlp.fc2.weight', 'decoder.24.mlp.fc2.bias', 'decoder.24.ls2.gamma', 'decoder.25.norm1.weight', 'decoder.25.norm1.bias', 'decoder.25.attn.qkv.weight', 'decoder.25.attn.qkv.bias', 'decoder.25.attn.proj.weight', 'decoder.25.attn.proj.bias', 'decoder.25.attn.q_norm.weight', 'decoder.25.attn.q_norm.bias', 'decoder.25.attn.k_norm.weight', 'decoder.25.attn.k_norm.bias', 'decoder.25.ls1.gamma', 'decoder.25.norm2.weight', 'decoder.25.norm2.bias', 'decoder.25.mlp.fc1.weight', 'decoder.25.mlp.fc1.bias', 'decoder.25.mlp.fc2.weight', 'decoder.25.mlp.fc2.bias', 'decoder.25.ls2.gamma', 'decoder.26.norm1.weight', 'decoder.26.norm1.bias', 'decoder.26.attn.qkv.weight', 'decoder.26.attn.qkv.bias', 'decoder.26.attn.proj.weight', 'decoder.26.attn.proj.bias', 'decoder.26.attn.q_norm.weight', 'decoder.26.attn.q_norm.bias', 'decoder.26.attn.k_norm.weight', 'decoder.26.attn.k_norm.bias', 'decoder.26.ls1.gamma', 'decoder.26.norm2.weight', 'decoder.26.norm2.bias', 'decoder.26.mlp.fc1.weight', 'decoder.26.mlp.fc1.bias', 'decoder.26.mlp.fc2.weight', 'decoder.26.mlp.fc2.bias', 'decoder.26.ls2.gamma', 'decoder.27.norm1.weight', 'decoder.27.norm1.bias', 'decoder.27.attn.qkv.weight', 'decoder.27.attn.qkv.bias', 'decoder.27.attn.proj.weight', 'decoder.27.attn.proj.bias', 'decoder.27.attn.q_norm.weight', 'decoder.27.attn.q_norm.bias', 'decoder.27.attn.k_norm.weight', 'decoder.27.attn.k_norm.bias', 'decoder.27.ls1.gamma', 'decoder.27.norm2.weight', 'decoder.27.norm2.bias', 'decoder.27.mlp.fc1.weight', 'decoder.27.mlp.fc1.bias', 'decoder.27.mlp.fc2.weight', 'decoder.27.mlp.fc2.bias', 'decoder.27.ls2.gamma', 'decoder.28.norm1.weight', 'decoder.28.norm1.bias', 'decoder.28.attn.qkv.weight', 'decoder.28.attn.qkv.bias', 'decoder.28.attn.proj.weight', 'decoder.28.attn.proj.bias', 'decoder.28.attn.q_norm.weight', 'decoder.28.attn.q_norm.bias', 'decoder.28.attn.k_norm.weight', 'decoder.28.attn.k_norm.bias', 'decoder.28.ls1.gamma', 'decoder.28.norm2.weight', 'decoder.28.norm2.bias', 'decoder.28.mlp.fc1.weight', 'decoder.28.mlp.fc1.bias', 'decoder.28.mlp.fc2.weight', 'decoder.28.mlp.fc2.bias', 'decoder.28.ls2.gamma', 'decoder.29.norm1.weight', 'decoder.29.norm1.bias', 'decoder.29.attn.qkv.weight', 'decoder.29.attn.qkv.bias', 'decoder.29.attn.proj.weight', 'decoder.29.attn.proj.bias', 'decoder.29.attn.q_norm.weight', 'decoder.29.attn.q_norm.bias', 'decoder.29.attn.k_norm.weight', 'decoder.29.attn.k_norm.bias', 'decoder.29.ls1.gamma', 'decoder.29.norm2.weight', 'decoder.29.norm2.bias', 'decoder.29.mlp.fc1.weight', 'decoder.29.mlp.fc1.bias', 'decoder.29.mlp.fc2.weight', 'decoder.29.mlp.fc2.bias', 'decoder.29.ls2.gamma', 'decoder.30.norm1.weight', 'decoder.30.norm1.bias', 'decoder.30.attn.qkv.weight', 'decoder.30.attn.qkv.bias', 'decoder.30.attn.proj.weight', 'decoder.30.attn.proj.bias', 'decoder.30.attn.q_norm.weight', 'decoder.30.attn.q_norm.bias', 'decoder.30.attn.k_norm.weight', 'decoder.30.attn.k_norm.bias', 'decoder.30.ls1.gamma', 'decoder.30.norm2.weight', 'decoder.30.norm2.bias', 'decoder.30.mlp.fc1.weight', 'decoder.30.mlp.fc1.bias', 'decoder.30.mlp.fc2.weight', 'decoder.30.mlp.fc2.bias', 'decoder.30.ls2.gamma', 'decoder.31.norm1.weight', 'decoder.31.norm1.bias', 'decoder.31.attn.qkv.weight', 'decoder.31.attn.qkv.bias', 'decoder.31.attn.proj.weight', 'decoder.31.attn.proj.bias', 'decoder.31.attn.q_norm.weight', 'decoder.31.attn.q_norm.bias', 'decoder.31.attn.k_norm.weight', 'decoder.31.attn.k_norm.bias', 'decoder.31.ls1.gamma', 'decoder.31.norm2.weight', 'decoder.31.norm2.bias', 'decoder.31.mlp.fc1.weight', 'decoder.31.mlp.fc1.bias', 'decoder.31.mlp.fc2.weight', 'decoder.31.mlp.fc2.bias', 'decoder.31.ls2.gamma', 'decoder.32.norm1.weight', 'decoder.32.norm1.bias', 'decoder.32.attn.qkv.weight', 'decoder.32.attn.qkv.bias', 'decoder.32.attn.proj.weight', 'decoder.32.attn.proj.bias', 'decoder.32.attn.q_norm.weight', 'decoder.32.attn.q_norm.bias', 'decoder.32.attn.k_norm.weight', 'decoder.32.attn.k_norm.bias', 'decoder.32.ls1.gamma', 'decoder.32.norm2.weight', 'decoder.32.norm2.bias', 'decoder.32.mlp.fc1.weight', 'decoder.32.mlp.fc1.bias', 'decoder.32.mlp.fc2.weight', 'decoder.32.mlp.fc2.bias', 'decoder.32.ls2.gamma', 'decoder.33.norm1.weight', 'decoder.33.norm1.bias', 'decoder.33.attn.qkv.weight', 'decoder.33.attn.qkv.bias', 'decoder.33.attn.proj.weight', 'decoder.33.attn.proj.bias', 'decoder.33.attn.q_norm.weight', 'decoder.33.attn.q_norm.bias', 'decoder.33.attn.k_norm.weight', 'decoder.33.attn.k_norm.bias', 'decoder.33.ls1.gamma', 'decoder.33.norm2.weight', 'decoder.33.norm2.bias', 'decoder.33.mlp.fc1.weight', 'decoder.33.mlp.fc1.bias', 'decoder.33.mlp.fc2.weight', 'decoder.33.mlp.fc2.bias', 'decoder.33.ls2.gamma', 'decoder.34.norm1.weight', 'decoder.34.norm1.bias', 'decoder.34.attn.qkv.weight', 'decoder.34.attn.qkv.bias', 'decoder.34.attn.proj.weight', 'decoder.34.attn.proj.bias', 'decoder.34.attn.q_norm.weight', 'decoder.34.attn.q_norm.bias', 'decoder.34.attn.k_norm.weight', 'decoder.34.attn.k_norm.bias', 'decoder.34.ls1.gamma', 'decoder.34.norm2.weight', 'decoder.34.norm2.bias', 'decoder.34.mlp.fc1.weight', 'decoder.34.mlp.fc1.bias', 'decoder.34.mlp.fc2.weight', 'decoder.34.mlp.fc2.bias', 'decoder.34.ls2.gamma', 'decoder.35.norm1.weight', 'decoder.35.norm1.bias', 'decoder.35.attn.qkv.weight', 'decoder.35.attn.qkv.bias', 'decoder.35.attn.proj.weight', 'decoder.35.attn.proj.bias', 'decoder.35.attn.q_norm.weight', 'decoder.35.attn.q_norm.bias', 'decoder.35.attn.k_norm.weight', 'decoder.35.attn.k_norm.bias', 'decoder.35.ls1.gamma', 'decoder.35.norm2.weight', 'decoder.35.norm2.bias', 'decoder.35.mlp.fc1.weight', 'decoder.35.mlp.fc1.bias', 'decoder.35.mlp.fc2.weight', 'decoder.35.mlp.fc2.bias', 'decoder.35.ls2.gamma', 'point_decoder.projects.weight', 'point_decoder.projects.bias', 'point_decoder.blocks.0.norm1.weight', 'point_decoder.blocks.0.norm1.bias', 'point_decoder.blocks.0.attn.qkv.weight', 'point_decoder.blocks.0.attn.qkv.bias', 'point_decoder.blocks.0.attn.proj.weight', 'point_decoder.blocks.0.attn.proj.bias', 'point_decoder.blocks.0.norm2.weight', 'point_decoder.blocks.0.norm2.bias', 'point_decoder.blocks.0.mlp.fc1.weight', 'point_decoder.blocks.0.mlp.fc1.bias', 'point_decoder.blocks.0.mlp.fc2.weight', 'point_decoder.blocks.0.mlp.fc2.bias', 'point_decoder.blocks.1.norm1.weight', 'point_decoder.blocks.1.norm1.bias', 'point_decoder.blocks.1.attn.qkv.weight', 'point_decoder.blocks.1.attn.qkv.bias', 'point_decoder.blocks.1.attn.proj.weight', 'point_decoder.blocks.1.attn.proj.bias', 'point_decoder.blocks.1.norm2.weight', 'point_decoder.blocks.1.norm2.bias', 'point_decoder.blocks.1.mlp.fc1.weight', 'point_decoder.blocks.1.mlp.fc1.bias', 'point_decoder.blocks.1.mlp.fc2.weight', 'point_decoder.blocks.1.mlp.fc2.bias', 'point_decoder.blocks.2.norm1.weight', 'point_decoder.blocks.2.norm1.bias', 'point_decoder.blocks.2.attn.qkv.weight', 'point_decoder.blocks.2.attn.qkv.bias', 'point_decoder.blocks.2.attn.proj.weight', 'point_decoder.blocks.2.attn.proj.bias', 'point_decoder.blocks.2.norm2.weight', 'point_decoder.blocks.2.norm2.bias', 'point_decoder.blocks.2.mlp.fc1.weight', 'point_decoder.blocks.2.mlp.fc1.bias', 'point_decoder.blocks.2.mlp.fc2.weight', 'point_decoder.blocks.2.mlp.fc2.bias', 'point_decoder.blocks.3.norm1.weight', 'point_decoder.blocks.3.norm1.bias', 'point_decoder.blocks.3.attn.qkv.weight', 'point_decoder.blocks.3.attn.qkv.bias', 'point_decoder.blocks.3.attn.proj.weight', 'point_decoder.blocks.3.attn.proj.bias', 'point_decoder.blocks.3.norm2.weight', 'point_decoder.blocks.3.norm2.bias', 'point_decoder.blocks.3.mlp.fc1.weight', 'point_decoder.blocks.3.mlp.fc1.bias', 'point_decoder.blocks.3.mlp.fc2.weight', 'point_decoder.blocks.3.mlp.fc2.bias', 'point_decoder.blocks.4.norm1.weight', 'point_decoder.blocks.4.norm1.bias', 'point_decoder.blocks.4.attn.qkv.weight', 'point_decoder.blocks.4.attn.qkv.bias', 'point_decoder.blocks.4.attn.proj.weight', 'point_decoder.blocks.4.attn.proj.bias', 'point_decoder.blocks.4.norm2.weight', 'point_decoder.blocks.4.norm2.bias', 'point_decoder.blocks.4.mlp.fc1.weight', 'point_decoder.blocks.4.mlp.fc1.bias', 'point_decoder.blocks.4.mlp.fc2.weight', 'point_decoder.blocks.4.mlp.fc2.bias', 'point_decoder.linear_out.weight', 'point_decoder.linear_out.bias', 'point_head.proj.weight', 'point_head.proj.bias', 'conf_decoder.projects.weight', 'conf_decoder.projects.bias', 'conf_decoder.blocks.0.norm1.weight', 'conf_decoder.blocks.0.norm1.bias', 'conf_decoder.blocks.0.attn.qkv.weight', 'conf_decoder.blocks.0.attn.qkv.bias', 'conf_decoder.blocks.0.attn.proj.weight', 'conf_decoder.blocks.0.attn.proj.bias', 'conf_decoder.blocks.0.norm2.weight', 'conf_decoder.blocks.0.norm2.bias', 'conf_decoder.blocks.0.mlp.fc1.weight', 'conf_decoder.blocks.0.mlp.fc1.bias', 'conf_decoder.blocks.0.mlp.fc2.weight', 'conf_decoder.blocks.0.mlp.fc2.bias', 'conf_decoder.blocks.1.norm1.weight', 'conf_decoder.blocks.1.norm1.bias', 'conf_decoder.blocks.1.attn.qkv.weight', 'conf_decoder.blocks.1.attn.qkv.bias', 'conf_decoder.blocks.1.attn.proj.weight', 'conf_decoder.blocks.1.attn.proj.bias', 'conf_decoder.blocks.1.norm2.weight', 'conf_decoder.blocks.1.norm2.bias', 'conf_decoder.blocks.1.mlp.fc1.weight', 'conf_decoder.blocks.1.mlp.fc1.bias', 'conf_decoder.blocks.1.mlp.fc2.weight', 'conf_decoder.blocks.1.mlp.fc2.bias', 'conf_decoder.blocks.2.norm1.weight', 'conf_decoder.blocks.2.norm1.bias', 'conf_decoder.blocks.2.attn.qkv.weight', 'conf_decoder.blocks.2.attn.qkv.bias', 'conf_decoder.blocks.2.attn.proj.weight', 'conf_decoder.blocks.2.attn.proj.bias', 'conf_decoder.blocks.2.norm2.weight', 'conf_decoder.blocks.2.norm2.bias', 'conf_decoder.blocks.2.mlp.fc1.weight', 'conf_decoder.blocks.2.mlp.fc1.bias', 'conf_decoder.blocks.2.mlp.fc2.weight', 'conf_decoder.blocks.2.mlp.fc2.bias', 'conf_decoder.blocks.3.norm1.weight', 'conf_decoder.blocks.3.norm1.bias', 'conf_decoder.blocks.3.attn.qkv.weight', 'conf_decoder.blocks.3.attn.qkv.bias', 'conf_decoder.blocks.3.attn.proj.weight', 'conf_decoder.blocks.3.attn.proj.bias', 'conf_decoder.blocks.3.norm2.weight', 'conf_decoder.blocks.3.norm2.bias', 'conf_decoder.blocks.3.mlp.fc1.weight', 'conf_decoder.blocks.3.mlp.fc1.bias', 'conf_decoder.blocks.3.mlp.fc2.weight', 'conf_decoder.blocks.3.mlp.fc2.bias', 'conf_decoder.blocks.4.norm1.weight', 'conf_decoder.blocks.4.norm1.bias', 'conf_decoder.blocks.4.attn.qkv.weight', 'conf_decoder.blocks.4.attn.qkv.bias', 'conf_decoder.blocks.4.attn.proj.weight', 'conf_decoder.blocks.4.attn.proj.bias', 'conf_decoder.blocks.4.norm2.weight', 'conf_decoder.blocks.4.norm2.bias', 'conf_decoder.blocks.4.mlp.fc1.weight', 'conf_decoder.blocks.4.mlp.fc1.bias', 'conf_decoder.blocks.4.mlp.fc2.weight', 'conf_decoder.blocks.4.mlp.fc2.bias', 'conf_decoder.linear_out.weight', 'conf_decoder.linear_out.bias', 'conf_head.proj.weight', 'conf_head.proj.bias', 'camera_decoder.projects.weight', 'camera_decoder.projects.bias', 'camera_decoder.blocks.0.norm1.weight', 'camera_decoder.blocks.0.norm1.bias', 'camera_decoder.blocks.0.attn.qkv.weight', 'camera_decoder.blocks.0.attn.qkv.bias', 'camera_decoder.blocks.0.attn.proj.weight', 'camera_decoder.blocks.0.attn.proj.bias', 'camera_decoder.blocks.0.norm2.weight', 'camera_decoder.blocks.0.norm2.bias', 'camera_decoder.blocks.0.mlp.fc1.weight', 'camera_decoder.blocks.0.mlp.fc1.bias', 'camera_decoder.blocks.0.mlp.fc2.weight', 'camera_decoder.blocks.0.mlp.fc2.bias', 'camera_decoder.blocks.1.norm1.weight', 'camera_decoder.blocks.1.norm1.bias', 'camera_decoder.blocks.1.attn.qkv.weight', 'camera_decoder.blocks.1.attn.qkv.bias', 'camera_decoder.blocks.1.attn.proj.weight', 'camera_decoder.blocks.1.attn.proj.bias', 'camera_decoder.blocks.1.norm2.weight', 'camera_decoder.blocks.1.norm2.bias', 'camera_decoder.blocks.1.mlp.fc1.weight', 'camera_decoder.blocks.1.mlp.fc1.bias', 'camera_decoder.blocks.1.mlp.fc2.weight', 'camera_decoder.blocks.1.mlp.fc2.bias', 'camera_decoder.blocks.2.norm1.weight', 'camera_decoder.blocks.2.norm1.bias', 'camera_decoder.blocks.2.attn.qkv.weight', 'camera_decoder.blocks.2.attn.qkv.bias', 'camera_decoder.blocks.2.attn.proj.weight', 'camera_decoder.blocks.2.attn.proj.bias', 'camera_decoder.blocks.2.norm2.weight', 'camera_decoder.blocks.2.norm2.bias', 'camera_decoder.blocks.2.mlp.fc1.weight', 'camera_decoder.blocks.2.mlp.fc1.bias', 'camera_decoder.blocks.2.mlp.fc2.weight', 'camera_decoder.blocks.2.mlp.fc2.bias', 'camera_decoder.blocks.3.norm1.weight', 'camera_decoder.blocks.3.norm1.bias', 'camera_decoder.blocks.3.attn.qkv.weight', 'camera_decoder.blocks.3.attn.qkv.bias', 'camera_decoder.blocks.3.attn.proj.weight', 'camera_decoder.blocks.3.attn.proj.bias', 'camera_decoder.blocks.3.norm2.weight', 'camera_decoder.blocks.3.norm2.bias', 'camera_decoder.blocks.3.mlp.fc1.weight', 'camera_decoder.blocks.3.mlp.fc1.bias', 'camera_decoder.blocks.3.mlp.fc2.weight', 'camera_decoder.blocks.3.mlp.fc2.bias', 'camera_decoder.blocks.4.norm1.weight', 'camera_decoder.blocks.4.norm1.bias', 'camera_decoder.blocks.4.attn.qkv.weight', 'camera_decoder.blocks.4.attn.qkv.bias', 'camera_decoder.blocks.4.attn.proj.weight', 'camera_decoder.blocks.4.attn.proj.bias', 'camera_decoder.blocks.4.norm2.weight', 'camera_decoder.blocks.4.norm2.bias', 'camera_decoder.blocks.4.mlp.fc1.weight', 'camera_decoder.blocks.4.mlp.fc1.bias', 'camera_decoder.blocks.4.mlp.fc2.weight', 'camera_decoder.blocks.4.mlp.fc2.bias', 'camera_decoder.linear_out.weight', 'camera_decoder.linear_out.bias', 'camera_head.res_conv.0.res_conv1.weight', 'camera_head.res_conv.0.res_conv1.bias', 'camera_head.res_conv.0.res_conv2.weight', 'camera_head.res_conv.0.res_conv2.bias', 'camera_head.res_conv.0.res_conv3.weight', 'camera_head.res_conv.0.res_conv3.bias', 'camera_head.res_conv.1.res_conv1.weight', 'camera_head.res_conv.1.res_conv1.bias', 'camera_head.res_conv.1.res_conv2.weight', 'camera_head.res_conv.1.res_conv2.bias', 'camera_head.res_conv.1.res_conv3.weight', 'camera_head.res_conv.1.res_conv3.bias', 'camera_head.more_mlps.0.weight', 'camera_head.more_mlps.0.bias', 'camera_head.more_mlps.2.weight', 'camera_head.more_mlps.2.bias', 'camera_head.fc_t.weight', 'camera_head.fc_t.bias', 'camera_head.fc_rot.weight', 'camera_head.fc_rot.bias'], unexpected_keys=['module.register_token', 'module.image_mean', 'module.image_std', 'module.encoder.cls_token', 'module.encoder.pos_embed', 'module.encoder.register_tokens', 'module.encoder.patch_embed.proj.weight', 'module.encoder.patch_embed.proj.bias', 'module.encoder.blocks.0.norm1.weight', 'module.encoder.blocks.0.norm1.bias', 'module.encoder.blocks.0.attn.qkv.weight', 'module.encoder.blocks.0.attn.qkv.bias', 'module.encoder.blocks.0.attn.proj.weight', 'module.encoder.blocks.0.attn.proj.bias', 'module.encoder.blocks.0.ls1.gamma', 'module.encoder.blocks.0.norm2.weight', 'module.encoder.blocks.0.norm2.bias', 'module.encoder.blocks.0.mlp.fc1.weight', 'module.encoder.blocks.0.mlp.fc1.bias', 'module.encoder.blocks.0.mlp.fc2.weight', 'module.encoder.blocks.0.mlp.fc2.bias', 'module.encoder.blocks.0.ls2.gamma', 'module.encoder.blocks.1.norm1.weight', 'module.encoder.blocks.1.norm1.bias', 'module.encoder.blocks.1.attn.qkv.weight', 'module.encoder.blocks.1.attn.qkv.bias', 'module.encoder.blocks.1.attn.proj.weight', 'module.encoder.blocks.1.attn.proj.bias', 'module.encoder.blocks.1.ls1.gamma', 'module.encoder.blocks.1.norm2.weight', 'module.encoder.blocks.1.norm2.bias', 'module.encoder.blocks.1.mlp.fc1.weight', 'module.encoder.blocks.1.mlp.fc1.bias', 'module.encoder.blocks.1.mlp.fc2.weight', 'module.encoder.blocks.1.mlp.fc2.bias', 'module.encoder.blocks.1.ls2.gamma', 'module.encoder.blocks.2.norm1.weight', 'module.encoder.blocks.2.norm1.bias', 'module.encoder.blocks.2.attn.qkv.weight', 'module.encoder.blocks.2.attn.qkv.bias', 'module.encoder.blocks.2.attn.proj.weight', 'module.encoder.blocks.2.attn.proj.bias', 'module.encoder.blocks.2.ls1.gamma', 'module.encoder.blocks.2.norm2.weight', 'module.encoder.blocks.2.norm2.bias', 'module.encoder.blocks.2.mlp.fc1.weight', 'module.encoder.blocks.2.mlp.fc1.bias', 'module.encoder.blocks.2.mlp.fc2.weight', 'module.encoder.blocks.2.mlp.fc2.bias', 'module.encoder.blocks.2.ls2.gamma', 'module.encoder.blocks.3.norm1.weight', 'module.encoder.blocks.3.norm1.bias', 'module.encoder.blocks.3.attn.qkv.weight', 'module.encoder.blocks.3.attn.qkv.bias', 'module.encoder.blocks.3.attn.proj.weight', 'module.encoder.blocks.3.attn.proj.bias', 'module.encoder.blocks.3.ls1.gamma', 'module.encoder.blocks.3.norm2.weight', 'module.encoder.blocks.3.norm2.bias', 'module.encoder.blocks.3.mlp.fc1.weight', 'module.encoder.blocks.3.mlp.fc1.bias', 'module.encoder.blocks.3.mlp.fc2.weight', 'module.encoder.blocks.3.mlp.fc2.bias', 'module.encoder.blocks.3.ls2.gamma', 'module.encoder.blocks.4.norm1.weight', 'module.encoder.blocks.4.norm1.bias', 'module.encoder.blocks.4.attn.qkv.weight', 'module.encoder.blocks.4.attn.qkv.bias', 'module.encoder.blocks.4.attn.proj.weight', 'module.encoder.blocks.4.attn.proj.bias', 'module.encoder.blocks.4.ls1.gamma', 'module.encoder.blocks.4.norm2.weight', 'module.encoder.blocks.4.norm2.bias', 'module.encoder.blocks.4.mlp.fc1.weight', 'module.encoder.blocks.4.mlp.fc1.bias', 'module.encoder.blocks.4.mlp.fc2.weight', 'module.encoder.blocks.4.mlp.fc2.bias', 'module.encoder.blocks.4.ls2.gamma', 'module.encoder.blocks.5.norm1.weight', 'module.encoder.blocks.5.norm1.bias', 'module.encoder.blocks.5.attn.qkv.weight', 'module.encoder.blocks.5.attn.qkv.bias', 'module.encoder.blocks.5.attn.proj.weight', 'module.encoder.blocks.5.attn.proj.bias', 'module.encoder.blocks.5.ls1.gamma', 'module.encoder.blocks.5.norm2.weight', 'module.encoder.blocks.5.norm2.bias', 'module.encoder.blocks.5.mlp.fc1.weight', 'module.encoder.blocks.5.mlp.fc1.bias', 'module.encoder.blocks.5.mlp.fc2.weight', 'module.encoder.blocks.5.mlp.fc2.bias', 'module.encoder.blocks.5.ls2.gamma', 'module.encoder.blocks.6.norm1.weight', 'module.encoder.blocks.6.norm1.bias', 'module.encoder.blocks.6.attn.qkv.weight', 'module.encoder.blocks.6.attn.qkv.bias', 'module.encoder.blocks.6.attn.proj.weight', 'module.encoder.blocks.6.attn.proj.bias', 'module.encoder.blocks.6.ls1.gamma', 'module.encoder.blocks.6.norm2.weight', 'module.encoder.blocks.6.norm2.bias', 'module.encoder.blocks.6.mlp.fc1.weight', 'module.encoder.blocks.6.mlp.fc1.bias', 'module.encoder.blocks.6.mlp.fc2.weight', 'module.encoder.blocks.6.mlp.fc2.bias', 'module.encoder.blocks.6.ls2.gamma', 'module.encoder.blocks.7.norm1.weight', 'module.encoder.blocks.7.norm1.bias', 'module.encoder.blocks.7.attn.qkv.weight', 'module.encoder.blocks.7.attn.qkv.bias', 'module.encoder.blocks.7.attn.proj.weight', 'module.encoder.blocks.7.attn.proj.bias', 'module.encoder.blocks.7.ls1.gamma', 'module.encoder.blocks.7.norm2.weight', 'module.encoder.blocks.7.norm2.bias', 'module.encoder.blocks.7.mlp.fc1.weight', 'module.encoder.blocks.7.mlp.fc1.bias', 'module.encoder.blocks.7.mlp.fc2.weight', 'module.encoder.blocks.7.mlp.fc2.bias', 'module.encoder.blocks.7.ls2.gamma', 'module.encoder.blocks.8.norm1.weight', 'module.encoder.blocks.8.norm1.bias', 'module.encoder.blocks.8.attn.qkv.weight', 'module.encoder.blocks.8.attn.qkv.bias', 'module.encoder.blocks.8.attn.proj.weight', 'module.encoder.blocks.8.attn.proj.bias', 'module.encoder.blocks.8.ls1.gamma', 'module.encoder.blocks.8.norm2.weight', 'module.encoder.blocks.8.norm2.bias', 'module.encoder.blocks.8.mlp.fc1.weight', 'module.encoder.blocks.8.mlp.fc1.bias', 'module.encoder.blocks.8.mlp.fc2.weight', 'module.encoder.blocks.8.mlp.fc2.bias', 'module.encoder.blocks.8.ls2.gamma', 'module.encoder.blocks.9.norm1.weight', 'module.encoder.blocks.9.norm1.bias', 'module.encoder.blocks.9.attn.qkv.weight', 'module.encoder.blocks.9.attn.qkv.bias', 'module.encoder.blocks.9.attn.proj.weight', 'module.encoder.blocks.9.attn.proj.bias', 'module.encoder.blocks.9.ls1.gamma', 'module.encoder.blocks.9.norm2.weight', 'module.encoder.blocks.9.norm2.bias', 'module.encoder.blocks.9.mlp.fc1.weight', 'module.encoder.blocks.9.mlp.fc1.bias', 'module.encoder.blocks.9.mlp.fc2.weight', 'module.encoder.blocks.9.mlp.fc2.bias', 'module.encoder.blocks.9.ls2.gamma', 'module.encoder.blocks.10.norm1.weight', 'module.encoder.blocks.10.norm1.bias', 'module.encoder.blocks.10.attn.qkv.weight', 'module.encoder.blocks.10.attn.qkv.bias', 'module.encoder.blocks.10.attn.proj.weight', 'module.encoder.blocks.10.attn.proj.bias', 'module.encoder.blocks.10.ls1.gamma', 'module.encoder.blocks.10.norm2.weight', 'module.encoder.blocks.10.norm2.bias', 'module.encoder.blocks.10.mlp.fc1.weight', 'module.encoder.blocks.10.mlp.fc1.bias', 'module.encoder.blocks.10.mlp.fc2.weight', 'module.encoder.blocks.10.mlp.fc2.bias', 'module.encoder.blocks.10.ls2.gamma', 'module.encoder.blocks.11.norm1.weight', 'module.encoder.blocks.11.norm1.bias', 'module.encoder.blocks.11.attn.qkv.weight', 'module.encoder.blocks.11.attn.qkv.bias', 'module.encoder.blocks.11.attn.proj.weight', 'module.encoder.blocks.11.attn.proj.bias', 'module.encoder.blocks.11.ls1.gamma', 'module.encoder.blocks.11.norm2.weight', 'module.encoder.blocks.11.norm2.bias', 'module.encoder.blocks.11.mlp.fc1.weight', 'module.encoder.blocks.11.mlp.fc1.bias', 'module.encoder.blocks.11.mlp.fc2.weight', 'module.encoder.blocks.11.mlp.fc2.bias', 'module.encoder.blocks.11.ls2.gamma', 'module.encoder.blocks.12.norm1.weight', 'module.encoder.blocks.12.norm1.bias', 'module.encoder.blocks.12.attn.qkv.weight', 'module.encoder.blocks.12.attn.qkv.bias', 'module.encoder.blocks.12.attn.proj.weight', 'module.encoder.blocks.12.attn.proj.bias', 'module.encoder.blocks.12.ls1.gamma', 'module.encoder.blocks.12.norm2.weight', 'module.encoder.blocks.12.norm2.bias', 'module.encoder.blocks.12.mlp.fc1.weight', 'module.encoder.blocks.12.mlp.fc1.bias', 'module.encoder.blocks.12.mlp.fc2.weight', 'module.encoder.blocks.12.mlp.fc2.bias', 'module.encoder.blocks.12.ls2.gamma', 'module.encoder.blocks.13.norm1.weight', 'module.encoder.blocks.13.norm1.bias', 'module.encoder.blocks.13.attn.qkv.weight', 'module.encoder.blocks.13.attn.qkv.bias', 'module.encoder.blocks.13.attn.proj.weight', 'module.encoder.blocks.13.attn.proj.bias', 'module.encoder.blocks.13.ls1.gamma', 'module.encoder.blocks.13.norm2.weight', 'module.encoder.blocks.13.norm2.bias', 'module.encoder.blocks.13.mlp.fc1.weight', 'module.encoder.blocks.13.mlp.fc1.bias', 'module.encoder.blocks.13.mlp.fc2.weight', 'module.encoder.blocks.13.mlp.fc2.bias', 'module.encoder.blocks.13.ls2.gamma', 'module.encoder.blocks.14.norm1.weight', 'module.encoder.blocks.14.norm1.bias', 'module.encoder.blocks.14.attn.qkv.weight', 'module.encoder.blocks.14.attn.qkv.bias', 'module.encoder.blocks.14.attn.proj.weight', 'module.encoder.blocks.14.attn.proj.bias', 'module.encoder.blocks.14.ls1.gamma', 'module.encoder.blocks.14.norm2.weight', 'module.encoder.blocks.14.norm2.bias', 'module.encoder.blocks.14.mlp.fc1.weight', 'module.encoder.blocks.14.mlp.fc1.bias', 'module.encoder.blocks.14.mlp.fc2.weight', 'module.encoder.blocks.14.mlp.fc2.bias', 'module.encoder.blocks.14.ls2.gamma', 'module.encoder.blocks.15.norm1.weight', 'module.encoder.blocks.15.norm1.bias', 'module.encoder.blocks.15.attn.qkv.weight', 'module.encoder.blocks.15.attn.qkv.bias', 'module.encoder.blocks.15.attn.proj.weight', 'module.encoder.blocks.15.attn.proj.bias', 'module.encoder.blocks.15.ls1.gamma', 'module.encoder.blocks.15.norm2.weight', 'module.encoder.blocks.15.norm2.bias', 'module.encoder.blocks.15.mlp.fc1.weight', 'module.encoder.blocks.15.mlp.fc1.bias', 'module.encoder.blocks.15.mlp.fc2.weight', 'module.encoder.blocks.15.mlp.fc2.bias', 'module.encoder.blocks.15.ls2.gamma', 'module.encoder.blocks.16.norm1.weight', 'module.encoder.blocks.16.norm1.bias', 'module.encoder.blocks.16.attn.qkv.weight', 'module.encoder.blocks.16.attn.qkv.bias', 'module.encoder.blocks.16.attn.proj.weight', 'module.encoder.blocks.16.attn.proj.bias', 'module.encoder.blocks.16.ls1.gamma', 'module.encoder.blocks.16.norm2.weight', 'module.encoder.blocks.16.norm2.bias', 'module.encoder.blocks.16.mlp.fc1.weight', 'module.encoder.blocks.16.mlp.fc1.bias', 'module.encoder.blocks.16.mlp.fc2.weight', 'module.encoder.blocks.16.mlp.fc2.bias', 'module.encoder.blocks.16.ls2.gamma', 'module.encoder.blocks.17.norm1.weight', 'module.encoder.blocks.17.norm1.bias', 'module.encoder.blocks.17.attn.qkv.weight', 'module.encoder.blocks.17.attn.qkv.bias', 'module.encoder.blocks.17.attn.proj.weight', 'module.encoder.blocks.17.attn.proj.bias', 'module.encoder.blocks.17.ls1.gamma', 'module.encoder.blocks.17.norm2.weight', 'module.encoder.blocks.17.norm2.bias', 'module.encoder.blocks.17.mlp.fc1.weight', 'module.encoder.blocks.17.mlp.fc1.bias', 'module.encoder.blocks.17.mlp.fc2.weight', 'module.encoder.blocks.17.mlp.fc2.bias', 'module.encoder.blocks.17.ls2.gamma', 'module.encoder.blocks.18.norm1.weight', 'module.encoder.blocks.18.norm1.bias', 'module.encoder.blocks.18.attn.qkv.weight', 'module.encoder.blocks.18.attn.qkv.bias', 'module.encoder.blocks.18.attn.proj.weight', 'module.encoder.blocks.18.attn.proj.bias', 'module.encoder.blocks.18.ls1.gamma', 'module.encoder.blocks.18.norm2.weight', 'module.encoder.blocks.18.norm2.bias', 'module.encoder.blocks.18.mlp.fc1.weight', 'module.encoder.blocks.18.mlp.fc1.bias', 'module.encoder.blocks.18.mlp.fc2.weight', 'module.encoder.blocks.18.mlp.fc2.bias', 'module.encoder.blocks.18.ls2.gamma', 'module.encoder.blocks.19.norm1.weight', 'module.encoder.blocks.19.norm1.bias', 'module.encoder.blocks.19.attn.qkv.weight', 'module.encoder.blocks.19.attn.qkv.bias', 'module.encoder.blocks.19.attn.proj.weight', 'module.encoder.blocks.19.attn.proj.bias', 'module.encoder.blocks.19.ls1.gamma', 'module.encoder.blocks.19.norm2.weight', 'module.encoder.blocks.19.norm2.bias', 'module.encoder.blocks.19.mlp.fc1.weight', 'module.encoder.blocks.19.mlp.fc1.bias', 'module.encoder.blocks.19.mlp.fc2.weight', 'module.encoder.blocks.19.mlp.fc2.bias', 'module.encoder.blocks.19.ls2.gamma', 'module.encoder.blocks.20.norm1.weight', 'module.encoder.blocks.20.norm1.bias', 'module.encoder.blocks.20.attn.qkv.weight', 'module.encoder.blocks.20.attn.qkv.bias', 'module.encoder.blocks.20.attn.proj.weight', 'module.encoder.blocks.20.attn.proj.bias', 'module.encoder.blocks.20.ls1.gamma', 'module.encoder.blocks.20.norm2.weight', 'module.encoder.blocks.20.norm2.bias', 'module.encoder.blocks.20.mlp.fc1.weight', 'module.encoder.blocks.20.mlp.fc1.bias', 'module.encoder.blocks.20.mlp.fc2.weight', 'module.encoder.blocks.20.mlp.fc2.bias', 'module.encoder.blocks.20.ls2.gamma', 'module.encoder.blocks.21.norm1.weight', 'module.encoder.blocks.21.norm1.bias', 'module.encoder.blocks.21.attn.qkv.weight', 'module.encoder.blocks.21.attn.qkv.bias', 'module.encoder.blocks.21.attn.proj.weight', 'module.encoder.blocks.21.attn.proj.bias', 'module.encoder.blocks.21.ls1.gamma', 'module.encoder.blocks.21.norm2.weight', 'module.encoder.blocks.21.norm2.bias', 'module.encoder.blocks.21.mlp.fc1.weight', 'module.encoder.blocks.21.mlp.fc1.bias', 'module.encoder.blocks.21.mlp.fc2.weight', 'module.encoder.blocks.21.mlp.fc2.bias', 'module.encoder.blocks.21.ls2.gamma', 'module.encoder.blocks.22.norm1.weight', 'module.encoder.blocks.22.norm1.bias', 'module.encoder.blocks.22.attn.qkv.weight', 'module.encoder.blocks.22.attn.qkv.bias', 'module.encoder.blocks.22.attn.proj.weight', 'module.encoder.blocks.22.attn.proj.bias', 'module.encoder.blocks.22.ls1.gamma', 'module.encoder.blocks.22.norm2.weight', 'module.encoder.blocks.22.norm2.bias', 'module.encoder.blocks.22.mlp.fc1.weight', 'module.encoder.blocks.22.mlp.fc1.bias', 'module.encoder.blocks.22.mlp.fc2.weight', 'module.encoder.blocks.22.mlp.fc2.bias', 'module.encoder.blocks.22.ls2.gamma', 'module.encoder.blocks.23.norm1.weight', 'module.encoder.blocks.23.norm1.bias', 'module.encoder.blocks.23.attn.qkv.weight', 'module.encoder.blocks.23.attn.qkv.bias', 'module.encoder.blocks.23.attn.proj.weight', 'module.encoder.blocks.23.attn.proj.bias', 'module.encoder.blocks.23.ls1.gamma', 'module.encoder.blocks.23.norm2.weight', 'module.encoder.blocks.23.norm2.bias', 'module.encoder.blocks.23.mlp.fc1.weight', 'module.encoder.blocks.23.mlp.fc1.bias', 'module.encoder.blocks.23.mlp.fc2.weight', 'module.encoder.blocks.23.mlp.fc2.bias', 'module.encoder.blocks.23.ls2.gamma', 'module.encoder.norm.weight', 'module.encoder.norm.bias', 'module.decoder.0.norm1.weight', 'module.decoder.0.norm1.bias', 'module.decoder.0.attn.qkv.weight', 'module.decoder.0.attn.qkv.bias', 'module.decoder.0.attn.proj.weight', 'module.decoder.0.attn.proj.bias', 'module.decoder.0.attn.q_norm.weight', 'module.decoder.0.attn.q_norm.bias', 'module.decoder.0.attn.k_norm.weight', 'module.decoder.0.attn.k_norm.bias', 'module.decoder.0.ls1.gamma', 'module.decoder.0.norm2.weight', 'module.decoder.0.norm2.bias', 'module.decoder.0.mlp.fc1.weight', 'module.decoder.0.mlp.fc1.bias', 'module.decoder.0.mlp.fc2.weight', 'module.decoder.0.mlp.fc2.bias', 'module.decoder.0.ls2.gamma', 'module.decoder.1.norm1.weight', 'module.decoder.1.norm1.bias', 'module.decoder.1.attn.qkv.weight', 'module.decoder.1.attn.qkv.bias', 'module.decoder.1.attn.proj.weight', 'module.decoder.1.attn.proj.bias', 'module.decoder.1.attn.q_norm.weight', 'module.decoder.1.attn.q_norm.bias', 'module.decoder.1.attn.k_norm.weight', 'module.decoder.1.attn.k_norm.bias', 'module.decoder.1.ls1.gamma', 'module.decoder.1.norm2.weight', 'module.decoder.1.norm2.bias', 'module.decoder.1.mlp.fc1.weight', 'module.decoder.1.mlp.fc1.bias', 'module.decoder.1.mlp.fc2.weight', 'module.decoder.1.mlp.fc2.bias', 'module.decoder.1.ls2.gamma', 'module.decoder.2.norm1.weight', 'module.decoder.2.norm1.bias', 'module.decoder.2.attn.qkv.weight', 'module.decoder.2.attn.qkv.bias', 'module.decoder.2.attn.proj.weight', 'module.decoder.2.attn.proj.bias', 'module.decoder.2.attn.q_norm.weight', 'module.decoder.2.attn.q_norm.bias', 'module.decoder.2.attn.k_norm.weight', 'module.decoder.2.attn.k_norm.bias', 'module.decoder.2.ls1.gamma', 'module.decoder.2.norm2.weight', 'module.decoder.2.norm2.bias', 'module.decoder.2.mlp.fc1.weight', 'module.decoder.2.mlp.fc1.bias', 'module.decoder.2.mlp.fc2.weight', 'module.decoder.2.mlp.fc2.bias', 'module.decoder.2.ls2.gamma', 'module.decoder.3.norm1.weight', 'module.decoder.3.norm1.bias', 'module.decoder.3.attn.qkv.weight', 'module.decoder.3.attn.qkv.bias', 'module.decoder.3.attn.proj.weight', 'module.decoder.3.attn.proj.bias', 'module.decoder.3.attn.q_norm.weight', 'module.decoder.3.attn.q_norm.bias', 'module.decoder.3.attn.k_norm.weight', 'module.decoder.3.attn.k_norm.bias', 'module.decoder.3.ls1.gamma', 'module.decoder.3.norm2.weight', 'module.decoder.3.norm2.bias', 'module.decoder.3.mlp.fc1.weight', 'module.decoder.3.mlp.fc1.bias', 'module.decoder.3.mlp.fc2.weight', 'module.decoder.3.mlp.fc2.bias', 'module.decoder.3.ls2.gamma', 'module.decoder.4.norm1.weight', 'module.decoder.4.norm1.bias', 'module.decoder.4.attn.qkv.weight', 'module.decoder.4.attn.qkv.bias', 'module.decoder.4.attn.proj.weight', 'module.decoder.4.attn.proj.bias', 'module.decoder.4.attn.q_norm.weight', 'module.decoder.4.attn.q_norm.bias', 'module.decoder.4.attn.k_norm.weight', 'module.decoder.4.attn.k_norm.bias', 'module.decoder.4.ls1.gamma', 'module.decoder.4.norm2.weight', 'module.decoder.4.norm2.bias', 'module.decoder.4.mlp.fc1.weight', 'module.decoder.4.mlp.fc1.bias', 'module.decoder.4.mlp.fc2.weight', 'module.decoder.4.mlp.fc2.bias', 'module.decoder.4.ls2.gamma', 'module.decoder.5.norm1.weight', 'module.decoder.5.norm1.bias', 'module.decoder.5.attn.qkv.weight', 'module.decoder.5.attn.qkv.bias', 'module.decoder.5.attn.proj.weight', 'module.decoder.5.attn.proj.bias', 'module.decoder.5.attn.q_norm.weight', 'module.decoder.5.attn.q_norm.bias', 'module.decoder.5.attn.k_norm.weight', 'module.decoder.5.attn.k_norm.bias', 'module.decoder.5.ls1.gamma', 'module.decoder.5.norm2.weight', 'module.decoder.5.norm2.bias', 'module.decoder.5.mlp.fc1.weight', 'module.decoder.5.mlp.fc1.bias', 'module.decoder.5.mlp.fc2.weight', 'module.decoder.5.mlp.fc2.bias', 'module.decoder.5.ls2.gamma', 'module.decoder.6.norm1.weight', 'module.decoder.6.norm1.bias', 'module.decoder.6.attn.qkv.weight', 'module.decoder.6.attn.qkv.bias', 'module.decoder.6.attn.proj.weight', 'module.decoder.6.attn.proj.bias', 'module.decoder.6.attn.q_norm.weight', 'module.decoder.6.attn.q_norm.bias', 'module.decoder.6.attn.k_norm.weight', 'module.decoder.6.attn.k_norm.bias', 'module.decoder.6.ls1.gamma', 'module.decoder.6.norm2.weight', 'module.decoder.6.norm2.bias', 'module.decoder.6.mlp.fc1.weight', 'module.decoder.6.mlp.fc1.bias', 'module.decoder.6.mlp.fc2.weight', 'module.decoder.6.mlp.fc2.bias', 'module.decoder.6.ls2.gamma', 'module.decoder.7.norm1.weight', 'module.decoder.7.norm1.bias', 'module.decoder.7.attn.qkv.weight', 'module.decoder.7.attn.qkv.bias', 'module.decoder.7.attn.proj.weight', 'module.decoder.7.attn.proj.bias', 'module.decoder.7.attn.q_norm.weight', 'module.decoder.7.attn.q_norm.bias', 'module.decoder.7.attn.k_norm.weight', 'module.decoder.7.attn.k_norm.bias', 'module.decoder.7.ls1.gamma', 'module.decoder.7.norm2.weight', 'module.decoder.7.norm2.bias', 'module.decoder.7.mlp.fc1.weight', 'module.decoder.7.mlp.fc1.bias', 'module.decoder.7.mlp.fc2.weight', 'module.decoder.7.mlp.fc2.bias', 'module.decoder.7.ls2.gamma', 'module.decoder.8.norm1.weight', 'module.decoder.8.norm1.bias', 'module.decoder.8.attn.qkv.weight', 'module.decoder.8.attn.qkv.bias', 'module.decoder.8.attn.proj.weight', 'module.decoder.8.attn.proj.bias', 'module.decoder.8.attn.q_norm.weight', 'module.decoder.8.attn.q_norm.bias', 'module.decoder.8.attn.k_norm.weight', 'module.decoder.8.attn.k_norm.bias', 'module.decoder.8.ls1.gamma', 'module.decoder.8.norm2.weight', 'module.decoder.8.norm2.bias', 'module.decoder.8.mlp.fc1.weight', 'module.decoder.8.mlp.fc1.bias', 'module.decoder.8.mlp.fc2.weight', 'module.decoder.8.mlp.fc2.bias', 'module.decoder.8.ls2.gamma', 'module.decoder.9.norm1.weight', 'module.decoder.9.norm1.bias', 'module.decoder.9.attn.qkv.weight', 'module.decoder.9.attn.qkv.bias', 'module.decoder.9.attn.proj.weight', 'module.decoder.9.attn.proj.bias', 'module.decoder.9.attn.q_norm.weight', 'module.decoder.9.attn.q_norm.bias', 'module.decoder.9.attn.k_norm.weight', 'module.decoder.9.attn.k_norm.bias', 'module.decoder.9.ls1.gamma', 'module.decoder.9.norm2.weight', 'module.decoder.9.norm2.bias', 'module.decoder.9.mlp.fc1.weight', 'module.decoder.9.mlp.fc1.bias', 'module.decoder.9.mlp.fc2.weight', 'module.decoder.9.mlp.fc2.bias', 'module.decoder.9.ls2.gamma', 'module.decoder.10.norm1.weight', 'module.decoder.10.norm1.bias', 'module.decoder.10.attn.qkv.weight', 'module.decoder.10.attn.qkv.bias', 'module.decoder.10.attn.proj.weight', 'module.decoder.10.attn.proj.bias', 'module.decoder.10.attn.q_norm.weight', 'module.decoder.10.attn.q_norm.bias', 'module.decoder.10.attn.k_norm.weight', 'module.decoder.10.attn.k_norm.bias', 'module.decoder.10.ls1.gamma', 'module.decoder.10.norm2.weight', 'module.decoder.10.norm2.bias', 'module.decoder.10.mlp.fc1.weight', 'module.decoder.10.mlp.fc1.bias', 'module.decoder.10.mlp.fc2.weight', 'module.decoder.10.mlp.fc2.bias', 'module.decoder.10.ls2.gamma', 'module.decoder.11.norm1.weight', 'module.decoder.11.norm1.bias', 'module.decoder.11.attn.qkv.weight', 'module.decoder.11.attn.qkv.bias', 'module.decoder.11.attn.proj.weight', 'module.decoder.11.attn.proj.bias', 'module.decoder.11.attn.q_norm.weight', 'module.decoder.11.attn.q_norm.bias', 'module.decoder.11.attn.k_norm.weight', 'module.decoder.11.attn.k_norm.bias', 'module.decoder.11.ls1.gamma', 'module.decoder.11.norm2.weight', 'module.decoder.11.norm2.bias', 'module.decoder.11.mlp.fc1.weight', 'module.decoder.11.mlp.fc1.bias', 'module.decoder.11.mlp.fc2.weight', 'module.decoder.11.mlp.fc2.bias', 'module.decoder.11.ls2.gamma', 'module.decoder.12.norm1.weight', 'module.decoder.12.norm1.bias', 'module.decoder.12.attn.qkv.weight', 'module.decoder.12.attn.qkv.bias', 'module.decoder.12.attn.proj.weight', 'module.decoder.12.attn.proj.bias', 'module.decoder.12.attn.q_norm.weight', 'module.decoder.12.attn.q_norm.bias', 'module.decoder.12.attn.k_norm.weight', 'module.decoder.12.attn.k_norm.bias', 'module.decoder.12.ls1.gamma', 'module.decoder.12.norm2.weight', 'module.decoder.12.norm2.bias', 'module.decoder.12.mlp.fc1.weight', 'module.decoder.12.mlp.fc1.bias', 'module.decoder.12.mlp.fc2.weight', 'module.decoder.12.mlp.fc2.bias', 'module.decoder.12.ls2.gamma', 'module.decoder.13.norm1.weight', 'module.decoder.13.norm1.bias', 'module.decoder.13.attn.qkv.weight', 'module.decoder.13.attn.qkv.bias', 'module.decoder.13.attn.proj.weight', 'module.decoder.13.attn.proj.bias', 'module.decoder.13.attn.q_norm.weight', 'module.decoder.13.attn.q_norm.bias', 'module.decoder.13.attn.k_norm.weight', 'module.decoder.13.attn.k_norm.bias', 'module.decoder.13.ls1.gamma', 'module.decoder.13.norm2.weight', 'module.decoder.13.norm2.bias', 'module.decoder.13.mlp.fc1.weight', 'module.decoder.13.mlp.fc1.bias', 'module.decoder.13.mlp.fc2.weight', 'module.decoder.13.mlp.fc2.bias', 'module.decoder.13.ls2.gamma', 'module.decoder.14.norm1.weight', 'module.decoder.14.norm1.bias', 'module.decoder.14.attn.qkv.weight', 'module.decoder.14.attn.qkv.bias', 'module.decoder.14.attn.proj.weight', 'module.decoder.14.attn.proj.bias', 'module.decoder.14.attn.q_norm.weight', 'module.decoder.14.attn.q_norm.bias', 'module.decoder.14.attn.k_norm.weight', 'module.decoder.14.attn.k_norm.bias', 'module.decoder.14.ls1.gamma', 'module.decoder.14.norm2.weight', 'module.decoder.14.norm2.bias', 'module.decoder.14.mlp.fc1.weight', 'module.decoder.14.mlp.fc1.bias', 'module.decoder.14.mlp.fc2.weight', 'module.decoder.14.mlp.fc2.bias', 'module.decoder.14.ls2.gamma', 'module.decoder.15.norm1.weight', 'module.decoder.15.norm1.bias', 'module.decoder.15.attn.qkv.weight', 'module.decoder.15.attn.qkv.bias', 'module.decoder.15.attn.proj.weight', 'module.decoder.15.attn.proj.bias', 'module.decoder.15.attn.q_norm.weight', 'module.decoder.15.attn.q_norm.bias', 'module.decoder.15.attn.k_norm.weight', 'module.decoder.15.attn.k_norm.bias', 'module.decoder.15.ls1.gamma', 'module.decoder.15.norm2.weight', 'module.decoder.15.norm2.bias', 'module.decoder.15.mlp.fc1.weight', 'module.decoder.15.mlp.fc1.bias', 'module.decoder.15.mlp.fc2.weight', 'module.decoder.15.mlp.fc2.bias', 'module.decoder.15.ls2.gamma', 'module.decoder.16.norm1.weight', 'module.decoder.16.norm1.bias', 'module.decoder.16.attn.qkv.weight', 'module.decoder.16.attn.qkv.bias', 'module.decoder.16.attn.proj.weight', 'module.decoder.16.attn.proj.bias', 'module.decoder.16.attn.q_norm.weight', 'module.decoder.16.attn.q_norm.bias', 'module.decoder.16.attn.k_norm.weight', 'module.decoder.16.attn.k_norm.bias', 'module.decoder.16.ls1.gamma', 'module.decoder.16.norm2.weight', 'module.decoder.16.norm2.bias', 'module.decoder.16.mlp.fc1.weight', 'module.decoder.16.mlp.fc1.bias', 'module.decoder.16.mlp.fc2.weight', 'module.decoder.16.mlp.fc2.bias', 'module.decoder.16.ls2.gamma', 'module.decoder.17.norm1.weight', 'module.decoder.17.norm1.bias', 'module.decoder.17.attn.qkv.weight', 'module.decoder.17.attn.qkv.bias', 'module.decoder.17.attn.proj.weight', 'module.decoder.17.attn.proj.bias', 'module.decoder.17.attn.q_norm.weight', 'module.decoder.17.attn.q_norm.bias', 'module.decoder.17.attn.k_norm.weight', 'module.decoder.17.attn.k_norm.bias', 'module.decoder.17.ls1.gamma', 'module.decoder.17.norm2.weight', 'module.decoder.17.norm2.bias', 'module.decoder.17.mlp.fc1.weight', 'module.decoder.17.mlp.fc1.bias', 'module.decoder.17.mlp.fc2.weight', 'module.decoder.17.mlp.fc2.bias', 'module.decoder.17.ls2.gamma', 'module.decoder.18.norm1.weight', 'module.decoder.18.norm1.bias', 'module.decoder.18.attn.qkv.weight', 'module.decoder.18.attn.qkv.bias', 'module.decoder.18.attn.proj.weight', 'module.decoder.18.attn.proj.bias', 'module.decoder.18.attn.q_norm.weight', 'module.decoder.18.attn.q_norm.bias', 'module.decoder.18.attn.k_norm.weight', 'module.decoder.18.attn.k_norm.bias', 'module.decoder.18.ls1.gamma', 'module.decoder.18.norm2.weight', 'module.decoder.18.norm2.bias', 'module.decoder.18.mlp.fc1.weight', 'module.decoder.18.mlp.fc1.bias', 'module.decoder.18.mlp.fc2.weight', 'module.decoder.18.mlp.fc2.bias', 'module.decoder.18.ls2.gamma', 'module.decoder.19.norm1.weight', 'module.decoder.19.norm1.bias', 'module.decoder.19.attn.qkv.weight', 'module.decoder.19.attn.qkv.bias', 'module.decoder.19.attn.proj.weight', 'module.decoder.19.attn.proj.bias', 'module.decoder.19.attn.q_norm.weight', 'module.decoder.19.attn.q_norm.bias', 'module.decoder.19.attn.k_norm.weight', 'module.decoder.19.attn.k_norm.bias', 'module.decoder.19.ls1.gamma', 'module.decoder.19.norm2.weight', 'module.decoder.19.norm2.bias', 'module.decoder.19.mlp.fc1.weight', 'module.decoder.19.mlp.fc1.bias', 'module.decoder.19.mlp.fc2.weight', 'module.decoder.19.mlp.fc2.bias', 'module.decoder.19.ls2.gamma', 'module.decoder.20.norm1.weight', 'module.decoder.20.norm1.bias', 'module.decoder.20.attn.qkv.weight', 'module.decoder.20.attn.qkv.bias', 'module.decoder.20.attn.proj.weight', 'module.decoder.20.attn.proj.bias', 'module.decoder.20.attn.q_norm.weight', 'module.decoder.20.attn.q_norm.bias', 'module.decoder.20.attn.k_norm.weight', 'module.decoder.20.attn.k_norm.bias', 'module.decoder.20.ls1.gamma', 'module.decoder.20.norm2.weight', 'module.decoder.20.norm2.bias', 'module.decoder.20.mlp.fc1.weight', 'module.decoder.20.mlp.fc1.bias', 'module.decoder.20.mlp.fc2.weight', 'module.decoder.20.mlp.fc2.bias', 'module.decoder.20.ls2.gamma', 'module.decoder.21.norm1.weight', 'module.decoder.21.norm1.bias', 'module.decoder.21.attn.qkv.weight', 'module.decoder.21.attn.qkv.bias', 'module.decoder.21.attn.proj.weight', 'module.decoder.21.attn.proj.bias', 'module.decoder.21.attn.q_norm.weight', 'module.decoder.21.attn.q_norm.bias', 'module.decoder.21.attn.k_norm.weight', 'module.decoder.21.attn.k_norm.bias', 'module.decoder.21.ls1.gamma', 'module.decoder.21.norm2.weight', 'module.decoder.21.norm2.bias', 'module.decoder.21.mlp.fc1.weight', 'module.decoder.21.mlp.fc1.bias', 'module.decoder.21.mlp.fc2.weight', 'module.decoder.21.mlp.fc2.bias', 'module.decoder.21.ls2.gamma', 'module.decoder.22.norm1.weight', 'module.decoder.22.norm1.bias', 'module.decoder.22.attn.qkv.weight', 'module.decoder.22.attn.qkv.bias', 'module.decoder.22.attn.proj.weight', 'module.decoder.22.attn.proj.bias', 'module.decoder.22.attn.q_norm.weight', 'module.decoder.22.attn.q_norm.bias', 'module.decoder.22.attn.k_norm.weight', 'module.decoder.22.attn.k_norm.bias', 'module.decoder.22.ls1.gamma', 'module.decoder.22.norm2.weight', 'module.decoder.22.norm2.bias', 'module.decoder.22.mlp.fc1.weight', 'module.decoder.22.mlp.fc1.bias', 'module.decoder.22.mlp.fc2.weight', 'module.decoder.22.mlp.fc2.bias', 'module.decoder.22.ls2.gamma', 'module.decoder.23.norm1.weight', 'module.decoder.23.norm1.bias', 'module.decoder.23.attn.qkv.weight', 'module.decoder.23.attn.qkv.bias', 'module.decoder.23.attn.proj.weight', 'module.decoder.23.attn.proj.bias', 'module.decoder.23.attn.q_norm.weight', 'module.decoder.23.attn.q_norm.bias', 'module.decoder.23.attn.k_norm.weight', 'module.decoder.23.attn.k_norm.bias', 'module.decoder.23.ls1.gamma', 'module.decoder.23.norm2.weight', 'module.decoder.23.norm2.bias', 'module.decoder.23.mlp.fc1.weight', 'module.decoder.23.mlp.fc1.bias', 'module.decoder.23.mlp.fc2.weight', 'module.decoder.23.mlp.fc2.bias', 'module.decoder.23.ls2.gamma', 'module.decoder.24.norm1.weight', 'module.decoder.24.norm1.bias', 'module.decoder.24.attn.qkv.weight', 'module.decoder.24.attn.qkv.bias', 'module.decoder.24.attn.proj.weight', 'module.decoder.24.attn.proj.bias', 'module.decoder.24.attn.q_norm.weight', 'module.decoder.24.attn.q_norm.bias', 'module.decoder.24.attn.k_norm.weight', 'module.decoder.24.attn.k_norm.bias', 'module.decoder.24.ls1.gamma', 'module.decoder.24.norm2.weight', 'module.decoder.24.norm2.bias', 'module.decoder.24.mlp.fc1.weight', 'module.decoder.24.mlp.fc1.bias', 'module.decoder.24.mlp.fc2.weight', 'module.decoder.24.mlp.fc2.bias', 'module.decoder.24.ls2.gamma', 'module.decoder.25.norm1.weight', 'module.decoder.25.norm1.bias', 'module.decoder.25.attn.qkv.weight', 'module.decoder.25.attn.qkv.bias', 'module.decoder.25.attn.proj.weight', 'module.decoder.25.attn.proj.bias', 'module.decoder.25.attn.q_norm.weight', 'module.decoder.25.attn.q_norm.bias', 'module.decoder.25.attn.k_norm.weight', 'module.decoder.25.attn.k_norm.bias', 'module.decoder.25.ls1.gamma', 'module.decoder.25.norm2.weight', 'module.decoder.25.norm2.bias', 'module.decoder.25.mlp.fc1.weight', 'module.decoder.25.mlp.fc1.bias', 'module.decoder.25.mlp.fc2.weight', 'module.decoder.25.mlp.fc2.bias', 'module.decoder.25.ls2.gamma', 'module.decoder.26.norm1.weight', 'module.decoder.26.norm1.bias', 'module.decoder.26.attn.qkv.weight', 'module.decoder.26.attn.qkv.bias', 'module.decoder.26.attn.proj.weight', 'module.decoder.26.attn.proj.bias', 'module.decoder.26.attn.q_norm.weight', 'module.decoder.26.attn.q_norm.bias', 'module.decoder.26.attn.k_norm.weight', 'module.decoder.26.attn.k_norm.bias', 'module.decoder.26.ls1.gamma', 'module.decoder.26.norm2.weight', 'module.decoder.26.norm2.bias', 'module.decoder.26.mlp.fc1.weight', 'module.decoder.26.mlp.fc1.bias', 'module.decoder.26.mlp.fc2.weight', 'module.decoder.26.mlp.fc2.bias', 'module.decoder.26.ls2.gamma', 'module.decoder.27.norm1.weight', 'module.decoder.27.norm1.bias', 'module.decoder.27.attn.qkv.weight', 'module.decoder.27.attn.qkv.bias', 'module.decoder.27.attn.proj.weight', 'module.decoder.27.attn.proj.bias', 'module.decoder.27.attn.q_norm.weight', 'module.decoder.27.attn.q_norm.bias', 'module.decoder.27.attn.k_norm.weight', 'module.decoder.27.attn.k_norm.bias', 'module.decoder.27.ls1.gamma', 'module.decoder.27.norm2.weight', 'module.decoder.27.norm2.bias', 'module.decoder.27.mlp.fc1.weight', 'module.decoder.27.mlp.fc1.bias', 'module.decoder.27.mlp.fc2.weight', 'module.decoder.27.mlp.fc2.bias', 'module.decoder.27.ls2.gamma', 'module.decoder.28.norm1.weight', 'module.decoder.28.norm1.bias', 'module.decoder.28.attn.qkv.weight', 'module.decoder.28.attn.qkv.bias', 'module.decoder.28.attn.proj.weight', 'module.decoder.28.attn.proj.bias', 'module.decoder.28.attn.q_norm.weight', 'module.decoder.28.attn.q_norm.bias', 'module.decoder.28.attn.k_norm.weight', 'module.decoder.28.attn.k_norm.bias', 'module.decoder.28.ls1.gamma', 'module.decoder.28.norm2.weight', 'module.decoder.28.norm2.bias', 'module.decoder.28.mlp.fc1.weight', 'module.decoder.28.mlp.fc1.bias', 'module.decoder.28.mlp.fc2.weight', 'module.decoder.28.mlp.fc2.bias', 'module.decoder.28.ls2.gamma', 'module.decoder.29.norm1.weight', 'module.decoder.29.norm1.bias', 'module.decoder.29.attn.qkv.weight', 'module.decoder.29.attn.qkv.bias', 'module.decoder.29.attn.proj.weight', 'module.decoder.29.attn.proj.bias', 'module.decoder.29.attn.q_norm.weight', 'module.decoder.29.attn.q_norm.bias', 'module.decoder.29.attn.k_norm.weight', 'module.decoder.29.attn.k_norm.bias', 'module.decoder.29.ls1.gamma', 'module.decoder.29.norm2.weight', 'module.decoder.29.norm2.bias', 'module.decoder.29.mlp.fc1.weight', 'module.decoder.29.mlp.fc1.bias', 'module.decoder.29.mlp.fc2.weight', 'module.decoder.29.mlp.fc2.bias', 'module.decoder.29.ls2.gamma', 'module.decoder.30.norm1.weight', 'module.decoder.30.norm1.bias', 'module.decoder.30.attn.qkv.weight', 'module.decoder.30.attn.qkv.bias', 'module.decoder.30.attn.proj.weight', 'module.decoder.30.attn.proj.bias', 'module.decoder.30.attn.q_norm.weight', 'module.decoder.30.attn.q_norm.bias', 'module.decoder.30.attn.k_norm.weight', 'module.decoder.30.attn.k_norm.bias', 'module.decoder.30.ls1.gamma', 'module.decoder.30.norm2.weight', 'module.decoder.30.norm2.bias', 'module.decoder.30.mlp.fc1.weight', 'module.decoder.30.mlp.fc1.bias', 'module.decoder.30.mlp.fc2.weight', 'module.decoder.30.mlp.fc2.bias', 'module.decoder.30.ls2.gamma', 'module.decoder.31.norm1.weight', 'module.decoder.31.norm1.bias', 'module.decoder.31.attn.qkv.weight', 'module.decoder.31.attn.qkv.bias', 'module.decoder.31.attn.proj.weight', 'module.decoder.31.attn.proj.bias', 'module.decoder.31.attn.q_norm.weight', 'module.decoder.31.attn.q_norm.bias', 'module.decoder.31.attn.k_norm.weight', 'module.decoder.31.attn.k_norm.bias', 'module.decoder.31.ls1.gamma', 'module.decoder.31.norm2.weight', 'module.decoder.31.norm2.bias', 'module.decoder.31.mlp.fc1.weight', 'module.decoder.31.mlp.fc1.bias', 'module.decoder.31.mlp.fc2.weight', 'module.decoder.31.mlp.fc2.bias', 'module.decoder.31.ls2.gamma', 'module.decoder.32.norm1.weight', 'module.decoder.32.norm1.bias', 'module.decoder.32.attn.qkv.weight', 'module.decoder.32.attn.qkv.bias', 'module.decoder.32.attn.proj.weight', 'module.decoder.32.attn.proj.bias', 'module.decoder.32.attn.q_norm.weight', 'module.decoder.32.attn.q_norm.bias', 'module.decoder.32.attn.k_norm.weight', 'module.decoder.32.attn.k_norm.bias', 'module.decoder.32.ls1.gamma', 'module.decoder.32.norm2.weight', 'module.decoder.32.norm2.bias', 'module.decoder.32.mlp.fc1.weight', 'module.decoder.32.mlp.fc1.bias', 'module.decoder.32.mlp.fc2.weight', 'module.decoder.32.mlp.fc2.bias', 'module.decoder.32.ls2.gamma', 'module.decoder.33.norm1.weight', 'module.decoder.33.norm1.bias', 'module.decoder.33.attn.qkv.weight', 'module.decoder.33.attn.qkv.bias', 'module.decoder.33.attn.proj.weight', 'module.decoder.33.attn.proj.bias', 'module.decoder.33.attn.q_norm.weight', 'module.decoder.33.attn.q_norm.bias', 'module.decoder.33.attn.k_norm.weight', 'module.decoder.33.attn.k_norm.bias', 'module.decoder.33.ls1.gamma', 'module.decoder.33.norm2.weight', 'module.decoder.33.norm2.bias', 'module.decoder.33.mlp.fc1.weight', 'module.decoder.33.mlp.fc1.bias', 'module.decoder.33.mlp.fc2.weight', 'module.decoder.33.mlp.fc2.bias', 'module.decoder.33.ls2.gamma', 'module.decoder.34.norm1.weight', 'module.decoder.34.norm1.bias', 'module.decoder.34.attn.qkv.weight', 'module.decoder.34.attn.qkv.bias', 'module.decoder.34.attn.proj.weight', 'module.decoder.34.attn.proj.bias', 'module.decoder.34.attn.q_norm.weight', 'module.decoder.34.attn.q_norm.bias', 'module.decoder.34.attn.k_norm.weight', 'module.decoder.34.attn.k_norm.bias', 'module.decoder.34.ls1.gamma', 'module.decoder.34.norm2.weight', 'module.decoder.34.norm2.bias', 'module.decoder.34.mlp.fc1.weight', 'module.decoder.34.mlp.fc1.bias', 'module.decoder.34.mlp.fc2.weight', 'module.decoder.34.mlp.fc2.bias', 'module.decoder.34.ls2.gamma', 'module.decoder.35.norm1.weight', 'module.decoder.35.norm1.bias', 'module.decoder.35.attn.qkv.weight', 'module.decoder.35.attn.qkv.bias', 'module.decoder.35.attn.proj.weight', 'module.decoder.35.attn.proj.bias', 'module.decoder.35.attn.q_norm.weight', 'module.decoder.35.attn.q_norm.bias', 'module.decoder.35.attn.k_norm.weight', 'module.decoder.35.attn.k_norm.bias', 'module.decoder.35.ls1.gamma', 'module.decoder.35.norm2.weight', 'module.decoder.35.norm2.bias', 'module.decoder.35.mlp.fc1.weight', 'module.decoder.35.mlp.fc1.bias', 'module.decoder.35.mlp.fc2.weight', 'module.decoder.35.mlp.fc2.bias', 'module.decoder.35.ls2.gamma', 'module.point_decoder.projects.weight', 'module.point_decoder.projects.bias', 'module.point_decoder.blocks.0.norm1.weight', 'module.point_decoder.blocks.0.norm1.bias', 'module.point_decoder.blocks.0.attn.qkv.weight', 'module.point_decoder.blocks.0.attn.qkv.bias', 'module.point_decoder.blocks.0.attn.proj.weight', 'module.point_decoder.blocks.0.attn.proj.bias', 'module.point_decoder.blocks.0.norm2.weight', 'module.point_decoder.blocks.0.norm2.bias', 'module.point_decoder.blocks.0.mlp.fc1.weight', 'module.point_decoder.blocks.0.mlp.fc1.bias', 'module.point_decoder.blocks.0.mlp.fc2.weight', 'module.point_decoder.blocks.0.mlp.fc2.bias', 'module.point_decoder.blocks.1.norm1.weight', 'module.point_decoder.blocks.1.norm1.bias', 'module.point_decoder.blocks.1.attn.qkv.weight', 'module.point_decoder.blocks.1.attn.qkv.bias', 'module.point_decoder.blocks.1.attn.proj.weight', 'module.point_decoder.blocks.1.attn.proj.bias', 'module.point_decoder.blocks.1.norm2.weight', 'module.point_decoder.blocks.1.norm2.bias', 'module.point_decoder.blocks.1.mlp.fc1.weight', 'module.point_decoder.blocks.1.mlp.fc1.bias', 'module.point_decoder.blocks.1.mlp.fc2.weight', 'module.point_decoder.blocks.1.mlp.fc2.bias', 'module.point_decoder.blocks.2.norm1.weight', 'module.point_decoder.blocks.2.norm1.bias', 'module.point_decoder.blocks.2.attn.qkv.weight', 'module.point_decoder.blocks.2.attn.qkv.bias', 'module.point_decoder.blocks.2.attn.proj.weight', 'module.point_decoder.blocks.2.attn.proj.bias', 'module.point_decoder.blocks.2.norm2.weight', 'module.point_decoder.blocks.2.norm2.bias', 'module.point_decoder.blocks.2.mlp.fc1.weight', 'module.point_decoder.blocks.2.mlp.fc1.bias', 'module.point_decoder.blocks.2.mlp.fc2.weight', 'module.point_decoder.blocks.2.mlp.fc2.bias', 'module.point_decoder.blocks.3.norm1.weight', 'module.point_decoder.blocks.3.norm1.bias', 'module.point_decoder.blocks.3.attn.qkv.weight', 'module.point_decoder.blocks.3.attn.qkv.bias', 'module.point_decoder.blocks.3.attn.proj.weight', 'module.point_decoder.blocks.3.attn.proj.bias', 'module.point_decoder.blocks.3.norm2.weight', 'module.point_decoder.blocks.3.norm2.bias', 'module.point_decoder.blocks.3.mlp.fc1.weight', 'module.point_decoder.blocks.3.mlp.fc1.bias', 'module.point_decoder.blocks.3.mlp.fc2.weight', 'module.point_decoder.blocks.3.mlp.fc2.bias', 'module.point_decoder.blocks.4.norm1.weight', 'module.point_decoder.blocks.4.norm1.bias', 'module.point_decoder.blocks.4.attn.qkv.weight', 'module.point_decoder.blocks.4.attn.qkv.bias', 'module.point_decoder.blocks.4.attn.proj.weight', 'module.point_decoder.blocks.4.attn.proj.bias', 'module.point_decoder.blocks.4.norm2.weight', 'module.point_decoder.blocks.4.norm2.bias', 'module.point_decoder.blocks.4.mlp.fc1.weight', 'module.point_decoder.blocks.4.mlp.fc1.bias', 'module.point_decoder.blocks.4.mlp.fc2.weight', 'module.point_decoder.blocks.4.mlp.fc2.bias', 'module.point_decoder.linear_out.weight', 'module.point_decoder.linear_out.bias', 'module.point_head.proj.weight', 'module.point_head.proj.bias', 'module.conf_decoder.projects.weight', 'module.conf_decoder.projects.bias', 'module.conf_decoder.blocks.0.norm1.weight', 'module.conf_decoder.blocks.0.norm1.bias', 'module.conf_decoder.blocks.0.attn.qkv.weight', 'module.conf_decoder.blocks.0.attn.qkv.bias', 'module.conf_decoder.blocks.0.attn.proj.weight', 'module.conf_decoder.blocks.0.attn.proj.bias', 'module.conf_decoder.blocks.0.norm2.weight', 'module.conf_decoder.blocks.0.norm2.bias', 'module.conf_decoder.blocks.0.mlp.fc1.weight', 'module.conf_decoder.blocks.0.mlp.fc1.bias', 'module.conf_decoder.blocks.0.mlp.fc2.weight', 'module.conf_decoder.blocks.0.mlp.fc2.bias', 'module.conf_decoder.blocks.1.norm1.weight', 'module.conf_decoder.blocks.1.norm1.bias', 'module.conf_decoder.blocks.1.attn.qkv.weight', 'module.conf_decoder.blocks.1.attn.qkv.bias', 'module.conf_decoder.blocks.1.attn.proj.weight', 'module.conf_decoder.blocks.1.attn.proj.bias', 'module.conf_decoder.blocks.1.norm2.weight', 'module.conf_decoder.blocks.1.norm2.bias', 'module.conf_decoder.blocks.1.mlp.fc1.weight', 'module.conf_decoder.blocks.1.mlp.fc1.bias', 'module.conf_decoder.blocks.1.mlp.fc2.weight', 'module.conf_decoder.blocks.1.mlp.fc2.bias', 'module.conf_decoder.blocks.2.norm1.weight', 'module.conf_decoder.blocks.2.norm1.bias', 'module.conf_decoder.blocks.2.attn.qkv.weight', 'module.conf_decoder.blocks.2.attn.qkv.bias', 'module.conf_decoder.blocks.2.attn.proj.weight', 'module.conf_decoder.blocks.2.attn.proj.bias', 'module.conf_decoder.blocks.2.norm2.weight', 'module.conf_decoder.blocks.2.norm2.bias', 'module.conf_decoder.blocks.2.mlp.fc1.weight', 'module.conf_decoder.blocks.2.mlp.fc1.bias', 'module.conf_decoder.blocks.2.mlp.fc2.weight', 'module.conf_decoder.blocks.2.mlp.fc2.bias', 'module.conf_decoder.blocks.3.norm1.weight', 'module.conf_decoder.blocks.3.norm1.bias', 'module.conf_decoder.blocks.3.attn.qkv.weight', 'module.conf_decoder.blocks.3.attn.qkv.bias', 'module.conf_decoder.blocks.3.attn.proj.weight', 'module.conf_decoder.blocks.3.attn.proj.bias', 'module.conf_decoder.blocks.3.norm2.weight', 'module.conf_decoder.blocks.3.norm2.bias', 'module.conf_decoder.blocks.3.mlp.fc1.weight', 'module.conf_decoder.blocks.3.mlp.fc1.bias', 'module.conf_decoder.blocks.3.mlp.fc2.weight', 'module.conf_decoder.blocks.3.mlp.fc2.bias', 'module.conf_decoder.blocks.4.norm1.weight', 'module.conf_decoder.blocks.4.norm1.bias', 'module.conf_decoder.blocks.4.attn.qkv.weight', 'module.conf_decoder.blocks.4.attn.qkv.bias', 'module.conf_decoder.blocks.4.attn.proj.weight', 'module.conf_decoder.blocks.4.attn.proj.bias', 'module.conf_decoder.blocks.4.norm2.weight', 'module.conf_decoder.blocks.4.norm2.bias', 'module.conf_decoder.blocks.4.mlp.fc1.weight', 'module.conf_decoder.blocks.4.mlp.fc1.bias', 'module.conf_decoder.blocks.4.mlp.fc2.weight', 'module.conf_decoder.blocks.4.mlp.fc2.bias', 'module.conf_decoder.linear_out.weight', 'module.conf_decoder.linear_out.bias', 'module.conf_head.proj.weight', 'module.conf_head.proj.bias', 'module.camera_decoder.projects.weight', 'module.camera_decoder.projects.bias', 'module.camera_decoder.blocks.0.norm1.weight', 'module.camera_decoder.blocks.0.norm1.bias', 'module.camera_decoder.blocks.0.attn.qkv.weight', 'module.camera_decoder.blocks.0.attn.qkv.bias', 'module.camera_decoder.blocks.0.attn.proj.weight', 'module.camera_decoder.blocks.0.attn.proj.bias', 'module.camera_decoder.blocks.0.norm2.weight', 'module.camera_decoder.blocks.0.norm2.bias', 'module.camera_decoder.blocks.0.mlp.fc1.weight', 'module.camera_decoder.blocks.0.mlp.fc1.bias', 'module.camera_decoder.blocks.0.mlp.fc2.weight', 'module.camera_decoder.blocks.0.mlp.fc2.bias', 'module.camera_decoder.blocks.1.norm1.weight', 'module.camera_decoder.blocks.1.norm1.bias', 'module.camera_decoder.blocks.1.attn.qkv.weight', 'module.camera_decoder.blocks.1.attn.qkv.bias', 'module.camera_decoder.blocks.1.attn.proj.weight', 'module.camera_decoder.blocks.1.attn.proj.bias', 'module.camera_decoder.blocks.1.norm2.weight', 'module.camera_decoder.blocks.1.norm2.bias', 'module.camera_decoder.blocks.1.mlp.fc1.weight', 'module.camera_decoder.blocks.1.mlp.fc1.bias', 'module.camera_decoder.blocks.1.mlp.fc2.weight', 'module.camera_decoder.blocks.1.mlp.fc2.bias', 'module.camera_decoder.blocks.2.norm1.weight', 'module.camera_decoder.blocks.2.norm1.bias', 'module.camera_decoder.blocks.2.attn.qkv.weight', 'module.camera_decoder.blocks.2.attn.qkv.bias', 'module.camera_decoder.blocks.2.attn.proj.weight', 'module.camera_decoder.blocks.2.attn.proj.bias', 'module.camera_decoder.blocks.2.norm2.weight', 'module.camera_decoder.blocks.2.norm2.bias', 'module.camera_decoder.blocks.2.mlp.fc1.weight', 'module.camera_decoder.blocks.2.mlp.fc1.bias', 'module.camera_decoder.blocks.2.mlp.fc2.weight', 'module.camera_decoder.blocks.2.mlp.fc2.bias', 'module.camera_decoder.blocks.3.norm1.weight', 'module.camera_decoder.blocks.3.norm1.bias', 'module.camera_decoder.blocks.3.attn.qkv.weight', 'module.camera_decoder.blocks.3.attn.qkv.bias', 'module.camera_decoder.blocks.3.attn.proj.weight', 'module.camera_decoder.blocks.3.attn.proj.bias', 'module.camera_decoder.blocks.3.norm2.weight', 'module.camera_decoder.blocks.3.norm2.bias', 'module.camera_decoder.blocks.3.mlp.fc1.weight', 'module.camera_decoder.blocks.3.mlp.fc1.bias', 'module.camera_decoder.blocks.3.mlp.fc2.weight', 'module.camera_decoder.blocks.3.mlp.fc2.bias', 'module.camera_decoder.blocks.4.norm1.weight', 'module.camera_decoder.blocks.4.norm1.bias', 'module.camera_decoder.blocks.4.attn.qkv.weight', 'module.camera_decoder.blocks.4.attn.qkv.bias', 'module.camera_decoder.blocks.4.attn.proj.weight', 'module.camera_decoder.blocks.4.attn.proj.bias', 'module.camera_decoder.blocks.4.norm2.weight', 'module.camera_decoder.blocks.4.norm2.bias', 'module.camera_decoder.blocks.4.mlp.fc1.weight', 'module.camera_decoder.blocks.4.mlp.fc1.bias', 'module.camera_decoder.blocks.4.mlp.fc2.weight', 'module.camera_decoder.blocks.4.mlp.fc2.bias', 'module.camera_decoder.linear_out.weight', 'module.camera_decoder.linear_out.bias', 'module.camera_head.res_conv.0.res_conv1.weight', 'module.camera_head.res_conv.0.res_conv1.bias', 'module.camera_head.res_conv.0.res_conv2.weight', 'module.camera_head.res_conv.0.res_conv2.bias', 'module.camera_head.res_conv.0.res_conv3.weight', 'module.camera_head.res_conv.0.res_conv3.bias', 'module.camera_head.res_conv.1.res_conv1.weight', 'module.camera_head.res_conv.1.res_conv1.bias', 'module.camera_head.res_conv.1.res_conv2.weight', 'module.camera_head.res_conv.1.res_conv2.bias', 'module.camera_head.res_conv.1.res_conv3.weight', 'module.camera_head.res_conv.1.res_conv3.bias', 'module.camera_head.more_mlps.0.weight', 'module.camera_head.more_mlps.0.bias', 'module.camera_head.more_mlps.2.weight', 'module.camera_head.more_mlps.2.bias', 'module.camera_head.fc_t.weight', 'module.camera_head.fc_t.bias', 'module.camera_head.fc_rot.weight', 'module.camera_head.fc_rot.bias']) +[2026-05-02 01:11:27,515][__main__][INFO] - [RANK 0] Freezing patch embedding and positional encoding parameters... +[2026-05-02 01:11:27,520][__main__][INFO] - [RANK 0] Frozen 304,376,832 parameters out of 958,696,732 total parameters. (31.75%) +[2026-05-02 01:11:27,520][__main__][INFO] - [RANK 0] Trainable parameters: 654,319,900 (68.25%) +[2026-05-02 01:11:27,520][__main__][INFO] - [RANK 0] Example frozen parameters: register_token, encoder.cls_token, encoder.pos_embed, encoder.register_tokens, encoder.patch_embed.proj.weight... +[2026-05-02 01:11:27,523][croco.utils.misc][INFO] - [RANK 0] Param groups = { + "no_decay": { + "weight_decay": 0.0, + "params": [ + "decoder.0.norm1.weight", + "decoder.0.norm1.bias", + "decoder.0.attn.qkv.bias", + "decoder.0.attn.proj.bias", + "decoder.0.attn.q_norm.weight", + "decoder.0.attn.q_norm.bias", + "decoder.0.attn.k_norm.weight", + "decoder.0.attn.k_norm.bias", + "decoder.0.ls1.gamma", + "decoder.0.norm2.weight", + "decoder.0.norm2.bias", + "decoder.0.mlp.fc1.bias", + "decoder.0.mlp.fc2.bias", + "decoder.0.ls2.gamma", + "decoder.1.norm1.weight", + "decoder.1.norm1.bias", + "decoder.1.attn.qkv.bias", + "decoder.1.attn.proj.bias", + "decoder.1.attn.q_norm.weight", + "decoder.1.attn.q_norm.bias", + "decoder.1.attn.k_norm.weight", + "decoder.1.attn.k_norm.bias", + "decoder.1.ls1.gamma", + "decoder.1.norm2.weight", + "decoder.1.norm2.bias", + "decoder.1.mlp.fc1.bias", + "decoder.1.mlp.fc2.bias", + "decoder.1.ls2.gamma", + "decoder.2.norm1.weight", + "decoder.2.norm1.bias", + "decoder.2.attn.qkv.bias", + "decoder.2.attn.proj.bias", + "decoder.2.attn.q_norm.weight", + "decoder.2.attn.q_norm.bias", + "decoder.2.attn.k_norm.weight", + "decoder.2.attn.k_norm.bias", + "decoder.2.ls1.gamma", + "decoder.2.norm2.weight", + "decoder.2.norm2.bias", + "decoder.2.mlp.fc1.bias", + "decoder.2.mlp.fc2.bias", + "decoder.2.ls2.gamma", + "decoder.3.norm1.weight", + "decoder.3.norm1.bias", + "decoder.3.attn.qkv.bias", + "decoder.3.attn.proj.bias", + "decoder.3.attn.q_norm.weight", + "decoder.3.attn.q_norm.bias", + "decoder.3.attn.k_norm.weight", + "decoder.3.attn.k_norm.bias", + "decoder.3.ls1.gamma", + "decoder.3.norm2.weight", + "decoder.3.norm2.bias", + "decoder.3.mlp.fc1.bias", + "decoder.3.mlp.fc2.bias", + "decoder.3.ls2.gamma", + "decoder.4.norm1.weight", + "decoder.4.norm1.bias", + "decoder.4.attn.qkv.bias", + "decoder.4.attn.proj.bias", + "decoder.4.attn.q_norm.weight", + "decoder.4.attn.q_norm.bias", + "decoder.4.attn.k_norm.weight", + "decoder.4.attn.k_norm.bias", + "decoder.4.ls1.gamma", + "decoder.4.norm2.weight", + "decoder.4.norm2.bias", + "decoder.4.mlp.fc1.bias", + "decoder.4.mlp.fc2.bias", + "decoder.4.ls2.gamma", + "decoder.5.norm1.weight", + "decoder.5.norm1.bias", + "decoder.5.attn.qkv.bias", + "decoder.5.attn.proj.bias", + "decoder.5.attn.q_norm.weight", + "decoder.5.attn.q_norm.bias", + "decoder.5.attn.k_norm.weight", + "decoder.5.attn.k_norm.bias", + "decoder.5.ls1.gamma", + "decoder.5.norm2.weight", + "decoder.5.norm2.bias", + "decoder.5.mlp.fc1.bias", + "decoder.5.mlp.fc2.bias", + "decoder.5.ls2.gamma", + "decoder.6.norm1.weight", + "decoder.6.norm1.bias", + "decoder.6.attn.qkv.bias", + "decoder.6.attn.proj.bias", + "decoder.6.attn.q_norm.weight", + "decoder.6.attn.q_norm.bias", + "decoder.6.attn.k_norm.weight", + "decoder.6.attn.k_norm.bias", + "decoder.6.ls1.gamma", + "decoder.6.norm2.weight", + "decoder.6.norm2.bias", + "decoder.6.mlp.fc1.bias", + "decoder.6.mlp.fc2.bias", + "decoder.6.ls2.gamma", + "decoder.7.norm1.weight", + "decoder.7.norm1.bias", + "decoder.7.attn.qkv.bias", + "decoder.7.attn.proj.bias", + "decoder.7.attn.q_norm.weight", + "decoder.7.attn.q_norm.bias", + "decoder.7.attn.k_norm.weight", + "decoder.7.attn.k_norm.bias", + "decoder.7.ls1.gamma", + "decoder.7.norm2.weight", + "decoder.7.norm2.bias", + "decoder.7.mlp.fc1.bias", + "decoder.7.mlp.fc2.bias", + "decoder.7.ls2.gamma", + "decoder.8.norm1.weight", + "decoder.8.norm1.bias", + "decoder.8.attn.qkv.bias", + "decoder.8.attn.proj.bias", + "decoder.8.attn.q_norm.weight", + "decoder.8.attn.q_norm.bias", + "decoder.8.attn.k_norm.weight", + "decoder.8.attn.k_norm.bias", + "decoder.8.ls1.gamma", + "decoder.8.norm2.weight", + "decoder.8.norm2.bias", + "decoder.8.mlp.fc1.bias", + "decoder.8.mlp.fc2.bias", + "decoder.8.ls2.gamma", + "decoder.9.norm1.weight", + "decoder.9.norm1.bias", + "decoder.9.attn.qkv.bias", + "decoder.9.attn.proj.bias", + "decoder.9.attn.q_norm.weight", + "decoder.9.attn.q_norm.bias", + "decoder.9.attn.k_norm.weight", + "decoder.9.attn.k_norm.bias", + "decoder.9.ls1.gamma", + "decoder.9.norm2.weight", + "decoder.9.norm2.bias", + "decoder.9.mlp.fc1.bias", + "decoder.9.mlp.fc2.bias", + "decoder.9.ls2.gamma", + "decoder.10.norm1.weight", + "decoder.10.norm1.bias", + "decoder.10.attn.qkv.bias", + "decoder.10.attn.proj.bias", + "decoder.10.attn.q_norm.weight", + "decoder.10.attn.q_norm.bias", + "decoder.10.attn.k_norm.weight", + "decoder.10.attn.k_norm.bias", + "decoder.10.ls1.gamma", + "decoder.10.norm2.weight", + "decoder.10.norm2.bias", + "decoder.10.mlp.fc1.bias", + "decoder.10.mlp.fc2.bias", + "decoder.10.ls2.gamma", + "decoder.11.norm1.weight", + "decoder.11.norm1.bias", + "decoder.11.attn.qkv.bias", + "decoder.11.attn.proj.bias", + "decoder.11.attn.q_norm.weight", + "decoder.11.attn.q_norm.bias", + "decoder.11.attn.k_norm.weight", + "decoder.11.attn.k_norm.bias", + "decoder.11.ls1.gamma", + "decoder.11.norm2.weight", + "decoder.11.norm2.bias", + "decoder.11.mlp.fc1.bias", + "decoder.11.mlp.fc2.bias", + "decoder.11.ls2.gamma", + "decoder.12.norm1.weight", + "decoder.12.norm1.bias", + "decoder.12.attn.qkv.bias", + "decoder.12.attn.proj.bias", + "decoder.12.attn.q_norm.weight", + "decoder.12.attn.q_norm.bias", + "decoder.12.attn.k_norm.weight", + "decoder.12.attn.k_norm.bias", + "decoder.12.ls1.gamma", + "decoder.12.norm2.weight", + "decoder.12.norm2.bias", + "decoder.12.mlp.fc1.bias", + "decoder.12.mlp.fc2.bias", + "decoder.12.ls2.gamma", + "decoder.13.norm1.weight", + "decoder.13.norm1.bias", + "decoder.13.attn.qkv.bias", + "decoder.13.attn.proj.bias", + "decoder.13.attn.q_norm.weight", + "decoder.13.attn.q_norm.bias", + "decoder.13.attn.k_norm.weight", + "decoder.13.attn.k_norm.bias", + "decoder.13.ls1.gamma", + "decoder.13.norm2.weight", + "decoder.13.norm2.bias", + "decoder.13.mlp.fc1.bias", + "decoder.13.mlp.fc2.bias", + "decoder.13.ls2.gamma", + "decoder.14.norm1.weight", + "decoder.14.norm1.bias", + "decoder.14.attn.qkv.bias", + "decoder.14.attn.proj.bias", + "decoder.14.attn.q_norm.weight", + "decoder.14.attn.q_norm.bias", + "decoder.14.attn.k_norm.weight", + "decoder.14.attn.k_norm.bias", + "decoder.14.ls1.gamma", + "decoder.14.norm2.weight", + "decoder.14.norm2.bias", + "decoder.14.mlp.fc1.bias", + "decoder.14.mlp.fc2.bias", + "decoder.14.ls2.gamma", + "decoder.15.norm1.weight", + "decoder.15.norm1.bias", + "decoder.15.attn.qkv.bias", + "decoder.15.attn.proj.bias", + "decoder.15.attn.q_norm.weight", + "decoder.15.attn.q_norm.bias", + "decoder.15.attn.k_norm.weight", + "decoder.15.attn.k_norm.bias", + "decoder.15.ls1.gamma", + "decoder.15.norm2.weight", + "decoder.15.norm2.bias", + "decoder.15.mlp.fc1.bias", + "decoder.15.mlp.fc2.bias", + "decoder.15.ls2.gamma", + "decoder.16.norm1.weight", + "decoder.16.norm1.bias", + "decoder.16.attn.qkv.bias", + "decoder.16.attn.proj.bias", + "decoder.16.attn.q_norm.weight", + "decoder.16.attn.q_norm.bias", + "decoder.16.attn.k_norm.weight", + "decoder.16.attn.k_norm.bias", + "decoder.16.ls1.gamma", + "decoder.16.norm2.weight", + "decoder.16.norm2.bias", + "decoder.16.mlp.fc1.bias", + "decoder.16.mlp.fc2.bias", + "decoder.16.ls2.gamma", + "decoder.17.norm1.weight", + "decoder.17.norm1.bias", + "decoder.17.attn.qkv.bias", + "decoder.17.attn.proj.bias", + "decoder.17.attn.q_norm.weight", + "decoder.17.attn.q_norm.bias", + "decoder.17.attn.k_norm.weight", + "decoder.17.attn.k_norm.bias", + "decoder.17.ls1.gamma", + "decoder.17.norm2.weight", + "decoder.17.norm2.bias", + "decoder.17.mlp.fc1.bias", + "decoder.17.mlp.fc2.bias", + "decoder.17.ls2.gamma", + "decoder.18.norm1.weight", + "decoder.18.norm1.bias", + "decoder.18.attn.qkv.bias", + "decoder.18.attn.proj.bias", + "decoder.18.attn.q_norm.weight", + "decoder.18.attn.q_norm.bias", + "decoder.18.attn.k_norm.weight", + "decoder.18.attn.k_norm.bias", + "decoder.18.ls1.gamma", + "decoder.18.norm2.weight", + "decoder.18.norm2.bias", + "decoder.18.mlp.fc1.bias", + "decoder.18.mlp.fc2.bias", + "decoder.18.ls2.gamma", + "decoder.19.norm1.weight", + "decoder.19.norm1.bias", + "decoder.19.attn.qkv.bias", + "decoder.19.attn.proj.bias", + "decoder.19.attn.q_norm.weight", + "decoder.19.attn.q_norm.bias", + "decoder.19.attn.k_norm.weight", + "decoder.19.attn.k_norm.bias", + "decoder.19.ls1.gamma", + "decoder.19.norm2.weight", + "decoder.19.norm2.bias", + "decoder.19.mlp.fc1.bias", + "decoder.19.mlp.fc2.bias", + "decoder.19.ls2.gamma", + "decoder.20.norm1.weight", + "decoder.20.norm1.bias", + "decoder.20.attn.qkv.bias", + "decoder.20.attn.proj.bias", + "decoder.20.attn.q_norm.weight", + "decoder.20.attn.q_norm.bias", + "decoder.20.attn.k_norm.weight", + "decoder.20.attn.k_norm.bias", + "decoder.20.ls1.gamma", + "decoder.20.norm2.weight", + "decoder.20.norm2.bias", + "decoder.20.mlp.fc1.bias", + "decoder.20.mlp.fc2.bias", + "decoder.20.ls2.gamma", + "decoder.21.norm1.weight", + "decoder.21.norm1.bias", + "decoder.21.attn.qkv.bias", + "decoder.21.attn.proj.bias", + "decoder.21.attn.q_norm.weight", + "decoder.21.attn.q_norm.bias", + "decoder.21.attn.k_norm.weight", + "decoder.21.attn.k_norm.bias", + "decoder.21.ls1.gamma", + "decoder.21.norm2.weight", + "decoder.21.norm2.bias", + "decoder.21.mlp.fc1.bias", + "decoder.21.mlp.fc2.bias", + "decoder.21.ls2.gamma", + "decoder.22.norm1.weight", + "decoder.22.norm1.bias", + "decoder.22.attn.qkv.bias", + "decoder.22.attn.proj.bias", + "decoder.22.attn.q_norm.weight", + "decoder.22.attn.q_norm.bias", + "decoder.22.attn.k_norm.weight", + "decoder.22.attn.k_norm.bias", + "decoder.22.ls1.gamma", + "decoder.22.norm2.weight", + "decoder.22.norm2.bias", + "decoder.22.mlp.fc1.bias", + "decoder.22.mlp.fc2.bias", + "decoder.22.ls2.gamma", + "decoder.23.norm1.weight", + "decoder.23.norm1.bias", + "decoder.23.attn.qkv.bias", + "decoder.23.attn.proj.bias", + "decoder.23.attn.q_norm.weight", + "decoder.23.attn.q_norm.bias", + "decoder.23.attn.k_norm.weight", + "decoder.23.attn.k_norm.bias", + "decoder.23.ls1.gamma", + "decoder.23.norm2.weight", + "decoder.23.norm2.bias", + "decoder.23.mlp.fc1.bias", + "decoder.23.mlp.fc2.bias", + "decoder.23.ls2.gamma", + "decoder.24.norm1.weight", + "decoder.24.norm1.bias", + "decoder.24.attn.qkv.bias", + "decoder.24.attn.proj.bias", + "decoder.24.attn.q_norm.weight", + "decoder.24.attn.q_norm.bias", + "decoder.24.attn.k_norm.weight", + "decoder.24.attn.k_norm.bias", + "decoder.24.ls1.gamma", + "decoder.24.norm2.weight", + "decoder.24.norm2.bias", + "decoder.24.mlp.fc1.bias", + "decoder.24.mlp.fc2.bias", + "decoder.24.ls2.gamma", + "decoder.25.norm1.weight", + "decoder.25.norm1.bias", + "decoder.25.attn.qkv.bias", + "decoder.25.attn.proj.bias", + "decoder.25.attn.q_norm.weight", + "decoder.25.attn.q_norm.bias", + "decoder.25.attn.k_norm.weight", + "decoder.25.attn.k_norm.bias", + "decoder.25.ls1.gamma", + "decoder.25.norm2.weight", + "decoder.25.norm2.bias", + "decoder.25.mlp.fc1.bias", + "decoder.25.mlp.fc2.bias", + "decoder.25.ls2.gamma", + "decoder.26.norm1.weight", + "decoder.26.norm1.bias", + "decoder.26.attn.qkv.bias", + "decoder.26.attn.proj.bias", + "decoder.26.attn.q_norm.weight", + "decoder.26.attn.q_norm.bias", + "decoder.26.attn.k_norm.weight", + "decoder.26.attn.k_norm.bias", + "decoder.26.ls1.gamma", + "decoder.26.norm2.weight", + "decoder.26.norm2.bias", + "decoder.26.mlp.fc1.bias", + "decoder.26.mlp.fc2.bias", + "decoder.26.ls2.gamma", + "decoder.27.norm1.weight", + "decoder.27.norm1.bias", + "decoder.27.attn.qkv.bias", + "decoder.27.attn.proj.bias", + "decoder.27.attn.q_norm.weight", + "decoder.27.attn.q_norm.bias", + "decoder.27.attn.k_norm.weight", + "decoder.27.attn.k_norm.bias", + "decoder.27.ls1.gamma", + "decoder.27.norm2.weight", + "decoder.27.norm2.bias", + "decoder.27.mlp.fc1.bias", + "decoder.27.mlp.fc2.bias", + "decoder.27.ls2.gamma", + "decoder.28.norm1.weight", + "decoder.28.norm1.bias", + "decoder.28.attn.qkv.bias", + "decoder.28.attn.proj.bias", + "decoder.28.attn.q_norm.weight", + "decoder.28.attn.q_norm.bias", + "decoder.28.attn.k_norm.weight", + "decoder.28.attn.k_norm.bias", + "decoder.28.ls1.gamma", + "decoder.28.norm2.weight", + "decoder.28.norm2.bias", + "decoder.28.mlp.fc1.bias", + "decoder.28.mlp.fc2.bias", + "decoder.28.ls2.gamma", + "decoder.29.norm1.weight", + "decoder.29.norm1.bias", + "decoder.29.attn.qkv.bias", + "decoder.29.attn.proj.bias", + "decoder.29.attn.q_norm.weight", + "decoder.29.attn.q_norm.bias", + "decoder.29.attn.k_norm.weight", + "decoder.29.attn.k_norm.bias", + "decoder.29.ls1.gamma", + "decoder.29.norm2.weight", + "decoder.29.norm2.bias", + "decoder.29.mlp.fc1.bias", + "decoder.29.mlp.fc2.bias", + "decoder.29.ls2.gamma", + "decoder.30.norm1.weight", + "decoder.30.norm1.bias", + "decoder.30.attn.qkv.bias", + "decoder.30.attn.proj.bias", + "decoder.30.attn.q_norm.weight", + "decoder.30.attn.q_norm.bias", + "decoder.30.attn.k_norm.weight", + "decoder.30.attn.k_norm.bias", + "decoder.30.ls1.gamma", + "decoder.30.norm2.weight", + "decoder.30.norm2.bias", + "decoder.30.mlp.fc1.bias", + "decoder.30.mlp.fc2.bias", + "decoder.30.ls2.gamma", + "decoder.31.norm1.weight", + "decoder.31.norm1.bias", + "decoder.31.attn.qkv.bias", + "decoder.31.attn.proj.bias", + "decoder.31.attn.q_norm.weight", + "decoder.31.attn.q_norm.bias", + "decoder.31.attn.k_norm.weight", + "decoder.31.attn.k_norm.bias", + "decoder.31.ls1.gamma", + "decoder.31.norm2.weight", + "decoder.31.norm2.bias", + "decoder.31.mlp.fc1.bias", + "decoder.31.mlp.fc2.bias", + "decoder.31.ls2.gamma", + "decoder.32.norm1.weight", + "decoder.32.norm1.bias", + "decoder.32.attn.qkv.bias", + "decoder.32.attn.proj.bias", + "decoder.32.attn.q_norm.weight", + "decoder.32.attn.q_norm.bias", + "decoder.32.attn.k_norm.weight", + "decoder.32.attn.k_norm.bias", + "decoder.32.ls1.gamma", + "decoder.32.norm2.weight", + "decoder.32.norm2.bias", + "decoder.32.mlp.fc1.bias", + "decoder.32.mlp.fc2.bias", + "decoder.32.ls2.gamma", + "decoder.33.norm1.weight", + "decoder.33.norm1.bias", + "decoder.33.attn.qkv.bias", + "decoder.33.attn.proj.bias", + "decoder.33.attn.q_norm.weight", + "decoder.33.attn.q_norm.bias", + "decoder.33.attn.k_norm.weight", + "decoder.33.attn.k_norm.bias", + "decoder.33.ls1.gamma", + "decoder.33.norm2.weight", + "decoder.33.norm2.bias", + "decoder.33.mlp.fc1.bias", + "decoder.33.mlp.fc2.bias", + "decoder.33.ls2.gamma", + "decoder.34.norm1.weight", + "decoder.34.norm1.bias", + "decoder.34.attn.qkv.bias", + "decoder.34.attn.proj.bias", + "decoder.34.attn.q_norm.weight", + "decoder.34.attn.q_norm.bias", + "decoder.34.attn.k_norm.weight", + "decoder.34.attn.k_norm.bias", + "decoder.34.ls1.gamma", + "decoder.34.norm2.weight", + "decoder.34.norm2.bias", + "decoder.34.mlp.fc1.bias", + "decoder.34.mlp.fc2.bias", + "decoder.34.ls2.gamma", + "decoder.35.norm1.weight", + "decoder.35.norm1.bias", + "decoder.35.attn.qkv.bias", + "decoder.35.attn.proj.bias", + "decoder.35.attn.q_norm.weight", + "decoder.35.attn.q_norm.bias", + "decoder.35.attn.k_norm.weight", + "decoder.35.attn.k_norm.bias", + "decoder.35.ls1.gamma", + "decoder.35.norm2.weight", + "decoder.35.norm2.bias", + "decoder.35.mlp.fc1.bias", + "decoder.35.mlp.fc2.bias", + "decoder.35.ls2.gamma", + "point_decoder.projects.bias", + "point_decoder.blocks.0.norm1.weight", + "point_decoder.blocks.0.norm1.bias", + "point_decoder.blocks.0.attn.qkv.bias", + "point_decoder.blocks.0.attn.proj.bias", + "point_decoder.blocks.0.norm2.weight", + "point_decoder.blocks.0.norm2.bias", + "point_decoder.blocks.0.mlp.fc1.bias", + "point_decoder.blocks.0.mlp.fc2.bias", + "point_decoder.blocks.1.norm1.weight", + "point_decoder.blocks.1.norm1.bias", + "point_decoder.blocks.1.attn.qkv.bias", + "point_decoder.blocks.1.attn.proj.bias", + "point_decoder.blocks.1.norm2.weight", + "point_decoder.blocks.1.norm2.bias", + "point_decoder.blocks.1.mlp.fc1.bias", + "point_decoder.blocks.1.mlp.fc2.bias", + "point_decoder.blocks.2.norm1.weight", + "point_decoder.blocks.2.norm1.bias", + "point_decoder.blocks.2.attn.qkv.bias", + "point_decoder.blocks.2.attn.proj.bias", + "point_decoder.blocks.2.norm2.weight", + "point_decoder.blocks.2.norm2.bias", + "point_decoder.blocks.2.mlp.fc1.bias", + "point_decoder.blocks.2.mlp.fc2.bias", + "point_decoder.blocks.3.norm1.weight", + "point_decoder.blocks.3.norm1.bias", + "point_decoder.blocks.3.attn.qkv.bias", + "point_decoder.blocks.3.attn.proj.bias", + "point_decoder.blocks.3.norm2.weight", + "point_decoder.blocks.3.norm2.bias", + "point_decoder.blocks.3.mlp.fc1.bias", + "point_decoder.blocks.3.mlp.fc2.bias", + "point_decoder.blocks.4.norm1.weight", + "point_decoder.blocks.4.norm1.bias", + "point_decoder.blocks.4.attn.qkv.bias", + "point_decoder.blocks.4.attn.proj.bias", + "point_decoder.blocks.4.norm2.weight", + "point_decoder.blocks.4.norm2.bias", + "point_decoder.blocks.4.mlp.fc1.bias", + "point_decoder.blocks.4.mlp.fc2.bias", + "point_decoder.linear_out.bias", + "point_head.proj.bias", + "conf_decoder.projects.bias", + "conf_decoder.blocks.0.norm1.weight", + "conf_decoder.blocks.0.norm1.bias", + "conf_decoder.blocks.0.attn.qkv.bias", + "conf_decoder.blocks.0.attn.proj.bias", + "conf_decoder.blocks.0.norm2.weight", + "conf_decoder.blocks.0.norm2.bias", + "conf_decoder.blocks.0.mlp.fc1.bias", + "conf_decoder.blocks.0.mlp.fc2.bias", + "conf_decoder.blocks.1.norm1.weight", + "conf_decoder.blocks.1.norm1.bias", + "conf_decoder.blocks.1.attn.qkv.bias", + "conf_decoder.blocks.1.attn.proj.bias", + "conf_decoder.blocks.1.norm2.weight", + "conf_decoder.blocks.1.norm2.bias", + "conf_decoder.blocks.1.mlp.fc1.bias", + "conf_decoder.blocks.1.mlp.fc2.bias", + "conf_decoder.blocks.2.norm1.weight", + "conf_decoder.blocks.2.norm1.bias", + "conf_decoder.blocks.2.attn.qkv.bias", + "conf_decoder.blocks.2.attn.proj.bias", + "conf_decoder.blocks.2.norm2.weight", + "conf_decoder.blocks.2.norm2.bias", + "conf_decoder.blocks.2.mlp.fc1.bias", + "conf_decoder.blocks.2.mlp.fc2.bias", + "conf_decoder.blocks.3.norm1.weight", + "conf_decoder.blocks.3.norm1.bias", + "conf_decoder.blocks.3.attn.qkv.bias", + "conf_decoder.blocks.3.attn.proj.bias", + "conf_decoder.blocks.3.norm2.weight", + "conf_decoder.blocks.3.norm2.bias", + "conf_decoder.blocks.3.mlp.fc1.bias", + "conf_decoder.blocks.3.mlp.fc2.bias", + "conf_decoder.blocks.4.norm1.weight", + "conf_decoder.blocks.4.norm1.bias", + "conf_decoder.blocks.4.attn.qkv.bias", + "conf_decoder.blocks.4.attn.proj.bias", + "conf_decoder.blocks.4.norm2.weight", + "conf_decoder.blocks.4.norm2.bias", + "conf_decoder.blocks.4.mlp.fc1.bias", + "conf_decoder.blocks.4.mlp.fc2.bias", + "conf_decoder.linear_out.bias", + "conf_head.proj.bias", + "camera_decoder.projects.bias", + "camera_decoder.blocks.0.norm1.weight", + "camera_decoder.blocks.0.norm1.bias", + "camera_decoder.blocks.0.attn.qkv.bias", + "camera_decoder.blocks.0.attn.proj.bias", + "camera_decoder.blocks.0.norm2.weight", + "camera_decoder.blocks.0.norm2.bias", + "camera_decoder.blocks.0.mlp.fc1.bias", + "camera_decoder.blocks.0.mlp.fc2.bias", + "camera_decoder.blocks.1.norm1.weight", + "camera_decoder.blocks.1.norm1.bias", + "camera_decoder.blocks.1.attn.qkv.bias", + "camera_decoder.blocks.1.attn.proj.bias", + "camera_decoder.blocks.1.norm2.weight", + "camera_decoder.blocks.1.norm2.bias", + "camera_decoder.blocks.1.mlp.fc1.bias", + "camera_decoder.blocks.1.mlp.fc2.bias", + "camera_decoder.blocks.2.norm1.weight", + "camera_decoder.blocks.2.norm1.bias", + "camera_decoder.blocks.2.attn.qkv.bias", + "camera_decoder.blocks.2.attn.proj.bias", + "camera_decoder.blocks.2.norm2.weight", + "camera_decoder.blocks.2.norm2.bias", + "camera_decoder.blocks.2.mlp.fc1.bias", + "camera_decoder.blocks.2.mlp.fc2.bias", + "camera_decoder.blocks.3.norm1.weight", + "camera_decoder.blocks.3.norm1.bias", + "camera_decoder.blocks.3.attn.qkv.bias", + "camera_decoder.blocks.3.attn.proj.bias", + "camera_decoder.blocks.3.norm2.weight", + "camera_decoder.blocks.3.norm2.bias", + "camera_decoder.blocks.3.mlp.fc1.bias", + "camera_decoder.blocks.3.mlp.fc2.bias", + "camera_decoder.blocks.4.norm1.weight", + "camera_decoder.blocks.4.norm1.bias", + "camera_decoder.blocks.4.attn.qkv.bias", + "camera_decoder.blocks.4.attn.proj.bias", + "camera_decoder.blocks.4.norm2.weight", + "camera_decoder.blocks.4.norm2.bias", + "camera_decoder.blocks.4.mlp.fc1.bias", + "camera_decoder.blocks.4.mlp.fc2.bias", + "camera_decoder.linear_out.bias", + "camera_head.res_conv.0.res_conv1.bias", + "camera_head.res_conv.0.res_conv2.bias", + "camera_head.res_conv.0.res_conv3.bias", + "camera_head.res_conv.1.res_conv1.bias", + "camera_head.res_conv.1.res_conv2.bias", + "camera_head.res_conv.1.res_conv3.bias", + "camera_head.more_mlps.0.bias", + "camera_head.more_mlps.2.bias", + "camera_head.fc_t.bias", + "camera_head.fc_rot.bias" + ], + "lr_scale": 1.0 + }, + "decay": { + "weight_decay": 0.05, + "params": [ + "decoder.0.attn.qkv.weight", + "decoder.0.attn.proj.weight", + "decoder.0.mlp.fc1.weight", + "decoder.0.mlp.fc2.weight", + "decoder.1.attn.qkv.weight", + "decoder.1.attn.proj.weight", + "decoder.1.mlp.fc1.weight", + "decoder.1.mlp.fc2.weight", + "decoder.2.attn.qkv.weight", + "decoder.2.attn.proj.weight", + "decoder.2.mlp.fc1.weight", + "decoder.2.mlp.fc2.weight", + "decoder.3.attn.qkv.weight", + "decoder.3.attn.proj.weight", + "decoder.3.mlp.fc1.weight", + "decoder.3.mlp.fc2.weight", + "decoder.4.attn.qkv.weight", + "decoder.4.attn.proj.weight", + "decoder.4.mlp.fc1.weight", + "decoder.4.mlp.fc2.weight", + "decoder.5.attn.qkv.weight", + "decoder.5.attn.proj.weight", + "decoder.5.mlp.fc1.weight", + "decoder.5.mlp.fc2.weight", + "decoder.6.attn.qkv.weight", + "decoder.6.attn.proj.weight", + "decoder.6.mlp.fc1.weight", + "decoder.6.mlp.fc2.weight", + "decoder.7.attn.qkv.weight", + "decoder.7.attn.proj.weight", + "decoder.7.mlp.fc1.weight", + "decoder.7.mlp.fc2.weight", + "decoder.8.attn.qkv.weight", + "decoder.8.attn.proj.weight", + "decoder.8.mlp.fc1.weight", + "decoder.8.mlp.fc2.weight", + "decoder.9.attn.qkv.weight", + "decoder.9.attn.proj.weight", + "decoder.9.mlp.fc1.weight", + "decoder.9.mlp.fc2.weight", + "decoder.10.attn.qkv.weight", + "decoder.10.attn.proj.weight", + "decoder.10.mlp.fc1.weight", + "decoder.10.mlp.fc2.weight", + "decoder.11.attn.qkv.weight", + "decoder.11.attn.proj.weight", + "decoder.11.mlp.fc1.weight", + "decoder.11.mlp.fc2.weight", + "decoder.12.attn.qkv.weight", + "decoder.12.attn.proj.weight", + "decoder.12.mlp.fc1.weight", + "decoder.12.mlp.fc2.weight", + "decoder.13.attn.qkv.weight", + "decoder.13.attn.proj.weight", + "decoder.13.mlp.fc1.weight", + "decoder.13.mlp.fc2.weight", + "decoder.14.attn.qkv.weight", + "decoder.14.attn.proj.weight", + "decoder.14.mlp.fc1.weight", + "decoder.14.mlp.fc2.weight", + "decoder.15.attn.qkv.weight", + "decoder.15.attn.proj.weight", + "decoder.15.mlp.fc1.weight", + "decoder.15.mlp.fc2.weight", + "decoder.16.attn.qkv.weight", + "decoder.16.attn.proj.weight", + "decoder.16.mlp.fc1.weight", + "decoder.16.mlp.fc2.weight", + "decoder.17.attn.qkv.weight", + "decoder.17.attn.proj.weight", + "decoder.17.mlp.fc1.weight", + "decoder.17.mlp.fc2.weight", + "decoder.18.attn.qkv.weight", + "decoder.18.attn.proj.weight", + "decoder.18.mlp.fc1.weight", + "decoder.18.mlp.fc2.weight", + "decoder.19.attn.qkv.weight", + "decoder.19.attn.proj.weight", + "decoder.19.mlp.fc1.weight", + "decoder.19.mlp.fc2.weight", + "decoder.20.attn.qkv.weight", + "decoder.20.attn.proj.weight", + "decoder.20.mlp.fc1.weight", + "decoder.20.mlp.fc2.weight", + "decoder.21.attn.qkv.weight", + "decoder.21.attn.proj.weight", + "decoder.21.mlp.fc1.weight", + "decoder.21.mlp.fc2.weight", + "decoder.22.attn.qkv.weight", + "decoder.22.attn.proj.weight", + "decoder.22.mlp.fc1.weight", + "decoder.22.mlp.fc2.weight", + "decoder.23.attn.qkv.weight", + "decoder.23.attn.proj.weight", + "decoder.23.mlp.fc1.weight", + "decoder.23.mlp.fc2.weight", + "decoder.24.attn.qkv.weight", + "decoder.24.attn.proj.weight", + "decoder.24.mlp.fc1.weight", + "decoder.24.mlp.fc2.weight", + "decoder.25.attn.qkv.weight", + "decoder.25.attn.proj.weight", + "decoder.25.mlp.fc1.weight", + "decoder.25.mlp.fc2.weight", + "decoder.26.attn.qkv.weight", + "decoder.26.attn.proj.weight", + "decoder.26.mlp.fc1.weight", + "decoder.26.mlp.fc2.weight", + "decoder.27.attn.qkv.weight", + "decoder.27.attn.proj.weight", + "decoder.27.mlp.fc1.weight", + "decoder.27.mlp.fc2.weight", + "decoder.28.attn.qkv.weight", + "decoder.28.attn.proj.weight", + "decoder.28.mlp.fc1.weight", + "decoder.28.mlp.fc2.weight", + "decoder.29.attn.qkv.weight", + "decoder.29.attn.proj.weight", + "decoder.29.mlp.fc1.weight", + "decoder.29.mlp.fc2.weight", + "decoder.30.attn.qkv.weight", + "decoder.30.attn.proj.weight", + "decoder.30.mlp.fc1.weight", + "decoder.30.mlp.fc2.weight", + "decoder.31.attn.qkv.weight", + "decoder.31.attn.proj.weight", + "decoder.31.mlp.fc1.weight", + "decoder.31.mlp.fc2.weight", + "decoder.32.attn.qkv.weight", + "decoder.32.attn.proj.weight", + "decoder.32.mlp.fc1.weight", + "decoder.32.mlp.fc2.weight", + "decoder.33.attn.qkv.weight", + "decoder.33.attn.proj.weight", + "decoder.33.mlp.fc1.weight", + "decoder.33.mlp.fc2.weight", + "decoder.34.attn.qkv.weight", + "decoder.34.attn.proj.weight", + "decoder.34.mlp.fc1.weight", + "decoder.34.mlp.fc2.weight", + "decoder.35.attn.qkv.weight", + "decoder.35.attn.proj.weight", + "decoder.35.mlp.fc1.weight", + "decoder.35.mlp.fc2.weight", + "point_decoder.projects.weight", + "point_decoder.blocks.0.attn.qkv.weight", + "point_decoder.blocks.0.attn.proj.weight", + "point_decoder.blocks.0.mlp.fc1.weight", + "point_decoder.blocks.0.mlp.fc2.weight", + "point_decoder.blocks.1.attn.qkv.weight", + "point_decoder.blocks.1.attn.proj.weight", + "point_decoder.blocks.1.mlp.fc1.weight", + "point_decoder.blocks.1.mlp.fc2.weight", + "point_decoder.blocks.2.attn.qkv.weight", + "point_decoder.blocks.2.attn.proj.weight", + "point_decoder.blocks.2.mlp.fc1.weight", + "point_decoder.blocks.2.mlp.fc2.weight", + "point_decoder.blocks.3.attn.qkv.weight", + "point_decoder.blocks.3.attn.proj.weight", + "point_decoder.blocks.3.mlp.fc1.weight", + "point_decoder.blocks.3.mlp.fc2.weight", + "point_decoder.blocks.4.attn.qkv.weight", + "point_decoder.blocks.4.attn.proj.weight", + "point_decoder.blocks.4.mlp.fc1.weight", + "point_decoder.blocks.4.mlp.fc2.weight", + "point_decoder.linear_out.weight", + "point_head.proj.weight", + "conf_decoder.projects.weight", + "conf_decoder.blocks.0.attn.qkv.weight", + "conf_decoder.blocks.0.attn.proj.weight", + "conf_decoder.blocks.0.mlp.fc1.weight", + "conf_decoder.blocks.0.mlp.fc2.weight", + "conf_decoder.blocks.1.attn.qkv.weight", + "conf_decoder.blocks.1.attn.proj.weight", + "conf_decoder.blocks.1.mlp.fc1.weight", + "conf_decoder.blocks.1.mlp.fc2.weight", + "conf_decoder.blocks.2.attn.qkv.weight", + "conf_decoder.blocks.2.attn.proj.weight", + "conf_decoder.blocks.2.mlp.fc1.weight", + "conf_decoder.blocks.2.mlp.fc2.weight", + "conf_decoder.blocks.3.attn.qkv.weight", + "conf_decoder.blocks.3.attn.proj.weight", + "conf_decoder.blocks.3.mlp.fc1.weight", + "conf_decoder.blocks.3.mlp.fc2.weight", + "conf_decoder.blocks.4.attn.qkv.weight", + "conf_decoder.blocks.4.attn.proj.weight", + "conf_decoder.blocks.4.mlp.fc1.weight", + "conf_decoder.blocks.4.mlp.fc2.weight", + "conf_decoder.linear_out.weight", + "conf_head.proj.weight", + "camera_decoder.projects.weight", + "camera_decoder.blocks.0.attn.qkv.weight", + "camera_decoder.blocks.0.attn.proj.weight", + "camera_decoder.blocks.0.mlp.fc1.weight", + "camera_decoder.blocks.0.mlp.fc2.weight", + "camera_decoder.blocks.1.attn.qkv.weight", + "camera_decoder.blocks.1.attn.proj.weight", + "camera_decoder.blocks.1.mlp.fc1.weight", + "camera_decoder.blocks.1.mlp.fc2.weight", + "camera_decoder.blocks.2.attn.qkv.weight", + "camera_decoder.blocks.2.attn.proj.weight", + "camera_decoder.blocks.2.mlp.fc1.weight", + "camera_decoder.blocks.2.mlp.fc2.weight", + "camera_decoder.blocks.3.attn.qkv.weight", + "camera_decoder.blocks.3.attn.proj.weight", + "camera_decoder.blocks.3.mlp.fc1.weight", + "camera_decoder.blocks.3.mlp.fc2.weight", + "camera_decoder.blocks.4.attn.qkv.weight", + "camera_decoder.blocks.4.attn.proj.weight", + "camera_decoder.blocks.4.mlp.fc1.weight", + "camera_decoder.blocks.4.mlp.fc2.weight", + "camera_decoder.linear_out.weight", + "camera_head.res_conv.0.res_conv1.weight", + "camera_head.res_conv.0.res_conv2.weight", + "camera_head.res_conv.0.res_conv3.weight", + "camera_head.res_conv.1.res_conv1.weight", + "camera_head.res_conv.1.res_conv2.weight", + "camera_head.res_conv.1.res_conv3.weight", + "camera_head.more_mlps.0.weight", + "camera_head.more_mlps.2.weight", + "camera_head.fc_t.weight", + "camera_head.fc_rot.weight" + ], + "lr_scale": 1.0 + } +} +[2026-05-02 01:11:31,675][__main__][INFO] - [RANK 0] Start training for 10 epochs +[2026-05-02 01:11:31,679][__main__][INFO] - [RANK 0] log_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_8gpu/ +[2026-05-02 01:12:55,295][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 0/2175] eta: 2 days, 2:30:54 lr: 0.000000 epoch: 0.0000 (0.0000) step: 0.0000 (0.0000) loss: 5439.7471 (5439.7471) Lcamera_frontend: 4.1901 (4.1901) Ldepth_frontend: 16.2068 (16.2068) Lpmap_frontend: 18.4157 (18.4157) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.1823 (4.1823) Ldepth_mix: 16.2088 (16.2088) Lpmap_mix: 18.4142 (18.4142) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.1872 (4.1872) Ldepth_backend: 16.2021 (16.2021) Lpmap_backend: 18.4084 (18.4084) Ltrack_backend: 0.0000 (0.0000) total: 5439.7471 (5439.7471) time: 83.6111 data: 29.0064 max mem: 32998 +[2026-05-02 01:21:05,143][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 10/2175] eta: 1 day, 7:20:53 lr: 0.000000 epoch: 0.0023 (0.0023) step: 5.0000 (4.8182) loss: 5439.7471 (5074.5225) Lcamera_frontend: 4.1901 (3.9014) Ldepth_frontend: 16.5766 (17.1214) Lpmap_frontend: 18.4116 (18.3189) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.1545 (3.8660) Ldepth_mix: 16.5722 (17.1337) Lpmap_mix: 18.4120 (18.3247) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.1872 (3.8722) Ldepth_backend: 16.5609 (17.1422) Lpmap_backend: 18.4084 (18.3294) Ltrack_backend: 0.0000 (0.0000) total: 5439.7471 (5074.5225) time: 52.1264 data: 2.6741 max mem: 78608 +[2026-05-02 01:29:36,320][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 20/2175] eta: 1 day, 6:54:57 lr: 0.000000 epoch: 0.0046 (0.0046) step: 10.0000 (9.8571) loss: 4455.1465 (4960.1760) Lcamera_frontend: 3.3703 (3.7843) Ldepth_frontend: 16.6046 (17.0302) Lpmap_frontend: 18.2110 (18.2589) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3676 (3.7677) Ldepth_mix: 16.6190 (17.0379) Lpmap_mix: 18.2049 (18.2631) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3687 (3.7813) Ldepth_backend: 16.6365 (17.0453) Lpmap_backend: 18.2003 (18.2679) Ltrack_backend: 0.0000 (0.0000) total: 4455.1465 (4960.1760) time: 50.0477 data: 0.0390 max mem: 78608 +[2026-05-02 01:38:31,581][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 30/2175] eta: 1 day, 7:08:01 lr: 0.000000 epoch: 0.0092 (0.0069) step: 20.0000 (14.8710) loss: 4440.9858 (5149.5577) Lcamera_frontend: 3.3703 (3.9496) Ldepth_frontend: 16.4207 (16.8971) Lpmap_frontend: 18.1098 (18.1753) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3676 (3.9317) Ldepth_mix: 16.4052 (16.9020) Lpmap_mix: 18.1147 (18.1778) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3687 (3.9404) Ldepth_backend: 16.3921 (16.9086) Lpmap_backend: 18.1122 (18.1819) Ltrack_backend: 0.0000 (0.0000) total: 4440.9858 (5149.5577) time: 52.3218 data: 0.0359 max mem: 78608 +[2026-05-02 01:47:19,122][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 40/2175] eta: 1 day, 7:03:40 lr: 0.000000 epoch: 0.0138 (0.0092) step: 30.0000 (19.9024) loss: 3602.1575 (4738.6459) Lcamera_frontend: 2.6445 (3.6082) Ldepth_frontend: 16.3623 (16.7774) Lpmap_frontend: 17.8377 (18.0587) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.6469 (3.5945) Ldepth_mix: 16.3657 (16.7802) Lpmap_mix: 17.8406 (18.0600) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.6506 (3.6002) Ldepth_backend: 16.3908 (16.7842) Lpmap_backend: 17.8430 (18.0624) Ltrack_backend: 0.0000 (0.0000) total: 3602.1575 (4738.6459) time: 53.1400 data: 0.0470 max mem: 78608 +[2026-05-02 01:56:01,602][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 50/2175] eta: 1 day, 6:54:03 lr: 0.000000 epoch: 0.0184 (0.0115) step: 40.0000 (24.8824) loss: 4052.8667 (4914.7079) Lcamera_frontend: 3.0227 (3.7566) Ldepth_frontend: 15.3567 (16.3900) Lpmap_frontend: 17.5850 (17.9529) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.0263 (3.7473) Ldepth_mix: 15.3520 (16.3914) Lpmap_mix: 17.5934 (17.9540) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.0425 (3.7521) Ldepth_backend: 15.3410 (16.3942) Lpmap_backend: 17.5838 (17.9559) Ltrack_backend: 0.0000 (0.0000) total: 4052.8667 (4914.7079) time: 52.5009 data: 0.0486 max mem: 78608 +[2026-05-02 02:04:35,533][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 60/2175] eta: 1 day, 6:39:48 lr: 0.000001 epoch: 0.0230 (0.0138) step: 50.0000 (29.8852) loss: 4611.7891 (4624.1606) Lcamera_frontend: 3.5242 (3.5173) Ldepth_frontend: 14.8783 (16.1767) Lpmap_frontend: 17.2696 (17.8315) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5254 (3.5091) Ldepth_mix: 14.8753 (16.1785) Lpmap_mix: 17.2746 (17.8327) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5254 (3.5133) Ldepth_backend: 14.8962 (16.1812) Lpmap_backend: 17.2799 (17.8347) Ltrack_backend: 0.0000 (0.0000) total: 4611.7891 (4624.1606) time: 51.8204 data: 0.0384 max mem: 78608 +[2026-05-02 02:12:54,678][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 70/2175] eta: 1 day, 6:19:49 lr: 0.000001 epoch: 0.0276 (0.0161) step: 60.0000 (34.9014) loss: 2338.9619 (4415.5456) Lcamera_frontend: 1.6208 (3.3463) Ldepth_frontend: 14.6870 (15.9957) Lpmap_frontend: 17.0434 (17.6934) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.6168 (3.3389) Ldepth_mix: 14.6816 (15.9976) Lpmap_mix: 17.0603 (17.6945) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.6194 (3.3427) Ldepth_backend: 14.6728 (16.0002) Lpmap_backend: 17.0784 (17.6962) Ltrack_backend: 0.0000 (0.0000) total: 2338.9619 (4415.5456) time: 50.6515 data: 0.0365 max mem: 78608 +[2026-05-02 02:21:55,933][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 80/2175] eta: 1 day, 6:20:52 lr: 0.000001 epoch: 0.0322 (0.0184) step: 70.0000 (39.9136) loss: 2830.8052 (4507.4803) Lcamera_frontend: 2.0447 (3.4276) Ldepth_frontend: 13.6907 (15.6550) Lpmap_frontend: 16.6447 (17.5395) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.0450 (3.4212) Ldepth_mix: 13.6989 (15.6557) Lpmap_mix: 16.6385 (17.5400) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.0456 (3.4242) Ldepth_backend: 13.6893 (15.6568) Lpmap_backend: 16.6355 (17.5410) Ltrack_backend: 0.0000 (0.0000) total: 2830.8052 (4507.4803) time: 52.0161 data: 0.0351 max mem: 78608 +[2026-05-02 02:30:44,088][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 90/2175] eta: 1 day, 6:14:43 lr: 0.000001 epoch: 0.0368 (0.0207) step: 80.0000 (44.9011) loss: 4618.3120 (4626.6009) Lcamera_frontend: 3.5741 (3.5324) Ldepth_frontend: 11.9974 (15.2596) Lpmap_frontend: 16.1725 (17.3686) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5745 (3.5266) Ldepth_mix: 11.9872 (15.2598) Lpmap_mix: 16.1768 (17.3687) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5731 (3.5292) Ldepth_backend: 11.9790 (15.2601) Lpmap_backend: 16.1724 (17.3693) Ltrack_backend: 0.0000 (0.0000) total: 4618.3120 (4626.6009) time: 53.4687 data: 0.0379 max mem: 78608 +[2026-05-02 02:39:46,093][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 100/2175] eta: 1 day, 6:12:48 lr: 0.000001 epoch: 0.0414 (0.0230) step: 90.0000 (49.8911) loss: 5022.5645 (4623.4317) Lcamera_frontend: 3.9090 (3.5353) Ldepth_frontend: 11.1564 (14.8799) Lpmap_frontend: 15.8865 (17.1837) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.9080 (3.5302) Ldepth_mix: 11.1525 (14.8791) Lpmap_mix: 15.8831 (17.1832) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.9096 (3.5321) Ldepth_backend: 11.1428 (14.8785) Lpmap_backend: 15.8805 (17.1833) Ltrack_backend: 0.0000 (0.0000) total: 5022.5645 (4623.4317) time: 53.5079 data: 0.0404 max mem: 78608 +[2026-05-02 02:48:41,248][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 110/2175] eta: 1 day, 6:07:27 lr: 0.000001 epoch: 0.0460 (0.0253) step: 100.0000 (54.8829) loss: 4557.3750 (4644.9514) Lcamera_frontend: 3.5408 (3.5588) Ldepth_frontend: 10.7256 (14.4959) Lpmap_frontend: 15.1856 (17.0001) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5416 (3.5540) Ldepth_mix: 10.7130 (14.4946) Lpmap_mix: 15.1785 (16.9993) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5403 (3.5558) Ldepth_backend: 10.7029 (14.4933) Lpmap_backend: 15.1740 (16.9991) Ltrack_backend: 0.0000 (0.0000) total: 4557.3750 (4644.9514) time: 53.8579 data: 0.0397 max mem: 78608 +[2026-05-02 02:57:26,794][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 120/2175] eta: 1 day, 5:58:49 lr: 0.000001 epoch: 0.0506 (0.0276) step: 110.0000 (59.8926) loss: 5502.2861 (4683.7406) Lcamera_frontend: 4.3387 (3.5975) Ldepth_frontend: 9.6748 (14.0736) Lpmap_frontend: 14.5716 (16.7704) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.3366 (3.5929) Ldepth_mix: 9.6837 (14.0719) Lpmap_mix: 14.5608 (16.7693) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.3393 (3.5946) Ldepth_backend: 9.6808 (14.0705) Lpmap_backend: 14.5548 (16.7689) Ltrack_backend: 0.0000 (0.0000) total: 5502.2861 (4683.7406) time: 53.0349 data: 0.0371 max mem: 78608 +[2026-05-02 03:05:51,048][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 130/2175] eta: 1 day, 5:44:36 lr: 0.000001 epoch: 0.0552 (0.0299) step: 120.0000 (64.9008) loss: 4250.2085 (4629.8815) Lcamera_frontend: 3.3097 (3.5586) Ldepth_frontend: 8.8296 (13.6655) Lpmap_frontend: 14.1952 (16.5645) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3095 (3.5543) Ldepth_mix: 8.8287 (13.6635) Lpmap_mix: 14.1915 (16.5631) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3103 (3.5559) Ldepth_backend: 8.8278 (13.6615) Lpmap_backend: 14.1896 (16.5625) Ltrack_backend: 0.0000 (0.0000) total: 4250.2085 (4629.8815) time: 51.4899 data: 0.0381 max mem: 78608 +[2026-05-02 03:14:28,174][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 140/2175] eta: 1 day, 5:34:19 lr: 0.000001 epoch: 0.0598 (0.0322) step: 130.0000 (69.9078) loss: 4317.0518 (4674.2318) Lcamera_frontend: 3.3687 (3.6015) Ldepth_frontend: 7.5516 (13.2470) Lpmap_frontend: 14.0117 (16.3667) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3691 (3.5975) Ldepth_mix: 7.5389 (13.2443) Lpmap_mix: 14.0007 (16.3649) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3693 (3.5991) Ldepth_backend: 7.5254 (13.2416) Lpmap_backend: 13.9955 (16.3639) Ltrack_backend: 0.0000 (0.0000) total: 4317.0518 (4674.2318) time: 51.0678 data: 0.0376 max mem: 78608 +[2026-05-02 03:23:19,989][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 150/2175] eta: 1 day, 5:27:32 lr: 0.000001 epoch: 0.0644 (0.0345) step: 140.0000 (74.9139) loss: 3733.5522 (4523.8774) Lcamera_frontend: 2.9059 (3.4813) Ldepth_frontend: 7.4963 (12.9285) Lpmap_frontend: 13.4335 (16.1624) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.9045 (3.4775) Ldepth_mix: 7.4956 (12.9255) Lpmap_mix: 13.4196 (16.1604) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.9059 (3.4790) Ldepth_backend: 7.4941 (12.9226) Lpmap_backend: 13.4136 (16.1593) Ltrack_backend: 0.0000 (0.0000) total: 3733.5522 (4523.8774) time: 52.4452 data: 0.0389 max mem: 78608 +[2026-05-02 03:32:20,082][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 160/2175] eta: 1 day, 5:22:13 lr: 0.000001 epoch: 0.0690 (0.0368) step: 150.0000 (79.9130) loss: 3263.8560 (4528.5137) Lcamera_frontend: 2.5248 (3.4903) Ldepth_frontend: 7.2993 (12.5895) Lpmap_frontend: 13.3417 (15.9807) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.5234 (3.4867) Ldepth_mix: 7.2779 (12.5860) Lpmap_mix: 13.3475 (15.9782) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.5242 (3.4881) Ldepth_backend: 7.2775 (12.5829) Lpmap_backend: 13.3527 (15.9770) Ltrack_backend: 0.0000 (0.0000) total: 3263.8560 (4528.5137) time: 53.5947 data: 0.0386 max mem: 78608 +[2026-05-02 03:40:41,763][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 170/2175] eta: 1 day, 5:08:58 lr: 0.000002 epoch: 0.0736 (0.0391) step: 160.0000 (84.9064) loss: 5749.9067 (4649.9435) Lcamera_frontend: 4.6079 (3.5964) Ldepth_frontend: 6.2434 (12.2392) Lpmap_frontend: 13.3222 (15.8257) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.6020 (3.5931) Ldepth_mix: 6.2328 (12.2355) Lpmap_mix: 13.3193 (15.8231) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.6080 (3.5943) Ldepth_backend: 6.2291 (12.2318) Lpmap_backend: 13.3208 (15.8217) Ltrack_backend: 0.0000 (0.0000) total: 5749.9067 (4649.9435) time: 52.0885 data: 0.0339 max mem: 78608 +[2026-05-02 03:49:20,820][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 180/2175] eta: 1 day, 4:59:27 lr: 0.000002 epoch: 0.0782 (0.0414) step: 170.0000 (89.9006) loss: 5086.8892 (4638.7650) Lcamera_frontend: 4.0559 (3.5920) Ldepth_frontend: 5.7784 (11.9123) Lpmap_frontend: 13.0749 (15.6532) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.0555 (3.5888) Ldepth_mix: 5.7628 (11.9082) Lpmap_mix: 13.0696 (15.6503) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.0546 (3.5900) Ldepth_backend: 5.7299 (11.9040) Lpmap_backend: 13.0706 (15.6488) Ltrack_backend: 0.0000 (0.0000) total: 5086.8892 (4638.7650) time: 51.0368 data: 0.0340 max mem: 78608 +[2026-05-02 03:58:24,017][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 190/2175] eta: 1 day, 4:54:12 lr: 0.000002 epoch: 0.0828 (0.0437) step: 180.0000 (94.8953) loss: 3726.2163 (4563.1622) Lcamera_frontend: 2.9281 (3.5331) Ldepth_frontend: 6.1579 (11.6645) Lpmap_frontend: 12.5926 (15.4879) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.9263 (3.5299) Ldepth_mix: 6.1371 (11.6605) Lpmap_mix: 12.5918 (15.4849) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.9270 (3.5312) Ldepth_backend: 6.1051 (11.6564) Lpmap_backend: 12.5942 (15.4835) Ltrack_backend: 0.0000 (0.0000) total: 3726.2163 (4563.1622) time: 53.1117 data: 0.0345 max mem: 78608 +[2026-05-02 04:07:02,783][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 200/2175] eta: 1 day, 4:44:34 lr: 0.000002 epoch: 0.0874 (0.0460) step: 190.0000 (99.8905) loss: 4229.8970 (4568.5722) Lcamera_frontend: 3.3685 (3.5425) Ldepth_frontend: 5.4301 (11.3417) Lpmap_frontend: 12.3605 (15.3152) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3628 (3.5394) Ldepth_mix: 5.4255 (11.3376) Lpmap_mix: 12.3565 (15.3123) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3673 (3.5406) Ldepth_backend: 5.4232 (11.3335) Lpmap_backend: 12.3588 (15.3110) Ltrack_backend: 0.0000 (0.0000) total: 4229.8970 (4568.5722) time: 53.0926 data: 0.0355 max mem: 78608 +[2026-05-02 04:15:48,714][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 210/2175] eta: 1 day, 4:36:07 lr: 0.000002 epoch: 0.0920 (0.0483) step: 200.0000 (104.8863) loss: 4594.2324 (4561.3104) Lcamera_frontend: 3.6749 (3.5406) Ldepth_frontend: 4.8983 (11.0646) Lpmap_frontend: 12.1301 (15.1640) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.6718 (3.5376) Ldepth_mix: 4.8964 (11.0604) Lpmap_mix: 12.1285 (15.1607) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.6750 (3.5389) Ldepth_backend: 4.8902 (11.0561) Lpmap_backend: 12.1297 (15.1593) Ltrack_backend: 0.0000 (0.0000) total: 4594.2324 (4561.3104) time: 52.2232 data: 0.0406 max mem: 78608 +[2026-05-02 04:25:02,671][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 220/2175] eta: 1 day, 4:31:48 lr: 0.000002 epoch: 0.0966 (0.0506) step: 210.0000 (109.8824) loss: 4568.3115 (4536.8525) Lcamera_frontend: 3.6364 (3.5237) Ldepth_frontend: 5.0702 (10.8510) Lpmap_frontend: 12.2075 (15.0261) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.6373 (3.5207) Ldepth_mix: 5.0746 (10.8469) Lpmap_mix: 12.2109 (15.0228) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.6351 (3.5220) Ldepth_backend: 5.0673 (10.8428) Lpmap_backend: 12.2192 (15.0215) Ltrack_backend: 0.0000 (0.0000) total: 4568.3115 (4536.8525) time: 53.9872 data: 0.0441 max mem: 78608 +[2026-05-02 04:34:07,899][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 230/2175] eta: 1 day, 4:25:50 lr: 0.000002 epoch: 0.1011 (0.0529) step: 220.0000 (114.8874) loss: 4276.0312 (4509.4486) Lcamera_frontend: 3.3828 (3.5047) Ldepth_frontend: 5.0702 (10.6244) Lpmap_frontend: 11.8485 (14.8669) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3771 (3.5017) Ldepth_mix: 5.0746 (10.6205) Lpmap_mix: 11.8452 (14.8637) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3815 (3.5030) Ldepth_backend: 5.0673 (10.6163) Lpmap_backend: 11.8477 (14.8626) Ltrack_backend: 0.0000 (0.0000) total: 4276.0312 (4509.4486) time: 54.9591 data: 0.0412 max mem: 78608 +[2026-05-02 04:42:57,679][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 240/2175] eta: 1 day, 4:17:32 lr: 0.000002 epoch: 0.1057 (0.0552) step: 230.0000 (119.8921) loss: 3473.5327 (4456.8284) Lcamera_frontend: 2.7443 (3.4641) Ldepth_frontend: 5.5202 (10.4253) Lpmap_frontend: 11.5392 (14.7365) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.7350 (3.4610) Ldepth_mix: 5.5188 (10.4212) Lpmap_mix: 11.5297 (14.7331) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.7436 (3.4625) Ldepth_backend: 5.5181 (10.4169) Lpmap_backend: 11.5313 (14.7320) Ltrack_backend: 0.0000 (0.0000) total: 3473.5327 (4456.8284) time: 53.7503 data: 0.0367 max mem: 78608 +[2026-05-02 04:51:33,396][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 250/2175] eta: 1 day, 4:07:24 lr: 0.000002 epoch: 0.1103 (0.0575) step: 240.0000 (124.8964) loss: 3473.5327 (4453.5763) Lcamera_frontend: 2.7443 (3.4646) Ldepth_frontend: 5.8839 (10.2340) Lpmap_frontend: 11.6054 (14.6051) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.7350 (3.4615) Ldepth_mix: 5.8812 (10.2301) Lpmap_mix: 11.5968 (14.6018) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.7436 (3.4630) Ldepth_backend: 5.8789 (10.2260) Lpmap_backend: 11.5976 (14.6008) Ltrack_backend: 0.0000 (0.0000) total: 3473.5327 (4453.5763) time: 52.2747 data: 0.0347 max mem: 78608 +[2026-05-02 04:59:44,266][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 260/2175] eta: 1 day, 3:54:20 lr: 0.000002 epoch: 0.1149 (0.0598) step: 250.0000 (129.9004) loss: 5020.2090 (4496.6866) Lcamera_frontend: 4.0252 (3.5039) Ldepth_frontend: 4.1422 (10.0185) Lpmap_frontend: 11.5727 (14.4809) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.0093 (3.5008) Ldepth_mix: 4.1413 (10.0147) Lpmap_mix: 11.5749 (14.4776) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.0260 (3.5024) Ldepth_backend: 4.1399 (10.0107) Lpmap_backend: 11.5792 (14.4768) Ltrack_backend: 0.0000 (0.0000) total: 5020.2090 (4496.6866) time: 50.3262 data: 0.0330 max mem: 78608 +[2026-05-02 05:08:17,814][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 270/2175] eta: 1 day, 3:44:18 lr: 0.000002 epoch: 0.1195 (0.0621) step: 260.0000 (134.9041) loss: 5942.0122 (4565.5912) Lcamera_frontend: 4.8106 (3.5645) Ldepth_frontend: 3.6318 (9.8041) Lpmap_frontend: 11.6432 (14.3758) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.8094 (3.5614) Ldepth_mix: 3.6296 (9.8002) Lpmap_mix: 11.6372 (14.3724) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.8098 (3.5630) Ldepth_backend: 3.6232 (9.7961) Lpmap_backend: 11.6385 (14.3717) Ltrack_backend: 0.0000 (0.0000) total: 5942.0122 (4565.5912) time: 50.2172 data: 0.0332 max mem: 78608 +[2026-05-02 05:17:16,865][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 280/2175] eta: 1 day, 3:37:14 lr: 0.000003 epoch: 0.1241 (0.0644) step: 270.0000 (139.9075) loss: 5042.1792 (4552.1092) Lcamera_frontend: 4.0482 (3.5563) Ldepth_frontend: 4.0804 (9.6270) Lpmap_frontend: 11.2768 (14.2471) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.0395 (3.5531) Ldepth_mix: 4.0774 (9.6233) Lpmap_mix: 11.2713 (14.2436) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.0364 (3.5548) Ldepth_backend: 4.0766 (9.6193) Lpmap_backend: 11.2759 (14.2430) Ltrack_backend: 0.0000 (0.0000) total: 5042.1792 (4552.1092) time: 52.6293 data: 0.0360 max mem: 78608 +[2026-05-02 05:26:10,597][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 290/2175] eta: 1 day, 3:29:28 lr: 0.000003 epoch: 0.1287 (0.0667) step: 280.0000 (144.9107) loss: 4218.5874 (4591.9725) Lcamera_frontend: 3.3580 (3.5921) Ldepth_frontend: 4.4536 (9.4639) Lpmap_frontend: 10.9399 (14.1407) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3524 (3.5889) Ldepth_mix: 4.4557 (9.4602) Lpmap_mix: 10.9365 (14.1372) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3588 (3.5907) Ldepth_backend: 4.4555 (9.4564) Lpmap_backend: 10.9411 (14.1367) Ltrack_backend: 0.0000 (0.0000) total: 4218.5874 (4591.9725) time: 53.6390 data: 0.0369 max mem: 78608 +[2026-05-02 05:34:56,198][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 300/2175] eta: 1 day, 3:20:46 lr: 0.000003 epoch: 0.1333 (0.0690) step: 290.0000 (149.9136) loss: 4716.9814 (4587.6797) Lcamera_frontend: 3.7770 (3.5910) Ldepth_frontend: 4.3568 (9.3207) Lpmap_frontend: 11.3872 (14.0354) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.7716 (3.5877) Ldepth_mix: 4.3548 (9.3171) Lpmap_mix: 11.3845 (14.0318) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.7765 (3.5897) Ldepth_backend: 4.3515 (9.3134) Lpmap_backend: 11.3857 (14.0314) Ltrack_backend: 0.0000 (0.0000) total: 4716.9814 (4587.6797) time: 52.9665 data: 0.0352 max mem: 78608 +[2026-05-02 05:43:29,144][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 310/2175] eta: 1 day, 3:10:49 lr: 0.000003 epoch: 0.1379 (0.0713) step: 300.0000 (154.9164) loss: 4312.5737 (4546.2671) Lcamera_frontend: 3.4524 (3.5586) Ldepth_frontend: 4.2806 (9.1994) Lpmap_frontend: 11.1786 (13.9453) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.4459 (3.5553) Ldepth_mix: 4.2700 (9.1960) Lpmap_mix: 11.1753 (13.9417) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.4528 (3.5573) Ldepth_backend: 4.2594 (9.1924) Lpmap_backend: 11.1797 (13.9413) Ltrack_backend: 0.0000 (0.0000) total: 4312.5737 (4546.2671) time: 51.9272 data: 0.0327 max mem: 78608 +[2026-05-02 05:52:26,159][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 320/2175] eta: 1 day, 3:03:15 lr: 0.000003 epoch: 0.1425 (0.0736) step: 310.0000 (159.9159) loss: 4932.9053 (4581.0649) Lcamera_frontend: 3.9604 (3.5900) Ldepth_frontend: 4.0451 (9.0433) Lpmap_frontend: 11.1708 (13.8544) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.9503 (3.5867) Ldepth_mix: 4.0337 (9.0398) Lpmap_mix: 11.1636 (13.8506) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.9610 (3.5887) Ldepth_backend: 4.0201 (9.0362) Lpmap_backend: 11.1626 (13.8503) Ltrack_backend: 0.0000 (0.0000) total: 4932.9053 (4581.0649) time: 52.4972 data: 0.0334 max mem: 78608 +[2026-05-02 06:00:53,516][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 330/2175] eta: 1 day, 2:52:51 lr: 0.000003 epoch: 0.1471 (0.0759) step: 320.0000 (164.9154) loss: 4796.5166 (4551.8984) Lcamera_frontend: 3.8577 (3.5674) Ldepth_frontend: 4.5250 (8.9516) Lpmap_frontend: 11.3759 (13.7780) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.8451 (3.5639) Ldepth_mix: 4.5174 (8.9482) Lpmap_mix: 11.3717 (13.7742) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.8583 (3.5661) Ldepth_backend: 4.5046 (8.9447) Lpmap_backend: 11.3782 (13.7738) Ltrack_backend: 0.0000 (0.0000) total: 4796.5166 (4551.8984) time: 52.2163 data: 0.0336 max mem: 78608 +[2026-05-02 06:09:47,760][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 340/2175] eta: 1 day, 2:44:59 lr: 0.000003 epoch: 0.1517 (0.0782) step: 330.0000 (169.9150) loss: 3601.4968 (4512.0299) Lcamera_frontend: 2.8273 (3.5358) Ldepth_frontend: 5.4634 (8.8595) Lpmap_frontend: 11.3144 (13.7031) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.8227 (3.5324) Ldepth_mix: 5.4572 (8.8561) Lpmap_mix: 11.3179 (13.6992) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.8274 (3.5346) Ldepth_backend: 5.4512 (8.8524) Lpmap_backend: 11.3189 (13.6988) Ltrack_backend: 0.0000 (0.0000) total: 3601.4968 (4512.0299) time: 52.0785 data: 0.0387 max mem: 78608 +[2026-05-02 06:18:48,802][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 350/2175] eta: 1 day, 2:37:39 lr: 0.000003 epoch: 0.1563 (0.0805) step: 340.0000 (174.9145) loss: 2794.0251 (4472.1340) Lcamera_frontend: 2.1622 (3.5040) Ldepth_frontend: 4.9128 (8.7775) Lpmap_frontend: 11.3144 (13.6334) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.1565 (3.5005) Ldepth_mix: 4.9140 (8.7743) Lpmap_mix: 11.3179 (13.6294) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.1606 (3.5029) Ldepth_backend: 4.9150 (8.7707) Lpmap_backend: 11.3189 (13.6290) Ltrack_backend: 0.0000 (0.0000) total: 2794.0251 (4472.1340) time: 53.7642 data: 0.0415 max mem: 78608 +[2026-05-02 06:27:40,479][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 360/2175] eta: 1 day, 2:29:26 lr: 0.000003 epoch: 0.1609 (0.0828) step: 350.0000 (179.9141) loss: 3373.9941 (4459.4196) Lcamera_frontend: 2.6730 (3.4950) Ldepth_frontend: 4.7950 (8.6832) Lpmap_frontend: 11.3920 (13.5682) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.6462 (3.4914) Ldepth_mix: 4.7883 (8.6800) Lpmap_mix: 11.3928 (13.5640) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.6709 (3.4939) Ldepth_backend: 4.7817 (8.6766) Lpmap_backend: 11.3938 (13.5635) Ltrack_backend: 0.0000 (0.0000) total: 3373.9941 (4459.4196) time: 53.6358 data: 0.0411 max mem: 78608 +[2026-05-02 06:36:28,529][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 370/2175] eta: 1 day, 2:20:53 lr: 0.000003 epoch: 0.1655 (0.0851) step: 360.0000 (184.9137) loss: 4276.0830 (4470.9654) Lcamera_frontend: 3.3949 (3.5062) Ldepth_frontend: 4.7412 (8.5925) Lpmap_frontend: 11.3355 (13.5019) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3904 (3.5024) Ldepth_mix: 4.7394 (8.5894) Lpmap_mix: 11.3225 (13.4977) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3966 (3.5051) Ldepth_backend: 4.7426 (8.5861) Lpmap_backend: 11.3202 (13.4974) Ltrack_backend: 0.0000 (0.0000) total: 4276.0830 (4470.9654) time: 52.9862 data: 0.0393 max mem: 78608 +[2026-05-02 06:45:04,506][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 380/2175] eta: 1 day, 2:11:23 lr: 0.000003 epoch: 0.1701 (0.0874) step: 370.0000 (189.9134) loss: 4305.5044 (4465.4405) Lcamera_frontend: 3.3949 (3.5032) Ldepth_frontend: 4.7412 (8.5067) Lpmap_frontend: 11.0125 (13.4295) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3904 (3.4992) Ldepth_mix: 4.7394 (8.5037) Lpmap_mix: 11.0114 (13.4251) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3966 (3.5021) Ldepth_backend: 4.7426 (8.5005) Lpmap_backend: 11.0153 (13.4248) Ltrack_backend: 0.0000 (0.0000) total: 4305.5044 (4465.4405) time: 52.2010 data: 0.0353 max mem: 78608 +[2026-05-02 06:53:46,445][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 390/2175] eta: 1 day, 2:02:22 lr: 0.000004 epoch: 0.1747 (0.0897) step: 380.0000 (194.9130) loss: 5305.4634 (4517.7736) Lcamera_frontend: 4.2638 (3.5487) Ldepth_frontend: 4.2713 (8.3975) Lpmap_frontend: 10.0871 (13.3486) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.2560 (3.5446) Ldepth_mix: 4.2597 (8.3945) Lpmap_mix: 10.0863 (13.3441) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.2604 (3.5476) Ldepth_backend: 4.2458 (8.3912) Lpmap_backend: 10.0892 (13.3438) Ltrack_backend: 0.0000 (0.0000) total: 5305.4634 (4517.7736) time: 51.8949 data: 0.0362 max mem: 78608 +[2026-05-02 07:02:23,675][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 400/2175] eta: 1 day, 1:53:02 lr: 0.000004 epoch: 0.1793 (0.0920) step: 390.0000 (199.9127) loss: 5356.4937 (4545.7789) Lcamera_frontend: 4.3305 (3.5739) Ldepth_frontend: 3.6935 (8.2908) Lpmap_frontend: 10.0871 (13.2755) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.3204 (3.5695) Ldepth_mix: 3.6836 (8.2876) Lpmap_mix: 10.0863 (13.2709) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.3308 (3.5728) Ldepth_backend: 3.6729 (8.2844) Lpmap_backend: 10.0892 (13.2706) Ltrack_backend: 0.0000 (0.0000) total: 5356.4937 (4545.7789) time: 51.9578 data: 0.0344 max mem: 78608 +[2026-05-02 07:11:12,339][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 410/2175] eta: 1 day, 1:44:33 lr: 0.000004 epoch: 0.1839 (0.0943) step: 400.0000 (204.9124) loss: 4670.4385 (4541.8820) Lcamera_frontend: 3.7577 (3.5721) Ldepth_frontend: 4.1954 (8.2014) Lpmap_frontend: 10.6621 (13.2190) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.7469 (3.5677) Ldepth_mix: 4.1909 (8.1982) Lpmap_mix: 10.6551 (13.2142) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.7559 (3.5710) Ldepth_backend: 4.1866 (8.1950) Lpmap_backend: 10.6521 (13.2140) Ltrack_backend: 0.0000 (0.0000) total: 4670.4385 (4541.8820) time: 52.2946 data: 0.0385 max mem: 78608 +[2026-05-02 07:20:04,840][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 420/2175] eta: 1 day, 1:36:19 lr: 0.000004 epoch: 0.1885 (0.0966) step: 410.0000 (209.9121) loss: 4242.2949 (4550.7091) Lcamera_frontend: 3.3798 (3.5808) Ldepth_frontend: 4.4150 (8.1197) Lpmap_frontend: 10.9479 (13.1643) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3716 (3.5763) Ldepth_mix: 4.4128 (8.1165) Lpmap_mix: 10.9356 (13.1594) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3819 (3.5797) Ldepth_backend: 4.4077 (8.1132) Lpmap_backend: 10.9319 (13.1591) Ltrack_backend: 0.0000 (0.0000) total: 4242.2949 (4550.7091) time: 53.0581 data: 0.0417 max mem: 78608 +[2026-05-02 07:29:18,365][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 430/2175] eta: 1 day, 1:29:28 lr: 0.000004 epoch: 0.1931 (0.0989) step: 420.0000 (214.9118) loss: 3811.5688 (4513.2318) Lcamera_frontend: 3.0263 (3.5508) Ldepth_frontend: 4.7361 (8.0634) Lpmap_frontend: 10.7967 (13.1007) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.0138 (3.5461) Ldepth_mix: 4.7320 (8.0604) Lpmap_mix: 10.7692 (13.0957) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.0265 (3.5497) Ldepth_backend: 4.7326 (8.0572) Lpmap_backend: 10.7536 (13.0953) Ltrack_backend: 0.0000 (0.0000) total: 3811.5688 (4513.2318) time: 54.3012 data: 0.0388 max mem: 78608 +[2026-05-02 07:38:29,136][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 440/2175] eta: 1 day, 1:22:20 lr: 0.000004 epoch: 0.1977 (0.1011) step: 430.0000 (219.9116) loss: 1805.9749 (4472.2171) Lcamera_frontend: 1.3500 (3.5176) Ldepth_frontend: 5.6222 (8.0148) Lpmap_frontend: 10.7814 (13.0505) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.3179 (3.5129) Ldepth_mix: 5.6236 (8.0120) Lpmap_mix: 10.7807 (13.0453) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.3484 (3.5165) Ldepth_backend: 5.6249 (8.0090) Lpmap_backend: 10.7695 (13.0447) Ltrack_backend: 0.0000 (0.0000) total: 1805.9749 (4472.2171) time: 55.2140 data: 0.0408 max mem: 78608 +[2026-05-02 07:47:23,248][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 450/2175] eta: 1 day, 1:14:03 lr: 0.000004 epoch: 0.2023 (0.1034) step: 440.0000 (224.9135) loss: 4272.9741 (4475.0546) Lcamera_frontend: 3.4222 (3.5213) Ldepth_frontend: 4.8796 (7.9433) Lpmap_frontend: 10.7156 (12.9843) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3944 (3.5163) Ldepth_mix: 4.8895 (7.9405) Lpmap_mix: 10.6794 (12.9789) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.4211 (3.5203) Ldepth_backend: 4.9035 (7.9378) Lpmap_backend: 10.6520 (12.9783) Ltrack_backend: 0.0000 (0.0000) total: 4272.9741 (4475.0546) time: 54.2411 data: 0.0387 max mem: 78608 +[2026-05-02 07:56:08,136][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 460/2175] eta: 1 day, 1:05:10 lr: 0.000004 epoch: 0.2069 (0.1057) step: 450.0000 (229.9154) loss: 4592.2461 (4456.6519) Lcamera_frontend: 3.6815 (3.5071) Ldepth_frontend: 4.3665 (7.8870) Lpmap_frontend: 10.1322 (12.9282) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.6700 (3.5020) Ldepth_mix: 4.3620 (7.8844) Lpmap_mix: 10.1180 (12.9226) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.6839 (3.5061) Ldepth_backend: 4.3593 (7.8818) Lpmap_backend: 10.1137 (12.9220) Ltrack_backend: 0.0000 (0.0000) total: 4592.2461 (4456.6519) time: 52.9476 data: 0.0369 max mem: 78608 +[2026-05-02 08:05:03,264][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 470/2175] eta: 1 day, 0:56:54 lr: 0.000004 epoch: 0.2115 (0.1080) step: 460.0000 (234.9172) loss: 4401.3013 (4452.7577) Lcamera_frontend: 3.4876 (3.5049) Ldepth_frontend: 5.2059 (7.8311) Lpmap_frontend: 10.1322 (12.8764) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.4712 (3.4997) Ldepth_mix: 5.2072 (7.8286) Lpmap_mix: 10.1180 (12.8708) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.4867 (3.5039) Ldepth_backend: 5.2040 (7.8261) Lpmap_backend: 10.1137 (12.8701) Ltrack_backend: 0.0000 (0.0000) total: 4401.3013 (4452.7577) time: 53.0006 data: 0.0382 max mem: 78608 +[2026-05-02 08:13:50,075][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 480/2175] eta: 1 day, 0:48:07 lr: 0.000004 epoch: 0.2161 (0.1103) step: 470.0000 (239.9189) loss: 4401.3013 (4430.6779) Lcamera_frontend: 3.4876 (3.4876) Ldepth_frontend: 4.5972 (7.7831) Lpmap_frontend: 9.6637 (12.8144) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.4712 (3.4824) Ldepth_mix: 4.6037 (7.7807) Lpmap_mix: 9.6514 (12.8086) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.4867 (3.4866) Ldepth_backend: 4.6081 (7.7782) Lpmap_backend: 9.6520 (12.8080) Ltrack_backend: 0.0000 (0.0000) total: 4401.3013 (4430.6779) time: 53.0968 data: 0.0384 max mem: 78608 +[2026-05-02 08:22:05,403][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 490/2175] eta: 1 day, 0:37:33 lr: 0.000005 epoch: 0.2207 (0.1126) step: 480.0000 (244.9206) loss: 4241.4062 (4436.7119) Lcamera_frontend: 3.4004 (3.4938) Ldepth_frontend: 4.5972 (7.7271) Lpmap_frontend: 9.6855 (12.7566) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3784 (3.4883) Ldepth_mix: 4.6037 (7.7248) Lpmap_mix: 9.6663 (12.7505) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3980 (3.4928) Ldepth_backend: 4.6081 (7.7225) Lpmap_backend: 9.6467 (12.7499) Ltrack_backend: 0.0000 (0.0000) total: 4241.4062 (4436.7119) time: 51.1069 data: 0.0374 max mem: 78608 +[2026-05-02 08:30:47,070][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 500/2175] eta: 1 day, 0:28:32 lr: 0.000005 epoch: 0.2253 (0.1149) step: 490.0000 (249.9222) loss: 4106.1992 (4414.7409) Lcamera_frontend: 3.2480 (3.4764) Ldepth_frontend: 4.6061 (7.6826) Lpmap_frontend: 9.8494 (12.7050) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.2321 (3.4708) Ldepth_mix: 4.6072 (7.6805) Lpmap_mix: 9.8298 (12.6988) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.2501 (3.4755) Ldepth_backend: 4.6085 (7.6784) Lpmap_backend: 9.8256 (12.6981) Ltrack_backend: 0.0000 (0.0000) total: 4106.1992 (4414.7409) time: 50.8490 data: 0.0448 max mem: 78608 +[2026-05-02 08:39:35,290][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 510/2175] eta: 1 day, 0:19:53 lr: 0.000005 epoch: 0.2299 (0.1172) step: 500.0000 (254.9237) loss: 3933.6536 (4417.0990) Lcamera_frontend: 3.1381 (3.4795) Ldepth_frontend: 4.8565 (7.6292) Lpmap_frontend: 9.5637 (12.6428) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.1143 (3.4737) Ldepth_mix: 4.8704 (7.6272) Lpmap_mix: 9.5520 (12.6364) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.1420 (3.4786) Ldepth_backend: 4.8826 (7.6252) Lpmap_backend: 9.5510 (12.6356) Ltrack_backend: 0.0000 (0.0000) total: 3933.6536 (4417.0990) time: 52.4931 data: 0.0484 max mem: 78608 +[2026-05-02 08:48:26,710][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 520/2175] eta: 1 day, 0:11:23 lr: 0.000005 epoch: 0.2345 (0.1195) step: 510.0000 (259.9251) loss: 4405.1108 (4418.4188) Lcamera_frontend: 3.5306 (3.4815) Ldepth_frontend: 4.9626 (7.5859) Lpmap_frontend: 9.6513 (12.5959) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5132 (3.4755) Ldepth_mix: 4.9785 (7.5840) Lpmap_mix: 9.6373 (12.5894) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5304 (3.4807) Ldepth_backend: 4.9895 (7.5821) Lpmap_backend: 9.6339 (12.5885) Ltrack_backend: 0.0000 (0.0000) total: 4405.1108 (4418.4188) time: 52.9814 data: 0.0400 max mem: 78608 +[2026-05-02 08:57:30,760][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 530/2175] eta: 1 day, 0:03:33 lr: 0.000005 epoch: 0.2391 (0.1218) step: 520.0000 (264.9266) loss: 4418.3706 (4418.6459) Lcamera_frontend: 3.5406 (3.4826) Ldepth_frontend: 4.6856 (7.5320) Lpmap_frontend: 10.3263 (12.5562) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5132 (3.4765) Ldepth_mix: 4.6748 (7.5302) Lpmap_mix: 10.3062 (12.5494) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5408 (3.4818) Ldepth_backend: 4.6644 (7.5283) Lpmap_backend: 10.3058 (12.5485) Ltrack_backend: 0.0000 (0.0000) total: 4418.3706 (4418.6459) time: 53.7734 data: 0.0369 max mem: 78608 +[2026-05-02 09:06:20,889][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 540/2175] eta: 23:54:57 lr: 0.000005 epoch: 0.2437 (0.1241) step: 530.0000 (269.9279) loss: 3846.4263 (4401.2905) Lcamera_frontend: 3.0784 (3.4690) Ldepth_frontend: 4.6856 (7.5025) Lpmap_frontend: 10.2784 (12.5089) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.0581 (3.4628) Ldepth_mix: 4.6748 (7.5010) Lpmap_mix: 10.2581 (12.5019) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.0781 (3.4681) Ldepth_backend: 4.6644 (7.4993) Lpmap_backend: 10.2489 (12.5010) Ltrack_backend: 0.0000 (0.0000) total: 3846.4263 (4401.2905) time: 53.7088 data: 0.0370 max mem: 78608 +[2026-05-02 09:14:58,674][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 550/2175] eta: 23:45:44 lr: 0.000005 epoch: 0.2483 (0.1264) step: 540.0000 (274.9292) loss: 4376.6650 (4425.2595) Lcamera_frontend: 3.4907 (3.4897) Ldepth_frontend: 5.0376 (7.4599) Lpmap_frontend: 10.0351 (12.4690) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.4843 (3.4835) Ldepth_mix: 5.0383 (7.4585) Lpmap_mix: 10.0193 (12.4619) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.4932 (3.4889) Ldepth_backend: 5.0365 (7.4568) Lpmap_backend: 10.0167 (12.4609) Ltrack_backend: 0.0000 (0.0000) total: 4376.6650 (4425.2595) time: 52.3955 data: 0.0353 max mem: 78608 +[2026-05-02 09:23:45,990][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 560/2175] eta: 23:37:00 lr: 0.000005 epoch: 0.2529 (0.1287) step: 550.0000 (279.9305) loss: 4772.4141 (4448.5945) Lcamera_frontend: 3.8427 (3.5101) Ldepth_frontend: 4.4398 (7.4164) Lpmap_frontend: 9.9974 (12.4268) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.8256 (3.5037) Ldepth_mix: 4.4245 (7.4150) Lpmap_mix: 9.9801 (12.4196) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.8420 (3.5092) Ldepth_backend: 4.4170 (7.4134) Lpmap_backend: 9.9885 (12.4186) Ltrack_backend: 0.0000 (0.0000) total: 4772.4141 (4448.5945) time: 52.2548 data: 0.0374 max mem: 78608