diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..537b6a62deafef5d00d5c297eeee031b676956bf --- /dev/null +++ b/app.py @@ -0,0 +1,102 @@ +from __future__ import absolute_import, division, print_function + +import os, sys +import cv2 +import yaml +import torch +import numpy as np +import torch.nn as nn +import gradio as gr +from huggingface_hub import hf_hub_download + +# ========== 让 Space 能 import 你的工程 ========== +PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) # app.py 在仓库根目录 +sys.path.append(PROJECT_ROOT) + +from networks.models import make + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# ====== HF 权重仓库配置(你已经上传成功)====== +WEIGHTS_REPO = "Insta360-Research/DAP-weights" +WEIGHTS_FILE = "model.pth" + +# ========== 可视化 ========== +def colorize_depth(depth, colormap=cv2.COLORMAP_JET): + depth = depth.astype(np.float32) + depth_norm = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6) + depth_u8 = (depth_norm * 255).astype(np.uint8) + return cv2.applyColorMap(depth_u8, colormap) # BGR + +# ========== 加载模型(只加载一次) ========== +def load_model(config_path: str): + with open(config_path, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + print(f"Downloading weights from HF: {WEIGHTS_REPO}/{WEIGHTS_FILE}") + model_path = hf_hub_download(repo_id=WEIGHTS_REPO, filename=WEIGHTS_FILE) + print(f"✅ Weights downloaded to: {model_path}") + + state = torch.load(model_path, map_location=device) + + model = make(config["model"]) + if any(k.startswith("module") for k in state.keys()): + model = nn.DataParallel(model) + + model = model.to(device) + + model_state = model.state_dict() + model.load_state_dict({k: v for k, v in state.items() if k in model_state}, strict=False) + model.eval() + print("✅ Model loaded.") + return model + +# 这里改成你仓库里的 config 路径 +CONFIG_PATH = "config/infer.yaml" +model = load_model(CONFIG_PATH) + +# ========== 单张图推理 ========== +@torch.no_grad() +def predict(img_rgb: np.ndarray): + """ + img_rgb: H x W x 3 (RGB), uint8 + return: depth_color_rgb, depth_gray + """ + if img_rgb is None: + return None, None + + img = img_rgb.astype(np.float32) / 255.0 + tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device) + + outputs = model(tensor) + + if isinstance(outputs, dict) and "pred_depth" in outputs: + # 你原来的 mask 逻辑 + if "pred_mask" in outputs: + outputs["pred_mask"] = 1 - outputs["pred_mask"] + outputs["pred_mask"] = (outputs["pred_mask"] > 0.5) + outputs["pred_depth"][~outputs["pred_mask"]] = 1 + pred = outputs["pred_depth"][0].detach().cpu().squeeze().numpy() + else: + pred = outputs[0].detach().cpu().squeeze().numpy() + + pred_clip = np.clip(pred, 0.001, 1.0) + depth_gray = (pred_clip * 255).astype(np.uint8) + + depth_color_bgr = colorize_depth(pred_clip, cv2.COLORMAP_JET) + depth_color_rgb = cv2.cvtColor(depth_color_bgr, cv2.COLOR_BGR2RGB) + + return depth_color_rgb, depth_gray + +demo = gr.Interface( + fn=predict, + inputs=gr.Image(type="numpy", label="Input Image"), + outputs=[ + gr.Image(type="numpy", label="Depth (Color)"), + gr.Image(type="numpy", label="Depth (Gray)"), + ], + title="DAP Depth Prediction Demo", + description="Upload an image and get depth prediction." +) + +demo.launch() diff --git a/config/infer.yaml b/config/infer.yaml new file mode 100755 index 0000000000000000000000000000000000000000..ae2b5026913e54ccf935a96aa69bdea02f413d3f --- /dev/null +++ b/config/infer.yaml @@ -0,0 +1,19 @@ +model: + name: dap + args: + midas_model_type: vitl + fine_tune_type: hypersim + min_depth: 0.01 + max_depth: 1.0 + train_decoder: True + +median_align: False +load_weights_dir: checkpoints +input: + height: 512 + width: 1024 +inference: + batch_size: 1 + num_workers: 1 + save_colormap: True + colormap_type: jet diff --git a/depth_anything_utils.py b/depth_anything_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..c1911789cd7ab307d9bed14217811a90d795e614 --- /dev/null +++ b/depth_anything_utils.py @@ -0,0 +1,249 @@ +import os +import random +from PIL import Image, ImageOps, ImageFilter +import torch +from torchvision import transforms +import torch.nn.functional as F + +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + if "semseg_mask" in sample: + # sample["semseg_mask"] = cv2.resize( + # sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST + # ) + sample["semseg_mask"] = F.interpolate(torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode='nearest').numpy()[0, 0] + + if "mask" in sample: + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + # sample["mask"] = sample["mask"].astype(bool) + + # print(sample['image'].shape, sample['depth'].shape) + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + if "semseg_mask" in sample: + sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32) + sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"]) + + return sample \ No newline at end of file diff --git a/depth_anything_v2_metric/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc b/depth_anything_v2_metric/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3456b8d5b4640b0825bb3f1e5e654cf14531c2c Binary files /dev/null and b/depth_anything_v2_metric/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc differ diff --git a/depth_anything_v2_metric/depth_anything_v2/__pycache__/dinov3_adpther.cpython-310.pyc b/depth_anything_v2_metric/depth_anything_v2/__pycache__/dinov3_adpther.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2089930df4b6b9d0c9c97e723d7bb937efa9f1c Binary files /dev/null and b/depth_anything_v2_metric/depth_anything_v2/__pycache__/dinov3_adpther.cpython-310.pyc differ diff --git a/depth_anything_v2_metric/depth_anything_v2/__pycache__/dpt.cpython-310.pyc b/depth_anything_v2_metric/depth_anything_v2/__pycache__/dpt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef5e545b705bc81af9e49be5deeb9a7899468a3b Binary files /dev/null and b/depth_anything_v2_metric/depth_anything_v2/__pycache__/dpt.cpython-310.pyc differ diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov2.py b/depth_anything_v2_metric/depth_anything_v2/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..ec4499a18330523aa3564b16be70e813de000c94 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov2.py @@ -0,0 +1,415 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + # w0, h0 = w0 + 0.1, h0 + 0.1 + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + # (int(w0), int(h0)), # to solve the upsampling shape issue + mode="bicubic", + antialias=self.interpolate_antialias + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def DINOv2(model_name): + model_zoo = { + "vits": vit_small, + "vitb": vit_base, + "vitl": vit_large, + "vitg": vit_giant2 + } + + return model_zoo[model_name]( + img_size=518, + patch_size=14, + init_values=1.0, + ffn_layer="mlp" if model_name != "vitg" else "swiglufused", + block_chunks=0, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1 + ) \ No newline at end of file diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb06ee02d5223d750c0f57e682102198e329e8ac Binary files /dev/null and b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc differ diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3876b5800cbf26b88e58301e793f3a773da4a4e2 Binary files /dev/null and b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc differ diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..763e87984fec019d21dff7ecda5dd052f7086899 Binary files /dev/null and b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc differ diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1fb23ece1f3111f59d7a327304b63935c49ac49 Binary files /dev/null and b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc differ diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49eff2d64e8455928d2bc55c8ed25a506f639829 Binary files /dev/null and b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc differ diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45a1ce6a94a0beff16289bcbb5a2ae67505c60e8 Binary files /dev/null and b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc differ diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cff815b69972eccf92c7a3493c8aaee05159de28 Binary files /dev/null and b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc differ diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5040d94da4ac965d27d1003c3bbf74d6d818731a Binary files /dev/null and b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc differ diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/attention.py b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..815a2bf53dbec496f6a184ed7d03bcecb7124262 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/attention.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + \ No newline at end of file diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/block.py b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/block.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, List, Any, Tuple, Dict + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/drop_path.py b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/drop_path.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/layer_scale.py b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/mlp.py b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/patch_embed.py b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..574abe41175568d700a389b8b96d1ba554914779 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/patch_embed.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/swiglu_ffn.py b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov2_layers/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/.docstr.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/.docstr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8c822fbc0f2f84d90629a1f70502c7e3dfd5f7f4 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/.docstr.yaml @@ -0,0 +1,6 @@ +paths: + - dinov3 +exclude: dinov3/tests +skip_init: True +skip_private: True +fail_under: 0 diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/.github/workflows/lint.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/.github/workflows/lint.yaml new file mode 100644 index 0000000000000000000000000000000000000000..53a5de1f2d5b756354b2e7ed88cb3c0d7442a959 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/.github/workflows/lint.yaml @@ -0,0 +1,47 @@ +name: Lint + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + run-linters: + name: Run linters + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.11 + cache: 'pip' + cache-dependency-path: '**/requirements*.txt' + - name: Install Python (development) dependencies + run: | + pip install -r requirements-dev.txt + - name: Run ruff (linter) + run: | + ruff check dinov3 + - name: Run ruff (formatter) + if: always() + run: | + ruff format --diff dinov3 + - name: Report docstring coverage + if: always() + run: | + docstr-coverage dinov3 + - name: Run mypy + if: always() + run: | + mypy --txt-report . + [ -f index.txt ] && cat index.txt + - name: Run pylint + if: always() + run: | + pylint --exit-zero dinov3 diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/.gitignore b/depth_anything_v2_metric/depth_anything_v2/dinov3/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..bc9665c2e6aae2080bccc03c3e7bf485b2a2c25a --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/.gitignore @@ -0,0 +1,18 @@ +build/ +dist/ +*.egg-info/ +**/__pycache__/ + +**/.ipynb_checkpoints +**/.ipynb_checkpoints/** + +**/notebooks + +# Ignore shell scripts +*.sh + +# Ignore swap files +*.swp + +# Ignore vscode directory +.vscode/ diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/CODE_OF_CONDUCT.md b/depth_anything_v2_metric/depth_anything_v2/dinov3/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..3232ed665566ec047ce55a929db1581dbda266a1 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/CONTRIBUTING.md b/depth_anything_v2_metric/depth_anything_v2/dinov3/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..c5f0c3e75e4921e5aa706898b4f6eaccd8e7d009 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to DINOv3 +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Meta's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to DINOv3, you agree that your contributions will be licensed +under the LICENSE.md file in the root directory of this source tree. diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/LICENSE.md b/depth_anything_v2_metric/depth_anything_v2/dinov3/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..f531b1e6b5ab2318957bbf8ad1bda9f800a23e17 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/LICENSE.md @@ -0,0 +1,66 @@ +# DINOv3 License + +*Last Updated: August 19, 2025* + +**“Agreement”** means the terms and conditions for use, reproduction, distribution and modification of the DINO Materials set forth herein. + +**“DINO Materials”** means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement. + +**“Documentation”** means the specifications, manuals and documentation accompanying +DINO Materials distributed by Meta. + +**“Licensee”** or **“you”** means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. + +**“Meta”** or **“we”** means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland). + +**“Sanctions”** means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom. + +**“Trade Controls”** means any of the following: Sanctions and applicable export and import controls. + +By clicking “I Accept” below or by using or distributing any portion or element of the DINO Materials, you agree to be bound by this Agreement. + +## 1. License Rights and Redistribution. + +a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the DINO Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the DINO Materials. + +b. Redistribution and Use. + +i. Distribution of DINO Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the DINO Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such DINO Materials. + +ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with DINO Materials, you must acknowledge the use of DINO Materials in your publication. + +iii. Your use of the DINO Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws. + +iv. Your use of the DINO Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the DINO Materials. + +v. You are not the target of Trade Controls and your use of DINO Materials must comply with Trade Controls. You agree not to use, or permit others to use, DINO Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons. + +## 2. User Support. + +Your use of the DINO Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the DINO Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind. + +## 3. Disclaimer of Warranty. + +UNLESS REQUIRED BY APPLICABLE LAW, THE DINO MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE DINO MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE DINO MATERIALS AND ANY OUTPUT AND RESULTS. + +## 4. Limitation of Liability. + +IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. + +## 5. Intellectual Property. + +a. Subject to Meta’s ownership of DINO Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the DINO Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications. + +b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the DINO Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the DINO Materials. + +## 6. Term and Termination. + +The term of this Agreement will commence upon your acceptance of this Agreement or access to the DINO Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the DINO Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement. + +## 7. Governing Law and Jurisdiction. + +This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement. + +## 8. Modifications and Amendments. + +Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the DINO Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta. diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/MODEL_CARD.md b/depth_anything_v2_metric/depth_anything_v2/dinov3/MODEL_CARD.md new file mode 100644 index 0000000000000000000000000000000000000000..2bd45fb07beab1ad5f7a734608e79d5fd8899b42 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/MODEL_CARD.md @@ -0,0 +1,432 @@ +# Model Card for DINOv3 + +DINOv3 is a family of versatile vision foundation models that outperforms the specialized state of the art across a broad range of settings, without fine-tuning. DINOv3 produces high-quality dense features that achieve outstanding performance on various vision tasks, significantly surpassing previous self- and weakly-supervised foundation models. + +## Model Details + +These are Vision Transformer and ConvNeXt models trained following the method described in the DINOv3 paper. 12 models are provided: + +- 10 models pretrained on web data (LVD-1689M dataset) + - 1 ViT-7B trained from scratch, + - 5 ViT-S/S+/B/L/H+ models distilled from the ViT-7B, + - 4 ConvNeXt-{T/S/B/L} models distilled from the ViT-7B, +- 2 models pretrained on satellite data (SAT-493M dataset) + - 1 ViT-7B trained from scratch + - 1 ViT-L distilled from the ViT-7B + + +Each Transformer-based model takes an image as input and returns a class token, patch tokens (and register tokens). These models follow a ViT architecture, with a patch size of 16. For a 224x224 image, this results in 1 class token + 4 register tokens + 196 patch tokens = 201 tokens (for DINOv2 with registers this resulted in 1 + 4 + 256 = 261 tokens). + +The models can accept larger images provided the image shapes are multiples of the patch size (16). If this condition is not verified, the model will crop to the closest smaller multiple of the patch size. + +### Model Description + +- **Developed by:** Meta AI +- **Model type:** Vision Transformer, ConvNeXt +- **License:** [DINOv3 License](https://ai.meta.com/resources/models-and-libraries/dinov3-license/) + +### Model Sources + +- **Repository:** [https://github.com/facebookresearch/dinov3](https://github.com/facebookresearch/dinov3) +- **Paper:** [https://arxiv.org/abs/2508.10104](https://arxiv.org/abs/2508.10104) + +## Uses + +The models are vision backbones providing multi-purpose features for downstream tasks. + +### Direct Use + +The models can be used without fine-tuning, with downstream classifiers as simple as linear layers, to obtain competitive results: + +- on image classification, using k-NN classifiers on the class token +- on image classification, with logistic regression classifiers applied on the class token +- on image classification, with a linear layer applied on the class token and the average of the patch tokens +- on image retrieval using nearest neighbors +- on geometric and semantic 3D keypoint correspondances +- on depth estimation, semantic segmentation, using linear layers +- on unsupervised object discovery +- on video segmentation tracking +- on video classification, using a small 4-layer attentive probe + +### Downstream Use + +While fine-tuning the models can yield some gains, it is recommended to keep this option as a last resort: the frozen features are expected to provide good performance out-of-the-box. + +## Bias, Risks, and Limitations + +Compared to DINOv2 and SEERv2, DINOv3 delivers somewhat consistent performance across income categories on geographical fairness and diversity, although with a notable performance drop in the low-income bucket compared to the highest-income bucket. + +DINOv3 also achieves relatively good scores across different regions, improving over its predecessor DINOv2. However, a relative difference is still observed between Europe and Africa. + +### Recommendations + +Fine-tuning is expected to increase the biases in the features produced by the model as they will be tuned to the fine-tuning labels. + +## How to Get Started with the Model + +Use the code below to get started with the model. + +```python +import torch + +model = torch.hub.load( + repo_or_dir='facebookresearch/dinov3', + model='', + weights='', +) + +# where MODEL_NAME can be one of: +# - dinov3_vits16 +# - dinov3_vits16plus +# - dinov3_vitb16 +# - dinov3_vitl16 +# - dinov3_vith16plus +# - dinov3_vit7b16 +# - dinov3_convnext_tiny +# - dinov3_convnext_small +# - dinov3_convnext_base +# - dinov3_convnext_large + +# For instance +dinov3_vits16 = torch.hub.load( + repo_or_dir='facebookresearch/dinov3', + model='dinov3_vits16', + weights='', +) +``` + +## Training Details + +### Training Data + +- Web dataset (LVD-1689M): a curated dataset of 1,689 millions of images extracted from a large data +pool of 17 billions web images collected from public posts on Instagram + +- Satellite dataset (SAT-493M): a dataset of 493 millions of 512x512 images sampled randomly from Maxar RGB ortho-rectified imagery at 0.6 meter resolution + +### Training Procedure + +**Training objective:** + +- DINO self-distillation loss with multi-crop +- iBOT masked-image modeling loss +- KoLeo regularization on [CLS] tokens +- Gram anchoring + +- **Training regime:** PyTorch FSDP2 (with bf16 and fp8 matrix multiplications) + +**Distillation:** + +- Distillation follows the standard DINOv3 pretraining procedure, except the teacher is a frozen pretrained ViT-7B. + +## Evaluation + +**Results** + +The reader is referred to the associated paper for details on the evaluation protocols + +*Results for ViT backbones pretrained (or distilled) on web (LVD-1689M)* + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Global TasksDense Tasks
ModelIN-ReaLIN-RObj.NetOx.-HADE20kNYU↓DAVISNAVISPair
DINOv3 ViT-S/1687.060.450.949.547.00.40372.756.350.4
DINOv3 ViT-S+/1688.068.854.650.048.80.39975.557.155.2
DINOv3 ViT-B/1689.376.764.158.551.80.37377.258.857.2
DINOv3 ViT-L/1690.288.174.863.154.90.35279.962.361.3
DINOv3 ViT-H+/1690.390.078.664.554.80.35279.363.356.3
DINOv3 ViT-7B/1690.491.191.172.855.90.30979.764.458.7
+ +*Results for ConvNeXt backbones distilled on web (LVD-1689M)* + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Global TasksDense Tasks
ModelIN-ReaLIN-RObj.NetADE20kNYU↓
+ @256px@512px@256px@512px@256px@512px
DINOv3 ConvNeXt Tiny86.687.773.774.152.658.742.70.448
DINOv3 ConvNeXt Small87.988.773.774.152.658.744.80.432
DINOv3 ConvNeXt Base88.589.277.278.256.261.346.30.420
DINOv3 ConvNeXt Large88.989.481.382.459.365.247.80.403
+ +*Results for ViT backbones pretrained (or distilled) on satellite (SAT-493M)* + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
(GEO-Bench) Classification
Model + m-BEnetm-brick-kiln + m-eurosatm-forestnetm-pv4germ-so2satmean
DINOv3 ViT-L/1673.096.594.160.696.057.479.6
DINOv3 ViT-7B/1674.097.294.862.396.162.181.1
(GEO-Bench) Segmentation
Modelm-cashewm-chesapeakem-NeonTreem-nz-cattlem-pv4ger-segm-SA-cropmean
DINOv3 ViT-L/1694.275.661.883.795.236.874.5
DINOv3 ViT-7B/1694.176.662.683.495.537.675.0
+ + +## Environmental Impact + +- **Hardware Type:** Nvidia H100 +- **Hours used:** 61,440 hours for ViT-7B model training +- **Cloud Provider:** Private infrastructure +- **Compute Region:** USA +- **Carbon Emitted:** 18t CO2eq + +## Technical Specifications + +### Model Architecture and Objective + +Vision Transformer models: + +- ViT-S (21M parameters): patch size 16, embedding dimension 384, 4 register tokens, 6 heads, MLP FFN, RoPE +- ViT-S+ (29M parameters): patch size 16, embedding dimension 384, 4 register tokens, 6 heads, SwiGLU FFN, RoPE +- ViT-B (86M parameters): patch size 16, embedding dimension 768, 4 register tokens, 12 heads, MLP FFN, RoPE +- ViT-L (300M parameters): patch size 16, embedding dimension 1024, 4 register tokens, 16 heads, MLP FFN, RoPE +- ViT-H+ (840M parameters): patch size 16, embedding dimension 1280, 4 register tokens, 20 heads, SwiGLU FFN, RoPE +- ViT-7B (6716M parameters): patch size 16, embedding dimension 4096, 4 register tokens, 32 heads, SwiGLU FFN, RoPE + +ConvNeXt models: + +- ConvNeXt Tiny (29M parameters) +- ConvNeXt Small (50M parameters) +- ConvNeXt Base (89M parameters) +- ConvNeXt Large (198M parameters) + +### Compute Infrastructure + +#### Hardware + +Nvidia H100 GPUs + +#### Software + +PyTorch 2.7 + +## More Information + +See the [blog post](https://ai.meta.com/blog/dinov3-self-supervised-vision-model/) and the associated [website](https://ai.meta.com/dinov3/). + +## Citation + +**BibTeX** + +``` +@misc{simeoni2025dinov3, + title={{DINOv3}}, + author={Sim{\'e}oni, Oriane and Vo, Huy V. and Seitzer, Maximilian and Baldassarre, Federico and Oquab, Maxime and Jose, Cijo and Khalidov, Vasil and Szafraniec, Marc and Yi, Seungeun and Ramamonjisoa, Micha{\"e}l and Massa, Francisco and Haziza, Daniel and Wehrstedt, Luca and Wang, Jianyuan and Darcet, Timoth{\'e}e and Moutakanni, Th{\'e}o and Sentana, Leonel and Roberts, Claire and Vedaldi, Andrea and Tolan, Jamie and Brandt, John and Couprie, Camille and Mairal, Julien and J{\'e}gou, Herv{\'e} and Labatut, Patrick and Bojanowski, Piotr}, + year={2025}, + eprint={2508.10104}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2508.10104}, +} +``` diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/README.md b/depth_anything_v2_metric/depth_anything_v2/dinov3/README.md new file mode 100644 index 0000000000000000000000000000000000000000..18b1a35ea4d1d20f9ef6fb795b71864558276395 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/README.md @@ -0,0 +1,734 @@ +🆕 [2025-08-14] :fire: DINOv3 backbones are now available in [Hugging Face Hub](https://huggingface.co/collections/facebook/dinov3-68924841bd6b561778e31009) and [supported](https://huggingface.co/docs/transformers/model_doc/dinov3) by the Hugging Face [Transformers](https://huggingface.co/docs/transformers/index) library + +# DINOv3 🦖🦖🦖 + +**[Meta AI Research, FAIR](https://ai.meta.com/research/)** + +Oriane Siméoni, Huy V. Vo, Maximilian Seitzer, Federico Baldassarre, Maxime Oquab,
+Cijo Jose, Vasil Khalidov, Marc Szafraniec, Seungeun Yi, Michaël Ramamonjisoa,
+Francisco Massa, Daniel Haziza, Luca Wehrstedt, Jianyuan Wang,
+Timothée Darcet, Théo Moutakanni, Leonel Sentana, Claire Roberts,
+Andrea Vedaldi, Jamie Tolan, John Brandt, Camille Couprie,
+Julien Mairal, Hervé Jégou, Patrick Labatut, Piotr Bojanowski + +[ :scroll: [`Paper`](https://arxiv.org/abs/2508.10104)] [ :newspaper: [`Blog`](https://ai.meta.com/blog/dinov3-self-supervised-vision-model/)] [ :globe_with_meridians: [`Website`](https://ai.meta.com/dinov3/)] [ :book: [`BibTeX`](#citing-dinov3)] + +Reference PyTorch implementation and models for DINOv3. For details, see the **[DINOv3](https://arxiv.org/abs/2508.10104)** paper. + +## Overview + +
+ market + + High-resolution dense features.
We visualize the cosine similarity maps obtained with DINOv3 output features
between the patches marked with a red cross and all other patches.
+
+ +
+ +An extended family of versatile vision foundation models producing high-quality dense features and achieving outstanding performance on various vision tasks including outperforming the specialized state of the art across a broad range of settings, without fine-tuning + +## Pretrained models + +:information_source: Please follow the link provided below to get access to all the model weights: once accepted, an e-mail will be sent with the complete list of URLs pointing to all the available model weights (both backbones and adapters). These URLs can then be used to either: +- download the model or adapter weights to a local filesystem and point `torch.hub.load()` to these local weights via the `weights` or `backbone_weights` parameters, or +- directly invoke `torch.hub.load()` to download and load a backbone or an adapter from its URL via also the `weights` or `backbone_weights` parameters. + +See the example code snippets below. + +:warning: Please use `wget` instead of a web browser to download the weights. + +ViT models pretrained on web dataset (LVD-1689M): + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelParametersPretraining
Dataset
Download
ViT-S/16 distilled 21MLVD-1689M[link]
ViT-S+/16 distilled29MLVD-1689M[link]
ViT-B/16 distilled86MLVD-1689M[link]
ViT-L/16 distilled300MLVD-1689M[link]
ViT-H+/16 distilled840MLVD-1689M[link]
ViT-7B/166,716MLVD-1689M[link]
+ +ConvNeXt models pretrained on web dataset (LVD-1689M): + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelParametersPretraining
Dataset
Download
ConvNeXt Tiny29MLVD-1689M[link]
ConvNeXt Small50MLVD-1689M[link]
ConvNeXt Base89MLVD-1689M[link]
ConvNeXt Large198MLVD-1689M[link]
+ +ViT models pretrained on satellite dataset (SAT-493M): + + + + + + + + + + + + + + + + + + + + + + + +
ModelParametersPretraining
Dataset
Download
ViT-L/16 distilled300MSAT-493M[link]
ViT-7B/166,716MSAT-493M[link]
+ + +### Pretrained backbones (via PyTorch [Hub](https://docs.pytorch.org/docs/stable/hub.html)) + +Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install PyTorch (the only required dependency for loading the model). Installing PyTorch with CUDA support is strongly recommended. + +```python +import torch + +REPO_DIR = + +# DINOv3 ViT models pretrained on web images +dinov3_vits16 = torch.hub.load(REPO_DIR, 'dinov3_vits16', source='local', weights=) +dinov3_vits16plus = torch.hub.load(REPO_DIR, 'dinov3_vits16plus', source='local', weights=) +dinov3_vitb16 = torch.hub.load(REPO_DIR, 'dinov3_vitb16', source='local', weights=) +dinov3_vitl16 = torch.hub.load(REPO_DIR, 'dinov3_vitl16', source='local', weights=) +dinov3_vith16plus = torch.hub.load(REPO_DIR, 'dinov3_vith16plus', source='local', weights=) +dinov3_vit7b16 = torch.hub.load(REPO_DIR, 'dinov3_vit7b16', source='local', weights=) + +# DINOv3 ConvNeXt models pretrained on web images +dinov3_convnext_tiny = torch.hub.load(REPO_DIR, 'dinov3_convnext_tiny', source='local', weights=) +dinov3_convnext_small = torch.hub.load(REPO_DIR, 'dinov3_convnext_small', source='local', weights=) +dinov3_convnext_base = torch.hub.load(REPO_DIR, 'dinov3_convnext_base', source='local', weights=) +dinov3_convnext_large = torch.hub.load(REPO_DIR, 'dinov3_convnext_large', source='local', weights=) + +# DINOv3 ViT models pretrained on satellite imagery +dinov3_vitl16 = torch.hub.load(REPO_DIR, 'dinov3_vitl16', source='local', weights=) +dinov3_vit7b16 = torch.hub.load(REPO_DIR, 'dinov3_vit7b16', source='local', weights=) +``` + +### Pretrained backbones (via Hugging Face [Transformers](https://huggingface.co/docs/transformers/)) + +All the backbones are available in the the [DINOv3](https://huggingface.co/collections/facebook/dinov3-68924841bd6b561778e31009) collection on Hugging Face Hub and supported via the Hugging Face [Transformers](https://huggingface.co/docs/transformers/index) library. Please refer to the corresponding documentation for usage, but below is a short example that demonstrates how to obtain an image embedding with either [Pipeline] or the [AutoModel] class. + +```python +from transformers import pipeline +from transformers.image_utils import load_image + +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" +image = load_image(url) + +feature_extractor = pipeline( + model="facebook/dinov3-convnext-tiny-pretrain-lvd1689m", + task="image-feature-extraction", +) +features = feature_extractor(image) +``` + +```python +import torch +from transformers import AutoImageProcessor, AutoModel +from transformers.image_utils import load_image + +url = "http://images.cocodataset.org/val2017/000000039769.jpg" +image = load_image(url) + +pretrained_model_name = "facebook/dinov3-convnext-tiny-pretrain-lvd1689m" +processor = AutoImageProcessor.from_pretrained(pretrained_model_name) +model = AutoModel.from_pretrained( + pretrained_model_name, + device_map="auto", +) + +inputs = processor(images=image, return_tensors="pt").to(model.device) +with torch.inference_mode(): + outputs = model(**inputs) + +pooled_output = outputs.pooler_output +print("Pooled output shape:", pooled_output.shape) +``` + +where `model` and `pretrained_model_name` above can be one of: +- `facebook/dinov3-vits16-pretrain-lvd1689m` +- `facebook/dinov3-vits16plus-pretrain-lvd1689m` +- `facebook/dinov3-vitb16-pretrain-lvd1689m` +- `facebook/dinov3-vitl16-pretrain-lvd1689m` +- `facebook/dinov3-vith16plus-pretrain-lvd1689m` +- `facebook/dinov3-vit7b16-pretrain-lvd1689m` +- `facebook/dinov3-convnext-base-pretrain-lvd1689m` +- `facebook/dinov3-convnext-large-pretrain-lvd1689m` +- `facebook/dinov3-convnext-small-pretrain-lvd1689m` +- `facebook/dinov3-convnext-tiny-pretrain-lvd1689m` +- `facebook/dinov3-vitl16-pretrain-sat493m` +- `facebook/dinov3-vit7b16-pretrain-sat493m` + +### Image transforms + +For models using the LVD-1689M weights (pretrained on web images), please use the following transform (standard ImageNet evaluation transform): + +```python +import torchvision + +def make_transform(resize_size: int = 224): + to_tensor = transforms.ToTensor() + resize = transforms.Resize((resize_size, resize_size), antialias=True) + normalize = transforms.Normalize( + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + ) + return transforms.Compose([to_tensor, resize, normalize]) +``` + + +For models using the SAT-493M weights (pretrained on satellite imagery), please use the following transform: + + +```python +import torchvision + +def make_transform(resize_size: int = 224): + to_tensor = transforms.ToTensor() + resize = transforms.Resize((resize_size, resize_size), antialias=True) + normalize = transforms.Normalize( + mean=(0.430, 0.411, 0.296), + std=(0.213, 0.156, 0.143), + ) + return transforms.Compose([to_tensor, resize, normalize]) +``` + +### Pretrained heads - Image classification + + + + + + + + + + + + + + + + + + +
BackbonePretraining
Dataset
Head
Dataset
Download
ViT-7B/16LVD-1689MImageNet[link]
+ + +The (full) classifier models can be loaded via PyTorch Hub: + +```python +import torch + +# DINOv3 +dinov3_vit7b16_lc = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_lc', source="local", weights=, backbone_weights=) + +``` + +### Pretrained heads - Depther trained on SYNTHMIX dataset + + + + + + + + + + + + + + + + + + +
BackbonePretraining
Dataset
Head
Dataset
Download
ViT-7B/16LVD-1689MSYNTHMIX[link]
+ + +```python +depther = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_dd', source="local", weights=, backbone_weights=) +``` + +Full example code of depther on an image + +```python +from PIL import Image +import torch +from torchvision import transforms +import matplotlib.pyplot as plt +from matplotlib import colormaps + +def get_img(): + import requests + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + return image + +def make_transform(resize_size: int | list[int] = 768): + to_tensor = transforms.ToTensor() + resize = transforms.Resize((resize_size, resize_size), antialias=True) + normalize = transforms.Normalize( + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + ) + return transforms.Compose([to_tensor, resize, normalize]) + +depther = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_dd', source="local", weights=, backbone_weights=) + +img_size = 1024 +img = get_img() +transform = make_transform(img_size) +with torch.inference_mode(): + with torch.autocast('cuda', dtype=torch.bfloat16): + batch_img = transform(img)[None] + batch_img = batch_img + depths = depther(batch_img) + +plt.figure(figsize=(12, 6)) +plt.subplot(121) +plt.imshow(img) +plt.axis("off") +plt.subplot(122) +plt.imshow(depths[0,0].cpu(), cmap=colormaps["Spectral"]) +plt.axis("off") + +``` + +### Pretrained heads - Detector trained on COCO2017 dataset + + + + + + + + + + + + + + + + + + +
BackbonePretraining
Dataset
Head
Dataset
Download
ViT-7B/16LVD-1689MCOCO2017[link]
+ + +```python +detector = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_de', source="local", weights=, backbone_weights=) +``` + +### Pretrained heads - Segmentor trained on ADE20K dataset + + + + + + + + + + + + + + + + + + +
BackbonePretraining
Dataset
Head
Dataset
Download
ViT-7B/16LVD-1689MADE20K[link]
+ +```python +segmentor = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_ms', source="local", weights=, backbone_weights=) +``` + +Full example code of segmentator on an image + +```python +import sys +sys.path.append(REPO_DIR) + +from PIL import Image +import torch +from torchvision import transforms +import matplotlib.pyplot as plt +from matplotlib import colormaps +from functools import partial +from dinov3.eval.segmentation.inference import make_inference + + +def get_img(): + import requests + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + return image + +def make_transform(resize_size: int | list[int] = 768): + to_tensor = transforms.ToTensor() + resize = transforms.Resize((resize_size, resize_size), antialias=True) + normalize = transforms.Normalize( + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + ) + return transforms.Compose([to_tensor, resize, normalize]) + +segmentor = torch.hub.load(REPO_DIR, 'dinov3_vit7b16_ms', source="local", weights=, backbone_weights=) + +img_size = 896 +img = get_img() +transform = make_transform(img_size) +with torch.inference_mode(): + with torch.autocast('cuda', dtype=torch.bfloat16): + batch_img = transform(img)[None] + pred_vit7b = segmentor(batch_img) # raw predictions + # actual segmentation map + segmentation_map_vit7b = make_inference( + batch_img, + segmentor, + inference_mode="slide", + decoder_head_type="m2f", + rescale_to=(img.size[-1], img.size[-2]), + n_output_channels=150, + crop_size=(img_size, img_size), + stride=(img_size, img_size), + output_activation=partial(torch.nn.functional.softmax, dim=1), + ).argmax(dim=1, keepdim=True) +plt.figure(figsize=(12, 6)) +plt.subplot(121) +plt.imshow(img) +plt.axis("off") +plt.subplot(122) +plt.imshow(segmentation_map_vit7b[0,0].cpu(), cmap=colormaps["Spectral"]) +plt.axis("off") +``` + + + + +### Pretrained heads - Zero-shot tasks with `dino.txt` + + + + + + + + + + + + + + +
BackboneDownload
ViT-L/16 distilled + [link], + vocabulary, + vocabulary license +
+ +The (full) dino.txt model can be loaded via PyTorch Hub: + +```python +import torch +# DINOv3 +dinov3_vitl16_dinotxt_tet1280d20h24l, tokenizer = torch.hub.load(REPO_DIR, 'dinov3_vitl16_dinotxt_tet1280d20h24l', weights=, backbone_weights=) +``` + + +## Installation + +The training and evaluation code requires PyTorch version >= 2.7.1 as well as a few other 3rd party packages. Note that the code has only been tested with the specified versions and also expects a Linux environment. To setup all the required dependencies for training and evaluation, please follow the instructions below: + +*[micromamba](https://mamba.readthedocs.io/en/latest/user_guide/micromamba.html)* **(Recommended)** - Clone the repository and then create and activate a `dinov3` conda environment using the provided environment definition: + +```shell +micromamba env create -f conda.yaml +micromamba activate dinov3 +``` + +## Getting started + +Several notebooks are provided to get started applying DINOv3: +- [PCA of patch features](notebooks/pca.ipynb): display the PCA of DINOv3 patch features on a foreground object (rainbow visualizations from the paper) [[Run in Google Colab]](https://colab.research.google.com/github/facebookresearch/dinov3/blob/main/notebooks/pca.ipynb) +- [Foreground segmentation](notebooks/foreground_segmentation.ipynb): train a linear foreground segmentation model based on DINOv3 features [[Run in Google Colab]](https://colab.research.google.com/github/facebookresearch/dinov3/blob/main/notebooks/foreground_segmentation.ipynb) +- [Dense and sparse matching](notebooks/dense_sparse_matching.ipynb): match patches from objects on two different images based on DINOv3 features [[Run in Google Colab]](https://colab.research.google.com/github/facebookresearch/dinov3/blob/main/notebooks/dense_sparse_matching.ipynb) +- [Segmentation tracking](notebooks/segmentation_tracking.ipynb): video segmentation tracking using a non-parametric method based on DINOv3 features [[Run in Google Colab]](https://colab.research.google.com/github/facebookresearch/dinov3/blob/main/notebooks/segmentation_tracking.ipynb) + +## Data preparation + +### ImageNet-1k + +The root directory of the dataset should hold the following contents: + +- `/test/ILSVRC2012_test_00000001.JPEG` +- `/test/[..]` +- `/test/ILSVRC2012_test_00100000.JPEG` +- `/train/n01440764/n01440764_10026.JPEG` +- `/train/[...]` +- `/train/n15075141/n15075141_9993.JPEG` +- `/val/n01440764/ILSVRC2012_val_00000293.JPEG` +- `/val/[...]` +- `/val/n15075141/ILSVRC2012_val_00049174.JPEG` +- `/labels.txt` + +The provided dataset implementation expects a few additional metadata files to be present under the extra directory: + +- `/class-ids-TRAIN.npy` +- `/class-ids-VAL.npy` +- `/class-names-TRAIN.npy` +- `/class-names-VAL.npy` +- `/entries-TEST.npy` +- `/entries-TRAIN.npy` +- `/entries-VAL.npy` + +These metadata files can be generated (once) with the following lines of Python code: + +```python +from dinov3.data.datasets import ImageNet + +for split in ImageNet.Split: + dataset = ImageNet(split=split, root="", extra="") + dataset.dump_extra() +``` + +Note that the root and extra directories do not have to be distinct directories. + +### ImageNet-22k + +Please adapt the [dataset class](dinov3/data/datasets/image_net_22k.py) to match your local setup. + +
+ +:warning: To execute the commands provided in the next sections for training and evaluation, the `dinov3` package should be included in the Python module search path, i.e. simply prefix the command to run with `PYTHONPATH=.`. + +## Training + +### Fast setup: training DINOv3 ViT-L/16 on ImageNet-1k + +Run DINOv3 pre-training on 4 H100-80GB nodes (32 GPUs) in a SLURM cluster environment with submitit: + +```shell + PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/train/train.py \ + --nodes 4 \ + --config-file dinov3/configs/train/vitl_im1k_lin834.yaml \ + --output-dir \ + train.dataset_path=ImageNet22k:root=:extra= +``` +Training time is approximately 14 hours and the resulting checkpoint should reach 82.0% on k-NN eval and 83.5% on linear eval. + +The training code saves the weights of the teacher in the eval folder every 12500 iterations for evaluation. + +### Exact DINOv3 setup: training DINOv3 ViT-7B/16 + +DINOv3 ViT-7B/16 is trained on a private dataset. The training involves 3 stages: +- Pretraining +- Gram anchoring +- High resolution adaptation + +#### Pretraining + +Launch DINOV3 ViT-7B/16 pretraining on 32 nodes (256 GPUs) in a SLURM cluster environment with submitit. + +```shell +PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/train/train.py \ + --nodes 32 \ + --config-file dinov3/configs/train/dinov3_vit7b16_pretrain.yaml \ + --output-dir \ + train.dataset_path=:root=:extra= +``` + +#### Gram anchoring + +```shell +PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/train/train.py \ + --nodes 32 \ + --config-file dinov3/configs/train/dinov3_vit7b16_gram_anchor.yaml \ + --output-dir \ + train.dataset_path=:root=:extra= \ + gram.ckpt= +``` + +#### High-resolution adaptation + + +```shell +PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/train/train.py \ + --nodes 32 \ + --config-file dinov3/configs/train/dinov3_vit7b16_high_res_adapt.yaml \ + --output-dir \ + train.dataset_path=:root=:extra= \ + gram.ckpt= \ + student.resume_from_teacher_chkpt= +``` + +## Multi-distillation + +### Test setup: + +```shell +PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/train/train.py \ + --nodes 1 \ + --config-file dinov3/configs/train/multi_distillation_test.yaml \ + --output-dir \ + --multi-distillation \ + train.dataset_path=:root=:extra= +``` + +## Evaluation + +The training code regularly saves the teacher weights. In order to evaluate the model, run the following evaluation on a single node: + + +### Logistic regression classification on ImageNet-1k + +```shell +PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/eval/log_regression.py \ + model.config_file=/config.yaml \ + model.pretrained_weights=/teacher_checkpoint.pth \ + output_dir= \ + train.dataset=ImageNet:split=TRAIN:root=:extra= \ + eval.test_dataset=ImageNet:split=VAL:root=:extra= +``` + +### k-NN classification on ImageNet-1k + +```shell +PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/eval/knn.py \ + model.config_file=/config.yaml \ + model.pretrained_weights=/teacher_checkpoint.pth \ + output_dir= \ + train.dataset=ImageNet:split=TRAIN:root=:extra= \ + eval.test_dataset=ImageNet:split=VAL:root=:extra= +``` + +### Linear classification with data augmentation on ImageNet-1k + +```shell +PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/eval/linear.py \ + model.config_file=/config.yaml \ + model.pretrained_weights=/teacher_checkpoint.pth \ + output_dir= \ + train.dataset=ImageNet:split=TRAIN:root=:extra= \ + train.val_dataset=ImageNet:split=VAL:root=:extra= +``` + + +### Text alignment on DINOv3 using dino.txt + +Text alignment can be done following the method from `dino.txt` aka [DINOv2 Meets Text](https://arxiv.org/abs/2412.16334). + +```shell +PYTHONPATH=${PWD} python -m dinov3.run.submit dinov3/eval/text/train_dinotxt.py \ + --nodes 4 \ + # An example config for text alignment is here: dinov3/eval/text/configs/dinov3_vitl_text.yaml \ + trainer_config_file="" \ + output-dir= +``` +Launching the above trains text alignment on 4 nodes with 8 gpus each (32 gpus in total). +Please note that the text alignment model in the DINOv3 paper was trained on a private dataset and here we have given an example config in ```dinov3/eval/text/configs/dinov3_vitl_text.yaml``` using ```CocoCaptions``` dataset for illustration purposes. +Please adapt the provided ```CocoCaptions``` dataset class, the dataset can be found [here](https://www.kaggle.com/datasets/nikhil7280/coco-image-caption) + +## License + +DINOv3 code and model weights are released under the DINOv3 License. See [LICENSE.md](LICENSE.md) for additional details. + +## Contributing + +See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). + +## Citing DINOv3 + +If you find this repository useful, please consider giving a star :star: and citation :t-rex:: + +``` +@misc{simeoni2025dinov3, + title={{DINOv3}}, + author={Sim{\'e}oni, Oriane and Vo, Huy V. and Seitzer, Maximilian and Baldassarre, Federico and Oquab, Maxime and Jose, Cijo and Khalidov, Vasil and Szafraniec, Marc and Yi, Seungeun and Ramamonjisoa, Micha{\"e}l and Massa, Francisco and Haziza, Daniel and Wehrstedt, Luca and Wang, Jianyuan and Darcet, Timoth{\'e}e and Moutakanni, Th{\'e}o and Sentana, Leonel and Roberts, Claire and Vedaldi, Andrea and Tolan, Jamie and Brandt, John and Couprie, Camille and Mairal, Julien and J{\'e}gou, Herv{\'e} and Labatut, Patrick and Bojanowski, Piotr}, + year={2025}, + eprint={2508.10104}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2508.10104}, +} +``` diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/conda.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/conda.yaml new file mode 100644 index 0000000000000000000000000000000000000000..db4c17be5c918075002b11c48abbef4a9d3043cd --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/conda.yaml @@ -0,0 +1,23 @@ +name: dinov3 +channels: + - defaults + - conda-forge +dependencies: + - python=3.11 + - omegaconf + - pip + - pip: + - ftfy # needed for dino.txt + - iopath + - omegaconf + - pandas + - regex # needed for dino.txt + - pandas + - scikit-learn + - scikit-learn-intelex + - submitit + - termcolor + - torch + - torchvision + - torchmetrics + diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b70f615df55246ea150f0c76e02b0b8b98d16b2b --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +__version__ = "0.0.1" diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/checkpointer/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/checkpointer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..120712ccd2ec68848a027ca2a54b662f4e109205 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/checkpointer/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from .checkpointer import ( + CheckpointRetentionPolicy, + cleanup_checkpoint, + find_all_checkpoints, + find_latest_checkpoint, + init_fsdp_model_from_checkpoint, + init_model_from_checkpoint_for_evals, + keep_checkpoint_copy, + keep_last_n_checkpoints, + load_checkpoint, + register_dont_save_hooks, + save_checkpoint, +) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/checkpointer/checkpointer.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/checkpointer/checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..7a984110c68be859f5612be5d290bf27d5a77a55 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/checkpointer/checkpointer.py @@ -0,0 +1,352 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +""" +Suggested file structure: + +output_dir/ +|-- ckpt/ +| |-- 0/ +| |-- 99/ +| |-- 199/ +| |-- 199_keep/ +| |-- 299/ +| `-- ... +`-- eval/ + `-- 0/ + `-- 99/ + `-- ckpt/ + +Distributed checkpointer docs: +- https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html +- https://pytorch.org/docs/stable/distributed.checkpoint.html +""" + +import logging +import shutil +import subprocess +import tempfile +from enum import Enum +from pathlib import Path +from typing import List, Sequence, Set + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +import torch.distributed.checkpoint.filesystem as dcpfs +import torch.distributed.checkpoint.state_dict as dcpsd +from torch.distributed.checkpoint.stateful import Stateful + +logger = logging.getLogger("dinov3") + + +class CheckpointRetentionPolicy(Enum): + ALL = "all" # keep all checkpoints + BEST = "best" + LAST = "last" + LAST_AND_BEST = "last_and_best" + NONE = "none" # do not keep any checkpoints + + @property + def keep_filters(self) -> Set[str]: + """Files that match these patterns are not deleted by cleanup""" + if self == CheckpointRetentionPolicy.LAST: + return set(["final"]) + if self == CheckpointRetentionPolicy.BEST: + return set(["best"]) + if self == CheckpointRetentionPolicy.LAST_AND_BEST: + return set(["final", "best"]) + if self == CheckpointRetentionPolicy.ALL: + return set() + return set() + + @property + def max_to_keep(self) -> int | None: + """ + maximum "periodic" checkpoints to keep concurrently, ie. saved with `step` and not `save`. `None` for keep all + """ + if self == CheckpointRetentionPolicy.ALL: + return None + return 1 + + +def save_checkpoint( + ckpt_dir: str | Path, # output_dir/ckpt/199 + *, + iteration: int | str, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer | None = None, + overwrite: bool = True, + process_group: dist.ProcessGroup = None, + **others: Stateful, +): + """Save a plain/DDP/FSDP/FSDP2 model, its optimizer, an integer iteration and other stateful objects.""" + rank = torch.distributed.get_rank(group=process_group) + + # Rank 0 checks if the checkpoint directory exists, but all ranks need to know if if exists, + # so they can raise an error when overwrite is False. If overwrite is True, rank 0 will delete it + # and other ranks wait for the deletion to finish. + ckpt_dir = Path(ckpt_dir) + ckpt_dir_exists = [ckpt_dir.exists() if rank == 0 else None] + src_rank = 0 + if process_group is not None: + src_rank = torch.distributed.get_global_rank(group=process_group, group_rank=0) + torch.distributed.broadcast_object_list(ckpt_dir_exists, src=src_rank, group=process_group) + ckpt_dir_exists = ckpt_dir_exists[0] + if ckpt_dir_exists: + if overwrite: + if rank == 0: + if ckpt_dir.is_dir(): + shutil.rmtree(ckpt_dir) + else: + ckpt_dir.unlink() + logger.info(f"Deleted: {ckpt_dir}") + torch.distributed.barrier(group=process_group) + else: + raise RuntimeError(f"Checkpoint already exists: {ckpt_dir}") + + # Rank 0 creates a temporary directory for the checkpoint and broadcasts the name to all ranks. + ckpt_dir.parent.mkdir(parents=True, exist_ok=True) + ckpt_dir_tmp = [tempfile.mkdtemp(dir=ckpt_dir.parent, prefix=ckpt_dir.name) if rank == 0 else None] + torch.distributed.broadcast_object_list(ckpt_dir_tmp, src=src_rank, group=process_group) + ckpt_dir_tmp = Path(ckpt_dir_tmp[0]) + + to_save = {"iteration": iteration} + to_save["model"] = dcpsd.get_model_state_dict(model) + if optimizer is not None: + to_save["optimizer"] = dcpsd.get_optimizer_state_dict(model, optimizer) + to_save.update(others) + dcp.save( + to_save, + storage_writer=dcpfs.FileSystemWriter(ckpt_dir_tmp), + process_group=process_group, + ) + + # Rank 0 renames the temporary directory to the final checkpoint directory. All ranks wait for the rename. + if rank == 0: + ckpt_dir_tmp.rename(ckpt_dir) + torch.distributed.barrier() + + logger.info(f"Saved: {ckpt_dir}") + + +def load_checkpoint( + ckpt_dir: str | Path, # output_dir/ckpt/199 + *, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer | None = None, + strict_loading: bool = True, + process_group: dist.ProcessGroup = None, + **others: Stateful, +) -> int | None: + """ + Load a plain/DDP/FSDP/FSDP2 model, its optimizer, an integer iteration and other stateful objects. + Can you take a checkpoint saved on N ranks and load it on M ranks? Sure you can! + Activation checkpointing and torch-compile can also be different between save and load, no problem. + """ + ckpt_dir = Path(ckpt_dir) + to_load = {"iteration": None} + to_load["model"] = dcpsd.get_model_state_dict(model) + if optimizer is not None: + to_load["optimizer"] = dcpsd.get_optimizer_state_dict(model, optimizer) + to_load.update(others) + dcp.load( + to_load, + storage_reader=dcpfs.FileSystemReader(ckpt_dir), + planner=dcp.default_planner.DefaultLoadPlanner(allow_partial_load=not strict_loading), + process_group=process_group, + ) + iteration = to_load["iteration"] + dcpsd.set_model_state_dict(model, to_load["model"]) + if optimizer is not None: + dcpsd.set_optimizer_state_dict(model, optimizer, to_load["optimizer"]) + logger.info(f"Loaded: {ckpt_dir}") + return iteration + + +def register_dont_save_hooks(module: torch.nn.Module, dont_save: Sequence[str]): + """ + Registers save/load state dict hooks such that the weights in `dont_save` are not persisted in the checkpoint. + + Typical use case: a classification model composed of a frozen backbone and a trainable head. + If the frozen backbone is loaded from torch hub, it does't make sense to save a copy of it in each checkpoint. + """ + + def state_dict_post_hook(module, state_dict, prefix, local_metadata): + # Remove frozen weights so they won't get saved. + # If this module is not the top-level module, its weights will have a prefix in the state dict. + nonlocal _dont_save + for k in _dont_save: + del state_dict[prefix + k] + + def load_state_dict_pre_hook( + module, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + # This pre hook exists only to pass the prefix to the post hook when loading the state dict. + nonlocal _prefix + assert _prefix is None + _prefix = prefix + + def load_state_dict_post_hook(module, incompatible_keys): + # Remove the frozen weights from the missing keys so they don't raise an error. + nonlocal _prefix + assert _prefix is not None + to_remove = [] + for missing_key in incompatible_keys.missing_keys: + k = missing_key.removeprefix(_prefix) + k = k.replace("_checkpoint_wrapped_module.", "") # Added by activation checkpointing + if k in _dont_save: + to_remove.append(missing_key) + for r in to_remove: + incompatible_keys.missing_keys.remove(r) + _prefix = None + + _dont_save = set(name.replace("_checkpoint_wrapped_module.", "") for name in dont_save) + _prefix = None + module.register_state_dict_post_hook(state_dict_post_hook) + module.register_load_state_dict_pre_hook(load_state_dict_pre_hook) + module.register_load_state_dict_post_hook(load_state_dict_post_hook) + + +def find_all_checkpoints(ckpt_dir: Path | str) -> list[Path]: + """Find all checkpoints in a directory, i.e. subdirs with integer name. Sorted from first to last.""" + ckpt_dir = Path(ckpt_dir) + if not ckpt_dir.is_dir(): + return [] + checkpoints = [p for p in ckpt_dir.iterdir() if p.is_dir() and _is_int(p.name)] + checkpoints.sort(key=lambda p: int(p.name)) + return checkpoints + + +def find_latest_checkpoint(ckpt_dir: Path | str) -> Path | None: + """Find the latest checkpoint in a directory, i.e. the subdir with the highest integer name.""" + checkpoints = find_all_checkpoints(ckpt_dir) + if len(checkpoints) == 0: + return None + return checkpoints[-1] + + +def keep_last_n_checkpoints(ckpt_dir: Path | str, n: int | None): + """In a directory with integer-named subdirs, keep only the n subdirs with the highest number.""" + if n is None: + return + checkpoints = find_all_checkpoints(ckpt_dir) + for ckpt_dir in checkpoints[:-n]: + try: + shutil.rmtree(ckpt_dir) + logger.info(f"Deleted: {ckpt_dir}") + except Exception: + logger.exception(f"Failed to delete: {ckpt_dir}") + + +def keep_checkpoint_copy(src: Path | str): + """Copy a file/directory next to itself with a _keep suffix. Files are hardlinked.""" + src = Path(src) + dst = src.parent / f"{src.name}_keep" + subprocess.check_output(["cp", "--recursive", "--link", src, dst]) + logger.info(f"Copied: {src} -> {dst}") + + +def _is_int(s: str) -> bool: + try: + int(s) + return True + except ValueError: + return False + + +# Initialize a FSDP2 model from DCP or PyTorch standard checkpoint +def init_fsdp_model_from_checkpoint( + model: torch.nn.Module, + checkpoint_path: str, + skip_load_keys: List[str] | None = None, + keys_not_sharded: List[str] | None = None, + process_group: dist.ProcessGroup = None, +): + if not Path(checkpoint_path).is_dir(): # PyTorch standard checkpoint + logger.info(f"Loading pretrained weights from {checkpoint_path}") + chkpt = torch.load(checkpoint_path, map_location="cpu")["teacher"] + from torch.distributed.device_mesh import DeviceMesh, init_device_mesh + + if process_group is None: + world_mesh = init_device_mesh( + "cuda", + mesh_shape=(dist.get_world_size(),), + mesh_dim_names=("dp",), + ) + else: + world_mesh = DeviceMesh.from_group(process_group, "cuda") + chkpt = { + key: ( + torch.distributed.tensor.distribute_tensor(tensor, world_mesh, src_data_rank=None) + if not any(key_not_sharded in key for key_not_sharded in keys_not_sharded) + else tensor + ) + for key, tensor in chkpt.items() + } + model.load_state_dict( + { + key: tensor + for key, tensor in chkpt.items() + if not any(skip_load_key in key for skip_load_key in skip_load_keys) + } + ) + else: # DCP checkpoint + load_checkpoint(ckpt_dir=checkpoint_path, model=model, process_group=process_group) + + +# Initialize a standard non distributed PyTorch model from PyTorch standard checkpoint for evals +def init_model_from_checkpoint_for_evals( + model: torch.nn.Module, pretrained_weights: str | Path, checkpoint_key: str = None +): + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) + + +def cleanup_checkpoint(ckpt_dir: str, checkpoint_retention_policy: CheckpointRetentionPolicy): + """ + ckpt_dir is the directory containing each individual checkpoint directories (either at iteration, best (validation performance) or final) + |-- ckpt_dir/ + | |-- 0/ + | |--checkpoint.pth or dcp_sharded_checkpoint_dir + | |-- 99/ + |--checkpoint.pth or dcp_sharded_checkpoint_dir + | |-- 199/ + |--checkpoint.pth or dcp_sharded_checkpoint_dir + | |-- best/ + |--checkpoint.pth or dcp_sharded_checkpoint_dir + | |-- 299/ + |--checkpoint.pth or dcp_sharded_checkpoint_dir + | |-- final/ + |--checkpoint.pth or dcp_sharded_checkpoint_dir + """ + ckpt_dir = Path(ckpt_dir) + if not ckpt_dir.is_dir(): + return [] + checkpoint_filters = checkpoint_retention_policy.keep_filters + checkpoints = [p for p in ckpt_dir.iterdir() if p.is_dir()] + for checkpoint in checkpoints: + if checkpoint in checkpoint_filters: + continue + try: + shutil.rmtree(checkpoint) + logger.info(f"Deleted: {checkpoint}") + except Exception: + logger.exception(f"Failed to delete: {checkpoint}") diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3719887dfb7c33ad67c2546e33e2f1c9f426760c --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from .config import ( + DinoV3SetupArgs, + apply_scaling_rules_to_cfg, + exit_job, + get_cfg_from_args, + get_default_config, + setup_config, + setup_job, + setup_multidistillation, + write_config, +) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/config.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/config.py new file mode 100644 index 0000000000000000000000000000000000000000..6aceaa86f9b952cc554cb6ae80ed16a242d67d0e --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/config.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +import math +import os +import pathlib +import sys +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Any, List, Optional, Sequence, Tuple + +from omegaconf import DictConfig, OmegaConf + +import dinov3.distributed as distributed +from dinov3.logging import cleanup_logging, setup_logging +from dinov3.utils import fix_random_seeds, get_conda_env, get_sha + +logger = logging.getLogger("dinov3") + + +@dataclass +class DinoV3SetupArgs: + config_file: str + pretrained_weights: str | None = None + shard_unsharded_model: bool = False + output_dir: str = "" + opts: List[Any] = field(default_factory=lambda: []) + + def __post_init__(self): + # When loaded from benchmark.yaml, self.opts is a frozen omegaconf.ListConfig, + # which works everywhere except when we want to modify it or when + # we try to json-serialize it. So we convert it to a regular list here. + if OmegaConf.is_config(self.opts): + self.opts = OmegaConf.to_object(self.opts) + + +def apply_scaling_rules_to_cfg(cfg): # to fix + assert distributed.is_enabled(), "Setup distributed to get global size !" + if "schedules" in cfg: + # For schedules v2, the scaling rules are applied when building the schedules, the config is not modified + return cfg + + if cfg.optim.scaling_rule == "linear_wrt_256": + old_lr = cfg.optim.lr + cfg.optim.lr *= cfg.train.batch_size_per_gpu * distributed.get_world_size() / 256.0 + logger.info(f"linear scaling learning rate; old: {old_lr}, new: {cfg.optim.lr}") + elif cfg.optim.scaling_rule == "sqrt_wrt_1024": + old_lr = cfg.optim.lr + cfg.optim.lr *= 4 * math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_world_size() / 1024.0) + logger.info(f"sqrt scaling learning rate; old: {old_lr}, new: {cfg.optim.lr}") + return cfg + + +def write_config(cfg, output_dir, name="config.yaml"): + logger.info(OmegaConf.to_yaml(cfg)) + output_dir = os.path.abspath(output_dir) + saved_cfg_path = os.path.join(output_dir, name) + with open(saved_cfg_path, "w") as f: + OmegaConf.save(config=cfg, f=f) + return saved_cfg_path + + +def get_default_config() -> DictConfig: + p = pathlib.Path(__file__).parent / "ssl_default_config.yaml" + return OmegaConf.load(p) + + +def get_cfg_from_args(args: DinoV3SetupArgs, multidistillation=False, strict=True): + overrides = [*args.opts] + if args.output_dir is not None: + overrides.append(f"train.output_dir={os.path.realpath(args.output_dir)}") + + # Config file + cfg = OmegaConf.load(args.config_file) + + # Command line overrides + opts_cfg = OmegaConf.from_cli(overrides) + + if multidistillation: + cfg = OmegaConf.merge(cfg, opts_cfg) + else: + # Default config + default_cfg = get_default_config() + if strict: + OmegaConf.set_struct(default_cfg, True) + cfg = OmegaConf.merge(default_cfg, cfg, opts_cfg) + return cfg + + +def setup_config(args: DinoV3SetupArgs, strict_cfg=True): + """ + Create configs and perform basic setups. + """ + # Create the cfg with OmegaConf + cfg = get_cfg_from_args(args, strict=strict_cfg) + # setup distributed, logging, and random seeds + logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) + # dump config before modifying so it can be reloaded + if args.output_dir is not None: + write_config(cfg, args.output_dir) + # modify the config inplace by applying scaling rules + apply_scaling_rules_to_cfg(cfg) + return cfg + + +def _enumerate_all_subgroup_ranks(all_subgroup_rank_spans: Sequence[Tuple[int, int]]): + """Expands a specification of process subgroups from spans to enumerated ranks. + + Args: + all_group_rank_spans: a sequence of rank spans (first rank, last rank), + one for each process group. Example: ((0, 1), (2, 3), (4, 7)). + """ + for first, last in all_subgroup_rank_spans: + assert first <= last + return tuple(tuple(range(first, last + 1)) for first, last in all_subgroup_rank_spans) + + +def setup_multidistillation(args: DinoV3SetupArgs): + base_output_dir = args.output_dir + os.makedirs(args.output_dir, exist_ok=True) + # get config file for this rank + base_cfg = OmegaConf.load(args.config_file) + assert base_cfg.multidistillation.enabled + + global_batch_size = base_cfg.multidistillation.global_batch_size + + distributed.enable(overwrite=True) + seed = getattr(args, "seed", 0) + rank = distributed.get_rank() + + # build process subgroups + all_subgroup_rank_spans = tuple( + (student.ranks_range[0], student.ranks_range[1] - 1) for student in base_cfg.multidistillation.students + ) + all_subgroup_ranks = _enumerate_all_subgroup_ranks(all_subgroup_rank_spans) + distributed.new_subgroups(all_subgroup_ranks) + + found = False + for student in base_cfg.multidistillation.students: + if rank in range(*student.ranks_range): + found = True + break + assert found, "rank of worker not in defined range" + + name = student.name + config_path = student.config_path + n_gpus = student.ranks_range[1] - student.ranks_range[0] + assert global_batch_size % n_gpus == 0 + total_n_gpus = distributed.get_world_size() + + args.output_dir = os.path.join(base_output_dir, name) + args.opts += [f"train.output_dir={args.output_dir}"] + args.opts += [f"train.batch_size_per_gpu={global_batch_size // total_n_gpus}"] + args.config_file = os.path.abspath(config_path) + default_cfg = get_default_config() + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(default_cfg, cfg, base_cfg, OmegaConf.from_cli(args.opts)) + + global logger + setup_logging(output=args.output_dir, level=logging.INFO) + + fix_random_seeds(seed + rank) + + write_config(cfg, args.output_dir) + apply_scaling_rules_to_cfg(cfg) + + return cfg + + +def setup_job( + output_dir: Optional[str] = None, + distributed_enabled: bool = True, + logging_enabled: bool = True, + seed: Optional[int] = 0, + restrict_print_to_main_process: bool = True, + distributed_timeout: timedelta | None = None, +): + """ + Setup methods that should be done in every fairvit job + Initializes logging, distributed, random seeds and other utilities. + """ + if output_dir is not None: + output_dir = os.path.realpath(output_dir) + os.makedirs(output_dir, exist_ok=True) + + if logging_enabled: + setup_logging( + output=output_dir, + level=logging.INFO, + log_to_stdout_only_in_main_process=restrict_print_to_main_process, + ) + + if distributed_enabled: + distributed.enable( + overwrite=True, + nccl_async_error_handling=True, + restrict_print_to_main_process=restrict_print_to_main_process, + timeout=distributed_timeout, + ) + + if seed is not None: + rank = distributed.get_rank() + fix_random_seeds(seed + rank) + + logger = logging.getLogger("dinov3") + logger.info("git:\n {}\n".format(get_sha())) + + # Log some python info + conda_env_name, conda_env_path = get_conda_env() + logger.info(f"conda env name: {conda_env_name}") + logger.info(f"conda env path: {conda_env_path}") + logger.info(f"python path: {sys.path}") + + +def exit_job(distributed_enabled: bool = True, logging_enabled: bool = True): + if distributed_enabled: + distributed.disable() + if logging_enabled: + cleanup_logging() diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/ssl_default_config.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/ssl_default_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..87ae3a0a33887c8f501d4d94f724f06942bab2a5 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/ssl_default_config.yaml @@ -0,0 +1,205 @@ +MODEL: + META_ARCHITECTURE: SSLMetaArch + DEVICE: cuda + WEIGHTS: '' + DTYPE: float32 +compute_precision: + param_dtype: bf16 + reduce_dtype: fp32 + sharding_strategy: SHARD_GRAD_OP +dino: + loss_weight: 1.0 + global_ignore_diagonal: true # Whether to ignore A-A and B-B global pairs, default as in DINOv2, ignored by SSLMetaArch + head_n_prototypes: 65536 + head_bottleneck_dim: 256 + head_norm_last_layer: false + head_nlayers: 3 + head_hidden_dim: 2048 + koleo_loss_weight: 0.1 + koleo_loss_distributed: false + koleo_topk: 1 + koleo_distributed_replicas: 0 + koleo_distributed_loss_group_size: null # Size of the nearest neighbor set for distributed Koleo. If None, uses global batch size. + koleo_distributed_loss_group_data: true # group data from adjacent ranks to make sure koleo is applied on the same data distribution + force_weight_norm: false + reweight_dino_local_loss: false # If true, reweighting of DINO loss + local_loss_weight_schedule: # Schedule for local loss weight, enabled if reweight_dino_local_loss is true + start: 0.5 + peak: 0.5 + end: 0.5 + warmup_epochs: 0 +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.5 + mask_random_circular_shift: false + force_masking_even_with_zero_weight: False + separate_head: true + head_n_prototypes: 65536 + head_bottleneck_dim: 256 + head_norm_last_layer: false + head_nlayers: 3 + head_hidden_dim: 2048 +gram: + use_loss: false # (bool) if true gram is used, else not + compute_stats: false # (bool): if true compute auxilliary stats + loss_weight: 1.0 # (float): weight of the loss + ema_teacher: false # (bool): using the EMA teacher as GRAM teacher + ckpt: null #(str): Checkpoint to the teacher + it_load_ema_teacher: -1 # (int): iteration at which the ema teacher is loaded into the gram teacher + rep_update: true # (bool): if true GRAM teacher updated every gram.update_frequency after iter gram.it_first_update steps + update_frequency: 50000 # (int): update frequency + it_first_update: 0 # (int): iteration of the first update + max_updates: null # (int): maximum number of updates to gram teacher. If None, it is unlimited + normalized: true # (bool): normalization of the features + img_level: false # (bool): if true GRAM computation at the image else, otherwise at the local batch level + remove_neg: false # (bool): if true remove the negative similarities before applying the loss + remove_only_teacher_neg: false # (bool): remove negative similarities of the teacher + tokens_used: all # (str): In [all, masked, unmasked] + global_teacher_resize_method: bicubic # Method for resizing the outputs of the gram teacher + global_teacher_resize_antialias: false # Whether to use antialiasing when resizing the outputs of the gram teacher + loss_weight_schedule: null # (dict): If not None, use a schedule for the loss weight instead of `loss_weight` +train: + batch_size_per_gpu: 64 + dataset_path: ImageNet:split=TRAIN + data_config: null + output_dir: . + saveckp_freq: 20 + seed: 0 + num_workers: 10 + OFFICIAL_EPOCH_LENGTH: 1250 + monitor_gradient_norm: false + chunk_schedule: [] + use_teacher_head: true + learn_from_teacher_tokens: false + centering: "sinkhorn_knopp" # or "sinkhorn_knopp" + checkpointing: false + checkpointing_full: false # aggressive checkpointing + compile: true + cudagraphs: false + sharded_eval_checkpoint: false + cache_dataset: false +student: + arch: vit_large + patch_size: 16 + drop_path_rate: 0.3 + layerscale: 1.0e-05 + pretrained_weights: '' + ffn_layer: "mlp" + ffn_ratio: 4.0 + resume_from_teacher_chkpt: "" + qkv_bias: true + proj_bias: true + ffn_bias: true + norm_layer: "layernorm" + n_storage_tokens: 0 + mask_k_bias: false + untie_cls_and_patch_norms: false # If true, use separate norms for CLS/reg and patch/mask tokens + untie_global_and_local_cls_norm: false # If true, use separate norms for local and global crop CLS token during training + in_chans: 3 + pos_embed_type: rope + pos_embed_rope_base: 100.0 + pos_embed_rope_min_period: null + pos_embed_rope_max_period: null + pos_embed_rope_normalize_coords: separate # min, max, separate + pos_embed_rope_shift_coords: null + pos_embed_rope_jitter_coords: null + pos_embed_rope_rescale_coords: null + pos_embed_rope_dtype: bf16 + fp8_enabled: False # Convert Linear layers to operate in fp8 precision + fp8_filter: "blocks" # Regex that must appear in module path; empty means everything +teacher: + momentum_teacher: 0.992 + final_momentum_teacher: 1 + warmup_teacher_temp: 0.04 + teacher_temp: 0.07 + warmup_teacher_temp_epochs: 30 + in_chans: 3 +distillation: # teacher + enabled: false + full_cfg_path: "" + checkpoint_path: "" +multidistillation: + enabled: false +hrft: # non-hrft'd student + enabled: false + checkpoint_path: "" # teacher_checkpoint path +optim: + epochs: 100 + optimizer: adamw + weight_decay: 0.04 + weight_decay_end: 0.4 + lr: 0.001 + warmup_epochs: 10 + min_lr: 1.0e-06 + schedule_trunc_extra: 0.0 # Compute the schedule for (1 + schedule_trunc_extra) steps and truncate, .25 is a good choice + clip_grad: 3.0 + freeze_last_layer_epochs: 1 + scaling_rule: sqrt_wrt_1024 + patch_embed_lr_mult: 0.2 + dino_head_wd_multiplier: 1.0 + layerwise_decay: 0.9 + multi_tensor_optim: true + dump_fsdp_weights_path: "" + adamw_beta1: 0.9 + adamw_beta2: 0.999 +crops: + global_crops_scale: + - 0.32 + - 1.0 + local_crops_number: 8 + local_crops_scale: + - 0.05 + - 0.32 + global_crops_size: 224 + local_crops_size: 96 + global_local_crop_pairs_ratios: 1.0 + gram_teacher_crops_size: null # If not None, return crops for gram teacher + localcrops_subset_of_globalcrops: false + share_color_jitter: false + horizontal_flips: true + gram_teacher_no_distortions: false # If True, no distortions are applied to gram teacher crops + rgb_mean: + - 0.485 + - 0.456 + - 0.406 + rgb_std: + - 0.229 + - 0.224 + - 0.225 +evaluation: + eval_period_iterations: 12500 + low_freq_every: 5 + config_files: # Must be in fairvit/eval/configs + high_freq: benchmark_high_frequency.yaml # More often + low_freq: benchmark_low_frequency.yaml # Less often +checkpointing: + period: 3750 + max_to_keep: 3 + keep_every: 99999999999999999 # Save a checkpoint every N iterations, regardless of max_to_keep and period + +# Example of constant schedules with schedules v2 +# # schedules: +# # lr: +# # start: 0.0 +# # peak: 1e-3 +# # end: 1e-6 +# # warmup_epochs: 10 +# # freeze_last_layer_epochs: 1 +# # weight_decay: +# # start: 0.04 +# # peak: 0.04 +# # end: 0.04 +# # warmup_epochs: 0 +# # momentum: +# # start: 0.992 +# # peak: 0.992 +# # end: 0.992 +# # warmup_epochs: 0 +# # teacher_temp: +# # start: 0.04 +# # peak: 0.07 +# # end: 0.07 +# # warmup_epochs: 30 diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/dinov3_vit7b16_gram_anchor.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/dinov3_vit7b16_gram_anchor.yaml new file mode 100644 index 0000000000000000000000000000000000000000..05cf7dfef6fda0199e10b45a761f14201a3db6b5 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/dinov3_vit7b16_gram_anchor.yaml @@ -0,0 +1,203 @@ +MODEL: + META_ARCHITECTURE: SSLMetaArch + DEVICE: cuda + WEIGHTS: '' + DTYPE: float32 +compute_precision: + param_dtype: bf16 + reduce_dtype: fp32 + sharding_strategy: SHARD_GRAD_OP +dino: + loss_weight: 1.0 + global_ignore_diagonal: true + head_n_prototypes: 262144 + head_bottleneck_dim: 512 + head_norm_last_layer: false + head_nlayers: 3 + head_hidden_dim: 8192 + koleo_loss_weight: 0.1 + koleo_loss_distributed: false + koleo_topk: 1 + koleo_distributed_replicas: 0 + koleo_distributed_loss_group_size: null + koleo_distributed_loss_group_data: true + force_weight_norm: false + reweight_dino_local_loss: true + local_loss_weight_schedule: + start: 1 + peak: 1 + end: 0.5 + warmup_epochs: 1000 + cosine_epochs: 1 +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.5 + mask_random_circular_shift: false + force_masking_even_with_zero_weight: false + separate_head: true + head_n_prototypes: 98304 + head_bottleneck_dim: 384 + head_norm_last_layer: false + head_nlayers: 3 + head_hidden_dim: 4096 +gram: + use_loss: true + compute_stats: false + loss_weight: 1.0 + ema_teacher: false + ckpt: ignore + it_load_ema_teacher: -1 + rep_update: true + update_frequency: 10000 + it_first_update: 1010000 + max_updates: 3 + normalized: true + img_level: true + remove_neg: false + remove_only_teacher_neg: false + tokens_used: all + global_teacher_resize_method: bicubic + global_teacher_resize_antialias: false + loss_weight_schedule: + start: 0 + peak: 0 + end: 2.0 + warmup_epochs: 1000 + cosine_epochs: 1 +train: + batch_size_per_gpu: 16 + dataset_path: null + saveckp_freq: 20 + seed: 0 + num_workers: 10 + OFFICIAL_EPOCH_LENGTH: 1000 + monitor_gradient_norm: false + chunk_schedule: [] + cache_dataset: true + use_teacher_head: true + learn_from_teacher_tokens: false + centering: sinkhorn_knopp + checkpointing: true + checkpointing_full: true + compile: true + cudagraphs: false + cell_augmentation: false + cell_augmentation_type: hpa + sharded_eval_checkpoint: true +student: + arch: vit_7b + patch_size: 16 + drop_path_rate: 0.4 + layerscale: 1.0e-05 + patch_drop: 0.0 + pretrained_weights: '' + ffn_layer: swiglu64 + ffn_ratio: 3 + resume_from_teacher_chkpt: '' + qkv_bias: false + proj_bias: true + ffn_bias: true + norm_layer: layernormbf16 + n_storage_tokens: 4 + untie_cls_and_patch_norms: false + untie_global_and_local_cls_norm: true + mask_k_bias: true + in_chans: 3 + pos_embed_type: rope + pos_embed_rope_base: 100 + pos_embed_rope_min_period: null + pos_embed_rope_max_period: null + pos_embed_rope_normalize_coords: separate + pos_embed_rope_shift_coords: null + pos_embed_rope_jitter_coords: null + pos_embed_rope_rescale_coords: 2 + pos_embed_rope_dtype: fp32 + fp8_enabled: true + fp8_filter: blocks +teacher: + momentum_teacher: null + final_momentum_teacher: null + warmup_teacher_temp: null + teacher_temp: null + warmup_teacher_temp_epochs: null + in_chans: 3 +distillation: + enabled: false + full_cfg_path: '' + checkpoint_path: '' +multidistillation: + enabled: false +hrft: + enabled: false + checkpoint_path: '' +optim: + epochs: 1200 + optimizer: adamw + weight_decay: null + weight_decay_end: null + lr: null + warmup_epochs: null + min_lr: null + schedule_trunc_extra: null + clip_grad: 30.0 + freeze_last_layer_epochs: null + scaling_rule: sqrt_wrt_1024 + patch_embed_lr_mult: 0.2 + dino_head_wd_multiplier: 1.0 + layerwise_decay: 0.98 + multi_tensor_optim: true + dump_fsdp_weights_path: '' + adamw_beta1: 0.9 + adamw_beta2: 0.99 +crops: + global_crops_scale: + - 0.32 + - 1.0 + local_crops_number: 8 + local_crops_scale: + - 0.05 + - 0.32 + global_crops_size: 256 + local_crops_size: 112 + gram_teacher_crops_size: 512 + localcrops_subset_of_globalcrops: false + share_color_jitter: false + horizontal_flips: false + gram_teacher_no_distortions: true + rgb_mean: + - 0.485 + - 0.456 + - 0.406 + rgb_std: + - 0.229 + - 0.224 + - 0.225 +checkpointing: + period: 1000 + max_to_keep: 3 + keep_every: 50000 +schedules: + lr: + start: 0 + peak: 3.0e-05 + end: 3.0e-05 + warmup_epochs: 100 + freeze_last_layer_epochs: 5 + weight_decay: + start: 0.04 + peak: 0.04 + end: 0.04 + warmup_epochs: 0 + teacher_temp: + start: 0.04 + peak: 0.07 + end: 0.07 + warmup_epochs: 100 + momentum: + start: 0.999 + peak: 0.999 + end: 0.999 + warmup_epochs: 0 diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/dinov3_vit7b16_high_res_adapt.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/dinov3_vit7b16_high_res_adapt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..313a104a7cbf757cc0288f10d8d22eb960bfe644 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/dinov3_vit7b16_high_res_adapt.yaml @@ -0,0 +1,224 @@ +MODEL: + META_ARCHITECTURE: SSLMetaArch + DEVICE: cuda + WEIGHTS: '' + DTYPE: float32 +compute_precision: + param_dtype: bf16 + reduce_dtype: fp32 + sharding_strategy: SHARD_GRAD_OP +dino: + loss_weight: 1.0 + global_ignore_diagonal: true + head_n_prototypes: 262144 + head_bottleneck_dim: 512 + head_norm_last_layer: false + head_nlayers: 3 + head_hidden_dim: 8192 + koleo_loss_weight: 0.1 + koleo_loss_distributed: true + koleo_topk: 1 + koleo_distributed_replicas: 0 + koleo_distributed_loss_group_size: 16 + force_weight_norm: false + reweight_dino_local_loss: true + local_loss_weight_schedule: + start: 0.5 + peak: 0.5 + end: 0.5 + warmup_epochs: 0 + cosine_epochs: 0 + koleo_distributed_loss_group_data: true +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.5 + mask_random_circular_shift: false + force_masking_even_with_zero_weight: false + separate_head: true + head_n_prototypes: 98304 + head_bottleneck_dim: 384 + head_norm_last_layer: false + head_nlayers: 3 + head_hidden_dim: 4096 +gram: + use_loss: true + compute_stats: false + loss_weight: 1.0 + ema_teacher: false + it_load_ema_teacher: -1 + rep_update: false + update_frequency: 10000 + it_first_update: 1010000 + max_updates: 3 + normalized: true + img_level: true + remove_neg: false + remove_only_teacher_neg: false + tokens_used: all + global_teacher_resize_method: bicubic + global_teacher_resize_antialias: false + loss_weight_schedule: + start: 1.5 + peak: 1.5 + end: 1.5 + warmup_epochs: 0 + cosine_epochs: 0 +train: + batch_size_per_gpu: 8 + dataset_path: null + saveckp_freq: 20 + seed: 0 + num_workers: 2 + OFFICIAL_EPOCH_LENGTH: 1000 + monitor_gradient_norm: false + chunk_schedule: [] + cache_dataset: true + use_teacher_head: true + learn_from_teacher_tokens: false + centering: sinkhorn_knopp + checkpointing: true + checkpointing_full: true + compile: true + cudagraphs: false + cell_augmentation: false + cell_augmentation_type: hpa + sharded_eval_checkpoint: true +student: + arch: vit_7b + patch_size: 16 + drop_path_rate: 0.4 + layerscale: 1.0e-05 + patch_drop: 0.0 + pretrained_weights: '' + ffn_layer: swiglu64 + ffn_ratio: 3 + resume_from_teacher_chkpt: '' + qkv_bias: false + proj_bias: true + ffn_bias: true + norm_layer: layernormbf16 + n_storage_tokens: 4 + untie_cls_and_patch_norms: false + untie_global_and_local_cls_norm: true + mask_k_bias: true + in_chans: 3 + pos_embed_type: rope + pos_embed_rope_base: 100 + pos_embed_rope_min_period: null + pos_embed_rope_max_period: null + pos_embed_rope_normalize_coords: separate + pos_embed_rope_shift_coords: null + pos_embed_rope_jitter_coords: null + pos_embed_rope_rescale_coords: 2 + pos_embed_rope_dtype: fp32 + fp8_enabled: true + fp8_filter: blocks +teacher: + momentum_teacher: null + final_momentum_teacher: null + warmup_teacher_temp: null + teacher_temp: null + warmup_teacher_temp_epochs: null + in_chans: 3 +distillation: + enabled: false + full_cfg_path: '' + checkpoint_path: '' +multidistillation: + enabled: false +hrft: + enabled: false + checkpoint_path: '' +optim: + epochs: 30 + optimizer: adamw + weight_decay: null + weight_decay_end: null + lr: null + warmup_epochs: null + min_lr: null + schedule_trunc_extra: null + clip_grad: 30.0 + freeze_last_layer_epochs: null + scaling_rule: sqrt_wrt_1024 + patch_embed_lr_mult: 0.2 + dino_head_wd_multiplier: 1.0 + layerwise_decay: 0.98 + multi_tensor_optim: true + dump_fsdp_weights_path: '' + adamw_beta1: 0.9 + adamw_beta2: 0.99 +crops: + global_crops_scale: + - 0.32 + - 1.0 + local_crops_number: 8 + local_crops_scale: + - 0.05 + - 0.32 + global_crops_size: + - 512 + - 768 + - 768 + - 768 + - 768 + local_crops_size: + - 112 + - 112 + - 168 + - 224 + - 336 + global_local_crop_pairs_ratios: + - 0.3 + - 0.3 + - 0.3 + - 0.05 + - 0.05 + gram_teacher_crops_size: + - 768 + - 1152 + - 1152 + - 1152 + - 1152 + localcrops_subset_of_globalcrops: false + share_color_jitter: false + horizontal_flips: false + gram_teacher_no_distortions: true + rgb_mean: + - 0.485 + - 0.456 + - 0.406 + rgb_std: + - 0.229 + - 0.224 + - 0.225 +checkpointing: + period: 250 + max_to_keep: 3 + keep_every: 50000 +schedules: + lr: + start: 0 + peak: 0 + end: 1.25e-05 + warmup_epochs: 0 + freeze_last_layer_epochs: 0 + cosine_epochs: 10 + weight_decay: + start: 0.04 + peak: 0.04 + end: 0.04 + warmup_epochs: 0 + teacher_temp: + start: 0.07 + peak: 0.07 + end: 0.07 + warmup_epochs: 0 + momentum: + start: 0.999 + peak: 0.999 + end: 0.999 + warmup_epochs: 0 diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/dinov3_vit7b16_pretrain.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/dinov3_vit7b16_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7b93be46ecedb582befb9d4fcb7ed185410f2517 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/dinov3_vit7b16_pretrain.yaml @@ -0,0 +1,172 @@ +MODEL: + META_ARCHITECTURE: SSLMetaArch + DEVICE: cuda + WEIGHTS: '' + DTYPE: float32 +compute_precision: + param_dtype: bf16 + reduce_dtype: fp32 + sharding_strategy: SHARD_GRAD_OP +dino: + loss_weight: 1.0 + global_ignore_diagonal: true + head_n_prototypes: 262144 + head_bottleneck_dim: 512 + head_norm_last_layer: false + head_nlayers: 3 + head_hidden_dim: 8192 + koleo_loss_weight: 0.1 + koleo_loss_distributed: false + koleo_topk: 1 + koleo_distributed_replicas: 0 + koleo_distributed_loss_group_size: null + force_weight_norm: false +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.5 + mask_random_circular_shift: false + force_masking_even_with_zero_weight: false + separate_head: true + head_n_prototypes: 98304 + head_bottleneck_dim: 384 + head_norm_last_layer: false + head_nlayers: 3 + head_hidden_dim: 4096 +gram: + use_loss: false + compute_stats: false +train: + batch_size_per_gpu: 16 + dataset_path: null + saveckp_freq: 20 + seed: 0 + num_workers: 10 + OFFICIAL_EPOCH_LENGTH: 1000 + monitor_gradient_norm: false + chunk_schedule: [] + cache_dataset: true + use_teacher_head: true + learn_from_teacher_tokens: false + centering: sinkhorn_knopp + checkpointing: true + checkpointing_full: false + compile: true + cudagraphs: false + cell_augmentation: false + cell_augmentation_type: hpa + sharded_eval_checkpoint: true +student: + arch: vit_7b + patch_size: 16 + drop_path_rate: 0.4 + layerscale: 1.0e-05 + patch_drop: 0.0 + pretrained_weights: '' + ffn_layer: swiglu64 + ffn_ratio: 3 + resume_from_teacher_chkpt: '' + qkv_bias: false + proj_bias: true + ffn_bias: true + norm_layer: layernormbf16 + n_storage_tokens: 4 + untie_cls_and_patch_norms: false + untie_global_and_local_cls_norm: true + mask_k_bias: true + in_chans: 3 + pos_embed_type: rope + pos_embed_rope_base: 100 + pos_embed_rope_min_period: null + pos_embed_rope_max_period: null + pos_embed_rope_normalize_coords: separate + pos_embed_rope_shift_coords: null + pos_embed_rope_jitter_coords: null + pos_embed_rope_rescale_coords: 2 + pos_embed_rope_dtype: fp32 + fp8_enabled: true + fp8_filter: blocks +teacher: + momentum_teacher: null + final_momentum_teacher: null + warmup_teacher_temp: null + teacher_temp: null + warmup_teacher_temp_epochs: null + in_chans: 3 +distillation: + enabled: false + full_cfg_path: '' + checkpoint_path: '' +multidistillation: + enabled: false +hrft: + enabled: false + checkpoint_path: '' +optim: + epochs: 1000 + optimizer: adamw + weight_decay: null + weight_decay_end: null + lr: null + warmup_epochs: null + min_lr: null + schedule_trunc_extra: null + clip_grad: 30.0 + freeze_last_layer_epochs: null + scaling_rule: sqrt_wrt_1024 + patch_embed_lr_mult: 0.2 + dino_head_wd_multiplier: 1.0 + layerwise_decay: 0.98 + multi_tensor_optim: true + dump_fsdp_weights_path: '' + adamw_beta1: 0.9 + adamw_beta2: 0.99 +crops: + global_crops_scale: + - 0.32 + - 1.0 + local_crops_number: 8 + local_crops_scale: + - 0.05 + - 0.32 + global_crops_size: 256 + local_crops_size: 112 + localcrops_subset_of_globalcrops: false + share_color_jitter: false + horizontal_flips: false + rgb_mean: + - 0.485 + - 0.456 + - 0.406 + rgb_std: + - 0.229 + - 0.224 + - 0.225 +checkpointing: + period: 1000 + max_to_keep: 3 + keep_every: 50000 +schedules: + lr: + start: 0 + peak: 5.0e-05 + end: 5.0e-05 + warmup_epochs: 100 + freeze_last_layer_epochs: 5 + weight_decay: + start: 0.04 + peak: 0.04 + end: 0.04 + warmup_epochs: 0 + teacher_temp: + start: 0.04 + peak: 0.07 + end: 0.07 + warmup_epochs: 100 + momentum: + start: 0.994 + peak: 0.994 + end: 0.994 + warmup_epochs: 0 diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/dinov3_vitl16_lvd1689m_distilled.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/dinov3_vitl16_lvd1689m_distilled.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5d81e605593401de00ea6ac68630714a2c707fcc --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/dinov3_vitl16_lvd1689m_distilled.yaml @@ -0,0 +1,251 @@ +MODEL: + META_ARCHITECTURE: MultiDistillationMetaArch + DEVICE: cuda + WEIGHTS: '' + DTYPE: float32 +compute_precision: + param_dtype: bf16 + reduce_dtype: fp32 + sharding_strategy: SHARD_GRAD_OP +dino: + loss_weight: 1.0 + global_ignore_diagonal: true + head_n_prototypes: 262144 + head_bottleneck_dim: 512 + head_norm_last_layer: false + head_nlayers: 3 + head_hidden_dim: 8192 + koleo_loss_weight: 0.1 + koleo_loss_distributed: false + koleo_topk: 1 + koleo_distributed_replicas: 0 + koleo_distributed_loss_group_size: null + koleo_distributed_loss_group_data: true + force_weight_norm: false + reweight_dino_local_loss: false + local_loss_weight_schedule: + start: 0.5 + peak: 0.5 + end: 0.5 + warmup_epochs: 0 +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.5 + mask_random_circular_shift: false + force_masking_even_with_zero_weight: false + separate_head: true + head_n_prototypes: 98304 + head_bottleneck_dim: 384 + head_norm_last_layer: false + head_nlayers: 3 + head_hidden_dim: 4096 +coding_rate_loss: + use_cls_loss: false + cls_loss_weight: 0.2 + use_masked_patches_loss: false + masked_patches_loss_weight: 0.1 + epsilon: 8 +gram: + use_loss: false + compute_stats: false + loss_weight: 1.0 + ema_teacher: false + ckpt: null + it_load_ema_teacher: -1 + rep_update: true + update_frequency: 50000 + it_first_update: 0 + max_updates: null + normalized: true + img_level: false + remove_neg: false + remove_only_teacher_neg: false + tokens_used: all + global_teacher_resize_method: bicubic + global_teacher_resize_antialias: false + loss_weight_schedule: null +train: + batch_size_per_gpu: 3 + dataset_path: + output_dir: + saveckp_freq: 20 + seed: 0 + num_workers: 2 + OFFICIAL_EPOCH_LENGTH: 1250 + monitor_gradient_norm: false + chunk_schedule: [] + cache_dataset: true + use_teacher_head: true + learn_from_teacher_tokens: false + centering: sinkhorn_knopp + checkpointing: true + checkpointing_full: true + compile: true + cudagraphs: false + cell_augmentation: false + cell_augmentation_type: hpa + sharded_eval_checkpoint: false +student: + arch: vit_large + patch_size: 16 + drop_path_rate: 0.0 + layerscale: 1.0e-05 + drop_path_uniform: true + drop_path_shape: uniform + patch_drop: 0.0 + pretrained_weights: '' + sin_cos_embeddings: false + fourier_embeddings: false + fourier_encoding_dim: 64 + multiple_pos_embeddings: false + cls_pos_embedding: false + reg_pos_embedding: false + ffn_layer: mlp + ffn_ratio: 4.0 + resume_from_teacher_chkpt: + block_chunks: 0 + qkv_bias: true + proj_bias: true + ffn_bias: true + norm_layer: layernormbf16 + n_storage_tokens: 4 + mask_attention: false + mask_register_attention: false + untie_cls_and_patch_norms: false + untie_global_and_local_cls_norm: false + interpolate_offset: 0.0 + interpolate_antialias: true + mask_k_bias: true + init_std_cls: 0.02 + init_std_reg: 0.02 + rescale_weights_by_layer_id: false + in_chans: 3 + pos_embed_grid_size: 48 + pos_embed_type: ropenew + pos_embed_rope_gamma: 1.0 + pos_embed_rope_init_multi_frequencies: false + pos_embed_rope_base: 100 + pos_embed_rope_min_period: null + pos_embed_rope_max_period: null + pos_embed_rope_normalize_coords: separate + pos_embed_rope_shift_coords: null + pos_embed_rope_jitter_coords: null + pos_embed_rope_rescale_coords: 2 + pos_embed_rope_dtype: bf16 + sparse24_ranges: [] + sparse24_filter: + - mlp + sparse24_default: false + fp8_enabled: false + fp8_filter: blocks +teacher: + momentum_teacher: 0.994 + final_momentum_teacher: 1 + warmup_teacher_temp: 0.04 + teacher_temp: 0.07 + warmup_teacher_temp_epochs: 120 + in_chans: 3 +distillation: + enabled: true + full_cfg_path: + checkpoint_path: +multidistillation: + enabled: true + global_batch_size: 1920 + students: + - name: vits_mlp4_4 + config_path: + ranks_range: + - 0 + - 48 + - name: vitsp_swiglu6_1 + config_path: + ranks_range: + - 48 + - 96 + - name: vitb_mlp4_3 + config_path: + ranks_range: + - 96 + - 176 + - name: vitl_mlp4_1 + config_path: + ranks_range: + - 176 + - 296 +hrft: + enabled: false + checkpoint_path: '' +optim: + epochs: 20 + optimizer: adamw + weight_decay: 0.04 + weight_decay_end: 0.2 + lr: 0.0002 + warmup_epochs: 0 + min_lr: 1.0e-06 + schedule_trunc_extra: 0.0 + clip_grad: 3.0 + freeze_last_layer_epochs: 0 + scaling_rule: sqrt_wrt_1024 + patch_embed_lr_mult: 0.2 + dino_head_wd_multiplier: 1.0 + layerwise_decay: 0.99 + multi_tensor_optim: true + dump_fsdp_weights_path: '' + adamw_beta1: 0.9 + adamw_beta2: 0.999 +crops: + global_crops_scale: + - 0.32 + - 1.0 + local_crops_number: 8 + local_crops_scale: + - 0.05 + - 0.32 + global_crops_size: 256 + local_crops_size: 112 + global_local_crop_pairs_ratios: 1.0 + gram_teacher_crops_size: 256 + localcrops_subset_of_globalcrops: false + share_color_jitter: false + horizontal_flips: false + gram_teacher_no_distortions: false + rgb_mean: + - 0.485 + - 0.456 + - 0.406 + rgb_std: + - 0.229 + - 0.224 + - 0.225 +checkpointing: + period: 3750 + max_to_keep: 3 + keep_every: 99999999999999999 +schedules: + weight_decay: + start: 0.04 + peak: 0.04 + end: 0.04 + warmup_epochs: 0 + teacher_temp: + start: 0.04 + peak: 0.07 + end: 0.07 + warmup_epochs: 0 + lr: + start: 0 + peak: 0 + end: 5.0e-05 + warmup_epochs: 0 + freeze_last_layer_epochs: 0 + cosine_epochs: 10 + momentum: + start: 0.994 + peak: 0.994 + end: 1.0 + warmup_epochs: 0 diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/multi_distillation_test.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/multi_distillation_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..984c075be3557c808b625956c456eece82ed655f --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/multi_distillation_test.yaml @@ -0,0 +1,27 @@ +MODEL: + META_ARCHITECTURE: MultiDistillationMetaArch +multidistillation: + enabled: true + global_batch_size: 256 + students: + - name: vits + config_path: dinov3/configs/train/multidist_tests/vits_p16.yaml + ranks_range: + - 0 + - 4 + - name: vitb + config_path: dinov3/configs/train/multidist_tests/vitb_p16.yaml + ranks_range: + - 4 + - 8 +distillation: # teacher + enabled: true + full_cfg_path: dinov3/configs/train/vitl_im1k_lin834.yaml + checkpoint_path: ignore +train: + dataset_path: ImageNet:split=TRAIN + cache_dataset: false + centering: "sinkhorn_knopp" + compile: true +ibot: + separate_head: true diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/multidist_tests/vitb_p16.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/multidist_tests/vitb_p16.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bbf715426d3c0dddd96cbd76b3cb67a91e68d4cf --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/multidist_tests/vitb_p16.yaml @@ -0,0 +1,7 @@ +# this corresponds to the default config +train: + dataset_path: ImageNet:split=TRAIN + checkpointing: true +student: + drop_path_rate: 0.1 + arch: vit_base diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/multidist_tests/vits_p16.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/multidist_tests/vits_p16.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3c7f831cae01de083174c7a36a13e99125dd2e61 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/multidist_tests/vits_p16.yaml @@ -0,0 +1,6 @@ +# this corresponds to the default config +train: + dataset_path: ImageNet:split=TRAIN +student: + drop_path_rate: 0.1 + arch: vit_small diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/vitl_im1k_lin834.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/vitl_im1k_lin834.yaml new file mode 100644 index 0000000000000000000000000000000000000000..46428be4ac7be9c4aefed38f88b26f34a2f60fa8 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/configs/train/vitl_im1k_lin834.yaml @@ -0,0 +1,143 @@ +# tested on RSC: /checkpoint/dino/qas/rope/vitl16_im1k/ +# gives 82.2 im1k-knn, 83.3 im1k-linear +# runs with a total batch size of 2048 (64/gpu, 4 nodes here) +# runs at 0.57s/iter +MODEL: + META_ARCHITECTURE: SSLMetaArch + DEVICE: cuda + WEIGHTS: '' + DTYPE: float32 +compute_precision: + param_dtype: bf16 + reduce_dtype: fp32 + sharding_strategy: SHARD_GRAD_OP +dino: + loss_weight: 1.0 + global_ignore_diagonal: true + head_n_prototypes: 65536 + head_bottleneck_dim: 256 + head_norm_last_layer: false + head_nlayers: 3 + head_hidden_dim: 2048 + koleo_loss_weight: 0.1 + koleo_loss_distributed: false + koleo_topk: 1 + koleo_distributed_replicas: 0 + force_weight_norm: false +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.5 + mask_random_circular_shift: false + force_masking_even_with_zero_weight: false + separate_head: true + head_n_prototypes: 65536 + head_bottleneck_dim: 256 + head_norm_last_layer: false + head_nlayers: 3 + head_hidden_dim: 2048 +train: + batch_size_per_gpu: 64 + dataset_path: ImageNet:split=TRAIN + output_dir: /checkpoint/dino/qas/rope/vitl16_im1k + saveckp_freq: 20 + seed: 0 + num_workers: 10 + OFFICIAL_EPOCH_LENGTH: 1250 + monitor_gradient_norm: false + chunk_schedule: [] + cache_dataset: true + use_teacher_head: true + learn_from_teacher_tokens: false + centering: sinkhorn_knopp + checkpointing: false + compile: true + cudagraphs: false + cell_augmentation: false + cell_augmentation_type: hpa +student: + arch: vit_large + patch_size: 16 + drop_path_rate: 0.3 + layerscale: 1.0e-05 + patch_drop: 0.0 + pretrained_weights: '' + ffn_layer: mlp + ffn_ratio: 4.0 + resume_from_teacher_chkpt: '' + qkv_bias: true + proj_bias: true + ffn_bias: true + norm_layer: layernorm + n_storage_tokens: 0 + mask_k_bias: false + in_chans: 3 + pos_embed_type: rope + pos_embed_rope_base: 100.0 + pos_embed_rope_min_period: null + pos_embed_rope_max_period: null + pos_embed_rope_normalize_coords: separate # min, max, separate + pos_embed_rope_shift_coords: null + pos_embed_rope_jitter_coords: null + pos_embed_rope_rescale_coords: null + pos_embed_rope_dtype: bf16 + fp8_enabled: False # Convert Linear layers to operate in fp8 precision + fp8_filter: "blocks" # Regex that must appear in module path; empty means everything +teacher: + momentum_teacher: 0.992 + final_momentum_teacher: 1 + warmup_teacher_temp: 0.04 + teacher_temp: 0.07 + warmup_teacher_temp_epochs: 30 + in_chans: 3 +distillation: + enabled: false + full_cfg_path: '' + checkpoint_path: '' +multidistillation: + enabled: false +hrft: + enabled: false + checkpoint_path: '' +optim: + epochs: 100 + optimizer: adamw + weight_decay: 0.04 + weight_decay_end: 0.4 + lr: 0.001 + warmup_epochs: 10 + min_lr: 1.0e-06 + clip_grad: 3.0 + freeze_last_layer_epochs: 1 + scaling_rule: sqrt_wrt_1024 + patch_embed_lr_mult: 0.2 + dino_head_wd_multiplier: 1.0 + layerwise_decay: 0.9 + multi_tensor_optim: true + dump_fsdp_weights_path: '' + adamw_beta1: 0.9 + adamw_beta2: 0.999 +crops: + global_crops_scale: + - 0.32 + - 1.0 + local_crops_number: 8 + local_crops_scale: + - 0.05 + - 0.32 + global_crops_size: 224 + local_crops_size: 96 + localcrops_subset_of_globalcrops: false + share_color_jitter: false + horizontal_flips: true +evaluation: + eval_period_iterations: 12500 + low_freq_every: 5 + config_files: + high_freq: benchmark_high_frequency.yaml + low_freq: benchmark_low_frequency.yaml +checkpointing: + period: 3750 + max_to_keep: 3 \ No newline at end of file diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8506ec58de482636b88d08b57674ed09e726595f --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from .adapters import DatasetWithEnumeratedTargets +from .augmentations import DataAugmentationDINO +from .collate import collate_data_and_cast +from .loaders import SamplerType, make_data_loader, make_dataset +from .meta_loaders import CombinedDataLoader +from .masking import MaskingGenerator +from .transforms import make_classification_eval_transform, make_classification_train_transform diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/adapters.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/adapters.py new file mode 100644 index 0000000000000000000000000000000000000000..9b56f975679dde8d3269586418ccd3d5bf367760 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/adapters.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from typing import Any, Optional, Tuple + +from torch.utils.data import Dataset + + +def extend_samples_with_index(dataset_class): + class DatasetWithIndex(dataset_class): + def __init__(self, **kwargs) -> None: + root = dataset_class.get_root() + super().__init__(root=root, **kwargs) + + def __getitem__(self, index: int): + image, target = super().__getitem__(index) + return image, target, index + + return DatasetWithIndex + + +class DatasetWithEnumeratedTargets(Dataset): + """ + If pad_dataset is set, pads based on torch's DistributedSampler implementation, which + with drop_last=False pads the last batch to be a multiple of the world size. + https://github.com/pytorch/pytorch/blob/main/torch/utils/data/distributed.py#L91 + """ + + def __init__(self, dataset: Dataset, pad_dataset: bool = False, num_replicas: Optional[int] = None): + self._dataset = dataset + self._size = len(self._dataset) + self._padded_size = self._size + self._pad_dataset = pad_dataset + if self._pad_dataset: + assert num_replicas is not None, "num_replicas should be set if pad_dataset is True" + self._padded_size = num_replicas * ((len(dataset) + num_replicas - 1) // num_replicas) + + def get_image_relpath(self, index: int) -> str: + assert self._pad_dataset or index < self._size + return self._dataset.get_image_relpath(index % self._size) + + def get_image_data(self, index: int) -> bytes: + assert self._pad_dataset or index < self._size + return self._dataset.get_image_data(index % self._size) + + def get_target(self, index: int) -> Tuple[Any, int]: + target = self._dataset.get_target(index % self._size) + if index >= self._size: + assert self._pad_dataset + return (-1, target) + return (index, target) + + def get_sample_decoder(self, index: int) -> Any: + assert self._pad_dataset or index < self._size + return self._dataset.get_sample_decoder(index % self._size) + + def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]: + image, target = self._dataset[index % self._size] + if index >= self._size: + assert self._pad_dataset + return image, (-1, target) + target = index if target is None else target + return image, (index, target) + + def __len__(self) -> int: + return self._padded_size diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/augmentations.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/augmentations.py new file mode 100644 index 0000000000000000000000000000000000000000..710cfd1b4de9c4d990279e6e4e4afec4ee8823b3 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/augmentations.py @@ -0,0 +1,227 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging + +import numpy as np +import torch +from torch import nn +from torchvision.transforms import v2 + +from dinov3.data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, GaussianBlur, make_normalize_transform + +logger = logging.getLogger("dinov3") + + +class DataAugmentationDINO(object): + def __init__( + self, + global_crops_scale, + local_crops_scale, + local_crops_number, + global_crops_size=224, + local_crops_size=96, + gram_teacher_crops_size=None, + gram_teacher_no_distortions=False, + teacher_no_color_jitter=False, + local_crops_subset_of_global_crops=False, + patch_size=16, + share_color_jitter=False, + horizontal_flips=True, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + ): + self.global_crops_scale = global_crops_scale + self.local_crops_scale = local_crops_scale + self.local_crops_number = local_crops_number + self.global_crops_size = global_crops_size + self.local_crops_size = local_crops_size + self.gram_teacher_crops_size = gram_teacher_crops_size + self.gram_teacher_no_distortions = gram_teacher_no_distortions + self.teacher_no_color_jitter = teacher_no_color_jitter + self.local_crops_subset_of_global_crops = local_crops_subset_of_global_crops + self.patch_size = patch_size + self.share_color_jitter = share_color_jitter + self.mean = mean + self.std = std + + logger.info("###################################") + logger.info("Using data augmentation parameters:") + logger.info(f"global_crops_scale: {global_crops_scale}") + logger.info(f"local_crops_scale: {local_crops_scale}") + logger.info(f"local_crops_number: {local_crops_number}") + logger.info(f"global_crops_size: {global_crops_size}") + logger.info(f"local_crops_size: {local_crops_size}") + logger.info(f"gram_crops_size: {gram_teacher_crops_size}") + logger.info(f"gram_teacher_no_distortions: {gram_teacher_no_distortions}") + logger.info(f"teacher_no_color_jitter: {teacher_no_color_jitter}") + logger.info(f"local_crops_subset_of_global_crops: {local_crops_subset_of_global_crops}") + logger.info(f"patch_size if local_crops_subset_of_global_crops: {patch_size}") + logger.info(f"share_color_jitter: {share_color_jitter}") + logger.info(f"horizontal flips: {horizontal_flips}") + logger.info("###################################") + + # Global crops and gram teacher crops can have different sizes. We first take a crop of the maximum size + # and then resize it to the desired size for global and gram teacher crops. + global_crop_max_size = max(global_crops_size, gram_teacher_crops_size if gram_teacher_crops_size else 0) + + # random resized crop and flip + self.geometric_augmentation_global = v2.Compose( + [ + v2.RandomResizedCrop( + global_crop_max_size, + scale=global_crops_scale, + interpolation=v2.InterpolationMode.BICUBIC, + ), + v2.RandomHorizontalFlip(p=0.5 if horizontal_flips else 0.0), + ] + ) + + resize_global = nn.Identity() # Resize transform applied to global crops after random crop + self.resize_global_post_transf = ( + nn.Identity() + ) # Resize transform applied to global crops after all other transforms + self.resize_gram_teacher = None # Resize transform applied to crops for gram teacher + if gram_teacher_crops_size is not None: + # All resize transforms will do nothing if the crop size is already the desired size. + if gram_teacher_no_distortions: + # When there a no distortions for the gram teacher crop, we can resize before the distortions. + # This is the preferred order, because it keeps the image size for the augmentations consistent, + # which matters e.g. for GaussianBlur. + resize_global = v2.Resize( + global_crops_size, + interpolation=v2.InterpolationMode.BICUBIC, + ) + else: + # When there a no distortions for the gram teacher crop, we need to resize after the distortions, + # because the distortions are shared between global and gram teacher crops. + self.resize_global_post_transf = v2.Resize( + global_crops_size, + interpolation=v2.InterpolationMode.BICUBIC, + ) + + self.resize_gram_teacher = v2.Resize( + gram_teacher_crops_size, + interpolation=v2.InterpolationMode.BICUBIC, + ) + + self.geometric_augmentation_local = v2.Compose( + [ + v2.RandomResizedCrop( + local_crops_size, + scale=local_crops_scale, + interpolation=v2.InterpolationMode.BICUBIC, + ), + v2.RandomHorizontalFlip(p=0.5 if horizontal_flips else 0.0), + ] + ) + + # color distortions / blurring + color_jittering = v2.Compose( + [ + v2.RandomApply( + [v2.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], + p=0.8, + ), + v2.RandomGrayscale(p=0.2), + ] + ) + + global_transfo1_extra = GaussianBlur(p=1.0) + + global_transfo2_extra = v2.Compose( + [ + GaussianBlur(p=0.1), + v2.RandomSolarize(threshold=128, p=0.2), + ] + ) + + local_transfo_extra = GaussianBlur(p=0.5) + + # normalization + self.normalize = v2.Compose( + [ + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + make_normalize_transform(mean=mean, std=std), + ] + ) + + if self.share_color_jitter: + self.color_jittering = color_jittering + self.global_transfo1 = v2.Compose([resize_global, global_transfo1_extra, self.normalize]) + self.global_transfo2 = v2.Compose([resize_global, global_transfo2_extra, self.normalize]) + self.local_transfo = v2.Compose([local_transfo_extra, self.normalize]) + else: + self.global_transfo1 = v2.Compose( + [resize_global, color_jittering, global_transfo1_extra, self.normalize] + ) + self.global_transfo2 = v2.Compose( + [resize_global, color_jittering, global_transfo2_extra, self.normalize] + ) + self.local_transfo = v2.Compose([color_jittering, local_transfo_extra, self.normalize]) + + def __call__(self, image): + output = {} + output["weak_flag"] = True # some residual from mugs + + if self.share_color_jitter: + image = self.color_jittering(image) + + # global crops: + im1_base = self.geometric_augmentation_global(image) + global_crop_1_transf = self.global_transfo1(im1_base) + global_crop_1 = self.resize_global_post_transf(global_crop_1_transf) + + im2_base = self.geometric_augmentation_global(image) + global_crop_2_transf = self.global_transfo2(im2_base) + global_crop_2 = self.resize_global_post_transf(global_crop_2_transf) + + output["global_crops"] = [global_crop_1, global_crop_2] + + # global crops for teacher: + if self.teacher_no_color_jitter: + output["global_crops_teacher"] = [ + self.normalize(im1_base), + self.normalize(im2_base), + ] + else: + output["global_crops_teacher"] = [global_crop_1, global_crop_2] + + if self.gram_teacher_crops_size is not None: + # crops for gram teacher: + if self.gram_teacher_no_distortions: + gram_crop_1 = self.normalize(self.resize_gram_teacher(im1_base)) + gram_crop_2 = self.normalize(self.resize_gram_teacher(im2_base)) + else: + gram_crop_1 = self.resize_gram_teacher(global_crop_1_transf) + gram_crop_2 = self.resize_gram_teacher(global_crop_2_transf) + output["gram_teacher_crops"] = [gram_crop_1, gram_crop_2] + + # local crops: + if self.local_crops_subset_of_global_crops: + _local_crops = [self.local_transfo(im1_base) for _ in range(self.local_crops_number // 2)] + [ + self.local_transfo(im2_base) for _ in range(self.local_crops_number // 2) + ] + + local_crops = [] + offsets = [] + gs = self.global_crops_size + ls = self.local_crops_size + for img in _local_crops: + rx, ry = np.random.randint(0, (gs - ls) // self.patch_size, 2) * self.patch_size + local_crops.append(img[:, rx : rx + ls, ry : ry + ls]) + offsets.append((rx, ry)) + + output["local_crops"] = local_crops + output["offsets"] = offsets + else: + local_crops = [ + self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number) + ] + output["local_crops"] = local_crops + output["offsets"] = () + + return output diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/collate.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/collate.py new file mode 100644 index 0000000000000000000000000000000000000000..8470008d456930e1675ecb1abea575aac745fade --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/collate.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import random + +import torch + + +def collate_data_and_cast( + samples_list, + mask_ratio_tuple, + mask_probability, + dtype, + n_tokens=None, + mask_generator=None, + random_circular_shift=False, + local_batch_size=None, +): + n_global_crops = len(samples_list[0][0]["global_crops"]) + n_local_crops = len(samples_list[0][0]["local_crops"]) + + collated_global_crops = torch.stack( + [s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list] + ) # [n_global_crops, B, ...] + collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list]) + if "gram_teacher_crops" in samples_list[0][0]: + collated_gram_teacher_crops = torch.stack( + [s[0]["gram_teacher_crops"][i] for i in range(n_global_crops) for s in samples_list] + ) # [n_global_crops, B, ...] + else: + collated_gram_teacher_crops = None + + if local_batch_size is not None: + # multi-distillation case, number of masks is different because the number of samples masked + # is different of the number of samples passed into the teacher initially + B = n_global_crops * local_batch_size + else: + B = len(collated_global_crops) + N = n_tokens + n_samples_masked = int(B * mask_probability) + probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1) + upperbound = 0 + masks_list = [] + for i in range(0, n_samples_masked): + prob_max = probs[i + 1] + mask = torch.BoolTensor(mask_generator(int(N * prob_max))) + if random_circular_shift: # apply le random circular shift to + shift_x, shift_y = ( + random.randint(0, mask.shape[0] - 1), + random.randint(0, mask.shape[1] - 1), + ) + mask = torch.roll(mask, (shift_x, shift_y), (0, 1)) + masks_list.append(mask) + upperbound += int(N * prob_max) + for _ in range(n_samples_masked, B): + masks_list.append(torch.BoolTensor(mask_generator(0))) + + random.shuffle(masks_list) + + collated_masks = torch.stack(masks_list).flatten(1) + mask_indices_list = collated_masks.flatten().nonzero().flatten() + + masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks] + + out = { + "collated_global_crops": collated_global_crops.to(dtype), + "collated_local_crops": collated_local_crops.to(dtype), + "collated_masks": collated_masks, + "mask_indices_list": mask_indices_list, + "masks_weight": masks_weight, + "upperbound": upperbound, + "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long), + } + if collated_gram_teacher_crops is not None: + out["collated_gram_teacher_crops"] = collated_gram_teacher_crops.to(dtype) + return out + + +# def get_batch_subset(collated_data_batch, target_bs): +def get_batch_subset(collated_data_batch, divide_by): + old_bs = collated_data_batch["collated_global_crops"].shape[0] // 2 + target_bs = (old_bs + divide_by - 1) // divide_by + collated_global_crops = ( + collated_data_batch["collated_global_crops"].unflatten(0, (2, old_bs)).narrow(1, 0, target_bs).flatten(0, 1) + ) + collated_local_crops = ( + collated_data_batch["collated_local_crops"].unflatten(0, (-1, old_bs)).narrow(1, 0, target_bs).flatten(0, 1) + ) + + masks_old_bs = collated_data_batch["collated_masks"].shape[0] // 2 + masks_target_bs = masks_old_bs // divide_by + collated_masks = ( + collated_data_batch["collated_masks"] + .unflatten(0, (2, masks_old_bs)) + .narrow(1, 0, masks_target_bs) + .flatten(0, 1) + ) + mask_indices_list = collated_masks.flatten().nonzero().flatten() + + while mask_indices_list.shape[0] == 0: + _unbind = list(collated_data_batch["collated_masks"].unbind(0)) + random.shuffle(_unbind) + _bind = torch.stack(_unbind, dim=0) + collated_masks = _bind.unflatten(0, (2, masks_old_bs)).narrow(1, 0, masks_target_bs).flatten(0, 1) + mask_indices_list = collated_masks.flatten().nonzero().flatten() + + masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks] + upperbound = collated_data_batch["upperbound"] + + new_batch = { + "collated_global_crops": collated_global_crops, + "collated_local_crops": collated_local_crops, + "collated_masks": collated_masks, + "mask_indices_list": mask_indices_list, + "masks_weight": masks_weight, + "upperbound": upperbound, + "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long), + } + + if "global_batch_size" in collated_data_batch.keys(): + new_batch["global_batch_size"] = collated_data_batch["global_batch_size"] // divide_by + + return new_batch diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9075e9fecb2fb4a37361df7106e9b6e6a56df41a --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from .ade20k import ADE20K +from .coco_captions import CocoCaptions +from .image_net import ImageNet +from .image_net_22k import ImageNet22k +from .nyu import NYU diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/ade20k.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/ade20k.py new file mode 100644 index 0000000000000000000000000000000000000000..17ebe7320fa4e8138292bbf86529e60cf2339085 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/ade20k.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import os +from enum import Enum +from typing import Any, Callable, List, Optional, Tuple, Union + +from PIL import Image + +from .decoders import Decoder, DenseTargetDecoder, ImageDataDecoder +from .extended import ExtendedVisionDataset + + +class _Split(Enum): + TRAIN = "train" + VAL = "val" + + @property + def dirname(self) -> str: + return { + _Split.TRAIN: "training", + _Split.VAL: "validation", + }[self] + + +def _file_to_segmentation_path(file_name: str, segm_base_path: str) -> str: + file_name_noext = os.path.splitext(file_name)[0] + return os.path.join(segm_base_path, file_name_noext + ".png") + + +def _load_segmentation(root: str, split_file_names: List[str]): + segm_base_path = "annotations" + segmentation_paths = [_file_to_segmentation_path(file_name, segm_base_path) for file_name in split_file_names] + return segmentation_paths + + +def _load_file_paths(root: str, split: _Split) -> Tuple[List[str], List[str]]: + with open(os.path.join(root, f"ADE20K_object150_{split.value}.txt")) as f: + split_file_names = sorted(f.read().strip().split("\n")) + + all_segmentation_paths = _load_segmentation(root, split_file_names) + file_names = [os.path.join("images", el) for el in split_file_names] + return file_names, all_segmentation_paths + + +class ADE20K(ExtendedVisionDataset): + Split = Union[_Split] + Labels = Union[Image.Image] + + def __init__( + self, + split: "ADE20K.Split", + root: Optional[str] = None, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + image_decoder: Decoder = ImageDataDecoder, + target_decoder: Decoder = DenseTargetDecoder, + ) -> None: + super().__init__( + root=root, + transforms=transforms, + transform=transform, + target_transform=target_transform, + image_decoder=image_decoder, + target_decoder=target_decoder, + ) + + self.image_paths, self.target_paths = _load_file_paths(root, split) + + def get_image_data(self, index: int) -> bytes: + image_relpath = self.image_paths[index] + image_full_path = os.path.join(self.root, image_relpath) + with open(image_full_path, mode="rb") as f: + image_data = f.read() + return image_data + + def get_target(self, index: int) -> Any: + target_relpath = self.target_paths[index] + target_full_path = os.path.join(self.root, target_relpath) + with open(target_full_path, mode="rb") as f: + target_data = f.read() + return target_data + + def __len__(self) -> int: + return len(self.image_paths) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/coco_captions.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/coco_captions.py new file mode 100644 index 0000000000000000000000000000000000000000..9622982ed3ada27fb13b15c297126b0aacadf32b --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/coco_captions.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import json +import os +import random +from enum import Enum +from typing import Callable, Dict, List, Optional, Union + +from .decoders import ImageDataDecoder, TargetDecoder +from .extended import ExtendedVisionDataset + +# Dataset: https://www.kaggle.com/datasets/nikhil7280/coco-image-caption + + +class _Split(Enum): + TRAIN = "train" + VAL = "val" + + +def read_images_and_captions(root: str, split: _Split) -> List[Dict]: + image_dir = None + if _Split(split) == _Split.TRAIN: + annotations_full_path = os.path.join( + root, "annotations_trainval2014/annotations/captions_train2014.json" + ) + image_dir = os.path.join(root, "train2014/train2014") + else: + annotations_full_path = os.path.join( + root, "annotations_trainval2017/annotations/captions_train2017.json" + ) + image_dir = os.path.join(root, "val2017/val2017") + with open(annotations_full_path) as f: + all_annotations = json.load(f) + data = {} + for item in all_annotations["images"]: + id = item["id"] + data[id] = { + "id": None, + "image": os.path.join(image_dir, item["file_name"]), + "captions": [], + } + for item in all_annotations["annotations"]: + data[item["image_id"]]["id"] = item["image_id"] + data[item["image_id"]]["captions"].append(item["caption"]) + return list(data.values()) + + +class CocoCaptions(ExtendedVisionDataset): + Split = Union[_Split] + + def __init__( + self, + *, + split: "CocoCaptions.Split", + root: Optional[str] = None, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__( + root=root, + transforms=transforms, + transform=transform, + target_transform=target_transform, + image_decoder=ImageDataDecoder, + target_decoder=TargetDecoder, + ) + + self.image_captions = read_images_and_captions(root, split) + + def get_image_relpath(self, index: int) -> str: + image_path = self.image_captions[index]["image"] + return image_path + + def get_image_data(self, index: int) -> bytes: + image_path = self.get_image_relpath(index) + with open(image_path, mode="rb") as f: + image_data = f.read() + return image_data + + def get_target(self, index: int) -> str: + return random.choice(self.image_captions[index]["captions"]) + + def __len__(self) -> int: + return len(self.image_captions) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/decoders.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..44715ee67f03dfc9aca1c2fde76c0d2f8488e198 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/decoders.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from io import BytesIO +from typing import Any + +from PIL import Image + + +class Decoder: + def decode(self) -> Any: + raise NotImplementedError + + +class ImageDataDecoder(Decoder): + def __init__(self, image_data: bytes) -> None: + self._image_data = image_data + + def decode(self) -> Image: + f = BytesIO(self._image_data) + return Image.open(f).convert(mode="RGB") + + +class TargetDecoder(Decoder): + def __init__(self, target: Any): + self._target = target + + def decode(self) -> Any: + return self._target + + +class DenseTargetDecoder(Decoder): + def __init__(self, image_data: bytes) -> None: + self._image_data = image_data + + def decode(self) -> Image: + f = BytesIO(self._image_data) + return Image.open(f) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/extended.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/extended.py new file mode 100644 index 0000000000000000000000000000000000000000..f7e0d5db29664b70136afed3eddb3f7326d8a59c --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/extended.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from typing import Any, Tuple + +from torchvision.datasets import VisionDataset + +from .decoders import Decoder, ImageDataDecoder, TargetDecoder + + +class ExtendedVisionDataset(VisionDataset): + def __init__( + self, + image_decoder: Decoder = ImageDataDecoder, + target_decoder: Decoder = TargetDecoder, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) # type: ignore + self.image_decoder = image_decoder + self.target_decoder = target_decoder + + def get_image_data(self, index: int) -> bytes: + raise NotImplementedError + + def get_target(self, index: int) -> Any: + raise NotImplementedError + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + try: + image_data = self.get_image_data(index) + image = self.image_decoder(image_data).decode() + except Exception as e: + raise RuntimeError(f"can not read image for sample {index}") from e + target = self.get_target(index) + target = self.target_decoder(target).decode() + + if self.transforms is not None: + image, target = self.transforms(image, target) + + return image, target + + def __len__(self) -> int: + raise NotImplementedError diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/image_net.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/image_net.py new file mode 100644 index 0000000000000000000000000000000000000000..f148cecfb95f138db739f6f25cb2c81391b1ad79 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/image_net.py @@ -0,0 +1,297 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import csv +import logging +import os +from enum import Enum +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np + +from .decoders import ImageDataDecoder, TargetDecoder +from .extended import ExtendedVisionDataset + +logger = logging.getLogger("dinov3") +_Target = int + + +class _Split(Enum): + TRAIN = "train" + VAL = "val" + TEST = "test" # NOTE: torchvision does not support the test split + + @property + def length(self) -> int: + split_lengths = { + _Split.TRAIN: 1_281_167, + _Split.VAL: 50_000, + _Split.TEST: 100_000, + } + return split_lengths[self] + + def get_dirname(self, class_id: Optional[str] = None) -> str: + return self.value if class_id is None else os.path.join(self.value, class_id) + + def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str: + dirname = self.get_dirname(class_id) + if self == _Split.TRAIN: + basename = f"{class_id}_{actual_index}" + else: # self in (_Split.VAL, _Split.TEST): + basename = f"ILSVRC2012_{self.value}_{actual_index:08d}" + return os.path.join(dirname, basename + ".JPEG") + + def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]: + assert self != _Split.TEST + dirname, filename = os.path.split(image_relpath) + class_id = os.path.split(dirname)[-1] + basename, _ = os.path.splitext(filename) + actual_index = int(basename.split("_")[-1]) + return class_id, actual_index + + +class ImageNet(ExtendedVisionDataset): + Target = Union[_Target] + Split = Union[_Split] + + def __init__( + self, + *, + split: "ImageNet.Split", + root: str, + extra: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__( + root=root, + transforms=transforms, + transform=transform, + target_transform=target_transform, + image_decoder=ImageDataDecoder, + target_decoder=TargetDecoder, + ) + self._extra_root = extra + self._split = split + + self._entries = None + self._class_ids = None + self._class_names = None + + @property + def split(self) -> "ImageNet.Split": + return self._split + + def _get_extra_full_path(self, extra_path: str) -> str: + return os.path.join(self._extra_root, extra_path) + + def _load_extra(self, extra_path: str) -> np.ndarray: + extra_full_path = self._get_extra_full_path(extra_path) + return np.load(extra_full_path, mmap_mode="r") + + def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None: + extra_full_path = self._get_extra_full_path(extra_path) + os.makedirs(self._extra_root, exist_ok=True) + np.save(extra_full_path, extra_array) + + @property + def _entries_path(self) -> str: + return f"entries-{self._split.value.upper()}.npy" + + @property + def _class_ids_path(self) -> str: + return f"class-ids-{self._split.value.upper()}.npy" + + @property + def _class_names_path(self) -> str: + return f"class-names-{self._split.value.upper()}.npy" + + def _get_entries(self) -> np.ndarray: + if self._entries is None: + self._entries = self._load_extra(self._entries_path) + assert self._entries is not None + return self._entries + + def _get_class_ids(self) -> np.ndarray: + if self._split == _Split.TEST: + raise AssertionError("Class IDs are not available in TEST split") + if self._class_ids is None: + self._class_ids = self._load_extra(self._class_ids_path) + assert self._class_ids is not None + return self._class_ids + + def _get_class_names(self) -> np.ndarray: + if self._split == _Split.TEST: + raise AssertionError("Class names are not available in TEST split") + if self._class_names is None: + self._class_names = self._load_extra(self._class_names_path) + assert self._class_names is not None + return self._class_names + + def find_class_id(self, class_index: int) -> str: + class_ids = self._get_class_ids() + return str(class_ids[class_index]) + + def find_class_name(self, class_index: int) -> str: + class_names = self._get_class_names() + return str(class_names[class_index]) + + def get_image_data(self, index: int) -> bytes: + entries = self._get_entries() + actual_index = entries[index]["actual_index"] + + class_id = self.get_class_id(index) + + image_relpath = self.split.get_image_relpath(actual_index, class_id) + image_full_path = os.path.join(self.root, image_relpath) + with open(image_full_path, mode="rb") as f: + image_data = f.read() + return image_data + + def get_target(self, index: int) -> Optional[Target]: + entries = self._get_entries() + class_index = entries[index]["class_index"] + return None if self.split == _Split.TEST else int(class_index) + + def get_targets(self) -> Optional[np.ndarray]: + entries = self._get_entries() + return None if self.split == _Split.TEST else entries["class_index"] + + def get_class_id(self, index: int) -> Optional[str]: + entries = self._get_entries() + class_id = entries[index]["class_id"] + return None if self.split == _Split.TEST else str(class_id) + + def get_class_name(self, index: int) -> Optional[str]: + entries = self._get_entries() + class_name = entries[index]["class_name"] + return None if self.split == _Split.TEST else str(class_name) + + def __len__(self) -> int: + entries = self._get_entries() + assert len(entries) == self.split.length + return len(entries) + + def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]: + labels_full_path = os.path.join(self.root, labels_path) + labels = [] + + try: + with open(labels_full_path, "r") as f: + reader = csv.reader(f) + for row in reader: + class_id, class_name = row + labels.append((class_id, class_name)) + except OSError as e: + raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e + + return labels + + def _dump_entries(self) -> None: + split = self.split + if split == ImageNet.Split.TEST: + dataset = None + sample_count = split.length + max_class_id_length, max_class_name_length = 0, 0 + else: + labels_path = "labels.txt" + logger.info(f'loading labels from "{labels_path}"') + labels = self._load_labels(labels_path) + + # NOTE: Using torchvision ImageFolder for consistency + from torchvision.datasets import ImageFolder + + dataset_root = os.path.join(self.root, split.get_dirname()) + dataset = ImageFolder(dataset_root) + sample_count = len(dataset) + max_class_id_length, max_class_name_length = -1, -1 + for sample in dataset.samples: + _, class_index = sample + class_id, class_name = labels[class_index] + max_class_id_length = max(len(class_id), max_class_id_length) + max_class_name_length = max(len(class_name), max_class_name_length) + + dtype = np.dtype( + [ + ("actual_index", " old_percent: + logger.info(f"creating entries: {percent}%") + old_percent = percent + + actual_index = index + 1 + class_index = np.uint32(-1) + class_id, class_name = "", "" + entries_array[index] = (actual_index, class_index, class_id, class_name) + else: + class_names = {class_id: class_name for class_id, class_name in labels} + + assert dataset + old_percent = -1 + for index in range(sample_count): + percent = 100 * (index + 1) // sample_count + if percent > old_percent: + logger.info(f"creating entries: {percent}%") + old_percent = percent + + image_full_path, class_index = dataset.samples[index] + image_relpath = os.path.relpath(image_full_path, self.root) + class_id, actual_index = split.parse_image_relpath(image_relpath) + class_name = class_names[class_id] + entries_array[index] = (actual_index, class_index, class_id, class_name) + + logger.info(f'saving entries to "{self._entries_path}"') + self._save_extra(entries_array, self._entries_path) + + def _dump_class_ids_and_names(self) -> None: + split = self.split + if split == ImageNet.Split.TEST: + return + + entries_array = self._load_extra(self._entries_path) + + max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1 + for entry in entries_array: + class_index, class_id, class_name = ( + entry["class_index"], + entry["class_id"], + entry["class_name"], + ) + max_class_index = max(int(class_index), max_class_index) + max_class_id_length = max(len(str(class_id)), max_class_id_length) + max_class_name_length = max(len(str(class_name)), max_class_name_length) + + class_count = max_class_index + 1 + class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}") + class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}") + for entry in entries_array: + class_index, class_id, class_name = ( + entry["class_index"], + entry["class_id"], + entry["class_name"], + ) + class_ids_array[class_index] = class_id + class_names_array[class_index] = class_name + + logger.info(f'saving class IDs to "{self._class_ids_path}"') + self._save_extra(class_ids_array, self._class_ids_path) + + logger.info(f'saving class names to "{self._class_names_path}"') + self._save_extra(class_names_array, self._class_names_path) + + def dump_extra(self) -> None: + self._dump_entries() + self._dump_class_ids_and_names() diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/image_net_22k.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/image_net_22k.py new file mode 100644 index 0000000000000000000000000000000000000000..f226f2300f6dc8f847e853f763b9ff0e5a543cb0 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/image_net_22k.py @@ -0,0 +1,301 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import os +import warnings +from dataclasses import dataclass +from enum import Enum +from functools import lru_cache +from gzip import GzipFile +from io import BytesIO +from mmap import ACCESS_READ, mmap +from typing import Any, Callable, List, Optional, Set, Tuple + +import numpy as np + +from .extended import ExtendedVisionDataset + +_Labels = int + +_DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors + + +@dataclass +class _ClassEntry: + block_offset: int + maybe_filename: Optional[str] = None + + +@dataclass +class _Entry: + class_index: int # noqa: E701 + start_offset: int + end_offset: int + filename: str + + +class _Split(Enum): + TRAIN = "train" + VAL = "val" + + @property + def length(self) -> int: + return { + _Split.TRAIN: 11_797_647, + _Split.VAL: 561_050, + }[self] + + def entries_path(self): + return f"imagenet21kp_{self.value}.txt" + + +def _get_tarball_path(class_id: str) -> str: + return f"{class_id}.tar" + + +def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int): + @lru_cache(maxsize=mmap_cache_size) + def _mmap_tarball(class_id: str) -> mmap: + tarball_path = _get_tarball_path(class_id) + tarball_full_path = os.path.join(tarballs_root, tarball_path) + with open(tarball_full_path) as f: + return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ) + + return _mmap_tarball + + +class ImageNet22k(ExtendedVisionDataset): + _GZIPPED_INDICES: Set[int] = { + 841_545, + 1_304_131, + 2_437_921, + 2_672_079, + 2_795_676, + 2_969_786, + 6_902_965, + 6_903_550, + 6_903_628, + 7_432_557, + 7_432_589, + 7_813_809, + 8_329_633, + 10_296_990, + 10_417_652, + 10_492_265, + 10_598_078, + 10_782_398, + 10_902_612, + 11_203_736, + 11_342_890, + 11_397_596, + 11_589_762, + 11_705_103, + 12_936_875, + 13_289_782, + } + Labels = _Labels + + def __init__( + self, + *, + root: str, + extra: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE, + ) -> None: + super().__init__(root, transforms, transform, target_transform) + self._extra_root = extra + + entries_path = self._get_entries_path(root) + self._entries = self._load_extra(entries_path) + + class_ids_path = self._get_class_ids_path(root) + self._class_ids = self._load_extra(class_ids_path) + + self._gzipped_indices = ImageNet22k._GZIPPED_INDICES + self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size) + + def _get_entries_path(self, root: Optional[str] = None) -> str: + return "entries.npy" + + def _get_class_ids_path(self, root: Optional[str] = None) -> str: + return "class-ids.npy" + + def _find_class_ids(self, path: str) -> List[str]: + class_ids = [] + + with os.scandir(path) as entries: + for entry in entries: + root, ext = os.path.splitext(entry.name) + if ext != ".tar": + continue + class_ids.append(root) + + return sorted(class_ids) + + def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]: + root = self.get_root(root) + entries: List[_Entry] = [] + class_ids = self._find_class_ids(root) + + for class_index, class_id in enumerate(class_ids): + path = os.path.join(root, "blocks", f"{class_id}.log") + class_entries = [] + + try: + with open(path) as f: + for line in f: + line = line.rstrip() + block, filename = line.split(":") + block_offset = int(block[6:]) + filename = filename[1:] + + maybe_filename = None + if filename != "** Block of NULs **": + maybe_filename = filename + _, ext = os.path.splitext(filename) + # assert ext == ".JPEG" + + class_entry = _ClassEntry(block_offset, maybe_filename) + class_entries.append(class_entry) + except OSError as e: + raise RuntimeError(f'can not read blocks file "{path}"') from e + + assert class_entries[-1].maybe_filename is None + + for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]): + assert class_entry1.block_offset <= class_entry2.block_offset + start_offset = 512 * class_entry1.block_offset + end_offset = 512 * class_entry2.block_offset + assert class_entry1.maybe_filename is not None + filename = class_entry1.maybe_filename + entry = _Entry(class_index, start_offset, end_offset, filename) + # Skip invalid image files (PIL throws UnidentifiedImageError) + if filename == "n06470073_47249.JPEG": + continue + entries.append(entry) + + return entries, class_ids + + def _load_extra(self, extra_path: str) -> np.ndarray: + extra_root = self._extra_root + extra_full_path = os.path.join(extra_root, extra_path) + return np.load(extra_full_path, mmap_mode="r") + + def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None: + extra_root = self._extra_root + extra_full_path = os.path.join(extra_root, extra_path) + os.makedirs(extra_root, exist_ok=True) + np.save(extra_full_path, extra_array) + + @property + def _tarballs_root(self) -> str: + return self.root + + def find_class_id(self, class_index: int) -> str: + return str(self._class_ids[class_index]) + + def get_image_data(self, index: int) -> bytes: + entry = self._entries[index] + class_id = entry["class_id"] + class_mmap = self._mmap_tarball(class_id) + + start_offset, end_offset = entry["start_offset"], entry["end_offset"] + try: + mapped_data = class_mmap[start_offset:end_offset] + data = mapped_data[512:] # Skip entry header block + + if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B): + assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}" + with GzipFile(fileobj=BytesIO(data)) as g: + data = g.read() + except Exception as e: + raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e + + return data + + def get_target(self, index: int) -> Any: + return int(self._entries[index]["class_index"]) + + def get_targets(self) -> np.ndarray: + return self._entries["class_index"] + + def get_class_id(self, index: int) -> str: + return str(self._entries[index]["class_id"]) + + def get_class_ids(self) -> np.ndarray: + return self._entries["class_id"] + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return super().__getitem__(index) + + def __len__(self) -> int: + return len(self._entries) + + def _dump_entries(self, *args, **kwargs) -> None: + entries, class_ids = self._load_entries_class_ids(*args, **kwargs) + + max_class_id_length, max_filename_length, max_class_index = -1, -1, -1 + for entry in entries: + class_id = class_ids[entry.class_index] + max_class_index = max(entry.class_index, max_class_index) + max_class_id_length = max(len(class_id), max_class_id_length) + max_filename_length = max(len(entry.filename), max_filename_length) + + dtype = np.dtype( + [ + ("class_index", " None: + entries_path = self._get_entries_path(*args, **kwargs) + entries_array = self._load_extra(entries_path) + + max_class_id_length, max_class_index = -1, -1 + for entry in entries_array: + class_index, class_id = entry["class_index"], entry["class_id"] + max_class_index = max(int(class_index), max_class_index) + max_class_id_length = max(len(str(class_id)), max_class_id_length) + + class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}") + for entry in entries_array: + class_index, class_id = entry["class_index"], entry["class_id"] + class_ids_array[class_index] = class_id + class_ids_path = self._get_class_ids_path(*args, **kwargs) + self._save_extra(class_ids_array, class_ids_path) + + def _dump_extra(self, *args, **kwargs) -> None: + self._dump_entries(*args, *kwargs) + self._dump_class_ids(*args, *kwargs) + + def dump_extra(self, root: Optional[str] = None) -> None: + return self._dump_extra(root) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/nyu.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/nyu.py new file mode 100644 index 0000000000000000000000000000000000000000..33b58986571637a71264c49b3389978f3e5879a4 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/datasets/nyu.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import os +from enum import Enum +from typing import Any, Callable, Optional, Union + +from PIL import Image + +from .decoders import Decoder, DenseTargetDecoder, ImageDataDecoder +from .extended import ExtendedVisionDataset + + +class _Split(Enum): + TRAIN = "train" + VAL = "val" + TEST = "test" + + @property + def data_fname(self) -> str: + _DATA_FNAMES = { + _Split.TRAIN: "nyu_train.txt", + _Split.VAL: "nyu_test.txt", + _Split.TEST: "nyu_test.txt", + } + return _DATA_FNAMES[self] + + +class NYU(ExtendedVisionDataset): + Split = Union[_Split] + Labels = Union[Image.Image] + + def __init__( + self, + *, + split: "NYU.Split", + root: Optional[str] = None, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + image_decoder: Decoder = ImageDataDecoder, + target_decoder: Decoder = DenseTargetDecoder, + ) -> None: + super().__init__( + root=root, + transforms=transforms, + transform=transform, + target_transform=target_transform, + image_decoder=image_decoder, + target_decoder=target_decoder, + ) + self.image_paths = [] + self.target_paths = [] + with open(os.path.join(root, split.data_fname)) as f: + lines = f.readlines() + lines = sorted(lines) + for line in lines: + image_relpath, depth_relpath, _ = line.split() + image_relpath = image_relpath.strip("/") + depth_relpath = depth_relpath.strip("/") + self.image_paths.append(image_relpath) + self.target_paths.append(depth_relpath) + + def get_image_data(self, index: int) -> bytes: + image_relpath = self.image_paths[index] + image_full_path = os.path.join(self.root, image_relpath) + with open(image_full_path, mode="rb") as f: + image_data = f.read() + return image_data + + def get_target(self, index: int) -> Any: + target_relpath = self.target_paths[index] + target_full_path = os.path.join(self.root, target_relpath) + with open(target_full_path, mode="rb") as f: + target_data = f.read() + return target_data + + def __len__(self) -> int: + return len(self.image_paths) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/loaders.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..6a55ad6cfdd77023ca8663886462761b8f589009 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/loaders.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from enum import Enum +from typing import Any, Callable, List, Optional, TypeVar + +import torch +from torch.utils.data import Sampler + +from .datasets import ADE20K, CocoCaptions, ImageNet, ImageNet22k, NYU +from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler + +logger = logging.getLogger("dinov3") + + +class SamplerType(Enum): + DISTRIBUTED = 0 + EPOCH = 1 + INFINITE = 2 + SHARDED_INFINITE = 3 + SHARDED_INFINITE_NEW = 4 + + +def _make_bool_str(b: bool) -> str: + return "yes" if b else "no" + + +def _make_sample_transform( + image_transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, +): + def transform(sample): + image, target = sample + if image_transform is not None: + image = image_transform(image) + if target_transform is not None: + target = target_transform(target) + return image, target + + return transform + + +def _parse_dataset_str(dataset_str: str): + tokens = dataset_str.split(":") + + name = tokens[0] + kwargs = {} + + for token in tokens[1:]: + key, value = token.split("=") + assert key in ("root", "extra", "split") + kwargs[key] = value + + if name == "ImageNet": + class_ = ImageNet + if "split" in kwargs: + kwargs["split"] = ImageNet.Split[kwargs["split"]] + elif name == "ImageNet22k": + class_ = ImageNet22k + elif name == "ADE20K": + class_ = ADE20K + if "split" in kwargs: + kwargs["split"] = ADE20K.Split[kwargs["split"]] + elif name == "CocoCaptions": + class_ = CocoCaptions + if "split" in kwargs: + kwargs["split"] = CocoCaptions.Split[kwargs["split"]] + elif name == "NYU": + class_ = NYU + if "split" in kwargs: + kwargs["split"] = NYU.Split[kwargs["split"]] + else: + raise ValueError(f'Unsupported dataset "{name}"') + + return class_, kwargs + + +def make_dataset( + *, + dataset_str: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + transforms: Optional[Callable] = None, +): + """ + Creates a dataset with the specified parameters. + + Args: + dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN). + transform: A transform to apply to images. + target_transform: A transform to apply to targets. + transforms: A transform to apply to both images and targets. + + Returns: + The created dataset. + """ + logger.info(f'using dataset: "{dataset_str}"') + + class_, kwargs = _parse_dataset_str(dataset_str) + dataset = class_(transform=transform, target_transform=target_transform, transforms=transforms, **kwargs) + + logger.info(f"# of dataset samples: {len(dataset):,d}") + + # Aggregated datasets do not expose (yet) these attributes, so add them. + if not hasattr(dataset, "transform"): + dataset.transform = transform + if not hasattr(dataset, "target_transform"): + dataset.target_transform = target_transform + if not hasattr(dataset, "transforms"): + dataset.transforms = transforms + + return dataset + + +def _make_sampler( + *, + dataset, + type: Optional[SamplerType] = None, + shuffle: bool = False, + seed: int = 0, + size: int = -1, + advance: int = 0, +) -> Optional[Sampler]: + sample_count = len(dataset) + + if type == SamplerType.INFINITE: + logger.info("sampler: infinite") + if size > 0: + raise ValueError("sampler size > 0 is invalid") + return InfiniteSampler( + sample_count=sample_count, + shuffle=shuffle, + seed=seed, + advance=advance, + ) + elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW): + logger.info("sampler: sharded infinite") + if size > 0: + raise ValueError("sampler size > 0 is invalid") + use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW + return ShardedInfiniteSampler( + sample_count=sample_count, + shuffle=shuffle, + seed=seed, + advance=advance, + use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice, + ) + elif type == SamplerType.EPOCH: + logger.info("sampler: epoch") + if advance > 0: + raise NotImplementedError("sampler advance > 0 is not supported") + size = size if size > 0 else sample_count + logger.info(f"# of samples / epoch: {size:,d}") + return EpochSampler( + size=size, + sample_count=sample_count, + shuffle=shuffle, + seed=seed, + ) + elif type == SamplerType.DISTRIBUTED: + logger.info("sampler: distributed") + if size > 0: + raise ValueError("sampler size > 0 is invalid") + if advance > 0: + raise ValueError("sampler advance > 0 is invalid") + return torch.utils.data.DistributedSampler( + dataset=dataset, + shuffle=shuffle, + seed=seed, + drop_last=False, + ) + + logger.info("sampler: none") + return None + + +T = TypeVar("T") + + +def make_data_loader( + *, + dataset, + batch_size: int, + num_workers: int, + shuffle: bool = True, + seed: int = 0, + sampler_type: Optional[SamplerType] = SamplerType.INFINITE, + sampler_size: int = -1, + sampler_advance: int = 0, + drop_last: bool = True, + persistent_workers: bool = False, + collate_fn: Optional[Callable[[List[T]], Any]] = None, + worker_init_fn: Optional[Callable[[List[T]], Any]] = None, +): + """ + Creates a data loader with the specified parameters. + + Args: + dataset: A dataset (third party, LaViDa or WebDataset). + batch_size: The size of batches to generate. + num_workers: The number of workers to use. + shuffle: Whether to shuffle samples. + seed: The random seed to use. + sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None. + sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset. + sampler_advance: How many samples to skip (when applicable). + drop_last: Whether the last non-full batch of data should be dropped. + persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once. + collate_fn: Function that performs batch collation + worker_init_fn: Optional init function for each dataloader worker. + """ + + sampler = _make_sampler( + dataset=dataset, + type=sampler_type, + shuffle=shuffle, + seed=seed, + size=sampler_size, + advance=sampler_advance, + ) + + logger.info("using PyTorch data loader") + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=True, + drop_last=drop_last, + persistent_workers=persistent_workers, + collate_fn=collate_fn, + worker_init_fn=worker_init_fn, + ) + + try: + logger.info(f"# of batches: {len(data_loader):,d}") + except TypeError: # data loader has no length + logger.info("infinite data loader") + return data_loader diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/masking.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/masking.py new file mode 100644 index 0000000000000000000000000000000000000000..691c31142c985dc9e71c5923fcaa60e199f21e83 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/masking.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import math +import random + +import numpy as np + + +class MaskingGenerator: + def __init__( + self, + input_size, + num_masking_patches=None, + min_num_patches=4, + max_num_patches=None, + min_aspect=0.3, + max_aspect=None, + ): + if not isinstance(input_size, tuple): + input_size = (input_size,) * 2 + self.height, self.width = input_size + + self.num_patches = self.height * self.width + self.num_masking_patches = num_masking_patches + + self.min_num_patches = min_num_patches + self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches + + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + + def __repr__(self): + repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( + self.height, + self.width, + self.min_num_patches, + self.max_num_patches, + self.num_masking_patches, + self.log_aspect_ratio[0], + self.log_aspect_ratio[1], + ) + return repr_str + + def get_shape(self): + return self.height, self.width + + def _mask(self, mask, max_mask_patches): + delta = 0 + for _ in range(10): + target_area = random.uniform(self.min_num_patches, max_mask_patches) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < self.width and h < self.height: + top = random.randint(0, self.height - h) + left = random.randint(0, self.width - w) + + num_masked = mask[top : top + h, left : left + w].sum() + # Overlap + if 0 < h * w - num_masked <= max_mask_patches: + for i in range(top, top + h): + for j in range(left, left + w): + if mask[i, j] == 0: + mask[i, j] = 1 + delta += 1 + + if delta > 0: + break + return delta + + def __call__(self, num_masking_patches=0): + mask = np.zeros(shape=self.get_shape(), dtype=bool) + mask_count = 0 + while mask_count < num_masking_patches: + max_mask_patches = num_masking_patches - mask_count + max_mask_patches = min(max_mask_patches, self.max_num_patches) + + delta = self._mask(mask, max_mask_patches) + if delta == 0: + break + else: + mask_count += delta + + return self.complete_mask_randomly(mask, num_masking_patches) + + def complete_mask_randomly(self, mask, num_masking_patches): + shape = mask.shape + m2 = mask.flatten() + to_add = np.random.choice(np.where(~m2)[0], size=num_masking_patches - m2.sum(), replace=False) + m2[to_add] = True + return m2.reshape(shape) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/meta_loaders.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/meta_loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..b82faa3be1605050b14089c208b827c7d1e52f14 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/meta_loaders.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from typing import Any, Iterable, Iterator, List, Tuple, TypeVar + +import numpy as np + +logger = logging.getLogger("dinov3") +Loader = Iterable[List[Any]] +T = TypeVar("T") + + +class CombinedDataLoader: + """ + Combines data loaders using the provided sampling ratios + """ + + GLOBAL_HOMOGENEOUS = 0 + LOCAL_HOMOGENEOUS = 1 + + def __init__( + self, + loaders_with_ratios: Iterable[Tuple[Loader, float]], + batch_size: int, + combining_mode: int = 1, + seed: int = 65537, + name: str = None, + logging_period: int = 100, + ): + if combining_mode not in [self.GLOBAL_HOMOGENEOUS, self.LOCAL_HOMOGENEOUS]: + raise ValueError(f"Unsupported value of combining_mode ({combining_mode})") + loaders, ratios = zip(*loaders_with_ratios) + assert np.all([loader.batch_size == batch_size for loader in loaders]), ( + f"All individual loaders must have the same batch size to the combined data loader for combining_mode={combining_mode}" + ) + self.loaders = loaders + self.ratios = ratios + self.batch_size = batch_size + self.combining_mode = combining_mode + self.initial_seed = seed + self.name = name if name is not None else "" + self.logging_period = logging_period + if combining_mode == self.GLOBAL_HOMOGENEOUS: + logger.info(f"Initialize CDL {self.name} with seed={seed}") + self.seed = seed + self.rng = np.random.default_rng(seed=seed) + else: + logger.info(f"Initialize CDL {self.name} with random seed") + self.seed = 0 + self.rng = np.random.default_rng() + self.loader_count = np.zeros(len(self.loaders)) + + def homogeneous_iterator(self) -> Iterator[List[Any]]: + iteration = 0 + iters = [iter(loader) for loader in self.loaders] + while True: + iteration += 1 + try: + idx = self.rng.choice(len(self.loaders), p=self.ratios) + self.loader_count[idx] += 1 + if iteration % self.logging_period == 0: + logger.info(f"Empirical ratios: CDL {self.name} {self.loader_count / self.loader_count.sum()}") + yield next(iters[idx]) + except StopIteration: + break + + def heterogeneous_iterator(self) -> Iterator[List[Any]]: + pass + + def __iter__(self) -> Iterator[List[Any]]: + if self.combining_mode in [self.GLOBAL_HOMOGENEOUS, self.LOCAL_HOMOGENEOUS]: + logger.info(f"Using homogeneous iterator for CDL {self.name}") + return self.homogeneous_iterator() + else: + raise ValueError(f"Unsupported value of combining_mode ({self.combining_mode})") diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/samplers.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..b58145ba547909d9fcbeeff0ada3438f351f3bf9 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/samplers.py @@ -0,0 +1,229 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import itertools +import warnings +from typing import Any, Optional + +import numpy as np +import torch +from torch.utils.data.sampler import Sampler + +from dinov3.distributed import get_rank, get_world_size + + +class EpochSampler(Sampler): + def __init__( + self, + *, + size: int, + sample_count: int, + shuffle: bool = False, + seed: int = 0, + start: Optional[int] = None, + step: Optional[int] = None, + ): + self._size = size + self._sample_count = sample_count + self._shuffle = shuffle + self._seed = seed + self._start = get_rank() if start is None else start + self._step = get_world_size() if step is None else step + self._epoch = 0 + + def __iter__(self): + count = (self._size + self._sample_count - 1) // self._sample_count + tiled_indices = np.tile(np.arange(self._sample_count), count) + if self._shuffle: + seed = self._seed * self._epoch if self._seed != 0 else self._epoch + rng = np.random.default_rng(seed) + iterable = rng.choice(tiled_indices, self._size, replace=False) + else: + iterable = tiled_indices[: self._size] + + yield from itertools.islice(iterable, self._start, None, self._step) + + def __len__(self): + return (self._size - self._start + self._step - 1) // self._step + + def set_epoch(self, epoch): + self._epoch = epoch + + +def _get_numpy_dtype(size: int) -> Any: + return np.int32 if size <= 2**31 else np.int64 + + +def _get_torch_dtype(size: int) -> Any: + return torch.int32 if size <= 2**31 else torch.int64 + + +def _generate_randperm_indices(*, size: int, generator: torch.Generator): + """Generate the indices of a random permutation.""" + dtype = _get_torch_dtype(size) + # This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921 + perm = torch.arange(size, dtype=dtype) + for i in range(size): + j = torch.randint(i, size, size=(1,), generator=generator).item() + + # Always swap even if no-op + value = perm[j].item() + perm[j] = perm[i].item() + perm[i] = value + yield value + + +class InfiniteSampler(Sampler): + def __init__( + self, + *, + sample_count: int, + shuffle: bool = False, + seed: int = 0, + start: Optional[int] = None, + step: Optional[int] = None, + advance: int = 0, + ): + self._sample_count = sample_count + self._seed = seed + self._shuffle = shuffle + self._start = get_rank() if start is None else start + self._step = get_world_size() if step is None else step + self._advance = advance + + def __iter__(self): + if self._shuffle: + iterator = self._shuffled_iterator() + else: + iterator = self._iterator() + + yield from itertools.islice(iterator, self._advance, None) + + def _iterator(self): + assert not self._shuffle + + while True: + iterable = range(self._sample_count) + yield from itertools.islice(iterable, self._start, None, self._step) + + def _shuffled_iterator(self): + assert self._shuffle + + # Instantiate a generator here (rather than in the ctor) to keep the class + # picklable (requirement of mp.spawn) + generator = torch.Generator().manual_seed(self._seed) + + while True: + iterable = _generate_randperm_indices(size=self._sample_count, generator=generator) + yield from itertools.islice(iterable, self._start, None, self._step) + + +# The following function is somewhat equivalent to _new_shuffle_tensor_slice below, +# but avoids a full in-place random permutation generation. +def _shuffle_tensor_slice( + *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator +) -> np.ndarray: + stop = len(tensor) + count = stop // step + drop_count = stop - step * count + if drop_count: + warnings.warn(f"# of dropped samples: {drop_count}", stacklevel=1) + + dtype = _get_numpy_dtype(stop) + result = np.empty(count, dtype=dtype) + + for i in range(count): + j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0 + + result[i] = result[j] + result[j] = tensor[start + i * step].item() + + return result + + +def _new_shuffle_tensor_slice( + *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator +) -> np.ndarray: + stop = len(tensor) + count = stop // step + dtype = torch.int64 # Needed for using randperm result as indices + count = stop // step + drop_count = stop - step * count + if drop_count: + warnings.warn(f"# of dropped samples: {drop_count}", stacklevel=1) + indices = torch.randperm(count, dtype=dtype, generator=generator) + return tensor[start::step][indices].numpy() + + +def _make_seed(seed: int, start: int, iter_count: int) -> int: + # NOTE: Tried a few variants (including iter_count << 32), this one worked best. + return seed + start + (iter_count << 24) + + +class ShardedInfiniteSampler(Sampler): + def __init__( + self, + *, + sample_count: int, + shuffle: bool = False, + seed: int = 0, + start: Optional[int] = None, + step: Optional[int] = None, + advance: int = 0, + use_new_shuffle_tensor_slice: bool = False, + ): + self._sample_count = sample_count + self._seed = seed + self._shuffle = shuffle + self._start = get_rank() if start is None else start + self._step = get_world_size() if step is None else step + self._advance = advance + self._iter_count = 0 + self._shuffle_tensor_slice_fn = ( + _new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice + ) + + def __iter__(self): + iter_count = self._advance // self._sample_count + if iter_count > 0: + self._advance -= iter_count * self._sample_count + self._iter_count += iter_count + + if self._shuffle: + iterator = self._shuffled_iterator() + else: + iterator = self._iterator() + + yield from itertools.islice(iterator, self._advance, None) + + def _iterator(self): + assert not self._shuffle + + while True: + iterable = range(self._sample_count) + yield from itertools.islice(iterable, self._start, None, self._step) + + def _shuffled_iterator(self): + assert self._shuffle + + # Instantiate a generator here (rather than in the ctor) to be keep the class + # picklable (requirement of mp.spawn) + generator = torch.Generator() + + # Always shuffle everything first + generator.manual_seed(self._seed) + dtype = _get_torch_dtype(self._sample_count) + perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator) + + while True: + # Re-seed on each iteration to allow skipping whole permutations + seed = _make_seed(self._seed, self._start, self._iter_count) + generator.manual_seed(seed) + + iterable = self._shuffle_tensor_slice_fn( + tensor=perm, start=self._start, step=self._step, generator=generator + ) + yield from iterable + self._iter_count += 1 diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/transforms.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..3104c827b5144a4bd87b7d997a191595d703e107 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/data/transforms.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from typing import Sequence + +import torch +from torchvision.transforms import v2 + +logger = logging.getLogger("dinov3") + + +def make_interpolation_mode(mode_str: str) -> v2.InterpolationMode: + return {mode.value: mode for mode in v2.InterpolationMode}[mode_str] + + +class GaussianBlur(v2.RandomApply): + """ + Apply Gaussian Blur to the PIL image. + """ + + def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0): + # NOTE: torchvision is applying 1 - probability to return the original image + keep_p = 1 - p + transform = v2.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max)) + super().__init__(transforms=[transform], p=keep_p) + + +# Use timm's names +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + +CROP_DEFAULT_SIZE = 224 +RESIZE_DEFAULT_SIZE = int(256 * CROP_DEFAULT_SIZE / 224) + + +def make_normalize_transform( + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +) -> v2.Normalize: + return v2.Normalize(mean=mean, std=std) + + +def make_base_transform( + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +) -> v2.Normalize: + return v2.Compose( + [ + v2.ToDtype(torch.float32, scale=True), + make_normalize_transform(mean=mean, std=std), + ] + ) + + +# This roughly matches torchvision's preset for classification training: +# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44 +def make_classification_train_transform( + *, + crop_size: int = CROP_DEFAULT_SIZE, + interpolation=v2.InterpolationMode.BICUBIC, + hflip_prob: float = 0.5, + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +): + transforms_list = [v2.ToImage(), v2.RandomResizedCrop(crop_size, interpolation=interpolation)] + if hflip_prob > 0.0: + transforms_list.append(v2.RandomHorizontalFlip(hflip_prob)) + transforms_list.append(make_base_transform(mean, std)) + transform = v2.Compose(transforms_list) + logger.info(f"Built classification train transform\n{transform}") + return transform + + +def make_resize_transform( + *, + resize_size: int, + resize_square: bool = False, + resize_large_side: bool = False, # Set the larger side to resize_size instead of the smaller + interpolation: v2.InterpolationMode = v2.InterpolationMode.BICUBIC, +): + assert not (resize_square and resize_large_side), "These two options can not be set together" + if resize_square: + logger.info("resizing image as a square") + size = (resize_size, resize_size) + transform = v2.Resize(size=size, interpolation=interpolation) + return transform + elif resize_large_side: + logger.info("resizing based on large side") + transform = v2.Resize(size=None, max_size=resize_size, interpolation=interpolation) + return transform + else: + transform = v2.Resize(resize_size, interpolation=interpolation) + return transform + + +# Derived from make_classification_eval_transform() with more control over resize and crop +def make_eval_transform( + *, + resize_size: int = RESIZE_DEFAULT_SIZE, + crop_size: int = CROP_DEFAULT_SIZE, + resize_square: bool = False, + resize_large_side: bool = False, # Set the larger side to resize_size instead of the smaller + interpolation: v2.InterpolationMode = v2.InterpolationMode.BICUBIC, + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +) -> v2.Compose: + transforms_list = [v2.ToImage()] + resize_transform = make_resize_transform( + resize_size=resize_size, + resize_square=resize_square, + resize_large_side=resize_large_side, + interpolation=interpolation, + ) + transforms_list.append(resize_transform) + if crop_size: + transforms_list.append(v2.CenterCrop(crop_size)) + transforms_list.append(make_base_transform(mean, std)) + transform = v2.Compose(transforms_list) + logger.info(f"Built eval transform\n{transform}") + return transform + + +# This matches (roughly) torchvision's preset for classification evaluation: +# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69 +def make_classification_eval_transform( + *, + resize_size: int = RESIZE_DEFAULT_SIZE, + crop_size: int = CROP_DEFAULT_SIZE, + interpolation=v2.InterpolationMode.BICUBIC, + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +) -> v2.Compose: + return make_eval_transform( + resize_size=resize_size, + crop_size=crop_size, + interpolation=interpolation, + mean=mean, + std=std, + resize_square=False, + resize_large_side=False, + ) + + +def voc2007_classification_target_transform(label, n_categories=20): + one_hot = torch.zeros(n_categories, dtype=int) + for instance in label.instances: + one_hot[instance.category_id] = True + return one_hot + + +def imaterialist_classification_target_transform(label, n_categories=294): + one_hot = torch.zeros(n_categories, dtype=int) + one_hot[label.attributes] = True + return one_hot + + +def get_target_transform(dataset_str): + if "VOC2007" in dataset_str: + return voc2007_classification_target_transform + elif "IMaterialist" in dataset_str: + return imaterialist_classification_target_transform + return None diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/distributed/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c6e4f69488f492cc13d5584b1ddc17d7b3a063a --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/distributed/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# isort: skip_file +from .torch_distributed_wrapper import ( + disable_distributed as disable, + enable_distributed as enable, + get_default_process_group, + get_process_subgroup, + get_rank, + get_subgroup_rank, + get_subgroup_size, + get_world_size, + is_distributed_enabled as is_enabled, + is_main_process, + is_subgroup_main_process, + new_subgroups, + save_in_main_process, + TorchDistributedEnvironment, +) + +from .torch_distributed_primitives import gather_all_tensors, reduce_dict diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/distributed/torch_distributed_primitives.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/distributed/torch_distributed_primitives.py new file mode 100644 index 0000000000000000000000000000000000000000..2cac0438ebcca4b0a496894be638641ea5a28471 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/distributed/torch_distributed_primitives.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from typing import Any, Dict, List, Optional + +import torch +import torch.distributed as dist +from torch.nn import functional as F + +from .torch_distributed_wrapper import get_default_process_group, get_world_size + + +def reduce_dict(input_dict: Dict[Any, torch.Tensor], average: bool = True) -> Dict[Any, torch.Tensor]: + """ + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dictionary with the same fields as + the input dictionary, after reduction. + + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + """ + world_size = get_world_size() + if world_size <= 1: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # Sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + stacked_values = torch.stack(values, dim=0) + dist.all_reduce(stacked_values) + if average: + stacked_values /= world_size + reduced_dict = {k: v for k, v in zip(names, stacked_values)} + return reduced_dict + + +def _simple_gather_all_tensors(result: torch.Tensor, group: Any, world_size: int) -> List[torch.Tensor]: + gathered_result = [torch.zeros_like(result) for _ in range(world_size)] + dist.all_gather(gathered_result, result, group) + return gathered_result + + +def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]: + """ + Copied from https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/utilities/distributed.py + Gather all tensors from several ddp processes onto a list that is broadcasted to all processes. + + Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case + tensors are padded, gathered and then trimmed to secure equal workload for all processes. + + Args: + result: the value to sync + group: the process group to gather results from. Defaults to all processes (world) + + Return: + list with size equal to the process group where element i corresponds to result tensor from process i + """ + if group is None: + group = get_default_process_group() + + # convert tensors to contiguous format + result = result.contiguous() + + world_size = get_world_size() + dist.barrier(group=group) + + # if the tensor is scalar, things are easy + if result.ndim == 0: + return _simple_gather_all_tensors(result, group, world_size) + + # 1. Gather sizes of all tensors + local_size = torch.tensor(result.shape, device=result.device) + local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] + dist.all_gather(local_sizes, local_size, group=group) + max_size = torch.stack(local_sizes).max(dim=0).values + all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) + + # 2. If shapes are all the same, then do a simple gather: + if all_sizes_equal: + return _simple_gather_all_tensors(result, group, world_size) + + # 3. If not, we need to pad each local tensor to maximum size, gather and then truncate + pad_dims = [] + pad_by = (max_size - local_size).detach().cpu() + for val in reversed(pad_by): + pad_dims.append(0) + pad_dims.append(val.item()) + result_padded = F.pad(result, pad_dims) + gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] + dist.all_gather(gathered_result, result_padded, group) + for idx, item_size in enumerate(local_sizes): + slice_param = [slice(dim_size) for dim_size in item_size] + gathered_result[idx] = gathered_result[idx][slice_param] + return gathered_result diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/distributed/torch_distributed_wrapper.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/distributed/torch_distributed_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..36fbe1b7c8dac03c37a042964d74af51ac3c3e7a --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/distributed/torch_distributed_wrapper.py @@ -0,0 +1,351 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +import os +import random +import socket +import subprocess +from datetime import timedelta +from enum import Enum +from typing import List, Sequence + +import torch +import torch.distributed as dist + +logger = logging.getLogger("dinov3") + +_DEFAULT_PROCESS_GROUP = None +_PROCESS_SUBGROUP = None +_BUILTIN_PRINT = None + + +def is_distributed_enabled() -> bool: + """ + Returns: + True if distributed training is enabled. + """ + return dist.is_available() and dist.is_initialized() + + +def get_rank(group=None) -> int: + """ + Returns: + The rank of the current process within the specified process group. + """ + if not is_distributed_enabled(): + return 0 + return dist.get_rank(group=group) + + +def get_world_size(group=None) -> int: + """ + Returns: + The number of processes in the specified process group. + """ + if not is_distributed_enabled(): + return 1 + return dist.get_world_size(group=group) + + +def is_main_process(group=None) -> bool: + """ + Returns: + True if the current process is the main one in the specified process group. + """ + return get_rank(group) == 0 + + +def save_in_main_process(*args, **kwargs) -> None: + """Utility function to save only from the main process.""" + group = kwargs.pop("group", None) + if not is_main_process(group): + return + torch.save(*args, **kwargs) + + +def _restrict_print_to_main_process() -> None: + """This function disables printing when not in the main process.""" + import builtins as __builtin__ + + global _BUILTIN_PRINT + _BUILTIN_PRINT = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_main_process() or force: + _BUILTIN_PRINT(*args, **kwargs) + + __builtin__.print = print + + +def _get_master_port(seed: int = 0) -> int: + MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000) + + master_port_str = os.environ.get("MASTER_PORT") + if master_port_str is None: + rng = random.Random(seed) + return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT) + + return int(master_port_str) + + +def _get_available_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # A "" host address means INADDR_ANY i.e. binding to all interfaces. + # Note this is not compatible with IPv6. + s.bind(("", 0)) + port = s.getsockname()[1] + return port + + +def _parse_slurm_node_list(s: str) -> List[str]: + return subprocess.check_output(["scontrol", "show", "hostnames", s], text=True).splitlines() + + +class JobType(Enum): + TORCHELASTIC = "TorchElastic" + SLURM = "Slurm" + MANUAL = "manual" + + +class TorchDistributedEnvironment: + """ + Helper class to get (and set) distributed job information from the + environment. Identifies and supports (in this order): + - TorchElastic, + - Slurm, + - Manual launch (single-node). + """ + + def __init__(self): + if "TORCHELASTIC_RUN_ID" in os.environ: + # TorchElastic job created with torchrun + self.job_id = os.environ["TORCHELASTIC_RUN_ID"] + self.job_type = JobType.TORCHELASTIC + + self.master_addr = os.environ["MASTER_ADDR"] + self.master_port = int(os.environ["MASTER_PORT"]) + self.rank = int(os.environ["RANK"]) + self.world_size = int(os.environ["WORLD_SIZE"]) + self.local_rank = int(os.environ["LOCAL_RANK"]) + self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + elif "SLURM_JOB_ID" in os.environ: + # Slurm job created with sbatch, submitit, etc... + self.job_id = int(os.environ["SLURM_JOB_ID"]) + self.job_type = JobType.SLURM + + node_count = int(os.environ["SLURM_JOB_NUM_NODES"]) + nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"]) + assert len(nodes) == node_count + + self.master_addr = nodes[0] + self.master_port = _get_master_port(seed=self.job_id) + self.rank = int(os.environ["SLURM_PROCID"]) + self.world_size = int(os.environ["SLURM_NTASKS"]) + self.local_rank = int(os.environ["SLURM_LOCALID"]) + self.local_world_size = self.world_size // node_count + else: + # Single node and single job launched manually + self.job_id = None + self.job_type = JobType.MANUAL + + self.master_addr = "127.0.0.1" + self.master_port = _get_available_port() + self.rank = 0 + self.world_size = 1 + self.local_rank = 0 + self.local_world_size = 1 + + assert self.rank < self.world_size + assert self.local_rank < self.local_world_size + + def export( + self, + *, + overwrite: bool, + nccl_async_error_handling: bool = False, + ) -> "TorchDistributedEnvironment": + # See the "Environment variable initialization" section from + # https://pytorch.org/docs/stable/distributed.html for the complete list of + # environment variables required for the env:// initialization method. + env_vars = { + "MASTER_ADDR": self.master_addr, + "MASTER_PORT": str(self.master_port), + "RANK": str(self.rank), + "WORLD_SIZE": str(self.world_size), + "LOCAL_RANK": str(self.local_rank), + "LOCAL_WORLD_SIZE": str(self.local_world_size), + } + if nccl_async_error_handling: + env_vars.update( + { + "TORCH_NCCL_ASYNC_ERROR_HANDLING": "1", # "TORCH_" prefix added in PyTorch 2.2 + } + ) + + if not overwrite: + for k, v in env_vars.items(): + # Only check for difference with preset environment variables + if k not in os.environ: + continue + if os.environ[k] == v: + continue + raise RuntimeError(f"Cannot export environment variables as {k} is already set") + + os.environ.update(env_vars) + return self + + @property + def is_main_process(self) -> bool: + return self.rank == 0 + + def __str__(self): + return ( + f"{self.job_type.value} job " + + (f"({self.job_id}) " if self.job_id else "") + + f"using {self.master_addr}:{self.master_port} " # noqa: E231 + f"(rank={self.rank}, world size={self.world_size})" + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"master_addr={self.master_addr}," # noqa: E231 + f"master_port={self.master_port}," # noqa: E231 + f"rank={self.rank}," # noqa: E231 + f"world_size={self.world_size}," # noqa: E231 + f"local_rank={self.local_rank}," # noqa: E231 + f"local_world_size={self.local_world_size}" + ")" + ) + + +def enable_distributed( + *, + set_cuda_current_device: bool = True, + overwrite: bool = False, + nccl_async_error_handling: bool = False, + restrict_print_to_main_process: bool = True, + timeout: timedelta | None = None, +): + """Enable distributed mode. + + Args: + set_cuda_current_device: If True, call torch.cuda.set_device() to set the + current PyTorch CUDA device to the one matching the local rank. + overwrite: If True, overwrites already set variables. Else fails. + nccl_async_error_handling: Enables NCCL asynchronous error handling. As a + side effect, this enables timing out PyTorch distributed operations + after a default 30 minutes delay). + restrict_print_to_main_process: If True, the print function of non-main processes + (ie rank>0) is disabled. Use print(..., force=True) to print anyway. + If False, nothing is changed and all processes can print as usual. + timeout: Timeout for operations executed against the process group. + Default value is 10 minutes for NCCL and 30 minutes for other backends. + """ + global _DEFAULT_PROCESS_GROUP + + if _DEFAULT_PROCESS_GROUP is not None: + raise RuntimeError("Distributed mode has already been enabled") + + torch_env = TorchDistributedEnvironment() + logger.info(f"PyTorch distributed environment: {torch_env}") + torch_env.export( + overwrite=overwrite, + nccl_async_error_handling=nccl_async_error_handling, + ) + + if set_cuda_current_device: + torch.cuda.set_device(torch_env.local_rank) + + dist.init_process_group(backend="nccl", timeout=timeout) + dist.barrier() + + if restrict_print_to_main_process: + _restrict_print_to_main_process() + + # Finalize setup + _DEFAULT_PROCESS_GROUP = torch.distributed.group.WORLD + + +def get_default_process_group(): + return _DEFAULT_PROCESS_GROUP + + +def disable_distributed() -> None: + global _BUILTIN_PRINT + if _BUILTIN_PRINT is not None: + import builtins as __builtin__ + + __builtin__.print = _BUILTIN_PRINT + + global _PROCESS_SUBGROUP + # checking here because get_process_subgroup can return _DEFAULT_PROCESS_GROUP + if _PROCESS_SUBGROUP is not None: + torch.distributed.destroy_process_group(_PROCESS_SUBGROUP) + _PROCESS_SUBGROUP = None + + global _DEFAULT_PROCESS_GROUP + if _DEFAULT_PROCESS_GROUP is not None: # not initialized + torch.distributed.destroy_process_group(_DEFAULT_PROCESS_GROUP) + _DEFAULT_PROCESS_GROUP = None + + +def new_subgroups(all_subgroup_ranks: Sequence[Sequence[int]]): + """Create new process subgroups according to the provided specification. + + Args: + all_subgroup_ranks: a sequence of rank sequences (first rank, ..., last rank), + one for each process subgroup. Example: ((0, 1), (2, 3), (4, 5, 6, 7)). + + Note: + This is similar to the (non-documented) new_subgroups_by_enumeration(). + This should be called once (and not sequentially) to create all subgroups. + """ + all_ranks = tuple(rank for subgroup_ranks in all_subgroup_ranks for rank in subgroup_ranks) + rank = get_rank() + assert len(all_ranks) == len(set(all_ranks)) + assert rank in all_ranks + + global _PROCESS_SUBGROUP + assert _PROCESS_SUBGROUP is None + + for subgroup_ranks in all_subgroup_ranks: + subgroup = torch.distributed.new_group(subgroup_ranks) + if rank in subgroup_ranks: + _PROCESS_SUBGROUP = subgroup + + +def get_process_subgroup(): + """ + Returns: + The process subgroup of this rank (or None). + """ + return _PROCESS_SUBGROUP or _DEFAULT_PROCESS_GROUP + + +def get_subgroup_rank() -> int: + """ + Returns: + The rank of the current process within its process subgroup. + """ + return get_rank(group=get_process_subgroup()) + + +def get_subgroup_size() -> int: + """ + Returns: + The number of processes in the process subgroup + """ + return get_world_size(group=get_process_subgroup()) + + +def is_subgroup_main_process() -> bool: + """ + Returns: + True if the current process is the main one within its process subgroup. + """ + return get_rank(group=get_process_subgroup()) == 0 diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/env/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/env/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15278fd217438cd25fc12ba9eda40cc38182b0d8 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/env/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import contextlib +import logging +import os +import tempfile +from typing import Optional + +import submitit.helpers + +logger = logging.getLogger("dinov3") + + +@contextlib.contextmanager +def clean_env(): + try: + # Hide torch.compile() variables from the launched evals + extra_names = ("TRITON_CACHE_DIR", "TORCHINDUCTOR_CACHE_DIR") + ctx = submitit.helpers.clean_env(extra_names=extra_names) + except TypeError as e: + logger.warning("Update submitit to the latest main branch\n%s", e) + ctx = submitit.helpers.clean_env() + with ctx: + yield + + +def set_triton_cache_dir(cache_dir: Optional[str] = None) -> None: + if cache_dir is None: + cache_dir = tempfile.mkdtemp() + os.environ["TRITON_CACHE_DIR"] = cache_dir diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/accumulators.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/accumulators.py new file mode 100644 index 0000000000000000000000000000000000000000..092cc7b671bfc4baa9482c0a911de104bc063ada --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/accumulators.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from collections import defaultdict +from typing import Dict, List, Optional + +import torch +from torch import Tensor + +from dinov3.distributed import gather_all_tensors # Gathers tensors of different sizes + +logger = logging.getLogger("dinov3") + + +def _cat_and_gather_tensor_list(tensor_list: List[Tensor]) -> Tensor: + local_cat = torch.cat(tensor_list) + return torch.cat(gather_all_tensors(local_cat)) + + +class Accumulator: + def __init__(self) -> None: + pass + + def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None: + raise NotImplementedError + + def accumulate(self) -> Optional[Dict[str, Tensor]]: + raise NotImplementedError + + +class NoOpAccumulator(Accumulator): + def __init__(self) -> None: + pass + + def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None: + pass + + def accumulate(self) -> None: + return None + + +class ResultsAccumulator(Accumulator): + """ + Accumulate predictions and targets across processes + """ + + def __init__(self) -> None: + self._local_values: Dict[str, List[Tensor]] = defaultdict(list) + self._gathered_values: Dict[str, Tensor] = {} + self._gathered = False + + def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None: + assert len(preds) == len(target) == len(index) + assert not self._gathered, "Tensors have already been gathered in this helper" + self._local_values["preds"].append(preds) + self._local_values["target"].append(target) + self._local_values["index"].append(index) + self._gathered = False + + def _gather_tensors(self): + for k, tensor_list in self._local_values.items(): + self._gathered_values[k] = _cat_and_gather_tensor_list(tensor_list) + self._gathered = True + + def accumulate(self) -> Dict[str, Tensor]: + if not self._gathered: + self._gather_tensors() + preds, target, index = [self._gathered_values[k] for k in ["preds", "target", "index"]] + assert len(preds) == len(target) == len(index) and index.min() == 0 + preds_ordered = torch.zeros((index.max() + 1, *preds.shape[1:]), dtype=preds.dtype, device=preds.device) + preds_ordered[index] = preds + target_ordered = torch.zeros((index.max() + 1, *target.shape[1:]), dtype=target.dtype, device=target.device) + target_ordered[index] = target + return {"preds": preds_ordered, "target": target_ordered} diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/data.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/data.py new file mode 100644 index 0000000000000000000000000000000000000000..6c71628cffcf472faf9ef486cf361d3823b29be6 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/data.py @@ -0,0 +1,256 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from functools import lru_cache +from typing import Any, Callable, Optional + +import numpy as np +import torch +from torch.utils.data import Subset +from torchvision.datasets.vision import StandardTransform + +from dinov3.eval.utils import extract_features + +logger = logging.getLogger("dinov3") + + +class SubsetEx(Subset): + def _get_actual_index(self, index: int) -> int: + return self.indices[index] + + def get_target(self, index: int) -> Any: + actual_index = self._get_actual_index(index) + return self.dataset.get_target(actual_index) + + @property + def transforms(self): + return self.dataset.transforms + + +def get_target_transform(dataset) -> Optional[Callable]: + if hasattr(dataset, "transforms"): + if isinstance(dataset.transforms, StandardTransform): + return dataset.transforms.target_transform + raise ValueError("Dataset has a non-standard .transforms property") + if hasattr(dataset, "target_transform"): + return dataset.target_transform + return None + + +@lru_cache(maxsize=1) +def get_labels(dataset) -> torch.Tensor: + """ + Get the labels of a classification dataset, as a Tensor, using the `get_targets` method + if it is present or loading the labels one by one with `get_target`, if it exists. + If the dataset has a target transform, iterate over the whole dataset to get the + transformed labels for each element, then stack them as a torch tensor. + """ + logger.info("Getting dataset labels ...") + if hasattr(dataset, "get_targets") or hasattr(dataset, "get_target"): + if hasattr(dataset, "get_targets"): # Returns a np.array + labels = dataset.get_targets() + elif hasattr(dataset, "get_target"): + labels = [dataset.get_target(i) for i in range(len(dataset))] + target_transform = get_target_transform(dataset) + if target_transform is not None: + labels = [target_transform(label) for label in labels] + else: + # Target transform is applied in this case + labels = [dataset[i][1] for i in range(len(dataset))] + return torch.stack([torch.tensor(label, dtype=int) for label in labels]) + + +def get_num_classes(dataset) -> int: + """ + Get the labels of a dataset and compute the number of classes + """ + labels = get_labels(dataset) + if len(labels.shape) > 1: + return int(labels.shape[1]) + return int(labels.max() + 1) + + +def create_class_indices_mapping(labels: torch.Tensor) -> dict[int, torch.Tensor]: + """ + Efficiently creates a mapping between the labels and tensors containing + the indices of all the dataset elements that share this label. + In the case of multiple labels, it is not guaranteed that there + will be exactly the specified percentage of labels. + """ + if len(labels.shape) > 1: # labels are a one-hot encoding + assert len(labels.shape) == 2 + sorted_labels, indices = torch.nonzero(labels.T, as_tuple=True) + else: + sorted_labels, indices = torch.sort(labels, stable=True) + unique_labels, counts = torch.unique_consecutive(sorted_labels, return_counts=True) + mapping = dict(zip(unique_labels.tolist(), torch.split(indices, counts.tolist()))) + return mapping + + +def _shuffle_dataset(dataset: torch.Tensor, seed: int = 0): + """ + Shuffling a dataset by subsetting it with a random permutation of its indices + """ + random_generator = torch.Generator() + random_generator.manual_seed(seed) + random_indices = torch.randperm(len(dataset), generator=random_generator) + return SubsetEx(dataset, random_indices) + + +def _subset_dataset_per_class( + class_indices_mapping: dict[int, torch.Tensor], + n_or_percent_per_class: int | float, + dataset_size: int, + seed: int = 0, + is_percent: bool = False, +) -> torch.Tensor: + """ + Helper function to select a percentage of a dataset, equally distributed across classes, + or to take the same number of elements from each class of the dataset. + Returns a boolean mask tensor being True at indices of selected elements + """ + + random_generator = torch.Generator() + random_generator.manual_seed(seed) + + final_indices_bool = torch.zeros(dataset_size, dtype=bool) + for class_indices in class_indices_mapping.values(): + # Select at least one element + n_for_class = max(int(len(class_indices) * n_or_percent_per_class), 1) if is_percent else n_or_percent_per_class + assert isinstance(n_for_class, int) + filtered_index = torch.randperm(len(class_indices), generator=random_generator)[:n_for_class] + final_indices_bool[class_indices[filtered_index]] = True + return final_indices_bool + + +def _multilabel_rebalance_subset( + class_indices_mapping: dict[int, torch.Tensor], + n_or_percent_per_class: int | float, + labels: torch.Tensor, + indices_bool: torch.Tensor, + dataset_size: int, + seed: int = 0, +) -> torch.Tensor: + """ + Helper function to refine a subset of a multi-label dataset (indices_bool) + to better match a target percentage of labels. + Returns a boolean mask tensor being True at indices of selected elements. + """ + + # Compute the number of selected labels in indices_bool + num_total_labels = labels.sum() + num_wanted_labels = int(num_total_labels * n_or_percent_per_class) + num_selected_labels = (labels[indices_bool] > 0).sum() + logger.info(f" {num_selected_labels} labels instead of {num_wanted_labels}") + + # Compute a new percentage and new set selecting less images, therefore less labels, to match approximatelly the exact percentage of labels selected + n_or_percent_per_class = n_or_percent_per_class / (num_selected_labels / num_wanted_labels) + final_indices_bool = _subset_dataset_per_class( + class_indices_mapping, n_or_percent_per_class, dataset_size, seed, True + ) + + # Compute the number of labels finally used + num_selected_labels = (labels[final_indices_bool] > 0).sum() + logger.info(f" {num_selected_labels} labels instead of {num_wanted_labels}") + + return final_indices_bool + + +def split_train_val_datasets(train_dataset, split_percentage: float = 0.1, shuffle_train: bool = True): + """ + Splitting a percent of the train dataset to choose hyperparameters, taking the same percentage for each class. + If `shuffle` is False, taking the first elements of each class as the validaton set. + """ + assert 0 < split_percentage < 1 + logger.info(f"Selecting {int(split_percentage * 100)}% of the train dataset as the validation set") + if shuffle_train: + logger.info("Shuffling train dataset before splitting in train and validation sets") + train_dataset = _shuffle_dataset(train_dataset) + train_labels = get_labels(train_dataset) + class_indices_mapping = create_class_indices_mapping(train_labels) + val_mask = torch.zeros(len(train_labels), dtype=bool) + for class_indices in class_indices_mapping.values(): + # If there is only one element, it goes in the train set + n_for_val = max(1, int(split_percentage * len(class_indices))) if len(class_indices) > 1 else 0 + val_mask[class_indices[:n_for_val]] = True + + val_dataset = SubsetEx(train_dataset, val_mask.nonzero().flatten()) + train_dataset = SubsetEx(train_dataset, (~val_mask).nonzero().flatten()) + return train_dataset, val_dataset + + +def create_train_dataset_dict( + train_dataset, + few_shot_eval: bool = False, + few_shot_k_or_percent: float | None = None, + few_shot_n_tries: int = 1, +) -> dict[int, dict[int, Any]]: + """ + Randomly split a dataset for few-shot evaluation, with `few_shot_k_or_percent` being + n elements or x% of a class. Produces a dict, which keys are number of random "tries" + and values are the dataset subset for this "try". + + Format is {"nth-try": dataset} + """ + if few_shot_eval is False: + assert few_shot_k_or_percent is None + assert few_shot_n_tries == 1 + return {0: train_dataset} + + assert few_shot_k_or_percent is not None + train_labels = get_labels(train_dataset) + class_indices_mapping = create_class_indices_mapping(train_labels) + train_dataset_dict: dict[int, Any] = {} + is_percent = few_shot_k_or_percent < 1 + if not is_percent: + few_shot_k_or_percent = int(few_shot_k_or_percent) + + for t in range(few_shot_n_tries): + t_subset_bool = _subset_dataset_per_class( + class_indices_mapping=class_indices_mapping, + n_or_percent_per_class=few_shot_k_or_percent, + dataset_size=len(train_labels), + is_percent=is_percent, + seed=t, + ) + if len(train_labels.shape) > 1 and is_percent: + t_subset_bool = _multilabel_rebalance_subset( + class_indices_mapping=class_indices_mapping, + n_or_percent_per_class=few_shot_k_or_percent, + dataset_size=len(train_labels), + labels=train_labels, + indices_bool=t_subset_bool, + seed=t, + ) + train_dataset_dict[t] = SubsetEx(train_dataset, t_subset_bool.nonzero().flatten()) + return train_dataset_dict + + +def extract_features_for_dataset_dict( + model, dataset_dict: dict[int, dict[int, Any]], batch_size: int, num_workers: int, gather_on_cpu=False +) -> dict[int, dict[str, torch.Tensor]]: + """ + Extract features for each subset of dataset in the context of few-shot evaluations + """ + few_shot_data_dict: dict[int, dict[str, torch.Tensor]] = {} + for try_n, dataset in dataset_dict.items(): + features, labels = extract_features(model, dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu) + few_shot_data_dict[try_n] = {"train_features": features, "train_labels": labels} + return few_shot_data_dict + + +def pad_multilabel_and_collate(batch, pad_value=-1): + """ + This method pads and collates a batch of (image, (index, target)) tuples, coming from + DatasetWithEnumeratedTargets, with targets that are list of potentially varying sizes. + The targets are padded to the length of the longest target list in the batch. + """ + maxlen = max(len(targets) for _, (_, targets) in batch) + padded_batch = [ + (image, (index, np.pad(targets, (0, maxlen - len(targets)), constant_values=pad_value))) + for image, (index, targets) in batch + ] + return torch.utils.data.default_collate(padded_batch) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/checkpoint_utils.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/checkpoint_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..03ec60ef6b1bbaf6b7d836f268e315b70aa25098 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/checkpoint_utils.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +import os + +import torch +from torch.optim.optimizer import Optimizer + +logger = logging.getLogger("dinov3") + + +def unwrap_ddp_state_dict(model_state_dict): + is_ddp = all([k.startswith("module.") for k in model_state_dict.keys()]) + if is_ddp: + model_state_dict = {k.split("module.", 1)[-1]: v for (k, v) in model_state_dict.items()} + return model_state_dict + + +def load_checkpoint(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location="cpu") + state_dicts = {} + iteration = None + if "iteration" in checkpoint.keys(): + iteration = checkpoint["iteration"] + state_dicts["model"] = unwrap_ddp_state_dict(checkpoint["model"]) + if "optimizer" in checkpoint.keys(): + state_dicts["optimizer"] = checkpoint["optimizer"] + return state_dicts, iteration + + +def find_latest_checkpoint(path): + if not os.path.exists(path): + return None + list_checkpoints = sorted([filepath for filepath in os.listdir(path) if filepath.endswith("pth")]) + if os.path.exists(os.path.join(path, "model_final.pth")): + return os.path.join(path, "model_final.pth") + elif len(list_checkpoints) >= 1: + model_latest_iteration_path = list_checkpoints[-1] + return os.path.join(path, model_latest_iteration_path) + else: + logger.info("Could not find checkpoint to resume from, starting from scratch") + + +def save_checkpoint(path: str, iteration: int, model: torch.nn.Module, optimizer: Optimizer): + chkpt = { + "model": unwrap_ddp_state_dict(model.state_dict()), + "optimizer": optimizer.state_dict(), + "iteration": iteration, + } + torch.save(chkpt, os.path.join(path, f"model_{iteration:08d}.pth")) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/config.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/config.py new file mode 100644 index 0000000000000000000000000000000000000000..59f9f36f8def70ce76aa19d9adf2f4c3702f1aa5 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/config.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from omegaconf import MISSING + +import torch + +from dinov3.eval.setup import ModelConfig + + +from dinov3.eval.depth.loss import LossType +from dinov3.eval.depth.models import DecoderConfig +from dinov3.eval.depth.transforms import make_depth_train_transforms, make_depth_eval_transforms +from dinov3.data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +class Dtype(Enum): + FLOAT32 = "float32" + BFLOAT16 = "bfloat16" + + @property + def autocast_dtype(self): + return { + Dtype.BFLOAT16: torch.bfloat16, + Dtype.FLOAT32: torch.float, + }[self] + + +class ResultExtension(Enum): + JPG = "jpg" + PNG = "png" + PTH = "pth" + + +@dataclass +class ResultConfig: + save_results: bool = False + extension: ResultExtension = ResultExtension.JPG + save_resolution: int | None = ( + None # if set, the output result image is resized to have its smallest size set to save_resolution + ) + overlay_alpha: float = 1.0 # if alpha == 1, masks are not overlaid on the original image + save_separate_files: bool = False # set to true to save individual files (image, prediction, gt) + + +@dataclass +class DatasetsConfig: + root: str = MISSING + train: str = "" + val: str = "" + test: str = "" + + +@dataclass +class OptimizerConfig: + lr: float = 1e-4 + beta1: float = 0.9 + beta2: float = 0.999 + weight_decay: float = 0.01 + gradient_clip: float = 35.0 + + +@dataclass +class SchedulerConfig: + type: str = "WarmupOneCycleLR" + total_iter: int = 38_400 # Total number of iterations for training + constructor_kwargs: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class TrainTransformConfig: + img_size: Any = None + random_crop: tuple[int, int] | None = None + brightness_range: tuple[float, float] = (0.9, 1.1) + rotation_angle: float = 2.5 # max rotation angle + fixed_crop: str = "FULL" + eval_mask: str = "FULL" + + +@dataclass +class EvalTransformConfig: + img_size: Any = None + fixed_crop: str = "FULL" + eval_mask: str = "FULL" + + +@dataclass +class TransformConfig: + train: TrainTransformConfig | None = None + eval: EvalTransformConfig = field(default_factory=EvalTransformConfig) + mean: tuple[float, float, float] = IMAGENET_DEFAULT_MEAN + std: tuple[float, float, float] = IMAGENET_DEFAULT_STD + normalization_constant: float = 1000.0 + + +@dataclass +class EvalConfig: + ignored_value: float = 0.0 # If depth pixels have this value in the dataset, they will be ignored + align_least_squares: bool = False # Choose whether to align predictions to ground truth during testing + min_depth: float = 0.001 # Minimum depth to be evaluated + max_depth: float = 10.0 # Maximum depth to be evaluated + use_tta: bool = True # apply test-time augmentation at evaluation time + eval_interval: int = 1600 # number of iterations between evaluations + + +@dataclass +class DepthConfig: + model: ModelConfig | None = None + bs: int = 2 + n_gpus: int = 8 + num_workers: int = 2 + seed: int = 321 + scheduler: SchedulerConfig = field(default_factory=SchedulerConfig) + optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + datasets: DatasetsConfig = field(default_factory=DatasetsConfig) + decoder_head: DecoderConfig = field(default_factory=DecoderConfig) + model_dtype: Dtype | None = None + losses: dict[LossType, float] | None = None # For example {SIGLOSS: 1.0, GRADIENT_LOG_LOSS: 0.0} + transforms: TransformConfig = field(default_factory=TransformConfig) + eval: EvalConfig = field(default_factory=EvalConfig) + metrics: list[str] = field(default_factory=lambda: ["rmse", "abs_rel", "a1"]) + result_config: ResultConfig = field(default_factory=ResultConfig) + load_from: str | None = None # path to .pt checkpoint to resume training from + output_dir: str = "" + + +def make_depth_train_transforms_from_config(config: DepthConfig): + assert config.datasets.train is not None + assert config.transforms.train is not None + transforms = make_depth_train_transforms( + img_size=config.transforms.train.img_size, + normalization_constant=config.transforms.normalization_constant, + random_crop_size=config.transforms.train.random_crop, + fixed_crop=config.transforms.train.fixed_crop, + brightness_range=config.transforms.train.brightness_range, + rotation_angle=config.transforms.train.rotation_angle, + mean=config.transforms.mean, + std=config.transforms.std, + ) + return transforms + + +def make_depth_eval_transforms_from_config(config: DepthConfig, split: str = "val"): + assert split in ["val", "test"] + transforms = make_depth_eval_transforms( + img_size=config.transforms.eval.img_size, + normalization_constant=config.transforms.normalization_constant, + fixed_crop=config.transforms.eval.fixed_crop, + tta=config.eval.use_tta if split == "test" else False, + mean=config.transforms.mean, + std=config.transforms.std, + ) + + return transforms diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/configs/config-nyu-synthmix-dpt-inference.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/configs/config-nyu-synthmix-dpt-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ff5ca8622abfda9c85fc0ce8a4d7544db20cc50f --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/configs/config-nyu-synthmix-dpt-inference.yaml @@ -0,0 +1,12 @@ +datasets: + test: 'NYU:split=VAL' +transforms: + eval: + img_size: 768 + fixed_crop: 'FULL' + eval_mask: 'NYU_EIGEN' +eval: + min_depth: 0.001 + max_depth: 10.0 + use_tta: true + align_least_squares: true diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/configs/config-nyu.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/configs/config-nyu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f1a89e8d02a445cc1f8bafef2554a23599b4ab00 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/configs/config-nyu.yaml @@ -0,0 +1,46 @@ +bs: 2 +n_gpus: 8 +seed: 321 +scheduler: + total_iter: 38400 + type: 'WarmupOneCycleLR' + constructor_kwargs: + final_div_factor: 1000.0 + warmup_iters: 12800 + base_momentum: 0.85 + max_momentum: 0.95 +optimizer: + lr : 3e-4 + beta1: 0.9 + beta2: 0.99 + weight_decay: 1e-4 + gradient_clip: 35 +datasets: + train: 'NYU:split=TRAIN' + val: 'NYU:split=VAL' + test: 'NYU:split=VAL' +decoder_head: + min_depth: 0.001 + max_depth: 10.0 + use_backbone_norm: True + use_batchnorm: True + backbone_out_layers: LAST + type: linear + use_cls_token: False + n_output_channels: 256 +losses: + SIGLOSS: 1.0 +transforms: + train: + random_crop: [416, 544] + brightness_range: [0.75, 1.25] + fixed_crop: 'NYU' + eval_mask: 'FULL' + eval: + fixed_crop: 'FULL' + eval_mask: 'NYU_EIGEN' +eval: + min_depth: 0.001 + max_depth: 10.0 + use_tta: true + align_least_squares: false diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/data.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/data.py new file mode 100644 index 0000000000000000000000000000000000000000..93953635c97c660afc3f3196e19e620ca73c47ef --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/data.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +import random +from functools import partial +from typing import Any + +import numpy as np +import torch + +from dinov3.data import make_dataset, make_data_loader, DatasetWithEnumeratedTargets, SamplerType +import dinov3.distributed as distributed + + +logger = logging.getLogger("dinov3") + + +def worker_init_fn(worker_id, num_workers, rank, seed): + """Worker init func for dataloader. + The seed of each worker equals to num_worker * rank + worker_id + user_seed + Args: + worker_id (int): Worker id. + num_workers (int): Number of workers. + rank (int): The rank of current process. + seed (int): The random seed to use. + """ + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) + torch.manual_seed(worker_seed) + + +def build_dataloader( + transforms: Any, + dataset_str: str, + device: int, + split: str = "train", + batch_size: int = 1, + n_gpus: int = 1, + num_workers: int = 2, + seed: int = 0, + use_init_fn=False, +): + """ + Build a dataloader from lavida descriptor strings. + One can specify either a list of descriptors or a single one. + When a list is used, the resulting dataset is + a concatenation of all the listed datasets. + + transforms: transforms for the dataset + dataset_str (str): a dataset descriptor, e.g. 'NYU:split=TRAIN' + device (int): id for the GPU rank + split (str): dataset split (choice: ['train', 'val', 'test']) + batch_size (int): batch size + n_gpus (int): number of ranks to use for distributed sampler + num_workers (int): number of workers for the dataloader + seed (int): random seed + use_init_fn (bool): if True, initializes workers with worker_init_fn + """ + assert split in ["train", "val", "test"] + is_train = split == "train" + ds = make_dataset(dataset_str=dataset_str, transforms=transforms) + logger.info(f"Dataset {split}:\n{ds}") + + if not is_train: + assert batch_size == 1, "Evaluation should only be done at batch size 1!" + ds = DatasetWithEnumeratedTargets(ds, pad_dataset=True, num_replicas=n_gpus) + + if use_init_fn and is_train: + init_fn = partial(worker_init_fn, num_workers=num_workers, rank=device, seed=seed + device) + else: + init_fn = None + dataloader = make_data_loader( + dataset=ds, + batch_size=batch_size, + sampler_type=SamplerType.DISTRIBUTED if distributed.is_enabled() else None, + drop_last=is_train, + shuffle=is_train, + persistent_workers=(not is_train), + worker_init_fn=init_fn, + seed=seed, + num_workers=num_workers, + ) + + if is_train: + return InfiniteDataloader(dataloader) + + return dataloader + + +class InfiniteDataloader: + def __init__(self, dataloader: torch.utils.data.DataLoader): + self.dataloader = dataloader + self.data_iterator = iter(dataloader) + self.sampler = dataloader.sampler + if not hasattr(self.sampler, "epoch"): + self.sampler.epoch = 0 # type: ignore + + def __iter__(self): + return self + + def __len__(self) -> int: + return len(self.dataloader) + + def __next__(self): + try: + data = next(self.data_iterator) + except StopIteration: + self.sampler.epoch += 1 + self.data_iterator = iter(self.dataloader) + data = next(self.data_iterator) + return data diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/datasets/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea8cef2afa438924642360f035e7319a1f1ee2b2 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/datasets/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/datasets/datasets_utils.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/datasets/datasets_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0c50884997785dd3bb21e5fe5eeac70c170b5a02 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/datasets/datasets_utils.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from enum import Enum + +import torch + + +class _EvalCropType(Enum): + NYU_EIGEN = "NYU_EIGEN" + FULL = "FULL" + + +def make_valid_mask(input, eval_crop: _EvalCropType = _EvalCropType.FULL, ignored_value: float = 0.0): + """Following Adabins, Do grag_crop or eigen_crop for testing + + Args: + input: input tensor in BxCxHxW format + eval_crop (_EvalCropType): evaluation crop used for evaluation + ignored_value (float): value from input to be ignored during evaluation + """ + B, _, h, w = input.shape + eval_mask = torch.zeros(input.shape, device=input.device) + if eval_crop == _EvalCropType.NYU_EIGEN: + y1, y2, x1, x2 = 45, 471, 41, 601 + orig_h, orig_w = 480, 640 + y1_new = int((y1 / orig_h) * h) + y2_new = int((y2 / orig_h) * h) + x1_new = int((x1 / orig_w) * w) + x2_new = int((x2 / orig_w) * w) + eval_mask[:, :, y1_new:y2_new, x1_new:x2_new] = 1 + else: + eval_mask.fill_(1) + + # make mask from ignored values + ignored_value_mask = torch.ones((B, 1, h, w), device=eval_mask.device) + ignored_value_mask[(input == ignored_value).all(dim=1, keepdims=True)] = 0 + + eval_mask = eval_mask * ignored_value_mask + return eval_mask.bool() diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/eval.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..11f8e5e7e7ea67510c5d601735913d9a9921836f --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/eval.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging + +from typing import Any +import torch +import torch.utils +import torch.utils.data + +import dinov3.distributed as distributed +from dinov3.logging import MetricLogger + + +from dinov3.eval.depth.config import ( + DepthConfig, + ResultConfig, + make_depth_eval_transforms_from_config, +) +from dinov3.eval.depth.data import build_dataloader +from dinov3.eval.depth.datasets.datasets_utils import _EvalCropType, make_valid_mask +from dinov3.eval.depth.metrics import calculate_depth_metrics, _DepthMetric, DEPTH_METRICS +from dinov3.eval.depth.transforms import Aug, LeftRightFlipAug +from dinov3.eval.depth.utils import align_depth_least_square +from dinov3.eval.depth.visualization_utils import depth_tensor_to_colorized_pil, save_predictions + + +logger = logging.getLogger("dinov3") + + +def inverse_tta_hook(transforms: Aug): + return lambda module, inputs, outputs: transforms.inverse(outputs) + + +def evaluate_depther_with_config( + config: DepthConfig, + depther: torch.nn.Module, + device: Any, + reduce_results: bool = True, +): + # 1- define dataset + transforms = make_depth_eval_transforms_from_config(config, split="test") + dataloader = build_dataloader( + dataset_str=config.datasets.test + f":root={config.datasets.root}", + transforms=transforms, + device=device, + split="test", + batch_size=1, + n_gpus=distributed.get_world_size(), + ) + metrics = [metric for metric in DEPTH_METRICS if metric.name in config.metrics] + + return evaluate_depther_with_dataloader( + dataloader, + depther, + device=device, + metrics=metrics, + eval_range=(config.eval.min_depth, config.eval.max_depth), + result_config=config.result_config, + ignored_value=config.eval.ignored_value, + eval_mask_type=config.transforms.eval.eval_mask, + save_dir=config.output_dir, + reduce_results=reduce_results, + align_least_squares=config.eval.align_least_squares, + use_tta=config.eval.use_tta, + ) + + +@torch.no_grad() +def evaluate_depther_with_dataloader( + dataloader: torch.utils.data.DataLoader, + depther: torch.nn.Module, + device: Any, + metrics: list[_DepthMetric], + eval_range: tuple[float, float], + result_config: ResultConfig, + save_dir="", + ignored_value: float = 0.0, + eval_mask_type: str = "NYU_EIGEN", + reduce_results: bool = True, + align_least_squares: bool = False, + use_tta: bool = False, +): + """ + Evaluate a dense estimation model with a dataloader + + Inputs: + - dataloader: a torch.utils.data.DataLoader + - depther: depth estimator to evaluate + - device: the (CUDA) device to evaluate on + - metrics: metrics to report during evaluation + - eval_range (float, float): depth evaluation range + - result_config (ResultConfig): contains parameters for results saving + - save_dir (str): saving directory for results (metrics and predictions) + - ignored_value (float): value to ignore from the ground truth + - eval_mask_type (str): evaluation mask. See _EvalCropType Enum for choices + - reduce_results (bool): if True, results are averaged across all samples (default=True) + - align_least_squares (bool): if True, aligns prediction in scale and shift with GT using least squares error minimization + - use_tta (bool): if True, uses left-right flipping test time augmentation (default False). + """ + + metric_names = [metric.name for metric in metrics] + all_metric_values_dict: dict[str, Any] = {metric: [] for metric in metric_names} + all_metric_values_dict["indices"] = [] + final_metric_values_dict = {} + + n_gpus = distributed.get_world_size() + + # build a metric_logger for validation + header = "Validation: " + metric_logger = MetricLogger(delimiter=" ") + all_losses: dict[str, list[float]] = {} + if use_tta: + hook = depther.register_forward_hook(inverse_tta_hook(LeftRightFlipAug(flip=True))) + else: + hook = None + + for batch_img, target in metric_logger.log_every(dataloader, 10, header=header): + index, gt_map = target + # batchify augmentations together + assert batch_img[0].shape[0] == 1 + batch_img = torch.cat(batch_img, dim=0).to(device) + gt_map = torch.cat(gt_map, dim=0).to(device) + + preds = depther(batch_img) + # Skip padded indices AFTER prediction, so that each rank can process the + # Same number of forwards, necessary for a sharded backbone + if index < 0: + continue + + # in case tta inflated the batch + B, C, _, _ = preds.shape + gt_map = gt_map[:B] + + # run post processing on ground truth and prediction + gt_map = torch.where( + torch.logical_or(gt_map >= eval_range[1], gt_map <= eval_range[0]), + ignored_value, + gt_map, + ) + + valid_mask = make_valid_mask( + gt_map, + eval_crop=_EvalCropType(eval_mask_type), + ignored_value=ignored_value, + ) + + # resize -if necessary- prediction to match size of gt + if gt_map.shape[-2:] != preds.shape[-2:]: + preds = torch.nn.functional.interpolate( + input=preds, + size=gt_map.shape[2:], + mode="bilinear", + align_corners=False, + ) + if align_least_squares: + preds = torch.stack([align_depth_least_square(gt_map, p, valid_mask)[0] for p in preds]) + preds = preds.to(device) + + if result_config.save_results: + assert preds.shape[0] == 1, "Cannot save results for more than one decoder" + save_predictions( + img=batch_img[: preds.shape[0]], + pred=preds[0], + gt=gt_map, + save_index=index, + result_config=result_config, + save_dir=save_dir, + pred_tensor_to_pil_fn=depth_tensor_to_colorized_pil, + ) + + preds = preds.clamp(min=eval_range[0], max=eval_range[1]) + + batch_metric_values = calculate_depth_metrics( + gt_map, + preds, + valid_mask, + list_metrics=metrics, + ) # NamedTuple with metrics as names + for metric in metric_names: + value = getattr(batch_metric_values, metric, None) + if value is not None: + all_metric_values_dict[metric].append(value) + all_metric_values_dict["indices"].append(index) + all_indices = torch.tensor(all_metric_values_dict["indices"], device=device) + if n_gpus > 1: + list_all_indices = distributed.gather_all_tensors(all_indices) + all_indices = torch.cat(list_all_indices, dim=0).cpu().to(torch.int32) + + out_results_dict = {} + + all_metric_values = torch.tensor( + [values_per_metric for values_per_metric in all_metric_values_dict.values()], + device=device, + ) + if n_gpus > 1: + all_metric_values = torch.cat(distributed.gather_all_tensors(all_metric_values), dim=1) + + final_metric_values = all_metric_values.nanmean(1) + final_metric_values_dict = dict(zip(metric_names, final_metric_values.cpu().numpy())) + + if reduce_results: + out_results_dict = {k: float(v) for (k, v) in final_metric_values_dict.items()} + else: + out_results_dict = { + metric_name: value for (metric_name, value) in zip(metric_names, all_metric_values.cpu().numpy().tolist()) + } + logger.info( + "Final scores: " + " ".join([f"{name}: {meter:.3f}" for name, meter in final_metric_values_dict.items()]) + ) + + if hook is not None: + hook.remove() + + return out_results_dict, all_losses, all_indices diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/loss.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..301eb8693307ecade0f832317bf5cc1bf37a2a97 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/loss.py @@ -0,0 +1,164 @@ +from enum import Enum +from functools import partial + +import torch +from torch import nn + + +class LossType(Enum): + SIGLOSS = "sigloss" + GRADIENT_LOSS = "gradient_loss" + GRADIENT_LOG_LOSS = "gradient_log_loss" + L1 = "l1" + + def module(self, *args, **kwargs): + return { + LossType.SIGLOSS: partial( + SigLoss, + warm_up=True, + warm_iter=100, + ), # default parameters for the custom loss (BW compatibility) + LossType.GRADIENT_LOG_LOSS: GradientLogLoss, + LossType.GRADIENT_LOSS: GradientLoss, + LossType.L1: L1Loss, + }[self](*args, **kwargs) + + +class GradientLoss(nn.Module): + def __init__(self): + super().__init__() + self.eps = 0.001 + + def forward(self, input, target, valid_mask=None): + input_downscaled = [input] + [input[..., :: 2 * i, :: 2 * i] for i in range(1, 4)] + target_downscaled = [target] + [target[..., :: 2 * i, :: 2 * i] for i in range(1, 4)] + if valid_mask is not None: + mask_downscaled = [valid_mask] + [valid_mask[..., :: 2 * i, :: 2 * i] for i in range(1, 4)] + else: + mask_downscaled = [torch.ones_like(target, dtype=bool) for target in target_downscaled] + + gradient_loss = 0 + for input, target, mask in zip(input_downscaled, target_downscaled, mask_downscaled): + N = torch.sum(mask) + d_diff = torch.mul(input - target, mask) + + v_gradient = torch.abs(d_diff[..., 0:-2, :] - d_diff[..., 2:, :]) + v_mask = torch.mul(mask[..., 0:-2, :], mask[..., 2:, :]) + v_gradient = torch.mul(v_gradient, v_mask) + + h_gradient = torch.abs(d_diff[..., :, 0:-2] - d_diff[..., :, 2:]) + h_mask = torch.mul(mask[..., :, 0:-2], mask[..., :, 2:]) + h_gradient = torch.mul(h_gradient, h_mask) + gradient_loss += (torch.sum(h_gradient) + torch.sum(v_gradient)) / N + + return gradient_loss + + +class GradientLogLoss(nn.Module): + def __init__(self): + super().__init__() + self.eps = 0.001 + + def forward(self, input, target, valid_mask=None): + input_downscaled = [input] + [input[..., :: 2 * i, :: 2 * i] for i in range(1, 4)] + target_downscaled = [target] + [target[..., :: 2 * i, :: 2 * i] for i in range(1, 4)] + if valid_mask is not None: + mask_downscaled = [valid_mask] + [valid_mask[..., :: 2 * i, :: 2 * i] for i in range(1, 4)] + else: + mask_downscaled = [torch.ones_like(target, dtype=bool) for target in target_downscaled] + + gradient_loss = 0 + for input, target, mask in zip(input_downscaled, target_downscaled, mask_downscaled): + N = torch.sum(mask) + input_log = torch.log(input + self.eps) + target_log = torch.log(target + self.eps) + log_d_diff = input_log - target_log + + log_d_diff = torch.mul(log_d_diff, mask) + + v_gradient = torch.abs(log_d_diff[..., 0:-2, :] - log_d_diff[..., 2:, :]) + v_mask = torch.mul(mask[..., 0:-2, :], mask[..., 2:, :]) + v_gradient = torch.mul(v_gradient, v_mask) + + h_gradient = torch.abs(log_d_diff[..., :, 0:-2] - log_d_diff[..., :, 2:]) + h_mask = torch.mul(mask[..., :, 0:-2], mask[..., :, 2:]) + h_gradient = torch.mul(h_gradient, h_mask) + gradient_loss += (torch.sum(h_gradient) + torch.sum(v_gradient)) / N + + return gradient_loss + + +class L1Loss(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, target, valid_mask=None): + loss = nn.functional.l1_loss(input, target, reduce=False) + mask = valid_mask if (valid_mask is not None) else torch.ones_like(input, dtype=bool) + loss = loss * mask + return loss.sum() / (mask.sum() + 1e-7) + + +class SigLoss(nn.Module): + """Sigloss + + Adapted from Binsformer who adapted from AdaBins + https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox/blob/7c0c89c8db07631fec1737f3087e4f1f540ccd53/depth/models/losses/sigloss.py#L8 + """ + + def __init__(self, warm_up=True, warm_iter=100): + super(SigLoss, self).__init__() + self.loss_name = "SigLoss" + self.eps = 0.001 # avoid grad explode + self.warm_up = warm_up + self.warm_iter = warm_iter + self.warm_up_counter = 0 + + def sigloss(self, input, target, valid_mask=None): + if valid_mask is None: + valid_mask = torch.ones_like(target, dtype=bool) + input = input[valid_mask] + target = target[valid_mask] + + g = torch.log(input + self.eps) - torch.log(target + self.eps) + Dg = 0.15 * torch.pow(torch.mean(g), 2) + if self.warm_up and self.warm_up_counter < self.warm_iter: + self.warm_up_counter += 1 + else: + Dg += torch.var(g) + if Dg <= 0: + return torch.abs(Dg) + return torch.sqrt(Dg) + + def forward(self, depth_pred, depth_gt, valid_mask=None): + """Forward function.""" + + return self.sigloss(depth_pred, depth_gt, valid_mask) + + +class MultiLoss(nn.Module): + """ + losses adapted from https://www.cs.cornell.edu/projects/megadepth/ + + Args: + dict_losses: (dict[LossType, float, Any]) a dict of losses in the format {LossType_1: Weight_1, ..., LossType_N: Weight_N}. + """ + + def __init__( + self, + dict_losses: dict[LossType, float], + ): + super(MultiLoss, self).__init__() + self.dict_losses = nn.ModuleDict({loss_type.name: loss_type.module() for loss_type in dict_losses.keys()}) + self.dict_weights = {loss_type.name: weight for (loss_type, weight) in dict_losses.items()} + self.eps = 0.001 # avoid grad explode + + def forward(self, depth_pred, depth_gt, valid_mask=None): + """Forward function.""" + + loss_depth = 0 + for loss_name in self.dict_losses.keys(): + weight = self.dict_weights[loss_name] + loss = self.dict_losses[loss_name](depth_pred, depth_gt, valid_mask) + loss_depth += weight * loss + return loss_depth diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/metrics.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..1a4104c4b0829af1237e6c4733b77befb86fd1dc --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/metrics.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import math +from collections import namedtuple +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True) +class _DepthMetric: + name: str + is_lower_better: bool + + @property + def worst_value(self) -> float: + return math.inf if self.is_lower_better else -math.inf + + def is_better(self, value1, value2) -> bool: + sign = 1 if self.is_lower_better else -1 + return sign * value1 < sign * value2 + + +DEPTH_METRICS = ( + _DepthMetric(name="a1", is_lower_better=False), + _DepthMetric(name="a2", is_lower_better=False), + _DepthMetric(name="a3", is_lower_better=False), + _DepthMetric(name="abs_rel", is_lower_better=True), + _DepthMetric(name="rmse", is_lower_better=True), + _DepthMetric(name="log_10", is_lower_better=True), + _DepthMetric(name="rmse_log", is_lower_better=True), + _DepthMetric(name="silog", is_lower_better=True), + _DepthMetric(name="sq_rel", is_lower_better=True), + _DepthMetric(name="mae", is_lower_better=True), +) + +DEPTH_METRICS_NAME = [metric.name for metric in DEPTH_METRICS] + +_DepthMetricValues = namedtuple("DepthMetricValues", [metric.name for metric in DEPTH_METRICS]) # type: ignore + + +def calculate_depth_metrics( + gt: torch.Tensor, + pred: torch.Tensor, + valid_mask: torch.Tensor | None = None, + list_metrics: list[_DepthMetric] = list(DEPTH_METRICS), +): + if gt.shape[0] == 0: + return [torch.nan] * len(DEPTH_METRICS) + + if valid_mask is not None: + valid_mask = torch.logical_and(valid_mask, gt > 0) + + gt = gt[valid_mask] + pred = pred[valid_mask] + + metrics_dict = {} + + metric_names = [metric.name for metric in list_metrics] + + thresh = torch.maximum((gt / pred), (pred / gt)) + metrics_dict["a1"] = (thresh < 1.25).float().mean() if "a1" in metric_names else torch.nan + metrics_dict["a2"] = (thresh < 1.25**2).float().mean() if "a2" in metric_names else torch.nan + metrics_dict["a3"] = (thresh < 1.25**3).float().mean() if "a3" in metric_names else torch.nan + + error = gt - pred + sq_error = error**2 + metrics_dict["mae"] = torch.mean(torch.abs(error)) if "mae" in metric_names else torch.nan + metrics_dict["abs_rel"] = torch.mean(torch.abs(error) / gt) if "abs_rel" in metric_names else torch.nan + metrics_dict["sq_rel"] = torch.mean(sq_error / gt) if "sq_rel" in metric_names else torch.nan + + metrics_dict["rmse"] = torch.sqrt(sq_error.mean()) if "rmse" in metric_names else torch.nan + + error_log = torch.log(gt) - torch.log(pred) + sq_error_log = error_log**2 + metrics_dict["rmse_log"] = torch.sqrt(sq_error_log.mean()) if "rmse_log" in metric_names else torch.nan + if "silog" in metric_names: + silog = torch.sqrt(torch.mean(sq_error_log) - torch.mean(error_log) ** 2) * 100 + if torch.isnan(silog): + silog = torch.tensor(0) + metrics_dict["silog"] = silog + else: + metrics_dict["silog"] = torch.nan + metrics_dict["log_10"] = ( + (torch.abs(torch.log10(gt) - torch.log10(pred))).mean() if "log_10" in metric_names else math.inf + ) + + return _DepthMetricValues(**metrics_dict) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/models/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a74cb7cd3d6f957117c9173fe38777c9c05cf2f --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/models/__init__.py @@ -0,0 +1,266 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. +from dataclasses import dataclass, field +from functools import partial +from typing import Any + +import torch +from dinov3.eval.depth.checkpoint_utils import load_checkpoint + +from .dpt_head import DPTHead +from .linear_head import LinearHead +from .encoder import BackboneLayersSet, DinoVisionTransformerWrapper, PatchSizeAdaptationStrategy + + +@dataclass +class DecoderConfig: + min_depth: float = 0.001 + max_depth: float = 80 + bins_strategy: str = "linear" # (choice: ["linear", "log"]) distribution of bins across the range + norm_strategy: str = ( + "linear" # (choice: ["linear", "softmax", "sigmoid"]) activation used before normalization of depth-bin logits + ) + head_kwargs: dict[str, Any] = field(default_factory=dict) + backbone_out_layers: Any = ( + BackboneLayersSet.FOUR_EVEN_INTERVALS # One of BackboneLayersSet(Enum) or a list of indices e.g. [0, 1, 2, 3] + ) + adapt_to_patch_size: PatchSizeAdaptationStrategy = PatchSizeAdaptationStrategy.CENTER_PADDING + use_backbone_norm: bool = True + # decoder + type: str = "linear" # choices: linear or dpt + n_output_channels: int = 256 + use_batchnorm: bool = False + use_cls_token: bool = False + + +class FeaturesToDepth(torch.nn.Module): + def __init__( + self, + min_depth=0.001, + max_depth=80, + bins_strategy="linear", + norm_strategy="linear", + ): + """ + Module which converts a feature maps into a depth map + + Args: + min_depth (float): minimum depth, used to calibrate the depth range + max_depth (float): maximum depth, used to calibrate the depth range + bins_strategy (str): Choices are 'linear' or 'log', for Uniform or Scale Invariant distributions for depth bins. + See AdaBins [1] for more details. + norm_strategy (str): Choices are 'linear', 'softmax' or 'sigmoid', for the conversion of features to depth logits + scale_up (bool): If true, and only if regression by classification is not used, the result is multiplied by max_depth + + + Example: + x = depth_model(input_image) # N C H W + - If pure regression (C == 1), depth is obtained by scaling and/or shifting x + - If C > 1, bins are used: + Depth is obtained as a weighted sum of depth bins, where weights are predicted logits. (see AdaBins [1] for more details) + + [1] AdaBins: https://github.com/shariqfarooq123/AdaBins + """ + super().__init__() + self.min_depth = min_depth + self.max_depth = max_depth + assert bins_strategy in ["linear", "log"], "Support bins_strategy: linear, log" + assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid" + + self.bins_strategy = bins_strategy + self.norm_strategy = norm_strategy + + def forward(self, x): + n_bins = x.shape[1] # N n_bins H W + if n_bins > 1: + if self.bins_strategy == "linear": + bins = torch.linspace(self.min_depth, self.max_depth, n_bins, device=x.device) + elif self.bins_strategy == "log": + bins = torch.linspace( + torch.log(torch.tensor(self.min_depth)), + torch.log(torch.tensor(self.max_depth)), + n_bins, + device=x.device, + ) + bins = torch.exp(bins) + + # following Adabins, default linear + if self.norm_strategy == "linear": + logit = torch.relu(x) + eps = 0.1 + logit = logit + eps + logit = logit / logit.sum(dim=1, keepdim=True) + elif self.norm_strategy == "softmax": + logit = torch.softmax(x, dim=1) + elif self.norm_strategy == "sigmoid": + logit = torch.sigmoid(x) + logit = logit / logit.sum(dim=1, keepdim=True) + + output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1) + else: + # standard regression + output = torch.relu(x) + self.min_depth + return output + + +def make_head( + embed_dims: int | list[int], + n_output_channels: int, + use_batchnorm: bool = False, + use_cls_token: bool = False, + head_type: str = "linear", + **kwargs, +) -> torch.nn.Module: + if isinstance(embed_dims, int): + embed_dims = [embed_dims] + decoder: torch.nn.Module + if head_type == "linear": + decoder = LinearHead( + in_channels=embed_dims, + n_output_channels=n_output_channels, + use_batchnorm=use_batchnorm, + use_cls_token=use_cls_token, + ) + elif head_type == "dpt": + decoder = DPTHead( + in_channels=embed_dims, + n_output_channels=n_output_channels, + readout_type="project" if use_cls_token else "ignore", + use_batchnorm=use_batchnorm, + **kwargs, + ) + else: + raise NotImplementedError("only linear and DPT head supported") + return decoder + + +class EncoderDecoder(torch.nn.Module): + def __init__( + self, + encoder: torch.nn.Module, + decoder: torch.nn.Module, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + + +class Depther(torch.nn.Module): + def __init__( + self, + encoder: torch.nn.Module, + decoder: torch.nn.Module, + min_depth: float, + max_depth: float, + bins_strategy: str = "linear", + norm_strategy: str = "linear", + autocast_dtype: torch.dtype = torch.float32, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + + self.features_to_depth = FeaturesToDepth( + min_depth=min_depth, + max_depth=max_depth, + bins_strategy=bins_strategy, + norm_strategy=norm_strategy, + ) + if torch.cuda.is_available(): + self.autocast_ctx = partial(torch.autocast, device_type="cuda", dtype=autocast_dtype, enabled=True) + self.encoder.cuda() + self.decoder.cuda() + else: + self.autocast_ctx = partial(torch.autocast, device_type="cpu", enabled=True) + + def forward(self, x): + with self.autocast_ctx(): + x = self.encoder(x) + x = self.decoder(x) + x = self.features_to_depth(x) + return x + + +def build_depther( + backbone: torch.nn.Module, + backbone_out_layers: tuple[int, ...] | BackboneLayersSet, + n_output_channels: int, + use_backbone_norm: bool = False, + use_batchnorm: bool = False, + use_cls_token: bool = False, + adapt_to_patch_size: PatchSizeAdaptationStrategy = PatchSizeAdaptationStrategy.CENTER_PADDING, + head_type: str = "dpt", + autocast_dtype: torch.dtype = torch.float32, + # depth args + min_depth: float = 0.001, + max_depth: float = 10.0, + bins_strategy: str = "linear", + norm_strategy: str = "linear", + **kwargs, +): + encoder = DinoVisionTransformerWrapper( + backbone_model=backbone, + backbone_out_layers=backbone_out_layers, + use_backbone_norm=use_backbone_norm, + adapt_to_patch_size=adapt_to_patch_size, + ) + + decoder = make_head( + encoder.embed_dims, + n_output_channels=n_output_channels, + use_batchnorm=use_batchnorm, + use_cls_token=use_cls_token, + head_type=head_type, + **kwargs, + ) + + depther = Depther( + encoder=encoder, + decoder=decoder, + min_depth=min_depth, + max_depth=max_depth, + bins_strategy=bins_strategy, + norm_strategy=norm_strategy, + autocast_dtype=autocast_dtype, + ) + depther.eval() + return depther + + +def make_depther_from_config( + backbone, + config: DecoderConfig, + checkpoint_path: str | None = None, + autocast_dtype: torch.dtype = torch.float32, +) -> Depther: + depther = build_depther( + backbone, + backbone_out_layers=config.backbone_out_layers, + n_output_channels=config.n_output_channels, + use_backbone_norm=config.use_backbone_norm, + use_batchnorm=config.use_batchnorm, + use_cls_token=config.use_cls_token, + adapt_to_patch_size=config.adapt_to_patch_size, + head_type=config.type, + autocast_dtype=autocast_dtype, + min_depth=config.min_depth, + max_depth=config.max_depth, + bins_strategy=config.bins_strategy, + norm_strategy=config.norm_strategy, + **config.head_kwargs, + ) + + # resume from checkpoint + if checkpoint_path is not None: + state_dicts, _ = load_checkpoint(checkpoint_path) + # load head state dict, the backbone has its pretrained_weights + depther.decoder.load_state_dict(state_dicts["model"], strict=True) + + return depther diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/models/dpt_head.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/models/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9b6be42db6e7937b94b51aa5050c0e333a8a5dbd --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/models/dpt_head.py @@ -0,0 +1,532 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import torch +from torch import nn + + +def kaiming_init( + module: nn.Module, + a: float = 0, + mode: str = "fan_out", + nonlinearity: str = "relu", + bias: float = 0, + distribution: str = "normal", +) -> None: + assert distribution in ["uniform", "normal"] + if hasattr(module, "weight") and module.weight is not None: + if distribution == "uniform": + nn.init.kaiming_uniform_(module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + else: + nn.init.kaiming_normal_(module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + if hasattr(module, "bias") and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def constant_init(module, val, bias=0): + if hasattr(module, "weight") and module.weight is not None: + nn.init.constant_(module.weight, val) + if hasattr(module, "bias") and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +class ConvModule(nn.Module): + """A conv block that bundles conv/norm/activation layers. + This block simplifies the usage of convolution layers, which are commonly + used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + It is based upon three build methods: `build_conv_layer()`, + `build_norm_layer()` and `build_activation_layer()`. + Besides, we add some additional features in this module. + 1. Automatically set `bias` of the conv layer. + 2. Spectral norm is supported. + 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only + supports zero and circular padding, and we add "reflect" padding mode. + Args: + in_channels (int): Number of channels in the input feature map. + Same as that in ``nn._ConvNd``. + out_channels (int): Number of channels produced by the convolution. + Same as that in ``nn._ConvNd``. + kernel_size (int | tuple[int]): Size of the convolving kernel. + Same as that in ``nn._ConvNd``. + stride (int | tuple[int]): Stride of the convolution. + Same as that in ``nn._ConvNd``. + padding (int | tuple[int]): Zero-padding added to both sides of + the input. Same as that in ``nn._ConvNd``. + dilation (int | tuple[int]): Spacing between kernel elements. + Same as that in ``nn._ConvNd``. + groups (int): Number of blocked connections from input channels to + output channels. Same as that in ``nn._ConvNd``. + bias (bool | str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise + False. Default: "auto". + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + inplace (bool): Whether to use inplace mode for activation. + Default: True. + with_spectral_norm (bool): Whether use spectral norm in conv module. + Default: False. + padding_mode (str): If the `padding_mode` has not been supported by + current `Conv2d` in PyTorch, we will use our own padding layer + instead. Currently, we support ['zeros', 'circular'] with official + implementation and ['reflect'] with our own implementation. + Default: 'zeros'. + order (tuple[str]): The order of conv/norm/activation layers. It is a + sequence of "conv", "norm" and "act". Common examples are + ("conv", "norm", "act") and ("act", "conv", "norm"). + Default: ('conv', 'norm', 'act'). + """ + + _abbr_ = "conv_block" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias="auto", + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type="ReLU"), + inplace=True, + with_spectral_norm=False, + padding_mode="zeros", + order=("conv", "norm", "act"), + ): + super().__init__() + assert conv_cfg is None or isinstance(conv_cfg, dict) + assert norm_cfg is None or isinstance(norm_cfg, dict) + assert act_cfg is None or isinstance(act_cfg, dict) + official_padding_mode = ["zeros", "circular"] + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.inplace = inplace + self.with_spectral_norm = with_spectral_norm + self.with_explicit_padding = padding_mode not in official_padding_mode + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 3 + assert set(order) == {"conv", "norm", "act"} + + self.with_norm = norm_cfg is not None + self.with_activation = act_cfg is not None + # if the conv layer is before a norm layer, bias is unnecessary. + if bias == "auto": + bias = not self.with_norm + self.with_bias = bias + + # if self.with_explicit_padding: + # pad_cfg = dict(type=padding_mode) + # self.padding_layer = build_padding_layer(pad_cfg, padding) + # to do Camille put back + + # reset padding to 0 for conv module + conv_padding = 0 if self.with_explicit_padding else padding + # build convolution layer + self.conv = nn.Conv2d( # build_conv_layer(#conv_cfg, + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=conv_padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + # export the attributes of self.conv to a higher level for convenience + self.in_channels = self.conv.in_channels + self.out_channels = self.conv.out_channels + self.kernel_size = self.conv.kernel_size + self.stride = self.conv.stride + self.padding = padding + self.dilation = self.conv.dilation + self.transposed = self.conv.transposed + self.output_padding = self.conv.output_padding + self.groups = self.conv.groups + + if self.with_spectral_norm: + self.conv = nn.utils.spectral_norm(self.conv) + + # # build normalization layers + if self.with_norm: + # norm layer is after conv layer + if order.index("norm") > order.index("conv"): + norm_channels = out_channels + else: + norm_channels = in_channels + # self.norm_name, norm = build_norm_layer( + # norm_cfg, norm_channels) # type: ignore + self.add_module("bn", torch.nn.SyncBatchNorm(norm_channels)) + # if self.with_bias: + # if isinstance(norm, (_BatchNorm, _InstanceNorm)): + # warnings.warn( + # 'Unnecessary conv bias before batch/instance norm') + self.norm_name = "bn" + else: + self.norm_name = None # type: ignore + + # build activation layer + if self.with_activation: + act_cfg_ = act_cfg.copy() # type: ignore + # nn.Tanh has no 'inplace' argument + if act_cfg_["type"] not in ["Tanh", "PReLU", "Sigmoid", "HSigmoid", "Swish", "GELU"]: + act_cfg_.setdefault("inplace", inplace) + self.activate = nn.ReLU() # build_activation_layer(act_cfg_) + + # Use msra init by default + self.init_weights() + + @property + def norm(self): + if self.norm_name: + return getattr(self, self.norm_name) + else: + return None + + def init_weights(self): + # 1. It is mainly for customized conv layers with their own + # initialization manners by calling their own ``init_weights()``, + # and we do not want ConvModule to override the initialization. + # 2. For customized conv layers without their own initialization + # manners (that is, they don't have their own ``init_weights()``) + # and PyTorch's conv layers, they will be initialized by + # this method with default ``kaiming_init``. + # Note: For PyTorch's conv layers, they will be overwritten by our + # initialization implementation using default ``kaiming_init``. + if not hasattr(self.conv, "init_weights"): + if self.with_activation and self.act_cfg["type"] == "LeakyReLU": + nonlinearity = "leaky_relu" + a = self.act_cfg.get("negative_slope", 0.01) + else: + nonlinearity = "relu" + a = 0 + kaiming_init(self.conv, a=a, nonlinearity=nonlinearity) + if self.with_norm: + constant_init(self.norm_name, 1, bias=0) + + def forward(self, x: torch.Tensor, activate: bool = True, norm: bool = True, debug: bool = False) -> torch.Tensor: + for layer in self.order: + if debug: + breakpoint() + if layer == "conv": + if self.with_explicit_padding: + x = self.padding_layer(x) + x = self.conv(x) + elif layer == "norm" and norm and self.with_norm: + x = self.norm(x) + elif layer == "act" and activate and self.with_activation: + x = self.activate(x) + return x + + +class Interpolate(nn.Module): + def __init__(self, scale_factor, mode, align_corners=False): + super(Interpolate, self).__init__() + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) + return x + + +class UpConvHead(nn.Module): + """ + A 3 layer Convolutional head with intermediate upsampling + + Args: + - features (int): number of input channels + - n_output_channels (int, default=256): number of output channels + - n_hidden_channels (int, default=32): number of channels in hidden layer + + The operations are + [ + Conv3x3(features, features // 2), + 2x-Upsampling, + Conv3x3(features // 2, hidden_channels), + ReLU, + Conv1x1(hidden_channels, n_output_channels), + ] + """ + + def __init__(self, features, n_output_channels, n_hidden_channels=32): + super(UpConvHead, self).__init__() + self.n_output_channels = n_output_channels + self.head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, n_hidden_channels, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2d(n_hidden_channels, n_output_channels, kernel_size=1, stride=1, padding=0), + ) + + def forward(self, x): + x = self.head(x) + return x + + +class ReassembleBlocks(nn.Module): + """ViTPostProcessBlock, process cls_token in ViT backbone output and + rearrange the feature vector to feature map. + Args: + in_channels (List): ViT feature channels. + Default: [1024, 1024, 1024, 1024]. + out_channels (List): output channels of each stage. + Default: [128, 256, 512, 1024]. + readout_type (str): Type of readout operation. Default: 'ignore'. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__( + self, + in_channels=[1024, 1024, 1024, 1024], + out_channels=[128, 256, 512, 1024], + readout_type="project", + use_batchnorm=False, + ): + super(ReassembleBlocks, self).__init__() + + assert readout_type in ["ignore", "add", "project"] + self.readout_type = readout_type + + self.projects = nn.ModuleList( + [ + ConvModule( + in_channels=in_channels[channel_index], + out_channels=out_channel, + kernel_size=1, + act_cfg=None, + ) + for channel_index, out_channel in enumerate(out_channels) + ] + ) + + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + if self.readout_type == "project": + self.readout_projects = nn.ModuleList() + for i in range(len(self.projects)): + self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels[i], in_channels[i]), nn.GELU())) + + self.batchnorm_layers = nn.ModuleList( + [nn.SyncBatchNorm(channel) if use_batchnorm else nn.Identity(channel) for channel in in_channels] + ) + + def forward(self, inputs): + assert isinstance(inputs, list) + out = [] + for i, x in enumerate(inputs): + assert len(x) == 2 + x, cls_token = x[0], x[1] + feature_shape = x.shape + if self.readout_type == "project": + x = x.flatten(2).permute((0, 2, 1)) + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + x = x.permute(0, 2, 1).reshape(feature_shape) + elif self.readout_type == "add": + x = x.flatten(2) + cls_token.unsqueeze(-1) + x = x.reshape(feature_shape) + else: + pass + x = self.batchnorm_layers[i](x) + x = self.projects[i](x) + x = self.resize_layers[i](x) + out.append(x) + return out + + +class PreActResidualConvUnit(nn.Module): + """ResidualConvUnit, pre-activate residual unit. + Args: + in_channels (int): number of channels in the input feature map. + act_cfg (dict): dictionary to construct and config activation layer. + norm_cfg (dict): dictionary to construct and config norm layer. + stride (int): stride of the first block. Default: 1 + dilation (int): dilation rate for convs layers. Default: 1. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, in_channels, act_cfg, norm_cfg, stride=1, dilation=1, init_cfg=None): + super(PreActResidualConvUnit, self).__init__() # init_cfg) + self.conv1 = ConvModule( + in_channels, + in_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=False, + order=("act", "conv", "norm"), + ) + self.conv2 = ConvModule( + in_channels, + in_channels, + 3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=False, + order=("act", "conv", "norm"), + ) + + def forward(self, inputs): + inputs_ = inputs.clone() + x = self.conv1(inputs) + x = self.conv2(x) + return x + inputs_ + + +class FeatureFusionBlock(nn.Module): + """FeatureFusionBlock, merge feature map from different stages. + Args: + in_channels (int): Input channels. + act_cfg (dict): The activation config for ResidualConvUnit. + norm_cfg (dict): Config dict for normalization layer. + expand (bool): Whether expand the channels in post process block. + Default: False. + align_corners (bool): align_corner setting for bilinear upsample. + Default: True. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, in_channels, act_cfg, norm_cfg, expand=False, align_corners=True, init_cfg=None): + super(FeatureFusionBlock, self).__init__() # init_cfg) + self.in_channels = in_channels + self.expand = expand + self.align_corners = align_corners + self.out_channels = in_channels + if self.expand: + self.out_channels = in_channels // 2 + self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_cfg=None, bias=True) + self.res_conv_unit1 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + self.res_conv_unit2 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + + def forward(self, *inputs): + x = inputs[0] + + if len(inputs) == 2: + if x.shape != inputs[1].shape: + res = torch.nn.functional.interpolate( + inputs[1], + size=(x.shape[2], x.shape[3]), + mode="bilinear", + align_corners=False, + ) + else: + res = inputs[1] + x = x + self.res_conv_unit1(res) + x = self.res_conv_unit2(x) # ok + + x = torch.nn.functional.interpolate(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners) + # ok + + x = self.project(x) # ok + return x + + +class DPTHead(nn.Module): + """Vision Transformers for Dense Prediction. + This head is implemented of `DPT `_. + Args: + in_channels (List): The input dimensions of the ViT backbone. + Default: [1024, 1024, 1024, 1024]. + channels (int): Channels after modules, before the task-specific module + (`conv_depth`). Default: 256. + post_process_channels (List): Out channels of post process conv + layers. Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + expand_channels (bool): Whether expand the channels in post process + block. Default: False. + """ + + def __init__( + self, + in_channels=(1024, 1024, 1024, 1024), + channels=256, + post_process_channels=[128, 256, 512, 1024], + readout_type="project", + expand_channels=False, + n_output_channels=256, + use_batchnorm=False, # TODO + **kwargs, + ): + super(DPTHead, self).__init__(**kwargs) + self.channels = channels + self.n_output_channels = n_output_channels + self.in_channels = in_channels + self.expand_channels = expand_channels + self.norm_cfg = None # TODO CHECK THIS + self.reassemble_blocks = ReassembleBlocks( + in_channels=in_channels, + out_channels=post_process_channels, + readout_type=readout_type, + use_batchnorm=use_batchnorm, + ) + + self.post_process_channels = [ + channel * (2**i) if expand_channels else channel for i, channel in enumerate(post_process_channels) + ] + self.convs = nn.ModuleList() + for channel in self.post_process_channels: + self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_cfg=None, bias=False)) + self.fusion_blocks = nn.ModuleList() + self.act_cfg = {"type": "ReLU"} + for _ in range(len(self.convs)): + self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_cfg, self.norm_cfg)) + self.fusion_blocks[0].res_conv_unit1 = None + self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_cfg=self.norm_cfg) + self.num_fusion_blocks = len(self.fusion_blocks) + self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) + self.num_post_process_channels = len(self.post_process_channels) + assert self.num_fusion_blocks == self.num_reassemble_blocks + assert self.num_reassemble_blocks == self.num_post_process_channels + self.conv_depth = UpConvHead(self.channels, self.n_output_channels) + + def forward_features(self, inputs): + assert len(inputs) == self.num_reassemble_blocks, ( + f"Expected {self.num_reassemble_blocks} inputs, got {len(inputs)}." + ) + x = [inp for inp in inputs] + + x = self.reassemble_blocks(x) + x = [self.convs[i](feature) for i, feature in enumerate(x)] + out = self.fusion_blocks[0](x[-1]) + + for i in range(1, len(self.fusion_blocks)): + out = self.fusion_blocks[i](out, x[-(i + 1)]) + + out = self.project(out) + return out + + def forward(self, inputs): + out = self.forward_features(inputs) + return self.conv_depth(out) + + def predict(self, inputs, rescale_to=(512, 512)): + out = self.forward_features(inputs) + return self.conv_depth.predict(out, rescale_to=rescale_to) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/models/embed.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/models/embed.py new file mode 100644 index 0000000000000000000000000000000000000000..7bb5d275cd0ef20cec8c3d23c84e0e1dd07f466d --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/models/embed.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import itertools +import math + +import torch + + +class CenterPadding(torch.nn.Module): + def __init__(self, multiple: int): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + # @torch.inference_mode() + def forward(self, x): + # expected shapes are ... x H x W, usually B x C x H x W + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:-3:-1])) + output = torch.nn.functional.pad(x, pads) + return output + + def __extra_repr__(self) -> str: + return f"multiple={self.multiple}" + + +class StretchToMultiple(torch.nn.Module): + def __init__(self, multiple: int): + super().__init__() + self.multiple = multiple + + def forward(self, x): + # expected shapes are ... x H x W, usually B x C x H x W + *shape, C, H, W = x.shape + new_H = math.ceil(H / self.multiple) * self.multiple + new_W = math.ceil(W / self.multiple) * self.multiple + if new_H != H or new_W != W: + x = x.reshape(-1, C, H, W) + x = torch.nn.functional.interpolate(x, size=(new_H, new_W), mode="bilinear") + x = x.reshape(*shape, C, new_H, new_W) + return x + + def __extra_repr__(self) -> str: + return f"multiple={self.multiple}" diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/models/encoder.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/models/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a09cd9ac0e0c07ae17e834f9b01560b026913667 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/models/encoder.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from enum import Enum + +from dinov3.eval.depth.models.embed import CenterPadding, StretchToMultiple +from torch import Tensor, nn + +logger = logging.getLogger("dinov3") + + +class BackboneLayersSet(Enum): + # Set of intermediate layers to take from the backbone + LAST = "LAST" # extracting only the last layer + FOUR_LAST = "FOUR_LAST" # extracting the last 4 layers + FOUR_EVEN_INTERVALS = "FOUR_EVEN_INTERVALS" # extracting outputs every 1/4 of the total number of blocks + + +def _get_backbone_out_indices( + model: nn.Module, + backbone_out_layers: list[int] | tuple[int, ...] | BackboneLayersSet = BackboneLayersSet.FOUR_EVEN_INTERVALS, +): + """ + Get indices for output layers of the ViT backbone. For now there are 3 options available: + BackboneLayersSet.LAST : only extract the last layer, used in segmentation tasks with a bn head. + BackboneLayersSet.FOUR_LAST : extract the last 4 layers, used in segmentation (multiscale setting) + BackboneLayersSet.FOUR_EVEN_INTERVALS : extract outputs every 1/4 of the total number of blocks + Reference outputs in 'FOUR_EVEN_INTERVALS' mode : + ViT/S (12 blocks): [2, 5, 8, 11] + ViT/B (12 blocks): [2, 5, 8, 11] + ViT/L (24 blocks): [5, 11, 17, 23] (correct), [4, 11, 17, 23] (incorrect) + ViT/g (40 blocks): [9, 19, 29, 39] + """ + n_blocks = getattr(model, "n_blocks", 1) + out_indices: list[int] + if isinstance(backbone_out_layers, (tuple, list)): + out_indices = list(backbone_out_layers) + elif backbone_out_layers == BackboneLayersSet.LAST: + out_indices = [n_blocks - 1] + elif backbone_out_layers == BackboneLayersSet.FOUR_LAST: + out_indices = [i for i in range(n_blocks - 4, n_blocks)] + elif backbone_out_layers == BackboneLayersSet.FOUR_EVEN_INTERVALS: + # XXX: Force (incorrect) out indices for backward-compatibility (ViT/L only) + if n_blocks == 24: + out_indices = [4, 11, 17, 23] + else: + out_indices = [i * (n_blocks // 4) - 1 for i in range(1, 5)] + assert all([out_index < n_blocks for out_index in out_indices]) + return out_indices + + +class PatchSizeAdaptationStrategy(Enum): + CENTER_PADDING = "center_padding" + STRETCH = "stretch" + NO_ADAPTATION = "never" + + +class DinoVisionTransformerWrapper(nn.Module): + """Vision Transformer.""" + + def __init__( + self, + backbone_model: nn.Module, + backbone_out_layers: str | tuple[int, ...] | BackboneLayersSet, + use_backbone_norm: bool = False, + adapt_to_patch_size: PatchSizeAdaptationStrategy = PatchSizeAdaptationStrategy.CENTER_PADDING, + ): + super().__init__() + + self.final_norm = use_backbone_norm + self.backbone = backbone_model + if isinstance(backbone_out_layers, str): + backbone_out_layers = BackboneLayersSet(backbone_out_layers) + self.backbone_out_indices = _get_backbone_out_indices(self.backbone, backbone_out_layers=backbone_out_layers) + + # If the backbone does not define embed_dims, use [embed_dim] * n_blocks + try: + embed_dims: list[int] = getattr(self.backbone, "embed_dims") + except AttributeError: + embed_dim: int = getattr(self.backbone, "embed_dim") + n_blocks: int = getattr(self.backbone, "n_blocks") + logger.warning(f"Backbone does not define embed_dims, using {[embed_dim] * n_blocks} instead") + embed_dims = [embed_dim] * n_blocks + self.embed_dims = [embed_dims[idx] for idx in self.backbone_out_indices] + + # How to adapt input images to the patch size of the model? + try: + input_pad_size = getattr(self.backbone, "input_pad_size") + except AttributeError: + patch_size = getattr(self.backbone, "patch_size") + logger.warning(f"Backbone does not define input_pad_size, using {patch_size=} instead") + input_pad_size = patch_size + self.patch_size_adapter: nn.Module = nn.Identity() + if adapt_to_patch_size is PatchSizeAdaptationStrategy.CENTER_PADDING: + self.patch_size_adapter = CenterPadding(input_pad_size) + elif adapt_to_patch_size is PatchSizeAdaptationStrategy.STRETCH: + self.patch_size_adapter = StretchToMultiple(input_pad_size) + + # Freeze backbone + self.backbone.requires_grad_(False) + + def forward( + self, + x: Tensor, # [B, rgb, H, W] + ) -> list[tuple[Tensor, Tensor]]: + x = self.patch_size_adapter(x) + outputs = self.backbone.get_intermediate_layers( # type: ignore + x, + n=self.backbone_out_indices, + reshape=True, + return_class_token=True, + norm=self.final_norm, + ) # List of (patch feats [B, C, h, w], class token [B, C]) + return outputs diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/models/linear_head.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/models/linear_head.py new file mode 100644 index 0000000000000000000000000000000000000000..76b592392a0e134b277fbe7d0cef68f2ba035355 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/models/linear_head.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LinearHead(nn.Module): + """Linear layer .""" + + def __init__( + self, + in_channels, + n_output_channels, + input_transform="resize", + align_corners=False, + use_batchnorm=False, + use_cls_token=True, + ): + super().__init__() + self.in_channels = in_channels + self.channels = sum(in_channels) + if use_cls_token: + self.channels *= 2 # concatenate CLS to patch tokens + self.input_transform = input_transform + self.align_corners = align_corners + self.n_output_channels = n_output_channels + self.use_cls_token = use_cls_token + + # batchnorm + self.batchnorm_layer = nn.SyncBatchNorm(self.channels) if use_batchnorm else nn.Identity(self.channels) + + # linear head + self.conv_depth = nn.Conv2d(self.channels, self.n_output_channels, kernel_size=1, padding=0, stride=1) + nn.init.normal_(self.conv_depth.weight, mean=0, std=0.01) + nn.init.constant_(self.conv_depth.bias, 0) + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + Tensor: The transformed inputs + """ + if "resize" in self.input_transform: + inputs = [ + torch.nn.functional.interpolate( + input=x, + size=[s for s in inputs[0].shape[2:]], + mode="bilinear", + align_corners=self.align_corners, + ) + for x in inputs + ] + inputs = torch.cat(inputs, dim=1) + return inputs + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + # accept lists (for cls token) + inputs = list(inputs) + for i, x in enumerate(inputs): + if self.use_cls_token: + assert len(x) == 2, "Missing class tokens" + x, cls_token = x[0], x[1] + if len(x.shape) == 2: + x = x[:, :, None, None] + cls_token = cls_token[:, :, None, None].expand_as(x) + inputs[i] = torch.cat((x, cls_token), 1) + else: + x = x[0] + if len(x.shape) == 2: + x = x[:, :, None, None] + inputs[i] = x + x = self._transform_inputs(inputs) + return x + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.batchnorm_layer(output) + output = self.conv_depth(output) + return output + + def predict(self, x, rescale_to=(512, 512)): + x = self.forward(x) + x = F.interpolate(input=x, size=rescale_to, mode="bilinear") + return x diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/run.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/run.py new file mode 100644 index 0000000000000000000000000000000000000000..60ad6e20dfc2ab21aeb2c90276d2f770001794c5 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/run.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import json +import logging +import os +import sys +from typing import Any, Dict + +import torch +from omegaconf import OmegaConf + +import dinov3.distributed as distributed +from dinov3.eval.depth.checkpoint_utils import find_latest_checkpoint +from dinov3.eval.depth.config import DepthConfig +from dinov3.eval.depth.eval import evaluate_depther_with_config +from dinov3.eval.depth.models import make_depther_from_config +from dinov3.eval.depth.train import train_model_with_backbone + +from dinov3.eval.helpers import args_dict_to_dataclass, cli_parser, write_results +from dinov3.eval.setup import load_model_and_context +from dinov3.run.init import job_context +from dinov3.hub.depthers import _get_depther_config, dinov3_vit7b16_dd + +RESULTS_FILENAME = "results-depth.csv" +MAIN_METRICS = [".*_abs_rel", ".*_a1", ".*_rmse"] + + +logger = logging.getLogger("dinov3") + + +def _add_dataset_prefix_to_results(results_dict: Dict[str, float], dataset_name: str): + final_dict = {dataset_name + "_" + k: v for k, v in results_dict.items()} + return final_dict + + +def eval_depther_with_model(*, depther: torch.nn.Module, config: DepthConfig): + if config.load_from is None: + config.load_from = find_latest_checkpoint(config.output_dir) + + logger.info(f"Using config: \n {OmegaConf.to_yaml(config)}") + results_dict, _, _ = evaluate_depther_with_config( + config=config, + depther=depther, + device=distributed.get_rank(), + reduce_results=False, + ) + test_config_name = config.datasets.test.split(":", 1)[0] + test_save_dir = os.path.join(config.output_dir, test_config_name) + # reduce results + if distributed.is_main_process(): + if not os.path.exists(test_save_dir): + os.makedirs(test_save_dir) + with open(os.path.join(test_save_dir, "results.json"), "w") as f: + json.dump(results_dict, f, indent=4) + for metric, values in results_dict.items(): + results_dict[metric] = float(torch.Tensor(values).nanmean()) # result can be NaN if ground truth is all masked + summary = " \n====== Summary ======\n" + summary += ( + f"{test_config_name:<10} " + + " ".join([f"{metric}: {value:.3f}" for metric, value in results_dict.items()]) + + "\n" + ) + results_dict = _add_dataset_prefix_to_results(results_dict, test_config_name) + summary += "=====================" + logger.info(summary) + return results_dict + + +def benchmark_launcher(eval_args: dict[str, Any]) -> dict[str, Any]: + """Initialization of distributed and logging are preconditions for this method""" + if "config" in eval_args: + base_config_path = eval_args.pop("config") + output_dir = eval_args["output_dir"] + base_config = OmegaConf.load(base_config_path) + structured_config = OmegaConf.structured(DepthConfig) + depth_config: DepthConfig = OmegaConf.to_object( # type: ignore + OmegaConf.merge( + structured_config, + base_config, + OmegaConf.create(eval_args), + ) + ) + else: + depth_config, output_dir = args_dict_to_dataclass( + eval_args=eval_args, config_dataclass=DepthConfig, save_config=False + ) + OmegaConf.save(config=depth_config, f=os.path.join(output_dir, "depth_config.yaml")) + + config_autocast_dtype = depth_config.model_dtype.autocast_dtype if depth_config.model_dtype is not None else None + if depth_config.load_from == "dinov3_vit7b16_dd": + with torch.device("cuda" if torch.cuda.is_available() else "cpu"): + autocast_dtype = config_autocast_dtype or torch.float32 + # override config parameters with those of the pretrained depther + depther_config = _get_depther_config("dinov3_vit7b16") + depth_config.decoder_head = OmegaConf.to_object( # type: ignore + OmegaConf.merge( + depth_config.decoder_head, + depther_config, + ) + ) + + depther = dinov3_vit7b16_dd( + pretrained=True, + autocast_dtype=autocast_dtype, + ) + else: + with torch.device("cuda" if torch.cuda.is_available() else "cpu"): + assert depth_config.model is not None + model, model_context = load_model_and_context(depth_config.model, output_dir=output_dir) + autocast_dtype = config_autocast_dtype or model_context["autocast_dtype"] + + if depth_config.load_from: + depther = make_depther_from_config( + backbone=model, + config=depth_config.decoder_head, + checkpoint_path=depth_config.load_from, + autocast_dtype=autocast_dtype, + ) + logger.info(f"Depth config:\n {OmegaConf.to_yaml(depth_config)}") + else: + # train backbone + depther = train_model_with_backbone(depth_config, model, autocast_dtype) + + results_dict = eval_depther_with_model(depther=depther, config=depth_config) + write_results(results_dict, output_dir, RESULTS_FILENAME) + return results_dict + + +def main(argv=None): + if argv is None: + argv = sys.argv[1:] + eval_args = cli_parser(argv) + with job_context(output_dir=eval_args["output_dir"]): + benchmark_launcher(eval_args=eval_args) + return 0 + + +if __name__ == "__main__": + main() diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/schedulers.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..394f841634ad19aeddc997f3ee8966ceb54b014e --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/schedulers.py @@ -0,0 +1,261 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from inspect import signature +import math +from typing import Any, Literal + +import torch +from packaging.version import Version +from torch.optim import lr_scheduler as torch_schedulers +from torch.optim.optimizer import Optimizer + + +TORCH_VERSION = Version(torch.__version__) + + +def annealing_cos(start, end, pct): + "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + + +def annealing_linear(start, end, pct): + "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0." + return (end - start) * pct + start + + +class WarmupOneCycleLR(torch_schedulers.LRScheduler): + def __init__( + self, + optimizer: Optimizer, + total_steps: int = 0, + warmup_iters: int = 0, + warmup_ratio: float = 0.0, # XXX: warmup ratio to deprecate, previously defined in mmcv segmentation code + pct_start: float = 0.295, + max_lr: float | list[float] | None = None, + anneal_strategy: Literal["cos", "linear"] = "cos", + base_momentum: float | tuple[float, ...] = 0.85, + max_momentum: float | tuple[float, ...] = 0.95, + div_factor: float = 25.0, + final_div_factor: float = 1000.0, + use_beta1: bool = True, + update_momentum: bool = True, + last_epoch: int = -1, + ): + """ + A variant of OneCycleLR with a warmup on top which potentially + replaces the first phase of the original OneCycleLR. + """ + self.warmup_iters = warmup_iters + self.warmup_ratio = warmup_ratio + self.min_point = float(pct_start * total_steps) + self.base_momentum = base_momentum + self.max_momentum = max_momentum + self.total_steps = total_steps + self.use_beta1 = use_beta1 + self.anneal_strategy = anneal_strategy + self.final_div_factor = final_div_factor + self.update_momentum = update_momentum + assert self.anneal_strategy in [ + "cos", + "linear", + ], f"Only cosine and linear-annealing strategy supported, got {self.anneal_strategy}" + assert total_steps > 0 + + if isinstance(max_lr, float): + max_lr = [max_lr] # type: ignore + # Initialize learning rate variables and momentum + for idx, group in enumerate(optimizer.param_groups): + if "initial_lr" not in group: + assert last_epoch == -1 + try: # this avoids getting an error when there are multiple lrs AND weight decay values + ml = max_lr[idx] # type: ignore + except IndexError: + ml = group["lr"] + assert isinstance(ml, float) # makes sure that the variable is well updated + group["initial_lr"] = ml / div_factor + group["max_lr"] = ml + group["min_lr"] = group["initial_lr"] / final_div_factor + # initialize learning rate + group["lr"] = ml / final_div_factor if self.warmup_iters > 0 else group["initial_lr"] + if self.use_beta1: + group["betas"] = (self.max_momentum, *group["betas"][1:]) + elif self.update_momentum: + group["momentum"] = self.max_momentum + + super().__init__(optimizer, last_epoch) + + def _anneal_func(self, *args, **kwargs): + if self.anneal_strategy == "cos": + return annealing_cos(*args, **kwargs) + elif self.anneal_strategy == "linear": + return annealing_linear(*args, **kwargs) + + def _compute_lr_momentum(self, optimizer_param_group): + # torch.optim.lr_scheduler.LRScheduler does an initial + # step that sets self._step_count = 1 + step_num = (self._step_count - 1) if self.last_epoch != -1 else 0 + momentum = 0 + if step_num < self.warmup_iters: + if self.warmup_ratio: # XXX (remove in the next BW-compatibility breaking cleanup) + k = (1 - step_num / self.warmup_iters) * (1 - self.warmup_ratio) + warmup_lr = optimizer_param_group["max_lr"] * (1 - k) + thelr = warmup_lr * (1 - step_num / self.total_steps) + else: + gmax = ( + optimizer_param_group["max_lr"] * (1 + math.cos(math.pi * step_num / float(self.total_steps))) / 2 + ) + thelr = optimizer_param_group["max_lr"] / self.final_div_factor + gmax * step_num / float( + self.warmup_iters + ) + else: + pct = (step_num - self.warmup_iters) / float(self.total_steps - self.warmup_iters) + step_num_to_use = step_num + momentum = self._anneal_func( + self.base_momentum, + self.max_momentum, + pct, + ) + if self.anneal_strategy == "cos": + step_num_to_use += 1 + thelr = self._anneal_func( + optimizer_param_group["max_lr"], + optimizer_param_group["min_lr"], + step_num_to_use / float(self.total_steps), + ) + return thelr, momentum + + def get_lr(self): + """Compute the learning rate of each parameter group.""" + if TORCH_VERSION >= Version("2.4.0"): + torch_schedulers._warn_get_lr_called_within_step(self) + + lrs = [] + step_num = self.last_epoch + + if step_num > self.total_steps: + raise ValueError( + f"Tried to step {step_num} times. The specified number of total steps is {self.total_steps}" # noqa: UP032 + ) + + for group in self.optimizer.param_groups: + computed_lr, computed_momentum = self._compute_lr_momentum(group) + lrs.append(computed_lr) # type: ignore[possibly-undefined] + if self.use_beta1: + group["betas"] = (computed_momentum, *group["betas"][1:]) # type: ignore[possibly-undefined] + elif self.update_momentum: + group["momentum"] = computed_momentum # type: ignore[possibly-undefined] + + return lrs + + +class WarmupMultiStepLR(torch_schedulers.LRScheduler): + def __init__( + self, + optimizer: Optimizer, + total_steps: int = 0, + milestones: list[float] = [0.5, 0.9, 1.0], + gamma: float = 0.1, + warmup_iters: int = 0, + max_lr: float | list[float] | None = None, + last_epoch: int = -1, + ): + """ + A variant of MultiStepLR with a warmup on top which potentially + replaces the first phase of the original OneCycleLR. + Instead of using epochs to define the milestones, this scheduler uses number of iterations + as it is the case when training dense heads. Two main parameters are: + - milestones (list of floats, between 0-1): indicates the % of iterations after which + the step schedule will be applied. + - gamma (float): factor to multiply the lr by, at each milestone + """ + self.milestones = milestones + self.milestone_index = 0 + self.gamma = gamma + self.warmup_iters = warmup_iters + self.total_steps = total_steps + assert total_steps > 0 + + max_lr_list = [max_lr] if isinstance(max_lr, float) else max_lr + # Initialize learning rate variables and momentum + for idx, group in enumerate(optimizer.param_groups): + if "initial_lr" not in group: + assert last_epoch == -1 + max_lr = max_lr_list[idx] if max_lr_list else group["lr"] + group["initial_lr"] = max_lr + group["max_lr"] = max_lr + super().__init__(optimizer, last_epoch) + + def _compute_lr(self, optimizer_param_group): + if self.warmup_iters > 0 and self._step_count < self.warmup_iters: + thelr = optimizer_param_group["max_lr"] * (self._step_count / self.warmup_iters) + else: + if self._step_count >= self.total_steps * self.milestones[self.milestone_index]: + self.milestone_index += 1 + thelr = optimizer_param_group["max_lr"] * (self.gamma**self.milestone_index) + return thelr + + def get_lr(self): + """Compute the learning rate of each parameter group.""" + torch_schedulers._warn_get_lr_called_within_step(self) + + lrs = [] + step_num = self.last_epoch + + if step_num > self.total_steps: + raise ValueError( + f"Tried to step {step_num} times. The specified number of total steps is {self.total_steps}" # noqa: UP032 + ) + for group in self.optimizer.param_groups: + computed_lr = self._compute_lr(group) + lrs.append(computed_lr) # type: ignore[possibly-undefined] + + return lrs + + +def build_scheduler( + scheduler_type: str, + optimizer: Optimizer, + lr: float, + total_iter: int, + constructor_kwargs: dict[str, Any], +): + _kwargs: dict[str, Any] = {} + _kwargs.update(**constructor_kwargs) + constructor_fn = SCHEDULERS_DICT[scheduler_type] + accepted_kwargs = signature(constructor_fn).parameters.keys() + keywords = list(constructor_kwargs.keys()) + for key in keywords: + if key not in accepted_kwargs: + # ignore arguments that are not part of kwargs + _kwargs.pop(key) + if scheduler_type in ["OneCycleLR", "WarmupOneCycleLR", "WarmupMultiStepLR"]: + _kwargs.update( + dict( + max_lr=lr, + total_steps=total_iter, + ) + ) + elif scheduler_type in [ + "ConstantLR", + "LinearLR", + "PolynomialLR", + ]: + constructor_kwargs.update(dict(total_iters=total_iter)) + + return constructor_fn(optimizer, **_kwargs) + + +SCHEDULERS_DICT = { + "ConstantLR": torch_schedulers.ConstantLR, + "LinearLR": torch_schedulers.LinearLR, + "MultiStepLR": torch_schedulers.MultiStepLR, + "PolynomialLR": torch_schedulers.PolynomialLR, + "StepLR": torch_schedulers.StepLR, + "OneCycleLR": torch_schedulers.OneCycleLR, + "WarmupOneCycleLR": WarmupOneCycleLR, + "WarmupMultiStepLR": WarmupMultiStepLR, +} diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/train.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/train.py new file mode 100644 index 0000000000000000000000000000000000000000..2f44095ef9b0b1f8af2c7e6ece00eec73c6f1dee --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/train.py @@ -0,0 +1,292 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +import os +from typing import Any, Callable + +import torch +import torch.distributed as dist +from omegaconf import OmegaConf +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim.optimizer import Optimizer +from torch.optim.lr_scheduler import LRScheduler + +import dinov3.distributed as distributed +from dinov3.eval.depth.data import build_dataloader +from dinov3.eval.depth.config import ( + DepthConfig, + ResultConfig, + make_depth_eval_transforms_from_config, + make_depth_train_transforms_from_config, +) + +from dinov3.eval.depth.checkpoint_utils import find_latest_checkpoint, load_checkpoint, save_checkpoint +from dinov3.eval.depth.datasets.datasets_utils import _EvalCropType, make_valid_mask +from dinov3.eval.depth.loss import MultiLoss +from dinov3.eval.depth.models import Depther, make_depther_from_config +from dinov3.eval.depth.metrics import DEPTH_METRICS +from dinov3.eval.depth.schedulers import build_scheduler +from dinov3.eval.depth.eval import evaluate_depther_with_dataloader + +from dinov3.eval.depth.utils import setup_model_ddp +from dinov3.logging import MetricLogger, SmoothedValue +from dinov3.utils import fix_random_seeds + +logger = logging.getLogger("dinov3") + + +class IterBasedTrainer: + def __init__( + self, + config: DepthConfig, + depther: Depther, + train_dataloader: Any, + val_dataloader: Any, + criterion: Callable, + metrics: list[str], + optimizer: Optimizer, + scheduler: LRScheduler, + ): + self._train_dataset_epoch = 0 # a counter for how many times all images in the dataset were seen + self.rank = distributed.get_rank() + + self.depther = depther + self.train_dataloader = train_dataloader + self.val_dataloader = val_dataloader + self.train_dataloader.sampler.set_epoch(self._train_dataset_epoch) + torch.backends.cudnn.benchmark = True + self.optimizer = optimizer + self.scheduler = scheduler + self.criterion = criterion + + self.global_step = 0 + self.total_iter = config.scheduler.total_iter + self._halt_trainer = False + + # filter out metrics that won't be tracked + self.metrics = [metric for metric in DEPTH_METRICS if metric.name in metrics] + self.config = config + + def train_on_batch(self, batch): + assert not self._halt_trainer + + device = self.rank + batch_img, depth_gt = batch + batch_img = batch_img.to(device) + depth_gt = depth_gt.to(device) + + valid_mask = make_valid_mask( + depth_gt, + eval_crop=_EvalCropType(self.config.transforms.train.eval_mask), + ignored_value=self.config.eval.ignored_value, + ).bool() + # mask out pixels outside of valid region + depth_gt[~valid_mask] = self.config.eval.ignored_value + + self.optimizer.zero_grad(set_to_none=True) + + # c) forward pass + pred = self.depther(batch_img) + # d) resize -if necessary- prediction to match size of gt + + if depth_gt.shape != pred.shape: + pred = torch.nn.functional.interpolate( + input=pred, + size=depth_gt.shape[2:], + mode="bilinear", + align_corners=False, + ) + loss = self.criterion(pred, depth_gt, valid_mask) + + # e) optimization + loss.backward() + torch.nn.utils.clip_grad_norm_( + [p for p in self.depther.parameters() if p.requires_grad], + self.config.optimizer.gradient_clip, + ) + self.optimizer.step() + self.scheduler.step() + + self.global_step += 1 + if self.global_step >= self.total_iter: + self._halt_trainer = False + + # update epoch for dataset + if self.global_step % len(self.train_dataloader): + self._train_dataset_epoch += 1 + self.train_dataloader.sampler.set_epoch(self._train_dataset_epoch) + + return loss + + def validate(self): + """ + Runs evaluation on the validation set + + Returns True if the selected target metric is better than the previous best one + """ + depther = self.depther + # unwrap DDP + if isinstance(depther, DDP): + depther = depther.module + + self.depther.eval() + new_metric_values_dict, _, _ = evaluate_depther_with_dataloader( + dataloader=self.val_dataloader, + depther=self.depther, + device=self.rank, + metrics=self.metrics, + eval_range=(self.config.eval.min_depth, self.config.eval.max_depth), + result_config=ResultConfig(save_results=False), + save_dir=self.config.output_dir, + ignored_value=self.config.eval.ignored_value, + eval_mask_type=self.config.transforms.eval.eval_mask, + align_least_squares=self.config.eval.align_least_squares, + use_tta=False, + ) + + # put model back to train mode after validation + self.depther.decoder.train() + + +def run_epochs(config: DepthConfig, backbone: torch.nn.Module, autocast_dtype: torch.dtype): + n_gpus = distributed.get_world_size() + + # 1- define decoder(s) and optimizer + optim_param_groups = [] + depther = make_depther_from_config( + backbone, + config.decoder_head, + autocast_dtype=autocast_dtype, + ) + depther.train() + if torch.cuda.is_available(): + depther = depther.cuda() + optim_param_groups.append( + { + "params": depther.decoder.parameters(), + "lr": config.optimizer.lr, + "betas": (config.optimizer.beta1, config.optimizer.beta2), + "weight_decay": config.optimizer.weight_decay, + } + ) + depther.decoder = setup_model_ddp(depther.decoder, device=distributed.get_rank()) + optimizer = torch.optim.AdamW(optim_param_groups) + + # 2- define scheduler + scheduler = build_scheduler( + config.scheduler.type, + optimizer=optimizer, + lr=config.optimizer.lr, + total_iter=config.scheduler.total_iter, + constructor_kwargs=config.scheduler.constructor_kwargs, + ) + + # 3- define transforms and dataloaders + train_transforms = make_depth_train_transforms_from_config(config) + val_transforms = make_depth_eval_transforms_from_config(config, split="val") + train_dataloader = build_dataloader( + transforms=train_transforms, + dataset_str=getattr(config.datasets, "train") + f":root={config.datasets.root}", + device=torch.cuda.current_device(), + split="train", + batch_size=config.bs, + n_gpus=n_gpus, + num_workers=n_gpus, + use_init_fn=True, + ) + + val_dataloader = build_dataloader( + transforms=val_transforms, + dataset_str=getattr(config.datasets, "val") + f":root={config.datasets.root}", + device=torch.cuda.current_device(), + split="val", + batch_size=1, + n_gpus=n_gpus, + num_workers=n_gpus, + ) + + # 4- define criterion + assert config.losses is not None, "No loss defined for training" + criterion = MultiLoss(dict_losses=config.losses) + + trainer = IterBasedTrainer( + config, + depther=depther, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + criterion=criterion, + metrics=config.metrics, + optimizer=optimizer, + scheduler=scheduler, + ) + + if config.load_from is None: + config.load_from = find_latest_checkpoint(config.output_dir) # returns "" if the path does not exist + + if config.load_from is not None: + logger.info(f"RESUMING CHECKPOINT from {config.load_from}") + chkpt, iteration = load_checkpoint(config.load_from) + depther.decoder.load_state_dict(chkpt["model"]) + optimizer.load_state_dict(chkpt["optimizer"]) + trainer.global_step = iteration or float("inf") # type: ignore + + metric_logger = MetricLogger(delimiter=" ") + metric_logger.add_meter("loss", SmoothedValue(window_size=4, fmt="{value:.3f}")) + logger.info(f"Built trainer with start: {trainer.global_step} | total_iter {trainer.total_iter}") + + for batch in metric_logger.log_every( + trainer.train_dataloader, + 50, + header="Train: ", + start_iteration=trainer.global_step, + n_iterations=trainer.total_iter, + ): + if trainer.global_step >= trainer.total_iter: + break + + loss = trainer.train_on_batch(batch) + metric_logger.update(loss=loss) + + if trainer.global_step % config.eval.eval_interval == 0: + dist.barrier() + trainer.validate() + if distributed.is_main_process(): + save_checkpoint( + config.output_dir, + iteration=trainer.global_step, + model=trainer.depther.decoder, + optimizer=trainer.optimizer, + ) + metric_logger.synchronize_between_processes() + + # one last validation only if the number of total iterations is NOT divisible by eval interval: + if trainer.total_iter % config.eval.eval_interval: + trainer.validate() + metric_logger.synchronize_between_processes() + + logger.info("done!") + if distributed.is_main_process(): + save_checkpoint( + config.output_dir, + iteration=trainer.global_step, + model=trainer.depther.decoder, + optimizer=trainer.optimizer, + ) + + # load selected model checkpoint and return the model + dist.barrier() + depther.eval() + return depther + + +def train_model_with_backbone(config: DepthConfig, backbone: torch.nn.Module, autocast_dtype: torch.dtype): + fix_random_seeds(config.seed + distributed.get_rank()) + + depth_file_path = os.path.join(config.output_dir, "depth_config.yaml") + OmegaConf.save(config=config, f=depth_file_path) + logger.info(f"Config:\n{OmegaConf.to_yaml(config)}") + trained_model = run_epochs(config=config, backbone=backbone, autocast_dtype=autocast_dtype) + return trained_model diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/transforms.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..da27e32fce978982005a4af23e51cd7f0f48b8e5 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/transforms.py @@ -0,0 +1,382 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from enum import Enum +from typing import Callable + +import numpy as np +import torch +import torchvision.transforms as T +from torchvision.transforms import v2 +from torchvision import tv_tensors + + +import torchvision.transforms.functional as TF +from PIL import Image + +from dinov3.data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +class _FixedCropType(Enum): + NYU = "NYU" + FULL = "FULL" + + +class Aug: + def __call__(self, x): + raise NotImplementedError + + def inverse(self, x): + raise NotImplementedError("This function has no inverse!") + + +class ColorAug(torch.nn.Module): + """Color augmentation used in depth estimation + + Args: + prob (float, optional): The color augmentation probability. Default: None. + gamma_range(tuple[float, float], optional): Gamma range for augmentation. Default: (0.9, 1.1). + brightness_range(tuple[float, float], optional): Brightness range for augmentation. Default: (0.9, 1.1). + color_range(tuple[float, float], optional): Color range for augmentation. Default: (0.9, 1.1). + """ + + def __init__(self, prob=None, gamma_range=(0.9, 1.1), brightness_range=(0.9, 1.1), color_range=(0.9, 1.1)): + super().__init__() + self.prob = prob + self.gamma_range = gamma_range + self.brightness_range = brightness_range + self.color_range = color_range + if prob is not None: + assert prob >= 0 and prob <= 1 + + def __call__(self, img): + """Call function to apply color augmentation. + + Args: + img: Data to transform. + + Returns: + img: Randomly colored data. + """ + aug = True if np.random.rand() < self.prob else False + if aug: + image = img.permute((1, 2, 0)) * 255 # 256, 256, 3 + + # gamma augmentation + gamma = np.random.uniform(min(*self.gamma_range), max(*self.gamma_range)) + image_aug = image**gamma + + # brightness augmentation + brightness = np.random.uniform(min(*self.brightness_range), max(*self.brightness_range)) + image_aug = image_aug * brightness + + # color augmentation + colors = np.random.uniform(min(*self.color_range), max(*self.color_range), size=3) + white = np.ones((image.shape[0], image.shape[1]), dtype=np.float32) + color_image = np.stack([white * colors[i] for i in range(3)], axis=2) + image_aug *= color_image + image_aug = np.clip(image_aug, 0, 255) + image_aug = image_aug / 255 + + return image_aug.permute((2, 0, 1)) + return img + + +class ColorAugV2(torch.nn.Module): + """Color augmentation used in depth estimation + + Args: + prob (float, optional): The color augmentation probability. Default: None. + gamma_range(tuple[float, float], optional): Gamma range for augmentation. Default: (0.9, 1.1). + brightness_range(tuple[float, float], optional): Brightness range for augmentation. Default: (0.9, 1.1). + color_range(tuple[float, float], optional): Color range for augmentation. Default: (0.9, 1.1). + """ + + def __init__(self, prob=None, gamma_range=(0.9, 1.1), brightness_range=(0.9, 1.1), color_range=(0.9, 1.1)): + super().__init__() + self.prob = prob + self.img_transform = ColorAug( + prob=prob, gamma_range=gamma_range, brightness_range=brightness_range, color_range=color_range + ) + + def __call__(self, img, label): + return self.img_transform(img), label + + def __repr__(self): + repr = "ColorAug(" + repr += f"\n\tgamma_range={self.img_transform.gamma_range}," + repr += f"\n\tbrightness_range={self.img_transform.brightness_range}," + repr += f"\n\tcolor_range={self.img_transform.color_range}," + repr += f"\n\tprob={self.prob}," + repr += ")" + return repr + + +class LeftRightFlipAug(Aug): + """ + Test time augmentation for depth estimation + from https://github.com/open-mmlab/mmcv/blob/main/mmcv/transforms/processing.py#L721 + + this is just returning two versions of the same image, and the according labels + """ + + def __init__( + self, + flip: bool = False, + ): + self._flip = flip + + def __call__(self, img, label=None): + """Call function to apply test time augment transforms on results. + + Args: + img: Data to transform. + + Returns: + list: A list of augmented data. + """ + + do_flips = [False, True] if self._flip else [False] + results_images = [] + results_labels = [] + + for do_flip in do_flips: + image_aug = TF.hflip(img) if do_flip else img + results_images.append(image_aug) + label_aug = TF.hflip(label) if do_flip else label + results_labels.append(label_aug) + + return results_images, results_labels + + def inverse(self, stacked_left_right_pair: torch.Tensor) -> torch.Tensor: + if not self._flip: + return stacked_left_right_pair + + pre_aug_batch_size = stacked_left_right_pair.shape[0] // 2 + assert pre_aug_batch_size * 2 == stacked_left_right_pair.shape[0] + return ( + stacked_left_right_pair[:pre_aug_batch_size] + TF.hflip(stacked_left_right_pair[pre_aug_batch_size:]) + ) / 2 + + +class NormalizeDepth(torch.nn.Module): + def __init__(self, normalization_factor): + super().__init__() + self.factor = normalization_factor + assert self.factor > 1e-6, f"Normalization factor should be > 1e-6, got {self.factor}" + + def forward(self, img, label): + assert label is not None + label = Depth(label / self.factor) # have to rewrap otherwise it becomes a torch.Tensor + return img, label + + def __repr__(self): + repr = f"NormalizeDepth(normalization_factor={self.factor})" + return repr + + +class NYUCrop: + def __init__(self, crop_box: tuple[int, int, int, int] = (43, 45, 608, 472)): + """NYU standard krop when training monocular depth estimation on NYU dataset. + + Args: + crop_box: (x1, y1, x2, y2) of cropped region. + """ + self._orig_width = 640 + self._orig_height = 480 + self._x1, self._y1, self._x2, self._y2 = crop_box + + def __call__(self, img): + """Call function to apply NYUCrop on images.""" + orig_h, orig_w = 480, 640 + w, h = img.size if isinstance(img, Image.Image) else img.shape[-2:][::-1] + y1_new = int((self._y1 / orig_h) * h) + y2_new = int((self._y2 / orig_h) * h) + x1_new = int((self._x1 / orig_w) * w) + x2_new = int((self._x2 / orig_w) * w) + if isinstance(img, Image.Image): + output_img = img.crop((x1_new, y1_new, x2_new, y2_new)) + elif isinstance(img, (torch.Tensor, np.ndarray)): + output_img = img[..., y1_new:y2_new, x1_new:x2_new] + else: + raise NotImplementedError(f"got unsupported input type {type(img)}") + return output_img + + +class ResizeV2: + """ + Resize both image and label using different interpolation modes. + """ + + def __init__( + self, + size, + image_interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, + label_interpolation: T.InterpolationMode = T.InterpolationMode.NEAREST, + resize_label: bool = False, + ): + self.size = size + self.image_interpolation = image_interpolation + self.label_interpolation = label_interpolation + self.resize_label = resize_label + + def __call__(self, img, label): + img = T.Resize(size=self.size, interpolation=self.image_interpolation)(img) + if self.resize_label: + label = T.Resize(size=self.size, interpolation=self.label_interpolation)(label) + return img, label + + def __repr__(self): + repr = f"Resize(img_size={self.size},label=" + repr += "None)" if not self.resize_label else f"{self.size})" + return repr + + +class FixedCrop(torch.nn.Module): + def __init__(self, crop_type: _FixedCropType | str): + super().__init__() + if isinstance(crop_type, str): + crop_type = _FixedCropType(crop_type) + self.crop: Callable + if crop_type == _FixedCropType.NYU: + self.crop = NYUCrop() + elif crop_type == _FixedCropType.FULL: + self.crop = lambda x: x + self.crop_type = crop_type + + def forward(self, img, label): + img = self.crop(img) + if label is not None: + label = self.crop(label) + return img, label + + def __repr__(self): + repr = f"FixedCrop({self.crop_type})" + return repr + + +class MaybeApply(torch.nn.Module): + def __init__(self, transform, threshold: float = 0.5): + super().__init__() + self._transform = transform + self._threshold = threshold + + def forward(self, img, label): + x = np.random.rand() + if x < self._threshold: + return self._transform(img, label) + return img, label + + +class Depth(tv_tensors.Mask): + pass + + +class ToRGBDTensorPair(torch.nn.Module): + """Read segmentation mask from arrays or PIL images""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, img, label): + img = T.ToTensor()(img) + if isinstance(label, Image.Image): + label = Depth(label, dtype=torch.uint16) + return img, label + + +# Add custom parameters to Resize and Rotate transforms for Depth (use Nearest interpolation) +# https://docs.pytorch.org/vision/master/auto_examples/transforms/plot_custom_tv_tensors.html + + +@v2.functional.register_kernel(functional="resize", tv_tensor_cls=Depth) +def depth_resize(my_dp, size): + out = TF.resize(my_dp, size=size, interpolation=T.InterpolationMode.NEAREST, antialias=True) + return tv_tensors.wrap(out, like=my_dp) + + +@v2.functional.register_kernel(functional="rotate", tv_tensor_cls=Depth) +def depth_rotate(my_dp, angle, *args, **kwargs): + out = TF.rotate(my_dp, angle=angle, interpolation=T.InterpolationMode.NEAREST) + return tv_tensors.wrap(out, like=my_dp) + + +def make_depth_train_transforms( + *, + normalization_constant: float = 1.0, + rotation_angle: float = 2.5, + interpolation=T.InterpolationMode.BILINEAR, + img_size: int | tuple[int, int] | None = None, + random_crop_size: tuple[int, int] | None = (352, 704), + fixed_crop: str = "FULL", + mean: tuple[float, float, float] = IMAGENET_DEFAULT_MEAN, + std: tuple[float, float, float] = IMAGENET_DEFAULT_STD, + brightness_range: tuple[float, float] = (0.9, 1.1), +): + # Fixed geometric transforms + transforms_list: list[Callable] = [] + transforms_list.append(FixedCrop(_FixedCropType(fixed_crop))) + if img_size is not None: + transforms_list.append( + ResizeV2( + img_size, + image_interpolation=interpolation, + label_interpolation=T.InterpolationMode.NEAREST, + resize_label=True, + ) + ) + + # To (TV)tensor + transforms_list.append(ToRGBDTensorPair()) + transforms_list.append(NormalizeDepth(normalization_constant)) + + # Random geometric augmentations + transforms_list.append( + MaybeApply( + v2.Compose( + [ + v2.RandomRotation(degrees=rotation_angle, interpolation=interpolation), + ] + ), + threshold=0.5, + ) + ) + transforms_list.append(v2.RandomHorizontalFlip()) + transforms_list.append(v2.RandomCrop(random_crop_size)) + + # Random color augmentations + transforms_list.append(ColorAugV2(prob=0.5, brightness_range=brightness_range)) + + # Normalize image + transforms_list.append(v2.Normalize(mean=mean, std=std)) + + return v2.Compose(transforms_list) + + +def make_depth_eval_transforms( + *, + normalization_constant: float = 1.0, + img_size: int | tuple[int, int] | None = None, + image_interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, + mean: tuple[float, float, float] = IMAGENET_DEFAULT_MEAN, + std: tuple[float, float, float] = IMAGENET_DEFAULT_STD, + fixed_crop: str = "FULL", + tta: bool = False, +): + transforms_list: list[Callable] = [] + # Apply the fixed evaluation crop + transforms_list.append(FixedCrop(fixed_crop)) + + # Convert image and depth to tensors + transforms_list.append(ToRGBDTensorPair()) + if img_size: + # don't resize the label for evaluation + transforms_list.append(ResizeV2(size=img_size, image_interpolation=image_interpolation, resize_label=False)) + + # Normalize input image and depth + transforms_list.append(v2.Normalize(mean=mean, std=std)) + transforms_list.append(NormalizeDepth(normalization_constant)) + transforms_list.append(LeftRightFlipAug(flip=tta)) + return v2.Compose(transforms_list) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/utils.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..868152ab974c5698aedcea1734d9d2f2dca66e6f --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/utils.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +import numpy as np + +import torch +from torch.nn.parallel import DistributedDataParallel as DDP + +logger = logging.getLogger("dinov3") + + +def align_depth_least_square( + gt_arr: np.ndarray | torch.Tensor, + pred_arr: np.ndarray | torch.Tensor, + valid_mask_arr: np.ndarray | torch.Tensor, + max_resolution=None, +): + """ + Adapted from Marigold + https://github.com/prs-eth/Marigold/blob/62413d56099d36573b2de1eb8c429839734b7782/src/util/alignment.py#L8 + """ + ori_shape = pred_arr.shape # input shape + dtype = pred_arr.dtype + if isinstance(pred_arr, torch.Tensor): + assert isinstance(gt_arr, torch.Tensor) and isinstance(valid_mask_arr, torch.Tensor) + pred_arr = pred_arr.to(torch.float32) # unsupported other types + device = gt_arr.device + gt_arr = gt_arr.detach().cpu().numpy() + pred_arr = pred_arr.detach().cpu().numpy() + valid_mask_arr = valid_mask_arr.detach().cpu().numpy() + + gt = gt_arr.squeeze() # [H, W] + pred = pred_arr.squeeze() + valid_mask = valid_mask_arr.squeeze() + + # Downsample + if max_resolution is not None: + scale_factor = np.min(max_resolution / np.array(ori_shape[-2:])) + if scale_factor < 1: + downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") + gt = downscaler(torch.as_tensor(gt).unsqueeze(0)).numpy() + pred = downscaler(torch.as_tensor(pred).unsqueeze(0)).numpy() + valid_mask = downscaler(torch.as_tensor(valid_mask).unsqueeze(0).float()).bool().numpy() + + assert gt.shape == pred.shape == valid_mask.shape, f"{gt.shape}, {pred.shape}, {valid_mask.shape}" + + gt_masked = gt[valid_mask].reshape((-1, 1)) + pred_masked = pred[valid_mask].reshape((-1, 1)) + + # numpy solver + _ones = np.ones_like(pred_masked) + A = np.concatenate([pred_masked, _ones], axis=-1) + try: + X = np.linalg.lstsq(A, gt_masked, rcond=None)[0] + scale, shift = X + except np.linalg.LinAlgError: + scale = 1 + shift = 0 + logger.info(f"Found wrong depth: \n Pred m:{pred_arr.min()} M:{pred_arr.max()} mean: {pred_arr.mean()}") + logger.info(f"Gt m:{gt_arr.min()} M:{gt_arr.max()} mean: {gt_arr.mean()}") + + aligned_pred = pred_arr * scale + shift + + # restore dimensions + aligned_pred = aligned_pred.reshape(ori_shape) + if isinstance(aligned_pred, np.ndarray): + aligned_pred = torch.as_tensor(aligned_pred, dtype=dtype, device=device) + return aligned_pred, scale, shift + + +def setup_model_ddp(model: torch.nn.Module, device: torch.device | int): + model = DDP(model.to(device), device_ids=[device]) + logger.info(f"Model moved to rank {device}") + return model diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/visualization_utils.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/visualization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e681f45bf90f43faa0eb67bc934bd869bd37d71a --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/depth/visualization_utils.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import os +from typing import Callable + +import matplotlib +import numpy as np +import torch +import torchvision.transforms as transforms +from PIL import Image + +from dinov3.eval.depth.config import ResultConfig, ResultExtension + +from dinov3.data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def alpha_blend(img_pil: Image.Image, mask_rgb: np.ndarray, alpha: float = 0.5) -> Image.Image: + img_rgba = img_pil.convert("RGBA") + mask_alpha = np.full_like(mask_rgb, np.round(255 * alpha))[..., 0:1] + mask_rgba = Image.fromarray(np.concatenate([mask_rgb, mask_alpha], axis=-1).astype(np.uint8)) + overlay = Image.alpha_composite(img_rgba, mask_rgba) + return overlay.convert("RGB") + + +def normalized_tensor_to_pil( + tensor: torch.Tensor, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, +) -> Image.Image: + """ + Transforms a normalized image tensor back into PIL image. + """ + assert tensor.ndim == 3 and tensor.shape[0] == 3, f"input should be 3xHxW, got {tensor.shape}" + std = torch.tensor(std, device=tensor.device)[:, None, None] + mean = torch.tensor(mean, device=tensor.device)[:, None, None] + unnormalized_tensor = torch.clamp(tensor * std + mean, 0.0, 1.0) + return transforms.functional.to_pil_image(unnormalized_tensor) + + +def depth_tensor_to_colorized_pil(depth_tensor: torch.Tensor, cmap="plasma", vmin=None, vmax=None): + # derived from https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox/blob/main/depth/utils/color_depth.py + value = depth_tensor.detach().cpu().reshape(depth_tensor.shape[-2:]).numpy() + # normalize + vmin = value.min() if vmin is None else vmin + vmax = value.max() if vmax is None else vmax + if vmin != vmax: + value = (value - vmin) / (vmax - vmin) # vmin..vmax + else: + # Avoid 0-division + value = value * 0.0 + cmapper = matplotlib.cm.get_cmap(cmap) + value = cmapper(value, bytes=True) # ((1)xhxwx4) + value = value[:, :, :3] # bgr -> rgb + rgb_value = value # [..., ::-1] + return Image.fromarray(rgb_value) + + +def save_raw_predictions( + img: torch.Tensor, # [1, 3, H, W] + pred: torch.Tensor, # [1, C, H, W] + gt: torch.Tensor, # [1, C, H, W] + save_dir: str, + save_index: int, +) -> None: + torch.save( + { + "image": transforms.functional.to_tensor(normalized_tensor_to_pil(img[0])).cpu(), + "pred": pred.detach().cpu(), + "target": gt.cpu(), + }, + os.path.join(save_dir, f"results_{int(save_index)}.pth"), + ) + + +def get_prediction_images( + img: torch.Tensor, # [1, 3, H, W] + pred: torch.Tensor, # [1, C, H, W] + gt: torch.Tensor, # [1, C, H, W] + pred_tensor_to_pil: Callable[[torch.Tensor], Image.Image], + alpha: float = 1.0, +) -> tuple[Image.Image, Image.Image, Image.Image]: + img_pil = normalized_tensor_to_pil(img[0]) + assert pred.shape[0] == gt.shape[0] == 1 + pred_pil = pred_tensor_to_pil(pred[0].cpu()) + gt_pil = pred_tensor_to_pil(gt[0].cpu()) + if img_pil.size != gt_pil.size: + img_pil = img_pil.resize(gt_pil.size) + + if alpha < 1.0: + pred_pil = alpha_blend(img_pil, np.array(pred_pil), alpha) + gt_pil = alpha_blend(img_pil, np.array(gt_pil), alpha) + return img_pil, pred_pil, gt_pil + + +def resize_results( + img_pil: Image.Image, + pred_pil: Image.Image, + gt_pil: Image.Image, + resolution: int, +) -> tuple[Image.Image, Image.Image, Image.Image]: + img_pil = transforms.functional.resize( + img_pil, + resolution, + interpolation=transforms.InterpolationMode.BILINEAR, + ) + pred_pil, gt_pil = [ + transforms.functional.resize( + x, + resolution, + interpolation=transforms.InterpolationMode.NEAREST, + ) + for x in [pred_pil, gt_pil] + ] + return img_pil, pred_pil, gt_pil + + +def save_predictions( + img: torch.Tensor, # [1, 3, H, W] + pred: torch.Tensor, # [1, C, H, W] + gt: torch.Tensor, # [1, C, H, W] + result_config: ResultConfig, + save_dir: str, + save_index: int, + pred_tensor_to_pil_fn: Callable, +) -> None: + vis_save_dir = os.path.join(save_dir, "visualizations") + os.makedirs(vis_save_dir, exist_ok=True) + if result_config.extension == ResultExtension.PTH: + save_raw_predictions( + img=img, + pred=pred, + gt=gt, + save_dir=vis_save_dir, + save_index=save_index, + ) + return + + img_pil, pred_pil, gt_pil = get_prediction_images( + img=img, + pred=pred, + gt=gt, + pred_tensor_to_pil=pred_tensor_to_pil_fn, + alpha=result_config.overlay_alpha, + ) + + if result_config.save_resolution: + img_pil, pred_pil, gt_pil = resize_results(img_pil, pred_pil, gt_pil, result_config.save_resolution) + + ext = result_config.extension.value + if result_config.save_separate_files: + img_pil.save(os.path.join(vis_save_dir, f"image_{int(save_index)}.{ext}")) + pred_pil.save(os.path.join(vis_save_dir, f"pred_{int(save_index)}.{ext}")) + gt_pil.save(os.path.join(vis_save_dir, f"gt_{int(save_index)}.{ext}")) + + band = np.concatenate([np.array(segm) for segm in [pred_pil, gt_pil]], axis=1) # horizontal band + band = Image.fromarray(band) + band.save(os.path.join(vis_save_dir, f"results_{int(save_index)}.{ext}")) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/config.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/config.py new file mode 100644 index 0000000000000000000000000000000000000000..38cfea3f3e021880e41321dfd419c3e20cbeceb4 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/config.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from dataclasses import dataclass + +from .models.position_encoding import PositionEncoding + + +@dataclass(kw_only=True) +class DetectionHeadConfig: + num_classes: int = 91 # 91 classes in COCO + # Deformable DETR tricks + with_box_refine: bool = True + two_stage: bool = True + # DINO DETR tricks + mixed_selection: bool = True + look_forward_twice: bool = True # was default False + # Hybrid Matching tricks + k_one2many: int = 6 # was 5 + lambda_one2many: float = 1.0 + num_queries_one2one: int = 300 # number of query slots for one_to_one matching + num_queries_one2many: int = 1500 # was 0, number of query slots for one_to_many matching + """ + Absolute coordinates & box regression reparameterization. + If true, we use absolute coordindates & reparameterization for bounding boxes. + """ + reparam: bool = True + topk: int = 100 + + # * Backbone + # type of positional embedding to use on top of the image features + position_embedding: PositionEncoding = PositionEncoding.SINE + num_feature_levels: int = 1 # number of feature levels + + # * Transformer + dec_layers: int = 6 # number of decoding layers in the transformer + dim_feedforward: int = 2048 # intermediate size of the feedforward layers in the transformer blocks + hidden_dim: int = 256 # size of the embeddings (dimension of the transformer) + dropout: float = 0.0 # dropout applied in the transformer, was 0.1 + nheads: int = 8 # number of attention heads inside the transformer's attentions + norm_type: str = "pre_norm" + + # Loss + aux_loss: bool = True # auxiliary decoding losses (loss at each layer) + + # * dev: proposals + proposal_feature_levels: int = 4 # was 1 + proposal_min_size: int = 50 + # * dev decoder: global decoder + decoder_type: str = "global_rpe_decomp" # was deform + decoder_use_checkpoint: bool = False + decoder_rpe_hidden_dim: int = 512 + decoder_rpe_type: str = "linear" + + # Custom + add_transformer_encoder: bool = True + num_encoder_layers: int = 6 + layers_to_use: list[int] | None = None + blocks_to_train: list[int] | None = None + n_windows_sqrt: int = 0 + proposal_in_stride: int | None = None + proposal_tgt_strides: list[int] | None = None + backbone_use_layernorm: bool = False # whether to use layernorm on each layer of the backbone's features diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5375bc66e1ed841a7091b81a0dcf56d1993c1f87 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/backbone.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..edc1c69dd328769acfb7319db940bf08dced43ff --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/backbone.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# ------------------------------------------------------------------------ +# Plain-DETR +# Copyright (c) 2023 Xi'an Jiaotong University & Microsoft Research Asia. +# Licensed under The MIT License [see LICENSE for details] +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Backbone modules. +""" +import logging +from typing import List, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ..util.misc import NestedTensor +from .position_encoding import build_position_encoding +from .utils import LayerNorm2D +from .windows import WindowsWrapper + +logger = logging.getLogger("dinov3") + + +class DINOBackbone(nn.Module): + def __init__( + self, + backbone_model: nn.Module, + train_backbone: bool, + blocks_to_train: Optional[List[str]] = None, + layers_to_use: Union[int, List] = 1, + use_layernorm: bool = True, + ): + super().__init__() + self.backbone = backbone_model + self.blocks_to_train = blocks_to_train + self.patch_size = self.backbone.patch_size + self.use_layernorm = use_layernorm + + for _, (name, parameter) in enumerate(self.backbone.named_parameters()): + train_condition = any(f".{b}." in name for b in self.blocks_to_train) if self.blocks_to_train else True + if (not train_backbone) or "mask_token" in name or (not train_condition): + parameter.requires_grad_(False) + + self.strides = [self.backbone.patch_size] + + # get embed_dim for each intermediate output + n_all_layers = self.backbone.n_blocks + blocks_to_take = ( + range(n_all_layers - layers_to_use, n_all_layers) if isinstance(layers_to_use, int) else layers_to_use + ) + + # if models do not define embed_dims, repeat embed_dim n_blocks times + embed_dims = getattr(self.backbone, "embed_dims", [self.backbone.embed_dim] * self.backbone.n_blocks) + embed_dims = [embed_dims[i] for i in range(n_all_layers) if i in blocks_to_take] + + if self.use_layernorm: + self.layer_norms = nn.ModuleList([LayerNorm2D(embed_dim) for embed_dim in embed_dims]) + + self.num_channels = [sum(embed_dims)] + self.layers_to_use = layers_to_use + + def forward(self, tensor_list: NestedTensor): + xs = self.backbone.get_intermediate_layers(tensor_list.tensors, n=self.layers_to_use, reshape=True) + if self.use_layernorm: + xs = [ln(x).contiguous() for ln, x in zip(self.layer_norms, xs)] + + xs = [torch.cat(xs, axis=1)] + + out: list[NestedTensor] = [] + for x in xs: + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + out.append(NestedTensor(x, mask)) + return out + + +class BackboneWithPositionEncoding(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + self.strides = backbone.strides + self.num_channels = backbone.num_channels + + def forward(self, tensor_list: NestedTensor): + out: List[NestedTensor] = list(self[0](tensor_list)) + pos = [self[1][idx](x).to(x.tensors.dtype) for idx, x in enumerate(out)] + return out, pos + + +def build_backbone(backbone_model, args): + position_embedding = build_position_encoding(args) + train_backbone = False + backbone = DINOBackbone( + backbone_model, train_backbone, args.blocks_to_train, args.layers_to_use, args.backbone_use_layernorm + ) + if args.n_windows_sqrt > 0: + logger.info(f"Wrapping with {args.n_windows_sqrt} x {args.n_windows_sqrt} windows") + backbone = WindowsWrapper( + backbone, n_windows_w=args.n_windows_sqrt, n_windows_h=args.n_windows_sqrt, patch_size=backbone.patch_size + ) + else: + logger.info("Not wrapping with windows") + + return BackboneWithPositionEncoding(backbone, position_embedding) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/detr.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/detr.py new file mode 100644 index 0000000000000000000000000000000000000000..c75a74f89f3b8e8255b0e4ad8ec2c7c5321204e4 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/detr.py @@ -0,0 +1,462 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# ------------------------------------------------------------------------ +# Plain-DETR +# Copyright (c) 2023 Xi'an Jiaotong University & Microsoft Research Asia. +# Licensed under The MIT License [see LICENSE for details] +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Deformable DETR model and criterion classes. +""" +import math + +import torch +import torch.nn.functional as F +from torch import nn + +from ..util import box_ops +from ..util.misc import NestedTensor, _get_clones, inverse_sigmoid, nested_tensor_from_tensor_list +from .backbone import build_backbone +from .transformer import build_transformer + + +class PlainDETR(nn.Module): + """This is the Deformable DETR module that performs object detection""" + + def __init__( + self, + backbone, + transformer, + num_classes, + num_feature_levels, + aux_loss=True, + with_box_refine=False, + two_stage=False, + num_queries_one2one=300, + num_queries_one2many=0, + mixed_selection=False, + ): + """Initializes the model. + Parameters: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_classes: number of object classes + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + with_box_refine: iterative bounding box refinement + two_stage: two-stage Deformable DETR + num_queries_one2one: number of object queries for one-to-one matching part + num_queries_one2many: number of object queries for one-to-many matching part + mixed_selection: a trick for Deformable DETR two stage + + """ + super().__init__() + num_queries = num_queries_one2one + num_queries_one2many + self.num_queries = num_queries + self.transformer = transformer + hidden_dim = transformer.d_model + self.class_embed = nn.Linear(hidden_dim, num_classes) + self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + self.num_feature_levels = num_feature_levels + if not two_stage: + self.query_embed = nn.Embedding(num_queries, hidden_dim * 2) + elif mixed_selection: + self.query_embed = nn.Embedding(num_queries, hidden_dim) + self.input_proj = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + ) + ] + ) + self.backbone = backbone + self.aux_loss = aux_loss + self.with_box_refine = with_box_refine + self.two_stage = two_stage + + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.class_embed.bias.data = torch.ones(num_classes) * bias_value + nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + for proj in self.input_proj: + nn.init.xavier_uniform_(proj[0].weight, gain=1) + nn.init.constant_(proj[0].bias, 0) + + # if two-stage, the last class_embed and bbox_embed is for region proposal generation + num_pred = (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers + if with_box_refine: + self.class_embed = _get_clones(self.class_embed, num_pred) + self.bbox_embed = _get_clones(self.bbox_embed, num_pred) + nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) + # hack implementation for iterative bounding box refinement + self.transformer.decoder.bbox_embed = self.bbox_embed + else: + nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) + self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) + self.transformer.decoder.bbox_embed = None + if two_stage: + # hack implementation for two-stage + self.transformer.decoder.class_embed = self.class_embed + for box_embed in self.bbox_embed: + nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) + + self.num_queries_one2one = num_queries_one2one + self.mixed_selection = mixed_selection + + def forward(self, samples: NestedTensor): + """The forward expects a NestedTensor, which consists of: + - samples.tensor: batched images, of shape [batch_size x 3 x H x W] + - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + + It returns a dict with the following elements: + - "pred_logits": the classification logits (including no-object) for all queries. + Shape= [batch_size x num_queries x (num_classes + 1)] + - "pred_boxes": The normalized boxes coordinates for all queries, represented as + (center_x, center_y, height, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of + dictionnaries containing the two above keys for each decoder layer. + """ + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.backbone(samples) + + srcs = [] + masks = [] + for layer, feat in enumerate(features): + src, mask = feat.decompose() + srcs.append(self.input_proj[layer](src)) + masks.append(mask) + assert mask is not None + + query_embeds = None + if not self.two_stage or self.mixed_selection: + query_embeds = self.query_embed.weight[0 : self.num_queries, :] + + # make attn mask + """ attention mask to prevent information leakage + """ + self_attn_mask = torch.zeros( + [ + self.num_queries, + self.num_queries, + ], + dtype=bool, + device=src.device, + ) + self_attn_mask[ + self.num_queries_one2one :, + 0 : self.num_queries_one2one, + ] = True + self_attn_mask[ + 0 : self.num_queries_one2one, + self.num_queries_one2one :, + ] = True + + ( + hs, + init_reference, + inter_references, + enc_outputs_class, + enc_outputs_coord_unact, + enc_outputs_delta, + output_proposals, + max_shape, + ) = self.transformer(srcs, masks, pos, query_embeds, self_attn_mask) + + outputs_classes_one2one = [] + outputs_coords_one2one = [] + outputs_classes_one2many = [] + outputs_coords_one2many = [] + for lvl in range(hs.shape[0]): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.class_embed[lvl](hs[lvl]) + tmp = self.bbox_embed[lvl](hs[lvl]) + if reference.shape[-1] == 4: + tmp += reference + else: + assert reference.shape[-1] == 2 + tmp[..., :2] += reference + outputs_coord = tmp.sigmoid() + + outputs_classes_one2one.append(outputs_class[:, 0 : self.num_queries_one2one]) + outputs_classes_one2many.append(outputs_class[:, self.num_queries_one2one :]) + + outputs_coords_one2one.append(outputs_coord[:, 0 : self.num_queries_one2one]) + outputs_coords_one2many.append(outputs_coord[:, self.num_queries_one2one :]) + + outputs_classes_one2one = torch.stack(outputs_classes_one2one) + outputs_coords_one2one = torch.stack(outputs_coords_one2one) + + outputs_classes_one2many = torch.stack(outputs_classes_one2many) + outputs_coords_one2many = torch.stack(outputs_coords_one2many) + + out = { + "pred_logits": outputs_classes_one2one[-1], + "pred_boxes": outputs_coords_one2one[-1], + "pred_logits_one2many": outputs_classes_one2many[-1], + "pred_boxes_one2many": outputs_coords_one2many[-1], + } + if self.aux_loss: + out["aux_outputs"] = self._set_aux_loss(outputs_classes_one2one, outputs_coords_one2one) + out["aux_outputs_one2many"] = self._set_aux_loss(outputs_classes_one2many, outputs_coords_one2many) + + if self.two_stage: + enc_outputs_coord = enc_outputs_coord_unact.sigmoid() + out["enc_outputs"] = { + "pred_logits": enc_outputs_class, + "pred_boxes": enc_outputs_coord, + } + return out + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + +class PlainDETRReParam(PlainDETR): + def forward(self, samples: NestedTensor): + """The forward expects a NestedTensor, which consists of: + - samples.tensor: batched images, of shape [batch_size x 3 x H x W] + - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + + It returns a dict with the following elements: + - "pred_logits": the classification logits (including no-object) for all queries. + Shape= [batch_size x num_queries x (num_classes + 1)] + - "pred_boxes": The normalized boxes coordinates for all queries, represented as + (center_x, center_y, height, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of + dictionnaries containing the two above keys for each decoder layer. + """ + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.backbone(samples) + + srcs = [] + masks = [] + for layer, feat in enumerate(features): + src, mask = feat.decompose() + srcs.append(self.input_proj[layer](src)) + masks.append(mask) + assert mask is not None + + query_embeds = None + if not self.two_stage or self.mixed_selection: + query_embeds = self.query_embed.weight[0 : self.num_queries, :] + + # make attn mask + """ attention mask to prevent information leakage + """ + self_attn_mask = torch.zeros( + [ + self.num_queries, + self.num_queries, + ], + dtype=bool, + device=src.device, + ) + self_attn_mask[ + self.num_queries_one2one :, + 0 : self.num_queries_one2one, + ] = True + self_attn_mask[ + 0 : self.num_queries_one2one, + self.num_queries_one2one :, + ] = True + + ( + hs, + init_reference, + inter_references, + enc_outputs_class, + enc_outputs_coord_unact, + enc_outputs_delta, + output_proposals, + max_shape, + ) = self.transformer(srcs, masks, pos, query_embeds, self_attn_mask) + + outputs_classes_one2one = [] + outputs_coords_one2one = [] + outputs_classes_one2many = [] + outputs_coords_one2many = [] + + outputs_coords_old_one2one = [] + outputs_deltas_one2one = [] + outputs_coords_old_one2many = [] + outputs_deltas_one2many = [] + + for lvl in range(hs.shape[0]): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + outputs_class = self.class_embed[lvl](hs[lvl]) + tmp = self.bbox_embed[lvl](hs[lvl]) + if reference.shape[-1] == 4: + outputs_coord = box_ops.box_xyxy_to_cxcywh(box_ops.delta2bbox(reference, tmp, max_shape)) + else: + raise NotImplementedError + + outputs_classes_one2one.append(outputs_class[:, 0 : self.num_queries_one2one]) + outputs_classes_one2many.append(outputs_class[:, self.num_queries_one2one :]) + + outputs_coords_one2one.append(outputs_coord[:, 0 : self.num_queries_one2one]) + outputs_coords_one2many.append(outputs_coord[:, self.num_queries_one2one :]) + + outputs_coords_old_one2one.append(reference[:, : self.num_queries_one2one]) + outputs_coords_old_one2many.append(reference[:, self.num_queries_one2one :]) + outputs_deltas_one2one.append(tmp[:, : self.num_queries_one2one]) + outputs_deltas_one2many.append(tmp[:, self.num_queries_one2one :]) + + outputs_classes_one2one = torch.stack(outputs_classes_one2one) + outputs_coords_one2one = torch.stack(outputs_coords_one2one) + + outputs_classes_one2many = torch.stack(outputs_classes_one2many) + outputs_coords_one2many = torch.stack(outputs_coords_one2many) + + out = { + "pred_logits": outputs_classes_one2one[-1], + "pred_boxes": outputs_coords_one2one[-1], + "pred_logits_one2many": outputs_classes_one2many[-1], + "pred_boxes_one2many": outputs_coords_one2many[-1], + "pred_boxes_old": outputs_coords_old_one2one[-1], + "pred_deltas": outputs_deltas_one2one[-1], + "pred_boxes_old_one2many": outputs_coords_old_one2many[-1], + "pred_deltas_one2many": outputs_deltas_one2many[-1], + } + + if self.aux_loss: + out["aux_outputs"] = self._set_aux_loss( + outputs_classes_one2one, outputs_coords_one2one, outputs_coords_old_one2one, outputs_deltas_one2one + ) + out["aux_outputs_one2many"] = self._set_aux_loss( + outputs_classes_one2many, outputs_coords_one2many, outputs_coords_old_one2many, outputs_deltas_one2many + ) + + if self.two_stage: + out["enc_outputs"] = { + "pred_logits": enc_outputs_class, + "pred_boxes": enc_outputs_coord_unact, + "pred_boxes_old": output_proposals, + "pred_deltas": enc_outputs_delta, + } + return out + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord, outputs_coord_old, outputs_deltas): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [ + { + "pred_logits": a, + "pred_boxes": b, + "pred_boxes_old": c, + "pred_deltas": d, + } + for a, b, c, d in zip(outputs_class[:-1], outputs_coord[:-1], outputs_coord_old[:-1], outputs_deltas[:-1]) + ] + + +class PostProcess(nn.Module): + """This module converts the model's output into the format expected by the coco api""" + + def __init__(self, topk=100, reparam=False): + super().__init__() + self.topk = topk + self.reparam = reparam + + @torch.no_grad() + def forward(self, outputs, target_sizes, original_target_sizes=None): + """Perform the computation + Parameters: + outputs: raw outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"] + + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + assert not self.reparam or original_target_sizes.shape[1] == 2 + + prob = out_logits.sigmoid() + topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), self.topk, dim=1) + scores = topk_values + topk_boxes = topk_indexes // out_logits.shape[2] + labels = topk_indexes % out_logits.shape[2] + boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) + boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + if self.reparam: + img_h, img_w = img_h[:, None, None], img_w[:, None, None] # [BS, 1, 1] + boxes[..., 0::2].clamp_(min=torch.zeros_like(img_w), max=img_w) + boxes[..., 1::2].clamp_(min=torch.zeros_like(img_h), max=img_h) + scale_h, scale_w = (original_target_sizes / target_sizes).unbind(1) + scale_fct = torch.stack([scale_w, scale_h, scale_w, scale_h], dim=1) + else: + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] + + return results + + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +def build_model(backbone_model, args): + backbone = build_backbone(backbone_model, args) + transformer = build_transformer(args) + model_class = PlainDETR if (not args.reparam) else PlainDETRReParam + return model_class( + backbone, + transformer, + num_classes=args.num_classes, + num_feature_levels=args.num_feature_levels, + aux_loss=args.aux_loss, + with_box_refine=args.with_box_refine, + two_stage=args.two_stage, + num_queries_one2one=args.num_queries_one2one, + num_queries_one2many=args.num_queries_one2many, + mixed_selection=args.mixed_selection, + ) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/global_ape_decoder.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/global_ape_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..32b2f9a4c196a81d1448aa32a9991a311a1237d5 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/global_ape_decoder.py @@ -0,0 +1,321 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# ------------------------------------------------------------------------ +# Plain-DETR +# Copyright (c) 2023 Xi'an Jiaotong University & Microsoft Research Asia. +# Licensed under The MIT License [see LICENSE for details] +# ------------------------------------------------------------------------ +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from ..util.misc import _get_activation_fn, _get_clones, inverse_sigmoid + + +class GlobalCrossAttention(nn.Module): + def __init__( + self, + dim, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.k = nn.Linear(dim, dim, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward( + self, + query, + k_input_flatten, + v_input_flatten, + input_padding_mask=None, + ): + B_, N, C = k_input_flatten.shape + k = self.k(k_input_flatten).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.v(v_input_flatten).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + B_, N, C = query.shape + q = self.q(query).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn_mask = None + if input_padding_mask is not None: + attn_mask = input_padding_mask[:, None, None] * -100 + attn_mask = attn_mask.contiguous() # to enable efficient attention + + x = torch.nn.functional.scaled_dot_product_attention( + query=q, + key=k, + value=v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0, + scale=self.scale, + ) + x = x.transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GlobalDecoderLayer(nn.Module): + def __init__( + self, + d_model=256, + d_ffn=1024, + dropout=0.1, + activation="relu", + n_heads=8, + norm_type="post_norm", + ): + super().__init__() + + self.norm_type = norm_type + + # global cross attention + self.cross_attn = GlobalCrossAttention(d_model, n_heads) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # self attention + self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_pre( + self, + tgt, + query_pos, + src, + src_pos_embed, + src_padding_mask=None, + self_attn_mask=None, + ): + # self attention + tgt2 = self.norm2(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn( + q.transpose(0, 1), k.transpose(0, 1), tgt2.transpose(0, 1), attn_mask=self_attn_mask, need_weights=False + )[0].transpose(0, 1) + tgt = tgt + self.dropout2(tgt2) + + # global cross attention + tgt2 = self.norm1(tgt) + tgt2 = self.cross_attn( + self.with_pos_embed(tgt2, query_pos), + self.with_pos_embed(src, src_pos_embed), + src, + src_padding_mask, + ) + tgt = tgt + self.dropout1(tgt2) + + # ffn + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout4(tgt2) + + return tgt + + def forward_post( + self, + tgt, + query_pos, + src, + src_pos_embed, + src_padding_mask=None, + self_attn_mask=None, + ): + # self attention + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn( + q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1), attn_mask=self_attn_mask, need_weights=False + )[0].transpose(0, 1) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # cross attention + tgt2 = self.cross_attn( + self.with_pos_embed(tgt, query_pos), + self.with_pos_embed(src, src_pos_embed), + src, + src_padding_mask, + ) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # ffn + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + + return tgt + + def forward( + self, + tgt, + query_pos, + src, + src_pos_embed, + src_padding_mask=None, + self_attn_mask=None, + ): + if self.norm_type == "pre_norm": + return self.forward_pre(tgt, query_pos, src, src_pos_embed, src_padding_mask, self_attn_mask) + if self.norm_type == "post_norm": + return self.forward_post(tgt, query_pos, src, src_pos_embed, src_padding_mask, self_attn_mask) + + +class GlobalDecoder(nn.Module): + def __init__( + self, + decoder_layer, + num_layers, + return_intermediate=False, + look_forward_twice=False, + use_checkpoint=False, + d_model=256, + norm_type="post_norm", + ): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + self.look_forward_twice = look_forward_twice + self.use_checkpoint = use_checkpoint + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + + self.norm_type = norm_type + if self.norm_type == "pre_norm": + self.final_layer_norm = nn.LayerNorm(d_model) + else: + self.final_layer_norm = None + + def _reset_parameters(self): + # stolen from Swin Transformer + def _init_weights(m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def forward( + self, + tgt, + reference_points, + src, + src_pos_embed, + src_spatial_shapes, + src_level_start_index, + src_valid_ratios, + query_pos=None, + src_padding_mask=None, + self_attn_mask=None, + max_shape=None, + ): + output = tgt + + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + if self.use_checkpoint: + output = checkpoint.checkpoint( + layer, + output, + query_pos, + src, + src_pos_embed, + src_padding_mask, + self_attn_mask, + ) + else: + output = layer( + output, + query_pos, + src, + src_pos_embed, + src_padding_mask, + self_attn_mask, + ) + + if self.final_layer_norm is not None: + output_after_norm = self.final_layer_norm(output) + else: + output_after_norm = output + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: + tmp = self.bbox_embed[lid](output_after_norm) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(output_after_norm) + intermediate_reference_points.append( + new_reference_points if self.look_forward_twice else reference_points + ) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack(intermediate_reference_points) + + return output_after_norm, reference_points + + +def build_global_ape_decoder(args): + decoder_layer = GlobalDecoderLayer( + d_model=args.hidden_dim, + d_ffn=args.dim_feedforward, + dropout=args.dropout, + activation="relu", + n_heads=args.nheads, + norm_type=args.norm_type, + ) + decoder = GlobalDecoder( + decoder_layer, + num_layers=args.dec_layers, + return_intermediate=True, + look_forward_twice=args.look_forward_twice, + use_checkpoint=args.decoder_use_checkpoint, + d_model=args.hidden_dim, + norm_type=args.norm_type, + ) + return decoder diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/global_rpe_decomp_decoder.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/global_rpe_decomp_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..667c6a10b011a423d58f09bfd79ca3bd669adb2b --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/global_rpe_decomp_decoder.py @@ -0,0 +1,443 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# ------------------------------------------------------------------------ +# Plain-DETR +# Copyright (c) 2023 Xi'an Jiaotong University & Microsoft Research Asia. +# Licensed under The MIT License [see LICENSE for details] +# ------------------------------------------------------------------------ +# -*- coding: utf-8 -*- +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from ..util.box_ops import box_xyxy_to_cxcywh, delta2bbox +from ..util.misc import _get_activation_fn, _get_clones, inverse_sigmoid + + +class GlobalCrossAttention(nn.Module): + def __init__( + self, + dim, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + rpe_hidden_dim=512, + rpe_type="linear", + feature_stride=16, + reparam=False, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.rpe_type = rpe_type + self.feature_stride = feature_stride + self.reparam = reparam + + self.cpb_mlp1 = self.build_cpb_mlp(2, rpe_hidden_dim, num_heads) + self.cpb_mlp2 = self.build_cpb_mlp(2, rpe_hidden_dim, num_heads) + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.k = nn.Linear(dim, dim, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def build_cpb_mlp(self, in_dim, hidden_dim, out_dim): + cpb_mlp = nn.Sequential( + nn.Linear(in_dim, hidden_dim, bias=True), nn.ReLU(inplace=True), nn.Linear(hidden_dim, out_dim, bias=False) + ) + return cpb_mlp + + def forward( + self, + query, + reference_points, + k_input_flatten, + v_input_flatten, + input_spatial_shapes, + input_padding_mask=None, + ): + assert len(input_spatial_shapes) == 1, "This is designed for single-scale decoder." + h, w = input_spatial_shapes[0] + stride = self.feature_stride + + ref_pts = torch.cat( + [ + reference_points[:, :, :, :2] - reference_points[:, :, :, 2:] / 2, + reference_points[:, :, :, :2] + reference_points[:, :, :, 2:] / 2, + ], + dim=-1, + ) # B, nQ, 1, 4 + if not self.reparam: + ref_pts[..., 0::2] *= w * stride + ref_pts[..., 1::2] *= h * stride + pos_x = ( + torch.linspace(0.5, w - 0.5, w, dtype=torch.float32, device=ref_pts.device)[None, None, :, None] * stride + ) # 1, 1, w, 1 + pos_y = ( + torch.linspace(0.5, h - 0.5, h, dtype=torch.float32, device=ref_pts.device)[None, None, :, None] * stride + ) # 1, 1, h, 1 + + if self.rpe_type == "abs_log8": + delta_x = ref_pts[..., 0::2] - pos_x # B, nQ, w, 2 + delta_y = ref_pts[..., 1::2] - pos_y # B, nQ, h, 2 + delta_x = torch.sign(delta_x) * torch.log2(torch.abs(delta_x) + 1.0) / np.log2(8) + delta_y = torch.sign(delta_y) * torch.log2(torch.abs(delta_y) + 1.0) / np.log2(8) + elif self.rpe_type == "linear": + delta_x = ref_pts[..., 0::2] - pos_x # B, nQ, w, 2 + delta_y = ref_pts[..., 1::2] - pos_y # B, nQ, h, 2 + else: + raise NotImplementedError + + rpe_x, rpe_y = self.cpb_mlp1(delta_x), self.cpb_mlp2(delta_y) # B, nQ, w/h, nheads + rpe = (rpe_x[:, :, None] + rpe_y[:, :, :, None]).flatten(2, 3) # B, nQ, h, w, nheads -> B, nQ, h*w, nheads + rpe = rpe.permute(0, 3, 1, 2) + + B_, N, C = k_input_flatten.shape + k = self.k(k_input_flatten).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.v(v_input_flatten).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + B_, N, C = query.shape + q = self.q(query).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn_mask = rpe + if input_padding_mask is not None: + attn_mask += input_padding_mask[:, None, None] * -100 + attn_mask = attn_mask.contiguous() # to enable efficient attention + + x = torch.nn.functional.scaled_dot_product_attention( + query=q, + key=k, + value=v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0, + scale=self.scale, + ) + + x = x.transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GlobalDecoderLayer(nn.Module): + def __init__( + self, + d_model=256, + d_ffn=1024, + dropout=0.1, + activation="relu", + n_heads=8, + norm_type="post_norm", + rpe_hidden_dim=512, + rpe_type="box_norm", + feature_stride=16, + reparam=False, + ): + super().__init__() + + self.norm_type = norm_type + + # global cross attention + self.cross_attn = GlobalCrossAttention( + d_model, + n_heads, + rpe_hidden_dim=rpe_hidden_dim, + rpe_type=rpe_type, + feature_stride=feature_stride, + reparam=reparam, + ) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # self attention + self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_pre( + self, + tgt, + query_pos, + reference_points, + src, + src_pos_embed, + src_spatial_shapes, + src_padding_mask=None, + self_attn_mask=None, + ): + # self attention + tgt2 = self.norm2(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn( + q.transpose(0, 1), k.transpose(0, 1), tgt2.transpose(0, 1), attn_mask=self_attn_mask, need_weights=False + )[0].transpose(0, 1) + tgt = tgt + self.dropout2(tgt2) + + # global cross attention + tgt2 = self.norm1(tgt) + tgt2 = self.cross_attn( + self.with_pos_embed(tgt2, query_pos), + reference_points, + self.with_pos_embed(src, src_pos_embed), + src, + src_spatial_shapes, + src_padding_mask, + ) + tgt = tgt + self.dropout1(tgt2) + + # ffn + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout4(tgt2) + + return tgt + + def forward_post( + self, + tgt, + query_pos, + reference_points, + src, + src_pos_embed, + src_spatial_shapes, + src_padding_mask=None, + self_attn_mask=None, + ): + # self attention + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn( + q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1), attn_mask=self_attn_mask, need_weights=False + )[0].transpose(0, 1) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # cross attention + tgt2 = self.cross_attn( + self.with_pos_embed(tgt, query_pos), + reference_points, + self.with_pos_embed(src, src_pos_embed), + src, + src_spatial_shapes, + src_padding_mask, + ) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # ffn + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + + return tgt + + def forward( + self, + tgt, + query_pos, + reference_points, + src, + src_pos_embed, + src_spatial_shapes, + src_padding_mask=None, + self_attn_mask=None, + ): + if self.norm_type == "pre_norm": + return self.forward_pre( + tgt, + query_pos, + reference_points, + src, + src_pos_embed, + src_spatial_shapes, + src_padding_mask, + self_attn_mask, + ) + if self.norm_type == "post_norm": + return self.forward_post( + tgt, + query_pos, + reference_points, + src, + src_pos_embed, + src_spatial_shapes, + src_padding_mask, + self_attn_mask, + ) + + +class GlobalDecoder(nn.Module): + def __init__( + self, + decoder_layer, + num_layers, + return_intermediate=False, + look_forward_twice=False, + use_checkpoint=False, + d_model=256, + norm_type="post_norm", + reparam=False, + ): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + self.look_forward_twice = look_forward_twice + self.use_checkpoint = use_checkpoint + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + self.reparam = reparam + + self.norm_type = norm_type + if self.norm_type == "pre_norm": + self.final_layer_norm = nn.LayerNorm(d_model) + else: + self.final_layer_norm = None + + def _reset_parameters(self): + # stolen from Swin Transformer + def _init_weights(m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def forward( + self, + tgt, + reference_points, + src, + src_pos_embed, + src_spatial_shapes, + src_level_start_index, + src_valid_ratios, + query_pos=None, + src_padding_mask=None, + self_attn_mask=None, + max_shape=None, + ): + output = tgt + + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + if self.reparam: + reference_points_input = reference_points[:, :, None] + else: + if reference_points.shape[-1] == 4: + reference_points_input = ( + reference_points[:, :, None] * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None] + ) + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None] + if self.use_checkpoint: + output = checkpoint.checkpoint( + layer, + output, + query_pos, + reference_points_input, + src, + src_pos_embed, + src_spatial_shapes, + src_padding_mask, + self_attn_mask, + ) + else: + output = layer( + output, + query_pos, + reference_points_input, + src, + src_pos_embed, + src_spatial_shapes, + src_padding_mask, + self_attn_mask, + ) + + if self.final_layer_norm is not None: + output_after_norm = self.final_layer_norm(output) + else: + output_after_norm = output + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: + tmp = self.bbox_embed[lid](output_after_norm) + if reference_points.shape[-1] == 4: + if self.reparam: + new_reference_points = box_xyxy_to_cxcywh(delta2bbox(reference_points, tmp, max_shape)) + else: + new_reference_points = tmp + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + if self.reparam: + raise NotImplementedError + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(output_after_norm) + intermediate_reference_points.append( + new_reference_points if self.look_forward_twice else reference_points + ) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack(intermediate_reference_points) + + return output_after_norm, reference_points + + +def build_global_rpe_decomp_decoder(args): + decoder_layer = GlobalDecoderLayer( + d_model=args.hidden_dim, + d_ffn=args.dim_feedforward, + dropout=args.dropout, + activation="relu", + n_heads=args.nheads, + norm_type=args.norm_type, + rpe_hidden_dim=args.decoder_rpe_hidden_dim, + rpe_type=args.decoder_rpe_type, + feature_stride=args.proposal_in_stride, + reparam=args.reparam, + ) + decoder = GlobalDecoder( + decoder_layer, + num_layers=args.dec_layers, + return_intermediate=True, + look_forward_twice=args.look_forward_twice, + use_checkpoint=args.decoder_use_checkpoint, + d_model=args.hidden_dim, + norm_type=args.norm_type, + reparam=args.reparam, + ) + return decoder diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/position_encoding.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..53b2c03085cbcaa3db10ee680d2fee2b129ec426 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/position_encoding.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# ------------------------------------------------------------------------ +# Plain-DETR +# Copyright (c) 2023 Xi'an Jiaotong University & Microsoft Research Asia. +# Licensed under The MIT License [see LICENSE for details] +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Various positional encodings for the transformer. +""" +import math +from enum import Enum + +import torch +from torch import nn + +from ..util.misc import NestedTensor + + +class PositionEncoding(Enum): + LEARNED = "learned" + SINE = "sine" + SINE_UNNORM = "sine_unnorm" + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale + else: + y_embed = (y_embed - 0.5) * self.scale + x_embed = (x_embed - 0.5) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = ( + torch.cat( + [ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], + dim=-1, + ) + .permute(2, 0, 1) + .unsqueeze(0) + .repeat(x.shape[0], 1, 1, 1) + ) + return pos + + +def build_position_encoding(args): + N_steps = args.hidden_dim // 2 + if args.position_embedding == PositionEncoding.SINE: # also called v2 + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif args.position_embedding == PositionEncoding.LEARNED: # also called v3 + position_embedding = PositionEmbeddingLearned(N_steps) + elif args.position_embedding == PositionEncoding.SINE_UNNORM: # also called v4 + position_embedding = PositionEmbeddingSine(N_steps, normalize=False) + else: + raise ValueError(f"not supported {args.position_embedding}") + position_embedding = nn.ModuleList([position_embedding for _ in range(args.num_feature_levels)]) + + return position_embedding diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/transformer.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..7684e86533d7f2b7694039cae29752481b4f8541 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/transformer.py @@ -0,0 +1,432 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +import math + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.init import constant_, normal_, xavier_uniform_ + +from ..util.box_ops import box_xyxy_to_cxcywh, delta2bbox +from .global_ape_decoder import build_global_ape_decoder +from .global_rpe_decomp_decoder import build_global_rpe_decomp_decoder +from .transformer_encoder import TransformerEncoder, TransformerEncoderLayer +from .utils import LayerNorm2D + + +class Transformer(nn.Module): + def __init__( + self, + d_model=256, + nhead=8, + num_feature_levels=4, + two_stage=False, + two_stage_num_proposals=300, + mixed_selection=False, + norm_type="post_norm", + decoder_type="deform", + proposal_feature_levels=1, + proposal_in_stride=16, + proposal_tgt_strides=[8, 16, 32, 64], + proposal_min_size=50, + args=None, + # transformer_encoder + add_transformer_encoder=False, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + num_encoder_layers=6, + ): + super().__init__() + + self.d_model = d_model + self.nhead = nhead + self.two_stage = two_stage + self.two_stage_num_proposals = two_stage_num_proposals + assert norm_type in ["pre_norm", "post_norm"], f"expected norm type is pre_norm or post_norm, get {norm_type}" + + if decoder_type == "global_ape": + self.decoder = build_global_ape_decoder(args) + elif decoder_type == "global_rpe_decomp": + self.decoder = build_global_rpe_decomp_decoder(args) + else: + raise NotImplementedError + + self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) + + if two_stage: + self.enc_output = nn.Linear(d_model, d_model) + self.enc_output_norm = nn.LayerNorm(d_model) + self.pos_trans = nn.Linear(d_model * 2, d_model * 2) + self.pos_trans_norm = nn.LayerNorm(d_model * 2) + else: + self.reference_points = nn.Linear(d_model, 2) + + self.mixed_selection = mixed_selection + self.proposal_feature_levels = proposal_feature_levels + self.proposal_tgt_strides = proposal_tgt_strides + self.proposal_min_size = proposal_min_size + if two_stage and proposal_feature_levels > 1: + assert len(proposal_tgt_strides) == proposal_feature_levels + + self.proposal_in_stride = proposal_in_stride + self.enc_output_proj = nn.ModuleList([]) + for stride in proposal_tgt_strides: + if stride == proposal_in_stride: + self.enc_output_proj.append(nn.Identity()) + elif stride > proposal_in_stride: + scale = int(math.log2(stride / proposal_in_stride)) + layers = [] + for _ in range(scale - 1): + layers += [ + nn.Conv2d(d_model, d_model, kernel_size=2, stride=2), + LayerNorm2D(d_model), + nn.GELU(), + ] + layers.append(nn.Conv2d(d_model, d_model, kernel_size=2, stride=2)) + self.enc_output_proj.append(nn.Sequential(*layers)) + else: + scale = int(math.log2(proposal_in_stride / stride)) + layers = [] + for _ in range(scale - 1): + layers += [ + nn.ConvTranspose2d(d_model, d_model, kernel_size=2, stride=2), + LayerNorm2D(d_model), + nn.GELU(), + ] + layers.append(nn.ConvTranspose2d(d_model, d_model, kernel_size=2, stride=2)) + self.enc_output_proj.append(nn.Sequential(*layers)) + + # ENCODER TRANSFORMER + self.encoder = None + if add_transformer_encoder: + encoder_layer = TransformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + activation, + normalize_before, + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + if not self.two_stage: + xavier_uniform_(self.reference_points.weight.data, gain=1.0) + constant_(self.reference_points.bias.data, 0.0) + normal_(self.level_embed) + + if hasattr(self.decoder, "_reset_parameters"): + self.decoder._reset_parameters() + + def get_proposal_pos_embed(self, proposals): + num_pos_feats = self.d_model // 2 + temperature = 10000 + scale = 2 * math.pi + + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) + return pos + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): + if self.proposal_feature_levels > 1: + memory, memory_padding_mask, spatial_shapes = self.expand_encoder_output( + memory, memory_padding_mask, spatial_shapes + ) + N_, S_, C_ = memory.shape + # base_scale = 4.0 + proposals = [] + _cur = 0 + for lvl, (H_, W_) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(N_, H_, W_, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid( + torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), + torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device), + ) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) + proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) + proposals.append(proposal) + _cur += H_ * W_ + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf")) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf")) + + output_memory = memory + output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + + max_shape = None + return output_memory, output_proposals, max_shape + + def get_valid_ratio(self, mask): + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def expand_encoder_output(self, memory, memory_padding_mask, spatial_shapes): + assert len(spatial_shapes) == 1, f"Get encoder output of shape {spatial_shapes}, not sure how to expand" + + bs, _, c = memory.shape + h, w = spatial_shapes[0] + + _out_memory = memory.view(bs, h, w, c).permute(0, 3, 1, 2) + _out_memory_padding_mask = memory_padding_mask.view(bs, h, w) + + out_memory, out_memory_padding_mask, out_spatial_shapes = [], [], [] + for i in range(self.proposal_feature_levels): + mem = self.enc_output_proj[i](_out_memory) + mask = F.interpolate(_out_memory_padding_mask[None].float(), size=mem.shape[-2:]).to(torch.bool) + + out_memory.append(mem) + out_memory_padding_mask.append(mask.squeeze(0)) + out_spatial_shapes.append(mem.shape[-2:]) + + out_memory = torch.cat([mem.flatten(2).transpose(1, 2) for mem in out_memory], dim=1) + out_memory_padding_mask = torch.cat([mask.flatten(1) for mask in out_memory_padding_mask], dim=1) + return out_memory, out_memory_padding_mask, out_spatial_shapes + + def get_reference_points(self, memory, mask_flatten, spatial_shapes): + output_memory, output_proposals, max_shape = self.gen_encoder_output_proposals( + memory, mask_flatten, spatial_shapes + ) + + # hack implementation for two-stage Deformable DETR + enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory) + enc_outputs_delta = None + enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals + + topk = self.two_stage_num_proposals + topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] + topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + return ( + reference_points, + max_shape, + enc_outputs_class, + enc_outputs_coord_unact, + enc_outputs_delta, + output_proposals, + ) + + def forward(self, srcs, masks, pos_embeds, query_embed=None, self_attn_mask=None): + # TODO: we may remove this loop as we only have one feature level + # prepare input for encoder + src_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): + bs, c, h, w = src.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + src = src.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + src_flatten.append(src) + mask_flatten.append(mask) + src_flatten = torch.cat(src_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + level_start_index = None # not used so far + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + + if self.encoder is not None: + memory = self.encoder(src_flatten, src_key_padding_mask=mask_flatten, pos=lvl_pos_embed_flatten) + else: + memory = src_flatten + + # prepare input for decoder + bs, _, c = memory.shape + if self.two_stage: + ( + reference_points, + max_shape, + enc_outputs_class, + enc_outputs_coord_unact, + enc_outputs_delta, + output_proposals, + ) = self.get_reference_points(memory, mask_flatten, spatial_shapes) + init_reference_out = reference_points + pos_trans_out = torch.zeros((bs, self.two_stage_num_proposals, 2 * c), device=init_reference_out.device) + pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(reference_points))) + + if not self.mixed_selection: + query_embed, tgt = torch.split(pos_trans_out, c, dim=2) + else: + # query_embed here is the content embed for deformable DETR + tgt = query_embed.unsqueeze(0).expand(bs, -1, -1) + query_embed, _ = torch.split(pos_trans_out, c, dim=2) + else: + query_embed, tgt = torch.split(query_embed, c, dim=1) + query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1) + tgt = tgt.unsqueeze(0).expand(bs, -1, -1) + reference_points = self.reference_points(query_embed).sigmoid() + init_reference_out = reference_points + max_shape = None + + # decoder + hs, inter_references = self.decoder( + tgt, + reference_points, + memory, + lvl_pos_embed_flatten, + spatial_shapes, + level_start_index, + valid_ratios, + query_embed, + mask_flatten, + self_attn_mask, + max_shape, + ) + + inter_references_out = inter_references + if self.two_stage: + return ( + hs, + init_reference_out, + inter_references_out, + enc_outputs_class, + enc_outputs_coord_unact, + enc_outputs_delta, + output_proposals, + max_shape, + ) + return hs, init_reference_out, inter_references_out, None, None, None, None, None + + +class TransformerReParam(Transformer): + def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): + if self.proposal_feature_levels > 1: + memory, memory_padding_mask, spatial_shapes = self.expand_encoder_output( + memory, memory_padding_mask, spatial_shapes + ) + N_, S_, C_ = memory.shape + # base_scale = 4.0 + proposals = [] + _cur = 0 + for lvl, (H_, W_) in enumerate(spatial_shapes): + stride = self.proposal_tgt_strides[lvl] + + grid_y, grid_x = torch.meshgrid( + torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), + torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device), + ) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) * stride + wh = torch.ones_like(grid) * self.proposal_min_size * (2.0**lvl) + proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) + proposals.append(proposal) + _cur += H_ * W_ + output_proposals = torch.cat(proposals, 1) + + H_, W_ = spatial_shapes[0] + stride = self.proposal_tgt_strides[0] + mask_flatten_ = memory_padding_mask[:, : H_ * W_].view(N_, H_, W_, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1, keepdim=True) * stride + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1, keepdim=True) * stride + img_size = torch.cat([valid_W, valid_H, valid_W, valid_H], dim=-1) + img_size = img_size.unsqueeze(1) # [BS, 1, 4] + + output_proposals_valid = ((output_proposals > 0.01 * img_size) & (output_proposals < 0.99 * img_size)).all( + -1, keepdim=True + ) + output_proposals = output_proposals.masked_fill( + memory_padding_mask.unsqueeze(-1).repeat(1, 1, 1), max(H_, W_) * stride + ) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, max(H_, W_) * stride) + + output_memory = memory + output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + + max_shape = (valid_H[:, None, :], valid_W[:, None, :]) + return output_memory, output_proposals, max_shape + + def get_reference_points(self, memory, mask_flatten, spatial_shapes): + output_memory, output_proposals, max_shape = self.gen_encoder_output_proposals( + memory, mask_flatten, spatial_shapes + ) + + # hack implementation for two-stage Deformable DETR + enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory) + enc_outputs_delta = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + enc_outputs_coord_unact = box_xyxy_to_cxcywh(delta2bbox(output_proposals, enc_outputs_delta, max_shape)) + + topk = self.two_stage_num_proposals + topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] + topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact + return ( + reference_points, + max_shape, + enc_outputs_class, + enc_outputs_coord_unact, + enc_outputs_delta, + output_proposals, + ) + + +def build_transformer(args): + model_class = Transformer if (not args.reparam) else TransformerReParam + return model_class( + d_model=args.hidden_dim, + nhead=args.nheads, + num_feature_levels=args.num_feature_levels, + two_stage=args.two_stage, + two_stage_num_proposals=args.num_queries_one2one + args.num_queries_one2many, + mixed_selection=args.mixed_selection, + norm_type=args.norm_type, + decoder_type=args.decoder_type, + proposal_feature_levels=args.proposal_feature_levels, + proposal_in_stride=args.proposal_in_stride, + proposal_tgt_strides=args.proposal_tgt_strides, + args=args, + proposal_min_size=args.proposal_min_size, + # transformer_encoder + add_transformer_encoder=args.add_transformer_encoder, + num_encoder_layers=args.num_encoder_layers, + ) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/transformer_encoder.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/transformer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef9c2537bb7c7245e0e6c49846955370164bccb --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/transformer_encoder.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# https://github.com/facebookresearch/detr/blob/main/models/transformer.py +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +from typing import Optional + +from torch import Tensor, nn + +from ..util.misc import _get_activation_fn, _get_clones + + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False): + super().__init__() + # Keeping Dropout 0 in self attention as it makes the eval 10% faster without performance change + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0, batch_first=True) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn( + q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask, need_weights=False + )[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn( + q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask, need_weights=False + )[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/utils.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a4f8e2f3986ca1fcee43a31ee45473d6bf65040b --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/utils.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import torch.nn as nn + + +class LayerNorm2D(nn.Module): + def __init__(self, normalized_shape, norm_layer=nn.LayerNorm): + super().__init__() + self.ln = norm_layer(normalized_shape) if norm_layer is not None else nn.Identity() + + def forward(self, x): + """ + x: N C H W + """ + x = x.permute(0, 2, 3, 1) + x = self.ln(x) + x = x.permute(0, 3, 1, 2) + return x diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/windows.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/windows.py new file mode 100644 index 0000000000000000000000000000000000000000..5399d235b136baa98f0c75a57289dd406cb0d83b --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/models/windows.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import math + +import numpy as np +import torch +import torch.nn.functional as F +from torchvision.transforms import v2 + +from ..util.misc import NestedTensor + + +class WindowsWrapper(torch.nn.Module): + """ + This wrapper will take an input (NestedTensor) at size (h, w) and split it + in `N = n_windows_h * n_windows_w` equally sized windows (the bottom and right windows might + be a little bit smaller), with sizes that are multiples of the patch size (as the input should be). + + Then, the input will be resized at the size of the top left window (h / n_windows_h, w / n_windows_w). + This resized input, plus the N windows, will be passed through the backbone. + Then, the features of the resized input will be resized to the original input size, while the + features of the windows will be concatenated side by side to reconstruct a feature map also + corresponding to the original image's size. + + Finally, both the features from the windows and from the resized images are stacked. + Compared to the output of the backbone of size [B, C, H, W], the output here is [B, 2 * C, H, W] + """ + + def __init__(self, backbone, n_windows_w, n_windows_h, patch_size): + # Assuming image size is divisible by patch_size + super().__init__() + self._backbone = backbone + self._n_windows_w = n_windows_w + self._n_windows_h = n_windows_h + self._patch_size = patch_size + self.strides = backbone.strides + self.num_channels = [el * 2 for el in backbone.num_channels] # resized + windows + + def forward(self, tensor_list: NestedTensor): + tensors = tensor_list.tensors + original_h, original_w = tensors.shape[2], tensors.shape[3] + # Get height and width of the windows, such that it is a multiple of the patch size + window_h = math.ceil((original_h // self._n_windows_h) / self._patch_size) * self._patch_size + window_w = math.ceil((original_w // self._n_windows_w) / self._patch_size) * self._patch_size + all_h = [window_h] * (self._n_windows_h - 1) + [original_h - window_h * (self._n_windows_h - 1)] + all_w = [window_w] * (self._n_windows_w - 1) + [original_w - window_w * (self._n_windows_w - 1)] + all_h_cumsum = [0] + list(np.cumsum(all_h)) + all_w_cumsum = [0] + list(np.cumsum(all_w)) + window_patch_features = [[0 for _ in range(self._n_windows_w)] for _ in range(self._n_windows_h)] + + for ih in range(self._n_windows_h): + for iw in range(self._n_windows_w): + window_tensor = v2.functional.crop( + tensors, top=all_h_cumsum[ih], left=all_w_cumsum[iw], height=all_h[ih], width=all_w[iw] + ) + window_mask = v2.functional.crop( + tensor_list.mask, top=all_h_cumsum[ih], left=all_w_cumsum[iw], height=all_h[ih], width=all_w[iw] + ) + window_patch_features[ih][iw] = self._backbone(NestedTensor(tensors=window_tensor, mask=window_mask))[0] + + window_tensors = torch.cat( + [ + torch.cat([el.tensors for el in window_patch_features[ih]], dim=-1) # type: ignore + for ih in range(len(window_patch_features)) + ], + dim=-2, + ) + # Also compute the global features in a "preferential" setting, of lower resolution + resized_global_tensor = v2.functional.resize(tensors, size=(window_h, window_w)) + global_features = self._backbone( + NestedTensor(tensors=resized_global_tensor, mask=tensor_list.mask) + ) # mask is not used + + concat_tensors = torch.cat( + [v2.functional.resize(global_features[0].tensors, size=window_tensors.shape[-2:]), window_tensors], dim=1 + ) + global_mask = F.interpolate(tensor_list.mask[None].float(), size=concat_tensors.shape[-2:]).to(torch.bool)[0] + out = [NestedTensor(tensors=concat_tensors, mask=global_mask)] + return out diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/util/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5375bc66e1ed841a7091b81a0dcf56d1993c1f87 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/util/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/util/box_ops.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/util/box_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..de47dbc46a6c37bae27fbe279cd0877af6fd7ec2 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/util/box_ops.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# ------------------------------------------------------------------------ +# Plain-DETR +# Copyright (c) 2023 Xi'an Jiaotong University & Microsoft Research Asia. +# Licensed under The MIT License [see LICENSE for details] +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Utilities for bounding box manipulation and GIoU. +""" +import numpy as np +import torch + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +def delta2bbox( + proposals, deltas, max_shape=None, wh_ratio_clip=16 / 1000, clip_border=True, add_ctr_clamp=False, ctr_clamp=32 +): + dxy = deltas[..., :2] + dwh = deltas[..., 2:] + + # Compute width/height of each roi + pxy = proposals[..., :2] + pwh = proposals[..., 2:] + + dxy_wh = pwh * dxy + + max_ratio = np.abs(np.log(wh_ratio_clip)) + if add_ctr_clamp: + dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp) + dwh = torch.clamp(dwh, max=max_ratio) + else: + dwh = dwh.clamp(min=-max_ratio, max=max_ratio) + + gxy = pxy + dxy_wh + gwh = pwh * dwh.exp() + x1y1 = gxy - (gwh * 0.5) + x2y2 = gxy + (gwh * 0.5) + bboxes = torch.cat([x1y1, x2y2], dim=-1) + if clip_border and max_shape is not None: + bboxes[..., 0::2].clamp_(min=0).clamp_(max=max_shape[1]) + bboxes[..., 1::2].clamp_(min=0).clamp_(max=max_shape[0]) + return bboxes + + +def bbox2delta(proposals, gt, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)): + # hack for matcher + if proposals.size() != gt.size(): + proposals = proposals[:, None] + gt = gt[None] + + proposals = proposals.float() + gt = gt.float() + px, py, pw, ph = proposals.unbind(-1) + gx, gy, gw, gh = gt.unbind(-1) + + dx = (gx - px) / (pw + 0.1) + dy = (gy - py) / (ph + 0.1) + dw = torch.log(gw / (pw + 0.1)) + dh = torch.log(gh / (ph + 0.1)) + deltas = torch.stack([dx, dy, dw, dh], dim=-1) + + # avoid unnecessary sync point if not needed + if means != (0.0, 0.0, 0.0, 0.0) or stds != (1.0, 1.0, 1.0, 1.0): + means = deltas.new_tensor(means).unsqueeze(0) + stds = deltas.new_tensor(stds).unsqueeze(0) + deltas = deltas.sub_(means).div_(stds) + + return deltas diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/util/misc.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/util/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..7fd2e7e758821de7363e856d053d8305ca780813 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/detection/util/misc.py @@ -0,0 +1,281 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# ------------------------------------------------------------------------ +# Plain-DETR +# Copyright (c) 2023 Xi'an Jiaotong University & Microsoft Research Asia. +# Licensed under The MIT License [see LICENSE for details] +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import copy +from typing import List, Optional + +import dinov3.distributed as distributed +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +from torch import Tensor + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = distributed.get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +def collate_fn(batch): + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + else: + raise ValueError("not supported") + return NestedTensor(tensor, mask) + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device, non_blocking=False): + cast_tensor = self.tensors.to(device, non_blocking=non_blocking) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device, non_blocking=non_blocking) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def record_stream(self, *args, **kwargs): + self.tensors.record_stream(*args, **kwargs) + if self.mask is not None: + self.mask.record_stream(*args, **kwargs) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + def __len__(self): + return len(self.tensors) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) + + +def get_total_grad_norm(parameters, norm_type=2): + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + device = parameters[0].grad.device + total_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), + norm_type, + ) + return total_norm + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +def match_name_keywords(n, name_keywords): + out = False + for b in name_keywords: + if b in n: + out = True + break + return out + + +def get_param_dict(model, args, return_name=False, use_layerwise_decay=False): + # sanity check: a variable could not match backbone_names and linear_proj_names at the same time + for n, p in model.named_parameters(): + if match_name_keywords(n, args.lr_backbone_names) and match_name_keywords(n, args.lr_linear_proj_names): + raise ValueError + + param_dicts = [ + { + "params": [ + p if not return_name else n + for n, p in model.named_parameters() + if not match_name_keywords(n, args.lr_backbone_names) + and not match_name_keywords(n, args.lr_linear_proj_names) + and not match_name_keywords(n, args.wd_norm_names) + and p.requires_grad + ], + "lr": args.lr, + "weight_decay": args.weight_decay, + }, + { + "params": [ + p if not return_name else n + for n, p in model.named_parameters() + if match_name_keywords(n, args.lr_backbone_names) + and not match_name_keywords(n, args.lr_linear_proj_names) + and not match_name_keywords(n, args.wd_norm_names) + and p.requires_grad + ], + "lr": args.lr_backbone, + "weight_decay": args.weight_decay, + }, + { + "params": [ + p if not return_name else n + for n, p in model.named_parameters() + if not match_name_keywords(n, args.lr_backbone_names) + and match_name_keywords(n, args.lr_linear_proj_names) + and not match_name_keywords(n, args.wd_norm_names) + and p.requires_grad + ], + "lr": args.lr * args.lr_linear_proj_mult, + "weight_decay": args.weight_decay, + }, + { + "params": [ + p if not return_name else n + for n, p in model.named_parameters() + if not match_name_keywords(n, args.lr_backbone_names) + and not match_name_keywords(n, args.lr_linear_proj_names) + and match_name_keywords(n, args.wd_norm_names) + and p.requires_grad + ], + "lr": args.lr, + "weight_decay": args.weight_decay * args.wd_norm_mult, + }, + { + "params": [ + p if not return_name else n + for n, p in model.named_parameters() + if match_name_keywords(n, args.lr_backbone_names) + and not match_name_keywords(n, args.lr_linear_proj_names) + and match_name_keywords(n, args.wd_norm_names) + and p.requires_grad + ], + "lr": args.lr_backbone, + "weight_decay": args.weight_decay * args.wd_norm_mult, + }, + { + "params": [ + p if not return_name else n + for n, p in model.named_parameters() + if not match_name_keywords(n, args.lr_backbone_names) + and match_name_keywords(n, args.lr_linear_proj_names) + and match_name_keywords(n, args.wd_norm_names) + and p.requires_grad + ], + "lr": args.lr * args.lr_linear_proj_mult, + "weight_decay": args.weight_decay * args.wd_norm_mult, + }, + ] + return param_dicts + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/helpers.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..0868dde9bf585327b5b087d070e14b394a0a7f22 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/helpers.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +import os +from typing import Any + +import torch +from omegaconf import OmegaConf + +import dinov3.distributed +from dinov3.eval import results + +logger = logging.getLogger("dinov3") + +CONFIG_FILE_KEY = "config_file" +EVAL_CONFIG_FNAME = "eval_config.yaml" + + +def write_results(results_dict, output_dir, results_filename) -> None: + """Save only on main if cuda is available""" + if torch.cuda.is_available() and not dinov3.distributed.is_main_process(): + return + results_path = os.path.join(output_dir, results_filename) + logger.info(f"Saving results to {results_path}") + results.save_from_dict(results_dict=results_dict, results_path=results_path) + + +def args_dict_to_dataclass(eval_args: dict[str, object], config_dataclass, save_config: bool = True) -> tuple[Any, str]: + """ + eval_args : arguments passed to create the eval config. + `CONFIG_FILE_KEY` is a reserved name to load a set of parameters from a config file. + config_dataclass: a dataclass used to define the config arguments, types and default values + save_config : whether to save the config in a file named `EVAL_CONFIG_FNAME` in the output_dir + """ + if CONFIG_FILE_KEY in eval_args: + config_file = eval_args.pop(CONFIG_FILE_KEY) + eval_args_config = OmegaConf.merge(OmegaConf.load(config_file), OmegaConf.create(eval_args)) + else: + eval_args_config = OmegaConf.create(eval_args) + + structured_config = OmegaConf.merge(OmegaConf.structured(config_dataclass), eval_args_config) + logger.info(f"Evaluation Configuration:\n{OmegaConf.to_yaml(structured_config)}") + output_dir = structured_config.output_dir + + if save_config and dinov3.distributed.is_main_process(): + OmegaConf.save(config=structured_config, f=os.path.join(output_dir, EVAL_CONFIG_FNAME)) + + return OmegaConf.to_object(structured_config), output_dir + + +def cli_parser(argv: list[str]) -> tuple[dict[str, Any]]: + """ + a method to parse argv and output a dict of eval arguments, and model building arguments. + - `argv` can come from the command line directly, or from a subset of the command line arguments, + as in dinov3.run.submitit + - `output_dir` can either be passed as `output_dir=` or `--output-dir=` (to support dinov3.run.submitit) + """ + cli_eval_args_dict = OmegaConf.to_container(OmegaConf.from_cli(argv)) + if "output_dir" not in cli_eval_args_dict: + cli_eval_args_dict["output_dir"] = cli_eval_args_dict.pop("--output-dir", ".") + return cli_eval_args_dict diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/knn.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..8b994133cc5800c519bd22d5997cf0f23129da45 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/knn.py @@ -0,0 +1,384 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import json +import logging +import os +import sys +import time +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.backends.cudnn as cudnn +from omegaconf import MISSING +from torch.nn.functional import one_hot, softmax + +import dinov3.distributed as distributed +from dinov3.data import SamplerType, make_data_loader, make_dataset +from dinov3.data.adapters import DatasetWithEnumeratedTargets +from dinov3.data.transforms import ( + CROP_DEFAULT_SIZE, + RESIZE_DEFAULT_SIZE, + get_target_transform, + make_classification_eval_transform, +) +from dinov3.distributed import gather_all_tensors +from dinov3.eval.data import ( + create_train_dataset_dict, + extract_features_for_dataset_dict, + get_num_classes, + pad_multilabel_and_collate, +) +from dinov3.eval.helpers import args_dict_to_dataclass, cli_parser, write_results +from dinov3.eval.metrics import ClassificationMetricType, build_classification_metric +from dinov3.eval.setup import ModelConfig, load_model_and_context +from dinov3.eval.utils import ModelWithNormalize, average_metrics, evaluate +from dinov3.eval.utils import save_results as default_save_results_func +from dinov3.run.init import job_context + +logger = logging.getLogger("dinov3") + + +RESULTS_FILENAME = "results-knn.csv" +MAIN_METRICS = [".* Top 1"] + + +@dataclass +class TrainConfig: + dataset: str = MISSING # train dataset path + batch_size: int = 256 # batch size for train set feature extraction + num_workers: int = 5 # number of workers for train set feature extraction + ks: Tuple[int, ...] = (10, 20, 100, 200) # values of k to evaluate + temperature: float = 0.07 + """ + Whether to skip the first nearest neighbor for each image in the test set. + Useful when training and testing on the same dataset split. + """ + skip_first_nn: bool = False + + +@dataclass +class EvalConfig: + test_dataset: str = MISSING # test dataset path + test_metric_type: ClassificationMetricType = ClassificationMetricType.MEAN_ACCURACY + batch_size: int | None = None # batch size for evaluation, None to use train batch size + num_workers: int = 5 # number of workers for evaluation + + +@dataclass +class TransformConfig: + resize_size: int = RESIZE_DEFAULT_SIZE + crop_size: int = CROP_DEFAULT_SIZE + + +@dataclass +class FewShotConfig: + enable: bool = False # whether to use few-shot evaluation + k_or_percent: Optional[float] = None # number of elements or % to take per class + n_tries: int = 1 # number of tries for few-shot evaluation + + +@dataclass +class KnnEvalConfig: + model: ModelConfig + train: TrainConfig = field(default_factory=TrainConfig) + eval: EvalConfig = field(default_factory=EvalConfig) + transform: TransformConfig = field(default_factory=TransformConfig) + few_shot: FewShotConfig = field(default_factory=FewShotConfig) + save_results: bool = False # save predictions and targets in the output directory + output_dir: str = "" + + +class KnnModule(torch.nn.Module): + """ + Gets knn of test features from all processes on a chunk of the train features + + Each rank gets a chunk of the train features as well as a chunk of the test features. + In `compute_neighbors`, for each rank one after the other, its chunk of test features + is sent to all devices, partial knns are computed with each chunk of train features + then collated back on the original device. + """ + + def __init__(self, *, train_features, train_labels, device, ks, T, num_classes=1000, skip_first_nn=False): + super().__init__() + + self.rank = distributed.get_rank() + self.world_size = distributed.get_world_size() + + self.device = device + self.train_features_rank_T = train_features.chunk(self.world_size)[self.rank].T.to(self.device) + # Labels can either be integers, or in a one-hot format + self.candidates = train_labels.chunk(self.world_size)[self.rank].unsqueeze(0).to(self.device) + + self.ks = ks + self.max_k = max(self.ks) + skip_first_nn + self.T = T + self.num_classes = num_classes + self.skip_first_nn = skip_first_nn + + if self.skip_first_nn: + logger.info("Skipping the first nearest neighbor of each element in the test dataset") + + def _get_knn_sims_and_labels(self, similarity, train_labels): + topk_sims, indices = similarity.topk(min(self.max_k, similarity.shape[1]), largest=True, sorted=True) + if len(train_labels.shape) == 3: # If the labels are in one_hot format + indices = indices.unsqueeze(2).expand(-1, -1, self.num_classes) # Orignally [bs, max_k] + neighbors_labels = torch.gather(train_labels, 1, indices) + return topk_sims, neighbors_labels + + def _similarity_for_rank(self, features_rank, source_rank): + """ + Broadcasts `features_rank` from `source_rank` and compute similarities + with the train features chunks from all ranks + """ + # Send the features from `source_rank` to all ranks + broadcast_shape = torch.tensor(features_rank.shape).to(self.device) + torch.distributed.broadcast(broadcast_shape, source_rank) + + broadcasted = features_rank + if self.rank != source_rank: + broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device) + torch.distributed.broadcast(broadcasted, source_rank) + + # Compute the neighbors for `source_rank` among `train_features_rank_T` + similarity_rank = torch.mm(broadcasted, self.train_features_rank_T) + candidate_labels = self.candidates.expand(len(similarity_rank), *self.candidates.shape[1:]) + return self._get_knn_sims_and_labels(similarity_rank, candidate_labels) + + def compute_neighbors(self, features_rank): + """ + If we are on rank `rank`, we broadcast the test features to other ranks, compute similarities + with their chunks of the train features, then gather these partial similarities back on `rank` + """ + topk_sims_rank, neighbors_labels_rank = None, None + for rank in range(self.world_size): + partial_topk_sims, partial_neighbors_labels = self._similarity_for_rank(features_rank, rank) + gathered_topk_sims = torch.cat(gather_all_tensors(partial_topk_sims), dim=1) + gathered_neighbor_labels = torch.cat(gather_all_tensors(partial_neighbors_labels), dim=1) + if self.rank == rank: # Performing a second top-k to get k neighbors from the gathered k * world_size + topk_sims_rank, neighbors_labels_rank = self._get_knn_sims_and_labels( + gathered_topk_sims, gathered_neighbor_labels + ) + return topk_sims_rank, neighbors_labels_rank + + def forward(self, features_rank): + """ + Compute the results on all values of `self.ks` neighbors from the full `self.max_k` + """ + assert all(k <= self.max_k for k in self.ks) + + topk_sims, neighbors_labels = self.compute_neighbors(features_rank) + if self.skip_first_nn: + topk_sims, neighbors_labels = topk_sims[:, 1:], neighbors_labels[:, 1:] + batch_size = neighbors_labels.shape[0] + topk_sims_transform = softmax(topk_sims / self.T, 1) + voting_coefficient = topk_sims_transform.view(batch_size, -1, 1) + if len(neighbors_labels.shape) == 2: # If the labels are not yet one hot + neighbors_labels = one_hot(neighbors_labels, num_classes=self.num_classes) + matmul = torch.mul(neighbors_labels, voting_coefficient) + probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.ks} + return probas_for_k + + +class DictKeysModule(torch.nn.Module): + def __init__(self, keys): + super().__init__() + self.keys = keys + + def forward(self, features_dict, targets): + for k in self.keys: + features_dict = features_dict[k] + return {"preds": features_dict, "target": targets} + + +def make_transform(config: TransformConfig): + if config.resize_size / config.crop_size != 256 / 224: + logger.warning( + f"Default resize / crop ratio is 256 / 224, here we have {config.resize_size} / {config.crop_size}" + ) + transform = make_classification_eval_transform(resize_size=config.resize_size, crop_size=config.crop_size) + return transform + + +def make_test_data_loader(config: EvalConfig, transform): + # Create test data loader. Do not extract features in advance due to difficulties with multilabel datasets. + multilabel_collate_fn = config.test_metric_type == ClassificationMetricType.ANY_MATCH_ACCURACY + test_dataset = make_dataset( + dataset_str=config.test_dataset, + transform=transform, + target_transform=get_target_transform(config.test_dataset), + ) + assert isinstance(config.batch_size, int) # eval batch size has been replaced by train batch size if None + + return make_data_loader( + dataset=DatasetWithEnumeratedTargets(test_dataset, pad_dataset=True, num_replicas=distributed.get_world_size()), + batch_size=config.batch_size, + num_workers=config.num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + persistent_workers=True, + collate_fn=pad_multilabel_and_collate if multilabel_collate_fn else None, + ) + + +def eval_knn( + *, + model, + train_data_dict, + test_data_loader, + metric_collection, + knn_config: TrainConfig, + num_classes: int, + save_results_func=None, +): + logger.info("Start the k-NN classification.") + eval_metrics_dict: Dict[int, Dict[int, Dict[str, float]]] = {} # {k: {try: {metric_name: metric_value}}} + save_results = save_results_func is not None + device = torch.cuda.current_device() + partial_knn_module = partial( + KnnModule, + device=device, + num_classes=num_classes, + T=knn_config.temperature, + skip_first_nn=knn_config.skip_first_nn, + ) + + for try_ in train_data_dict.keys(): + train_features, train_labels = train_data_dict[try_]["train_features"], train_data_dict[try_]["train_labels"] + ks = sorted(set([el if el < len(train_features) else len(train_features) for el in knn_config.ks])) + knn_module = partial_knn_module(train_features=train_features, train_labels=train_labels, ks=ks) + postprocessors, metrics = {k: DictKeysModule([k]) for k in ks}, {k: metric_collection.clone() for k in ks} + _, eval_metrics, accumulated_results = evaluate( + torch.nn.Sequential(model, knn_module), + test_data_loader, + postprocessors, + metrics, + device, + accumulate_results=save_results, + ) + for k in ks: + if save_results: + if len(train_data_dict) > 1: + split_results_saver = partial(save_results_func, filename_suffix=f"try_{try_}_k_{k}") + else: + split_results_saver = partial(save_results_func, filename_suffix=f"k_{k}") + split_results_saver(**accumulated_results[k]) + + if k not in eval_metrics_dict: + eval_metrics_dict[k] = {} + eval_metrics_dict[k][try_] = {metric: v.item() * 100.0 for metric, v in eval_metrics[k].items()} + + if len(train_data_dict) > 1: + return {k: average_metrics(eval_metrics_dict[k]) for k in eval_metrics_dict.keys()} + + return {k: eval_metrics_dict[k][0] for k in eval_metrics_dict.keys()} + + +def _log_and_format_results_dict(input_results_dict, few_shot_n_tries: int) -> Dict[str, float]: + results_dict = {} + for knn_ in input_results_dict.keys(): + if few_shot_n_tries == 1: + top1 = input_results_dict[knn_]["top-1"] + results_dict[f"{knn_} Top 1"] = top1 + results_string = f"{knn_} NN classifier result: Top1: {top1:.2f}" + if "top-5" in input_results_dict[knn_]: + top5 = input_results_dict[knn_]["top-5"] + results_dict[f"{knn_} Top 5"] = top5 + results_string += f" Top5: {top5:.2f}" + else: + top1_mean, top1_std = input_results_dict[knn_]["top-1_mean"], input_results_dict[knn_]["top-1_std"] + results_dict[f"{knn_} Top 1"] = top1_mean + results_string = f"{knn_} NN classifier result: Top1 Avg: {top1_mean:.2f}, Top1 Std {top1_std:.2f}" + if "top-5_mean" in input_results_dict[knn_]: + top5_mean, top5_std = input_results_dict[knn_]["top-5_mean"], input_results_dict[knn_]["top-5_std"] + results_dict[f"{knn_} Top 5"] = top5_mean + results_string += f" Top5 Avg: {top5_mean:.2f}, Top5 Std {top5_std:.2f}" + logger.info(results_string) + return results_dict + + +def eval_knn_with_model(*, model: torch.nn.Module, autocast_dtype, config: KnnEvalConfig): + start = time.time() + cudnn.benchmark = True + + # Setting up datasets + transform = make_transform(config.transform) + train_dataset = make_dataset( + dataset_str=config.train.dataset, + transform=transform, + target_transform=get_target_transform(config.train.dataset), + ) + train_dataset_dict = create_train_dataset_dict( + train_dataset, + few_shot_eval=config.few_shot.enable, + few_shot_k_or_percent=config.few_shot.k_or_percent, + few_shot_n_tries=config.few_shot.n_tries, + ) + + # Setting up metrics + num_classes = get_num_classes(train_dataset) + metric_collection = build_classification_metric(config.eval.test_metric_type, num_classes=num_classes) + config.eval.batch_size = config.eval.batch_size or config.train.batch_size + test_data_loader = make_test_data_loader(config.eval, transform) + + # Setting up save results function + save_results_func = None + if config.save_results: + save_results_func = partial(default_save_results_func, output_dir=config.output_dir) + + model = ModelWithNormalize(model) + with torch.autocast("cuda", dtype=autocast_dtype): + logger.info("Extracting features for train set...") + train_data_dict = extract_features_for_dataset_dict( + model, train_dataset_dict, config.train.batch_size, config.train.num_workers, gather_on_cpu=True + ) + results_dict_knn = eval_knn( + model=model, + train_data_dict=train_data_dict, + test_data_loader=test_data_loader, + metric_collection=metric_collection, + knn_config=config.train, + num_classes=num_classes, + save_results_func=save_results_func, + ) + results_dict = _log_and_format_results_dict(results_dict_knn, config.few_shot.n_tries) + + # TODO: Remove as cleaner writers are used + metrics_file_path = os.path.join(config.output_dir, "results_eval_knn.json") + with open(metrics_file_path, "a") as f: + for k, v in results_dict.items(): + f.write(json.dumps({k: v}) + "\n") + + if distributed.is_enabled(): + torch.distributed.barrier() + logger.info(f"Knn evaluation done in {int(time.time() - start)}s") + return results_dict + + +def benchmark_launcher(eval_args: dict[str, object]) -> dict[str, Any]: + """Initialization of distributed and logging are preconditions for this method""" + dataclass_config, output_dir = args_dict_to_dataclass(eval_args=eval_args, config_dataclass=KnnEvalConfig) + model, model_context = load_model_and_context(dataclass_config.model, output_dir=output_dir) + results_dict = eval_knn_with_model( + model=model, config=dataclass_config, autocast_dtype=model_context["autocast_dtype"] + ) + write_results(results_dict, output_dir, RESULTS_FILENAME) + return results_dict + + +def main(argv=None): + if argv is None: + argv = sys.argv[1:] + eval_args = cli_parser(argv) + with job_context(output_dir=eval_args["output_dir"]): + benchmark_launcher(eval_args=eval_args) + return 0 + + +if __name__ == "__main__": + main() diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/linear.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..e091d24398581bd40de913008a19767decb2c8aa --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/linear.py @@ -0,0 +1,688 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import json +import logging +import os +import sys +import time +from dataclasses import dataclass, field +from enum import Enum +from functools import partial +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Tuple + +import torch +import torch.backends.cudnn as cudnn +import torch.nn as nn +from omegaconf import MISSING +from torch.nn.parallel import DistributedDataParallel + +import dinov3.distributed as distributed +from dinov3.checkpointer import ( + CheckpointRetentionPolicy, + cleanup_checkpoint, + find_latest_checkpoint, + keep_last_n_checkpoints, +) +from dinov3.data import SamplerType, make_data_loader, make_dataset +from dinov3.data.adapters import DatasetWithEnumeratedTargets +from dinov3.data.transforms import ( + CROP_DEFAULT_SIZE, + RESIZE_DEFAULT_SIZE, + make_classification_eval_transform, + make_classification_train_transform, +) +from dinov3.eval.data import create_train_dataset_dict, get_num_classes, pad_multilabel_and_collate +from dinov3.eval.helpers import args_dict_to_dataclass, cli_parser, write_results +from dinov3.eval.metrics import ClassificationMetricType, build_classification_metric +from dinov3.eval.setup import ModelConfig, load_model_and_context +from dinov3.eval.utils import LossType, ModelWithIntermediateLayers, average_metrics, evaluate +from dinov3.eval.utils import save_results as default_save_results_func +from dinov3.logging import MetricLogger, SmoothedValue +from dinov3.run.init import job_context + +logger = logging.getLogger("dinov3") + +RESULTS_FILENAME = "results-linear.csv" +# Can be several keys, depending on if multiple test sets are chosen and if doing few-shot +MAIN_METRICS = [".*_accuracy(_mean)?"] + + +class OptimizerType(Enum): + SGD = "sgd" + ADAMW = "adamw" + + def get_optimizer(self, optim_param_groups): + if self == OptimizerType.ADAMW: + optimizer = torch.optim.AdamW(optim_param_groups, weight_decay=0) + else: + optimizer = torch.optim.SGD(optim_param_groups, momentum=0.9, weight_decay=0) + return optimizer + + +class SchedulerType(Enum): + COSINE_ANNEALING = "cosine_annealing" + ONE_CYCLE = "one_cycle" + + def get_scheduler(self, optimizer, optim_param_groups, epoch_length, epochs, max_iter): + if self == SchedulerType.ONE_CYCLE: + lr_list = [optim_param_groups[i]["lr"] for i in range(len(optim_param_groups))] + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, max_lr=lr_list, steps_per_epoch=epoch_length, epochs=epochs + ) + else: + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0) + return scheduler + + +_DEFAULT_LR_LIST: Tuple[float, ...] = (1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 0.1) + + +@dataclass +class TrainConfig: + dataset: str = MISSING # train dataset path + val_dataset: str = MISSING # val dataset path + val_metric_type: ClassificationMetricType = ClassificationMetricType.MEAN_ACCURACY + batch_size: int = 128 # batch size (per GPU) + num_workers: int = 8 + # Linear Head Parameters + learning_rates: Tuple[float, ...] = _DEFAULT_LR_LIST # learning rates to grid search + n_last_blocks_list: Tuple[int] = (1,) # number of backbone last blocks used for the linear classifier + loss_type: LossType = LossType.CROSS_ENTROPY + optimizer_type: OptimizerType = OptimizerType.SGD + scheduler_type: SchedulerType = SchedulerType.COSINE_ANNEALING + epochs: int = 10 # number of training epochs + epoch_length: int = 1250 # length of an epoch in number of iterations + save_checkpoint_iterations: int | None = ( + None # number of iterations between two checkpoint saves (default: one epoch) + ) + eval_period_iterations: int | None = None # number of iterations between two evaluations (default: one epoch) + checkpoint_retention_policy: CheckpointRetentionPolicy = CheckpointRetentionPolicy.NONE # keep checkpoints or not + resume: bool = True # whether to resume from existing checkpoints + classifier_fpath: Optional[str] = None # path to a file containing pretrained linear classifiers + + +@dataclass +class EvalConfig: + test_datasets: Tuple[str, ...] = () # additional test dataset paths + test_metric_types: Tuple[ClassificationMetricType, ...] = () + batch_size: int = 256 # batch size (per GPU) + num_workers: int = 8 + + +@dataclass +class TransformConfig: + resize_size: int = RESIZE_DEFAULT_SIZE + crop_size: int = CROP_DEFAULT_SIZE + + +@dataclass +class FewShotConfig: + enable: bool = False # whether to use few-shot evaluation + k_or_percent: Optional[float] = None # number of elements or % to take per class + n_tries: int = 1 # number of tries for few-shot evaluation + + +@dataclass +class LinearEvalConfig: + model: ModelConfig + train: TrainConfig = field(default_factory=TrainConfig) + eval: EvalConfig = field(default_factory=EvalConfig) + transform: TransformConfig = field(default_factory=TransformConfig) + few_shot: FewShotConfig = field(default_factory=FewShotConfig) + save_results: bool = False # save predictions and targets in the output directory + output_dir: str = "" + + +def has_ddp_wrapper(m: nn.Module) -> bool: + return isinstance(m, DistributedDataParallel) + + +def remove_ddp_wrapper(m: nn.Module) -> nn.Module: + return m.module if has_ddp_wrapper(m) else m + + +def create_linear_input(x_tokens_list, use_n_blocks, use_avgpool): + intermediate_output = x_tokens_list[-use_n_blocks:] + output = torch.cat([class_token for _, class_token in intermediate_output], dim=-1) + if use_avgpool: + output = torch.cat( + ( + output, + torch.mean(intermediate_output[-1][0], dim=1), # patch tokens + ), + dim=-1, + ) + output = output.reshape(output.shape[0], -1) + return output.float() + + +class LinearClassifier(nn.Module): + """Linear layer to train on top of frozen features""" + + def __init__(self, out_dim, use_n_blocks, use_avgpool, num_classes=1000): + super().__init__() + self.out_dim = out_dim + self.use_n_blocks = use_n_blocks + self.use_avgpool = use_avgpool + self.num_classes = num_classes + self.linear = nn.Linear(out_dim, num_classes) + self.linear.weight.data.normal_(mean=0.0, std=0.01) + self.linear.bias.data.zero_() + + def forward(self, x_tokens_list): + output = create_linear_input(x_tokens_list, self.use_n_blocks, self.use_avgpool) + return self.linear(output) + + +class AllClassifiers(nn.Module): + def __init__(self, classifiers_dict): + super().__init__() + self.classifiers_dict = nn.ModuleDict() + self.classifiers_dict.update(classifiers_dict) + + def forward(self, inputs): + return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()} + + def __len__(self): + return len(self.classifiers_dict) + + +class LinearPostprocessor(nn.Module): + def __init__(self, linear_classifier, class_mapping=None): + super().__init__() + self.linear_classifier = linear_classifier + self.register_buffer("class_mapping", None if class_mapping is None else torch.LongTensor(class_mapping)) + + def forward(self, samples, targets): + preds = self.linear_classifier(samples) + return { + "preds": preds[:, self.class_mapping] if self.class_mapping is not None else preds, + "target": targets, + } + + +def scale_lr(learning_rates, batch_size): + return learning_rates * (batch_size * distributed.get_world_size()) / 256.0 + + +def setup_linear_classifiers(sample_output, n_last_blocks_list, learning_rates, batch_size, num_classes=1000): + linear_classifiers_dict = nn.ModuleDict() + optim_param_groups = [] + for n in n_last_blocks_list: + for avgpool in [True]: + for _lr in learning_rates: + lr = scale_lr(_lr, batch_size) + out_dim = create_linear_input(sample_output, use_n_blocks=n, use_avgpool=avgpool).shape[1] + linear_classifier = LinearClassifier( + out_dim, use_n_blocks=n, use_avgpool=avgpool, num_classes=num_classes + ) + linear_classifier = linear_classifier.cuda() + linear_classifiers_dict[ + f"classifier_{n}_blocks_avgpool_{avgpool}_lr_{lr:.5f}".replace(".", "_") + ] = linear_classifier + optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr}) + + linear_classifiers = AllClassifiers(linear_classifiers_dict) + if distributed.is_enabled(): + linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers) + + return linear_classifiers, optim_param_groups + + +def make_eval_transform(config: TransformConfig): + if config.resize_size / config.crop_size != 256 / 224: + logger.warning( + f"Default resize / crop ratio is 256 / 224, here we have {config.resize_size} / {config.crop_size}" + ) + transform = make_classification_eval_transform(resize_size=config.resize_size, crop_size=config.crop_size) + return transform + + +def make_eval_data_loader( + *, + test_dataset_str, + transform_config, + batch_size, + num_workers, + metric_type, +): + transform = make_eval_transform(transform_config) + test_dataset = make_dataset(dataset_str=test_dataset_str, transform=transform) + + class_mapping = None + if hasattr(test_dataset, "get_imagenet_class_mapping"): + class_mapping = test_dataset.get_imagenet_class_mapping() + + test_data_loader = make_data_loader( + dataset=DatasetWithEnumeratedTargets(test_dataset, pad_dataset=True, num_replicas=distributed.get_world_size()), + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + persistent_workers=False, + collate_fn=pad_multilabel_and_collate if metric_type == ClassificationMetricType.ANY_MATCH_ACCURACY else None, + ) + return test_data_loader, class_mapping + + +@dataclass +class Evaluator: + batch_size: int + num_workers: int + transform_config: TransformConfig + dataset_str: str + metric_type: ClassificationMetricType + metrics_file_path: str + training_num_classes: int + save_results_func: Optional[Callable] + + def __post_init__(self): + self.data_loader, self.class_mapping = make_eval_data_loader( + test_dataset_str=self.dataset_str, + batch_size=self.batch_size, + num_workers=self.num_workers, + transform_config=self.transform_config, + metric_type=self.metric_type, + ) + self.main_metric_name = f"{self.dataset_str}_accuracy" + + @torch.no_grad() + def _evaluate_linear_classifiers( + self, + *, + feature_model, + linear_classifiers, + iteration, + prefixstring="", + best_classifier_on_val=None, + accumulate_results=False, + ) -> Tuple[Dict[str, Any], Optional[Dict[str, torch.Tensor]]]: + logger.info("running validation !") + + num_classes = len(self.class_mapping) if self.class_mapping is not None else self.training_num_classes + metric = build_classification_metric(self.metric_type, num_classes=num_classes) + postprocessors = { + k: LinearPostprocessor(v, self.class_mapping) for k, v in linear_classifiers.classifiers_dict.items() + } + metrics = {k: metric.clone() for k in linear_classifiers.classifiers_dict} + + _, results_dict_temp, accumulated_results = evaluate( + feature_model, + self.data_loader, + postprocessors, + metrics, + torch.cuda.current_device(), + accumulate_results=accumulate_results, + ) + + logger.info("") + results_dict = {} + max_accuracy = 0 + best_classifier = "" + for _, (classifier_string, metric) in enumerate(results_dict_temp.items()): + logger.info(f"{prefixstring} -- Classifier: {classifier_string} * {metric}") + if ( + best_classifier_on_val is None and metric["top-1"].item() > max_accuracy + ) or classifier_string == best_classifier_on_val: + max_accuracy = metric["top-1"].item() + best_classifier = classifier_string + + results_dict["best_classifier"] = {"name": best_classifier, "accuracy": max_accuracy} + + logger.info(f"best classifier: {results_dict['best_classifier']}") + + accumulated_best_results = None + if accumulated_results is not None: + accumulated_best_results = accumulated_results[best_classifier] + + if distributed.is_main_process(): + with open(self.metrics_file_path, "a") as f: + f.write(f"iter: {iteration}\n") + for k, v in results_dict.items(): + f.write(json.dumps({k: v}) + "\n") + f.write("\n") + + return results_dict, accumulated_best_results + + def evaluate_and_maybe_save( + self, + feature_model, + linear_classifiers, + iteration: int, + best_classifier_on_val: Optional[Any] = None, + save_filename_suffix: str = "", + prefixstring: str = "", + ): + logger.info(f"Testing on {self.dataset_str}") + save_results = self.save_results_func is not None + full_results_dict, accumulated_best_results = self._evaluate_linear_classifiers( + feature_model=feature_model, + linear_classifiers=remove_ddp_wrapper(linear_classifiers), + iteration=iteration, + prefixstring=prefixstring, + best_classifier_on_val=best_classifier_on_val, + accumulate_results=save_results, + ) + if self.save_results_func is not None: + self.save_results_func( + filename_suffix=f"{self.dataset_str}{save_filename_suffix}", **accumulated_best_results + ) + + results_dict = { + self.main_metric_name: 100.0 * full_results_dict["best_classifier"]["accuracy"], + "best_classifier": full_results_dict["best_classifier"]["name"], + } + return results_dict + + +def make_evaluators( + eval_config: EvalConfig, + val_metric_type: ClassificationMetricType, + val_dataset: str, + transform_config: TransformConfig, + metrics_file_path: str, + training_num_classes: int, + save_results_func: Optional[Callable], +): + test_metric_types = eval_config.test_metric_types + if len(test_metric_types) == 0: + test_metric_types = (val_metric_type,) * len(eval_config.test_datasets) + else: + assert len(test_metric_types) == len(eval_config.test_datasets) + val_evaluator, *test_evaluators = [ + Evaluator( + dataset_str=dataset_str, + batch_size=eval_config.batch_size, + num_workers=eval_config.num_workers, + transform_config=transform_config, + metric_type=metric_type, + metrics_file_path=metrics_file_path, + training_num_classes=training_num_classes, + save_results_func=save_results_func, + ) + for dataset_str, metric_type in zip( + (val_dataset,) + tuple(eval_config.test_datasets), + (val_metric_type,) + tuple(test_metric_types), + ) + ] + return val_evaluator, test_evaluators + + +def setup_linear_training( + *, + config: TrainConfig, + sample_output: torch.Tensor, + training_num_classes: int, + checkpoint_output_dir: str, +): + linear_classifiers, optim_param_groups = setup_linear_classifiers( + sample_output, + config.n_last_blocks_list, + config.learning_rates, + config.batch_size, + training_num_classes, + ) + max_iter = config.epochs * config.epoch_length + optimizer = config.optimizer_type.get_optimizer(optim_param_groups=optim_param_groups) + scheduler = config.scheduler_type.get_scheduler( + optimizer=optimizer, + optim_param_groups=optim_param_groups, + epoch_length=config.epoch_length, + epochs=config.epochs, + max_iter=max_iter, + ) + start_iter = 0 + best_accuracy = -1 + if config.resume and ( + last_checkpoint_dir := find_latest_checkpoint(config.classifier_fpath or checkpoint_output_dir) + ): + logger.info(f"Checkpoint found {last_checkpoint_dir}") + checkpoint = torch.load(last_checkpoint_dir / "checkpoint.pth") + start_iter = checkpoint.get("iteration", -1) + 1 + best_accuracy = checkpoint.get("best_accuracy", -1) + linear_classifiers.load_state_dict(checkpoint["linear_classifiers"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + + if config.loss_type == LossType.BINARY_CROSS_ENTROPY: + criterion = nn.BCEWithLogitsLoss() + else: + criterion = nn.CrossEntropyLoss() + + return ( + linear_classifiers, + start_iter, + max_iter, + criterion, + optimizer, + scheduler, + best_accuracy, + ) + + +def train_linear_classifiers( + *, + feature_model, + train_dataset, + train_config: TrainConfig, + training_num_classes: int, + val_evaluator: Evaluator, + checkpoint_output_dir: str, +): + (linear_classifiers, start_iter, max_iter, criterion, optimizer, scheduler, best_accuracy,) = setup_linear_training( + config=train_config, + sample_output=feature_model(train_dataset[0][0].unsqueeze(0).cuda()), + training_num_classes=training_num_classes, + checkpoint_output_dir=checkpoint_output_dir, + ) + checkpoint_period = train_config.save_checkpoint_iterations or train_config.epoch_length + eval_period = train_config.eval_period_iterations or train_config.epoch_length + + sampler_type = SamplerType.INFINITE + train_data_loader = make_data_loader( + dataset=train_dataset, + batch_size=train_config.batch_size, + num_workers=train_config.num_workers, + shuffle=True, + seed=0, + sampler_type=sampler_type, + sampler_advance=start_iter, + drop_last=True, + persistent_workers=True, + ) + + iteration = start_iter + logger.info("Starting training from iteration {}".format(start_iter)) + metric_logger = MetricLogger(delimiter=" ") + metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6g}")) + header = "Training" + for data, labels in metric_logger.log_every( + train_data_loader, + 10, + header, + max_iter, + start_iter, + ): + data = data.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + features = feature_model(data) + outputs = linear_classifiers(features) + + if len(labels.shape) > 1: + labels = labels.float() + losses = {f"loss_{k}": criterion(v, labels) for k, v in outputs.items()} + loss = sum(losses.values()) + + # compute the gradients + optimizer.zero_grad() + loss.backward() + + # step + optimizer.step() + scheduler.step() + + # log + if iteration % 10 == 0: + torch.cuda.synchronize() + metric_logger.update(loss=loss.item()) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + # Checkpointing + is_last_iteration = (iteration + 1) == max_iter + is_ckpt_iteration = ((iteration + 1) % checkpoint_period == 0) or is_last_iteration + if is_ckpt_iteration: + ckpt_dir = Path(checkpoint_output_dir).expanduser() + if distributed.is_subgroup_main_process(): + ckpt_sub_dir = "final" if is_last_iteration else str(iteration) + (ckpt_dir / ckpt_sub_dir).mkdir(parents=True, exist_ok=True) + checkpoint = { + "iteration": iteration, + "linear_classifiers": linear_classifiers.state_dict(), + "optimizer": optimizer.state_dict(), + "best_accuracy": best_accuracy, + } + torch.save(checkpoint, ckpt_dir / ckpt_sub_dir / "checkpoint.pth") + keep_last_n_checkpoints(ckpt_dir, train_config.checkpoint_retention_policy.max_to_keep) + + if eval_period > 0 and (iteration + 1) % eval_period == 0 and iteration != max_iter - 1: + val_results_dict = val_evaluator.evaluate_and_maybe_save( + feature_model=feature_model, + linear_classifiers=linear_classifiers, + prefixstring=f"ITER: {iteration}", + iteration=iteration, + ) + val_accuracy = val_results_dict[val_evaluator.main_metric_name] + if val_accuracy >= best_accuracy: + best_accuracy = val_accuracy + (ckpt_dir / "best").mkdir(parents=True, exist_ok=True) + checkpoint = { + "iteration": iteration, + "linear_classifiers": linear_classifiers.state_dict(), + "optimizer": optimizer.state_dict(), + "best_accuracy": best_accuracy, + } + torch.save(checkpoint, ckpt_dir / "best" / "checkpoint.pth") + torch.distributed.barrier() + + iteration = iteration + 1 + + return feature_model, linear_classifiers, iteration + + +def make_train_transform(config: TransformConfig): + train_transform = make_classification_train_transform(crop_size=config.crop_size) + return train_transform + + +def make_train_dataset(train_dataset: str, transform_config: TransformConfig): + train_transform = make_train_transform(transform_config) + return make_dataset(dataset_str=train_dataset, transform=train_transform) + + +def eval_linear_with_model(*, model: torch.nn.Module, autocast_dtype, config: LinearEvalConfig): + start = time.time() + cudnn.benchmark = True + + train_dataset = make_train_dataset(config.train.dataset, config.transform) + training_num_classes = get_num_classes(train_dataset) + train_dataset_dict = create_train_dataset_dict( + train_dataset, + few_shot_eval=config.few_shot.enable, + few_shot_k_or_percent=config.few_shot.k_or_percent, + few_shot_n_tries=config.few_shot.n_tries, + ) + n_last_blocks = max(config.train.n_last_blocks_list) + autocast_ctx = partial(torch.autocast, device_type="cuda", enabled=True, dtype=autocast_dtype) + feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx) + + save_results_func = None + if config.save_results: + save_results_func = partial(default_save_results_func, output_dir=config.output_dir) + + metrics_file_path = os.path.join(config.output_dir, "results_eval_linear.json") + val_evaluator, test_evaluators = make_evaluators( + eval_config=config.eval, + val_metric_type=config.train.val_metric_type, + val_dataset=config.train.val_dataset, + transform_config=config.transform, + metrics_file_path=metrics_file_path, + training_num_classes=training_num_classes, + save_results_func=save_results_func, + ) + results_dict = {} + checkpoint_output_dirs: list = [] + for _try in train_dataset_dict.keys(): + if len(train_dataset_dict) > 1: + checkpoint_output_dir = os.path.join(config.output_dir, f"checkpoints_{_try}") + save_filename_suffix = f"_{_try}" + else: + checkpoint_output_dir = os.path.join(config.output_dir, "checkpoints") + save_filename_suffix = "" + os.makedirs(checkpoint_output_dir, exist_ok=True) + + feature_model, linear_classifiers, iteration = train_linear_classifiers( + feature_model=feature_model, + train_dataset=train_dataset_dict[_try], + train_config=config.train, + training_num_classes=training_num_classes, + val_evaluator=val_evaluator, + checkpoint_output_dir=checkpoint_output_dir, + ) + checkpoint_output_dirs.append(checkpoint_output_dir) + results_dict[_try] = val_evaluator.evaluate_and_maybe_save( + feature_model=feature_model, + linear_classifiers=linear_classifiers, + iteration=iteration, + save_filename_suffix=save_filename_suffix, + ) + for test_evaluator in test_evaluators: + eval_results_dict = test_evaluator.evaluate_and_maybe_save( + feature_model=feature_model, + linear_classifiers=linear_classifiers, + iteration=iteration, + best_classifier_on_val=results_dict[_try]["best_classifier"], + save_filename_suffix=save_filename_suffix, + ) + results_dict[_try] = {**eval_results_dict, **results_dict[_try]} + + if len(train_dataset_dict) > 1: + results_dict = average_metrics(results_dict, ignore_keys=["best_classifier"]) + else: + results_dict = {**results_dict[_try]} + + for checkpoint_output_dir in checkpoint_output_dirs: + if distributed.is_subgroup_main_process(): + cleanup_checkpoint(checkpoint_output_dir, config.train.checkpoint_retention_policy) + + logger.info("Test Results Dict " + str(results_dict)) + logger.info(f"Linear evaluation done in {int(time.time() - start)}s") + return results_dict + + +def benchmark_launcher(eval_args: dict[str, object]) -> dict[str, Any]: + """Initialization of distributed and logging are preconditions for this method""" + dataclass_config, output_dir = args_dict_to_dataclass(eval_args=eval_args, config_dataclass=LinearEvalConfig) + model, model_context = load_model_and_context(dataclass_config.model, output_dir=output_dir) + results_dict = eval_linear_with_model( + model=model, config=dataclass_config, autocast_dtype=model_context["autocast_dtype"] + ) + write_results(results_dict, output_dir, RESULTS_FILENAME) + return results_dict + + +def main(argv=None): + if argv is None: + argv = sys.argv[1:] + eval_args = cli_parser(argv) + with job_context(output_dir=eval_args["output_dir"]): + benchmark_launcher(eval_args=eval_args) + return 0 + + +if __name__ == "__main__": + main() diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/log_regression.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/log_regression.py new file mode 100644 index 0000000000000000000000000000000000000000..e1a3e7c8100c539ff71a13f713eedd7bb739810a --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/log_regression.py @@ -0,0 +1,422 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +import sys +import time +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Dict, List, Optional + +import torch +import torch.backends.cudnn as cudnn +import torch.distributed +from omegaconf import MISSING +from torch import nn +from torch.utils.data import TensorDataset +from torchmetrics import MetricTracker + +from dinov3.data import SamplerType, make_data_loader, make_dataset +from dinov3.data.adapters import DatasetWithEnumeratedTargets +from dinov3.data.transforms import CROP_DEFAULT_SIZE, get_target_transform, make_classification_eval_transform +from dinov3.distributed import get_rank, get_world_size +from dinov3.eval.data import ( + create_train_dataset_dict, + extract_features_for_dataset_dict, + get_num_classes, + split_train_val_datasets, +) +from dinov3.eval.helpers import args_dict_to_dataclass, cli_parser, write_results +from dinov3.eval.metrics import ClassificationMetricType, build_classification_metric +from dinov3.eval.setup import ModelConfig, load_model_and_context +from dinov3.eval.utils import average_metrics, evaluate, extract_features +from dinov3.eval.utils import save_results as default_save_results_func +from dinov3.run.init import job_context +from dinov3.utils.dtype import as_torch_dtype + +logger = logging.getLogger("dinov3") + + +RESULTS_FILENAME = "results-log-regression.csv" +MAIN_METRICS = ["top-1(_mean)?"] + + +try: + from sklearnex import patch_sklearn + + patch_sklearn() +except ImportError: + logger.warning("Can't import sklearnex. If installed, that speeds up scikit-learn 10-100x") + +try: + from sklearn.linear_model import LogisticRegression as sklearnLogisticRegression + from sklearn.multiclass import OneVsRestClassifier +except ImportError: + logger.warning("Can't import scikit-learn. This is necessary for evaluating log regression") + raise ImportError + + +C_POWER_RANGE = torch.linspace(-6, 5, 45) +_CPU_DEVICE = torch.device("cpu") + + +@dataclass +class TrainConfig: + dataset: str = MISSING # train dataset path + val_dataset: Optional[str] = None # val dataset path. If None, choose hyperparameters on 10% of the train set. + val_metric_type: ClassificationMetricType = ClassificationMetricType.MEAN_ACCURACY + batch_size: int = 256 # batch size for train and val set feature extraction + num_workers: int = 5 # number of workers for train and val set feature extraction + tol: float = 1e-12 # tolerance in logistic regression + train_features_device: str = "cpu" # device to gather train features (cpu, cuda, cuda:0, etc.) + train_dtype: str = "float64" # data type to convert the train features to + max_train_iters: int = 1_000 # maximum number of train iterations in logistic regression + + +@dataclass +class EvalConfig: + test_dataset: str = MISSING # test dataset path + batch_size: int | None = None # use train.batch_size if None + num_workers: int = 5 + test_metric_type: Optional[ClassificationMetricType] = None + + +@dataclass +class TransformConfig: + resize_size: int = CROP_DEFAULT_SIZE + crop_size: int = CROP_DEFAULT_SIZE + + +@dataclass +class FewShotConfig: + enable: bool = False # whether to use few-shot evaluation + k_or_percent: Optional[float] = None # number of elements or % to take per class + n_tries: int = 1 # number of tries for few-shot evaluation + + +@dataclass +class LogregEvalConfig: + model: ModelConfig + train: TrainConfig = field(default_factory=TrainConfig) + eval: EvalConfig = field(default_factory=EvalConfig) + transform: TransformConfig = field(default_factory=TransformConfig) + few_shot: FewShotConfig = field(default_factory=FewShotConfig) + save_results: bool = False # save predictions and targets in the output directory + output_dir: str = "" + + +class LogRegModule(nn.Module): + def __init__(self, C, multi_label=False, logreg_config=TrainConfig): + super().__init__() + self.dtype = as_torch_dtype(logreg_config.train_dtype) + self.device = torch.device(logreg_config.train_features_device) + assert self.device == _CPU_DEVICE, f"SKLearn can only work on CPU device, got {self.device}" + self.estimator = sklearnLogisticRegression( + penalty="l2", + solver="lbfgs", + C=C, + max_iter=logreg_config.max_train_iters, + n_jobs=-1, + tol=logreg_config.tol, + ) + if multi_label: + self.estimator = OneVsRestClassifier(self.estimator, n_jobs=-1) + + def forward(self, samples, targets): + samples_device = samples.device + samples = samples.to(dtype=self.dtype, device=self.device) + if self.device == _CPU_DEVICE: + samples = samples.numpy() + probas = self.estimator.predict_proba(samples) + return {"preds": torch.from_numpy(probas).to(samples_device), "target": targets} + + def fit(self, train_features, train_labels): + train_features = train_features.to(dtype=self.dtype, device=self.device) + train_labels = train_labels.to(dtype=self.dtype, device=self.device) + if self.device == _CPU_DEVICE: + # both cuml and sklearn only work with numpy arrays on CPU + train_features = train_features.numpy() + train_labels = train_labels.numpy() + self.estimator.fit(train_features, train_labels) + + +def evaluate_logreg_model(*, logreg_model, test_metric, test_data_loader, save_results_func=None): + key = "metrics" # We need only one key as we have only one metric + postprocessors, metrics = {key: logreg_model}, {key: test_metric} + _, eval_metrics, accumulated_results = evaluate( + nn.Identity(), + test_data_loader, + postprocessors, + metrics, + torch.cuda.current_device(), + accumulate_results=save_results_func is not None, + ) + if save_results_func is not None: + save_results_func(**accumulated_results[key]) + return eval_metrics + + +def train_for_C(*, C, train_features, train_labels, logreg_config: TrainConfig): + logreg_model = LogRegModule(C, multi_label=len(train_labels.shape) > 1, logreg_config=logreg_config) + logreg_model.fit(train_features, train_labels) + return logreg_model + + +def sweep_C_values( + *, + train_features, + train_labels, + val_data_loader, + val_metric, + logreg_config: TrainConfig, +): + metric_tracker = MetricTracker(val_metric, maximize=True) + ALL_C = 10**C_POWER_RANGE + logreg_models: Dict[float, Any] = {} + + train_features_device = torch.device(logreg_config.train_features_device) + train_dtype = as_torch_dtype(logreg_config.train_dtype) + train_features = train_features.to(dtype=train_dtype, device=train_features_device) + train_labels = train_labels.to(device=train_features_device) + + for i in range(get_rank(), len(ALL_C), get_world_size()): + C = ALL_C[i].item() + logger.info( + f"Training for C = {C:.4g}, dtype={train_dtype}, " + f"features: {train_features.shape}, {train_features.dtype}, " + f"labels: {train_labels.shape}, {train_labels.dtype}" + ) + logreg_models[C] = train_for_C( + C=C, + train_features=train_features, + train_labels=train_labels, + logreg_config=logreg_config, + ) + + gather_list: List[Dict[float, Any]] = [{} for _ in range(get_world_size())] + torch.distributed.all_gather_object(gather_list, logreg_models) + + for logreg_dict in gather_list: + logreg_models.update(logreg_dict) + gather_list.clear() + + for i in range(len(ALL_C)): + metric_tracker.increment() + C = ALL_C[i].item() + evals = evaluate_logreg_model( + logreg_model=logreg_models.pop(C), + test_metric=metric_tracker, + test_data_loader=val_data_loader, + ) + logger.info(f"Trained for C = {C:.4g}, accuracies = {evals}") + best_stats, which_epoch = metric_tracker.best_metric(return_step=True) + best_stats_100 = {k: 100.0 * v for k, v in best_stats.items()} + if which_epoch["top-1"] == i: + best_C = C + logger.info(f"Sweep best {best_stats_100}, best C = {best_C:.4g}") + + return best_stats, best_C + + +def make_logreg_data_loader(batch_size: int, num_workers: int, features: torch.Tensor, labels: torch.Tensor): + return make_data_loader( + dataset=DatasetWithEnumeratedTargets( + TensorDataset(features, labels), pad_dataset=True, num_replicas=get_world_size() + ), + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + ) + + +def get_best_logreg_with_features( + *, + train_features: torch.Tensor, + train_labels: torch.Tensor, + val_features: torch.Tensor, + val_labels: torch.Tensor, + val_metric, + concatenate_train_val: bool, + train_config: TrainConfig, +): + val_data_loader = make_logreg_data_loader( + train_config.batch_size, train_config.num_workers, val_features, val_labels + ) + _, best_C_t = sweep_C_values( + train_features=train_features, + train_labels=train_labels, + val_data_loader=val_data_loader, + val_metric=val_metric, + logreg_config=train_config, + ) + if concatenate_train_val: + logger.info("Best parameter found, concatenating features") + train_features = torch.cat((train_features, val_features)) + train_labels = torch.cat((train_labels, val_labels)) + + logger.info("Training final model") + + logreg_model = train_for_C( + C=best_C_t, + logreg_config=train_config, + train_features=train_features, + train_labels=train_labels, + ) + return logreg_model + + +def make_transform(config: TransformConfig): + if config.resize_size / config.crop_size != 1: + logger.warning(f"Default resize / crop ratio is 1, here we have {config.resize_size} / {config.crop_size}") + transform = make_classification_eval_transform(resize_size=config.resize_size, crop_size=config.crop_size) + return transform + + +def make_train_val_datasets(train_config: TrainConfig, few_shot_config: FewShotConfig, transform): + train_dataset = make_dataset( + dataset_str=train_config.dataset, + transform=transform, + target_transform=get_target_transform(train_config.dataset), + ) + if train_config.val_dataset is not None: + val_dataset = make_dataset( + dataset_str=train_config.val_dataset, + transform=transform, + target_transform=get_target_transform(train_config.val_dataset), + ) + else: + split_percentage = 0.01 if few_shot_config.enable else 0.1 + train_dataset, val_dataset = split_train_val_datasets(train_dataset, split_percentage=split_percentage) + + train_dataset_dict = create_train_dataset_dict( + train_dataset, + few_shot_eval=few_shot_config.enable, + few_shot_k_or_percent=few_shot_config.k_or_percent, + few_shot_n_tries=few_shot_config.n_tries, + ) + num_classes = get_num_classes(train_dataset) + return train_dataset_dict, val_dataset, num_classes + + +def make_test_dataset_and_data_loader(model, config: EvalConfig, transform, gather_on_cpu: bool): + test_dataset = make_dataset( + dataset_str=config.test_dataset, + transform=transform, + target_transform=get_target_transform(config.test_dataset), + ) + test_features, test_labels = extract_features( + model, test_dataset, config.batch_size, config.num_workers, gather_on_cpu=gather_on_cpu + ) + assert isinstance(config.batch_size, int) # eval batch size has been replaced by train batch size if None + test_data_loader = make_logreg_data_loader(config.batch_size, config.num_workers, test_features, test_labels) + return test_dataset, test_data_loader + + +def eval_log_regression_with_model(*, model: torch.nn.Module, autocast_dtype, config: LogregEvalConfig): + """ + Implements the "standard" process for log regression evaluation: + The value of C is chosen by training on train_dataset and evaluating on + val_dataset. Then, the final model is trained on a concatenation of + train_dataset and val_dataset, and is evaluated on test_dataset. + If there is no val_dataset, the value of C is the one that yields + the best results on a random 10% subset of the train dataset + """ + start = time.time() + cudnn.benchmark = True + + transform = make_transform(config.transform) + config.eval.batch_size = config.eval.batch_size or config.train.batch_size # use train batch size for eval if None + + # Setting up train and val datasets + train_dataset_dict, val_dataset, num_classes = make_train_val_datasets(config.train, config.few_shot, transform) + + # Extracting features + with torch.autocast("cuda", dtype=autocast_dtype): + gather_on_cpu = torch.device(config.train.train_features_device) == _CPU_DEVICE + train_data_dict = extract_features_for_dataset_dict( + model, train_dataset_dict, config.train.batch_size, config.train.num_workers, gather_on_cpu=gather_on_cpu + ) + logger.info("Choosing hyperparameters on the val dataset") + val_features, val_labels = extract_features( + model, val_dataset, config.train.batch_size, config.train.num_workers, gather_on_cpu=gather_on_cpu + ) + test_dataset, test_data_loader = make_test_dataset_and_data_loader(model, config.eval, transform, gather_on_cpu) + + # Moves the model to cpu in-place. Deleting the variable would only delete a reference and not free any space. + model.cpu() # all features are extracted, we won't use the backbone anymore + torch.cuda.empty_cache() + + # Setting up metrics + val_metric = build_classification_metric(config.train.val_metric_type, num_classes=num_classes, dataset=val_dataset) + test_metric_type = config.eval.test_metric_type or config.train.val_metric_type + test_metric = build_classification_metric(test_metric_type, num_classes=num_classes, dataset=test_dataset) + + # Setting up save results function + save_results_func = None + if config.save_results: + save_results_func = partial(default_save_results_func, output_dir=config.output_dir) + + results_dict = {} + for _try in train_data_dict.keys(): + logreg_model = get_best_logreg_with_features( + train_features=train_data_dict[_try]["train_features"], + train_labels=train_data_dict[_try]["train_labels"], + val_features=val_features, + val_labels=val_labels, + val_metric=val_metric, + concatenate_train_val=not config.few_shot.enable, + train_config=config.train, + ) + if len(train_data_dict) > 1 and save_results_func is not None: # add suffix + split_results_saver = partial(save_results_func, filename_suffix=str(_try)) + else: + split_results_saver = save_results_func # type: ignore + + eval_metrics = evaluate_logreg_model( + logreg_model=logreg_model, + test_metric=test_metric.clone(), + test_data_loader=test_data_loader, + save_results_func=split_results_saver, + ) + results_dict[_try] = {k: v.item() * 100.0 for k, v in eval_metrics["metrics"].items()} + + if len(train_data_dict) > 1: + results_dict = average_metrics(results_dict) + else: + results_dict = {**results_dict[_try]} + + logger.info(f"Log regression evaluation done in {int(time.time() - start)}s") + logger.info("Training of the supervised logistic regression on frozen features completed.") + results_string = "\n".join([f"{k}: {results_dict[k]:.4g}" for k in sorted(results_dict.keys())]) + logger.info("Results:\n" + results_string) + + torch.distributed.barrier() + return results_dict + + +def benchmark_launcher(eval_args: dict[str, object]) -> dict[str, Any]: + """Initialization of distributed and logging are preconditions for this method""" + dataclass_config, output_dir = args_dict_to_dataclass(eval_args=eval_args, config_dataclass=LogregEvalConfig) + model, model_context = load_model_and_context(dataclass_config.model, output_dir=output_dir) + results_dict = eval_log_regression_with_model( + model=model, config=dataclass_config, autocast_dtype=model_context["autocast_dtype"] + ) + write_results(results_dict, output_dir, RESULTS_FILENAME) + return results_dict + + +def main(argv=None): + if argv is None: + argv = sys.argv[1:] + eval_args = cli_parser(argv) + with job_context(output_dir=eval_args["output_dir"]): + benchmark_launcher(eval_args=eval_args) + return 0 + + +if __name__ == "__main__": + main() diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/metrics/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7d76da62937bd31345524fd683bd964757651d55 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/metrics/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from .classification import ( + AveragingMethod, + ClassificationMetricType, + MacroAveragedMeanReciprocalRank, + MeanAveragePrecisionVOC2007, + accuracy, + build_classification_metric, + build_topk_accuracy_metric, +) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/metrics/classification.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/metrics/classification.py new file mode 100644 index 0000000000000000000000000000000000000000..4a6ee25c24827005f6890c486c0eb16b43a5cd65 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/metrics/classification.py @@ -0,0 +1,325 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from enum import Enum +from typing import Any, Dict, Optional + +import numpy as np +import torch +from torch import Tensor +from torchmetrics import Metric, MetricCollection +from torchmetrics.classification import ( + MulticlassAccuracy, + MulticlassAUROC, + MulticlassF1Score, + MulticlassRecall, + MultilabelAveragePrecision, + MultilabelF1Score, + MultilabelPrecisionRecallCurve, +) +from torchmetrics.utilities.data import dim_zero_cat, select_topk + +from .imagenet_c import ImageNet_C_Metric + +logger = logging.getLogger("dinov3") + + +class ClassificationMetricType(Enum): + AUROC = "auroc" + MEAN_ACCURACY = "mean_accuracy" + MEAN_RECALL = "mean_recall" + MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy" + MEAN_PER_CLASS_RECALL = "mean_per_class_recall" + PER_CLASS_ACCURACY = "per_class_accuracy" + MEAN_AVERAGE_PRECISION_VOC_2007 = "map_voc2007" + ANY_MATCH_ACCURACY = "any_match_accuracy" + GROUPBY_ANY_MATCH_ACCURACY_1 = "groupby_any_match_accuracy_1" + GROUPBY_ANY_MATCH_ACCURACY_5 = "groupby_any_match_accuracy_5" + MEAN_MULTICLASS_F1 = "mean_multiclass_f1" + MEAN_PER_CLASS_MULTICLASS_F1 = "mean_per_class_multiclass_f1" + MEAN_MULTILABEL_F1 = "mean_multilabel_f1" + MEAN_PER_CLASS_MULTILABEL_F1 = "mean_per_class_multilabel_f1" + IMAGENET_C_METRIC = "imagenet_c_metric" + MACRO_AVERAGED_MEAN_RECIPROCAL_RANK = "macro_averaged_mean_reciprocal_rank" + MACRO_MULTILABEL_AVERAGE_PRECISION = "macro_multilabel_average_precision" + + @property + def averaging_method(self): + return getattr(AveragingMethod, self.name, None) + + @property + def is_topk_accuracy_metric(self): + return self.value in ("mean_accuracy", "mean_per_class_accuracy", "per_class_accuracy") + + @property + def is_topk_recall_metric(self): + return self.value in ("mean_recall", "mean_per_class_recall") + + @property + def is_multilabel(self): + return self.value in ( + "map_voc2007", + "any_match_accuracy", + "groupby_any_match_accuracy_1", + "groupby_any_match_accuracy_5", + "mean_multilabel_f1", + "mean_per_class_multilabel_f1", + ) + + def __str__(self): + return self.value + + +class AveragingMethod(Enum): + MEAN_ACCURACY = "micro" + MEAN_RECALL = "micro" + MEAN_PER_CLASS_ACCURACY = "macro" + MEAN_PER_CLASS_RECALL = "macro" + PER_CLASS_ACCURACY = "none" + MEAN_MULTICLASS_F1 = "micro" + MEAN_PER_CLASS_MULTICLASS_F1 = "macro" + MEAN_MULTILABEL_F1 = "micro" + MEAN_PER_CLASS_MULTILABEL_F1 = "macro" + + def __str__(self): + return self.value + + +def _make_default_ks(num_classes: int): + return (1, 5) if num_classes >= 5 else (1,) + + +def build_classification_metric( + metric_type: ClassificationMetricType, *, num_classes: int, ks: Optional[tuple] = None, dataset=None +): + if metric_type.is_topk_accuracy_metric: + ks = ks or _make_default_ks(num_classes) + return build_topk_accuracy_metric(average_type=metric_type.averaging_method, num_classes=num_classes, ks=ks) + elif metric_type.is_topk_recall_metric: + ks = ks or _make_default_ks(num_classes) + return build_topk_recall_metric(average_type=metric_type.averaging_method, num_classes=num_classes, ks=ks) + elif metric_type == ClassificationMetricType.MEAN_AVERAGE_PRECISION_VOC_2007: + assert ks is None + map_voc2007 = MeanAveragePrecisionVOC2007(num_labels=int(num_classes)) + return MetricCollection({"top-1": map_voc2007}) + elif metric_type == ClassificationMetricType.ANY_MATCH_ACCURACY: + ks = ks or _make_default_ks(num_classes) + return build_topk_any_match_accuracy_metric(num_classes=num_classes, ks=ks) + elif metric_type == ClassificationMetricType.GROUPBY_ANY_MATCH_ACCURACY_1: + return GroupByAnyMatchAccuracy(top_k=1, num_classes=int(num_classes), dataset=dataset) + elif metric_type == ClassificationMetricType.GROUPBY_ANY_MATCH_ACCURACY_5: + return GroupByAnyMatchAccuracy(top_k=5, num_classes=int(num_classes), dataset=dataset) + elif metric_type == ClassificationMetricType.IMAGENET_C_METRIC: + return ImageNet_C_Metric() + elif metric_type == ClassificationMetricType.AUROC: + return MetricCollection({"top-1": MulticlassAUROC(num_classes=int(num_classes))}) + elif metric_type == ClassificationMetricType.MACRO_MULTILABEL_AVERAGE_PRECISION: + return MetricCollection({"top-1": MultilabelAveragePrecision(num_labels=int(num_classes), average="macro")}) + + elif metric_type in ( + ClassificationMetricType.MEAN_MULTICLASS_F1, + ClassificationMetricType.MEAN_PER_CLASS_MULTICLASS_F1, + ): + return MetricCollection( + {"top-1": MulticlassF1Score(num_classes=int(num_classes), average=metric_type.averaging_method.value)} + ) + elif metric_type in ( + ClassificationMetricType.MEAN_MULTILABEL_F1, + ClassificationMetricType.MEAN_PER_CLASS_MULTILABEL_F1, + ): + return MetricCollection( + {"top-1": MultilabelF1Score(num_labels=int(num_classes), average=metric_type.averaging_method.value)} + ) + elif metric_type == ClassificationMetricType.MACRO_AVERAGED_MEAN_RECIPROCAL_RANK: + return MetricCollection({"top-1": MacroAveragedMeanReciprocalRank(num_classes=int(num_classes))}) + raise ValueError(f"Unknown metric type {metric_type}") + + +def build_topk_accuracy_metric(average_type: AveragingMethod, num_classes: int, ks: tuple = (1, 5)): + metrics: Dict[str, Metric] = { + f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks + } + return MetricCollection(metrics) + + +def build_topk_recall_metric(average_type: AveragingMethod, num_classes: int, ks: tuple = (1, 5)): + metrics: Dict[str, Metric] = { + f"top-{k}": MulticlassRecall(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks + } + return MetricCollection(metrics) + + +def build_topk_any_match_accuracy_metric(num_classes: int, ks: tuple = (1, 5)): + metrics: Dict[str, Metric] = {f"top-{k}": AnyMatchAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks} + return MetricCollection(metrics) + + +class MeanAveragePrecisionVOC2007(MultilabelPrecisionRecallCurve): + """ + VOC2007 11-points mAP Evaluation defined on page 11 of + The PASCAL Visual Object Classes (VOC) Challenge (Everingham et al., 2010) + """ + + def __init__(self, *args, recall_level_count: int = 11, **kwargs): + super().__init__(*args, **kwargs) + self.recall_thresholds = torch.linspace(0, 1, recall_level_count) + + def compute(self): + precision, recall, _ = super().compute() + interpolated_precisions = torch.stack( + [torch.max(precision[i][recall[i] >= r]) for r in self.recall_thresholds for i in range(len(precision))] + ) + return torch.mean(interpolated_precisions) + + +class AnyMatchAccuracy(Metric): + """ + This computes an accuracy where an element is considered correctly + predicted if one of the predictions is in a list of targets + """ + + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + top_k: int = 1, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.num_classes = num_classes + self.top_k = top_k + self.add_state("tp", [], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + # preds [B, D] + # target [B, A] + # preds_oh [B, D] with 0 and 1 + # select top K highest probabilities, use one hot representation + preds_oh = select_topk(preds, self.top_k) + # target_oh [B, D + 1] with 0 and 1 + target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32) + target = target.long() + # for undefined targets (-1) use a fake value `num_classes` + target[target == -1] = self.num_classes + # fill targets, use one hot representation + target_oh.scatter_(1, target, 1) + # target_oh [B, D] (remove the fake target at index `num_classes`) + target_oh = target_oh[:, :-1] + # tp [B] with 0 and 1 + tp = (preds_oh * target_oh == 1).sum(dim=1) + # at least one match between prediction and target + tp.clip_(max=1) + # ignore instances where no targets are defined + mask = target_oh.sum(dim=1) > 0 + tp = tp[mask] + self.tp.append(tp) # type: ignore + + def compute(self) -> Tensor: + tp = dim_zero_cat(self.tp) # type: ignore + return tp.float().mean() + + +class GroupByAnyMatchAccuracy(AnyMatchAccuracy): + def __init__( + self, + dataset, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + assert hasattr(dataset, "get_groupby_labels"), "The dataset should have a `get_groupby_labels` method" + self._groupby_labels: Dict[str, np.ndarray] = dataset.get_groupby_labels() + assert hasattr(dataset, "get_mapped_targets"), "The dataset should have a `get_mapped_targets` method" + self._mapped_targets: torch.Tensor = torch.from_numpy(dataset.get_mapped_targets()) + self.add_state("indices", [], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + self.indices.append(target) # target are indices in this case + super().update(preds, self._mapped_targets[target.tolist()].to(preds.device)) + + def groupby_metric(self, variable: np.ndarray, indices: np.ndarray, tp: torch.Tensor) -> Dict[Any, Tensor]: + groubpy_dict = {} + for v in set(variable): + index = np.where(variable[indices] == v)[0] + groubpy_dict[v] = tp[index].mean() + return groubpy_dict + + def compute(self) -> Tensor: + tp = dim_zero_cat(self.tp).float() # type: ignore + indices = dim_zero_cat(self.indices).cpu().numpy() # type: ignore + global_score = tp.mean() + results_dict = {"top-1": global_score} + for label_name, label_value in self._groupby_labels.items(): + groupby_results = self.groupby_metric(label_value, indices, tp) + printable_results = {k: f"{100. * v.item():.4g}" for k, v in groupby_results.items()} + logger.info(f"Scores by {label_name} {printable_results}\n") + results_dict = {**results_dict, **groupby_results} + return results_dict + + +class MacroAveragedMeanReciprocalRank(Metric): + """ + This computes the macro average mean reciprocal rank metric. + Rank is defined as the position at which the target label is found when + we sort the prediction scores from most probable label to least probable + The reciprocal of the rank (1 / rank) which lies in [0, 1] gives a measure on how well the model does. + the higher the rank the better the model. The reciprocal rank of each sample is aggregated by the target + label and we sum those aggregates groupby the target labels. This quantity is divided by the number of + samples per label which gives as per label or macro reciprocal rank performance. This per label metric is + avergaed across all the labels to get the macro averaged mean reciprocal rank metric. This metric is + useful when we have label imbalance and we want to give equal importance to rare labels as well as frequent labels. + """ + + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.num_classes = num_classes + self.add_state("per_class_mrr", default=torch.zeros(self.num_classes, dtype=torch.float), dist_reduce_fx="sum") + self.add_state( + "per_class_num_samples", default=torch.zeros(self.num_classes, dtype=torch.float), dist_reduce_fx="sum" + ) + + def update(self, preds: Tensor, target: torch.LongTensor) -> None: # type: ignore + # preds: FloatTensor [B, num_classes] + # target: LongTensor [B] target labels + # ranks: [] + rank_scores = 1 / (preds >= preds.gather(1, target[:, None].expand_as(preds))).sum(dim=1) + + unique_targets = target.unique().tolist() + target_remap = {key: val for val, key in enumerate(unique_targets)} + target_inv_remap = {val: key for val, key in enumerate(unique_targets)} + remaped_targets = torch.LongTensor(list(map(target_remap.get, target.tolist()))).to(target.device) + unique_remaped_targets, remaped_target_count = remaped_targets.unique(sorted=True, return_counts=True) + sum_rank_scores = torch.zeros_like(unique_remaped_targets, dtype=torch.float).scatter_add_( + 0, remaped_targets, rank_scores + ) + unique_targets = torch.LongTensor(list(map(target_inv_remap.get, unique_remaped_targets.tolist()))).to( + target.device + ) + self.per_class_mrr.index_add_(0, unique_targets, sum_rank_scores) + self.per_class_num_samples.index_add_(0, unique_targets, remaped_target_count.float()) + + def compute(self) -> Tensor: + return (self.per_class_mrr / (self.per_class_num_samples + 1e-6)).mean() + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + return [correct[:k].reshape(-1).float().sum(0) * 100.0 / batch_size for k in topk] diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/metrics/imagenet_c.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/metrics/imagenet_c.py new file mode 100644 index 0000000000000000000000000000000000000000..5ab201d24db418c4256738b2a6c7d8c307a47bba --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/metrics/imagenet_c.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from typing import Any, Dict, Optional + +import numpy as np +import torch +from torch import Tensor +from torchmetrics import Metric + +logger = logging.getLogger("dinov3") + + +# corruption type (str) -> level (int) -> score (float) +Scores = Dict[str, Dict[int, float]] +# corruption type (str) -> score (float) +AverageScores = Dict[str, float] + + +ALEXNET_INVERSE_SCORES: Scores = { + "GAUSSIAN_NOISE": { + 1: 0.69528, + 2: 0.82542, + 3: 0.93554, + 4: 0.98138, + 5: 0.99452, + }, + "SHOT_NOISE": { + 1: 0.71224, + 2: 0.85108, + 3: 0.93574, + 4: 0.98182, + 5: 0.99146, + }, + "IMPULSE_NOISE": { + 1: 0.78374, + 2: 0.89808, + 3: 0.9487, + 4: 0.9872, + 5: 0.99548, + }, + "DEFOCUS_BLUR": { + 1: 0.656239999999999, + 2: 0.73202, + 3: 0.85036, + 4: 0.91364, + 5: 0.94714, + }, + "GLASS_BLUR": { + 1: 0.64308, + 2: 0.75054, + 3: 0.88806, + 4: 0.91622, + 5: 0.93344, + }, + "MOTION_BLUR": { + 1: 0.5843, + 2: 0.70048, + 3: 0.82108, + 4: 0.8975, + 5: 0.92638, + }, + "ZOOM_BLUR": { + 1: 0.70008, + 2: 0.769919999999999, + 3: 0.80784, + 4: 0.84198, + 5: 0.87198, + }, + "SNOW": { + 1: 0.71726, + 2: 0.88392, + 3: 0.86468, + 4: 0.9187, + 5: 0.94952, + }, + "FROST": { + 1: 0.6139, + 2: 0.797339999999999, + 3: 0.8879, + 4: 0.89942, + 5: 0.9343, + }, + "FOG": { + 1: 0.67474, + 2: 0.7605, + 3: 0.84378, + 4: 0.8726, + 5: 0.945, + }, + "BRIGHTNESS": { + 1: 0.4514, + 2: 0.48502, + 3: 0.54048, + 4: 0.62166, + 5: 0.724399999999999, + }, + "CONTRAST": { + 1: 0.64548, + 2: 0.7615, + 3: 0.88874, + 4: 0.9776, + 5: 0.9927, + }, + "ELASTIC_TRANSFORM": { + 1: 0.52596, + 2: 0.70116, + 3: 0.55686, + 4: 0.64076, + 5: 0.80554, + }, + "PIXELATE": { + 1: 0.52218, + 2: 0.5462, + 3: 0.737279999999999, + 4: 0.87092, + 5: 0.91262, + }, + "JPEG_COMPRESSION": { + 1: 0.510019999999999, + 2: 0.54718, + 3: 0.57294, + 4: 0.654579999999999, + 5: 0.74778, + }, + "SPECKLE_NOISE": { + 1: 0.66192, + 2: 0.7444, + 3: 0.90246, + 4: 0.94548, + 5: 0.97268, + }, + "GAUSSIAN_BLUR": { + 1: 0.54732, + 2: 0.70444, + 3: 0.82574, + 4: 0.89864, + 5: 0.9594, + }, + "SPATTER": { + 1: 0.47196, + 2: 0.621939999999999, + 3: 0.75052, + 4: 0.84132, + 5: 0.90182, + }, + "SATURATE": { + 1: 0.59342, + 2: 0.65514, + 3: 0.51174, + 4: 0.70834, + 5: 0.8226, + }, +} + +N_LEVELS = 5 +CORRUPTION_LEVEL_TO_ID = { + (k, level): i * N_LEVELS + level - 1 + for i, k in enumerate(sorted(ALEXNET_INVERSE_SCORES.keys())) + for level in range(1, 1 + N_LEVELS) +} +ID_TO_CORRUPTION_LEVEL = {i: k for k, i in CORRUPTION_LEVEL_TO_ID.items()} + + +def compute_relative_average_scores(scores: Scores, inv_scores_ref: Scores = ALEXNET_INVERSE_SCORES) -> AverageScores: + rel_scores = {} + for corruption_type in inv_scores_ref.keys(): + if corruption_type not in scores: + logger.info(f"No results for split {corruption_type}") + continue + inv_scores_for_type = [] + inv_scores_ref_for_type = [] + for level in range(1, 1 + N_LEVELS): + if level not in scores[corruption_type]: + continue + # append inverse score (confusion error) + inv_scores_for_type.append(1.0 - scores[corruption_type][level]) + inv_scores_ref_for_type.append(inv_scores_ref[corruption_type][level]) + rel_scores[corruption_type] = np.mean(inv_scores_for_type) / np.mean(inv_scores_ref_for_type) + + mce_score = np.mean([v for _, v in rel_scores.items()]) + return mce_score + + +class ImageNet_C_Metric(Metric): + + is_differentiable: bool = False + higher_is_better: Optional[bool] = False + full_state_update: bool = False + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + self.add_state("tp", torch.zeros(len(CORRUPTION_LEVEL_TO_ID)), dist_reduce_fx="sum") + self.add_state("total", torch.zeros(len(CORRUPTION_LEVEL_TO_ID)), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + from large_vision_dataset.datasets.image_net_c import CORRUPTION_TYPES + + target_labels, corruption_types, levels = target.unbind(1) + tps = torch.argmax(preds, dim=1) == target_labels.to(preds.device) + index = torch.tensor( + [ + CORRUPTION_LEVEL_TO_ID[(CORRUPTION_TYPES[ct].upper(), level.item())] + for ct, level in zip(corruption_types, levels) + ], + device=preds.device, + ) + self.total += torch.bincount(index, minlength=len(CORRUPTION_LEVEL_TO_ID)) + self.tp += torch.bincount(index, weights=tps, minlength=len(CORRUPTION_LEVEL_TO_ID)) + + def compute(self) -> Tensor: + flattened_scores = (self.tp / self.total).float().cpu().numpy() + scores: Scores = {} + for i, score in enumerate(flattened_scores): + corruption_type, level = ID_TO_CORRUPTION_LEVEL[i] + if corruption_type not in scores: + scores[corruption_type] = {} + scores[corruption_type][level] = score + + mce_score = compute_relative_average_scores(scores) + return {"top-1": torch.tensor(mce_score, device=self.tp.device)} diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/results.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/results.py new file mode 100644 index 0000000000000000000000000000000000000000..cc8448197d73597dd031ebf110b6630936a0684b --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/results.py @@ -0,0 +1,248 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import json +import logging +import os +from contextlib import nullcontext +from enum import Enum +from os import PathLike +from typing import IO, Any, Callable, Dict, List, Optional, Sequence, Union + +import pandas as pd +import yaml # type: ignore + +logger = logging.getLogger("dinov3") + + +# This type represents a list of results, e.g. baselines for an evaluation. +Results: Any = pd.DataFrame + +try: + import openpyxl # noqa: 401 + + HAS_OPENPYXL = True +except ImportError: + HAS_OPENPYXL = False + logger.warning("can't import openpyxl package") + + +PathOrFileLikeObject = Union[str, PathLike, IO] + + +class FileFormat(Enum): + CSV = "csv" + JSON_LINES = "json-lines" + EXCEL = "excel" + YAML = "yaml" + + @staticmethod + def guess(path: Union[str, PathLike]) -> "FileFormat": + _, ext = os.path.splitext(path) + supported_exts = { + ".csv": FileFormat.CSV, + ".jsonl": FileFormat.JSON_LINES, + ".excel": FileFormat.EXCEL, + ".yaml": FileFormat.YAML, + } + if ext not in supported_exts: + raise ValueError(f"Passed path has extension {ext}, only {list(supported_exts.keys())} are supported.") + return supported_exts[ext] + + +_INT_DTYPES = [ + pd.Int8Dtype(), + pd.Int16Dtype(), + pd.Int32Dtype(), + pd.UInt8Dtype(), + pd.UInt16Dtype(), + pd.UInt32Dtype(), + pd.Int64Dtype(), +] + +_FLOAT_DTYPES = [ + pd.Float32Dtype(), + pd.Float64Dtype(), +] + +_TO_STRING_DTYPES = [ + pd.BooleanDtype(), +] + +_VALID_DTYPES = [ + pd.StringDtype(), + pd.Int64Dtype(), + pd.Float64Dtype(), +] + + +def _map_dtypes(results: Results) -> Results: + results = results.convert_dtypes( + infer_objects=True, + convert_string=True, + convert_integer=True, + convert_boolean=True, + convert_floating=True, + ) + for column_name in results.columns: + if results.dtypes[column_name] in _INT_DTYPES: + results[column_name] = results[column_name].astype(pd.Int64Dtype()) + elif results.dtypes[column_name] in _FLOAT_DTYPES: + results[column_name] = results[column_name].astype(pd.Float64Dtype()) + elif results.dtypes[column_name] in _TO_STRING_DTYPES: + results[column_name] = results[column_name].astype(pd.StringDtype()) + + return results + + +def _validate_column(results: Results, *, name: str, dtype: Union[str, type]) -> bool: + try: + loc = results.columns.get_loc(name) + except KeyError: + return False + return results.dtypes[loc] == dtype + + +def _validate(results: Results) -> bool: + for column_name in results.columns: + dtype = results.dtypes[column_name] + if dtype not in _VALID_DTYPES: + return False + + return True + + +def _assert_valid_dtypes(results: Results) -> None: + assert _validate(results), f"All dtypes from {results.dtypes} must be in {_VALID_DTYPES}" + + +Scalar = Union[str, int, float] + + +def _map_scalar(x: Scalar) -> List[Scalar]: + return [x] + + +def _map_scalar_list(x: List[Scalar]) -> List[Scalar]: + return x + + +def make(data: Dict[str, Union[str, int, float]]) -> Results: + """Construct results from a dictionary of scalars or lists of scalars.""" + + map_value: Callable[..., List[Scalar]] + if all((isinstance(value, Sequence) for key, value in data.items())): + map_value = _map_scalar_list + else: + map_value = _map_scalar + results = pd.DataFrame({key: map_value(value) for key, value in data.items()}) + results = _map_dtypes(results) + _assert_valid_dtypes(results) + return results + + +def vstack(*results_sequence: Sequence[Results]) -> Results: + """Concatenate (vertically) results.""" + + return pd.concat(results_sequence, axis=0, ignore_index=True) + + +def load(f: PathOrFileLikeObject, file_format: Optional[FileFormat] = None) -> Results: + """Load results from a file via a path-like object or from a file-like object.""" + + if isinstance(f, (str, PathLike)): + file_format = FileFormat.guess(f) + elif file_format is None: + raise ValueError("No file format specified for file-like object") + + assert file_format is not None + if file_format == FileFormat.CSV: + results = pd.read_csv(f, sep=",", na_values="", header=0) + elif file_format == FileFormat.JSON_LINES: + results = pd.read_json(f, lines=True) + elif file_format == FileFormat.EXCEL: + results = pd.read_excel(f) + elif file_format == FileFormat.YAML: + with open(f) as file: # type: ignore + results = pd.DataFrame.from_dict(yaml.safe_load(file), orient="index") + else: + raise ValueError("Unsupported file format: {file_format}") + + results = _map_dtypes(results) + _assert_valid_dtypes(results) + return results + + +def load_collection(f: PathOrFileLikeObject) -> Dict[str, Results]: + """Load a collection of results from a file via a path-like object or from a file-like object.""" + + results_collection = pd.read_excel(f, sheet_name=None) + + for sheet_name, results in results_collection.items(): + results = _map_dtypes(results) + _assert_valid_dtypes(results) + results_collection[sheet_name] = results + return results_collection + + +def save( + results: Sequence[Results], + f: PathOrFileLikeObject, + file_format: Optional[FileFormat] = None, +) -> None: + """Save results to a file via a path-like object or to a file-like object.""" + + _assert_valid_dtypes(results) + + if isinstance(f, (str, PathLike)): + file_format = FileFormat.guess(f) + elif file_format is None: + raise ValueError("No file format specified for file-like object") + + assert file_format is not None + if file_format == FileFormat.CSV: + results.to_csv(f, index=False, header=True, sep=",", na_rep="") # type: ignore + elif file_format == FileFormat.JSON_LINES: + # NOTE: pandas escapes '/' characters + s = results.to_json(orient="records", lines=True, indent=None) # type: ignore + if isinstance(f, (str, PathLike)): + context = open(f, "w") # type: ignore + else: + context = nullcontext(enter_result=f) # type: ignore + with context as f: + for line in s.splitlines(): + line = json.dumps(json.loads(line), separators=(",", ":")) + f.write(line + "\n") + elif file_format == FileFormat.EXCEL: + results.to_excel(f, header=True, index=False, na_rep="") # type: ignore + elif file_format == FileFormat.YAML: + with open(f, "w") as fp: # type: ignore + yaml.safe_dump(results.to_dict(orient="index"), fp, default_flow_style=False) # type: ignore + else: + raise ValueError("Unsupported file format: {file_format}") + + +def save_from_dict( + results_dict: Dict[str, Union[str, int, float]], + results_path: PathOrFileLikeObject, +) -> None: + results = make(results_dict) + save(results, results_path) + + +def save_collection( + results_collection: Dict[str, Results], + f: PathOrFileLikeObject, +) -> None: + """Save a collection of results to a file via a path-like object or to a file-like object.""" + + if not HAS_OPENPYXL: + logger.warning("openpyxl need to be installed, passing...") + return + + with pd.ExcelWriter(f, engine="openpyxl", mode="w") as writer: + for sheet_name, results in results_collection.items(): + _assert_valid_dtypes(results) + results.to_excel(writer, sheet_name=sheet_name, header=True, index=False, na_rep="") diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/config.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/config.py new file mode 100644 index 0000000000000000000000000000000000000000..70db1260a220a3808f160095dda97b845af600f8 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/config.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from dataclasses import dataclass, field +from enum import Enum +from omegaconf import MISSING +from typing import Any + +import torch + +from dinov3.data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from dinov3.eval.segmentation.models import BackboneLayersSet +from dinov3.eval.setup import ModelConfig + + +DEFAULT_MEAN = tuple(mean * 255 for mean in IMAGENET_DEFAULT_MEAN) +DEFAULT_STD = tuple(std * 255 for std in IMAGENET_DEFAULT_STD) + + +class ModelDtype(Enum): + FLOAT32 = "float32" + BFLOAT16 = "bfloat16" + + @property + def autocast_dtype(self): + return { + ModelDtype.BFLOAT16: torch.bfloat16, + ModelDtype.FLOAT32: torch.float32, + }[self] + + +@dataclass +class OptimizerConfig: + lr: float = 1e-3 + beta1: float = 0.9 + beta2: float = 0.999 + weight_decay: float = 1e-2 + gradient_clip: float = 35.0 + + +@dataclass +class SchedulerConfig: + type: str = "WarmupOneCycleLR" + total_iter: int = 40_000 # Total number of iterations for training + constructor_kwargs: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class DatasetConfig: + root: str = MISSING # Path to the dataset folder + train: str = "" # Dataset descriptor, e.g. "ADE20K:split=TRAIN" + val: str = "" + + +@dataclass +class DecoderConfig: + type: str = "m2f" # Decoder type must be one of [linear, m2f] + backbone_out_layers: BackboneLayersSet = BackboneLayersSet.LAST + use_batchnorm: bool = True + use_cls_token: bool = False + use_backbone_norm: bool = True # Uses the backbone's output normalization on all layers + num_classes: int = 150 # Number of segmentation classes + hidden_dim: int = 2048 # Hidden dimension, only used for M2F head + + +@dataclass +class TrainConfig: + diceloss_weight: float = 0.0 + celoss_weight: float = 1.0 + + +@dataclass +class TrainTransformConfig: + img_size: Any = None + random_img_size_ratio_range: tuple[float] | None = None + crop_size: tuple[int] | None = None + flip_prob: float = 0.0 + + +@dataclass +class EvalTransformConfig: + img_size: Any = None + tta_ratios: tuple[float] = (1.0,) + + +@dataclass +class TransformConfig: + train: TrainTransformConfig = field(default_factory=TrainTransformConfig) + eval: EvalTransformConfig = field(default_factory=EvalTransformConfig) + mean: tuple[float] = DEFAULT_MEAN + std: tuple[float] = DEFAULT_STD + + +@dataclass +class EvalConfig: + compute_metric_per_image: bool = False + reduce_zero_label: bool = True # For ADE20K, ignores 0 label (=background/unlabeled) + mode: str = "slide" + crop_size: int | None = 512 + stride: int | None = 341 + eval_interval: int = 40000 + use_tta: bool = False # apply test-time augmentation at evaluation time + + +@dataclass +class SegmentationConfig: + model: ModelConfig | None = None # config of the DINOv3 backbone + bs: int = 2 + n_gpus: int = 8 + num_workers: int = 6 # number of workers to use / GPU + model_dtype: ModelDtype = ModelDtype.FLOAT32 + seed: int = 100 + datasets: DatasetConfig = field(default_factory=DatasetConfig) + metric_to_save: str = "mIoU" # Name of the metric to save + decoder_head: DecoderConfig = field(default_factory=DecoderConfig) + scheduler: SchedulerConfig = field(default_factory=SchedulerConfig) + optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + transforms: TransformConfig = field(default_factory=TransformConfig) + train: TrainConfig = field(default_factory=TrainConfig) + eval: EvalConfig = field(default_factory=EvalConfig) + # Additional Parameters + output_dir: str | None = None + load_from: str | None = None # path to .pt checkpoint to resume training from or evaluate from diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/configs/config-ade20k-linear-training.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/configs/config-ade20k-linear-training.yaml new file mode 100644 index 0000000000000000000000000000000000000000..93e9bc82b6566aba4cb1706f8073089fee6620d9 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/configs/config-ade20k-linear-training.yaml @@ -0,0 +1,53 @@ +# Config for ade20k, linear training + +bs: 2 +n_gpus: 8 +metric_to_save: 'mIoU' +model_dtype: FLOAT32 +scheduler: + total_iter: 40000 + type: 'WarmupOneCycleLR' + constructor_kwargs: + warmup_iters: 1500 + warmup_ratio: 1e-6 + final_div_factor: .inf + pct_start: 0 + anneal_strategy: 'cos' + use_beta1: False + update_momentum: False +optimizer: + lr : 1e-3 + beta1: 0.9 + beta2: 0.999 + weight_decay: 1e-3 + gradient_clip: 'inf' +datasets: + root: "" # Path to the ADE20K dataset + train: "ADE20K:split=TRAIN" + val: "ADE20K:split=VAL" +train: + diceloss_weight: 0.0 + celoss_weight: 1.0 +decoder_head: + type: "linear" + backbone_out_layers: LAST + use_cls_token: False + use_batchnorm: True + use_backbone_norm: True + num_classes: 150 +transforms: + train: + img_size: 512 + random_img_size_ratio_range: [0.5, 2.0] + crop_size: [512, 512] + flip_prob: 0.5 + eval: + img_size: 512 +eval: + compute_metric_per_image: False + reduce_zero_label: True + mode: "slide" + crop_size: 512 + stride: 341 + eval_interval: 5000 + use_tta: False diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/configs/config-ade20k-m2f-inference.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/configs/config-ade20k-m2f-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d8d8a8ca485978486ea5faea69007d252f86a701 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/configs/config-ade20k-m2f-inference.yaml @@ -0,0 +1,21 @@ +# Config for ade20k, M2F inference + +metric_to_save: 'mIoU' +datasets: + root: "" # Path to the ADE20K dataset + val: "ADE20K:split=VAL" +decoder_head: + type: "m2f" + backbone_out_layers: FOUR_EVEN_INTERVALS + num_classes: 150 +transforms: + eval: + img_size: 896 + tta_ratios: [0.9, 0.95, 1.0, 1.05, 1.1] +eval: + compute_metric_per_image: False + reduce_zero_label: True + mode: "slide" + crop_size: 896 + stride: 596 + use_tta: True diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/eval.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..c213f19031688ead18360dfcb2bf0144f02bf41d --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/eval.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from functools import partial +import logging + +import torch + +import dinov3.distributed as distributed +from dinov3.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader, make_dataset +from dinov3.eval.segmentation.inference import make_inference +from dinov3.eval.segmentation.metrics import ( + calculate_intersect_and_union, + calculate_segmentation_metrics, +) +from dinov3.eval.segmentation.models import build_segmentation_decoder +from dinov3.eval.segmentation.transforms import make_segmentation_eval_transforms +from dinov3.hub.segmentors import dinov3_vit7b16_ms +from dinov3.logging import MetricLogger + +logger = logging.getLogger("dinov3") + +RESULTS_FILENAME = "results-semantic-segmentation.csv" +MAIN_METRICS = ["mIoU"] + + +def evaluate_segmentation_model( + segmentation_model: torch.nn.Module, + test_dataloader, + device, + eval_res, + eval_stride, + decoder_head_type, + num_classes, + autocast_dtype, +): + segmentation_model = segmentation_model.to(device) + segmentation_model.eval() + all_metric_values = [] + metric_logger = MetricLogger(delimiter=" ") + + for batch_img, (_, gt) in metric_logger.log_every(test_dataloader, 10, header="Validation: "): + batch_img = [img.to(device).to(dtype=autocast_dtype) for img in batch_img] + gt = gt.to(device)[0] + aggregated_preds = torch.zeros(1, num_classes, gt.shape[-2], gt.shape[-1]) + for img_idx, img in enumerate(batch_img): + aggregated_preds += make_inference( + img, + segmentation_model.module, + inference_mode="slide", + decoder_head_type=decoder_head_type, + rescale_to=gt.shape[-2:], + n_output_channels=num_classes, + crop_size=(eval_res, eval_res), + stride=(eval_stride, eval_stride), + apply_horizontal_flip=(img_idx and img_idx >= len(batch_img) / 2), + output_activation=partial(torch.nn.functional.softmax, dim=1), + ) + aggregated_preds = (aggregated_preds / len(batch_img)).argmax(dim=1, keepdim=True).to(device) + intersect_and_union = calculate_intersect_and_union( + aggregated_preds[0], + gt, + num_classes=num_classes, + reduce_zero_label=True, + ) + all_metric_values.append(intersect_and_union) + del img, gt, aggregated_preds, intersect_and_union + + all_metric_values = torch.stack(all_metric_values) + if distributed.is_enabled(): + all_metric_values = torch.cat(distributed.gather_all_tensors((all_metric_values))) + final_metrics = calculate_segmentation_metrics( + all_metric_values, + metrics=["mIoU", "dice", "fscore"], + ) + final_metrics = {k: round(v.cpu().item() * 100, 2) for k, v in final_metrics.items()} + logger.info(final_metrics) + return final_metrics + + +def test_segmentation(backbone, config): + # 1- construct a segmentation decoder + if config.load_from == "dinov3_vit7b16_ms": # torch hub descriptor + # Load public m2f head checkpoints + logger.info("Loading the 7B backbone and the M2F adapter with torchhub") + segmentation_model = dinov3_vit7b16_ms(autocast_dtype=config.model_dtype.autocast_dtype, check_hash=True) + else: + segmentation_model = build_segmentation_decoder( + backbone, + config.decoder_head.backbone_out_layers, + config.decoder_head.type, + hidden_dim=config.decoder_head.hidden_dim, # Only used for instantiating a M2F head + num_classes=config.decoder_head.num_classes, + autocast_dtype=config.model_dtype.autocast_dtype, + ) + state_dict = torch.load(config.load_from, map_location="cpu")["model"] + _, _ = segmentation_model.load_state_dict(state_dict, strict=False) + device = distributed.get_rank() + segmentation_model = torch.nn.parallel.DistributedDataParallel(segmentation_model.to(device), device_ids=[device]) + + # 2- dataloader for testing + eval_res = config.eval.crop_size + eval_stride = config.eval.stride + transforms = make_segmentation_eval_transforms( + img_size=eval_res, + inference_mode="slide", + use_tta=config.eval.use_tta, + tta_ratios=config.transforms.eval.tta_ratios, + ) + + test_dataset = DatasetWithEnumeratedTargets( + make_dataset( + dataset_str=f"{config.datasets.val}:root={config.datasets.root}", + transforms=transforms, + ) + ) + + test_sampler_type = None + if distributed.is_enabled(): + test_sampler_type = SamplerType.DISTRIBUTED + + test_dataloader = make_data_loader( + dataset=test_dataset, + batch_size=1, + num_workers=6, + sampler_type=test_sampler_type, + drop_last=False, + shuffle=False, + persistent_workers=True, + ) + + # 3- make inference + return evaluate_segmentation_model( + segmentation_model=segmentation_model, + test_dataloader=test_dataloader, + device=device, + eval_res=eval_res, + eval_stride=eval_stride, + decoder_head_type=config.decoder_head.type, + num_classes=config.decoder_head.num_classes, + autocast_dtype=config.model_dtype.autocast_dtype, + ) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/inference.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..db0eabd5948e9875a99a624a9c06f59e4c2a95a2 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/inference.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from typing import Callable, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from torchvision.transforms import functional as Fv + + +def make_inference( + x: torch.Tensor, + segmentation_model: nn.Module, + inference_mode: str = "whole", + decoder_head_type: str = "linear", + rescale_to=(512, 512), + n_output_channels: int = 256, + crop_size: Optional[Tuple[int]] = None, + stride: Optional[Tuple[int]] = None, + apply_horizontal_flip: bool = False, + num_max_forward: int = 1, + output_activation: Callable | None = None, +): + """Make inference on a given image, and reverts horizontal flip TTA if applicable. + If `inference_mode` = whole, one single prediction is made for the image. + If `inference_mode` = slide, the image is cropped into multiple slices and the latter are + used to make prediction following a sliding window method. + + Args: + x (tensor): input image to make inference on. + dense_predictor (nn.Module): model to use for evaluating on dense tasks. + requires a `predict` method. + inference_mode (str, optional): Do inference on the whole image (mode="whole"), or by + adopting a sliding window approach to aggregate the results on + smaller patches of the input image (mode="slide"). Defaults to "whole". + rescale_to (tuple, optional): Resizing the output of the model prediction to the + shape of the ground truth. Defaults to (512, 512). + n_output_channels (int): number of output classes + crop_size (tuple, optional): [h_crop, w_crop] + stride (tuple, optional): [h_stride, w_stride] + apply_horizontal_flip (bool): Determines if horizontal flip TTA was applied for + the prediction. Defaults to False. + output_activation (callable): Output activation to use on top of the predictions. + - softmax is used when each pixel belongs to a single class (multiclass), + - sigmoid is used when pixel can belong to multiple classes (multilabel). Defaults to None (identity). + Returns: + Tensor: The segmentation results created from the input image. + """ + assert inference_mode in ["whole", "slide"] + if inference_mode == "slide": + # crop size and stride are needed for sliding inference + assert crop_size is not None + assert stride is not None + pred = F.interpolate( + slide_inference( + x, + segmentation_model, + decoder_head_type, + n_output_channels=n_output_channels, + crop_size=crop_size, + stride=stride, + num_max_forward=num_max_forward, + ), + size=rescale_to, + mode="bilinear", + align_corners=False, + ) + else: + pred = segmentation_model.predict( + F.interpolate( + x, + size=(512, 512), + mode="bilinear", + align_corners=False, + ), + rescale_to=rescale_to, + ) + if decoder_head_type == "m2f": + mask_pred, mask_cls = pred["pred_masks"], pred["pred_logits"] + mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + pred = torch.einsum("bqc,bqhw->bchw", mask_cls.to(torch.float), mask_pred.to(torch.float)) + if apply_horizontal_flip: + pred = Fv.hflip(pred) + if output_activation: + pred = output_activation(pred) + return pred + + +def slide_inference( + inputs: torch.Tensor, + segmentation_model: nn.Module, + decoder_head_type: str = "linear", + n_output_channels: int = 256, + crop_size: Tuple = (512, 512), + stride: Tuple = (341, 341), + num_max_forward: int = 1, +): + """Inference by sliding-window with overlap. + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + Args: + inputs (tensor): the tensor should have a shape NxCxHxW, + which contains all images in the batch. + segmentation_model (nn.Module): model to use for evaluating on dense tasks. + n_output_channels (int): number of output channels + crop_size (tuple): (h_crop, w_crop) + stride (tuple): (h_stride, w_stride) + Returns: + Tensor: The output results from model of each input image. + """ + h_stride, w_stride = stride + h_crop, w_crop = crop_size + batch_size, C, h_img, w_img = inputs.shape + if h_crop > h_img and w_crop > w_img: # Meaning we are doing < 1.0 TTA + h_crop, w_crop = min(h_img, w_img), min(h_img, w_img) + assert batch_size == 1 # As of now, the code assumes that a single image is passed at a time at inference time + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = inputs.new_zeros((1, n_output_channels, h_img, w_img)).cpu() + count_mat = inputs.new_zeros((1, 1, h_img, w_img)).to(torch.int8).cpu() + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = inputs[:, :, y1:y2, x1:x2] + crop_pred = segmentation_model.predict(crop_img, rescale_to=crop_img.shape[2:]) + if decoder_head_type == "m2f": + mask_pred, mask_cls = crop_pred["pred_masks"], crop_pred["pred_logits"] + mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + crop_pred = torch.einsum("bqc,bqhw->bchw", mask_cls.to(torch.bfloat16), mask_pred.to(torch.bfloat16)) + del mask_cls, mask_pred + preds += F.pad(crop_pred, (int(x1), int(preds.shape[-1] - x2), int(y1), int(preds.shape[-2] - y2))).cpu() + count_mat[:, :, y1:y2, x1:x2] += 1 + del crop_img, crop_pred + # Optional buffer to ensure each gpu does the same number of operations for sharded models + for _ in range(h_grids * w_grids, num_max_forward): + dummy_input = inputs.new_zeros((1, C, h_crop, w_crop)) + _ = segmentation_model.predict(dummy_input, rescale_to=dummy_input.shape[2:]) + assert (count_mat == 0).sum() == 0 + return preds / count_mat diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/loss.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b73b1efc002bc295cd635dca50cd66a9bfd25a29 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/loss.py @@ -0,0 +1,296 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import functools + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +def reduce_loss(loss, reduction) -> torch.Tensor: + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = nn._reduction.get_enum(reduction) + # None: 0, element-wise mean: 1, sum: 2 + assert reduction_enum in [0, 1, 2] + if reduction_enum == 0: + return loss + if reduction_enum == 1: + return loss.mean() + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction="mean", avg_factor=None) -> torch.Tensor: + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. + reduction (str): Same as built-in losses of PyTorch. + avg_factor (float): Average factor when computing the mean of losses. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + if weight.dim() > 1: + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == "mean": + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = torch.finfo(torch.float32).eps + loss = loss.sum() / (avg_factor + eps) + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != "none": + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)`. + """ + + @functools.wraps(loss_func) + def wrapper(pred, target, weight=None, reduction="mean", avg_factor=None, **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper + + +def get_class_weight(class_weight): + """Get class weight for loss function. + + Args: + class_weight (list[float] | str | None): If class_weight is a str, + take it as a file name and read from it. + """ + if isinstance(class_weight, str): + class_weight = np.load(class_weight) + + return class_weight + + +@weighted_loss +def dice_loss(pred, target, valid_mask, smooth=1, exponent=2, class_weight=None, ignore_index=255): + assert pred.shape[0] == target.shape[0] + total_loss = 0 + num_classes = pred.shape[1] + for i in range(num_classes): + if i != ignore_index: + dice_loss = binary_dice_loss( + pred[:, i], target[..., i], valid_mask=valid_mask, smooth=smooth, exponent=exponent + ) + if class_weight is not None: + dice_loss *= class_weight[i] + total_loss += dice_loss + return total_loss / num_classes + + +@weighted_loss +def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwargs): + assert pred.shape[0] == target.shape[0] + pred = pred.reshape(pred.shape[0], -1) + target = target.reshape(target.shape[0], -1) + valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) + + num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth + den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth + + return 1 - num / den + + +class DiceLoss(nn.Module): + """DiceLoss. + + This loss is proposed in `V-Net: Fully Convolutional Neural Networks for + Volumetric Medical Image Segmentation `_. + + Args: + smooth (float): A float number to smooth loss, and avoid NaN error. + Default: 1 + exponent (float): An float number to calculate denominator + value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Default to 1.0. + ignore_index (int | None): The label index to be ignored. Default: 255. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_dice'. + """ + + def __init__( + self, + smooth=1, + exponent=2, + reduction="mean", + class_weight=None, + loss_weight=1.0, + ignore_index=255, + loss_name="loss_dice", + **kwargs, + ): + super(DiceLoss, self).__init__() + self.smooth = smooth + self.exponent = exponent + self.reduction = reduction + self.class_weight = get_class_weight(class_weight) + self.loss_weight = loss_weight + self.ignore_index = ignore_index + self._loss_name = loss_name + + def forward(self, pred, target, avg_factor=None, reduction_override=None, **kwargs): + assert reduction_override in (None, "none", "mean", "sum") + reduction = reduction_override if reduction_override else self.reduction + if self.class_weight is not None: + class_weight = pred.new_tensor(self.class_weight) + else: + class_weight = None + + pred = F.softmax(pred, dim=1) + num_classes = pred.shape[1] + one_hot_target = F.one_hot(torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes) + valid_mask = (target != self.ignore_index).long() + + loss = self.loss_weight * dice_loss( + pred, + one_hot_target, + valid_mask=valid_mask, + reduction=reduction, + avg_factor=avg_factor, + smooth=self.smooth, + exponent=self.exponent, + class_weight=class_weight, + ignore_index=self.ignore_index, + ) + return loss + + +@weighted_loss +def multilabel_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, class_weight=None, ignore_index=255): + assert pred.shape[0] == target.shape[0] + total_loss = 0 + num_classes = pred.shape[1] + for i in range(num_classes): + if i != ignore_index: + dice_loss = binary_dice_loss( + pred[:, i], target[:, i], valid_mask=valid_mask, smooth=smooth, exponent=exponent + ) + if class_weight is not None: + dice_loss *= class_weight[i] + total_loss += dice_loss + return total_loss / num_classes + + +class MultilabelDiceLoss(DiceLoss): + def forward(self, pred, target, avg_factor=None, reduction_override=None, **kwargs): + assert reduction_override in (None, "none", "mean", "sum") + reduction = reduction_override if reduction_override else self.reduction + if self.class_weight is not None: + class_weight = pred.new_tensor(self.class_weight) + else: + class_weight = None + + pred = F.sigmoid(pred) + if False: + valid_mask = (target[..., self.ignore_index] == 0).long() + else: + valid_mask = torch.ones_like(target[:, 0]).long() + + loss = self.loss_weight * multilabel_dice_loss( + pred, + target, + valid_mask=valid_mask, + reduction=reduction, + avg_factor=avg_factor, + smooth=self.smooth, + exponent=self.exponent, + class_weight=class_weight, + ignore_index=self.ignore_index, + ) + return loss + + +class CrossEntropyLoss(nn.Module): + def __init__( + self, + weight=None, + class_weight=None, + loss_weight=1.0, + reduction="mean", + avg_factor=None, + ignore_index=255, + avg_non_ignore=False, + ): + super(CrossEntropyLoss, self).__init__() + self.weight = weight + self.class_weight = class_weight + self.loss_weight = loss_weight + self.reduction = reduction + self.avg_factor = avg_factor + self.ignore_index = ignore_index + self.avg_non_ignore = avg_non_ignore + + def forward(self, pred, label): + loss = F.cross_entropy(pred, label, weight=self.class_weight, reduction="none", ignore_index=self.ignore_index) + + if (self.avg_factor is None) and self.avg_non_ignore and self.reduction == "mean": + avg_factor = label.numel() - (label == self.ignore_index).sum().item() + else: + avg_factor = None + + loss = weight_reduce_loss(loss, weight=self.weight, reduction=self.reduction, avg_factor=avg_factor) + return self.loss_weight * loss + + +class MultiSegmentationLoss(nn.Module): + """ + Combine different losses used in segmentation. + """ + + def __init__(self, diceloss_weight=0.0, celoss_weight=0.0): + super(MultiSegmentationLoss, self).__init__() + + if diceloss_weight > 0: + self.loss = MultilabelDiceLoss(loss_weight=diceloss_weight) + elif celoss_weight > 0: + self.loss = CrossEntropyLoss(reduction="mean", loss_weight=celoss_weight) + else: + self.loss = lambda _: 0 + + def forward(self, pred, gt): + """Forward function.""" + return self.loss(pred, gt) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/metrics.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..608d7ca98ec279fa96498fe1e7578efda827bd96 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/metrics.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +import numpy as np +import pandas as pd +import torch + + +pd.set_option("display.max_rows", 200) + +logger = logging.getLogger("dinov3") + + +SEGMENTATION_METRICS = ["mIoU", "acc", "aAcc", "dice", "fscore", "precision", "recall"] + + +def calculate_segmentation_metrics( + pre_eval_results, + metrics=["mIoU"], + beta=1, +): + """Calculate the segmentation metrics after aggregating all the intermediate results. + + Args: + pre_eval_results (list): Lists of (area_intersect, area_union, area_pred_label, and area_label). + These are intermediate results to compute the final metrics such as iou, fscore, etc. + metrics (list): Metrics to compute. Defaults to ["mIoU"]. + beta (int): Parameter for computing F-score. Defaults to 1 (for computing F1-score). + + Returns: + Dictionary of final metrics. + """ + pre_eval_results = tuple(zip(*pre_eval_results)) + assert len(pre_eval_results) == 4 + total_area_intersect = sum(pre_eval_results[0]) + total_area_union = sum(pre_eval_results[1]) + total_area_pred_label = sum(pre_eval_results[2]) + total_area_label = sum(pre_eval_results[3]) + metrics_dict = total_area_to_metrics( + total_area_intersect, + total_area_union, + total_area_pred_label, + total_area_label, + metrics=metrics, + beta=beta, + ) + df = pd.DataFrame( + { + "Class Index": np.arange(len(metrics_dict["mIoU"])), + "mIoU": 100 * metrics_dict["mIoU"].cpu().numpy(), + } + ) + logger.info(f"mIoU per class:\n{df.to_string(index=False)}") + return { + "mIoU": metrics_dict["mIoU"].nanmean(), + "acc": metrics_dict["acc"].nanmean(), + "aAcc": metrics_dict["aAcc"].nanmean(), + "dice": metrics_dict["dice"].nanmean(), + "fscore": metrics_dict["fscore"].nanmean(), + "precision": metrics_dict["precision"].nanmean(), + "recall": metrics_dict["recall"].nanmean(), + } + + +def preprocess_nonzero_labels(label, ignore_index=255): + label_new = label.clone() + label_new[label_new == ignore_index] += 1 + label_new -= 1 + label_new[label_new == -1] = ignore_index + return label_new + + +def calculate_intersect_and_union(pred_label, label, num_classes, ignore_index=255, reduce_zero_label=False): + """Calculate intersection and Union. + Args: + pred_label (torch.Tensor): Prediction segmentation map + label (torch.Tensor): Ground truth segmentation map + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + reduce_zero_label (bool): Indicates whether or not label 0 is to be ignored. + """ + pred_label = pred_label.float() # Enables float tensor operations + if reduce_zero_label: + label = preprocess_nonzero_labels(label, ignore_index=ignore_index) + + mask = label != ignore_index + pred_label = pred_label[mask] + label = label[mask] + intersect = pred_label[pred_label == label] + area_intersect = torch.histc(intersect.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_pred_label = torch.histc(pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_label = torch.histc(label.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_union = area_pred_label + area_label - area_intersect + + return torch.stack([area_intersect, area_union, area_pred_label, area_label]) + + +def total_area_to_metrics( + total_area_intersect, + total_area_union, + total_area_pred_label, + total_area_label, + metrics=["mIoU"], + beta=1, +): + """Calculate evaluation metrics + Args: + total_area_intersect (torch.Tensor): The intersection of prediction and + ground truth histogram on all classes. + total_area_union (torch.Tensor): The union of prediction and ground truth + histogram on all classes. + total_area_pred_label (torch.Tensor): The prediction histogram on all + classes. + total_area_label (torch.Tensor): The ground truth histogram on all classes. + metrics (list[str] | str): Metrics to be evaluated, + can be 'mIoU', 'mDice', or 'mFscore'. + beta (int): Parameter for computing F-score. Defaults to 1 (for computing F1-score). + Returns: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category evaluation metrics, shape (num_classes, ). + """ + + def f_score(precision, recall, beta=1): + score = (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall) + return score + + if isinstance(metrics, str): + metrics = [metrics] + allowed_metrics = ["mIoU", "dice", "fscore"] + if not set(metrics).issubset(set(allowed_metrics)): + raise KeyError(f"metrics {metrics} is not supported") + + all_acc = total_area_intersect.sum() / total_area_label.sum() + ret_metrics = dict({"aAcc": all_acc}) + for metric in metrics: + if metric == "mIoU": + ret_metrics["mIoU"] = total_area_intersect / total_area_union + ret_metrics["acc"] = total_area_intersect / total_area_label + elif metric == "dice": + ret_metrics["dice"] = 2 * total_area_intersect / (total_area_pred_label + total_area_label) + ret_metrics["acc"] = total_area_intersect / total_area_label + elif metric == "fscore": + precision = total_area_intersect / total_area_pred_label + recall = total_area_intersect / total_area_label + f_value = torch.tensor([f_score(x[0], x[1], beta) for x in zip(precision, recall)]) + ret_metrics["fscore"] = f_value + ret_metrics["precision"] = precision + ret_metrics["recall"] = recall + return ret_metrics diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..202386d2cef7aac6809fef6a6f8facfe7a061fc5 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/__init__.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from enum import Enum +from functools import partial + +import torch + +from dinov3.eval.segmentation.models.backbone.dinov3_adapter import DINOv3_Adapter +from dinov3.eval.segmentation.models.heads.linear_head import LinearHead +from dinov3.eval.segmentation.models.heads.mask2former_head import Mask2FormerHead +from dinov3.eval.utils import ModelWithIntermediateLayers + + +class BackboneLayersSet(Enum): + """ + Set of intermediate layers to take from the backbone. + """ + + LAST = "LAST" # extracting only the last layer + FOUR_LAST = "FOUR_LAST" # extracting the four last layers + FOUR_EVEN_INTERVALS = "FOUR_EVEN_INTERVALS" # extracting outputs every 1/4 of the total number of blocks + + +def _get_backbone_out_indices( + model: torch.nn.Module, + backbone_out_layers: BackboneLayersSet = BackboneLayersSet.FOUR_EVEN_INTERVALS, +): + """ + Get indices for output layers of the ViT backbone. For now there are 3 options available: + BackboneLayersSet.LAST : only extract the last layer, used in segmentation tasks with a bn head. + BackboneLayersSet.FOUR_EVEN_INTERVALS : extract outputs every 1/4 of the total number of blocks + Reference outputs in 'FOUR_EVEN_INTERVALS' mode : + ViT/S (12 blocks): [2, 5, 8, 11] + ViT/B (12 blocks): [2, 5, 8, 11] + ViT/L (24 blocks): [5, 11, 17, 23] (classic), [4, 11, 17, 23] (used in the paper) + ViT/g (40 blocks): [9, 19, 29, 39] + """ + n_blocks = getattr(model, "n_blocks", 1) + if backbone_out_layers == BackboneLayersSet.LAST: + out_indices = [n_blocks - 1] + elif backbone_out_layers == BackboneLayersSet.FOUR_LAST: + out_indices = [i for i in range(n_blocks - 4, n_blocks)] + elif backbone_out_layers == BackboneLayersSet.FOUR_EVEN_INTERVALS: + # Take indices that were used in the paper (for ViT/L only) + if n_blocks == 24: + out_indices = [4, 11, 17, 23] + else: + out_indices = [i * (n_blocks // 4) - 1 for i in range(1, 5)] + assert all([out_index < n_blocks for out_index in out_indices]) + return out_indices + + +class FeatureDecoder(torch.nn.Module): + def __init__(self, segmentation_model: torch.nn.ModuleList, autocast_ctx): + super().__init__() + self.segmentation_model = segmentation_model + self.autocast_ctx = autocast_ctx + + def forward(self, inputs): + with self.autocast_ctx(): + for module in self.segmentation_model: + inputs = module.forward(inputs) + return inputs + + def predict(self, inputs, rescale_to=(512, 512)): + with torch.inference_mode(): + with self.autocast_ctx(): + out = self.segmentation_model[0](inputs) # backbone forward + out = self.segmentation_model[1].predict(out, rescale_to=rescale_to) # decoder head prediction + return out + + +def build_segmentation_decoder( + backbone_model, + backbone_out_layers=BackboneLayersSet.FOUR_EVEN_INTERVALS, + decoder_type="linear", + hidden_dim=2048, + num_classes=150, + autocast_dtype=torch.float32, +): + backbone_indices_to_use = _get_backbone_out_indices(backbone_model, backbone_out_layers) + autocast_ctx = partial(torch.autocast, device_type="cuda", enabled=True, dtype=autocast_dtype) + if decoder_type == "m2f": + backbone_model = DINOv3_Adapter( + backbone_model, + interaction_indexes=backbone_indices_to_use, + ) + backbone_model.eval() + embed_dim = backbone_model.backbone.embed_dim + patch_size = backbone_model.patch_size + decoder = Mask2FormerHead( + input_shape={ + "1": [embed_dim, patch_size * 4, patch_size * 4, 4], + "2": [embed_dim, patch_size * 2, patch_size * 2, 4], + "3": [embed_dim, patch_size, patch_size, 4], + "4": [embed_dim, int(patch_size / 2), int(patch_size / 2), 4], + }, + hidden_dim=hidden_dim, + num_classes=num_classes, + ignore_value=255, + ) + elif decoder_type == "linear": + backbone_model = ModelWithIntermediateLayers( + backbone_model, + n=backbone_indices_to_use, + autocast_ctx=autocast_ctx, + reshape=True, + return_class_token=False, + ) + # Important: we freeze the backbone + backbone_model.requires_grad_(False) + embed_dim = backbone_model.feature_model.embed_dim + if isinstance(embed_dim, int): + if backbone_out_layers in [BackboneLayersSet.FOUR_LAST, BackboneLayersSet.FOUR_EVEN_INTERVALS]: + embed_dim = [embed_dim] * 4 + else: + embed_dim = [embed_dim] + decoder = LinearHead( + in_channels=embed_dim, + n_output_channels=num_classes, + ) + else: + raise ValueError(f'Unsupported decoder "{decoder_type}"') + + segmentation_model = FeatureDecoder( + torch.nn.ModuleList( + [ + backbone_model, + decoder, + ] + ), + autocast_ctx=autocast_ctx, + ) + return segmentation_model diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/backbone/dinov3_adapter.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/backbone/dinov3_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..14906bb7aec55ec519e82a88c80518acdfdd998c --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/backbone/dinov3_adapter.py @@ -0,0 +1,484 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp + +from functools import partial + +from dinov3.eval.segmentation.models.utils.ms_deform_attn import MSDeformAttn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +def get_reference_points(spatial_shapes, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / H_ + ref_x = ref_x.reshape(-1)[None] / W_ + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] + return reference_points + + +def deform_inputs(x, patch_size): + bs, c, h, w = x.shape + spatial_shapes = torch.as_tensor( + [(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], dtype=torch.long, device=x.device + ) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = get_reference_points([(h // patch_size, w // patch_size)], x.device) + deform_inputs1 = [reference_points, spatial_shapes, level_start_index] + + spatial_shapes = torch.as_tensor([(h // patch_size, w // patch_size)], dtype=torch.long, device=x.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = get_reference_points([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], x.device) + deform_inputs2 = [reference_points, spatial_shapes, level_start_index] + + return deform_inputs1, deform_inputs2 + + +class ConvFFN(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + n = N // 21 + x1 = x[:, 0 : 16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous() + x2 = x[:, 16 * n : 20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous() + x3 = x[:, 20 * n :, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous() + x1 = self.dwconv(x1).flatten(2).transpose(1, 2) + x2 = self.dwconv(x2).flatten(2).transpose(1, 2) + x3 = self.dwconv(x3).flatten(2).transpose(1, 2) + x = torch.cat([x1, x2, x3], dim=1) + return x + + +class Extractor(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + deform_ratio=1.0, + with_cffn=True, + cffn_ratio=0.25, + drop=0.0, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + with_cp=False, + ): + super().__init__() + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MSDeformAttn( + d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio + ) + self.with_cffn = with_cffn + self.with_cp = with_cp + if with_cffn: + self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop) + self.ffn_norm = norm_layer(dim) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W): + def _inner_forward(query, feat): + attn = self.attn( + self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None + ) + query = query + attn + + if self.with_cffn: + query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W)) + return query + + if self.with_cp and query.requires_grad: + query = cp.checkpoint(_inner_forward, query, feat) + else: + query = _inner_forward(query, feat) + + return query + + +class InteractionBlockWithCls(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0.0, + drop_path=0.0, + with_cffn=True, + cffn_ratio=0.25, + init_values=0.0, + deform_ratio=1.0, + extra_extractor=False, + with_cp=False, + ): + super().__init__() + self.extractor = Extractor( + dim=dim, + n_levels=1, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + if extra_extractor: + self.extra_extractors = nn.Sequential( + *[ + Extractor( + dim=dim, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + for _ in range(2) + ] + ) + else: + self.extra_extractors = None + + def forward(self, x, c, cls, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks): + c = self.extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + if self.extra_extractors is not None: + for extractor in self.extra_extractors: + c = extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + return x, c, cls + + +class SpatialPriorModule(nn.Module): + def __init__(self, inplanes=64, embed_dim=384, with_cp=False): + super().__init__() + self.with_cp = with_cp + + self.stem = nn.Sequential( + *[ + nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ] + ) + self.conv2 = nn.Sequential( + *[ + nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(2 * inplanes), + nn.ReLU(inplace=True), + ] + ) + self.conv3 = nn.Sequential( + *[ + nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(4 * inplanes), + nn.ReLU(inplace=True), + ] + ) + self.conv4 = nn.Sequential( + *[ + nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(4 * inplanes), + nn.ReLU(inplace=True), + ] + ) + self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + + def forward(self, x): + def _inner_forward(x): + c1 = self.stem(x) + c2 = self.conv2(c1) + c3 = self.conv3(c2) + c4 = self.conv4(c3) + c1 = self.fc1(c1) + c2 = self.fc2(c2) + c3 = self.fc3(c3) + c4 = self.fc4(c4) + + bs, dim, _, _ = c1.shape + # c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s + c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s + c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s + c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s + + return c1, c2, c3, c4 + + if self.with_cp and x.requires_grad: + outs = cp.checkpoint(_inner_forward, x) + else: + outs = _inner_forward(x) + return outs + + +class DINOv3_Adapter(nn.Module): + def __init__( + self, + backbone, + interaction_indexes=[9, 19, 29, 39], + pretrain_size=512, + conv_inplane=64, + n_points=4, + deform_num_heads=16, + drop_path_rate=0.3, + init_values=0.0, + with_cffn=True, + cffn_ratio=0.25, + deform_ratio=0.5, + add_vit_feature=True, + use_extra_extractor=True, + with_cp=True, + ): + super(DINOv3_Adapter, self).__init__() + self.backbone = backbone + # Important: we freeze the backbone + self.backbone.requires_grad_(False) + + self.pretrain_size = (pretrain_size, pretrain_size) + self.interaction_indexes = interaction_indexes + self.add_vit_feature = add_vit_feature + embed_dim = self.backbone.embed_dim + self.patch_size = self.backbone.patch_size + print("embed dim", embed_dim) + print("interaction_indexes", self.interaction_indexes) + print("patch_size", self.patch_size) + + block_fn = InteractionBlockWithCls + self.level_embed = nn.Parameter(torch.zeros(3, embed_dim)) + self.spm = SpatialPriorModule(inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False) + self.interactions = nn.Sequential( + *[ + block_fn( + dim=embed_dim, + num_heads=deform_num_heads, + n_points=n_points, + init_values=init_values, + drop_path=drop_path_rate, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + extra_extractor=( + (True if i == len(self.interaction_indexes) - 1 else False) and use_extra_extractor + ), + with_cp=with_cp, + ) + for i in range(len(self.interaction_indexes)) + ] + ) + self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2) + self.norm1 = nn.SyncBatchNorm(embed_dim) + self.norm2 = nn.SyncBatchNorm(embed_dim) + self.norm3 = nn.SyncBatchNorm(embed_dim) + self.norm4 = nn.SyncBatchNorm(embed_dim) + + self.up.apply(self._init_weights) + self.spm.apply(self._init_weights) + self.interactions.apply(self._init_weights) + self.apply(self._init_deform_weights) + torch.nn.init.normal_(self.level_embed) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def _get_pos_embed(self, pos_embed, H, W): + pos_embed = pos_embed.reshape( + 1, self.pretrain_size[0] // self.patch_size, self.pretrain_size[1] // self.patch_size, -1 + ).permute(0, 3, 1, 2) + pos_embed = ( + F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False) + .reshape(1, -1, H * W) + .permute(0, 2, 1) + ) + return pos_embed + + def _init_deform_weights(self, m): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + + def _add_level_embed(self, c2, c3, c4): + c2 = c2 + self.level_embed[0] + c3 = c3 + self.level_embed[1] + c4 = c4 + self.level_embed[2] + return c2, c3, c4 + + def forward(self, x): + deform_inputs1, deform_inputs2 = deform_inputs(x, self.patch_size) + + # SPM forward + c1, c2, c3, c4 = self.spm(x) + c2, c3, c4 = self._add_level_embed(c2, c3, c4) + + c = torch.cat([c2, c3, c4], dim=1) + + # Code for matching with oss + H_c, W_c = x.shape[2] // 16, x.shape[3] // 16 + H_toks, W_toks = x.shape[2] // self.patch_size, x.shape[3] // self.patch_size + bs, C, h, w = x.shape + + with torch.autocast("cuda", torch.bfloat16): + with torch.no_grad(): + all_layers = self.backbone.get_intermediate_layers( + x, n=self.interaction_indexes, return_class_token=True + ) + + x_for_shape, _ = all_layers[0] + bs, _, dim = x_for_shape.shape + del x_for_shape + + cls, x = ( + x[ + :, + :1, + ], + x[ + :, + 5:, + ], + ) + + outs = list() + for i, layer in enumerate(self.interactions): + x, cls = all_layers[i] + _, c, _ = layer( + x, + c, + cls, + deform_inputs1, + deform_inputs2, + H_c, + W_c, + H_toks, + W_toks, + ) + outs.append(x.transpose(1, 2).view(bs, dim, H_toks, W_toks).contiguous()) + + # Split & Reshape + c2 = c[:, 0 : c2.size(1), :] + c3 = c[:, c2.size(1) : c2.size(1) + c3.size(1), :] + c4 = c[:, c2.size(1) + c3.size(1) :, :] + + c2 = c2.transpose(1, 2).view(bs, dim, H_c * 2, W_c * 2).contiguous() + c3 = c3.transpose(1, 2).view(bs, dim, H_c, W_c).contiguous() + c4 = c4.transpose(1, 2).view(bs, dim, H_c // 2, W_c // 2).contiguous() + c1 = self.up(c2) + c1 + + if self.add_vit_feature: + x1, x2, x3, x4 = outs + + x1 = F.interpolate(x1, size=(4 * H_c, 4 * W_c), mode="bilinear", align_corners=False) + x2 = F.interpolate(x2, size=(2 * H_c, 2 * W_c), mode="bilinear", align_corners=False) + x3 = F.interpolate(x3, size=(1 * H_c, 1 * W_c), mode="bilinear", align_corners=False) + x4 = F.interpolate(x4, size=(H_c // 2, W_c // 2), mode="bilinear", align_corners=False) + c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4 + + # Final Norm + f1 = self.norm1(c1) + f2 = self.norm2(c2) + f3 = self.norm3(c3) + f4 = self.norm4(c4) + + return {"1": f1, "2": f2, "3": f3, "4": f4} diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/heads/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5375bc66e1ed841a7091b81a0dcf56d1993c1f87 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/heads/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/heads/linear_head.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/heads/linear_head.py new file mode 100644 index 0000000000000000000000000000000000000000..eadcad5031466e6494f1d56b8d693e700b73ea9d --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/heads/linear_head.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LinearHead(nn.Module): + """Linear layer .""" + + def __init__( + self, + in_channels, + n_output_channels, + use_batchnorm=True, + use_cls_token=False, + ): + super().__init__() + self.in_channels = in_channels + self.channels = sum(in_channels) + if use_cls_token: + self.channels *= 2 # concatenate CLS to patch tokens + self.n_output_channels = n_output_channels + self.use_cls_token = use_cls_token + self.batchnorm_layer = nn.SyncBatchNorm(self.channels) if use_batchnorm else nn.Identity(self.channels) + self.conv = nn.Conv2d(self.channels, self.n_output_channels, kernel_size=1, padding=0, stride=1) + self.dropout = nn.Dropout2d(0.1) + nn.init.normal_(self.conv.weight, mean=0, std=0.01) + nn.init.constant_(self.conv.bias, 0) + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + Tensor: The transformed inputs + """ + inputs = [ + torch.nn.functional.interpolate( + input=x, + size=inputs[0].shape[2:], + mode="bilinear", + align_corners=False, + ) + for x in inputs + ] + inputs = torch.cat(inputs, dim=1) + return inputs + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + # accept lists (for cls token) + inputs = list(inputs) + for i, x in enumerate(inputs): + if self.use_cls_token: + assert len(x) == 2, "Missing class tokens" + x, cls_token = x[0], x[1] + if len(x.shape) == 2: + x = x[:, :, None, None] + cls_token = cls_token[:, :, None, None].expand_as(x) + inputs[i] = torch.cat((x, cls_token), 1) + else: + if len(x.shape) == 2: + x = x[:, :, None, None] + inputs[i] = x + x = self._transform_inputs(inputs) + return x + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.dropout(output) + output = self.batchnorm_layer(output) + output = self.conv(output) + return output + + def predict(self, x, rescale_to=(512, 512)): + """ + Predict function used in evaluation. + No dropout is used, and the output is rescaled to the ground truth + for computing metrics. + """ + x = self._forward_feature(x) + x = self.batchnorm_layer(x) + x = self.conv(x) + x = F.interpolate(input=x, size=rescale_to, mode="bilinear") + return x diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/heads/mask2former_head.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/heads/mask2former_head.py new file mode 100644 index 0000000000000000000000000000000000000000..2352af7277216ce0e55a62e7f99bedb3acbf8cb1 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/heads/mask2former_head.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import Dict, Tuple + +from torch import nn +from torch.nn import functional as F + +from dinov3.eval.segmentation.models.heads.pixel_decoder import MSDeformAttnPixelDecoder +from dinov3.eval.segmentation.models.heads.mask2former_transformer_decoder import MultiScaleMaskedTransformerDecoder + + +class Mask2FormerHead(nn.Module): + def __init__( + self, + input_shape: Dict[str, Tuple[int]], # ShapeSpec: [channels, height, width, stride] + hidden_dim: int = 2048, + num_classes: int = 150, + loss_weight: float = 1.0, + ignore_value: int = -1, + # extra parameters + transformer_in_feature: str = "multi_scale_pixel_decoder", + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + num_classes: number of classes to predict + pixel_decoder: the pixel decoder module + loss_weight: loss weight + ignore_value: category id to be ignored during training. + transformer_predictor: the transformer decoder that makes prediction + transformer_in_feature: input feature name to the transformer_predictor + """ + super().__init__() + orig_input_shape = input_shape + input_shape = sorted(input_shape.items(), key=lambda x: x[1][-1]) + self.in_features = [k for k, _ in input_shape] + + self.ignore_value = ignore_value + self.common_stride = 4 + self.loss_weight = loss_weight + + self.pixel_decoder = MSDeformAttnPixelDecoder( + input_shape=orig_input_shape, + transformer_dropout=0.0, + transformer_nheads=16, + transformer_dim_feedforward=4096, + transformer_enc_layers=6, + conv_dim=hidden_dim, + mask_dim=hidden_dim, + norm="GN", + transformer_in_features=["1", "2", "3", "4"], + common_stride=4, + ) + self.predictor = MultiScaleMaskedTransformerDecoder( + in_channels=hidden_dim, + mask_classification=True, + num_classes=num_classes, + hidden_dim=hidden_dim, + num_queries=100, + nheads=16, + dim_feedforward=4096, + dec_layers=9, + pre_norm=False, + mask_dim=hidden_dim, + enforce_input_project=False, + ) + + self.transformer_in_feature = transformer_in_feature + self.num_classes = num_classes + + def forward_features(self, features, mask=None): + return self.layers(features, mask) + + def forward(self, features, mask=None): + output = self.forward_features(features, mask) + return output + + def predict(self, features, mask=None, rescale_to=(512, 512)): + output = self.forward_features(features, mask) + output["pred_masks"] = F.interpolate( + output["pred_masks"], + size=rescale_to, + mode="bilinear", + align_corners=False, + ) + return output + + def layers(self, features, mask=None): + mask_features, _, multi_scale_features = self.pixel_decoder.forward_features(features) + predictions = self.predictor(multi_scale_features, mask_features, mask) + return predictions diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/heads/mask2former_transformer_decoder.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/heads/mask2former_transformer_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b3ff3190f8cb8fe0708ba7b0e4990c1cc7fff7f0 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/heads/mask2former_transformer_decoder.py @@ -0,0 +1,471 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# Copyright (c) Facebook, Inc. and its affiliates. +# Adapted from: https://github.com/facebookresearch/detr/blob/master/models/detr.py + +from typing import Optional +import torch +from torch import nn, Tensor +from torch.nn import functional as F + +from dinov3.eval.segmentation.models.utils.position_encoding import PositionEmbeddingSine + + +def c2_xavier_fill(module: nn.Module) -> None: + """ + Initialize `module.weight` using the "XavierFill" implemented in Caffe2. + Also initializes `module.bias` to 0. + + Args: + module (torch.nn.Module): module to initialize. + """ + # Caffe2 implementation of XavierFill in fact + # corresponds to kaiming_uniform_ in PyTorch + # pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Module, Tensor]`. + nn.init.kaiming_uniform_(module.weight, a=1) + if module.bias is not None: + # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, + # torch.Tensor]`. + nn.init.constant_(module.bias, 0) + + +class Conv2d(torch.nn.Conv2d): + """ + A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. + """ + + def __init__(self, *args, **kwargs): + """ + Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: + + Args: + norm (nn.Module, optional): a normalization layer + activation (callable(Tensor) -> Tensor): a callable activation function + + It assumes that norm layer is used before activation. + """ + norm = kwargs.pop("norm", None) + activation = kwargs.pop("activation", None) + super().__init__(*args, **kwargs) + + self.norm = norm + self.activation = activation + + def forward(self, x): + x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + if self.norm is not None: + x = self.norm(x) + if self.activation is not None: + x = self.activation(x) + return x + + +class SelfAttentionLayer(nn.Module): + def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + + return tgt + + def forward_pre( + self, + tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + tgt2 = self.norm(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + + return tgt + + def forward( + self, + tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask, query_pos) + return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask, query_pos) + + +class CrossAttentionLayer(nn.Module): + def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): + super().__init__() + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + tgt, + memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + + return tgt + + def forward_pre( + self, + tgt, + memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + tgt2 = self.norm(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout(tgt2) + + return tgt + + def forward( + self, + tgt, + memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) + + +class FFNLayer(nn.Module): + def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False): + super().__init__() + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm = nn.LayerNorm(d_model) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt): + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt + + def forward_pre(self, tgt): + tgt2 = self.norm(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout(tgt2) + return tgt + + def forward(self, tgt): + if self.normalize_before: + return self.forward_pre(tgt) + return self.forward_post(tgt) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class MultiScaleMaskedTransformerDecoder(nn.Module): + def __init__( + self, + in_channels, + mask_classification=True, + *, + num_classes: int, + hidden_dim: int, + num_queries: int, + nheads: int, + dim_feedforward: int, + dec_layers: int, + pre_norm: bool, + mask_dim: int, + enforce_input_project: bool, + ): + """ + NOTE: this interface is experimental. + Args: + in_channels: channels of the input features + mask_classification: whether to add mask classifier or not + num_classes: number of classes + hidden_dim: Transformer feature dimension + num_queries: number of queries + nheads: number of heads + dim_feedforward: feature dimension in feedforward network + enc_layers: number of Transformer encoder layers + dec_layers: number of Transformer decoder layers + pre_norm: whether to use pre-LayerNorm or not + mask_dim: mask feature dimension + enforce_input_project: add input project 1x1 conv even if input + channels and hidden dim is identical + """ + super().__init__() + + assert mask_classification, "Only support mask classification model" + self.mask_classification = mask_classification + + # positional encoding + N_steps = hidden_dim // 2 + self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) + + # define Transformer decoder here + self.num_heads = nheads + self.num_layers = dec_layers + self.transformer_self_attention_layers = nn.ModuleList() + self.transformer_cross_attention_layers = nn.ModuleList() + self.transformer_ffn_layers = nn.ModuleList() + + for _ in range(self.num_layers): + self.transformer_self_attention_layers.append( + SelfAttentionLayer( + d_model=hidden_dim, + nhead=nheads, + dropout=0.0, + normalize_before=pre_norm, + ) + ) + + self.transformer_cross_attention_layers.append( + CrossAttentionLayer( + d_model=hidden_dim, + nhead=nheads, + dropout=0.0, + normalize_before=pre_norm, + ) + ) + + self.transformer_ffn_layers.append( + FFNLayer( + d_model=hidden_dim, + dim_feedforward=dim_feedforward, + dropout=0.0, + normalize_before=pre_norm, + ) + ) + + self.post_norm = nn.LayerNorm(hidden_dim) + + self.num_queries = num_queries + # learnable query features + self.query_feat = nn.Embedding(num_queries, hidden_dim) + # learnable query p.e. + self.query_embed = nn.Embedding(num_queries, hidden_dim) + + # level embedding (we always use 3 scales) + self.num_feature_levels = 3 + self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) + self.input_proj = nn.ModuleList() + for _ in range(self.num_feature_levels): + if in_channels != hidden_dim or enforce_input_project: + self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1)) + c2_xavier_fill(self.input_proj[-1]) + else: + self.input_proj.append(nn.Sequential()) + + # output FFNs + if self.mask_classification: + self.class_embed = nn.Linear(hidden_dim, num_classes + 1) + self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) + + def forward(self, x, mask_features, mask=None): + # x is a list of multi-scale feature + assert len(x) == self.num_feature_levels + src = [] + pos = [] + size_list = [] + + # disable mask, it does not affect performance + del mask + + for i in range(self.num_feature_levels): + size_list.append(x[i].shape[-2:]) + pos.append(self.pe_layer(x[i], None).flatten(2)) + src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) + + # flatten NxCxHxW to HWxNxC + pos[-1] = pos[-1].permute(2, 0, 1) + src[-1] = src[-1].permute(2, 0, 1) + + _, bs, _ = src[0].shape + + # QxNxC + query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) + output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) + + predictions_class = [] + predictions_mask = [] + + # prediction heads on learnable query features + outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads( + output, mask_features, attn_mask_target_size=size_list[0] + ) + predictions_class.append(outputs_class) + predictions_mask.append(outputs_mask) + + for i in range(self.num_layers): + level_index = i % self.num_feature_levels + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + # attention: cross-attention first + output = self.transformer_cross_attention_layers[i]( + output, + src[level_index], + memory_mask=attn_mask, + memory_key_padding_mask=None, # here we do not apply masking on padded region + pos=pos[level_index], + query_pos=query_embed, + ) + + output = self.transformer_self_attention_layers[i]( + output, tgt_mask=None, tgt_key_padding_mask=None, query_pos=query_embed + ) + + # FFN + output = self.transformer_ffn_layers[i](output) + + outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads( + output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels] + ) + predictions_class.append(outputs_class) + predictions_mask.append(outputs_mask) + + assert len(predictions_class) == self.num_layers + 1 + + out = { + "pred_logits": predictions_class[-1], + "pred_masks": predictions_mask[-1], + "aux_outputs": self._set_aux_loss( + predictions_class if self.mask_classification else None, predictions_mask + ), + } + return out + + def forward_prediction_heads(self, output, mask_features, attn_mask_target_size): + decoder_output = self.post_norm(output) + decoder_output = decoder_output.transpose(0, 1) + outputs_class = self.class_embed(decoder_output) + mask_embed = self.mask_embed(decoder_output) + outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) + + # NOTE: prediction is of higher-resolution + # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW] + attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False) + # must use bool type + # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. + attn_mask = ( + attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5 + ).bool() + attn_mask = attn_mask.detach() + + return outputs_class, outputs_mask, attn_mask + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_seg_masks): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + if self.mask_classification: + return [{"pred_logits": a, "pred_masks": b} for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])] + else: + return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/heads/pixel_decoder.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/heads/pixel_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..693001db18b44f50ab966d5470d1f2fb1ec6ed03 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/heads/pixel_decoder.py @@ -0,0 +1,413 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.init import normal_ +from torch.amp import autocast + +from dinov3.eval.segmentation.models.utils.batch_norm import get_norm +from dinov3.eval.segmentation.models.utils.position_encoding import PositionEmbeddingSine +from dinov3.eval.segmentation.models.utils.transformer import _get_clones, _get_activation_fn +from dinov3.eval.segmentation.models.utils.ms_deform_attn import MSDeformAttn + + +def c2_xavier_fill(module: nn.Module) -> None: + """ + Initialize `module.weight` using the "XavierFill" implemented in Caffe2. + Also initializes `module.bias` to 0. + + Args: + module (torch.nn.Module): module to initialize. + """ + # Caffe2 implementation of XavierFill in fact + # corresponds to kaiming_uniform_ in PyTorch + # pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Module, Tensor]`. + nn.init.kaiming_uniform_(module.weight, a=1) + if module.bias is not None: + # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, + # torch.Tensor]`. + nn.init.constant_(module.bias, 0) + + +class Conv2d(torch.nn.Conv2d): + """ + A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. + """ + + def __init__(self, *args, **kwargs): + """ + Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: + + Args: + norm (nn.Module, optional): a normalization layer + activation (callable(Tensor) -> Tensor): a callable activation function + + It assumes that norm layer is used before activation. + """ + norm = kwargs.pop("norm", None) + activation = kwargs.pop("activation", None) + super().__init__(*args, **kwargs) + + self.norm = norm + self.activation = activation + + def forward(self, x): + # torchscript does not support SyncBatchNorm yet + # https://github.com/pytorch/pytorch/issues/40507 + # and we skip these codes in torchscript since: + # 1. currently we only support torchscript in evaluation mode + # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or + # later version, `Conv2d` in these PyTorch versions has already supported empty inputs. + # if not torch.jit.is_scripting(): + # # Dynamo doesn't support context managers yet + # is_dynamo_compiling = check_if_dynamo_compiling() + # if not is_dynamo_compiling: + # with warnings.catch_warnings(record=True): + # if x.numel() == 0 and self.training: + # # https://github.com/pytorch/pytorch/issues/12013 + # assert not isinstance( + # self.norm, torch.nn.SyncBatchNorm + # ), "SyncBatchNorm does not support empty inputs!" + + x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + if self.norm is not None: + x = self.norm(x) + if self.activation is not None: + x = self.activation(x) + return x + + +# MSDeformAttn Transformer encoder in deformable detr +class MSDeformAttnTransformerEncoderOnly(nn.Module): + def __init__( + self, + d_model=256, + nhead=8, + num_encoder_layers=6, + dim_feedforward=1024, + dropout=0.1, + activation="relu", + num_feature_levels=4, + enc_n_points=4, + ): + super().__init__() + + self.d_model = d_model + self.nhead = nhead + + encoder_layer = MSDeformAttnTransformerEncoderLayer( + d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points + ) + self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers) + + self.level_encoding = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + normal_(self.level_encoding) + + def get_valid_ratio(self, mask): + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def forward(self, srcs, pos_embeds): + masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs] + # prepare input for encoder + src_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): + bs, c, h, w = src.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + src = src.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_encoding[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + src_flatten.append(src) + mask_flatten.append(mask) + src_flatten = torch.cat(src_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + + # encoder + memory = self.encoder( + src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten + ) + + return memory, spatial_shapes, level_start_index + + +class MSDeformAttnTransformerEncoderLayer(nn.Module): + def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4): + super().__init__() + + # self attention + self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout2 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout3 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, src): + src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) + src = src + self.dropout3(src2) + src = self.norm2(src) + return src + + def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): + # self attention + src2 = self.self_attn( + self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask + ) + src = src + self.dropout1(src2) + src = self.norm1(src) + + # ffn + src = self.forward_ffn(src) + + return src + + +class MSDeformAttnTransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None): + output = src + reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) + for _, layer in enumerate(self.layers): + output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) + + return output + + +# @SEM_SEG_HEADS_REGISTRY.register() +class MSDeformAttnPixelDecoder(nn.Module): + # @configurable + def __init__( + self, + input_shape: Dict[str, Tuple[int]], # ShapeSpec: [channels, height, width, stride] + *, + transformer_dropout: float, + transformer_nheads: int, + transformer_dim_feedforward: int, + transformer_enc_layers: int, + conv_dim: int, + mask_dim: int, + norm: Optional[Union[str, Callable]] = None, + # deformable transformer encoder args + transformer_in_features: List[str], + common_stride: int, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + transformer_dropout: dropout probability in transformer + transformer_nheads: number of heads in transformer + transformer_dim_feedforward: dimension of feedforward network + transformer_enc_layers: number of transformer encoder layers + conv_dims: number of output channels for the intermediate conv layers. + mask_dim: number of output channels for the final conv layer. + norm (str or callable): normalization for all conv layers + """ + super().__init__() + transformer_input_shape = {k: v for k, v in input_shape.items() if k in transformer_in_features} + + # this is the input shape of pixel decoder # ShapeSpec: [channels, height, width, stride] + input_shape = sorted(input_shape.items(), key=lambda x: x[1][-1]) + self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" + self.feature_strides = [v[-1] for k, v in input_shape] + self.feature_channels = [v[0] for k, v in input_shape] + + # this is the input shape of transformer encoder (could use less features than pixel decoder + transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1][-1]) + self.transformer_in_features = [k for k, v in transformer_input_shape] # starting from "res2" to "res5" + transformer_in_channels = [v[0] for k, v in transformer_input_shape] + self.transformer_feature_strides = [v[-1] for k, v in transformer_input_shape] # to decide extra FPN layers + + self.transformer_num_feature_levels = 3 # TODO switch with len(self.transformer_in_features) + if self.transformer_num_feature_levels > 1: + input_proj_list = [] + # from low resolution to high resolution (res5 -> res2) + for in_channels in transformer_in_channels[::-1][:-1]: # TODO remove [:-1] + input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, conv_dim, kernel_size=1), + nn.GroupNorm(32, conv_dim), + ) + ) + self.input_convs = nn.ModuleList(input_proj_list) + else: + self.input_convs = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1), + nn.GroupNorm(32, conv_dim), + ) + ] + ) + + for proj in self.input_convs: + nn.init.xavier_uniform_(proj[0].weight, gain=1) + nn.init.constant_(proj[0].bias, 0) + + self.encoder = MSDeformAttnTransformerEncoderOnly( + d_model=conv_dim, + dropout=transformer_dropout, + nhead=transformer_nheads, + dim_feedforward=transformer_dim_feedforward, + num_encoder_layers=transformer_enc_layers, + num_feature_levels=self.transformer_num_feature_levels, + ) + N_steps = conv_dim // 2 + self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) + + self.mask_dim = mask_dim + # use 1x1 conv instead + self.mask_feature = Conv2d( + conv_dim, + mask_dim, + kernel_size=1, + stride=1, + padding=0, + ) + c2_xavier_fill(self.mask_feature) + + self.maskformer_num_feature_levels = 3 # always use 3 scales + self.common_stride = common_stride + + # extra fpn levels + stride = min(self.transformer_feature_strides) + self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride)) + + lateral_convs = [] + output_convs = [] + + use_bias = norm == "" + for idx, in_channels in enumerate(self.feature_channels[:1]): # TODO self.num_fpn_levels]): + lateral_norm = get_norm(norm, conv_dim) + output_norm = get_norm(norm, conv_dim) + + lateral_conv = Conv2d(in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm) + output_conv = Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + c2_xavier_fill(lateral_conv) + c2_xavier_fill(output_conv) + # self.add_module("lateral_convs".format(idx + 1), lateral_conv) # TODO replace "adapter_{}" + # self.add_module("output_convs".format(idx + 1), output_conv) # TODO replace layer_{}"" + + lateral_convs.append(lateral_conv) + output_convs.append(output_conv) + # Place convs into top-down order (from low to high resolution) + # to make the top-down computation in forward clearer. + self.lateral_convs = nn.ModuleList(lateral_convs[::-1]) + self.output_convs = nn.ModuleList(output_convs[::-1]) + + @autocast(device_type="cuda", enabled=False) + def forward_features(self, features): + srcs = [] + pos = [] + # Reverse feature maps into top-down order (from low to high resolution) + for idx, f in enumerate(self.transformer_in_features[::-1][:-1]): # TODO remove [:-1] + x = features[f].float() # deformable detr does not support half precision + srcs.append(self.input_convs[idx](x)) + pos.append(self.pe_layer(x)) + + y, spatial_shapes, level_start_index = self.encoder(srcs, pos) + bs = y.shape[0] + + split_size_or_sections = [None] * self.transformer_num_feature_levels + for i in range(self.transformer_num_feature_levels): + if i < self.transformer_num_feature_levels - 1: + split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i] + else: + split_size_or_sections[i] = y.shape[1] - level_start_index[i] + y = torch.split(y, split_size_or_sections, dim=1) + + out = [] + multi_scale_features = [] + num_cur_levels = 0 + for i, z in enumerate(y): + out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1])) + + # append `out` with extra FPN levels + # Reverse feature maps into top-down order (from low to high resolution) + for idx, f in enumerate(self.in_features[0]): # TODO re put [:self.num_fpn_levels][::-1]): + x = features[f].float() + lateral_conv = self.lateral_convs[idx] + output_conv = self.output_convs[idx] + cur_fpn = lateral_conv(x) + # Following FPN implementation, we use nearest upsampling here + y = cur_fpn + F.interpolate(out[-1], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False) + y = output_conv(y) + out.append(y) + + for o in out: + if num_cur_levels < self.maskformer_num_feature_levels: + multi_scale_features.append(o) + num_cur_levels += 1 + + return self.mask_feature(out[-1]), out[0], multi_scale_features diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/batch_norm.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/batch_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..1dc0c94cf166e5595c503c2d2ff9b5e4a8c8a998 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/batch_norm.py @@ -0,0 +1,355 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# Copyright (c) Facebook, Inc. and its affiliates. +import torch +import torch.distributed as dist +from torch import nn +from torch.nn import functional as F +from torch.nn import BatchNorm2d + +import dinov3.distributed as distributed + + +class FrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + It contains non-trainable buffers called + "weight" and "bias", "running_mean", "running_var", + initialized to perform identity transformation. + + The pre-trained backbone models from Caffe2 only contain "weight" and "bias", + which are computed from the original four parameters of BN. + The affine transform `x * weight + bias` will perform the equivalent + computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. + When loading a backbone model from Caffe2, "running_mean" and "running_var" + will be left unchanged as identity transformation. + + Other pre-trained backbone models may contain all 4 parameters. + + The forward is implemented by `F.batch_norm(..., training=False)`. + """ + + _version = 3 + + def __init__(self, num_features, eps=1e-5): + super().__init__() + self.num_features = num_features + self.eps = eps + self.register_buffer("weight", torch.ones(num_features)) + self.register_buffer("bias", torch.zeros(num_features)) + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features) - eps) + self.register_buffer("num_batches_tracked", None) + + def forward(self, x): + if x.requires_grad: + # When gradients are needed, F.batch_norm will use extra memory + # because its backward op computes gradients for weight/bias as well. + scale = self.weight * (self.running_var + self.eps).rsqrt() + bias = self.bias - self.running_mean * scale + scale = scale.reshape(1, -1, 1, 1) + bias = bias.reshape(1, -1, 1, 1) + out_dtype = x.dtype # may be half + return x * scale.to(out_dtype) + bias.to(out_dtype) + else: + # When gradients are not needed, F.batch_norm is a single fused op + # and provide more optimization opportunities. + return F.batch_norm( + x, + self.running_mean, + self.running_var, + self.weight, + self.bias, + training=False, + eps=self.eps, + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version < 2: + # No running_mean/var in early versions + # This will silent the warnings + if prefix + "running_mean" not in state_dict: + state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) + if prefix + "running_var" not in state_dict: + state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def __repr__(self): + return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) + + @classmethod + def convert_frozen_batchnorm(cls, module): + """ + Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. + + Args: + module (torch.nn.Module): + + Returns: + If module is BatchNorm/SyncBatchNorm, returns a new module. + Otherwise, in-place convert module and return it. + + Similar to convert_sync_batchnorm in + https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py + """ + bn_module = nn.modules.batchnorm + bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) + res = module + if isinstance(module, bn_module): + res = cls(module.num_features) + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + res.num_batches_tracked = module.num_batches_tracked + else: + for name, child in module.named_children(): + new_child = cls.convert_frozen_batchnorm(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + @classmethod + def convert_frozenbatchnorm2d_to_batchnorm2d(cls, module: nn.Module) -> nn.Module: + """ + Convert all FrozenBatchNorm2d to BatchNorm2d + + Args: + module (torch.nn.Module): + + Returns: + If module is FrozenBatchNorm2d, returns a new module. + Otherwise, in-place convert module and return it. + + This is needed for quantization. + """ + + res = module + if isinstance(module, FrozenBatchNorm2d): + res = torch.nn.BatchNorm2d(module.num_features, module.eps) + + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data.clone().detach() + res.running_var.data = module.running_var.data.clone().detach() + res.eps = module.eps + res.num_batches_tracked = module.num_batches_tracked + else: + for name, child in module.named_children(): + new_child = cls.convert_frozenbatchnorm2d_to_batchnorm2d(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + +def get_norm(norm, out_channels): + """ + Args: + norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; + or a callable that takes a channel number and returns + the normalization layer as a nn.Module. + + Returns: + nn.Module or None: the normalization layer + """ + if norm is None: + return None + if isinstance(norm, str): + if len(norm) == 0: + return None + norm = { + "BN": BatchNorm2d, + # Fixed in https://github.com/pytorch/pytorch/pull/36382 + "SyncBN": nn.SyncBatchNorm, + "FrozenBN": FrozenBatchNorm2d, + "GN": lambda channels: nn.GroupNorm(32, channels), + # for debugging: + "nnSyncBN": nn.SyncBatchNorm, + "naiveSyncBN": NaiveSyncBatchNorm, + # expose stats_mode N as an option to caller, required for zero-len inputs + "naiveSyncBN_N": lambda channels: NaiveSyncBatchNorm(channels, stats_mode="N"), + "LN": lambda channels: LayerNorm(channels), + }[norm] + return norm(out_channels) + + +class NaiveSyncBatchNorm(BatchNorm2d): + """ + In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient + when the batch size on each worker is different. + (e.g., when scale augmentation is used, or when it is applied to mask head). + + This is a slower but correct alternative to `nn.SyncBatchNorm`. + + Note: + There isn't a single definition of Sync BatchNorm. + + When ``stats_mode==""``, this module computes overall statistics by using + statistics of each worker with equal weight. The result is true statistics + of all samples (as if they are all on one worker) only when all workers + have the same (N, H, W). This mode does not support inputs with zero batch size. + + When ``stats_mode=="N"``, this module computes overall statistics by weighting + the statistics of each worker by their ``N``. The result is true statistics + of all samples (as if they are all on one worker) only when all workers + have the same (H, W). It is slower than ``stats_mode==""``. + + Even though the result of this module may not be the true statistics of all samples, + it may still be reasonable because it might be preferrable to assign equal weights + to all workers, regardless of their (H, W) dimension, instead of putting larger weight + on larger images. From preliminary experiments, little difference is found between such + a simplified implementation and an accurate computation of overall mean & variance. + """ + + def __init__(self, *args, stats_mode="", **kwargs): + super().__init__(*args, **kwargs) + assert stats_mode in ["", "N"] + self._stats_mode = stats_mode + + def forward(self, input): + if distributed.get_world_size() == 1 or not self.training: + return super().forward(input) + + B, C = input.shape[0], input.shape[1] + + half_input = input.dtype == torch.float16 + if half_input: + # fp16 does not have good enough numerics for the reduction here + input = input.float() + mean = torch.mean(input, dim=[0, 2, 3]) + meansqr = torch.mean(input * input, dim=[0, 2, 3]) + + if self._stats_mode == "": + assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.' + vec = torch.cat([mean, meansqr], dim=0) + vec = torch.distributed.nn.all_reduce(vec) * (1.0 / dist.get_world_size()) + mean, meansqr = torch.split(vec, C) + momentum = self.momentum + else: + if B == 0: + vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype) + vec = vec + input.sum() # make sure there is gradient w.r.t input + else: + vec = torch.cat( + [ + mean, + meansqr, + torch.ones([1], device=mean.device, dtype=mean.dtype), + ], + dim=0, + ) + vec = torch.distributed.nn.all_reduce(vec * B) + + total_batch = vec[-1].detach() + momentum = total_batch.clamp(max=1) * self.momentum # no update if total_batch is 0 + mean, meansqr, _ = torch.split(vec / total_batch.clamp(min=1), C) # avoid div-by-zero + + var = meansqr - mean * mean + invstd = torch.rsqrt(var + self.eps) + scale = self.weight * invstd + bias = self.bias - mean * scale + scale = scale.reshape(1, -1, 1, 1) + bias = bias.reshape(1, -1, 1, 1) + + self.running_mean += momentum * (mean.detach() - self.running_mean) + self.running_var += momentum * (var.detach() - self.running_var) + ret = input * scale + bias + if half_input: + ret = ret.half() + return ret + + +class CycleBatchNormList(nn.ModuleList): + """ + Implement domain-specific BatchNorm by cycling. + + When a BatchNorm layer is used for multiple input domains or input + features, it might need to maintain a separate test-time statistics + for each domain. See Sec 5.2 in :paper:`rethinking-batchnorm`. + + This module implements it by using N separate BN layers + and it cycles through them every time a forward() is called. + + NOTE: The caller of this module MUST guarantee to always call + this module by multiple of N times. Otherwise its test-time statistics + will be incorrect. + """ + + def __init__(self, length: int, bn_class=nn.BatchNorm2d, **kwargs): + """ + Args: + length: number of BatchNorm layers to cycle. + bn_class: the BatchNorm class to use + kwargs: arguments of the BatchNorm class, such as num_features. + """ + self._affine = kwargs.pop("affine", True) + super().__init__([bn_class(**kwargs, affine=False) for k in range(length)]) + if self._affine: + # shared affine, domain-specific BN + channels = self[0].num_features + self.weight = nn.Parameter(torch.ones(channels)) + self.bias = nn.Parameter(torch.zeros(channels)) + self._pos = 0 + + def forward(self, x): + ret = self[self._pos](x) + self._pos = (self._pos + 1) % len(self) + + if self._affine: + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + return ret * w + b + else: + return ret + + def extra_repr(self): + return f"affine={self._affine}" + + +class LayerNorm(nn.Module): + """ + A LayerNorm variant, popularized by Transformers, that performs point-wise mean and + variance normalization over the channel dimension for inputs that have shape + (batch_size, channels, height, width). + https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 + """ + + def __init__(self, normalized_shape, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ms_deform_attn.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ms_deform_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..fd662c9b4ac38f2658944a66eef49869fdea96f5 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ms_deform_attn.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import math +import warnings + +import torch +import torch.nn.functional as F +from torch import nn +from torch.autograd import Function +from torch.amp import custom_fwd, custom_bwd + +from torch.autograd.function import once_differentiable +from torch.nn.init import constant_, xavier_uniform_ + +try: + import MultiScaleDeformableAttention as MSDA +except ImportError: + # if we just care about inference, we don't need + # the compiled extension for multi-scale deformable attention + MSDA = None + + +class MSDeformAttnFunction(Function): + @staticmethod + @custom_fwd(device_type="cuda", cast_inputs=torch.float32) + def forward( + ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step + ): + ctx.im2col_step = im2col_step + output = ms_deform_attn_core_pytorch( + value, + value_spatial_shapes, + # value_level_start_index, + sampling_locations, + attention_weights, + ) + ctx.save_for_backward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights + ) + return output + + @staticmethod + @once_differentiable + @custom_bwd(device_type="cuda") + def backward(ctx, grad_output): + if MSDA is None: + raise RuntimeError( + "MultiScaleDeformableAttention is not available, " + "please compile with CUDA if you want to train a " + "segmentation head with deformable attention" + ) + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = MSDA.ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output, + ctx.im2col_step, + ) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_ * D_, Lq_) + return output.transpose(1, 2).contiguous() + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, ratio=1.0): + """Multi-Scale Deformable Attention Module. + + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError("d_model must be divisible by n_heads, but got {} and {}".format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 + # which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn( + "You'd better set d_model in MSDeformAttn to make " + "the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation." + ) + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + self.ratio = ratio + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, int(d_model * ratio)) + self.output_proj = nn.Linear(int(d_model * ratio), d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.0) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.0) + constant_(self.attention_weights.bias.data, 0.0) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.0) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.0) + + def forward( + self, + query, + reference_points, + input_flatten, + input_spatial_shapes, + input_level_start_index, + input_padding_mask=None, + ): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + + value = value.view(N, Len_in, self.n_heads, int(self.ratio * self.d_model) // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1]) + ) + output = MSDeformAttnFunction.apply( + value, + input_spatial_shapes, + input_level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + output = self.output_proj(output) + return output diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/functions/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/functions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ce67775be80031ef924d64478192cbf616fc988d --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/functions/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn_func import MSDeformAttnFunction diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/functions/ms_deform_attn_func.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/functions/ms_deform_attn_func.py new file mode 100644 index 0000000000000000000000000000000000000000..fa343cbed0f1e28af7f9006fe5bc9cbaabbe6aa8 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/functions/ms_deform_attn_func.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +import MultiScaleDeformableAttention as MSDA + + +class MSDeformAttnFunction(Function): + @staticmethod + def forward( + ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step + ): + ctx.im2col_step = im2col_step + output = MSDA.ms_deform_attn_forward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step + ) + ctx.save_for_backward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = MSDA.ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output, + ctx.im2col_step, + ) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_ * D_, Lq_) + return output.transpose(1, 2).contiguous() diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/modules/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7449bf1949737d220c5917d1df6d924842251c6d --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/modules/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn import MSDeformAttn diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/modules/ms_deform_attn.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/modules/ms_deform_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b7d70c21f89146b461758065994d5ae414594d --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/modules/ms_deform_attn.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import warnings +import math + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import xavier_uniform_, constant_ + +from ..functions import MSDeformAttnFunction + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): + """ + Multi-Scale Deformable Attention Module + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError("d_model must be divisible by n_heads, but got {} and {}".format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn( + "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation." + ) + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.0) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.0) + constant_(self.attention_weights.bias.data, 0.0) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.0) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.0) + + def forward( + self, + query, + reference_points, + input_flatten, + input_spatial_shapes, + input_level_start_index, + input_padding_mask=None, + ): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) + # N, Len_q, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1]) + ) + output = MSDeformAttnFunction.apply( + value, + input_spatial_shapes, + input_level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + output = self.output_proj(output) + return output diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/setup.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..e7dc7d64e6966bea6e6800988c7bb2f3f231e2e5 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/setup.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +import os +import glob + +import torch + +from torch.utils.cpp_extension import CUDA_HOME +from torch.utils.cpp_extension import CppExtension +from torch.utils.cpp_extension import CUDAExtension + +from setuptools import find_packages +from setuptools import setup + +requirements = ["torch", "torchvision"] + + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "src") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + extra_compile_args = {"cxx": []} + define_macros = [] + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + else: + raise NotImplementedError("Cuda is not availabel") + + sources = [os.path.join(extensions_dir, s) for s in sources] + include_dirs = [extensions_dir] + ext_modules = [ + extension( + "MultiScaleDeformableAttention", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + return ext_modules + + +setup( + name="MultiScaleDeformableAttention", + version="1.0", + author="Weijie Su", + url="https://github.com/fundamentalvision/Deformable-DETR", + description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", + packages=find_packages( + exclude=( + "configs", + "tests", + ) + ), + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/cpu/ms_deform_attn_cpu.cpp b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/cpu/ms_deform_attn_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d24ce5c2792d56575ddb4e58d9e572eb77d6a028 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/cpu/ms_deform_attn_cpu.cpp @@ -0,0 +1,46 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This software may be used and distributed in accordance with +// the terms of the DINOv3 License Agreement. + +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include + +#include +#include + + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/cpu/ms_deform_attn_cpu.h b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/cpu/ms_deform_attn_cpu.h new file mode 100644 index 0000000000000000000000000000000000000000..cf952dc46782435448e3120dbfc485544e82a620 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/cpu/ms_deform_attn_cpu.h @@ -0,0 +1,38 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This software may be used and distributed in accordance with +// the terms of the DINOv3 License Agreement. + +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + + diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/cuda/ms_deform_attn_cuda.cu b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/cuda/ms_deform_attn_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..fe112e464c0ba1c8f10a6c1c393bb23a8c846387 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/cuda/ms_deform_attn_cuda.cu @@ -0,0 +1,158 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This software may be used and distributed in accordance with +// the terms of the DINOv3 License Agreement. + +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include +#include "cuda/ms_deform_im2col_cuda.cuh" + +#include +#include +#include +#include + + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + columns.data()); + + })); + } + + output = output.view({batch, num_query, num_heads*channels}); + + return output; +} + + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto grad_value = at::zeros_like(value); + auto grad_sampling_loc = at::zeros_like(sampling_loc); + auto grad_attn_weight = at::zeros_like(attn_weight); + + const int batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), + grad_output_g.data(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + grad_value.data() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); + + })); + } + + return { + grad_value, grad_sampling_loc, grad_attn_weight + }; +} diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/cuda/ms_deform_attn_cuda.h b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/cuda/ms_deform_attn_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..080c6a2b17733da8a8b04a4a64c3aa1e81904b26 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/cuda/ms_deform_attn_cuda.h @@ -0,0 +1,35 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This software may be used and distributed in accordance with +// the terms of the DINOv3 License Agreement. + +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/cuda/ms_deform_im2col_cuda.cuh b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/cuda/ms_deform_im2col_cuda.cuh new file mode 100644 index 0000000000000000000000000000000000000000..736841fcfbeaf2d6e494fff7e556b6d8b8b73499 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/cuda/ms_deform_im2col_cuda.cuh @@ -0,0 +1,1332 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This software may be used and distributed in accordance with +// the terms of the DINOv3 License Agreement. + +/*! +************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************** +* Modified from DCN (https://github.com/msracver/Deformable-ConvNets) +* Copyright (c) 2018 Microsoft +************************************************************************** +*/ + +#include +#include +#include + +#include +#include + +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) +{ + return (N + num_threads - 1) / num_threads; +} + + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + + +template +__global__ void ms_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockSize; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockSize/2; s>0; s>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +void ms_deformable_im2col_cuda(cudaStream_t stream, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* data_col) +{ + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +template +void ms_deformable_col2im_cuda(cudaStream_t stream, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t * data_spatial_shapes, + const int64_t * data_level_start_index, + const scalar_t * data_sampling_loc, + const scalar_t * data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) + { + if ((channels & 1023) == 0) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_gm + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + else{ + switch(channels) + { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + default: + if (channels < 64) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/ms_deform_attn.h b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/ms_deform_attn.h new file mode 100644 index 0000000000000000000000000000000000000000..5fd1f667dba53df0c4797cc7101c17a4a2662a3f --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/ms_deform_attn.h @@ -0,0 +1,67 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This software may be used and distributed in accordance with +// the terms of the DINOv3 License Agreement. + +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once + +#include "cpu/ms_deform_attn_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/ms_deform_attn_cuda.h" +#endif + + +at::Tensor +ms_deform_attn_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_forward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector +ms_deform_attn_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_backward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/vision.cpp b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/vision.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e1af72c89dfc2e4c03732d4bda80b05184a44a39 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/src/vision.cpp @@ -0,0 +1,21 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This software may be used and distributed in accordance with +// the terms of the DINOv3 License Agreement. + +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "ms_deform_attn.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); + m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); +} diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/test.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/test.py new file mode 100644 index 0000000000000000000000000000000000000000..e18f4742b3d7651fd12474fcfafa7bf2c89b49e8 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/ops/test.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +from torch.autograd import gradcheck + +from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch + + +N, M, D = 1, 2, 2 +Lq, L, P = 2, 2, 2 +shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() +level_start_index = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1])) +S = sum([(H * W).item() for H, W in shapes]) + + +torch.manual_seed(3) + + +@torch.no_grad() +def check_forward_equal_with_pytorch_double(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ( + ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()) + .detach() + .cpu() + ) + output_cuda = ( + MSDeformAttnFunction.apply( + value.double(), + shapes, + level_start_index, + sampling_locations.double(), + attention_weights.double(), + im2col_step, + ) + .detach() + .cpu() + ) + fwdok = torch.allclose(output_cuda, output_pytorch) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print( + f"* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}" + ) + + +@torch.no_grad() +def check_forward_equal_with_pytorch_float(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() + output_cuda = ( + MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step) + .detach() + .cpu() + ) + fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print( + f"* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}" + ) + + +def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): + value = torch.rand(N, S, M, channels).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + func = MSDeformAttnFunction.apply + + value.requires_grad = grad_value + sampling_locations.requires_grad = grad_sampling_loc + attention_weights.requires_grad = grad_attn_weight + + gradok = gradcheck( + func, + ( + value.double(), + shapes, + level_start_index, + sampling_locations.double(), + attention_weights.double(), + im2col_step, + ), + ) + + print(f"* {gradok} check_gradient_numerical(D={channels})") + + +if __name__ == "__main__": + check_forward_equal_with_pytorch_double() + check_forward_equal_with_pytorch_float() + + for channels in [30, 32, 64, 71, 1025, 2048, 3096]: + check_gradient_numerical(channels, True, True, True) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/position_encoding.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..a8f6abc79bcb285f9e4ad8436dc36a92dc7ee8b9 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/position_encoding.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# Copyright (c) Facebook, Inc. and its affiliates. +# # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py +""" +Various positional encodings for the transformer. +""" + +import math + +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask=None): + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + def __repr__(self, _repr_indent=4): + head = "Positional encoding " + self.__class__.__name__ + body = [ + "num_pos_feats: {}".format(self.num_pos_feats), + "temperature: {}".format(self.temperature), + "normalize: {}".format(self.normalize), + "scale: {}".format(self.scale), + ] + # _repr_indent = 4 + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/transformer.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..8a6ff187d144ee593c8c8b72d07bccbe12fa8685 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/models/utils/transformer.py @@ -0,0 +1,359 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py +""" +Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" + +import copy +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class Transformer(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + if mask is not None: + mask = mask.flatten(1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed) + return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) + + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos, + ) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + return self.forward_post( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/run.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/run.py new file mode 100644 index 0000000000000000000000000000000000000000..6e4ef010167e8bfb0fb64f9d109bd4e8e81fc479 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/run.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from omegaconf import OmegaConf +import os +import sys +from typing import Any + +from dinov3.eval.segmentation.config import SegmentationConfig +from dinov3.eval.segmentation.eval import test_segmentation +from dinov3.eval.segmentation.train import train_segmentation +from dinov3.eval.helpers import args_dict_to_dataclass, cli_parser, write_results +from dinov3.eval.setup import load_model_and_context +from dinov3.run.init import job_context + + +logger = logging.getLogger("dinov3") + +RESULTS_FILENAME = "results-semantic-segmentation.csv" +MAIN_METRICS = ["mIoU"] + + +def run_segmentation_with_dinov3( + backbone, + config, +): + if config.load_from: + logger.info("Testing model performance on a pretrained decoder head") + return test_segmentation(backbone=backbone, config=config) + assert config.decoder_head.type == "linear", "Only linear head is supported for training" + return train_segmentation(backbone=backbone, config=config) + + +def benchmark_launcher(eval_args: dict[str, object]) -> dict[str, Any]: + """Initialization of distributed and logging are preconditions for this method""" + if "config" in eval_args: # using a config yaml file, useful for training + base_config_path = eval_args.pop("config") + output_dir = eval_args["output_dir"] + base_config = OmegaConf.load(base_config_path) + structured_config = OmegaConf.structured(SegmentationConfig) + dataclass_config: SegmentationConfig = OmegaConf.to_object( + OmegaConf.merge( + structured_config, + base_config, + OmegaConf.create(eval_args), + ) + ) + else: # either using default values, or only adding some args to the command line + dataclass_config, output_dir = args_dict_to_dataclass(eval_args=eval_args, config_dataclass=SegmentationConfig) + backbone = None + if dataclass_config.model: + backbone, _ = load_model_and_context(dataclass_config.model, output_dir=output_dir) + else: + assert dataclass_config.load_from == "dinov3_vit7b16_ms" + logger.info(f"Segmentation Config:\n{OmegaConf.to_yaml(dataclass_config)}") + segmentation_file_path = os.path.join(output_dir, "segmentation_config.yaml") + OmegaConf.save(config=dataclass_config, f=segmentation_file_path) + results_dict = run_segmentation_with_dinov3(backbone=backbone, config=dataclass_config) + write_results(results_dict, output_dir, RESULTS_FILENAME) + return results_dict + + +def main(argv=None): + if argv is None: + argv = sys.argv[1:] + eval_args = cli_parser(argv) + with job_context(output_dir=eval_args["output_dir"]): + benchmark_launcher(eval_args=eval_args) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/schedulers.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..907f3bf8daf07fe17c2c7fc3127a933d813cb8e6 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/schedulers.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from inspect import signature +import math +from typing import Any, Literal + +import torch +from packaging.version import Version +from torch.optim import lr_scheduler as torch_schedulers +from torch.optim.optimizer import Optimizer + +TORCH_VERSION = Version(torch.__version__) + + +def annealing_cos(start, end, pct): + "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + + +def annealing_linear(start, end, pct): + "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0." + return (end - start) * pct + start + + +class WarmupOneCycleLR(torch_schedulers.LRScheduler): + def __init__( + self, + optimizer: Optimizer, + max_lr: float | None = None, + total_steps: int = 0, + warmup_iters: int = 0, + warmup_ratio: float = 0.0, + pct_start: float = 0.295, + anneal_strategy: Literal["cos", "linear"] = "cos", + base_momentum: float = 0.85, + max_momentum: float = 0.95, + div_factor: float = 25.0, + final_div_factor: float = 1000.0, + use_beta1: bool = True, + update_momentum: bool = True, + last_epoch: int = -1, + ): + """ + A variant of OneCycleLR with a warmup on top which potentially + replaces the first phase of the original OneCycleLR. + """ + self.warmup_iters = warmup_iters + self.warmup_ratio = warmup_ratio + self.max_lr = max_lr + self.min_point = float(pct_start * total_steps) + self.base_momentum = base_momentum + self.max_momentum = max_momentum + self.total_steps = total_steps + self.use_beta1 = use_beta1 + self.anneal_strategy = anneal_strategy + self.final_div_factor = final_div_factor + self.update_momentum = update_momentum + assert self.anneal_strategy in [ + "cos", + "linear", + ], f"Only cosine and linear-annealing strategy supported, got {self.anneal_strategy}" + assert total_steps > 0 + + # Initialize learning rate variables and momentum + for group in optimizer.param_groups: + if "initial_lr" not in group: + assert last_epoch == -1 + ml = group["lr"] + assert isinstance(ml, float) # makes sure that the variable is well updated + group["initial_lr"] = ml / div_factor + group["max_lr"] = ml + group["min_lr"] = group["initial_lr"] / final_div_factor + # initialize learning rate + group["lr"] = ml / final_div_factor if self.warmup_iters > 0 else group["initial_lr"] + if self.use_beta1: + group["betas"] = (self.max_momentum, *group["betas"][1:]) + elif self.update_momentum: + group["momentum"] = self.max_momentum + + super().__init__(optimizer, last_epoch) + + def _anneal_func(self, *args, **kwargs): + if self.anneal_strategy == "cos": + return annealing_cos(*args, **kwargs) + elif self.anneal_strategy == "linear": + return annealing_linear(*args, **kwargs) + + def _compute_lr_momentum(self, optimizer_param_group): + # torch.optim.lr_scheduler.LRScheduler does an initial + # step that sets self._step_count = 1 + step_num = (self._step_count - 1) if self.last_epoch != -1 else 0 + momentum = 0 + if step_num < self.warmup_iters: + if self.warmup_ratio: + k = (1 - step_num / self.warmup_iters) * (1 - self.warmup_ratio) + warmup_lr = optimizer_param_group["max_lr"] * (1 - k) + thelr = warmup_lr * (1 - step_num / self.total_steps) + else: + gmax = ( + optimizer_param_group["max_lr"] * (1 + math.cos(math.pi * step_num / float(self.total_steps))) / 2 + ) + thelr = optimizer_param_group["max_lr"] / self.final_div_factor + gmax * step_num / float( + self.warmup_iters + ) + else: + pct = (step_num - self.warmup_iters) / float(self.total_steps - self.warmup_iters) + step_num_to_use = step_num + momentum = self._anneal_func( + self.base_momentum, + self.max_momentum, + pct, + ) + if self.anneal_strategy == "cos": + step_num_to_use += 1 + thelr = self._anneal_func( + optimizer_param_group["max_lr"], + optimizer_param_group["min_lr"], + step_num_to_use / float(self.total_steps), + ) + return thelr, momentum + + def get_lr(self): + """Compute the learning rate of each parameter group.""" + if TORCH_VERSION >= Version("2.4.0"): + torch_schedulers._warn_get_lr_called_within_step(self) + + lrs = [] + step_num = self.last_epoch + + if step_num > self.total_steps: + raise ValueError( + f"Tried to step {step_num} times. The specified number of total steps is {self.total_steps}" # noqa: UP032 + ) + + for group in self.optimizer.param_groups: + computed_lr, computed_momentum = self._compute_lr_momentum(group) + lrs.append(computed_lr) # type: ignore[possibly-undefined] + if self.use_beta1: + group["betas"] = (computed_momentum, *group["betas"][1:]) # type: ignore[possibly-undefined] + elif self.update_momentum: + group["momentum"] = computed_momentum # type: ignore[possibly-undefined] + + return lrs + + +def build_scheduler( + scheduler_type: str, + optimizer: Optimizer, + lr: float, + total_iter: int, + constructor_kwargs: dict[str, Any], +): + _kwargs = {} + _kwargs.update(**constructor_kwargs) + constructor_fn = SCHEDULERS_DICT[scheduler_type] + accepted_kwargs = signature(constructor_fn).parameters.keys() + keywords = list(constructor_kwargs.keys()) + for key in keywords: + if key not in accepted_kwargs: + # ignore arguments that are not part of kwargs + _kwargs.pop(key) + if scheduler_type in ["OneCycleLR", "WarmupOneCycleLR", "WarmupMultiStepLR"]: + _kwargs.update( + dict( + max_lr=lr, + total_steps=total_iter, + ) + ) + elif scheduler_type in [ + "ConstantLR", + "LinearLR", + "PolynomialLR", + ]: + constructor_kwargs.update(dict(total_iters=total_iter)) + + return constructor_fn(optimizer, **_kwargs) + + +SCHEDULERS_DICT = { + "ConstantLR": torch_schedulers.ConstantLR, + "LinearLR": torch_schedulers.LinearLR, + "MultiStepLR": torch_schedulers.MultiStepLR, + "PolynomialLR": torch_schedulers.PolynomialLR, + "StepLR": torch_schedulers.StepLR, + "OneCycleLR": torch_schedulers.OneCycleLR, + "WarmupOneCycleLR": WarmupOneCycleLR, +} diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/train.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/train.py new file mode 100644 index 0000000000000000000000000000000000000000..cb73d0646a2b84fa3f02ab5c26c25cbae0024339 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/train.py @@ -0,0 +1,327 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from functools import partial +import logging +import numpy as np +import os +import random + +import torch +import torch.distributed as dist + +from dinov3.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader, make_dataset +import dinov3.distributed as distributed +from dinov3.eval.segmentation.eval import evaluate_segmentation_model +from dinov3.eval.segmentation.loss import MultiSegmentationLoss +from dinov3.eval.segmentation.metrics import SEGMENTATION_METRICS +from dinov3.eval.segmentation.models import build_segmentation_decoder +from dinov3.eval.segmentation.schedulers import build_scheduler +from dinov3.eval.segmentation.transforms import make_segmentation_eval_transforms, make_segmentation_train_transforms +from dinov3.logging import MetricLogger, SmoothedValue + +logger = logging.getLogger("dinov3") + + +class InfiniteDataloader: + def __init__(self, dataloader: torch.utils.data.DataLoader): + self.dataloader = dataloader + self.data_iterator = iter(dataloader) + self.sampler = dataloader.sampler + if not hasattr(self.sampler, "epoch"): + self.sampler.epoch = 0 # type: ignore + + def __iter__(self): + return self + + def __len__(self) -> int: + return len(self.dataloader) + + def __next__(self): + try: + data = next(self.data_iterator) + except StopIteration: + self.sampler.epoch += 1 + self.data_iterator = iter(self.dataloader) + data = next(self.data_iterator) + return data + + +def worker_init_fn(worker_id, num_workers, rank, seed): + """Worker init func for dataloader. + The seed of each worker equals to num_worker * rank + worker_id + user_seed + Args: + worker_id (int): Worker id. + num_workers (int): Number of workers. + rank (int): The rank of current process. + seed (int): The random seed to use. + """ + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) + torch.manual_seed(worker_seed) + + +def validate( + segmentation_model: torch.nn.Module, + val_dataloader, + device, + autocast_dtype, + eval_res, + eval_stride, + decoder_head_type, + num_classes, + global_step, + metric_to_save, + current_best_metric_to_save_value, +): + new_metric_values_dict = evaluate_segmentation_model( + segmentation_model, + val_dataloader, + device, + eval_res, + eval_stride, + decoder_head_type, + num_classes, + autocast_dtype, + ) + logger.info(f"Step {global_step}: {new_metric_values_dict}") + # `segmentation_model` is a module list of [backbone, decoder] + # Only put the head in train mode + segmentation_model.module.segmentation_model[1].train() + is_better = False + if new_metric_values_dict[metric_to_save] > current_best_metric_to_save_value: + is_better = True + return is_better, new_metric_values_dict + + +def train_step( + segmentation_model: torch.nn.Module, + batch, + device, + scaler, + optimizer, + optimizer_gradient_clip, + scheduler, + criterion, + model_dtype, + global_step, +): + # a) load batch + batch_img, (_, gt) = batch + batch_img = batch_img.to(device) # B x C x h x w + gt = gt.to(device) # B x (num_classes if multilabel) x h x w + optimizer.zero_grad(set_to_none=True) + + # b) forward pass + with torch.autocast("cuda", dtype=model_dtype, enabled=True if model_dtype is not None else False): + pred = segmentation_model(batch_img) # B x num_classes x h x w + gt = torch.squeeze(gt).long() # Adapt gt dimension to enable loss calculation + + # c) compute loss + if gt.shape[-2:] != pred.shape[-2:]: + pred = torch.nn.functional.interpolate(input=pred, size=gt.shape[-2:], mode="bilinear", align_corners=False) + loss = criterion(pred, gt) + + # d) optimization + if scaler is not None: + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(segmentation_model.module.parameters(), optimizer_gradient_clip) + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + torch.nn.utils.clip_grad_norm_(segmentation_model.module.parameters(), optimizer_gradient_clip) + optimizer.step() + + if global_step > 0: # inheritance from old mmcv code + scheduler.step() + + return loss + + +def train_segmentation( + backbone, + config, +): + assert config.decoder_head.type == "linear", "Only linear head is supported for training" + # 1- load the segmentation decoder + logger.info("Initializing the segmentation model") + segmentation_model = build_segmentation_decoder( + backbone, + config.decoder_head.backbone_out_layers, + "linear", + num_classes=config.decoder_head.num_classes, + autocast_dtype=config.model_dtype.autocast_dtype, + ) + global_device = distributed.get_rank() + local_device = torch.cuda.current_device() + segmentation_model = torch.nn.parallel.DistributedDataParallel( + segmentation_model.to(local_device), device_ids=[local_device] + ) # should be local rank + model_parameters = filter(lambda p: p.requires_grad, segmentation_model.parameters()) + logger.info(f"Number of trainable parameters: {sum(p.numel() for p in model_parameters)}") + + # 2- create data transforms + dataloaders + train_transforms = make_segmentation_train_transforms( + img_size=config.transforms.train.img_size, + random_img_size_ratio_range=config.transforms.train.random_img_size_ratio_range, + crop_size=config.transforms.train.crop_size, + flip_prob=config.transforms.train.flip_prob, + reduce_zero_label=config.eval.reduce_zero_label, + ) + val_transforms = make_segmentation_eval_transforms( + img_size=config.transforms.eval.img_size, + inference_mode=config.eval.mode, + ) + + train_dataset = DatasetWithEnumeratedTargets( + make_dataset( + dataset_str=f"{config.datasets.train}:root={config.datasets.root}", + transforms=train_transforms, + ) + ) + train_sampler_type = None + if distributed.is_enabled(): + train_sampler_type = SamplerType.DISTRIBUTED + init_fn = partial( + worker_init_fn, num_workers=config.num_workers, rank=global_device, seed=config.seed + global_device + ) + train_dataloader = InfiniteDataloader( + make_data_loader( + dataset=train_dataset, + batch_size=config.bs, + num_workers=config.num_workers, + sampler_type=train_sampler_type, + shuffle=True, + persistent_workers=False, + worker_init_fn=init_fn, + ) + ) + + val_dataset = DatasetWithEnumeratedTargets( + make_dataset( + dataset_str=f"{config.datasets.val}:root={config.datasets.root}", + transforms=val_transforms, + ) + ) + val_sampler_type = None + if distributed.is_enabled(): + val_sampler_type = SamplerType.DISTRIBUTED + val_dataloader = make_data_loader( + dataset=val_dataset, + batch_size=1, + num_workers=config.num_workers, + sampler_type=val_sampler_type, + drop_last=False, + shuffle=False, + persistent_workers=True, + ) + + # 3- define and create scaler, optimizer, scheduler, loss + scaler = None + if config.model_dtype.autocast_dtype is not None: + scaler = torch.amp.GradScaler("cuda") + + optimizer = torch.optim.AdamW( + [ + { + "params": filter(lambda p: p.requires_grad, segmentation_model.parameters()), + "lr": config.optimizer.lr, + "betas": (config.optimizer.beta1, config.optimizer.beta2), + "weight_decay": config.optimizer.weight_decay, + } + ] + ) + scheduler = build_scheduler( + config.scheduler.type, + optimizer=optimizer, + lr=config.optimizer.lr, + total_iter=config.scheduler.total_iter, + constructor_kwargs=config.scheduler.constructor_kwargs, + ) + criterion = MultiSegmentationLoss( + diceloss_weight=config.train.diceloss_weight, celoss_weight=config.train.celoss_weight + ) + total_iter = config.scheduler.total_iter + global_step = 0 + global_best_metric_values = {metric: 0.0 for metric in SEGMENTATION_METRICS} + + # 5- train the model + metric_logger = MetricLogger(delimiter=" ") + metric_logger.add_meter("loss", SmoothedValue(window_size=4, fmt="{value:.3f}")) + for batch in metric_logger.log_every( + train_dataloader, + 50, + header="Train: ", + start_iteration=global_step, + n_iterations=total_iter, + ): + if global_step >= total_iter: + break + loss = train_step( + segmentation_model, + batch, + local_device, + scaler, + optimizer, + config.optimizer.gradient_clip, + scheduler, + criterion, + config.model_dtype.autocast_dtype, + global_step, + ) + global_step += 1 + metric_logger.update(loss=loss) + if global_step % config.eval.eval_interval == 0: + dist.barrier() + is_better, best_metric_values_dict = validate( + segmentation_model, + val_dataloader, + local_device, + config.model_dtype.autocast_dtype, + config.eval.crop_size, + config.eval.stride, + config.decoder_head.type, + config.decoder_head.num_classes, + global_step, + config.metric_to_save, + global_best_metric_values[config.metric_to_save], + ) + if is_better: + logger.info(f"New best metrics at Step {global_step}: {best_metric_values_dict}") + global_best_metric_values = best_metric_values_dict + + # one last validation only if the number of total iterations is NOT divisible by eval interval: + if total_iter % config.eval.eval_interval: + is_better, best_metric_values_dict = validate( + segmentation_model, + val_dataloader, + local_device, + config.model_dtype.autocast_dtype, + config.eval.crop_size, + config.eval.stride, + config.decoder_head.type, + config.decoder_head.num_classes, + global_step, + config.metric_to_save, + global_best_metric_values[config.metric_to_save], + ) + if is_better: + logger.info(f"New best metrics at Step {global_step}: {best_metric_values_dict}") + global_best_metric_values = best_metric_values_dict + logger.info("Training is done!") + # segmentation_model is a module list of [backbone, decoder] + # Only save the decoder head + torch.save( + { + "model": {k: v for k, v in segmentation_model.module.state_dict().items() if "segmentation_model.1" in k}, + "optimizer": optimizer.state_dict(), + }, + os.path.join(config.output_dir, "model_final.pth"), + ) + logger.info(f"Final best metrics: {global_best_metric_values}") + return global_best_metric_values diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/transforms.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..2b877177f6721902f94c324ec9ef286ccc9b9d15 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/segmentation/transforms.py @@ -0,0 +1,474 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import numpy as np +from PIL import Image +from typing import Any, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F +from torchvision import transforms as T +from torchvision.transforms import functional as Fv +from torchvision.transforms import v2 +from torchvision.tv_tensors import Mask + +from dinov3.data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, make_normalize_transform +from dinov3.eval.segmentation.metrics import preprocess_nonzero_labels + + +class PhotoMetricDistortion(torch.nn.Module): + """Apply photometric distortion to image sequentially, every transformation + is applied with a probability of 0.5. + 1. random brightness + 2. random contrast (mode 0) + 3. convert color from BGR to HSV + 4. random saturation + 5. random hue + 6. convert color from HSV to BGR + 7. random contrast (mode 1) + Args: + brightness_delta (int): delta of brightness. + contrast_range (tuple): range of contrast. + saturation_range (tuple): range of saturation. + hue_delta (int): delta of hue. + """ + + def __init__( + self, + brightness_delta: int = 32, + contrast_range: Sequence[float] = (0.5, 1.5), + saturation_range: Sequence[float] = (0.5, 1.5), + hue_range: Sequence[float] = (-0.5, 0.5), + ): + super().__init__() + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_lower, self.hue_upper = hue_range + + def convert(self, img: np.ndarray, alpha: float = 1.0, beta: float = 0.0) -> np.ndarray: + """Multiple with alpha and add beat with clip.""" + img = img.astype(np.float32) * alpha + beta + img = np.clip(img, 0, 255) + return img.astype(np.uint8) + + def brightness(self, img: np.ndarray) -> np.ndarray: + if np.random.randint(2): + return self.convert(img, beta=np.random.uniform(-self.brightness_delta, self.brightness_delta)) + return img + + def contrast(self, img: np.ndarray) -> np.ndarray: + if np.random.randint(2): + return self.convert(img, alpha=np.random.uniform(self.contrast_lower, self.contrast_upper)) + return img + + def saturation(self, img: np.ndarray) -> np.ndarray: + if np.random.randint(2): + saturation_factor = np.random.uniform(self.saturation_lower, self.saturation_upper) + img_tensor = torch.tensor(img.astype(np.uint8)).permute((2, 0, 1)) + img_tensor = Fv.adjust_saturation(img_tensor, saturation_factor) + img = img_tensor.permute((1, 2, 0)).numpy() + return img + + def hue(self, img: np.ndarray) -> np.ndarray: + if np.random.randint(2): + hue_factor = np.random.uniform(self.hue_lower, self.hue_upper) + img_tensor = torch.tensor(img.astype(np.uint8)).permute((2, 0, 1)) + img_tensor = Fv.adjust_hue(img_tensor, hue_factor) + img = img_tensor.permute((1, 2, 0)).numpy() + return img + + def forward(self, img, label) -> Tuple[torch.Tensor, Any]: + """Transform function to perform photometric distortion on images.""" + # Operations need numpy arrays + img = img.permute((1, 2, 0)).numpy() + # random brightness + img = self.brightness(img) + # mode == 0 --> do random contrast first + # mode == 1 --> do random contrast last + mode = np.random.randint(2) + if mode == 1: + img = self.contrast(img) + # random saturation + img = self.saturation(img) + # random hue + img = self.hue(img) + # random contrast + if mode == 0: + img = self.contrast(img) + return torch.tensor(img.astype(np.float32)).permute((2, 0, 1)), label + + +class ReduceZeroLabel(torch.nn.Module): + """Operation on the labels when class 0 is to be ignored.""" + + def __init__(self, ignore_index=255): + super().__init__() + self.ignore_index = ignore_index + + def forward(self, img, label): + label = preprocess_nonzero_labels(label, ignore_index=self.ignore_index) + return img, label + + +class MaybeApplyImageLabel(torch.nn.Module): + """Apply a given operation on both image and label + given a probability threshold. + Args: + _transform (torchvision.transforms): type of transform to apply. + Since this transform is applied on both image and label, + it has to be deterministic (e.g. horizontal flip, non-random crop). + _threshold (float): probability of applying the above transform.""" + + def __init__(self, transform, threshold: float = 0.5): + super().__init__() + self._transform = transform + self._threshold = threshold + + def __call__(self, img, label): + x = np.random.rand() + if x < self._threshold: + return self._transform(img), self._transform(label) + return img, label + + +class FixedSideResize: + """Resize an image, given a fixed value for the small side. + Args: + small_size (int): small size to resize an image to. + example: if small_size = 512, an image of size (300, 400) will be resized to (512, 683) + image_interpolation (T.InterpolationMode): Interpolation mode when resizing a given image. + label_interpolation (T.InterpolationMode): Interpolation mode when resizing a given label. + random_img_size_ratio_range (tuple(min, max)): If used, for a given image, a random ratio + between the range is used to multiply to `small_size` for resizing + inference_mode (str): Dataset inference mode. + If value is "whole", resize both image and label for a single prediction on the resized image. + If value is "slide", resize image, do sliding inference on it, then scale it back to the + original image size for final prediction - the label doesn't need to be resized. + Returns: + image, label (PIL.Image, tensor.Tensor): resized image and label + """ + + def __init__( + self, + small_size, + image_interpolation, + label_interpolation, + random_img_size_ratio_range=None, + inference_mode="whole", + use_tta=False, + tta_img_size_ratio_range=[1.0], + ): + self.small_size = small_size + self.image_interpolation = image_interpolation + self.label_interpolation = label_interpolation + self.random_img_size_ratio_range = random_img_size_ratio_range + self.inference_mode = inference_mode + self.use_tta = use_tta + self.tta_img_size_ratio_range = tta_img_size_ratio_range + + def _random_sample_ratio(self): + min_ratio, max_ratio = self.random_img_size_ratio_range + ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio + return int(self.small_size * ratio) + + def _resize(self, img, label, small_size): + init_width, init_height = img.size + if init_height > init_width: + new_width = small_size + new_height = int(small_size * init_height / init_width + 0.5) + else: + new_height = small_size + new_width = int(small_size * init_width / init_height + 0.5) + + img = T.Resize(size=(new_height, new_width), interpolation=self.image_interpolation)(img) + if self.inference_mode == "whole": + label = T.Resize(size=(new_height, new_width), interpolation=self.label_interpolation)(label) + return img, label + + def __call__(self, img, label): + if not self.use_tta: + small_size = self.small_size + if self.random_img_size_ratio_range: + small_size = self._random_sample_ratio() + return self._resize(img, label, small_size) + + tta_img_list = [] # Used only if TTA + for tta_ratio in self.tta_img_size_ratio_range: + tta_size = int(self.small_size * tta_ratio) + if tta_ratio < 1: + tta_size = int(np.ceil(tta_size / 32)) * 32 + tta_img, _ = self._resize(img, label, tta_size) + tta_img_list.append(tta_img) + return tta_img_list, label + + +class ResizeV2: + """ + Resize both image and label using different interpolation modes. + """ + + def __init__(self, size, image_interpolation, label_interpolation): + self.size = size + self.image_interpolation = image_interpolation + self.label_interpolation = label_interpolation + + def __call__(self, img, label): + img = T.Resize(size=self.size, interpolation=self.image_interpolation)(img) + label = T.Resize(size=self.size, interpolation=self.label_interpolation)(label) + return img, label + + +class CustomResize(torch.nn.Module): + def __init__( + self, + img_resize, + image_interpolation, + label_interpolation, + random_img_size_ratio_range=None, + inference_mode="whole", + use_tta=False, + tta_img_size_ratio_range=[1.0], + ): + super().__init__() + if isinstance(img_resize, int): + self.resize_function = FixedSideResize( + small_size=img_resize, + image_interpolation=image_interpolation, + label_interpolation=label_interpolation, + random_img_size_ratio_range=random_img_size_ratio_range, + inference_mode=inference_mode, + use_tta=use_tta, + tta_img_size_ratio_range=tta_img_size_ratio_range, + ) + else: + self.resize_function = ResizeV2( + size=img_resize, + image_interpolation=image_interpolation, + label_interpolation=label_interpolation, + ) + + def forward(self, img, label): + return self.resize_function(img, label) + + +class RandomCropWithLabel(torch.nn.Module): + """Randomly crop the image & segmentation label. + Args: + crop_size (tuple(h, w)): Expected size after cropping. + cat_max_ratio (float): The maximum ratio that a single category could + occupy in the cropped image. Default value is 0.75. + ignore_index (int): Index to ignore when measuring the category ratio + in a cropped image + Returns: + cropped_img (torch.Tensor), Optional[crop_bbox](tuple) + """ + + def __init__(self, crop_size, cat_max_ratio=0.75, ignore_index=255): + super().__init__() + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size + self.cat_max_ratio = cat_max_ratio + self.ignore_index = ignore_index + + def get_crop_bbox(self, img): + """Randomly get a crop bounding box.""" + margin_h = max(img.shape[-2] - self.crop_size[0], 0) + margin_w = max(img.shape[-1] - self.crop_size[1], 0) + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] + + return crop_y1, crop_y2, crop_x1, crop_x2 + + def crop(self, img, crop_bbox): + """Crop given a crop bounding box""" + crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + img = img[:, crop_y1:crop_y2, crop_x1:crop_x2] + return img + + def forward(self, img, label): + """Find an adequate crop for a given image and crop it""" + # Create a random crop_bbox + new_crop_bbox = self.get_crop_bbox(img) + if self.cat_max_ratio < 1.0: + # Check that the ratio of label_counts / nb_pixels created + # with the random crop_bbox is under `cat_max_ratio` + # Repeat until 10 times to find a good crop_bbox + for _ in range(10): + seg_temp = self.crop(label, new_crop_bbox) + labels, cnt = np.unique(seg_temp, return_counts=True) + cnt = cnt[labels != self.ignore_index] + if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < self.cat_max_ratio: + break + new_crop_bbox = self.get_crop_bbox(img) + + return self.crop(img, new_crop_bbox), self.crop(label, new_crop_bbox) + + +class HorizontalFlipAug(torch.nn.Module): + def forward(self, img_list, label): + """Call function to apply test time augment transforms on results. + + Args: + img (PIL image | torch.Tensor | List[PIL image]): Data to transform. + + Returns: + list: A list of augmented data. + """ + if isinstance(img_list, Image.Image): + img_list = [img_list] + augmented_img_list = [Fv.hflip(img) for img in img_list] + img_list.extend(augmented_img_list) + return img_list, label + + def inverse(self, stacked_left_right_pair): + pre_aug_batch_size = len(stacked_left_right_pair) // 2 + assert pre_aug_batch_size * 2 == len(stacked_left_right_pair) + orig_img_list = stacked_left_right_pair[:pre_aug_batch_size] + orig_img_list.extend([Fv.hflip(img) for img in stacked_left_right_pair[pre_aug_batch_size:]]) + return orig_img_list + + +class PadTensor(torch.nn.Module): + """Pad a given tensor to the desired shape""" + + def __init__(self, pad_shape=[512, 512], img_pad_value=0, label_pad_value=255): + super().__init__() + self.pad_shape = pad_shape + self.img_pad_value = img_pad_value + self.label_pad_value = label_pad_value + + def forward(self, img, label): + h, w = img.shape[-2:] + new_h, new_w = self.pad_shape[0] - h, self.pad_shape[1] - w + img = F.pad(input=img, pad=(0, new_w, 0, new_h), mode="constant", value=self.img_pad_value) + label = F.pad(input=label, pad=(0, new_w, 0, new_h), mode="constant", value=self.label_pad_value) + return img, label + + +class NormalizeImage(torch.nn.Module): + def __init__(self, mean, std): + super().__init__() + self.normalize_function = make_normalize_transform(mean=mean, std=std) + + def forward(self, img, label): + return self.normalize_function(img.float()), label + + +class TransformImages(torch.nn.Module): + """Given a list of operations, apply them on a tensor or a list of transforms. + Transforms apply on images. Always return a list of tensors for coherent output format. + Args: + _transform (List[torchvision.transforms]): transforms to apply. + """ + + def __init__(self, transforms): + super().__init__() + self._transforms = transforms + + def forward(self, img, label): + if isinstance(img, (torch.Tensor, Image.Image)): + img = [img] + for transform in self._transforms: + # only apply transforms on the augmented images + img = [transform(im, label)[0] for im in img] + return img, label + + +class MaskToTensor(torch.nn.Module): + """Read segmentation mask from arrays or PIL images""" + + def forward(self, img, label): + if isinstance(label, np.ndarray): + return img, Mask(label).permute(2, 0, 1) + return img, Mask(label) + + +def make_segmentation_train_transforms( + *, + img_size: Optional[Union[List[int], int]] = None, + image_interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, + label_interpolation: T.InterpolationMode = T.InterpolationMode.NEAREST, + random_img_size_ratio_range: Optional[List[float]] = None, + crop_size: Optional[Tuple[int]] = None, + flip_prob: float = 0.0, + reduce_zero_label: bool = False, + mean: Sequence[float] = [mean * 255 for mean in IMAGENET_DEFAULT_MEAN], + std: Sequence[float] = [std * 255 for std in IMAGENET_DEFAULT_STD], +): + # Label conversion to tensor + transforms_list = [MaskToTensor()] # type: List[Any] + # Resizing + if img_size is not None: + transforms_list.append( + CustomResize( + img_resize=img_size, + image_interpolation=image_interpolation, + label_interpolation=label_interpolation, + inference_mode="whole", # when training, always resize image + label + random_img_size_ratio_range=random_img_size_ratio_range, + ) + ) + # Conversion to torch.Tensor + transforms_list.extend([v2.PILToTensor()]) + + # Reducing zero labels + if reduce_zero_label: + transforms_list.append(ReduceZeroLabel()) + + # Random crop + if crop_size: + transforms_list.append(RandomCropWithLabel(crop_size=crop_size)) + + # Rest of the image and label-specific transforms + transforms_list.extend( + [ + MaybeApplyImageLabel(transform=Fv.hflip, threshold=flip_prob), + PhotoMetricDistortion(), + NormalizeImage(mean=mean, std=std), + ] + ) + + # Pad if cropping was done previously + if crop_size: + transforms_list.append(PadTensor(pad_shape=crop_size, img_pad_value=0, label_pad_value=255)) + + return v2.Compose(transforms_list) + + +def make_segmentation_eval_transforms( + *, + img_size: Optional[Union[List[int], int]] = None, + inference_mode: str = "whole", + image_interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, + label_interpolation: T.InterpolationMode = T.InterpolationMode.NEAREST, + use_tta: bool = False, + tta_ratios: Sequence[float] = [1.0], + mean: Sequence[float] = [mean * 255 for mean in IMAGENET_DEFAULT_MEAN], + std: Sequence[float] = [std * 255 for std in IMAGENET_DEFAULT_STD], +): + # Label conversion to tensor + transforms_list = [MaskToTensor()] # type: List[Any] + # Optional resizing + if img_size is not None: + transforms_list.append( + CustomResize( + img_resize=img_size, + image_interpolation=image_interpolation, + label_interpolation=label_interpolation, + inference_mode=inference_mode, + use_tta=use_tta, + tta_img_size_ratio_range=tta_ratios, + ) + ) + + if use_tta: + transforms_list.append(HorizontalFlipAug()) + # Always return a list of tensors for prediction at evaluation time + transforms_list.append(TransformImages(transforms=[v2.PILToTensor(), NormalizeImage(mean=mean, std=std)])) + + return v2.Compose(transforms_list) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/setup.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..be4e20fced3e5a35e68a06eca787e1207c665846 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/setup.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from dataclasses import dataclass +from typing import Tuple, TypedDict + +import torch +import torch.backends.cudnn as cudnn +import torch.nn as nn + +from dinov3.configs import DinoV3SetupArgs, setup_config +from dinov3.models import build_model_for_eval + + +@dataclass +class ModelConfig: + # Loading a local file + config_file: str | None = None + pretrained_weights: str | None = None + # Loading a DINOv3 or v2 model from torch.hub + dino_hub: str | None = None + + +class BaseModelContext(TypedDict): + """ + An object that contains the context of a model (autocast, description, ...) + """ + + autocast_dtype: torch.dtype # default could be torch.float + + +def load_model_and_context(model_config: ModelConfig, output_dir: str) -> tuple[torch.nn.Module, BaseModelContext]: + if model_config.dino_hub is not None: + assert model_config.pretrained_weights is None and model_config.config_file is None + if "dinov3" in model_config.dino_hub: + repo = "dinov3" + elif "dinov2" in model_config.dino_hub: + repo = "dinov2" + else: + raise ValueError + model = torch.hub.load(f"facebookresearch/{repo}", model_config.dino_hub) + base_model_context = BaseModelContext(autocast_dtype=torch.float) + else: + model, base_model_context = setup_and_build_model( + config_file=model_config.config_file, + pretrained_weights=model_config.pretrained_weights, + output_dir=output_dir, + ) + + model.cuda() + model.eval() + return model, base_model_context + + +def get_autocast_dtype(config): + teacher_dtype_str = config.compute_precision.param_dtype + if teacher_dtype_str == "bf16": + return torch.bfloat16 + else: + return torch.float + + +def setup_and_build_model( + config_file: str, + pretrained_weights: str | None = None, + shard_unsharded_model: bool = False, + output_dir: str = "", + opts: list | None = None, + **ignored_kwargs, +) -> Tuple[nn.Module, BaseModelContext]: + cudnn.benchmark = True + del ignored_kwargs + setup_args = DinoV3SetupArgs( + config_file=config_file, + pretrained_weights=pretrained_weights, + shard_unsharded_model=shard_unsharded_model, + output_dir=output_dir, + opts=opts or [], + ) + config = setup_config(setup_args, strict_cfg=False) + model = build_model_for_eval(config, setup_args.pretrained_weights) + autocast_dtype = get_autocast_dtype(config) + return model, BaseModelContext(autocast_dtype=autocast_dtype) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/ac_comp_parallelize.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/ac_comp_parallelize.py new file mode 100644 index 0000000000000000000000000000000000000000..42617b9ea7d59bddd983078accd827b7c3b54e35 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/ac_comp_parallelize.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from contextlib import suppress +from functools import partial + +import torch +import torch.nn as nn +from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import register_fsdp_forward_method +from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState +from torch.utils.checkpoint import create_selective_checkpoint_contexts + +logger = logging.getLogger("dinov3") + + +def map_modules_and_blocks(models: list[nn.Module], callable) -> None: + for m in models: + for block_id, block in enumerate(m.blocks): + m.blocks[block_id] = callable(block, is_backbone_block=True) + + +def ac_compile_parallelize_and_init( + clip_model: nn.Module, + world_mesh: DeviceMesh, + do_compile: bool, + use_activation_checkpointing: bool, + use_full_activation_checkpointing: bool, + use_cuda_graphs: bool, + param_dtype_str: str = "bf16", + reduce_dtype_str: str = "fp32", +) -> None: + """ + Order of the wrappers: + 1/ Activation checkpointing on blocks + 2/ Compile blocks + 3/ FSDP blocks + global model + """ + logger.info("DISTRIBUTED FSDP -- preparing model for distributed training") + + # 1/ AC on blocks + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + ) + + trained_models = [] + inference_only_models = [] + for model in [clip_model.visual_model, clip_model.text_model]: + if not model.freeze_backbone: + trained_models.append(model.backbone) + else: + inference_only_models.append(model.backbone) + trained_models.append(model.head) + + for model in trained_models: + if use_activation_checkpointing: + if use_full_activation_checkpointing: + _checkpointing_wrapper = checkpoint_wrapper + logger.info( + "using selective checkpointing on backbone with full checkpointing policy" + ) + else: + _save_list = [ + # mm + torch.ops.aten.mm.default, + torch.ops.aten._scaled_mm.default, + # attentions + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + ] + with suppress( + AttributeError + ): # ignore exception if op is missing (old xFormers) + _save_list.append(torch.ops.xformers_flash3.flash_fwd.default) + _checkpointing_wrapper = partial( + checkpoint_wrapper, + context_fn=partial( + create_selective_checkpoint_contexts, _save_list + ), + preserve_rng_state=True, + ) + logger.info( + "using selective checkpointing on backbone with selective policy" + ) + for i, b in enumerate(model.blocks): + if not isinstance(b, nn.Identity): + model.blocks[i] = _checkpointing_wrapper(b) + + # 2/ Compile blocks + def compile_block(block: nn.Module) -> nn.Module: + if do_compile: + if use_cuda_graphs: + block.compile( + fullgraph=True, dynamic=False, options={"triton.cudagraphs": True} + ) + else: + block.compile() + return block + + def compile_backbone(backbone: nn.Module) -> nn.Module: + for block_id, block in enumerate(backbone.blocks): + backbone.blocks[block_id] = compile_block(block) + + def compile_head(head: nn.Module) -> nn.Module: + for block_id in range(head.num_blocks): + head.blocks[block_id] = compile_block(head.blocks[block_id]) + if do_compile and isinstance(head.linear_projection, nn.Linear): + head.linear_projection.compile() + + compile_backbone(clip_model.visual_model.backbone) + compile_backbone(clip_model.text_model.backbone) + compile_head(clip_model.visual_model.head) + compile_head(clip_model.text_model.head) + DTYPE_MAP = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + mp_policy = MixedPrecisionPolicy( + param_dtype=DTYPE_MAP[param_dtype_str], + reduce_dtype=DTYPE_MAP[reduce_dtype_str], + ) + fsdp_config = {"mesh": world_mesh["dp"], "mp_policy": mp_policy} + + for block in clip_model.visual_model.backbone.blocks: + fully_shard(block, **fsdp_config, reshard_after_forward=True) + for i in range(clip_model.visual_model.head.num_blocks): + fully_shard( + clip_model.visual_model.head.blocks[i], + **fsdp_config, + reshard_after_forward=True, + ) + fully_shard( + clip_model.visual_model.head.linear_projection, + **fsdp_config, + reshard_after_forward=True, + ) + fully_shard( + clip_model.visual_model.backbone, **fsdp_config, reshard_after_forward=True + ) + fully_shard(clip_model.visual_model.head, **fsdp_config, reshard_after_forward=True) + register_fsdp_forward_method( + clip_model.visual_model.backbone, "get_intermediate_layers" + ) + for block in clip_model.text_model.backbone.blocks: + fully_shard(block, **fsdp_config, reshard_after_forward=True) + for i in range(clip_model.text_model.head.num_blocks): + fully_shard( + clip_model.text_model.head.blocks[i], + **fsdp_config, + reshard_after_forward=True, + ) + fully_shard( + clip_model.text_model.head.linear_projection, + **fsdp_config, + reshard_after_forward=True, + ) + fully_shard( + clip_model.text_model.backbone, **fsdp_config, reshard_after_forward=True + ) + fully_shard(clip_model.text_model.head, **fsdp_config, reshard_after_forward=True) + + clip_model.to_empty(device="cuda") + clip_model.init_weights() + + for model in inference_only_models: + fsdp_state: FSDPState = model._get_fsdp_state() + if not fsdp_state._fsdp_param_group: + continue + mi = fsdp_state._fsdp_param_group.post_forward_mesh_info + fsdp_state._lazy_init() + fsdp_state._fsdp_param_group.post_forward_mesh_info = mi diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/build_dinotxt.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/build_dinotxt.py new file mode 100644 index 0000000000000000000000000000000000000000..0c879bd7a837bac5b065512a2dc36f0692b4867d --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/build_dinotxt.py @@ -0,0 +1,237 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from pathlib import Path +from typing import Any, Dict, List + +import dinov3.distributed as distributed +import torch +from dinov3.checkpointer import load_checkpoint, register_dont_save_hooks +from dinov3.data import ( + make_classification_eval_transform, + make_classification_train_transform, +) +from torch.distributed import DeviceMesh +from torch.distributed._composable.replicate import replicate +from torch.distributed.device_mesh import init_device_mesh + +from dinov3.eval.text.tokenizer import get_tokenizer + +from dinov3.eval.text.ac_comp_parallelize import ac_compile_parallelize_and_init +from dinov3.eval.text.dinotxt_model import DINOTxt, DINOTxtConfig + +logger = logging.getLogger("dinov3") + + +# This allows us to load OSS DINOv2 models from pretrained weights using DINOv3 ViT +def rename_register_token( + chkpt: Dict[str, Any], n_register_tokens: int, embed_dim: int +) -> Dict[str, Any]: + if "register_tokens" in chkpt: + chkpt["storage_tokens"] = chkpt["register_tokens"] + del chkpt["register_tokens"] + else: + chkpt["storage_tokens"] = torch.zeros(1, n_register_tokens, embed_dim) + return chkpt + + +def load_backbone_checkpoint( + model: torch.nn.Module, + checkpoint_path: str, + world_mesh: DeviceMesh, + skip_load_prefixes: List[str] = [], +): + if not Path(checkpoint_path).is_dir(): # PyTorch standard checkpoint + logger.info(f"Loading pretrained weights from {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location="cpu") + if "register_tokens" in state_dict: + state_dict["storage_tokens"] = state_dict["register_tokens"] + del state_dict["register_tokens"] + if "teacher" in state_dict: + state_dict = state_dict["teacher"] + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + state_dict = { + k: ( + torch.distributed.tensor.distribute_tensor( + v, world_mesh, src_data_rank=None + ) + if not k.startswith("rope_embed.periods") and "qkv.bias_mask" not in k + else v + ) + for k, v in state_dict.items() + } + model.load_state_dict( + { + k: v + for k, v in state_dict.items() + if not any(k.startswith(prefix) for prefix in skip_load_prefixes) + } + ) + else: # DCP checkpoint + load_checkpoint(checkpoint_path, model) + + +def compile_parallelize_and_init( + model: torch.nn.Module, + model_config: DINOTxtConfig, + world_mesh: DeviceMesh, + use_fsdp: bool, + do_compile: bool, + use_ac: bool, + use_full_ac: bool, + use_cuda_graphs: bool, + param_dtype_str: str = "bf16", + reduce_dtype_str: str = "fp32", +) -> None: + if not use_fsdp: + logger.info("Wrap in DDP, compile and initialize the model") + if do_compile: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" + replicate(model, device_mesh=world_mesh, bucket_cap_mb=100) + if do_compile: + model.compile() + model = model.to_empty(device="cuda") + model.init_weights() + else: + logger.info("Wrap in FSDP, compile and initialize the model") + ac_compile_parallelize_and_init( + model, + world_mesh, + do_compile, + use_ac, + use_full_ac, + use_cuda_graphs, + param_dtype_str, + reduce_dtype_str, + ) + if model.visual_model.freeze_backbone: + vision_backbone_pretrained_weights = ( + model_config.vision_backbone_pretrained_weights + ) + logger.info( + f"Loading visual backbone pretrained-weights from: {vision_backbone_pretrained_weights}" + ) + load_backbone_checkpoint( + model.visual_model.backbone, + vision_backbone_pretrained_weights, + world_mesh, + ["dino_loss", "ibot_patch_loss", "dino_head", "ibot_head"], + ) + model.visual_model.backbone = model.visual_model.backbone.eval() + for param in model.visual_model.backbone.parameters(): + param.requires_grad = False + logger.info("Froze visual backbone!") + register_dont_save_hooks( + model, + dont_save=[ + k + for k, _ in model.state_dict().items() + if k.startswith("visual_model.backbone") + ], + ) + if model.text_model.freeze_backbone: + text_backbone_pretrained_weights = model_config.text_backbone_pretrained_weights + logger.info( + f"Loading text backbone pretrained-weights from: {text_backbone_pretrained_weights}" + ) + load_backbone_checkpoint( + model.text_model.backbone, text_backbone_pretrained_weights, world_mesh + ) + logger.info("Assigned pretrained-weights to text backbone..") + logger.info("Freezing text backbone") + model.text_model.backbone = model.text_model.backbone.eval() + for param in model.text_model.backbone.parameters(): + param.requires_grad = False + logger.info("Froze text backbone!") + register_dont_save_hooks( + model, + dont_save=[ + k + for k, _ in model.state_dict().items() + if k.startswith("text_model.backbone") + ], + ) + + +def build_model_and_tokenizer( + model_config: DINOTxtConfig, + use_fsdp: bool = True, + do_compile: bool = False, + use_ac: bool = True, + use_full_ac: bool = False, + use_cuda_graphs: bool = False, + param_dtype_str: str = "bf16", + reduce_dtype_str: str = "fp32", +): + with torch.device("meta"): + model = DINOTxt(model_config=model_config, device="meta") + world_mesh = init_device_mesh( + "cuda", + mesh_shape=(distributed.get_world_size(),), + mesh_dim_names=("dp",), + ) + compile_parallelize_and_init( + model, + model_config, + world_mesh, + use_fsdp, + do_compile, + use_ac, + use_full_ac, + use_cuda_graphs, + param_dtype_str, + reduce_dtype_str, + ) + tokenizer = get_tokenizer(model_config.text_vocab_path_or_url) + return ( + model, + make_classification_train_transform( + crop_size=model_config.vision_model_train_img_size + ), + tokenizer, + ) + + +def build_model_for_eval( + model_config: DINOTxtConfig, + pretrained_weights: str, + use_fsdp: bool = True, + do_compile: bool = True, + param_dtype_str: str = "bf16", + reduce_dtype_str: str = "fp32", +): + with torch.device("meta"): + model = DINOTxt(model_config=model_config) + world_mesh = init_device_mesh( + "cuda", + mesh_shape=(distributed.get_world_size(),), + mesh_dim_names=("dp",), + ) + compile_parallelize_and_init( + model, + model_config, + world_mesh, + use_fsdp, + do_compile, + False, + False, + False, + param_dtype_str, + reduce_dtype_str, + ) + load_checkpoint(pretrained_weights, model=model) + model.eval() + tokenizer = get_tokenizer(model_config.text_vocab_path_or_url) + crop_size = model_config.vision_model_train_img_size + resize_size = int(256 * crop_size / 224) + return ( + model, + make_classification_eval_transform( + resize_size=resize_size, crop_size=crop_size + ), + tokenizer, + ) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/clip_loss.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/clip_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d6f4c80c5491c3724ee62022ce325db12db45b3e --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/clip_loss.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from typing import Callable, Optional, Tuple + +import torch + + +def _cycle_over_all_chunks( + my_chunk: torch.Tensor, + pg: torch.distributed.ProcessGroup, + step_fn: Callable[ + [torch.Tensor, int, Optional[torch.distributed.Work], Optional[torch.Tensor]], + Optional[torch.Tensor], + ], +): + next_rank = (pg.rank() + 1) % pg.size() + prev_rank = (pg.rank() - 1) % pg.size() + + extra_req: Optional[torch.distributed.Work] = None + dst_extra_chunk: Optional[torch.Tensor] = None + + dst_chunk = torch.empty_like(my_chunk) + for iter_ in range(pg.size()): + src_chunk = my_chunk if iter_ == 0 else dst_chunk + dst_chunk = torch.empty_like(my_chunk) + + if iter_ < pg.size() - 1: + send_op = torch.distributed.P2POp( + torch.distributed.isend, src_chunk, next_rank, group=pg + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, dst_chunk, prev_rank, group=pg + ) + reqs = torch.distributed.batch_isend_irecv([send_op, recv_op]) + else: + reqs = [] + + src_extra_chunk = step_fn( + src_chunk, (pg.rank() - iter_) % pg.size(), extra_req, dst_extra_chunk + ) + if src_extra_chunk is not None: + dst_extra_chunk = torch.empty_like(src_extra_chunk) + send_op = torch.distributed.P2POp( + torch.distributed.isend, src_extra_chunk, next_rank, group=pg + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, dst_extra_chunk, prev_rank, group=pg + ) + (extra_req,) = torch.distributed.batch_isend_irecv([send_op, recv_op]) + else: + extra_req = None + dst_extra_chunk = None + + for req in reqs: + req.wait() + + return extra_req, dst_extra_chunk + + +class MemoryEfficientClipLoss(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + pg: torch.distributed.ProcessGroup, + image_features: torch.Tensor, + text_features: torch.Tensor, + logit_scale: torch.Tensor, + ) -> torch.Tensor: + image_partial_lses_for_me = torch.empty( + (pg.size(), image_features.shape[0]), + dtype=torch.float32, + device=image_features.device, + ) + text_partial_lses_for_others = torch.empty( + (pg.size(), text_features.shape[0]), + dtype=torch.float32, + device=text_features.device, + ) + + positives: Optional[torch.Tensor] = None + + def my_step( + incoming: torch.Tensor, + other_rank: int, + _req: Optional[torch.distributed.Work], + _extra: Optional[torch.Tensor], + ) -> None: + nonlocal positives + logits = logit_scale * (image_features @ incoming.T) + if other_rank == pg.rank(): + positives = torch.diag(logits) + torch.logsumexp(logits, dim=1, out=image_partial_lses_for_me[other_rank]) + torch.logsumexp(logits, dim=0, out=text_partial_lses_for_others[other_rank]) + + _cycle_over_all_chunks(text_features, pg, my_step) + + text_partial_lses_for_me = torch.empty_like(text_partial_lses_for_others) + torch.distributed.all_to_all_single( + text_partial_lses_for_me, text_partial_lses_for_others, group=pg + ) + + image_lses_for_me = torch.logsumexp(image_partial_lses_for_me, dim=0) + text_lses_for_me = torch.logsumexp(text_partial_lses_for_me, dim=0) + + assert positives is not None + ctx.save_for_backward( + image_features, + text_features, + logit_scale, + positives, + image_lses_for_me, + text_lses_for_me, + ) + ctx.pg = pg # type: ignore[attr-defined] + + return (-(2 * positives - image_lses_for_me - text_lses_for_me).mean() / 2).to( + positives.dtype + ) + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, *grad_outputs: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], ...]: + pg: torch.distributed.ProcessGroup = ctx.pg # type: ignore[attr-defined] + image_features: torch.Tensor + text_features: torch.Tensor + logit_scale: torch.Tensor + positives: torch.Tensor + image_lses_for_me: torch.Tensor + text_lses_for_me: torch.Tensor + ( + image_features, + text_features, + logit_scale, + positives, + image_lses_for_me, + text_lses_for_me, + ) = ctx.saved_tensors # type: ignore[attr-defined] + + (grad,) = grad_outputs + grad /= 2 * positives.numel() + + text_lse_for_others = text_lses_for_me.new_empty( + (pg.size(),) + text_lses_for_me.shape + ) + torch.distributed.all_gather_into_tensor( + text_lse_for_others, text_lses_for_me, group=pg + ) + + grad_image_features = torch.zeros_like(image_features) + grad_logit_scale = torch.zeros_like(logit_scale) + + def my_step( + incoming: torch.Tensor, + other_rank: int, + req: Optional[torch.distributed.Work], + grad_text_features: Optional[torch.Tensor], + ) -> torch.Tensor: + raw_logits = image_features @ incoming.T + logits = logit_scale * raw_logits + + grad_logits = ( + (logits - image_lses_for_me[:, None]).exp() + + (logits - text_lse_for_others[other_rank, None, :]).exp() + ).to(logits.dtype) + if other_rank == pg.rank(): + torch.diagonal(grad_logits).sub_(2) + + grad_logit_scale.add_((raw_logits * grad_logits).sum()) + grad_raw_logits = grad_logits * logit_scale + + grad_image_features.addmm_(grad_raw_logits, incoming) + if req is None: + grad_text_features = torch.matmul(grad_raw_logits.T, image_features) + else: + req.wait() + assert grad_text_features is not None + grad_text_features.addmm_(grad_raw_logits.T, image_features) + + return grad_text_features + + req, grad_text_features = _cycle_over_all_chunks(text_features, pg, my_step) + req.wait() + + return ( + None, + grad * grad_image_features, + grad * grad_text_features, + grad * grad_logit_scale, + ) + + +def memory_efficient_clip_loss( + image_features: torch.Tensor, + text_features: torch.Tensor, + logit_scale: torch.Tensor, + *, + group: torch.distributed.ProcessGroup, +) -> torch.Tensor: + return MemoryEfficientClipLoss.apply( + group, image_features.float(), text_features.float(), logit_scale.float() + ) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/configs/dinov3_vitl_text.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/configs/dinov3_vitl_text.yaml new file mode 100644 index 0000000000000000000000000000000000000000..76fb1c060f0c36d358b40d6a9bc4664869af918d --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/configs/dinov3_vitl_text.yaml @@ -0,0 +1,58 @@ +train_dataset_str: #Example: CocoCaptions:split=TRAIN:root= +embed_dim: 2048 +vision_backbone_config: #Example: /dinov3/configs/train/dinov3_vitl16_lvd1689m_distilled.yaml +vision_backbone_pretrained_weights: #Example ~/.cache/torch/hub/checkpoints/dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth #Example: +vision_model_train_img_size: 224 +vision_model_use_class_token: true +vision_model_freeze_backbone: true +vision_model_num_head_blocks: 2 +vision_model_head_blocks_drop_path: 0.3 +vision_model_use_linear_projection: false +vision_model_use_patch_tokens: true +vision_model_patch_tokens_pooler_type: mean +vision_model_patch_token_layer: 1 +vision_model_use_gram_loss: false +vision_model_patch_sampling_rate_for_gram_loss: 1.0 +vision_model_normalize_patch_tokens_for_gram_loss: true +vision_model_gram_loss_weight: 1.0 +text_backbone_config: #Example: /dinov3/eval/text/configs/text_backbone.yaml +text_backbone_pretrained_weights: null +text_model_freeze_backbone: false +text_model_num_head_blocks: 0 +text_model_head_blocks_drop_prob: 0.0 +text_model_head_blocks_is_causal: true +text_model_tokens_pooler_type: argmax +text_model_use_linear_projection: true +text_vocab_path_or_url: # https://dl.fbaipublicfiles.com/dinov3/thirdparty/bpe_simple_vocab_16e6.txt.gz +init_logit_scale: 2.659260036932778 +init_logit_bias: null +freeze_logit_scale: false +output_dict: false +no_resume: false +lr_scheduler_type: cosine +lr: 0.0007 +weight_decay: 0.0001 +batch_size: 256 +beta1: 0.9 +beta2: 0.99 +eps: 1.0e-08 +eval_only: false +dataset_use_cache: false +max_checkpoints_to_keep: 5 +max_iteration: 50000 +warmup_length: 2000 +checkpointing_period: 500 +eval_freq: 5000 +gc_freq: 100 +eval_pretrained_weights: '' +output_dir: #Example: ~/tmp/dinov3_dinnotxt_vitl16 +seed: 11 +do_compile: true +use_fsdp: true +use_ac: true +use_full_ac: false +use_cuda_graphs: false +param_dtype_str: bf16 +reduce_dtype_str: fp32 +profiling: false +dtype_str: bf16 diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/configs/text_backbone.yaml b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/configs/text_backbone.yaml new file mode 100644 index 0000000000000000000000000000000000000000..76f18233932d3671936bb72db3e7d628b6b1aa02 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/configs/text_backbone.yaml @@ -0,0 +1,10 @@ +model_name: 1280d20h24l +context_length: 77 +vocab_size: 49408 +dim: 1280 +num_heads: 20 +num_layers: 24 +ffn_ratio: 4.0 +is_causal: true +dropout_prob: 0 +ls_init_value: null diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/dinotxt_model.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/dinotxt_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1056dd09c4f890d3f780cf9725f39db49764a228 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/dinotxt_model.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from dinov3.eval.text.text_tower import build_text_model +from dinov3.eval.text.vision_tower import build_vision_model + + +@dataclass +class DINOTxtConfig: + embed_dim: int + vision_backbone_config: str | None = None + text_backbone_config: str | None = None + vision_backbone_pretrained_weights: str | None = None + text_backbone_pretrained_weights: str | None = None + vision_model_freeze_backbone: bool = True + vision_model_train_img_size: int = 224 + vision_model_use_class_token: bool = True + vision_model_use_patch_tokens: bool = False + vision_model_num_head_blocks: int = 0 + vision_model_head_blocks_drop_path: float = 0.3 + vision_model_use_linear_projection: bool = False + vision_model_patch_tokens_pooler_type: str = "mean" + vision_model_patch_token_layer: int = 1 # which layer to take patch tokens from + # 1 - last layer, 2 - second last layer, etc. + text_model_freeze_backbone: bool = False + text_model_num_head_blocks: int = 0 + text_model_head_blocks_is_causal: bool = False + text_model_head_blocks_drop_prob: float = 0.0 + text_model_tokens_pooler_type: str = "first" + text_model_use_linear_projection: bool = False + text_vocab_path_or_url: Optional[str] = None + init_logit_scale: float = math.log(1 / 0.07) + init_logit_bias: Optional[float] = None + freeze_logit_scale: bool = False + + +class DINOTxt(nn.Module): + def __init__( + self, + model_config: DINOTxtConfig, + vision_backbone: Optional[nn.Module] = None, + text_backbone: Optional[nn.Module] = None, + device=None, + ): + super().__init__() + self.model_config = model_config + self.visual_model = build_vision_model( + model_config.embed_dim, + model_config.vision_backbone_config, + model_config.vision_model_freeze_backbone, + model_config.vision_model_num_head_blocks, + model_config.vision_model_head_blocks_drop_path, + model_config.vision_model_use_class_token, + model_config.vision_model_use_patch_tokens, + model_config.vision_model_patch_token_layer, + model_config.vision_model_patch_tokens_pooler_type, + model_config.vision_model_use_linear_projection, + backbone=vision_backbone, + ) + self.text_model = build_text_model( + model_config.embed_dim, + model_config.text_backbone_config, + model_config.text_model_freeze_backbone, + model_config.text_model_num_head_blocks, + model_config.text_model_head_blocks_is_causal, + model_config.text_model_head_blocks_drop_prob, + model_config.text_model_tokens_pooler_type, + model_config.text_model_use_linear_projection, + backbone=text_backbone, + ) + self.logit_scale = nn.Parameter(torch.empty(1, device=device)) + if model_config.freeze_logit_scale: + self.logit_scale.requires_grad = False + + def init_weights(self): + torch.nn.init.constant(self.logit_scale, self.model_config.init_logit_scale) + self.visual_model.init_weights() + self.text_model.init_weights() + + def encode_image_with_patch_tokens( + self, + image: torch.Tensor, + normalize: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + features, patch_tokens, backbone_patch_tokens = self.visual_model(image) + return ( + F.normalize(features, dim=-1) if normalize else features, + patch_tokens, + backbone_patch_tokens, + ) + + def encode_image( + self, + image: torch.Tensor, + normalize: bool = False, + ) -> torch.Tensor: + features, _, _ = self.visual_model(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text: torch.Tensor, normalize: bool = False) -> torch.Tensor: + features = self.text_model(text) + return F.normalize(features, dim=-1) if normalize else features + + def get_logits( + self, image: torch.Tensor, text: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + text_features = self.encode_text(text, normalize=True) + image_features = self.encode_image(image, normalize=True) + image_logits = self.logit_scale.exp() * image_features @ text_features.T + text_logits = image_logits.T + return image_logits, text_logits + + def forward( + self, + image: torch.Tensor, + text: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + text_features = self.encode_text(text, normalize=True) + image_features, patch_tokens, backbone_patch_tokens = ( + self.encode_image_with_patch_tokens(image, normalize=True) + ) + return ( + image_features, + text_features, + self.logit_scale.exp(), + patch_tokens, + backbone_patch_tokens, + ) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/gram_loss.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/gram_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..5cc80795bb6d9d28d826780b97fda5a9dc32dfa4 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/gram_loss.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import torch +import torch.nn.functional as F + + +def gram_loss_fn( + backbone_patch_tokens: torch.Tensor, + patch_tokens: torch.Tensor, + patch_sampling_rate: float = 1.0, + normalize: bool = True, +) -> torch.Tensor: + num_patches, dim = patch_tokens.shape[1:] + idx = torch.randperm(num_patches)[: int(num_patches * patch_sampling_rate)] + patch_tokens = patch_tokens[:, idx, :] + backbone_patch_tokens = backbone_patch_tokens[:, idx, :] + if normalize: + patch_tokens = F.normalize(patch_tokens, dim=-1) + backbone_patch_tokens = F.normalize(backbone_patch_tokens, dim=-1) + return torch.nn.MSELoss()( + patch_tokens @ patch_tokens.transpose(-2, -1), + backbone_patch_tokens @ backbone_patch_tokens.transpose(-2, -1), + ) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/text_tower.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/text_tower.py new file mode 100644 index 0000000000000000000000000000000000000000..77e24781ada31e50f457292588ce28dd85a56d56 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/text_tower.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from typing import Optional + +import torch + +from dinov3.eval.text.text_transformer import TextTransformer +from dinov3.layers import CausalSelfAttentionBlock +from torch import nn + +logger = logging.getLogger("dinov3") + + +class TextHead(nn.Module): + def __init__( + self, + input_dim: int, + embed_dim: int, + num_heads: int, + num_blocks: int, + block_drop_prob: float, + is_causal: bool, + use_linear_projection: bool, + ): + super().__init__() + block_list = [nn.Identity()] + self.ln_final = nn.Identity() + if num_blocks > 0: + logger.info(f"Adding {num_blocks} text tower transformer head blocks") + block_list = [ + CausalSelfAttentionBlock( + dim=input_dim, + num_heads=num_heads, + is_causal=is_causal, + dropout_prob=block_drop_prob, + ) + for _ in range(num_blocks) + ] + self.ln_final = nn.LayerNorm(input_dim) + self.blocks = nn.ModuleList(block_list) + self.num_blocks = num_blocks + self.linear_projection = nn.Identity() + if input_dim != embed_dim or use_linear_projection: + logger.info( + f"Text tower : Using a linear projection from {input_dim} to {embed_dim}" + ) + self.linear_projection = nn.Linear(input_dim, embed_dim, bias=False) + + def init_weights(self): + if self.num_blocks > 0: + for i in range(self.num_blocks): + self.blocks[i].init_weights() + self.ln_final.reset_parameters() + if isinstance(self.linear_projection, nn.Linear): + nn.init.normal_( + self.linear_projection.weight, + std=self.linear_projection.in_features**-0.5, + ) + + def forward(self, text_tokens: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + text_tokens = block(text_tokens) + text_tokens = self.ln_final(text_tokens) + return self.linear_projection(text_tokens) + + +class TextTower(nn.Module): + def __init__( + self, + backbone: nn.Module, + freeze_backbone: bool, + embed_dim: int, + num_head_blocks: int, + head_blocks_is_causal: bool, + head_blocks_block_drop_prob: float, + tokens_pooler_type: str, + use_linear_projection: bool, + ): + super().__init__() + self.backbone = backbone + self.freeze_backbone = freeze_backbone + backbone_out_dim = backbone.embed_dim + logger.info(f"Text backbone embedding dimension: {backbone_out_dim}") + self.backbone = backbone + self.head = TextHead( + backbone_out_dim, + embed_dim, + self.backbone.num_heads, + num_head_blocks, + head_blocks_block_drop_prob, + head_blocks_is_causal, + use_linear_projection, + ) + self.tokens_pooler_type = tokens_pooler_type + + def init_weights(self): + self.backbone.init_weights() + self.head.init_weights() + + def forward(self, token_indices: torch.Tensor) -> torch.Tensor: + text_tokens = self.backbone(token_indices) + text_tokens = self.head(text_tokens) + if self.tokens_pooler_type == "first": + features = text_tokens[:, 0] + elif self.tokens_pooler_type == "last": + features = text_tokens[:, -1] + elif self.tokens_pooler_type == "argmax": + assert token_indices is not None + features = text_tokens[ + torch.arange(text_tokens.shape[0]), token_indices.argmax(dim=-1) + ] + else: + raise ValueError(f"Unknown text tokens pooler type: {self.pooler_type}") + return features + + +def build_text_backbone( + cfg, +) -> torch.nn.Module: + logger.info("Setting up a text transformer") + model = TextTransformer( + context_length=cfg.context_length, + vocab_size=cfg.vocab_size, + dim=cfg.dim, + num_heads=cfg.num_heads, + num_layers=cfg.num_layers, + ffn_ratio=cfg.ffn_ratio, + is_causal=cfg.is_causal, + ls_init_value=cfg.ls_init_value, + dropout_prob=cfg.dropout_prob, + ) + logger.info(f"Setting upa custom text transformer {cfg.model_name}") + return model + + +def build_text_model( + embed_dim: int, + backbone_model_config: str, + freeze_backbone: bool, + num_head_blocks: int, + head_blocks_is_causal: bool, + head_blocks_drop_prob: float, + tokens_pooler_type: str, + use_linear_projection: bool, + backbone: Optional[nn.Module] = None, +): + if backbone is None: + if backbone_model_config is not None: + from omegaconf import OmegaConf + + cfg = OmegaConf.load(backbone_model_config) + backbone = build_text_backbone(cfg) + else: + raise RuntimeError( + "Failed to create, text backbone, either backbone or backbone_model_config should be not None" + ) + return TextTower( + backbone, + freeze_backbone, + embed_dim, + num_head_blocks, + head_blocks_is_causal, + head_blocks_drop_prob, + tokens_pooler_type, + use_linear_projection, + ) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/text_transformer.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/text_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..66183365f8198bf2a1213330087d6f2417868a9a --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/text_transformer.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from typing import Callable, Optional, Tuple + +import torch +import torch.nn as nn +from dinov3.layers import CausalSelfAttentionBlock + + +class TextTransformer(nn.Module): + def __init__( + self, + context_length: int, + vocab_size: int, + dim: int, + num_heads: int, + num_layers: int, + ffn_ratio: float, + is_causal: bool, + ls_init_value: Optional[float] = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + dropout_prob: float = 0.0, + ): + super().__init__() + self.vocab_size = vocab_size + self.embed_dim = dim + self.num_heads = num_heads + + self.token_embedding = nn.Embedding(vocab_size, dim) + self.positional_embedding = nn.Parameter(torch.empty(context_length, dim)) + self.dropout = nn.Dropout(dropout_prob) + self.num_layers = num_layers + block_list = [ + CausalSelfAttentionBlock( + dim=dim, + num_heads=num_heads, + ffn_ratio=ffn_ratio, + ls_init_value=ls_init_value, + is_causal=is_causal, + act_layer=act_layer, + norm_layer=norm_layer, + dropout_prob=dropout_prob, + ) + for _ in range(num_layers) + ] + self.blocks = nn.ModuleList(block_list) + self.ln_final = norm_layer(dim) + + def init_weights(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + init_attn_std = self.embed_dim**-0.5 + init_proj_std = (self.embed_dim**-0.5) * ((2 * self.num_layers) ** -0.5) + init_fc_std = (2 * self.embed_dim) ** -0.5 + for block in self.blocks: + block.init_weights(init_attn_std, init_proj_std, init_fc_std) + self.ln_final.reset_parameters() + + def forward(self, token_indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + _, N = token_indices.size() + x = self.token_embedding(token_indices) + self.positional_embedding[:N] + x = self.dropout(x) + for block in self.blocks: + x = block(x) + x = self.ln_final(x) + return x diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/tokenizer.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5834418f0c8220280d6059c89bb6ae55f5972789 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/tokenizer.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from typing import List, Union + +import torch +from dinov3.thirdparty.CLIP.clip.simple_tokenizer import SimpleTokenizer + + +class Tokenizer(SimpleTokenizer): + def __init__(self, vocab_path: str): + SimpleTokenizer.__init__(self, bpe_path=vocab_path) + + def tokenize( + self, texts: Union[str, List[str]], context_length: int = 77 + ) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + sot_token = self.encoder["<|startoftext|>"] + eot_token = self.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, : len(tokens)] = torch.tensor(tokens) + + return result + + +def get_tokenizer(bpe_path_or_url: str) -> Tokenizer | None: + import urllib + from io import BytesIO + + from .tokenizer import Tokenizer + + if urllib.parse.urlparse(bpe_path_or_url).scheme: + try: + with urllib.request.urlopen(bpe_path_or_url) as response: + file_buf = BytesIO(response.read()) + return Tokenizer(vocab_path=file_buf) + except Exception as e: + raise FileNotFoundError( + f"Failed to download file from url {bpe_path_or_url} with error last: {e}" + ) + else: + with open(bpe_path_or_url, "rb") as f: + file_buf = BytesIO(f.read()) + return Tokenizer(vocab_path=file_buf) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/train_dinotxt.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/train_dinotxt.py new file mode 100644 index 0000000000000000000000000000000000000000..295b04c475719c6a75ee260f0cd1bbcc00de42dd --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/train_dinotxt.py @@ -0,0 +1,339 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import gc +import logging +from functools import partial +import math +import os +import sys +from pathlib import Path +from typing import Callable + +import dinov3.distributed as distributed +import torch +from dinov3.checkpointer import ( + find_latest_checkpoint, + keep_last_n_checkpoints, + load_checkpoint, + save_checkpoint, +) +from dinov3.configs import setup_job +from dinov3.data import SamplerType, make_data_loader, make_dataset +from dinov3.eval.text.build_dinotxt import build_model_and_tokenizer +from dinov3.eval.text.clip_loss import memory_efficient_clip_loss +from dinov3.eval.text.dinotxt_model import DINOTxt, DINOTxtConfig +from dinov3.eval.text.gram_loss import gram_loss_fn +from dinov3.logging import MetricLogger, setup_logging +from dinov3.train.cosine_lr_scheduler import linear_warmup_cosine_decay +from omegaconf import OmegaConf +from torch import optim + +logger = logging.getLogger("dinov3") + + +def unwrap_model(model): + return getattr(model, "module", model) + + +def test( + model: DINOTxt, + iteration: str, + output_dir: str, +): + eval_dir = Path(output_dir) / "eval" / str(iteration) + if distributed.is_subgroup_main_process(): + eval_dir.mkdir(parents=True, exist_ok=True) + + ckpt_dir = eval_dir / str("sharded_model_checkpoint") + save_checkpoint(ckpt_dir, iteration=iteration, model=model) + + +def apply_learning_rate(optimizer, lr: float): + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def exclude(n: str, p: torch.Tensor) -> bool: + return p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or "logit_scale" in n + + +def include(n: str, p: torch.Tensor) -> bool: + return not exclude(n, p) + + +def train( + train_dataset, + model: DINOTxt, + tokenizer: Callable, + max_iteration: int, + warmup_length: int, + checkpointing_period: int, + output_dir: str, + dtype_str: str, + sampler_type: SamplerType, + lr_scheduler_type: str, + lr: float, + weight_decay: float, + batch_size: int = 256, + beta1: float = 0.9, + beta2: float = 0.99, + eps: float = 1e-8, + num_workers: int = 10, + eval_freq: int = 1000, + gc_freq: int = 100, + use_gram_loss: bool = False, + patch_sampling_rate_for_gram_loss: float = 0.5, + normalize_patch_tokens_for_gram_loss: bool = False, + gram_loss_weight: float = 1.0, + max_checkpoints_to_keep: int = None, + resume: bool = False, + seed: int = 11, +): + named_parameters = list(model.named_parameters()) + gain_or_bias_params = [ + p for n, p in named_parameters if exclude(n, p) and p.requires_grad + ] + gain_or_bias_params_names = [ + n for n, p in named_parameters if exclude(n, p) and p.requires_grad + ] + rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] + rest_params_names = [ + n for n, p in named_parameters if include(n, p) and p.requires_grad + ] + logger.info(f"Gain or bias params: {gain_or_bias_params_names}") + logger.info(f"Rest params: {rest_params_names}") + logger.info( + f"Learning rate: {lr}, batch_size_per_gpu: {batch_size}, weight_decay: {weight_decay}" + ) + optimizer = optim.AdamW( + [ + {"params": gain_or_bias_params, "weight_decay": 0.0}, + {"params": rest_params, "weight_decay": weight_decay}, + ], + lr=lr, + betas=(beta1, beta2), + eps=eps, + ) + learning_rates = linear_warmup_cosine_decay( + 0, lr, 0, warmup_iterations=warmup_length, total_iterations=max_iteration + ) + logger.info( + f"Init lr scheduler: {lr_scheduler_type}, warmup length: {warmup_length}, base_lr: {lr}, max iter: {max_iteration}" + ) + + if ( + resume + and (ckpt_dir := find_latest_checkpoint(os.path.join(output_dir, "ckpt"))) + is not None + ): + iteration = load_checkpoint(ckpt_dir, model=model, optimizer=optimizer) + start_iteration = iteration + 1 + del iteration, ckpt_dir + else: + logger.info("Initializing from scratch") + start_iteration = 0 + + def collate_fn(batch): + images, captions = list(zip(*batch))[:2] + return torch.stack(images), tokenizer.tokenize(captions) + + train_data_loader = make_data_loader( + dataset=train_dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=False, + seed=seed, + sampler_type=sampler_type, + sampler_advance=start_iteration, + drop_last=True, + persistent_workers=True, + collate_fn=collate_fn, + ) + rank = distributed.get_rank() + world_size = distributed.get_world_size() + logger.info( + f"Init loss function: rank: {distributed.get_rank()}, world size: {world_size}" + ) + clip_loss = partial(memory_efficient_clip_loss, group=torch.distributed.group.WORLD) + cur_iteration = start_iteration + logger.info(f"Starting training from iteration {start_iteration}...") + header = "Training" + metric_logger = MetricLogger(delimiter=" ") + gc.disable() + device_id = rank % torch.cuda.device_count() + + for batch in metric_logger.log_every( + train_data_loader, + 10, + header, + max_iteration, + start_iteration, + ): + images, text_tokens = batch + images = images.to(device=device_id, non_blocking=True) + text_tokens = text_tokens.to(device=device_id, non_blocking=True) + ( + image_embeddings, + text_embeddings, + logit_scale, + patch_tokens, + backbone_patch_tokens, + ) = model(images, text_tokens) + contrastive_loss = clip_loss(image_embeddings, text_embeddings, logit_scale) + total_loss = contrastive_loss + if use_gram_loss: + gram_loss = gram_loss_fn( + patch_tokens, + backbone_patch_tokens, + patch_sampling_rate_for_gram_loss, + normalize_patch_tokens_for_gram_loss, + ) + total_loss = contrastive_loss + gram_loss_weight * gram_loss + + if total_loss.isnan(): + msg = f"Loss is NaN at iteration {cur_iteration}, aborting..." + logger.error(msg) + raise RuntimeError(msg) + apply_learning_rate(optimizer=optimizer, lr=learning_rates[cur_iteration]) + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + + # This clamping trick is from OpenCLIP reposistory which MetaCLIP follows. Orginally used in CLIP training. + # NOTE: we clamp to 4.6052 = ln(100), as in the original paper. + with torch.no_grad(): + unwrap_model(model).logit_scale.clamp_(0, math.log(100)) + metric_logger.update(contrastive_loss=contrastive_loss.item()) + if use_gram_loss: + metric_logger.update(gram_loss=gram_loss.item()) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + metric_logger.update(logit_scale=logit_scale.item()) + is_last_iteration = (cur_iteration + 1) == max_iteration + is_ckpt_iteration = ( + (cur_iteration + 1) % checkpointing_period == 0 + ) or is_last_iteration + if is_ckpt_iteration: + ckpt_dir = Path(output_dir, "ckpt").expanduser() + save_checkpoint( + ckpt_dir / str(cur_iteration), + iteration=cur_iteration, + model=model, + optimizer=optimizer, + ) + if distributed.is_main_process(): + keep_last_n_checkpoints(ckpt_dir, max_checkpoints_to_keep) + if eval_freq > 0 and (cur_iteration + 1) % eval_freq == 0: + test( + model, + iteration=f"training_{cur_iteration}", + batch_size=batch_size, + num_workers=num_workers, + output_dir=output_dir, + dtype_str=dtype_str, + ) + torch.cuda.synchronize() + if (cur_iteration + 1) % gc_freq == 0: + logger.info("Garbage collection...") + gc.collect() + cur_iteration += 1 + + +def write_config( + model_config: DINOTxtConfig, output_dir, name="clip_model_config.yaml" +): + logger.info(OmegaConf.to_yaml(model_config)) + saved_cfg_path = os.path.join(output_dir, name) + with open(saved_cfg_path, "w") as f: + OmegaConf.save(config=model_config, f=f) + return saved_cfg_path + + +def main(argv=None): + if argv is None: + argv = sys.argv[1:] + args_dict = OmegaConf.to_container(OmegaConf.from_cli(argv)) + logger.info(args_dict) + config = OmegaConf.load(args_dict["trainer_config_file"]) + logger.info(config) + if "output_dir" in args_dict: + config.output_dir = args_dict["--output-dir"] + setup_job(output_dir=config.output_dir, seed=config.seed) + setup_logging(output=os.path.join(config.output_dir, "nan_logs"), name="nan_logger") + logger.info("Trainer config:") + logger.info(config) + model_config = DINOTxtConfig( + embed_dim=config.embed_dim, + text_backbone_config=config.text_backbone_config, + vision_backbone_config=config.vision_backbone_config, + text_backbone_pretrained_weights=config.text_backbone_pretrained_weights, + vision_backbone_pretrained_weights=config.vision_backbone_pretrained_weights, + vision_model_train_img_size=config.vision_model_train_img_size, + vision_model_use_class_token=config.vision_model_use_class_token, + vision_model_use_patch_tokens=config.vision_model_use_patch_tokens, + vision_model_num_head_blocks=config.vision_model_num_head_blocks, + vision_model_head_blocks_drop_path=config.vision_model_head_blocks_drop_path, + vision_model_use_linear_projection=config.vision_model_use_linear_projection, + vision_model_patch_tokens_pooler_type=config.vision_model_patch_tokens_pooler_type, + vision_model_patch_token_layer=config.vision_model_patch_token_layer, + text_model_freeze_backbone=config.text_model_freeze_backbone, + text_model_num_head_blocks=config.text_model_num_head_blocks, + text_model_head_blocks_is_causal=config.text_model_head_blocks_is_causal, + text_model_head_blocks_drop_prob=config.text_model_head_blocks_drop_prob, + text_model_tokens_pooler_type=config.text_model_tokens_pooler_type, + text_model_use_linear_projection=config.text_model_use_linear_projection, + text_vocab_path_or_url=config.text_vocab_path_or_url, + init_logit_scale=config.init_logit_scale, + freeze_logit_scale=config.freeze_logit_scale, + init_logit_bias=config.init_logit_bias, + ) + write_config(model_config=model_config, output_dir=config.output_dir) + model, transform, tokenizer = build_model_and_tokenizer( + model_config, + use_fsdp=config.use_fsdp, + do_compile=config.do_compile, + use_ac=config.use_ac, + use_cuda_graphs=config.use_cuda_graphs, + ) + + train_dataset = make_dataset( + dataset_str=config.train_dataset_str, + transform=transform, + ) + sampler_type = ( + SamplerType.SHARDED_INFINITE + if config.dataset_use_cache + else SamplerType.INFINITE + ) + train( + train_dataset=train_dataset, + model=model, + tokenizer=tokenizer, + max_iteration=config.max_iteration, + warmup_length=config.warmup_length, + checkpointing_period=config.checkpointing_period, + output_dir=config.output_dir, + dtype_str=config.dtype_str, + lr_scheduler_type=config.lr_scheduler_type, + lr=config.lr, + weight_decay=config.weight_decay, + batch_size=config.batch_size, + beta1=config.beta1, + beta2=config.beta2, + eps=config.eps, + sampler_type=sampler_type, + eval_freq=config.eval_freq, + gc_freq=config.gc_freq, + max_checkpoints_to_keep=config.max_checkpoints_to_keep, + use_gram_loss=config.vision_model_use_gram_loss, + patch_sampling_rate_for_gram_loss=config.vision_model_patch_sampling_rate_for_gram_loss, + normalize_patch_tokens_for_gram_loss=config.vision_model_normalize_patch_tokens_for_gram_loss, + gram_loss_weight=config.vision_model_gram_loss_weight, + resume=not config.no_resume, + ) + + +if __name__ == "__main__": + main() diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/vision_tower.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/vision_tower.py new file mode 100644 index 0000000000000000000000000000000000000000..b9b8b389ce8675f5d598921e535b92f82c5c4399 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/text/vision_tower.py @@ -0,0 +1,207 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from functools import partial +from typing import Optional, Tuple + +import torch +from torch import nn + +from dinov3.layers import SelfAttentionBlock, SwiGLUFFN +from dinov3.models.vision_transformer import init_weights_vit +from dinov3.utils import named_apply + +logger = logging.getLogger("dinov3") + + +class VisionHead(nn.Module): + def __init__( + self, + input_dim: int, + embed_dim: int, + num_heads: int, + num_blocks: int, + blocks_drop_path: float, + use_class_token: bool, + use_patch_tokens: bool, + use_linear_projection: bool, + ): + super().__init__() + block_list = [nn.Identity()] + self.ln_final = nn.Identity() + if num_blocks > 0: + block_list = [ + SelfAttentionBlock( + input_dim, + num_heads, + ffn_layer=partial(SwiGLUFFN, align_to=64), + init_values=1e-5, + drop_path=blocks_drop_path, + ) + for _ in range(num_blocks) + ] + self.ln_final = nn.LayerNorm(input_dim) + self.blocks = nn.ModuleList(block_list) + self.num_blocks = num_blocks + multiplier = 2 if use_class_token and use_patch_tokens else 1 + self.linear_projection = nn.Identity() + if multiplier * input_dim != embed_dim or use_linear_projection: + logger.info( + f"Vision Tower: Using a linear projection from {input_dim} to {embed_dim}" + ) + assert embed_dim % multiplier == 0, ( + f"Expects {embed_dim} to be divisible by {multiplier}" + ) + self.linear_projection = nn.Linear( + input_dim, embed_dim // multiplier, bias=False + ) + + def init_weights(self): + if self.num_blocks > 0: + for i in range(self.num_blocks): + block = self.blocks[i] + named_apply(init_weights_vit, block) + self.ln_final.reset_parameters() + if isinstance(self.linear_projection, nn.Linear): + nn.init.normal_( + self.linear_projection.weight, + std=self.linear_projection.in_features**-0.5, + ) + + def forward(self, image_tokens: torch.Tensor) -> torch.Tensor: + # FIXME(cijose) ROPE embeddings are not used in DINOv2, refactor to use it in the future + for block in self.blocks: + image_tokens = block(image_tokens) + image_tokens = self.ln_final(image_tokens) + return self.linear_projection(image_tokens) + + +class VisionTower(nn.Module): + def __init__( + self, + backbone: nn.Module, + freeze_backbone: bool, + embed_dim: int, + num_head_blocks: int, + head_blocks_block_drop_path: float, + use_class_token: bool, + use_patch_tokens: bool, + patch_token_layer: int, + patch_tokens_pooler_type: str, + use_linear_projection: bool, + ): + super().__init__() + self.backbone = backbone + self.freeze_backbone = freeze_backbone + self.use_class_token = use_class_token + self.use_patch_tokens = use_patch_tokens + self.patch_token_layer = patch_token_layer + self.patch_tokens_pooler_type = patch_tokens_pooler_type + self.num_register_tokens = 0 + if hasattr(self.backbone, "num_register_tokens"): + self.num_register_tokens = self.backbone.num_register_tokens + elif hasattr(self.backbone, "n_storage_tokens"): + self.num_register_tokens = self.backbone.n_storage_tokens + backbone_out_dim = self.backbone.embed_dim + logger.info(f"Visual backbone embedding dimension: {backbone_out_dim}") + self.head = VisionHead( + backbone_out_dim, + embed_dim, + self.backbone.num_heads, + num_head_blocks, + head_blocks_block_drop_path, + use_class_token, + use_patch_tokens, + use_linear_projection, + ) + + def init_weights(self): + self.backbone.init_weights() + self.head.init_weights() + + def get_backbone_features( + self, images: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tokens = self.backbone.get_intermediate_layers( + images, + n=self.patch_token_layer, + return_class_token=True, + return_extra_tokens=True, + ) + class_token = tokens[-1][1] + patch_tokens = tokens[0][0] + register_tokens = tokens[0][2] + return class_token, patch_tokens, register_tokens + + def get_class_and_patch_tokens( + self, images: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + class_token, patch_tokens, register_tokens = self.get_backbone_features(images) + image_tokens = self.head( + torch.cat([class_token.unsqueeze(1), register_tokens, patch_tokens], dim=1) + ) + return ( + image_tokens[:, 0], + image_tokens[:, self.num_register_tokens + 1 :], + patch_tokens, + ) + + def forward(self, images: torch.Tensor) -> torch.Tensor: + class_token, patch_tokens, backbone_patch_tokens = ( + self.get_class_and_patch_tokens(images) + ) + features = [] + if self.use_class_token: + features.append(class_token) + if self.use_patch_tokens: + if self.patch_tokens_pooler_type == "mean": + features.append(torch.mean(patch_tokens, dim=1)) + elif self.patch_tokens_pooler_type == "max": + features.append(torch.max(patch_tokens, dim=1).values) + else: + raise ValueError( + f"Unknown patch tokens pooler type: {self.patch_tokens_pooler_type}" + ) + return torch.cat(features, dim=-1), patch_tokens, backbone_patch_tokens + + +def build_vision_model( + embed_dim: int, + backbone_model_config: str, + freeze_backbone: bool, + num_head_blocks: int, + blocks_drop_path: float, + use_class_token: bool, + use_patch_tokens: bool, + patch_token_layer: int, + patch_tokens_pooler_type: str, + use_linear_projection: bool, + backbone: Optional[nn.Module] = None, +): + if backbone is None: + if backbone_model_config is not None: + from omegaconf import OmegaConf + + from dinov3.models import build_model_from_cfg as build_vision_backbone + + cfg = OmegaConf.load(backbone_model_config) + backbone, _ = build_vision_backbone(cfg, only_teacher=True) + else: + raise RuntimeError( + "Failed to create, vision backbone, either backbone or backbone_model_config should be not None" + ) + return VisionTower( + backbone, + freeze_backbone, + embed_dim, + num_head_blocks, + blocks_drop_path, + use_class_token, + use_patch_tokens, + patch_token_layer, + patch_tokens_pooler_type, + use_linear_projection, + ) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/utils.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c72096a22f896b443f072f84ef1d01a1b19e09fd --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/eval/utils.py @@ -0,0 +1,286 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import gc +import logging +import os +from enum import Enum +from typing import Any, Dict, List, Optional + +import numpy as np +import torch +from torch import nn +from torchmetrics import Metric + +import dinov3.distributed as distributed +from dinov3.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader +from dinov3.eval.accumulators import NoOpAccumulator, ResultsAccumulator +from dinov3.logging import MetricLogger + +logger = logging.getLogger("dinov3") + + +class LossType(Enum): + CROSS_ENTROPY = "cross_entropy" + BINARY_CROSS_ENTROPY = "binary_cross_entropy" + + +class ModelWithNormalize(torch.nn.Module): + def __init__(self, model: torch.nn.Module) -> None: + super().__init__() + self._model = model + + def forward(self, samples): + return nn.functional.normalize(self._model(samples), dim=1, p=2) + + +class ModelWithMultiScale(torch.nn.Module): + def __init__(self, model: torch.nn.Module, mode: str = "bilinear") -> None: + super().__init__() + self._model = model + self._mode = mode + + def forward(self, samples): + output = None + for scale in (1, 0.5**0.5, 0.5): + if scale == 1: + resized_samples = samples.clone() + else: + resized_samples = nn.functional.interpolate( + samples, scale_factor=scale, mode=self._mode, align_corners=False + ) + scale_output = self._model(resized_samples).clone() + if output is None: + output = scale_output + else: + output += scale_output + return output / 3 + + +def wrap_model( + model: nn.Module, + *, + normalize: bool = True, + multi_scale: bool = False, +) -> nn.Module: + logger.info("multi-scale: {}".format("enabled" if multi_scale else "disabled")) + if multi_scale: + model = ModelWithMultiScale(model) + + logger.info("normalize: {}".format("enabled" if normalize else "disabled")) + if normalize: + model = ModelWithNormalize(model) + return model + + +class ModelWithIntermediateLayers(nn.Module): + def __init__(self, feature_model, n, autocast_ctx, reshape=False, return_class_token=True): + super().__init__() + self.feature_model = feature_model + self.feature_model.eval() + self.n = n # Layer indices (Sequence) or n last layers (int) to take + self.autocast_ctx = autocast_ctx + self.reshape = reshape + self.return_class_token = return_class_token + + def forward(self, images): + with torch.inference_mode(): + with self.autocast_ctx(): + features = self.feature_model.get_intermediate_layers( + images, + n=self.n, + reshape=self.reshape, + return_class_token=self.return_class_token + ) + return features + + +@torch.inference_mode() +def evaluate( + model: nn.Module, + data_loader, + postprocessors: Dict[str, nn.Module], + metrics: Dict[str, Metric], + device: torch.device, + criterion: Optional[nn.Module] = None, + accumulate_results: bool = False, +): + gc.collect() # Avoids garbage collection errors in DataLoader workers + model.eval() + if criterion is not None: + criterion.eval() + + for metric in metrics.values(): + metric = metric.to(device) + + metric_logger = MetricLogger(delimiter=" ") + header = "Test:" + + accumulator_class = ResultsAccumulator if accumulate_results else NoOpAccumulator + accumulators = {k: accumulator_class() for k in postprocessors.keys()} + + # Dataset needs to be wrapped in fairvit.data.adapters.DatasetWithEnumeratedTargets + for samples, (index, targets), *_ in metric_logger.log_every(data_loader, 10, header): + samples, targets, index = samples[index >= 0], targets[index >= 0], index[index >= 0] + if len(index) == 0: + continue + + outputs = model(samples.to(device)) + index = index.to(device) + targets = targets.to(device) + + if criterion is not None: + loss = criterion(outputs, targets) + metric_logger.update(loss=loss.item()) + + for k, metric in metrics.items(): + metric_inputs = postprocessors[k](outputs, targets) + metric.update(**metric_inputs) + accumulators[k].update(preds=metric_inputs["preds"], target=metric_inputs["target"], index=index) + + metric_logger.synchronize_between_processes() + logger.info(f"Averaged stats: {metric_logger}") + stats = {k: metric.compute() for k, metric in metrics.items()} + metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + # accumulator.accumulate() returns None for the NoOpAccumulator + accumulated_results = {k: accumulator.accumulate() for k, accumulator in accumulators.items()} + + return metric_logger_stats, stats, accumulated_results + + +def all_gather_and_flatten(tensor_rank): + tensor_all_ranks = torch.empty( + distributed.get_world_size(), + *tensor_rank.shape, + dtype=tensor_rank.dtype, + device=tensor_rank.device, + ) + tensor_list = list(tensor_all_ranks.unbind(0)) + torch.distributed.all_gather(tensor_list, tensor_rank.contiguous()) + return tensor_all_ranks.flatten(end_dim=1) + + +def extract_features(model, dataset, batch_size, num_workers, gather_on_cpu=False): + dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset) + sample_count = len(dataset_with_enumerated_targets) + data_loader = make_data_loader( + dataset=dataset_with_enumerated_targets, + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + ) + return extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu) + + +@torch.inference_mode() +def extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu=False): + gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda") + metric_logger = MetricLogger(delimiter=" ") + features, all_labels = None, None + for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10): + samples = samples.cuda(non_blocking=True) + labels_rank = labels_rank.cuda(non_blocking=True) + index = index.cuda(non_blocking=True) + features_rank = model(samples).float() + + # init storage feature matrix + if features is None: + features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device) + labels_shape = list(labels_rank.shape) + labels_shape[0] = sample_count + all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device) + logger.info(f"Storing features into tensor of shape {features.shape}") + + # share indexes, features and labels between processes + index_all = all_gather_and_flatten(index).to(gather_device) + features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device) + labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device) + + # update storage feature matrix + if len(index_all) > 0: + features.index_copy_(0, index_all, features_all_ranks) + all_labels.index_copy_(0, index_all, labels_all_ranks) + + logger.info(f"Features Shape {features.shape}") + logger.info(f"Labels Shape {all_labels.shape}") + + return features, all_labels + + +def save_features_dict(features_dict: Dict[str, torch.Tensor], path: str) -> None: + logger.info(f'saving features to "{path}"') + + for key, value in features_dict.items(): + assert isinstance(key, str) + assert isinstance(value, torch.Tensor) + + _, ext = os.path.splitext(path) + if ext == ".pt": + torch.save(features_dict, path) + elif ext == ".npy": + numpy_features_dict = { # Convert to NumPy arrays (if possible) + key: value.cpu().numpy() for key, value in features_dict.items() + } + np.save(path, numpy_features_dict, allow_pickle=True) + else: + raise ValueError(f'Unsupported features dict extension "{ext}"') + + +def load_features_dict(path: str) -> Dict[str, torch.Tensor]: + logger.info(f'loading features from "{path}"') + + _, ext = os.path.splitext(path) + if ext == ".pt": + features_dict = torch.load(path) + elif ext == ".npy": + numpy_features_dict = np.load(path, allow_pickle=True).item() + features_dict = {key: torch.from_numpy(value) for key, value in numpy_features_dict.items()} + else: + raise ValueError(f'Unsupported features dict extension "{ext}"') + + for key, value in features_dict.items(): + assert isinstance(key, str) + assert isinstance(value, torch.Tensor) + + return features_dict + + +def average_metrics(eval_metrics_dict: dict[Any, dict[str, torch.Tensor]], ignore_keys: List[str] = []): + """ + Function that computes the average and the std on a metrics dict. + A linear evaluation dictionary contains "best_classifier", + so this specific key is removed for computing aggregated metrics. + """ + output_metrics_dict = {} + metrics = [metric for metric in eval_metrics_dict[0].keys() if metric not in ignore_keys] + for metric in metrics: + stats_tensor = torch.tensor([stat[metric] for stat in eval_metrics_dict.values()]) + output_metrics_dict[metric + "_mean"] = stats_tensor.mean().item() + output_metrics_dict[metric + "_std"] = torch.std(stats_tensor).item() + + return output_metrics_dict + + +def save_results( + preds: torch.Tensor, + target: torch.Tensor, + output_dir: str, + filename_suffix: Optional[str] = None, +) -> None: + """ + Helper to save predictions from a model and their associated targets, aligned by their index + """ + filename_suffix = "" if filename_suffix is None else f"_{filename_suffix}" + preds_filename = f"preds{filename_suffix}.npy" + target_filename = f"target{filename_suffix}.npy" + preds_path = os.path.join(output_dir, preds_filename) + target_path = os.path.join(output_dir, target_filename) + logger.info(f"Saving to {preds_path}") + np.save(preds_path, preds.cpu().numpy()) + logger.info(f"Saving to {target_path}") + np.save(target_path, target.cpu().numpy()) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/fsdp/ac_compile_parallelize.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/fsdp/ac_compile_parallelize.py new file mode 100644 index 0000000000000000000000000000000000000000..1644f98593a225e6c91a4768e0c9245bf24c9638 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/fsdp/ac_compile_parallelize.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from functools import partial +from typing import Any, List, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.fsdp import register_fsdp_forward_method +from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState +from torch.utils.checkpoint import create_selective_checkpoint_contexts + +from dinov3.utils import utils + +logger = logging.getLogger("dinov3") + + +def map_modules_and_blocks(models: list[nn.ModuleDict], callable) -> None: + for m in models: + assert isinstance(m, nn.ModuleDict) + for k in m.keys(): + if k == "backbone": + assert isinstance(m[k].blocks, nn.ModuleList) + for block_id, block in enumerate(m[k].blocks): + m[k].blocks[block_id] = callable(block, is_backbone_block=True) + else: + m[k] = callable(m[k], is_backbone_block=False) + + +def ac_compile_parallelize( + trained_model: nn.ModuleDict, + inference_only_models: List[nn.ModuleDict], + cfg: Any, + trained_model_process_group: Optional[dist.ProcessGroup] = None, + inference_only_models_process_groups: Optional[List[dist.ProcessGroup]] = None, +) -> None: + """ + Order of the wrappers: + 1/ Activation checkpointing on blocks + 2/ Compile blocks + 3/ FSDP blocks + global model + """ + assert ( + isinstance(trained_model, nn.ModuleDict) and "backbone" in trained_model.keys() + ), f"{trained_model} does not contain a backbone?" + logger.info("DISTRIBUTED FSDP -- preparing model for distributed training") + if utils.has_batchnorms(trained_model): + raise NotImplementedError + + # 1/ AC on blocks + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper + + backbone = trained_model.backbone + if cfg.train.checkpointing: + if cfg.train.checkpointing_full: + _checkpointing_wrapper = checkpoint_wrapper + logger.info("using selective checkpointing on backbone with full checkpointing policy") + else: + _save_list = [ + # mm + torch.ops.aten.mm.default, + torch.ops.aten._scaled_mm.default, + # attentions + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + ] + _checkpointing_wrapper = partial( + checkpoint_wrapper, + context_fn=partial(create_selective_checkpoint_contexts, _save_list), + preserve_rng_state=True, + ) + logger.info("using selective checkpointing on backbone with selective policy") + for i, b in enumerate(backbone.blocks): + backbone.blocks[i] = _checkpointing_wrapper(b) + + # 2/ Compile blocks + all_models = [trained_model] + inference_only_models + if trained_model_process_group is None and inference_only_models_process_groups is None: + all_pgs = [None] * len(all_models) + elif trained_model_process_group is None: + all_pgs = [None] + inference_only_models_process_groups + elif inference_only_models_process_groups is None: + all_pgs = [trained_model_process_group] + [None] * len(inference_only_models_process_groups) + else: + all_pgs = [trained_model_process_group] + inference_only_models_process_groups + + def wrap_compile_block(m: nn.Module, is_backbone_block: bool) -> nn.Module: + if cfg.train.compile: + if is_backbone_block and cfg.train.cudagraphs: + m.compile(fullgraph=True, dynamic=False, options={"triton.cudagraphs": True}) + else: + m.compile() + return m + + map_modules_and_blocks(all_models, wrap_compile_block) + + # 3/ Wrap submodules with FSDP + world_mesh = init_device_mesh( + "cuda", + mesh_shape=(dist.get_world_size(),), + mesh_dim_names=("dp",), + ) + DTYPE_MAP = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + mp_policy = MixedPrecisionPolicy( + param_dtype=DTYPE_MAP[cfg.compute_precision.param_dtype], + reduce_dtype=DTYPE_MAP[cfg.compute_precision.reduce_dtype], + ) + + for m, pg in zip(all_models, all_pgs): + if pg is None: + world_mesh = init_device_mesh( + "cuda", + mesh_shape=(dist.get_world_size(),), + mesh_dim_names=("dp",), + ) + else: + world_mesh = DeviceMesh.from_group(pg, "cuda") + fsdp_config = {"mesh": world_mesh, "mp_policy": mp_policy} + for k in m.keys(): + if k != "backbone": + m[k] = fully_shard(m[k], **fsdp_config, reshard_after_forward=True) + m[k].set_reduce_scatter_divide_factor(1) + continue + # Backbone - FSDP every block + blocks = m[k].blocks + + assert isinstance(blocks, nn.ModuleList) + for block_id, block in enumerate(blocks): + block_reshard: int | bool = True + # if m is trained_model and dist.get_world_size() % 8 == 0 and dist.get_world_size() > 8: + # block_reshard = 8 + blocks[block_id] = fully_shard(block, **fsdp_config, reshard_after_forward=block_reshard) + blocks[block_id].set_reduce_scatter_divide_factor(1) + prev_block: FSDPState + next_block: FSDPState + for prev_block, next_block in zip(blocks, blocks[1:]): + prev_block.set_modules_to_forward_prefetch([next_block]) + next_block.set_modules_to_backward_prefetch([prev_block]) + fully_shard(m.backbone, **fsdp_config, reshard_after_forward=True).set_reduce_scatter_divide_factor(1) + register_fsdp_forward_method(m.backbone, "get_intermediate_layers") + + # 4/ Move to `cuda` device + for model in all_models: + model.to_empty(device="cuda") + + # 5/ FSDP2: Reshard immediately after forward for inference-only models + for model in inference_only_models: + for k in model.keys(): + fsdp_state: FSDPState = model[k]._get_fsdp_state() + if not fsdp_state._fsdp_param_group: + continue + mi = fsdp_state._fsdp_param_group.post_forward_mesh_info + fsdp_state._lazy_init() + fsdp_state._fsdp_param_group.post_forward_mesh_info = mi diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/backbones.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..1a2a9b2c832d5b13343eaad9d418176a73f63e71 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/backbones.py @@ -0,0 +1,614 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import os +from enum import Enum +from typing import List, Optional, Union +from urllib.parse import urlparse +from pathlib import Path + +import torch + +from .utils import DINOV3_BASE_URL + + +class Weights(Enum): + LVD1689M = "LVD1689M" + SAT493M = "SAT493M" + + +def is_url(path: str) -> bool: + parsed = urlparse(path) + return parsed.scheme in ("https", "file") + + +def convert_path_or_url_to_url(path: str) -> str: + if is_url(path): + return path + return Path(path).expanduser().resolve().as_uri() + + +def _make_dinov3_vit_model_arch( + *, + patch_size: int = 16, + compact_arch_name: str = "vitb", +): + if "plus" in compact_arch_name: + model_arch = compact_arch_name.replace("plus", f"{patch_size}plus") + else: + model_arch = f"{compact_arch_name}{patch_size}" + return model_arch + + +def _make_dinov3_vit_model_url( + *, + patch_size: int = 16, + compact_arch_name: str = "vitb", + version: Optional[str] = None, + weights: Union[Weights, str] = Weights.LVD1689M, + hash: Optional[str] = None, +): + model_name = "dinov3" + model_arch = _make_dinov3_vit_model_arch(patch_size=patch_size, compact_arch_name=compact_arch_name) + version_suffix = f"_{version}" if version else "" + weights_name = weights.value.lower() + hash_suffix = f"-{hash}" if hash else "" + model_dir = f"{model_name}_{model_arch}" + model_filename = f"{model_name}_{model_arch}_pretrain_{weights_name}{version_suffix}{hash_suffix}.pth" + return os.path.join(DINOV3_BASE_URL, model_dir, model_filename) + + +def _make_dinov3_vit( + *, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + compact_arch_name: str = "vitb", + pos_embed_rope_base: float = 100.0, + pos_embed_rope_min_period: float | None = None, + pos_embed_rope_max_period: float | None = None, + pos_embed_rope_normalize_coords: str = "separate", + pos_embed_rope_shift_coords: float | None = None, + pos_embed_rope_jitter_coords: float | None = None, + pos_embed_rope_rescale_coords: float | None = None, + pos_embed_rope_dtype: str = "fp32", + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + ffn_ratio: float = 4.0, + qkv_bias: bool = True, + drop_path_rate: float = 0.0, + layerscale_init: float | None = None, + norm_layer: str = "layernorm", + ffn_layer: str = "mlp", + ffn_bias: bool = True, + proj_bias: bool = True, + n_storage_tokens: int = 0, + mask_k_bias: bool = False, + pretrained: bool = True, + version: Optional[str] = None, + weights: Union[Weights, str] = Weights.LVD1689M, + hash: Optional[str] = None, + check_hash: bool = False, + **kwargs, +): + from ..models.vision_transformer import DinoVisionTransformer + + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + pos_embed_rope_base=pos_embed_rope_base, + pos_embed_rope_min_period=pos_embed_rope_min_period, + pos_embed_rope_max_period=pos_embed_rope_max_period, + pos_embed_rope_normalize_coords=pos_embed_rope_normalize_coords, + pos_embed_rope_shift_coords=pos_embed_rope_shift_coords, + pos_embed_rope_jitter_coords=pos_embed_rope_jitter_coords, + pos_embed_rope_rescale_coords=pos_embed_rope_rescale_coords, + pos_embed_rope_dtype=pos_embed_rope_dtype, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + ffn_ratio=ffn_ratio, + qkv_bias=qkv_bias, + drop_path_rate=drop_path_rate, + layerscale_init=layerscale_init, + norm_layer=norm_layer, + ffn_layer=ffn_layer, + ffn_bias=ffn_bias, + proj_bias=proj_bias, + n_storage_tokens=n_storage_tokens, + mask_k_bias=mask_k_bias, + ) + vit_kwargs.update(**kwargs) + model = DinoVisionTransformer(**vit_kwargs) + if pretrained: + if type(weights) is Weights and weights not in {Weights.LVD1689M, Weights.SAT493M}: + raise ValueError(f"Unsupported weights for the backbone: {weights}") + elif type(weights) is Weights: + url = _make_dinov3_vit_model_url( + patch_size=patch_size, + compact_arch_name=compact_arch_name, + version=version, + weights=weights, + hash=hash, + ) + else: + url = convert_path_or_url_to_url(weights) + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=check_hash) + model.load_state_dict(state_dict, strict=True) + else: + model.init_weights() + return model + + +def _make_dinov3_convnext_model_url( + *, + compact_arch_name: str = "convnext_base", + weights: Union[Weights, str] = Weights.LVD1689M, + hash: Optional[str] = None, +): + model_name = "dinov3" + weights_name = weights.value.lower() + hash_suffix = f"-{hash}" if hash else "" + + model_dir = f"{model_name}_{compact_arch_name}" + model_filename = f"{model_name}_{compact_arch_name}_pretrain_{weights_name}{hash_suffix}.pth" + return os.path.join(DINOV3_BASE_URL, model_dir, model_filename) + + +def _make_dinov3_convnext( + in_chans: int = 3, + depths: List[int] = [3, 3, 27, 3], + dims: List[int] = [128, 256, 512, 1024], + compact_arch_name: str = "convnext_base", + drop_path_rate: float = 0.0, + layer_scale_init_value: float = 1e-6, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD1689M, + hash: Optional[str] = None, + **kwargs, +): + from ..models.convnext import ConvNeXt + + model_kwargs = dict( + in_chans=in_chans, + depths=depths, + dims=dims, + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, + ) + model_kwargs.update(**kwargs) + model = ConvNeXt(**model_kwargs) + if pretrained: + if type(weights) is Weights and weights not in {Weights.LVD1689M, Weights.SAT493M}: + raise ValueError(f"Unsupported weights for the backbone: {weights}") + elif type(weights) is Weights: + url = _make_dinov3_convnext_model_url( + compact_arch_name=compact_arch_name, + weights=weights, + hash=hash, + ) + else: + url = convert_path_or_url_to_url(weights) + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + return model + + +def dinov3_vits16( + *, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD1689M, + check_hash: bool = False, + **kwargs, +): + if "hash" not in kwargs: + kwargs["hash"] = "08c60483" + kwargs["version"] = None + return _make_dinov3_vit( + img_size=224, + patch_size=16, + in_chans=3, + pos_embed_rope_base=100, + pos_embed_rope_normalize_coords="separate", + pos_embed_rope_rescale_coords=2, + pos_embed_rope_dtype="fp32", + embed_dim=384, + depth=12, + num_heads=6, + ffn_ratio=4, + qkv_bias=True, + drop_path_rate=0.0, + layerscale_init=1.0e-05, + norm_layer="layernormbf16", + ffn_layer="mlp", + ffn_bias=True, + proj_bias=True, + n_storage_tokens=4, + mask_k_bias=True, + pretrained=pretrained, + weights=weights, + compact_arch_name="vits", + check_hash=check_hash, + **kwargs, + ) + + +def dinov3_vits16plus( + *, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD1689M, + check_hash: bool = False, + **kwargs, +): + if "hash" not in kwargs: + kwargs["hash"] = "4057cbaa" + kwargs["version"] = None + return _make_dinov3_vit( + img_size=224, + patch_size=16, + in_chans=3, + pos_embed_rope_base=100, + pos_embed_rope_normalize_coords="separate", + pos_embed_rope_rescale_coords=2, + pos_embed_rope_dtype="fp32", + embed_dim=384, + depth=12, + num_heads=6, + ffn_ratio=6, + qkv_bias=True, + drop_path_rate=0.0, + layerscale_init=1.0e-05, + norm_layer="layernormbf16", + ffn_layer="swiglu", + ffn_bias=True, + proj_bias=True, + n_storage_tokens=4, + mask_k_bias=True, + pretrained=pretrained, + weights=weights, + compact_arch_name="vitsplus", + check_hash=check_hash, + **kwargs, + ) + + +def dinov3_vitb16( + *, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD1689M, + check_hash: bool = False, + **kwargs, +): + if "hash" not in kwargs: + kwargs["hash"] = "73cec8be" + kwargs["version"] = None + return _make_dinov3_vit( + img_size=224, + patch_size=16, + in_chans=3, + pos_embed_rope_base=100, + pos_embed_rope_normalize_coords="separate", + pos_embed_rope_rescale_coords=2, + pos_embed_rope_dtype="fp32", + embed_dim=768, + depth=12, + num_heads=12, + ffn_ratio=4, + qkv_bias=True, + drop_path_rate=0.0, + layerscale_init=1.0e-05, + norm_layer="layernormbf16", + ffn_layer="mlp", + ffn_bias=True, + proj_bias=True, + n_storage_tokens=4, + mask_k_bias=True, + pretrained=pretrained, + weights=weights, + compact_arch_name="vitb", + check_hash=check_hash, + **kwargs, + ) + + +def dinov3_vitl16( + *, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD1689M, + check_hash: bool = False, + **kwargs, +): + untie_global_and_local_cls_norm = False + if weights == Weights.LVD1689M: + if "hash" not in kwargs: + kwargs["hash"] = "8aa4cbdd" + elif weights == Weights.SAT493M: + if "hash" not in kwargs: + kwargs["hash"] = "eadcf0ff" + untie_global_and_local_cls_norm = True + elif type(weights) is str: + import re + + pattern = r"-(.{8}).pth" + matches = re.findall(pattern, weights) + if len(matches) != 1: + raise ValueError(f"Unexpected weights specification for the ViT-L backbone: {weights}") + hash = matches[0] + if hash == "eadcf0ff": + untie_global_and_local_cls_norm = True + kwargs["version"] = None + return _make_dinov3_vit( + img_size=224, + patch_size=16, + in_chans=3, + pos_embed_rope_base=100, + pos_embed_rope_normalize_coords="separate", + pos_embed_rope_rescale_coords=2, + pos_embed_rope_dtype="fp32", + embed_dim=1024, + depth=24, + num_heads=16, + ffn_ratio=4, + qkv_bias=True, + drop_path_rate=0.0, + layerscale_init=1.0e-05, + norm_layer="layernormbf16", + ffn_layer="mlp", + ffn_bias=True, + proj_bias=True, + n_storage_tokens=4, + mask_k_bias=True, + untie_global_and_local_cls_norm=untie_global_and_local_cls_norm, + pretrained=pretrained, + weights=weights, + compact_arch_name="vitl", + check_hash=check_hash, + **kwargs, + ) + + +def dinov3_vitl16plus( + *, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD1689M, + check_hash: bool = False, + **kwargs, +): + if "hash" not in kwargs: + kwargs["hash"] = "46503df0" + + return _make_dinov3_vit( + img_size=224, + patch_size=16, + in_chans=3, + pos_embed_rope_base=100, + pos_embed_rope_normalize_coords="separate", + pos_embed_rope_rescale_coords=2, + pos_embed_rope_dtype="fp32", + embed_dim=1024, + depth=24, + num_heads=16, + ffn_ratio=6.0, + qkv_bias=True, + drop_path_rate=0.0, + layerscale_init=1.0e-05, + norm_layer="layernormbf16", + ffn_layer="swiglu", + ffn_bias=True, + proj_bias=True, + n_storage_tokens=4, + mask_k_bias=True, + pretrained=pretrained, + weights=weights, + compact_arch_name="vitlplus", + check_hash=check_hash, + **kwargs, + ) + + +def dinov3_vith16plus( + *, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD1689M, + check_hash: bool = False, + **kwargs, +): + if "hash" not in kwargs: + kwargs["hash"] = "7c1da9a5" + + return _make_dinov3_vit( + img_size=224, + patch_size=16, + in_chans=3, + pos_embed_rope_base=100, + pos_embed_rope_normalize_coords="separate", + pos_embed_rope_rescale_coords=2, + pos_embed_rope_dtype="fp32", + embed_dim=1280, + depth=32, + num_heads=20, + ffn_ratio=6.0, + qkv_bias=True, + drop_path_rate=0.0, + layerscale_init=1.0e-05, + norm_layer="layernormbf16", + ffn_layer="swiglu", + ffn_bias=True, + proj_bias=True, + n_storage_tokens=4, + mask_k_bias=True, + pretrained=pretrained, + weights=weights, + compact_arch_name="vithplus", + check_hash=check_hash, + **kwargs, + ) + + +def dinov3_vit7b16( + *, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD1689M, + check_hash: bool = False, + **kwargs, +): + if weights == Weights.LVD1689M: + if "hash" not in kwargs: + kwargs["hash"] = "a955f4ea" + elif weights == Weights.SAT493M: + if "hash" not in kwargs: + kwargs["hash"] = "a6675841" + kwargs["version"] = None + untie_global_and_local_cls_norm = True + return _make_dinov3_vit( + img_size=224, + patch_size=16, + in_chans=3, + pos_embed_rope_base=100, + pos_embed_rope_normalize_coords="separate", + pos_embed_rope_rescale_coords=2, + pos_embed_rope_dtype="fp32", + embed_dim=4096, + depth=40, + num_heads=32, + ffn_ratio=3, + qkv_bias=False, + drop_path_rate=0.0, + layerscale_init=1.0e-05, + norm_layer="layernormbf16", + ffn_layer="swiglu64", + ffn_bias=True, + proj_bias=True, + n_storage_tokens=4, + mask_k_bias=True, + untie_global_and_local_cls_norm=untie_global_and_local_cls_norm, + pretrained=pretrained, + weights=weights, + compact_arch_name="vit7b", + check_hash=check_hash, + **kwargs, + ) + + +def dinov3_convnext_tiny( + *, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD1689M, + **kwargs, +): + _hash_convnext = "21b726bb" + if "hash" not in kwargs: + kwargs["hash"] = _hash_convnext + + from ..models.convnext import convnext_sizes + + size_dict = convnext_sizes["tiny"] + + model = _make_dinov3_convnext( + in_chans=3, + depths=size_dict["depths"], + dims=size_dict["dims"], + compact_arch_name="convnext_tiny", + drop_path_rate=0, + layer_scale_init_value=1e-6, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + if not pretrained: + model.init_weights() + return model + + +def dinov3_convnext_small( + *, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD1689M, + **kwargs, +): + _hash_convnext = "296db49d" + if "hash" not in kwargs: + kwargs["hash"] = _hash_convnext + + from ..models.convnext import convnext_sizes + + size_dict = convnext_sizes["small"] + + model = _make_dinov3_convnext( + in_chans=3, + depths=size_dict["depths"], + dims=size_dict["dims"], + compact_arch_name="convnext_small", + drop_path_rate=0, + layer_scale_init_value=1e-6, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + if not pretrained: + model.init_weights() + return model + + +def dinov3_convnext_base( + *, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD1689M, + **kwargs, +): + _hash_convnext = "801f2ba9" + if "hash" not in kwargs: + kwargs["hash"] = _hash_convnext + + from ..models.convnext import convnext_sizes + + size_dict = convnext_sizes["base"] + + model = _make_dinov3_convnext( + in_chans=3, + depths=size_dict["depths"], + dims=size_dict["dims"], + compact_arch_name="convnext_base", + drop_path_rate=0, + layer_scale_init_value=1e-6, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + if not pretrained: + model.init_weights() + return model + + +def dinov3_convnext_large( + *, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD1689M, + **kwargs, +): + _hash_convnext = "61fa432d" + if "hash" not in kwargs: + kwargs["hash"] = _hash_convnext + + from ..models.convnext import convnext_sizes + + size_dict = convnext_sizes["large"] + + model = _make_dinov3_convnext( + in_chans=3, + depths=size_dict["depths"], + dims=size_dict["dims"], + compact_arch_name="convnext_large", + drop_path_rate=0, + layer_scale_init_value=1e-6, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + if not pretrained: + model.init_weights() + return model diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/classifiers.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/classifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..1595f38cf5b5064ac6341bbb2a989617dda4e199 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/classifiers.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import os +from enum import Enum +from typing import Optional + +import torch +import torch.nn as nn + +from .backbones import ( + dinov3_vit7b16, + Weights as BackboneWeights, + convert_path_or_url_to_url, +) + +from .utils import DINOV3_BASE_URL + + +class ClassifierWeights(Enum): + IMAGENET1K = "IMAGENET1K" + + +def _make_dinov3_linear_classification_head( + *, + backbone_name: str = "dinov3_vit7b16", + embed_dim: int = 8192, + pretrained: bool = True, + classifier_weights: ClassifierWeights | str = ClassifierWeights.IMAGENET1K, + check_hash: bool = False, + **kwargs, +): + linear_head = nn.Linear(embed_dim, 1_000) + if pretrained: + if type(classifier_weights) is ClassifierWeights: + assert classifier_weights == ClassifierWeights.IMAGENET1K, ( + f"Unsupported weights for linear classifier: {classifier_weights}" + ) + weights_name = classifier_weights.value.lower() + hash = kwargs["hash"] if "hash" in kwargs else "90d8ed92" + model_filename = f"{backbone_name}_{weights_name}_linear_head-{hash}.pth" + url = os.path.join(DINOV3_BASE_URL, backbone_name, model_filename) + else: + url = convert_path_or_url_to_url(classifier_weights) + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=check_hash) + linear_head.load_state_dict(state_dict, strict=True) + return linear_head + + +class _LinearClassifierWrapper(nn.Module): + def __init__(self, *, backbone: nn.Module, linear_head: nn.Module): + super().__init__() + self.backbone = backbone + self.linear_head = linear_head + + def forward(self, x): + x = self.backbone.forward_features(x) + cls_token = x["x_norm_clstoken"] + patch_tokens = x["x_norm_patchtokens"] + linear_input = torch.cat( + [ + cls_token, + patch_tokens.mean(dim=1), + ], + dim=1, + ) + return self.linear_head(linear_input) + + +def _make_dinov3_linear_classifier( + *, + backbone_name: str = "dinov3_vit7b16", + pretrained: bool = True, + classifier_weights: ClassifierWeights | str = ClassifierWeights.IMAGENET1K, + backbone_weights: BackboneWeights | str = BackboneWeights.LVD1689M, + check_hash: bool = False, + **kwargs, +): + if backbone_name == "dinov3_vit7b16": + backbone = dinov3_vit7b16(pretrained=pretrained, weights=backbone_weights, check_hash=check_hash) + else: + raise AssertionError(f"Unsupported backbone: {backbone_name}, linear classifiers are provided only for ViT-7b") + embed_dim = backbone.embed_dim + linear_head = _make_dinov3_linear_classification_head( + backbone_name=backbone_name, + embed_dim=2 * embed_dim, + pretrained=pretrained, + classifier_weights=classifier_weights, + **kwargs, + ) + return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head) + + +def dinov3_vit7b16_lc( + *, + pretrained: bool = True, + weights: ClassifierWeights | str = ClassifierWeights.IMAGENET1K, + backbone_weights: BackboneWeights | str = BackboneWeights.LVD1689M, + check_hash: bool = False, + **kwargs, +): + """ + Linear classifier on top of a DINOv3 ViT-7B/16 backbone pretrained on the LVD-1689M dataset and trained on ImageNet-1k. + """ + return _make_dinov3_linear_classifier( + backbone_name="dinov3_vit7b16", + pretrained=pretrained, + classifier_weights=weights, + backbone_weights=backbone_weights, + check_hash=check_hash, + **kwargs, + ) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/depthers.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/depthers.py new file mode 100644 index 0000000000000000000000000000000000000000..25c67e080be4f33fbf9b4a956e27c4be5184821a --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/depthers.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from enum import Enum +from typing import Optional, Tuple + +import torch +from dinov3.eval.depth.models import DecoderConfig, make_depther_from_config + +from .utils import DINOV3_BASE_URL +from .backbones import ( + Weights as BackboneWeights, + dinov3_vitl16, + dinov3_vit7b16, + convert_path_or_url_to_url, +) + + +class DepthWeights(Enum): + SYNTHMIX = "SYNTHMIX" + + +def _get_depth_range(dataset: DepthWeights): + depth_ranges = { + DepthWeights.SYNTHMIX: (0.001, 100.0), + } + return depth_ranges[dataset] + + +_DPT_HEAD_CONFIG_DICT = dict( + use_backbone_norm=True, + use_batchnorm=True, + use_cls_token=False, + n_output_channels=256, + depth_weights=DepthWeights.SYNTHMIX, + backbone_weights=BackboneWeights.LVD1689M, +) + + +def _get_out_layers(backbone_name): + if "vitl" in backbone_name: + return [4, 11, 17, 23] + elif "vit7b" in backbone_name: + return [9, 19, 29, 39] + else: + raise ValueError(f"Unrecognized backbone name {backbone_name}") + + +def _get_post_process_channels(backbone_name): + if "vitl" in backbone_name: + return [1024, 1024, 1024, 1024] + elif "vit7b" in backbone_name: + return [2048, 2048, 2048, 2048] + + +_BACKBONE_DICT = { + "dinov3_vit7b16": dinov3_vit7b16, + "dinov3_vitl16": dinov3_vitl16, +} + + +def _get_depther_config( + backbone_name: str = "dinov3_vit7b16", + depth_range: Optional[Tuple[float, float]] = None, + **kwargs, +): + out_index = _get_out_layers(backbone_name) + post_process_channels = _get_post_process_channels(backbone_name) + + depth_range = depth_range or _get_depth_range(DepthWeights(_DPT_HEAD_CONFIG_DICT["depth_weights"])) + min_depth, max_depth = depth_range + depther_config = DecoderConfig( + min_depth=min_depth, + max_depth=max_depth, + backbone_out_layers=out_index, + n_output_channels=_DPT_HEAD_CONFIG_DICT["n_output_channels"], # type: ignore + use_backbone_norm=bool(_DPT_HEAD_CONFIG_DICT["use_backbone_norm"]), + use_batchnorm=bool(_DPT_HEAD_CONFIG_DICT["use_batchnorm"]), + use_cls_token=bool(_DPT_HEAD_CONFIG_DICT["use_cls_token"]), + type="dpt", + # DPTHead args + head_kwargs=dict( + channels=512, + post_process_channels=post_process_channels, + ), + **kwargs, + ) + return depther_config + + +def _make_dinov3_dpt_depther( + *, + backbone_name: str = "dinov3_vit7b16", + pretrained: bool = True, + depther_weights: DepthWeights | str = DepthWeights.SYNTHMIX, + backbone_weights: BackboneWeights | str = BackboneWeights.LVD1689M, + depth_range: Optional[Tuple[float, float]] = None, + check_hash: bool = False, + autocast_dtype: torch.dtype = torch.float32, + **kwargs, +): + backbone: torch.nn.Module = _BACKBONE_DICT[backbone_name]( + pretrained=pretrained, + weights=backbone_weights, + ) + + depther = make_depther_from_config( + backbone, + config=_get_depther_config(backbone_name, depth_range), + autocast_dtype=autocast_dtype, + ) + + if pretrained: + if isinstance(depther_weights, DepthWeights): + assert depther_weights == DepthWeights.SYNTHMIX, f"Unsupported depther weights {depther_weights}" + weights_name = depther_weights.value.lower() + hash = kwargs["hash"] if "hash" in kwargs else "02040be1" + url = DINOV3_BASE_URL + f"/{backbone_name}/{backbone_name}_{weights_name}_dpt_head-{hash}.pth" + else: + url = convert_path_or_url_to_url(depther_weights) + checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=check_hash) + depther.decoder.load_state_dict(checkpoint, strict=True) + return depther + + +def dinov3_vit7b16_dd( + *, + pretrained: bool = True, + weights: DepthWeights | str = DepthWeights.SYNTHMIX, + backbone_weights: BackboneWeights | str = BackboneWeights.LVD1689M, + check_hash: bool = False, + autocast_dtype: torch.dtype = torch.float32, + **kwargs, +): + return _make_dinov3_dpt_depther( + backbone_name="dinov3_vit7b16", + pretrained=pretrained, + depther_weights=weights, + backbone_weights=backbone_weights, + check_hash=check_hash, + autocast_dtype=autocast_dtype, + **kwargs, + ) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/detectors.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/detectors.py new file mode 100644 index 0000000000000000000000000000000000000000..f4861af0187570f8a8dc3fd55c4d8d860e58dcc2 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/detectors.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import os +from enum import Enum + +import torch + +from dinov3.eval.detection.config import DetectionHeadConfig +from dinov3.eval.detection.models.detr import PostProcess, build_model +from dinov3.eval.detection.models.position_encoding import PositionEncoding + +from .backbones import Weights as BackboneWeights, dinov3_vit7b16, dinov3_vitl16plus, convert_path_or_url_to_url +from .utils import DINOV3_BASE_URL + + +class DetectionWeights(Enum): + COCO2017 = "COCO2017" + + +class DetectorWithProcessor(torch.nn.Module): + """ + takes as input a list of (3, H, W) normalized image tensors and outputs + a list of dicts with keys "scores", "labels" and "boxes" (format XYXY) + """ + + def __init__(self, detector, postprocessor): + super().__init__() + self.detector = detector + self.postprocessor = postprocessor + + def forward(self, samples: list[torch.Tensor]): + outputs = self.detector(samples) + sizes_tensor = torch.tensor([sample.shape[1:] for sample in samples], device=samples[0].device) # N * [3, H, W] + return self.postprocessor(outputs, target_sizes=sizes_tensor, original_target_sizes=sizes_tensor) + + +def _make_dinov3_detector( + *, + backbone_name: str, + pretrained: bool = True, + detector_weights: str | DetectionWeights, + backbone_weights: str | BackboneWeights, + check_hash: bool = False, + **kwargs, +): + detection_kwargs = dict( + with_box_refine=True, + two_stage=True, + mixed_selection=True, + look_forward_twice=True, + k_one2many=6, + lambda_one2many=1.0, + num_queries_one2one=1500, + num_queries_one2many=1500, + reparam=True, + position_embedding=PositionEncoding.SINE, + num_feature_levels=1, + dec_layers=6, + dim_feedforward=2048, + dropout=0.0, + norm_type="pre_norm", + proposal_feature_levels=4, + proposal_min_size=50, + decoder_type="global_rpe_decomp", + decoder_use_checkpoint=False, + decoder_rpe_hidden_dim=512, + decoder_rpe_type="linear", + layers_to_use=None, + blocks_to_train=None, + add_transformer_encoder=True, + num_encoder_layers=6, + backbone_use_layernorm=False, + num_classes=91, # 91 classes in COCO + aux_loss=True, + topk=1500, + hidden_dim=768, + nheads=8, + ) + config = DetectionHeadConfig(**detection_kwargs) + backbone_class = dict(dinov3_vit7b16=dinov3_vit7b16, dinov3_vitl16plus=dinov3_vitl16plus)[backbone_name] + n_windows_sqrt = dict(dinov3_vit7b16=3, dinov3_vitl16plus=2)[backbone_name] + backbone = backbone_class(pretrained=pretrained, weights=backbone_weights, check_hash=check_hash) + backbone.eval() + + config.n_windows_sqrt = n_windows_sqrt + config.proposal_in_stride = backbone.patch_size + config.proposal_tgt_strides = [int(m * backbone.patch_size) for m in (0.5, 1, 2, 4)] + + if config.layers_to_use is None: + # e.g. [2, 5, 8, 11] for a backbone with 12 blocks, similar to depth evaluation + config.layers_to_use = [m * backbone.n_blocks // 4 - 1 for m in range(1, 5)] + + detector = build_model(backbone, config) + if pretrained: + if type(detector_weights) is DetectionWeights and detector_weights == DetectionWeights.COCO2017: + assert detector_weights == DetectionWeights.COCO2017, f"Unsupported detector weights {detector_weights}" + detection_weights_name = detector_weights.value.lower() + hash = kwargs["hash"] if "hash" in kwargs else "b0235ff7" + model_filename = f"{backbone_name}_{detection_weights_name}_detr_head-{hash}.pth" + url = os.path.join(DINOV3_BASE_URL, backbone_name, model_filename) + else: + url = convert_path_or_url_to_url(detector_weights) + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=check_hash)["model"] + detector.load_state_dict(state_dict, strict=False) + # Necessary for inference + detector.num_queries = detector.num_queries_one2one + detector.transformer.two_stage_num_proposals = detector.num_queries + + postprocessor = PostProcess(config.topk, config.reparam) + model = DetectorWithProcessor(detector=detector, postprocessor=postprocessor) + return model + + +def dinov3_vit7b16_de( + *, + pretrained: bool = True, + weights: DetectionWeights | str = DetectionWeights.COCO2017, + backbone_weights: BackboneWeights | str = BackboneWeights.LVD1689M, + check_hash: bool = False, + **kwargs, +): + return _make_dinov3_detector( + backbone_name="dinov3_vit7b16", + pretrained=pretrained, + detector_weights=weights, + backbone_weights=backbone_weights, + check_hash=check_hash, + **kwargs, + ) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/dinotxt.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/dinotxt.py new file mode 100644 index 0000000000000000000000000000000000000000..cff38f562453a2c5cf2411c8d9c379f8e7b407e9 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/dinotxt.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import math +from typing import Any, Tuple, Union +from enum import Enum + +import torch +from torch import nn + +from .backbones import dinov3_vitl16, Weights as BackboneWeights, convert_path_or_url_to_url +from .utils import DINOV3_BASE_URL + + +class DINOTxtWeights(Enum): + LVTD2300M = "LVTD2300M" + + +# returns dinotxt model and tokenizer +def dinov3_vitl16_dinotxt_tet1280d20h24l( + *, + pretrained: bool = True, + weights: Union[DINOTxtWeights, str] = DINOTxtWeights.LVTD2300M, + backbone_weights: Union[BackboneWeights, str] = BackboneWeights.LVD1689M, + bpe_path_or_url: str = "https://dl.fbaipublicfiles.com/dinov3/thirdparty/bpe_simple_vocab_16e6.txt.gz", + check_hash: bool = False, +) -> Tuple[nn.Module, Any]: + from dinov3.eval.text.dinotxt_model import DINOTxt, DINOTxtConfig + from dinov3.eval.text.text_transformer import TextTransformer + from dinov3.eval.text.tokenizer import get_tokenizer + + dinotxt_config = DINOTxtConfig( + embed_dim=2048, + vision_model_freeze_backbone=True, + vision_model_train_img_size=224, + vision_model_use_class_token=True, + vision_model_use_patch_tokens=True, + vision_model_num_head_blocks=2, + vision_model_head_blocks_drop_path=0.3, + vision_model_use_linear_projection=False, + vision_model_patch_tokens_pooler_type="mean", + vision_model_patch_token_layer=1, # which layer to take patch tokens from + # 1 - last layer, 2 - second last layer, etc. + text_model_freeze_backbone=False, + text_model_num_head_blocks=0, + text_model_head_blocks_is_causal=False, + text_model_head_blocks_drop_prob=0.0, + text_model_tokens_pooler_type="argmax", + text_model_use_linear_projection=True, + init_logit_scale=math.log(1 / 0.07), + init_logit_bias=None, + freeze_logit_scale=False, + ) + vision_backbone = dinov3_vitl16(pretrained=pretrained, weights=backbone_weights) + text_backbone = TextTransformer( + context_length=77, + vocab_size=49408, + dim=1280, + num_heads=20, + num_layers=24, + ffn_ratio=4, + is_causal=True, + ls_init_value=None, + dropout_prob=0.0, + ) + model = DINOTxt(model_config=dinotxt_config, vision_backbone=vision_backbone, text_backbone=text_backbone) + if pretrained: + model.visual_model.backbone = vision_backbone + model.eval() + if type(weights) is DINOTxtWeights and weights == DINOTxtWeights.LVTD2300M: + url = f"{DINOV3_BASE_URL}/dinov3_vitl16/dinov3_vitl16_dinotxt_vision_head_and_text_encoder-a442d8f5.pth" + elif type(weights) is DINOTxtWeights and weights != DINOTxtWeights.LVTD2300M: + raise AssertionError(f"Unsuported weights for DINOTxt: {weights}") + else: + url = convert_path_or_url_to_url(weights) + vision_head_and_text_encoder_state_dict = torch.hub.load_state_dict_from_url(url, check_hash=check_hash) + model.load_state_dict(vision_head_and_text_encoder_state_dict, strict=False) + else: + model.init_weights() + return model, get_tokenizer(bpe_path_or_url=bpe_path_or_url) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/segmentors.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/segmentors.py new file mode 100644 index 0000000000000000000000000000000000000000..d6b40b7e4362f6938945c97be0c2e0d993d8a6f0 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/segmentors.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import os +from enum import Enum + +import torch +from dinov3.eval.segmentation.models import build_segmentation_decoder + +from .backbones import ( + dinov3_vit7b16, + dinov3_vitl16, + Weights as BackboneWeights, + convert_path_or_url_to_url, +) +from .utils import DINOV3_BASE_URL + + +class SegmentorWeights(Enum): + ADE20K = "ADE20K" + + +def _make_dinov3_m2f_segmentor( + *, + backbone_name: str = "dinov3_vit7b16", + pretrained: bool = True, + segmentor_weights: SegmentorWeights | str = SegmentorWeights.ADE20K, + backbone_weights: BackboneWeights | str = BackboneWeights.LVD1689M, + check_hash: bool = False, + autocast_dtype: torch.dtype = torch.bfloat16, + **kwargs, +): + if backbone_name == "dinov3_vit7b16": + backbone_model = dinov3_vit7b16(pretrained=pretrained, weights=backbone_weights, check_hash=check_hash) + elif backbone_name == "dinov3_vitl16": + backbone_model = dinov3_vitl16(pretrained=pretrained, weights=backbone_weights, check_hash=check_hash) + else: + raise AssertionError(f"No pretrained segmentation checkpoint available for {backbone_name}") + + hidden_dim = 2048 if "hidden_dim" not in kwargs else kwargs["hidden_dim"] + segmentor = build_segmentation_decoder( + backbone_model=backbone_model, + decoder_type="m2f", + hidden_dim=hidden_dim, + autocast_dtype=autocast_dtype, + ) + if pretrained: + if type(segmentor_weights) is SegmentorWeights: + assert segmentor_weights == SegmentorWeights.ADE20K, f"Unsupported weights for segmentor: {segmentor_weights}" + segmentor_weights_name = segmentor_weights.value.lower() + hash = kwargs["hash"] if "hash" in kwargs else "bf307cb1" + model_filename = f"{backbone_name}_{segmentor_weights_name}_m2f_head-{hash}.pth" + url = os.path.join(DINOV3_BASE_URL, backbone_name, model_filename) + else: + url = convert_path_or_url_to_url(segmentor_weights) + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=check_hash) + missing_keys, unexpected_keys = segmentor.load_state_dict(state_dict, strict=False) + assert len([k for k in missing_keys if "backbone" not in k]) == 0 + assert len(unexpected_keys) == 0 + + return segmentor + + +def dinov3_vit7b16_ms( + *, + pretrained: bool = True, + weights: SegmentorWeights | str = SegmentorWeights.ADE20K, + backbone_weights: BackboneWeights | str = BackboneWeights.LVD1689M, + check_hash: bool = False, + autocast_dtype: torch.dtype = torch.bfloat16, + **kwargs, +): + return _make_dinov3_m2f_segmentor( + backbone_name="dinov3_vit7b16", + pretrained=pretrained, + segmentor_weights=weights, + backbone_weights=backbone_weights, + check_hash=check_hash, + autocast_dtype=autocast_dtype, + **kwargs, + ) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/utils.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3de79fcf1f032fe3d5d382c27c089cec5c1a0cb9 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/hub/utils.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +DINOV3_BASE_URL = "https://dl.fbaipublicfiles.com/dinov3" diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b82c261aaba6f0b7b871662f3549062e69928f8 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from .attention import CausalSelfAttention, LinearKMaskedBias, SelfAttention +from .block import CausalSelfAttentionBlock, SelfAttentionBlock +from .ffn_layers import Mlp, SwiGLUFFN +from .fp8_linear import convert_linears_to_fp8 +from .layer_scale import LayerScale +from .patch_embed import PatchEmbed +from .rms_norm import RMSNorm +from .rope_position_encoding import RopePositionEmbedding diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/attention.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..8335b0176fa7ea15808a558fb2c802e06a9bbf9d --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/attention.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import math +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from dinov3.utils import cat_keep_shapes, uncat_with_shapes +from torch import Tensor, nn + + +# RoPE-related functions: +def rope_rotate_half(x: Tensor) -> Tensor: + # x: [ x0 x1 x2 x3 x4 x5] + # out: [-x3 -x4 -x5 x0 x1 x2] + x1, x2 = x.chunk(2, dim=-1) + return torch.cat([-x2, x1], dim=-1) + + +def rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor: + # x: [..., D], eg [x0, x1, x2, x3, x4, x5] + # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2] + # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2] + return (x * cos) + (rope_rotate_half(x) * sin) + + +class LinearKMaskedBias(nn.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + o = self.out_features + assert o % 3 == 0 + if self.bias is not None: + self.register_buffer("bias_mask", torch.full_like(self.bias, fill_value=math.nan)) + + def forward(self, input: Tensor) -> Tensor: + masked_bias = self.bias * self.bias_mask.to(self.bias.dtype) if self.bias is not None else None + return F.linear(input, self.weight, masked_bias) + + +class SelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + mask_k_bias: bool = False, + device=None, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + linear_class = LinearKMaskedBias if mask_k_bias else nn.Linear + self.qkv = linear_class(dim, dim * 3, bias=qkv_bias, device=device) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias, device=device) + self.proj_drop = nn.Dropout(proj_drop) + + def apply_rope(self, q: Tensor, k: Tensor, rope: Tensor | Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: + # All operations will use the dtype of rope, the output is cast back to the dtype of q and k + q_dtype = q.dtype + k_dtype = k.dtype + sin, cos = rope + rope_dtype = sin.dtype + q = q.to(dtype=rope_dtype) + k = k.to(dtype=rope_dtype) + N = q.shape[-2] + prefix = N - sin.shape[-2] + assert prefix >= 0 + q_prefix = q[:, :, :prefix, :] + q = rope_apply(q[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head] + q = torch.cat((q_prefix, q), dim=-2) # [B, head, N, D//head] + k_prefix = k[:, :, :prefix, :] + k = rope_apply(k[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head] + k = torch.cat((k_prefix, k), dim=-2) # [B, head, N, D//head] + q = q.to(dtype=q_dtype) + k = k.to(dtype=k_dtype) + return q, k + + def forward(self, x: Tensor, attn_bias=None, rope: Tensor = None) -> Tensor: + qkv = self.qkv(x) + attn_v = self.compute_attention(qkv=qkv, attn_bias=attn_bias, rope=rope) + x = self.proj(attn_v) + x = self.proj_drop(x) + return x + + def forward_list(self, x_list, attn_bias=None, rope_list=None) -> List[Tensor]: + assert len(x_list) == len(rope_list) # should be enforced by the Block + x_flat, shapes, num_tokens = cat_keep_shapes(x_list) + qkv_flat = self.qkv(x_flat) + qkv_list = uncat_with_shapes(qkv_flat, shapes, num_tokens) + att_out = [] + for _, (qkv, _, rope) in enumerate(zip(qkv_list, shapes, rope_list)): + att_out.append(self.compute_attention(qkv, attn_bias=attn_bias, rope=rope)) + x_flat, shapes, num_tokens = cat_keep_shapes(att_out) + x_flat = self.proj(x_flat) + return uncat_with_shapes(x_flat, shapes, num_tokens) + + def compute_attention(self, qkv: Tensor, attn_bias=None, rope=None) -> Tensor: + assert attn_bias is None + B, N, _ = qkv.shape + C = self.qkv.in_features + + qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads) + q, k, v = torch.unbind(qkv, 2) + q, k, v = [t.transpose(1, 2) for t in [q, k, v]] + if rope is not None: + q, k = self.apply_rope(q, k, rope) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = x.transpose(1, 2) + return x.reshape([B, N, C]) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def init_weights( + self, init_attn_std: float | None = None, init_proj_std: float | None = None, factor: float = 1.0 + ) -> None: + init_attn_std = init_attn_std or (self.dim**-0.5) + init_proj_std = init_proj_std or init_attn_std * factor + nn.init.normal_(self.qkv.weight, std=init_attn_std) + nn.init.normal_(self.proj.weight, std=init_proj_std) + if self.qkv.bias is not None: + nn.init.zeros_(self.qkv.bias) + if self.proj.bias is not None: + nn.init.zeros_(self.proj.bias) + + def forward(self, x: Tensor, is_causal: bool = True) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + q, k, v = torch.unbind(qkv, 2) + q, k, v = [t.transpose(1, 2) for t in [q, k, v]] + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=self.attn_drop if self.training else 0, is_causal=is_causal + ) + x = x.transpose(1, 2).contiguous().view(B, N, C) + x = self.proj_drop(self.proj(x)) + return x diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/block.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..2051a1a607ab8c10c2476fa9d019225f4ba6ab5b --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/block.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from typing import Callable, List, Optional + +import torch +from torch import Tensor, nn + +from dinov3.utils import cat_keep_shapes, uncat_with_shapes + +from .attention import CausalSelfAttention, SelfAttention +from .ffn_layers import Mlp +from .layer_scale import LayerScale # , DropPath + +torch._dynamo.config.automatic_dynamic_shapes = False +torch._dynamo.config.accumulated_cache_size_limit = 1024 + + +class SelfAttentionBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + ffn_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = SelfAttention, + ffn_layer: Callable[..., nn.Module] = Mlp, + mask_k_bias: bool = False, + device=None, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + mask_k_bias=mask_k_bias, + device=device, + ) + self.ls1 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * ffn_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + device=device, + ) + self.ls2 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity() + + self.sample_drop_ratio = drop_path + + @staticmethod + def _maybe_index_rope(rope: tuple[Tensor, Tensor] | None, indices: Tensor) -> tuple[Tensor, Tensor] | None: + if rope is None: + return None + + sin, cos = rope + assert sin.ndim == cos.ndim + if sin.ndim == 4: + # If the rope embedding has a batch dimension (is different for each batch element), index into it + return sin[indices], cos[indices] # [batch, heads, patches, embed_dim] + else: + # No batch dimension, do not index + return sin, cos # [heads, patches, embed_dim] or [patches, embed_dim] + + def _forward(self, x: Tensor, rope=None) -> Tensor: + """ + This is the reference implementation for a single tensor, matching what is done below for a list. + We call the list op on [x] instead of this function. + """ + b, _, _ = x.shape + sample_subset_size = max(int(b * (1 - self.sample_drop_ratio)), 1) + residual_scale_factor = b / sample_subset_size + + if self.training and self.sample_drop_ratio > 0.0: + indices_1 = (torch.randperm(b, device=x.device))[:sample_subset_size] + + x_subset_1 = x[indices_1] + rope_subset = self._maybe_index_rope(rope, indices_1) + residual_1 = self.attn(self.norm1(x_subset_1), rope=rope_subset) + + x_attn = torch.index_add( + x, + dim=0, + source=self.ls1(residual_1), + index=indices_1, + alpha=residual_scale_factor, + ) + + indices_2 = (torch.randperm(b, device=x.device))[:sample_subset_size] + + x_subset_2 = x_attn[indices_2] + residual_2 = self.mlp(self.norm2(x_subset_2)) + + x_ffn = torch.index_add( + x_attn, + dim=0, + source=self.ls2(residual_2), + index=indices_2, + alpha=residual_scale_factor, + ) + else: + x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope)) + x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn))) + + return x_ffn + + def _forward_list(self, x_list: List[Tensor], rope_list=None) -> List[Tensor]: + """ + This list operator concatenates the tokens from the list of inputs together to save + on the elementwise operations. Torch-compile memory-planning allows hiding the overhead + related to concat ops. + """ + b_list = [x.shape[0] for x in x_list] + sample_subset_sizes = [max(int(b * (1 - self.sample_drop_ratio)), 1) for b in b_list] + residual_scale_factors = [b / sample_subset_size for b, sample_subset_size in zip(b_list, sample_subset_sizes)] + + if self.training and self.sample_drop_ratio > 0.0: + indices_1_list = [ + (torch.randperm(b, device=x.device))[:sample_subset_size] + for x, b, sample_subset_size in zip(x_list, b_list, sample_subset_sizes) + ] + x_subset_1_list = [x[indices_1] for x, indices_1 in zip(x_list, indices_1_list)] + + if rope_list is not None: + rope_subset_list = [ + self._maybe_index_rope(rope, indices_1) for rope, indices_1 in zip(rope_list, indices_1_list) + ] + else: + rope_subset_list = rope_list + + flattened, shapes, num_tokens = cat_keep_shapes(x_subset_1_list) + norm1 = uncat_with_shapes(self.norm1(flattened), shapes, num_tokens) + residual_1_list = self.attn.forward_list(norm1, rope_list=rope_subset_list) + + x_attn_list = [ + torch.index_add( + x, + dim=0, + source=self.ls1(residual_1), + index=indices_1, + alpha=residual_scale_factor, + ) + for x, residual_1, indices_1, residual_scale_factor in zip( + x_list, residual_1_list, indices_1_list, residual_scale_factors + ) + ] + + indices_2_list = [ + (torch.randperm(b, device=x.device))[:sample_subset_size] + for x, b, sample_subset_size in zip(x_list, b_list, sample_subset_sizes) + ] + x_subset_2_list = [x[indices_2] for x, indices_2 in zip(x_attn_list, indices_2_list)] + flattened, shapes, num_tokens = cat_keep_shapes(x_subset_2_list) + norm2_flat = self.norm2(flattened) + norm2_list = uncat_with_shapes(norm2_flat, shapes, num_tokens) + + residual_2_list = self.mlp.forward_list(norm2_list) + + x_ffn = [ + torch.index_add( + x_attn, + dim=0, + source=self.ls2(residual_2), + index=indices_2, + alpha=residual_scale_factor, + ) + for x_attn, residual_2, indices_2, residual_scale_factor in zip( + x_attn_list, residual_2_list, indices_2_list, residual_scale_factors + ) + ] + else: + x_out = [] + for x, rope in zip(x_list, rope_list): + x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope)) + x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn))) + x_out.append(x_ffn) + x_ffn = x_out + + return x_ffn + + def forward(self, x_or_x_list, rope_or_rope_list=None) -> List[Tensor]: + if isinstance(x_or_x_list, Tensor): + # for reference: + # return self._forward(x_or_x_list, rope=rope_or_rope_list) + # in order to match implementations we call the list op: + return self._forward_list([x_or_x_list], rope_list=[rope_or_rope_list])[0] + elif isinstance(x_or_x_list, list): + if rope_or_rope_list is None: + rope_or_rope_list = [None for x in x_or_x_list] + # return [self._forward(x, rope=rope) for x, rope in zip(x_or_x_list, rope_or_rope_list)] + return self._forward_list(x_or_x_list, rope_list=rope_or_rope_list) + else: + raise AssertionError + + +class CausalSelfAttentionBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + ffn_ratio: float = 4.0, + ls_init_value: Optional[float] = None, + is_causal: bool = True, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + dropout_prob: float = 0.0, + ): + super().__init__() + + self.dim = dim + self.is_causal = is_causal + self.ls1 = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity() + self.attention_norm = norm_layer(dim) + self.attention = CausalSelfAttention(dim, num_heads, attn_drop=dropout_prob, proj_drop=dropout_prob) + + self.ffn_norm = norm_layer(dim) + ffn_hidden_dim = int(dim * ffn_ratio) + self.feed_forward = Mlp( + in_features=dim, + hidden_features=ffn_hidden_dim, + drop=dropout_prob, + act_layer=act_layer, + ) + + self.ls2 = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity() + + def init_weights( + self, + init_attn_std: float | None = None, + init_proj_std: float | None = None, + init_fc_std: float | None = None, + factor: float = 1.0, + ) -> None: + init_attn_std = init_attn_std or (self.dim**-0.5) + init_proj_std = init_proj_std or init_attn_std * factor + init_fc_std = init_fc_std or (2 * self.dim) ** -0.5 + self.attention.init_weights(init_attn_std, init_proj_std) + self.attention_norm.reset_parameters() + nn.init.normal_(self.feed_forward.fc1.weight, std=init_fc_std) + nn.init.normal_(self.feed_forward.fc2.weight, std=init_proj_std) + self.ffn_norm.reset_parameters() + + def forward( + self, + x: torch.Tensor, + ): + + x_attn = x + self.ls1(self.attention(self.attention_norm(x), self.is_causal)) + x_ffn = x_attn + self.ls2(self.feed_forward(self.ffn_norm(x_attn))) + return x_ffn diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/dino_head.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..bb71f35fc7ecf15e31963eb76d21626ccdea9b90 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/dino_head.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp( + nlayers, + in_dim, + bottleneck_dim, + hidden_dim=hidden_dim, + use_bn=use_bn, + bias=mlp_bias, + ) + self.last_layer = nn.Linear(bottleneck_dim, out_dim, bias=False) + + def init_weights(self) -> None: + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, no_last_layer=False, only_last_layer=False): + if not only_last_layer: + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + if not no_last_layer: + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/ffn_layers.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/ffn_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..749e6be2cd2cd8040bb3233a61a9a306eaea6ca8 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/ffn_layers.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from typing import Callable, List, Optional + +import torch.nn.functional as F +from torch import Tensor, nn + +from dinov3.utils import cat_keep_shapes, uncat_with_shapes + + +class ListForwardMixin(object): + def forward(self, x: Tensor): + raise NotImplementedError + + def forward_list(self, x_list: List[Tensor]) -> List[Tensor]: + x_flat, shapes, num_tokens = cat_keep_shapes(x_list) + x_flat = self.forward(x_flat) + return uncat_with_shapes(x_flat, shapes, num_tokens) + + +class Mlp(nn.Module, ListForwardMixin): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + device=None, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, device=device) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, device=device) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class SwiGLUFFN(nn.Module, ListForwardMixin): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Optional[Callable[..., nn.Module]] = None, + drop: float = 0.0, + bias: bool = True, + align_to: int = 8, + device=None, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + d = int(hidden_features * 2 / 3) + swiglu_hidden_features = d + (-d % align_to) + self.w1 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) + self.w2 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) + self.w3 = nn.Linear(swiglu_hidden_features, out_features, bias=bias, device=device) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.w1(x) + x2 = self.w2(x) + hidden = F.silu(x1) * x2 + return self.w3(hidden) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/fp8_linear.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/fp8_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..5fd88a179ef7b6d638dc211a1e592ac8a726b4e9 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/fp8_linear.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import re + +import torch + +from dinov3.layers.attention import LinearKMaskedBias +from dinov3.utils import named_replace + +# avoid division by zero when calculating scale +EPS = 1e-12 + + +def scale(t, amax_t): + max_v = torch.finfo(torch.float8_e4m3fn).max + scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v + t_fp8 = (t / scale_t).to(torch.float8_e4m3fn) + return t_fp8, scale_t + + +def matmul(first, amax_first, second_t, amax_second_t, bias): + first_fp8, scale_first = scale(first, amax_first) + second_t_fp8, scale_second_t = scale(second_t, amax_second_t) + # PyTorch's row-wise scaled matmul kernel is based on CUTLASS and is quite + # slow. Hence we fall back to an "unscaled" matmul, which uses cuBLAS, and + # apply the scale manually afterwards. + output = torch._scaled_mm( + first_fp8, + second_t_fp8.t(), + scale_a=scale_first.new_ones((1, 1)), + scale_b=scale_second_t.t().new_ones((1, 1)), + bias=None, + out_dtype=torch.bfloat16, + use_fast_accum=False, + ) + output = (output * scale_first * scale_second_t.t()).to(torch.bfloat16) + if bias is not None: + output = output + bias + return output + + +@torch.compiler.allow_in_graph +class Fp8LinearFn(torch.autograd.Function): + @staticmethod + def forward(ctx, a, b_t, bias): + amax_a = a.abs().amax(dim=-1, keepdim=True) + amax_b_t = b_t.abs().amax(dim=-1, keepdim=True) + out = matmul(a, amax_a, b_t, amax_b_t, bias) + + ctx.a_requires_grad = a.requires_grad + ctx.b_requires_grad = b_t.requires_grad + ctx.bias_requires_grad = bias.requires_grad if bias is not None else False + + ctx.save_for_backward(a, b_t, amax_b_t.max()) + + return out + + @staticmethod + def backward(ctx, grad_out): + a, b_t, amax_b = ctx.saved_tensors + + if ctx.a_requires_grad: + b = b_t.t().contiguous() + amax_grad_out = grad_out.abs().amax(dim=-1, keepdim=True) + amax_b = amax_b.repeat(b.shape[0], 1) + grad_a = matmul(grad_out, amax_grad_out, b, amax_b, None) + else: + grad_a = None + if ctx.b_requires_grad: + grad_b = grad_out.t() @ a + else: + grad_b = None + if ctx.bias_requires_grad: + grad_bias = grad_out.sum(dim=0) + else: + grad_bias = None + + return grad_a, grad_b, grad_bias + + +class Fp8Linear(torch.nn.Linear): + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias) + out = out.unflatten(0, input.shape[:-1]) + return out + + +class Fp8LinearKMaskedBias(LinearKMaskedBias): + def forward(self, input: torch.Tensor) -> torch.Tensor: + masked_bias = self.bias * self.bias_mask if self.bias is not None else None + out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, masked_bias) + out = out.unflatten(0, input.shape[:-1]) + return out + + +def convert_linears_to_fp8(root_module: torch.nn.Module, *, filter: str) -> torch.nn.Module: + filter_re = re.compile(filter) + total_count = 0 + + def replace(module: torch.nn.Module, name: str) -> torch.nn.Module: + nonlocal total_count + if not isinstance(module, torch.nn.Linear) or not filter_re.search(name): + return module + if type(module) == torch.nn.Linear: + new_cls = Fp8Linear + elif type(module) == LinearKMaskedBias: + new_cls = Fp8LinearKMaskedBias + else: + assert False, str(type(module)) + if module.in_features % 64 != 0 or module.out_features % 64 != 0: + # This is not a strict requirement, but H100 TensorCores for fp8 + # operate on tiles of 64 elements anyways, and Inductor sometimes + # pads inner dims to become multiples of 64. Also, if one day we + # switch back to cuBLAS, it artificially requires dims to be + # multiples of 16. + raise RuntimeError( + "fp8 requires all dimensions to be multiples of 64 " "(consider using ffn_layer=swiglu64 or higher)" + ) + new_module = new_cls( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + dtype=module.weight.dtype, + device=module.weight.device, + ) + new_module.weight = module.weight + new_module.bias = module.bias + total_count += 1 + return new_module + + out = named_replace(replace, root_module) + assert total_count > 0, "fp8: no layer found to convert" + # Force re-compile everything + torch._dynamo.reset_code_caches() + from torch._inductor.cudagraph_trees import reset_cudagraph_trees + + reset_cudagraph_trees() + return out diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/layer_scale.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..0b72b7c64c9cc38fd4e3db63e9c90f0158caa36c --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/layer_scale.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from typing import Union + +import torch +from torch import Tensor, nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + device=None, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(torch.empty(dim, device=device)) + self.init_values = init_values + + def reset_parameters(self): + nn.init.constant_(self.gamma, self.init_values) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/patch_embed.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..760343f14cd0c1c2bbb2c70d43f82eb0bb1fddf4 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/patch_embed.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import math +from typing import Callable, Tuple, Union + +from torch import Tensor, nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Callable | None = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + # patch_H, patch_W = self.patch_size + # assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + # assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + def reset_parameters(self): + k = 1 / (self.in_chans * (self.patch_size[0] ** 2)) + nn.init.uniform_(self.proj.weight, -math.sqrt(k), math.sqrt(k)) + if self.proj.bias is not None: + nn.init.uniform_(self.proj.bias, -math.sqrt(k), math.sqrt(k)) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/rms_norm.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..1d0a89c47c5e71687cadbf47fef567b2c6a2b3b4 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/rms_norm.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import torch +from torch import Tensor, nn + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + self.eps = eps + + def reset_parameters(self) -> None: + nn.init.constant_(self.weight, 1) + + def _norm(self, x: Tensor) -> Tensor: + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/rope_position_encoding.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/rope_position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..2635d09e7732fb4c146f57d3aef19c2d3e5668ec --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/rope_position_encoding.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import math +from typing import Literal + +import numpy as np +import torch +from torch import Tensor, nn + + +# RoPE positional embedding with no mixing of coordinates (axial) and no learnable weights +# Supports two parametrizations of the rope parameters: either using `base` or `min_period` and `max_period`. +class RopePositionEmbedding(nn.Module): + def __init__( + self, + embed_dim: int, + *, + num_heads: int, + base: float | None = 100.0, + min_period: float | None = None, + max_period: float | None = None, + normalize_coords: Literal["min", "max", "separate"] = "separate", + shift_coords: float | None = None, + jitter_coords: float | None = None, + rescale_coords: float | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + ): + super().__init__() + assert embed_dim % (4 * num_heads) == 0 + both_periods = min_period is not None and max_period is not None + if (base is None and not both_periods) or (base is not None and both_periods): + raise ValueError("Either `base` or `min_period`+`max_period` must be provided.") + + D_head = embed_dim // num_heads + self.base = base + self.min_period = min_period + self.max_period = max_period + self.D_head = D_head + self.normalize_coords = normalize_coords + self.shift_coords = shift_coords + self.jitter_coords = jitter_coords + self.rescale_coords = rescale_coords + + # Needs persistent=True because we do teacher.load_state_dict(student.state_dict()) to initialize the teacher + self.dtype = dtype # Don't rely on self.periods.dtype + self.register_buffer( + "periods", + torch.empty(D_head // 4, device=device, dtype=dtype), + persistent=True, + ) + self._init_weights() + + def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]: + device = self.periods.device + dtype = self.dtype + dd = {"device": device, "dtype": dtype} + + # Prepare coords in range [-1, +1] + if self.normalize_coords == "max": + max_HW = max(H, W) + coords_h = torch.arange(0.5, H, **dd) / max_HW # [H] + coords_w = torch.arange(0.5, W, **dd) / max_HW # [W] + elif self.normalize_coords == "min": + min_HW = min(H, W) + coords_h = torch.arange(0.5, H, **dd) / min_HW # [H] + coords_w = torch.arange(0.5, W, **dd) / min_HW # [W] + elif self.normalize_coords == "separate": + coords_h = torch.arange(0.5, H, **dd) / H # [H] + coords_w = torch.arange(0.5, W, **dd) / W # [W] + else: + raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}") + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # [H, W, 2] + coords = coords.flatten(0, 1) # [HW, 2] + coords = 2.0 * coords - 1.0 # Shift range [0, 1] to [-1, +1] + + # Shift coords by adding a uniform value in [-shift, shift] + if self.training and self.shift_coords is not None: + shift_hw = torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords) + coords += shift_hw[None, :] + + # Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter] + if self.training and self.jitter_coords is not None: + jitter_max = np.log(self.jitter_coords) + jitter_min = -jitter_max + jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp() + coords *= jitter_hw[None, :] + + # Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale] + if self.training and self.rescale_coords is not None: + rescale_max = np.log(self.rescale_coords) + rescale_min = -rescale_max + rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp() + coords *= rescale_hw + + # Prepare angles and sin/cos + angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] # [HW, 2, D//4] + angles = angles.flatten(1, 2) # [HW, D//2] + angles = angles.tile(2) # [HW, D] + cos = torch.cos(angles) # [HW, D] + sin = torch.sin(angles) # [HW, D] + + return (sin, cos) # 2 * [HW, D] + + def _init_weights(self): + device = self.periods.device + dtype = self.dtype + if self.base is not None: + periods = self.base ** ( + 2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2) + ) # [D//4] + else: + base = self.max_period / self.min_period + exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype) # [D//4] range [0, 1] + periods = base**exponents # range [1, max_period / min_period] + periods = periods / base # range [min_period / max_period, 1] + periods = periods * self.max_period # range [min_period, max_period] + self.periods.data = periods diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/sparse_linear.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/sparse_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..fbb9e103182d10ed94a2dfa13e58060516ff3dbe --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/layers/sparse_linear.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from typing import Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +import xformers.ops as xops + +from dinov3.utils import named_apply, named_replace + +logger = logging.getLogger("dinov3") + + +class LinearW24(torch.nn.Linear): + ALGO = "largest_abs_values_greedy" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.sparsity_enabled = False + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not self.sparsity_enabled: + return super().forward(input) + + input_shape = input.shape + input = input.flatten(end_dim=-2) + dim0 = input.shape[0] + if dim0 % 8 != 0: + # NOTE: This should be torch-compiled away + input = F.pad(input, [0, 0, 0, -dim0 % 8]) + w_sparse = xops.sparsify24( + self.weight, + algo=self.ALGO, + gradient="ste", + backend="cusparselt", + ) + return F.linear(input, w_sparse, self.bias,)[ + :dim0 + ].unflatten(dim=0, sizes=input_shape[:-1]) + + +def replace_linears_with_sparse_linear(root_module: nn.Module, *, filter_fn: Callable[[str], bool]) -> nn.Module: + total_count = 0 + + def replace(module: nn.Module, name: str) -> nn.Module: + nonlocal total_count + if not isinstance(module, nn.Linear) or not filter_fn(name): + return module + assert type(module) == nn.Linear, "Subtypes not supported" + new_module = LinearW24( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + dtype=module.weight.dtype, + device=module.weight.device, + ) + new_module.weight = module.weight + new_module.bias = module.bias + total_count += 1 + return new_module + + out = named_replace(replace, root_module) + assert total_count > 0, "2:4 sparsity: no layer found to sparsify" + return out + + +def update_24sparsity(root_module: nn.Module, enabled: bool) -> int: + num_modified = 0 + + def maybe_apply_sparsity(module: nn.Module, name: str) -> nn.Module: + nonlocal num_modified + if not isinstance(module, LinearW24): + return module + num_modified += 1 + module.sparsity_enabled = enabled + logger.info(f"- {'' if module.sparsity_enabled else 'de'}sparsifying {name}") + return module + + named_apply(maybe_apply_sparsity, root_module) + # Force re-compile everything + torch._dynamo.reset_code_caches() + from torch._inductor.cudagraph_trees import reset_cudagraph_trees + + reset_cudagraph_trees() + return num_modified diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/logging/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/logging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..118925070ffda3ae73b6c15b0549a772f07d063d --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/logging/__init__.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import functools +import logging +import os +import sys +from typing import Optional + +from termcolor import colored + +from dinov3.distributed import TorchDistributedEnvironment + +from dinov3.logging.helpers import MetricLogger, SmoothedValue + +_LEVEL_COLORED_KWARGS = { + logging.DEBUG: {"color": "green", "attrs": ["bold"]}, + logging.INFO: {"color": "green"}, + logging.WARNING: {"color": "yellow"}, + logging.ERROR: {"color": "red"}, + logging.CRITICAL: {"color": "red", "attrs": ["bold"]}, +} + + +class _LevelColoredFormatter(logging.Formatter): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def formatMessage(self, record): + log = super().formatMessage(record) + + colored_kwargs = _LEVEL_COLORED_KWARGS.get(record.levelno) + if colored_kwargs is None: + return log + + msg = record.msg % record.args if record.msg == "%s" else record.msg + index = log.rfind(msg, len(log) - len(msg)) + # Can happen in some cases, like if the msg contains `%s` which + # have been replaced in `formatMessage`. Fallback to no colors + if index == -1: + return log + prefix = log[:index] + prefix = colored(prefix, **colored_kwargs) + return prefix + msg + + +# So that calling _configure_logger multiple times won't add many handlers +@functools.lru_cache() +def _configure_logger( + name: Optional[str] = None, + *, + level: int = logging.DEBUG, + output: Optional[str] = None, + color: bool = True, + log_to_stdout_only_in_main_process: bool = True, +): + """ + Configure a logger. + + Adapted from Detectron2. + + Args: + name: The name of the logger to configure. + level: The logging level to use. + output: A file name or a directory to save log. If None, will not save log file. + If ends with ".txt" or ".log", assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + color: Whether stdout output should be colored (ignored if stdout is not a terminal). + log_to_stdout_only_in_main_process: The main process (rank 0) always logs to stdout, + regardless of this flag. If False, other ranks will also log to their stdout. + + Returns: + The configured logger. + """ + + # Disable colored output if the stdout is not a terminal + color = color and os.isatty(sys.stdout.fileno()) + + logger = logging.getLogger(name) + logger.setLevel(level) + logger.propagate = False + + # Loosely match Google glog format: + # [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg + # but use a shorter timestamp and include the logger name: + # [IWEF]yyyymmdd hh:mm:ss logger threadid file:line] msg + fmt_prefix = "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] " + fmt_message = "%(message)s" + fmt = fmt_prefix + fmt_message + datefmt = "%Y%m%d %H:%M:%S" + plain_formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) + + torch_env = TorchDistributedEnvironment() + + # rank 0 always logs to stdout, for other ranks it depends on log_to_stdout_only_in_main_process + should_log_to_stdout = torch_env.is_main_process or not log_to_stdout_only_in_main_process + if should_log_to_stdout: + handler = logging.StreamHandler(stream=sys.stdout) + handler.setLevel(logging.DEBUG) + + formatter: logging.Formatter + if color: + formatter = _LevelColoredFormatter( + fmt=fmt, + datefmt=datefmt, + ) + else: + formatter = plain_formatter + + handler.setFormatter(formatter) + logger.addHandler(handler) + + # file logging for all workers + if output: + if os.path.splitext(output)[-1] in (".txt", ".log"): + filename = output + else: + filename = os.path.join(output, "logs", "log.txt") + + if not torch_env.is_main_process: + filename = filename + f".rank{torch_env.rank}" + + os.makedirs(os.path.dirname(filename), exist_ok=True) + + handler = logging.StreamHandler(open(filename, "a")) + handler.setLevel(logging.DEBUG) + handler.setFormatter(plain_formatter) + logger.addHandler(handler) + + logger.debug(f"PyTorch distributed environment: {torch_env}") + return logger + + +def setup_logging( + output: Optional[str] = None, + *, + name: Optional[str] = None, + level: int = logging.DEBUG, + color: bool = True, + capture_warnings: bool = True, + log_to_stdout_only_in_main_process: bool = True, +) -> None: + """ + Setup logging. + + Args: + output: A file name or a directory to save log files. If None, log + files will not be saved. If output ends with ".txt" or ".log", it + is assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + name: The name of the logger to configure, by default the root logger. + level: The logging level to use. + color: Whether stdout output should be colored (ignored if stdout is not a terminal). + capture_warnings: Whether warnings should be captured as logs. + log_to_stdout_only_in_main_process: The main process (rank 0) always logs to stdout, + regardless of this flag. If False, other ranks will also log to their stdout. + """ + logging.captureWarnings(capture_warnings) + # Ensure the path is canonical to properly use the cache of `_configure_logger` + output = output if output is None else os.path.realpath(output) + _configure_logger( + name, + level=level, + output=output, + color=color, + log_to_stdout_only_in_main_process=log_to_stdout_only_in_main_process, + ) + + +def cleanup_logging(*, name: Optional[str] = None) -> None: + logger = logging.getLogger(name) + for handler in logger.handlers: + handler.flush() + handler.close() + logger.removeHandler(handler) + + # clears the cache of `_configure_logger` to allow re-initialization + _configure_logger.cache_clear() diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/logging/helpers.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/logging/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..b01e1f8352a4235ef1707ebb0eddbc6d80d729ef --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/logging/helpers.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import datetime +import json +import logging +import time +from collections import defaultdict, deque + +import torch + +import dinov3.distributed as distributed + +logger = logging.getLogger("dinov3") + + +class MetricLogger(object): + def __init__(self, delimiter="\t", output_file=None): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + self.output_file = output_file + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def dump_in_output_file(self, iteration, iter_time, data_time): + if self.output_file is None or not distributed.is_main_process(): + return + dict_to_dump = dict( + iteration=iteration, + iter_time=iter_time, + data_time=data_time, + ) + dict_to_dump.update({k: v.median for k, v in self.meters.items()}) + with open(self.output_file, "a") as f: + f.write(json.dumps(dict_to_dump) + "\n") + pass + + def log_every(self, iterable, print_freq, header=None, n_iterations=None, start_iteration=0): + i = start_iteration + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.6f}") + data_time = SmoothedValue(fmt="{avg:.6f}") + + if n_iterations is None: + n_iterations = len(iterable) + + space_fmt = ":" + str(len(str(n_iterations))) + "d" + + log_list = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_list += ["mem: {current_memory:.0f}"] + log_list += ["(max mem: {max_memory:.0f})"] + + log_msg = self.delimiter.join(log_list) + MB = 1024.0 * 1024.0 + for obj in iterable: + if i >= n_iterations: + break + + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == n_iterations - 1: + self.dump_in_output_file(iteration=i, iter_time=iter_time.avg, data_time=data_time.avg) + eta_seconds = iter_time.global_avg * (n_iterations - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + logger.info( + log_msg.format( + i, + n_iterations, + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + current_memory=torch.cuda.memory_allocated() / MB, + max_memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + logger.info( + log_msg.format( + i, + n_iterations, + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + s_it = total_time / n_iterations if n_iterations > 0 else 0 + logger.info("{} Total time: {} ({:.6f} s / it)".format(header, total_time_str, s_it)) + + +class SmoothedValue: + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, num=1): + self.deque.append(value) + self.count += num + self.total += value * num + + def synchronize_between_processes(self): + """ + Distributed synchronization of the metric + Warning: does not synchronize the deque! + """ + if not distributed.is_enabled(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + torch.distributed.barrier() + torch.distributed.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() # returns float("nan") when d is empty + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() # returns float("nan") when d is empty + + @property + def global_avg(self): + if self.count == 0: + return float("nan") + return self.total / self.count + + @property + def max(self): + if len(self.deque) == 0: + return float("nan") + return max(self.deque) + + @property + def value(self): + if len(self.deque) == 0: + return float("nan") + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/loss/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f52a54a7da4f5a48c22b71008136f4e6f6c3a92 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/loss/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from .dino_clstoken_loss import DINOLoss +from .gram_loss import GramLoss +from .ibot_patch_loss import iBOTPatchLoss +from .koleo_loss import KoLeoLoss, KoLeoLossDistributed diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/loss/dino_clstoken_loss.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/loss/dino_clstoken_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..40bc68d91b67e42090d6d1ff61569bcdab0f0a24 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/loss/dino_clstoken_loss.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import math + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn + +from dinov3.distributed import get_process_subgroup, get_subgroup_size + + +class DINOLoss(nn.Module): + def __init__( + self, + out_dim, + student_temp=0.1, + center_momentum=0.9, + ): + super().__init__() + self.student_temp = student_temp + self.center_momentum = center_momentum + self.register_buffer("center", torch.full((1, out_dim), math.nan)) + self.updated = True + self.reduce_handle = None + self.len_teacher_output = None + self.async_batch_center = None + + def init_weights(self) -> None: + self.center.zero_() + + @torch.no_grad() + def softmax_center_teacher(self, teacher_output, teacher_temp, update_centers=True): + if update_centers: + self.apply_center_update() + # teacher centering and sharpening + return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1) + + @torch.no_grad() + def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3): + # teacher_output: [batch, prototypes] + teacher_output = teacher_output.float() + world_size = get_subgroup_size() if dist.is_initialized() else 1 + Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper + B = Q.shape[1] * world_size # number of samples to assign + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + if dist.is_initialized(): + dist.all_reduce(sum_Q, group=get_process_subgroup()) + Q /= sum_Q + + for _ in range(n_iterations): + # normalize each row: total weight per prototype must be 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + if dist.is_initialized(): + dist.all_reduce(sum_of_rows, group=get_process_subgroup()) + Q /= sum_of_rows + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the colomns must sum to 1 so that Q is an assignment + return Q.t() + + def forward(self, student_logits, teacher_probs, ignore_diagonal=False): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + student_logits: [student crops, batch, prototypes] + teacher_probs: [teacher crops, batch, prototypes] must sum to 1 over the last dim + + loss = 0 + count = 0 + for each sample `b` in the batch: + for each student crop `s` of this sample: + for each teacher crop `t` of this sample: + if ignore_diagonal and s == t: + continue + loss += cross_entropy(softmax(student_logits[s, b] / student_temp), teacher_probs[t, b]) + count += 1 + return loss / count + """ + student_crops, B, K = student_logits.shape + teacher_crops, _, _ = teacher_probs.shape + student_logits = F.log_softmax(student_logits.float() / self.student_temp, dim=-1) + if not ignore_diagonal: + loss = -torch.einsum("s b k, t b k -> ", student_logits, teacher_probs) + return loss / (B * student_crops * teacher_crops) + else: + loss = -torch.einsum("s b k, t b k -> s t", student_logits, teacher_probs) + min_st = min(student_crops, teacher_crops) + loss = torch.diagonal_scatter(loss, loss.new_zeros(min_st)) + return loss.sum() / (B * student_crops * teacher_crops - B * min_st) + + @torch.no_grad() + def update_center(self, teacher_output): + self.reduce_center_update(teacher_output) + + @torch.no_grad() + def reduce_center_update(self, teacher_output): + self.updated = False + self.len_teacher_output = len(teacher_output) + self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True) + if dist.is_initialized(): + self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True, group=get_process_subgroup()) + + @torch.no_grad() + def apply_center_update(self): + if self.updated is False: + world_size = get_subgroup_size() if dist.is_initialized() else 1 + + if self.reduce_handle is not None: + self.reduce_handle.wait() + _t = self.async_batch_center / (self.len_teacher_output * world_size) + + self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) + + self.updated = True diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/loss/gram_loss.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/loss/gram_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4c2f0ac8726a137b0e963aba2a0f6bb1d19c3e6b --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/loss/gram_loss.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GramLoss(nn.Module): + """Implementation of the gram loss""" + + def __init__( + self, + apply_norm=True, + img_level=True, + remove_neg=True, + remove_only_teacher_neg=False, + ): + super().__init__() + + # Loss + self.mse_loss = torch.nn.MSELoss() + + # Parameters + self.apply_norm = apply_norm + self.remove_neg = remove_neg + self.remove_only_teacher_neg = remove_only_teacher_neg + + if self.remove_neg or self.remove_only_teacher_neg: + assert self.remove_neg != self.remove_only_teacher_neg + + def forward(self, output_feats, target_feats, img_level=True): + """Compute the MSE loss between the gram matrix of the input and target features. + + Args: + output_feats: Pytorch tensor (B, N, dim) or (B*N, dim) if img_level == False + target_feats: Pytorch tensor (B, N, dim) or (B*N, dim) if img_level == False + img_level: bool, if true gram computed at the image level only else over the entire batch + Returns: + loss: scalar + """ + + # Dimensions of the tensor should be (B, N, dim) + if img_level: + assert len(target_feats.shape) == 3 and len(output_feats.shape) == 3 + + # Float casting + output_feats = output_feats.float() + target_feats = target_feats.float() + + # SSL correlation + if self.apply_norm: + target_feats = F.normalize(target_feats, dim=-1) + + if not img_level and len(target_feats.shape) == 3: + # Flatten (B, N, D) into (B*N, D) + target_feats = target_feats.flatten(0, 1) + + # Compute similarities + target_sim = torch.matmul(target_feats, target_feats.transpose(-1, -2)) + + # Patch correlation + if self.apply_norm: + output_feats = F.normalize(output_feats, dim=-1) + + if not img_level and len(output_feats.shape) == 3: + # Flatten (B, N, D) into (B*N, D) + output_feats = output_feats.flatten(0, 1) + + # Compute similarities + student_sim = torch.matmul(output_feats, output_feats.transpose(-1, -2)) + + if self.remove_neg: + target_sim[target_sim < 0] = 0.0 + student_sim[student_sim < 0] = 0.0 + + elif self.remove_only_teacher_neg: + # Remove only the negative sim values of the teacher + target_sim[target_sim < 0] = 0.0 + student_sim[(student_sim < 0) & (target_sim < 0)] = 0.0 + + return self.mse_loss(student_sim, target_sim) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/loss/ibot_patch_loss.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/loss/ibot_patch_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..5412f0abf1675cb2f4d945bb4ef9e30ba32061c2 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/loss/ibot_patch_loss.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import math + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn + +from dinov3.distributed import get_process_subgroup, get_subgroup_size + + +def lossfunc(t, s, temp): # noqa: F811 + return torch.sum(t.float() * F.log_softmax(s.float() / temp, dim=-1), dim=-1) + + +class SinkhornKnoppTeacher(nn.Module): + """ + NOTE: This is a module and not a function in the `iBOTPatchLoss` class + This is because we want to torch.compile it, and torch.compil-ing a single + function with the `@torch.compile` decorator is bad. + It's better to `module.compile()` it, as we can control when we enable or + disable compilation globally. + """ + + @torch.no_grad() + def forward(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3): + teacher_output = teacher_output.float() + # world_size = dist.get_world_size() if dist.is_initialized() else 1 + Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper + # B = Q.shape[1] * world_size # number of samples to assign + B = n_masked_patches_tensor + dist.all_reduce(B, group=get_process_subgroup()) + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + if dist.is_initialized(): + dist.all_reduce(sum_Q, group=get_process_subgroup()) + Q /= sum_Q + + for _ in range(n_iterations): + # normalize each row: total weight per prototype must be 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + if dist.is_initialized(): + dist.all_reduce(sum_of_rows, group=get_process_subgroup()) + Q /= sum_of_rows + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the colomns must sum to 1 so that Q is an assignment + return Q.t() + + +class iBOTPatchLoss(nn.Module): + def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9): + super().__init__() + self.student_temp = student_temp + self.center_momentum = center_momentum + self.register_buffer("center", torch.full((1, 1, patch_out_dim), math.nan)) + self.updated = True + self.reduce_handle = None + self.len_teacher_patch_tokens = None + self.async_batch_center = None + self.sinkhorn_knopp_teacher = SinkhornKnoppTeacher() + self.sinkhorn_knopp_teacher.compile() + + def init_weights(self) -> None: + self.center.zero_() + + @torch.no_grad() + def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp, update_centers=True): + if update_centers: + self.apply_center_update() + return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1) + + def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + student_patch_tokens: (B, N, D) tensor + teacher_patch_tokens: (B, N, D) tensor + student_masks_flat: (B, N) tensor + """ + t = teacher_patch_tokens + s = student_patch_tokens + loss = lossfunc(t, s, self.student_temp) + loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0) + return -loss.mean() + + def forward_masked( + self, + student_patch_tokens_masked, + teacher_patch_tokens_masked, + student_masks_flat, + n_masked_patches=None, + masks_weight=None, + ): + t = teacher_patch_tokens_masked + s = student_patch_tokens_masked + # loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) + loss = lossfunc(t, s, self.student_temp) + if masks_weight is None: + masks_weight = ( + (1 / student_masks_flat.sum(-1).clamp(min=1.0)) + .unsqueeze(-1) + .expand_as(student_masks_flat)[student_masks_flat] + ) + if n_masked_patches is not None: + loss = loss[:n_masked_patches] + loss = loss * masks_weight + return -loss.sum() / student_masks_flat.shape[0] + + @torch.no_grad() + def update_center(self, teacher_patch_tokens): + self.reduce_center_update(teacher_patch_tokens) + + @torch.no_grad() + def reduce_center_update(self, teacher_patch_tokens): + self.updated = False + self.len_teacher_patch_tokens = len(teacher_patch_tokens) + self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True) + if dist.is_initialized(): + self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True, group=get_process_subgroup()) + + @torch.no_grad() + def apply_center_update(self): + if self.updated is False: + world_size = get_subgroup_size() if dist.is_initialized() else 1 + + if self.reduce_handle is not None: + self.reduce_handle.wait() + _t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size) + + self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) + + self.updated = True diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/loss/koleo_loss.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/loss/koleo_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..dee9194dbb0533b01a2012af6841606b6186847f --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/loss/koleo_loss.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import torch +import torch.distributed as torch_dist +import torch.nn as nn +import torch.nn.functional as F + +import dinov3.distributed as dist + + +class KoLeoLoss(nn.Module): + """Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search""" + + def __init__(self): + super().__init__() + self.pdist = nn.PairwiseDistance(2, eps=1e-8) + + def pairwise_NNs_inner(self, x): + """ + Pairwise nearest neighbors for L2-normalized vectors. + Uses Torch rather than Faiss to remain on GPU. + """ + # parwise dot products (= inverse distance) + dots = torch.mm(x, x.t()) + n = x.shape[0] + dots.view(-1)[:: (n + 1)].fill_(-1) # Trick to fill diagonal with -1 + _, indices = torch.max(dots, dim=1) # max inner prod -> min distance + return indices + + def forward(self, student_output, eps=1e-8): + """ + Args: + student_output (BxD): backbone output of student + """ + with torch.autocast("cuda", enabled=False): + student_output = F.normalize(student_output, eps=eps, p=2, dim=-1) + indices = self.pairwise_NNs_inner(student_output) + distances = self.pdist(student_output, student_output[indices]) # BxD, BxD -> B + loss = -torch.log(distances + eps).mean() + return loss + + +class KoLeoLossDistributed(nn.Module): + """Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search""" + + def __init__(self, topk=1, loss_group_size: int | None = None): + super().__init__() + self.pdist = nn.PairwiseDistance(2, eps=1e-8) + self.topk = topk + self.loss_group_size = loss_group_size # Size of the nearest neighbor set. If None, uses global batch size. + + def pairwise_NNs_inner(self, x, all_x, rank): + """ + Pairwise nearest neighbors for L2-normalized vectors. + Uses Torch rather than Faiss to remain on GPU. + """ + # parwise dot products (= inverse distance) + dots = torch.mm(x, all_x.t()) # local_B x global_B + local_B, global_B = dots.shape + dots.view(-1)[rank * local_B :: (global_B + 1)].fill_(-1) # Trick to fill diagonal with -1 + _, indices = torch.topk(dots, dim=1, k=self.topk) # max inner prod -> min distance + return indices + + def forward(self, student_output, eps=1e-8): + """ + Args: + student_output (BxD): backbone output of student + """ + with torch.autocast("cuda", enabled=False): + student_output = F.normalize(student_output, eps=eps, p=2, dim=-1) # local_B x D + + if dist.is_enabled(): + all_student_outputs = torch.cat(torch_dist.nn.all_gather(student_output), dim=0) # global_B x D + world_size = dist.get_world_size() + rank = dist.get_rank() + else: + all_student_outputs = student_output + world_size = 1 + rank = 0 + + # Group the global batch into groups of size `loss_group_size` and use the features of the group + # the local rank falls into as the nearest neighbor set for the local rank + local_B = len(student_output) + global_B = len(all_student_outputs) + loss_group_size = self.loss_group_size if self.loss_group_size is not None else global_B + if loss_group_size % local_B != 0: + raise ValueError( + f"Loss group size size {loss_group_size} must be a multiple of local batch size {local_B}." + ) + if global_B % loss_group_size != 0: + raise ValueError( + f"Global batch size {global_B} must be divisible by loss group size {loss_group_size}." + ) + n_groups = global_B // loss_group_size + ranks_per_group = world_size // n_groups + rank_in_group = rank % ranks_per_group + group = rank // ranks_per_group + all_student_outputs = all_student_outputs.view(n_groups, loss_group_size, student_output.shape[1]) + all_student_outputs = all_student_outputs[group] # loss_group_size x D + + with torch.no_grad(): + indices = self.pairwise_NNs_inner(student_output, all_student_outputs, rank_in_group) # local_B x topk + + student_output_expanded = ( + student_output.unsqueeze(1).repeat(1, self.topk, 1).flatten(0, 1) + ) # (local_B * topk) x D + distances = self.pdist(student_output_expanded, all_student_outputs[indices].flatten(0, 1)) # BxD, BxD -> B + loss = -torch.log(distances.float() + eps).mean() + + return loss diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/models/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5427e2d3d729a3ac43af96cd48fd7ffeb33a868 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/models/__init__.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from pathlib import Path + +from typing import Union + +import torch +import torch.nn as nn + +from dinov3.layers.fp8_linear import convert_linears_to_fp8 + +from . import vision_transformer as vits + +logger = logging.getLogger("dinov3") + + +def init_fp8(model: nn.Module, args) -> nn.Module: + if not args.fp8_enabled: + logger.info("fp8 matmuls: OFF (disabled in config)") + return model + logger.info("fp8 matmuls: ON") + # Multi-kernel makes Inductor auto-tune between a regular "streaming"-based + # reduction kernel and a "persistent" reduction kernel. Since fp8 has some + # multi-pass steps (e.g., first get amax, then scale), persistent kernels + # should perform better. + torch._inductor.config.triton.multi_kernel = 1 + return convert_linears_to_fp8(model, filter=args.fp8_filter) + + +def build_model(args, only_teacher=False, img_size=224, device=None): + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + pos_embed_rope_base=args.pos_embed_rope_base, + pos_embed_rope_min_period=args.pos_embed_rope_min_period, + pos_embed_rope_max_period=args.pos_embed_rope_max_period, + pos_embed_rope_normalize_coords=args.pos_embed_rope_normalize_coords, + pos_embed_rope_shift_coords=args.pos_embed_rope_shift_coords, + pos_embed_rope_jitter_coords=args.pos_embed_rope_jitter_coords, + pos_embed_rope_rescale_coords=args.pos_embed_rope_rescale_coords, + qkv_bias=args.qkv_bias, + layerscale_init=args.layerscale, + norm_layer=args.norm_layer, + ffn_layer=args.ffn_layer, + ffn_bias=args.ffn_bias, + proj_bias=args.proj_bias, + n_storage_tokens=args.n_storage_tokens, + mask_k_bias=args.mask_k_bias, + untie_cls_and_patch_norms=args.untie_cls_and_patch_norms, + untie_global_and_local_cls_norm=args.untie_global_and_local_cls_norm, + device=device, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + teacher = init_fp8(teacher, args) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + ) + embed_dim = student.embed_dim + else: + raise NotImplementedError(f"Unrecognized architecture {args.arch}") + student = init_fp8(student, args) + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher: bool = False): + outputs = build_model( + cfg.student, + only_teacher=only_teacher, + img_size=cfg.crops.global_crops_size + if isinstance(cfg.crops.global_crops_size, int) + else max(cfg.crops.global_crops_size), + device="meta", + ) + if only_teacher: + teacher, embed_dim = outputs + return teacher, embed_dim + else: + student, teacher, embed_dim = outputs + return student, teacher, embed_dim + + +def build_model_for_eval( + config, + pretrained_weights: Union[str, Path] | None, + shard_unsharded_model: bool = False, # If the model is not sharded, shard it. No effect if already sharded on disk +): + model, _ = build_model_from_cfg(config, only_teacher=True) + if pretrained_weights is None or pretrained_weights == "": + logger.info("No pretrained weights") + model.init_weights() + elif Path(pretrained_weights).is_dir(): + logger.info("PyTorch DCP checkpoint") + from dinov3.checkpointer import load_checkpoint + from dinov3.fsdp.ac_compile_parallelize import ac_compile_parallelize + + moduledict = nn.ModuleDict({"backbone": model}) + # Wrap with FSDP + ac_compile_parallelize(moduledict, inference_only_models=[], cfg=config) + # Move to CUDA + model.to_empty(device="cuda") + # Load checkpoint + load_checkpoint(pretrained_weights, model=moduledict, strict_loading=True) + shard_unsharded_model = False + else: + logger.info("PyTorch consolidated checkpoint") + from dinov3.checkpointer import init_model_from_checkpoint_for_evals + + # consolidated checkpoint codepath + model.to_empty(device="cuda") + init_model_from_checkpoint_for_evals(model, pretrained_weights, "teacher") + if shard_unsharded_model: + logger.info("Sharding model") + moduledict = nn.ModuleDict({"backbone": model}) + ac_compile_parallelize(moduledict, inference_only_models=[], cfg=config) + model.eval() + return model diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/models/convnext.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/models/convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..7271ae51863e0215a6533624b396e955b7349f99 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/models/convnext.py @@ -0,0 +1,340 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from functools import partial +from typing import Dict, List, Optional, Sequence, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.nn.init +from torch import Tensor, nn + + +logger = logging.getLogger("dinov3") + + +def drop_path(x: Tensor, drop_prob: float = 0.0, training: bool = False) -> Tensor: + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None) -> None: + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x: Tensor) -> Tensor: + return drop_path(x, self.drop_prob, self.training) + + +class Block(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + + Source: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py + """ + + def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class LayerNorm(nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + + Source: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class ConvNeXt(nn.Module): + r""" + Code adapted from https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.pyConvNeXt + + A PyTorch impl of : `A ConvNet for the 2020s` - + https://arxiv.org/pdf/2201.03545.pdf + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + patch_size (int | None): Pseudo patch size. Used to resize feature maps to those of a ViT with a given patch size. If None, no resizing is performed + """ + + def __init__( + self, + # original ConvNeXt arguments + in_chans: int = 3, + depths: List[int] = [3, 3, 9, 3], + dims: List[int] = [96, 192, 384, 768], + drop_path_rate: float = 0.0, + layer_scale_init_value: float = 1e-6, + # DINO arguments + patch_size: int | None = None, + **ignored_kwargs, + ): + super().__init__() + if len(ignored_kwargs) > 0: + logger.warning(f"Ignored kwargs: {ignored_kwargs}") + del ignored_kwargs + + # ==== ConvNeXt's original init ===== + self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers + stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + ) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks + dp_rates = [x for x in np.linspace(0, drop_path_rate, sum(depths))] + cur = 0 + for i in range(4): + stage = nn.Sequential( + *[ + Block(dim=dims[i], drop_path=dp_rates[cur + j], layer_scale_init_value=layer_scale_init_value) + for j in range(depths[i]) + ] + ) + self.stages.append(stage) + cur += depths[i] + + self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer + # ==== End of ConvNeXt's original init ===== + + # ==== DINO adaptation ==== + self.head = nn.Identity() # remove classification head + self.embed_dim = dims[-1] + self.embed_dims = dims # per layer dimensions + self.n_blocks = len(self.downsample_layers) # 4 + self.chunked_blocks = False + self.n_storage_tokens = 0 # no registers + + self.norms = nn.ModuleList([nn.Identity() for i in range(3)]) + self.norms.append(self.norm) + + self.patch_size = patch_size + self.input_pad_size = 4 # first convolution with kernel_size = 4, stride = 4 + + def init_weights(self): + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.LayerNorm): + module.reset_parameters() + if isinstance(module, LayerNorm): + module.weight = nn.Parameter(torch.ones(module.normalized_shape)) + module.bias = nn.Parameter(torch.zeros(module.normalized_shape)) + if isinstance(module, (nn.Conv2d, nn.Linear)): + torch.nn.init.trunc_normal_(module.weight, std=0.02) + nn.init.constant_(module.bias, 0) + + def forward_features(self, x: Tensor | List[Tensor], masks: Optional[Tensor] = None) -> List[Dict[str, Tensor]]: + if isinstance(x, torch.Tensor): + return self.forward_features_list([x], [masks])[0] + else: + return self.forward_features_list(x, masks) + + def forward_features_list(self, x_list: List[Tensor], masks_list: List[Tensor]) -> List[Dict[str, Tensor]]: + output = [] + for x, masks in zip(x_list, masks_list): + h, w = x.shape[-2:] + for i in range(4): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + x_pool = x.mean([-2, -1]) # global average pooling, (N, C, H, W) -> (N, C) + x = torch.flatten(x, 2).transpose(1, 2) + + # concat [CLS] and patch tokens as (N, HW + 1, C), then normalize + x_norm = self.norm(torch.cat([x_pool.unsqueeze(1), x], dim=1)) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_storage_tokens": x_norm[:, 1 : self.n_storage_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.n_storage_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + + return output + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + def _get_intermediate_layers(self, x, n=1): + h, w = x.shape[-2:] + output, total_block_len = [], len(self.downsample_layers) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i in range(total_block_len): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + if i in blocks_to_take: + x_pool = x.mean([-2, -1]) + x_patches = x + if self.patch_size is not None: + # Resize output feature maps to that of a ViT with given patch_size + x_patches = nn.functional.interpolate( + x, + size=(h // self.patch_size, w // self.patch_size), + mode="bilinear", + antialias=True, + ) + output.append( + [ + x_pool, # CLS (B x C) + x_patches, # B x C x H x W + ] + ) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x, + n: Union[int, Sequence] = 1, # Layers or n last layers to take, + reshape: bool = False, + return_class_token: bool = False, + norm: bool = True, + ): + outputs = self._get_intermediate_layers(x, n) + + if norm: + nchw_shapes = [out[-1].shape for out in outputs] + if isinstance(n, int): + norms = self.norms[-n:] + else: + norms = [self.norms[i] for i in n] + outputs = [ + ( + norm(cls_token), # N x C + norm(patches.flatten(-2, -1).permute(0, 2, 1)), # N x HW x C + ) + for (cls_token, patches), norm in zip(outputs, norms) + ] + if reshape: + outputs = [ + (cls_token, patches.permute(0, 2, 1).reshape(*nchw).contiguous()) + for (cls_token, patches), nchw in zip(outputs, nchw_shapes) + ] + elif not reshape: + # force B x N x C format for patch tokens + outputs = [(cls_token, patches.flatten(-2, -1).permute(0, 2, 1)) for (cls_token, patches) in outputs] + class_tokens = [out[0] for out in outputs] + outputs = [out[1] for out in outputs] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + +convnext_sizes = { + "tiny": dict( + depths=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + ), + "small": dict( + depths=[3, 3, 27, 3], + dims=[96, 192, 384, 768], + ), + "base": dict( + depths=[3, 3, 27, 3], + dims=[128, 256, 512, 1024], + ), + "large": dict( + depths=[3, 3, 27, 3], + dims=[192, 384, 768, 1536], + ), +} + + +def get_convnext_arch(arch_name): + size_dict = None + query_sizename = arch_name.split("_")[1] + try: + size_dict = convnext_sizes[query_sizename] + except KeyError: + raise NotImplementedError("didn't recognize vit size string") + + return partial( + ConvNeXt, + **size_dict, + ) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/models/vision_transformer.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..51316e518850cd607ffddc40c40bd39e8f2a03c2 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/models/vision_transformer.py @@ -0,0 +1,416 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from functools import partial +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union + +import torch +import torch.nn.init +from torch import Tensor, nn + +from dinov3.layers import LayerScale, Mlp, PatchEmbed, RMSNorm, RopePositionEmbedding, SelfAttentionBlock, SwiGLUFFN +from dinov3.utils import named_apply + +logger = logging.getLogger("dinov3") + +ffn_layer_dict = { + "mlp": Mlp, + "swiglu": SwiGLUFFN, + "swiglu32": partial(SwiGLUFFN, align_to=32), + "swiglu64": partial(SwiGLUFFN, align_to=64), + "swiglu128": partial(SwiGLUFFN, align_to=128), +} + +norm_layer_dict = { + "layernorm": partial(nn.LayerNorm, eps=1e-6), + "layernormbf16": partial(nn.LayerNorm, eps=1e-5), + "rmsnorm": RMSNorm, +} + +dtype_dict = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +def init_weights_vit(module: nn.Module, name: str = ""): + if isinstance(module, nn.Linear): + torch.nn.init.trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + if hasattr(module, "bias_mask") and module.bias_mask is not None: + o = module.out_features + module.bias_mask.fill_(1) + module.bias_mask[o // 3 : 2 * o // 3].fill_(0) + if isinstance(module, nn.LayerNorm): + module.reset_parameters() + if isinstance(module, LayerScale): + module.reset_parameters() + if isinstance(module, PatchEmbed): + module.reset_parameters() + if isinstance(module, RMSNorm): + module.reset_parameters() + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + *, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + pos_embed_rope_base: float = 100.0, + pos_embed_rope_min_period: float | None = None, + pos_embed_rope_max_period: float | None = None, + pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate", + pos_embed_rope_shift_coords: float | None = None, + pos_embed_rope_jitter_coords: float | None = None, + pos_embed_rope_rescale_coords: float | None = None, + pos_embed_rope_dtype: str = "bf16", + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + ffn_ratio: float = 4.0, + qkv_bias: bool = True, + drop_path_rate: float = 0.0, + layerscale_init: float | None = None, + norm_layer: str = "layernorm", + ffn_layer: str = "mlp", + ffn_bias: bool = True, + proj_bias: bool = True, + n_storage_tokens: int = 0, + mask_k_bias: bool = False, + untie_cls_and_patch_norms: bool = False, + untie_global_and_local_cls_norm: bool = False, + device: Any | None = None, + **ignored_kwargs, + ): + super().__init__() + if len(ignored_kwargs) > 0: + logger.warning(f"Ignored kwargs: {ignored_kwargs}") + del ignored_kwargs + + norm_layer_cls = norm_layer_dict[norm_layer] + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + flatten_embedding=False, + ) + + self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, device=device)) + self.n_storage_tokens = n_storage_tokens + if self.n_storage_tokens > 0: + self.storage_tokens = nn.Parameter(torch.empty(1, n_storage_tokens, embed_dim, device=device)) + logger.info(f"using base={pos_embed_rope_base} for rope new") + logger.info(f"using min_period={pos_embed_rope_min_period} for rope new") + logger.info(f"using max_period={pos_embed_rope_max_period} for rope new") + logger.info(f"using normalize_coords={pos_embed_rope_normalize_coords} for rope new") + logger.info(f"using shift_coords={pos_embed_rope_shift_coords} for rope new") + logger.info(f"using rescale_coords={pos_embed_rope_rescale_coords} for rope new") + logger.info(f"using jitter_coords={pos_embed_rope_jitter_coords} for rope new") + logger.info(f"using dtype={pos_embed_rope_dtype} for rope new") + self.rope_embed = RopePositionEmbedding( + embed_dim=embed_dim, + num_heads=num_heads, + base=pos_embed_rope_base, + min_period=pos_embed_rope_min_period, + max_period=pos_embed_rope_max_period, + normalize_coords=pos_embed_rope_normalize_coords, + shift_coords=pos_embed_rope_shift_coords, + jitter_coords=pos_embed_rope_jitter_coords, + rescale_coords=pos_embed_rope_rescale_coords, + dtype=dtype_dict[pos_embed_rope_dtype], + device=device, + ) + logger.info(f"using {ffn_layer} layer as FFN") + ffn_layer_cls = ffn_layer_dict[ffn_layer] + ffn_ratio_sequence = [ffn_ratio] * depth + blocks_list = [ + SelfAttentionBlock( + dim=embed_dim, + num_heads=num_heads, + ffn_ratio=ffn_ratio_sequence[i], + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=drop_path_rate, + norm_layer=norm_layer_cls, + act_layer=nn.GELU, + ffn_layer=ffn_layer_cls, + init_values=layerscale_init, + mask_k_bias=mask_k_bias, + device=device, + ) + for i in range(depth) + ] + + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + # This norm is applied to everything, or when untying, to patch and mask tokens. + self.norm = norm_layer_cls(embed_dim) + + self.untie_cls_and_patch_norms = untie_cls_and_patch_norms + if untie_cls_and_patch_norms: + # When untying, this norm is applied to CLS tokens and registers. + self.cls_norm = norm_layer_cls(embed_dim) + else: + self.cls_norm = None + + self.untie_global_and_local_cls_norm = untie_global_and_local_cls_norm + if untie_global_and_local_cls_norm: + # When untying, this norm is applied to local CLS tokens and registers. + # This norm is never used during eval. + self.local_cls_norm = norm_layer_cls(embed_dim) + else: + self.local_cls_norm = None + self.head = nn.Identity() + self.mask_token = nn.Parameter(torch.empty(1, embed_dim, device=device)) + + def init_weights(self): + self.rope_embed._init_weights() + nn.init.normal_(self.cls_token, std=0.02) + if self.n_storage_tokens > 0: + nn.init.normal_(self.storage_tokens, std=0.02) + nn.init.zeros_(self.mask_token) + named_apply(init_weights_vit, self) + + def prepare_tokens_with_masks(self, x: Tensor, masks=None) -> Tuple[Tensor, Tuple[int]]: + x = self.patch_embed(x) + B, H, W, _ = x.shape + x = x.flatten(1, 2) + + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + cls_token = self.cls_token + else: + cls_token = self.cls_token + 0 * self.mask_token + if self.n_storage_tokens > 0: + storage_tokens = self.storage_tokens + else: + storage_tokens = torch.empty( + 1, + 0, + cls_token.shape[-1], + dtype=cls_token.dtype, + device=cls_token.device, + ) + + x = torch.cat( + [ + cls_token.expand(B, -1, -1), + storage_tokens.expand(B, -1, -1), + x, + ], + dim=1, + ) + + return x, (H, W) + + def forward_features_list(self, x_list: List[Tensor], masks_list: List[Tensor]) -> List[Dict[str, Tensor]]: + x = [] + rope = [] + for t_x, t_masks in zip(x_list, masks_list): + t2_x, hw_tuple = self.prepare_tokens_with_masks(t_x, t_masks) + x.append(t2_x) + rope.append(hw_tuple) + for _, blk in enumerate(self.blocks): + if self.rope_embed is not None: + rope_sincos = [self.rope_embed(H=H, W=W) for H, W in rope] + else: + rope_sincos = [None for r in rope] + x = blk(x, rope_sincos) + all_x = x + output = [] + for idx, (x, masks) in enumerate(zip(all_x, masks_list)): + if self.untie_cls_and_patch_norms or self.untie_global_and_local_cls_norm: + if self.untie_global_and_local_cls_norm and self.training and idx == 1: + # Assume second entry of list corresponds to local crops. + # We only ever apply this during training. + x_norm_cls_reg = self.local_cls_norm(x[:, : self.n_storage_tokens + 1]) + elif self.untie_cls_and_patch_norms: + x_norm_cls_reg = self.cls_norm(x[:, : self.n_storage_tokens + 1]) + else: + x_norm_cls_reg = self.norm(x[:, : self.n_storage_tokens + 1]) + x_norm_patch = self.norm(x[:, self.n_storage_tokens + 1 :]) + else: + x_norm = self.norm(x) + x_norm_cls_reg = x_norm[:, : self.n_storage_tokens + 1] + x_norm_patch = x_norm[:, self.n_storage_tokens + 1 :] + output.append( + { + "x_norm_clstoken": x_norm_cls_reg[:, 0], + "x_storage_tokens": x_norm_cls_reg[:, 1:], + "x_norm_patchtokens": x_norm_patch, + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x: Tensor | List[Tensor], masks: Optional[Tensor] = None) -> List[Dict[str, Tensor]]: + if isinstance(x, torch.Tensor): + return self.forward_features_list([x], [masks])[0] + else: + return self.forward_features_list(x, masks) + + def _get_intermediate_layers_not_chunked(self, x: Tensor, n: int = 1) -> List[Tensor]: + x, (H, W) = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + if self.rope_embed is not None: + rope_sincos = self.rope_embed(H=H, W=W) + else: + rope_sincos = None + x = blk(x, rope_sincos) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + *, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + return_extra_tokens: bool = False, + norm: bool = True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, ...]]]: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs_normed = [] + for out in outputs: + if self.untie_cls_and_patch_norms: + x_norm_cls_reg = self.cls_norm(out[:, : self.n_storage_tokens + 1]) + x_norm_patch = self.norm(out[:, self.n_storage_tokens + 1 :]) + outputs_normed.append(torch.cat((x_norm_cls_reg, x_norm_patch), dim=1)) + else: + outputs_normed.append(self.norm(out)) + outputs = outputs_normed + class_tokens = [out[:, 0] for out in outputs] + extra_tokens = [out[:, 1 : self.n_storage_tokens + 1] for out in outputs] + outputs = [out[:, self.n_storage_tokens + 1 :] for out in outputs] + if reshape: + B, _, h, w = x.shape + outputs = [ + out.reshape(B, h // self.patch_size, w // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if not return_class_token and not return_extra_tokens: + return tuple(outputs) + elif return_class_token and not return_extra_tokens: + return tuple(zip(outputs, class_tokens)) + elif not return_class_token and return_extra_tokens: + return tuple(zip(outputs, extra_tokens)) + elif return_class_token and return_extra_tokens: + return tuple(zip(outputs, class_tokens, extra_tokens)) + + def forward(self, *args, is_training: bool = False, **kwargs) -> List[Dict[str, Tensor]] | Tensor: + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def vit_small(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + ffn_ratio=4, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + ffn_ratio=4, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + ffn_ratio=4, + **kwargs, + ) + return model + + +def vit_so400m(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1152, + depth=27, + num_heads=18, + ffn_ratio=3.777777778, + **kwargs, + ) + return model + + +def vit_huge2(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1280, + depth=32, + num_heads=20, + ffn_ratio=4, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + ffn_ratio=4, + **kwargs, + ) + return model + + +def vit_7b(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=4096, + depth=40, + num_heads=32, + ffn_ratio=3, + **kwargs, + ) + return model diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/run/init.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/run/init.py new file mode 100644 index 0000000000000000000000000000000000000000..b3d96d718f5b3b63702cd0b40bbaefcab0dfecca --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/run/init.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import contextlib +from datetime import timedelta +from typing import Optional + +from dinov3.configs import exit_job, setup_job + + +@contextlib.contextmanager +def job_context( + output_dir: Optional[str] = None, + distributed_enabled: bool = True, + logging_enabled: bool = True, + seed: Optional[int] = 0, + restrict_print_to_main_process: bool = True, + distributed_timeout: timedelta | None = None, +): + setup_job( + output_dir=output_dir, + distributed_enabled=distributed_enabled, + logging_enabled=logging_enabled, + seed=seed, + restrict_print_to_main_process=restrict_print_to_main_process, + distributed_timeout=distributed_timeout, + ) + try: + yield + finally: + exit_job(distributed_enabled=distributed_enabled, logging_enabled=logging_enabled) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/run/submit.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/run/submit.py new file mode 100644 index 0000000000000000000000000000000000000000..5ad14edc26b02dd16e82bb636ac9cd43467082cd --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/run/submit.py @@ -0,0 +1,208 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import argparse +import logging +import os +from pathlib import Path + +from dinov3.logging import setup_logging +from dinov3.utils.cluster import ( + get_slurm_account, + get_slurm_executor_parameters, + get_slurm_partition, + get_slurm_qos, + get_user_checkpoint_path, +) +from dinov3.utils.custom_callable import load_custom_callable + +logger = logging.getLogger("dinov3") + + +def get_submitit_parser(): + slurm_partition = get_slurm_partition() + slurm_account = get_slurm_account() + slurm_qos = get_slurm_qos() + parser = argparse.ArgumentParser("Submitit arguments", add_help=False) + parser.add_argument( + "--ngpus", + default=8, + type=int, + help="Number of gpus to request on each node, default: %(default)s", + ) + parser.add_argument( + "--nodes", + default=1, + type=int, + help="Number of nodes to request, default: %(default)s", + ) + parser.add_argument( + "--timeout", + default=2800, + type=int, + help="Duration of the job, default: %(default)s", + ) + parser.add_argument( + "--slurm-partition", + default=slurm_partition, + type=str, + help="Partition where to submit, default: %(default)s", + ) + parser.add_argument( + "--slurm-qos", + default=slurm_qos, + metavar="SLURM_QOS", + type=str, + dest="slurm_qos", + help="slurm QoS to use for jobs in cluster environment, default: %(default)s", + ) + parser.add_argument( + "--slurm-array-parallelism", + default=256, + type=int, + help="Maximum number of jobs that will be executed in parallel, default: %(default)s", + ) + parser.add_argument( + "--slurm-nice", + default=0, + type=int, + help="Adjusted scheduling priority within Slurm, default: %(default)s", + ) + parser.add_argument( + "--slurm-account", + default=slurm_account, + type=str, + help="Slurm account name, default: %(default)s", + ) + parser.add_argument( + "--comment", + default="", + type=str, + help="Comment to pass to scheduler, e.g. priority message, default: '%(default)s'", + ) + parser.add_argument( + "--exclude", + default="", + type=str, + help="Nodes to exclude, default: '%(default)s'", + ) + parser.add_argument( + "--output-dir", + type=str, + help="output dir", + ) + return parser + + +def get_run_parser(): + parser = argparse.ArgumentParser("Launcher arguments", parents=[get_submitit_parser()]) + parser.add_argument( + "module_path", + type=str, + help="Full path to the program/script to be launched in parallel, " + "followed by all the arguments for the training script.", + ) + parser.add_argument( + "--callable-name", + type=str, + default="main", + help="Name of the callable to execute in the script", + ) + return parser + + +def get_shared_folder() -> Path: + user_checkpoint_path = get_user_checkpoint_path() + if user_checkpoint_path is None: + raise RuntimeError("Path to user checkpoint cannot be determined") + path = user_checkpoint_path / "experiments" + path.mkdir(exist_ok=True) + return path + + +class CheckpointableSubmitter: + def __init__(self, module_path, callable_name, args, output_dir): + self.args = args + self.callable_name = callable_name + self.module_path = os.path.realpath(module_path) + self.output_dir = os.path.realpath(output_dir) + + def __call__(self): + self._setup_args() + callable_ = load_custom_callable(self.module_path, self.callable_name) + callable_(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.callable_name} from {self.module_path} with {self.args}") + empty_class = type(self)(self.module_path, self.callable_name, self.args, self.output_dir) + return submitit.helpers.DelayedSubmission(empty_class) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.output_dir = str(self.output_dir).replace("%j", str(job_env.job_id)) + if "--output-dir" not in self.args: + self.args.insert(0, f"--output-dir={self.output_dir}") + + # Setup logging with exact same arguments as in fairvit/run/init.py + # to use lru_cache memoization and avoid setting up the logger twice + setup_logging(output=self.output_dir, level=logging.INFO) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Module Path: {self.module_path}") + logger.info(f"Callable Name: {self.callable_name}") + logger.info(f'Args: {" ".join(self.args)}') + + +def submit_jobs(class_to_submit, output_dir, submitit_args, name="fairvit"): + import submitit + + Path(output_dir).mkdir(parents=True, exist_ok=True) + executor = submitit.AutoExecutor(folder=output_dir, slurm_max_num_timeout=30) + + kwargs = {} + if submitit_args.comment: + kwargs["slurm_comment"] = submitit_args.comment + if submitit_args.exclude: + kwargs["slurm_exclude"] = submitit_args.exclude + + executor_params = get_slurm_executor_parameters( + nodes=submitit_args.nodes, + num_gpus_per_node=submitit_args.ngpus, + timeout_min=submitit_args.timeout, # max is 60 * 72 + slurm_signal_delay_s=120, + slurm_partition=submitit_args.slurm_partition, + slurm_qos=submitit_args.slurm_qos, + # slurm_account=submitit_args.slurm_account, + slurm_additional_parameters=dict(nice=submitit_args.slurm_nice), + **kwargs, + ) + executor.update_parameters(name=name, **executor_params) + job = executor.submit(class_to_submit) + + logger.info(f"Submitted job_id: {job.job_id}") + str_output_dir = os.path.abspath(output_dir).replace("%j", str(job.job_id)) + logger.info(f"Logs and checkpoints will be saved at: {str_output_dir}") + + +def main(): + setup_logging(level=logging.INFO) + args, script_args = get_run_parser().parse_known_args() + assert os.path.exists(args.module_path), "The module path does not exist" + + file_name = os.path.splitext(os.path.split(args.module_path)[1])[0] + name = f"{file_name}:{args.callable_name}" + + if args.output_dir is None: + args.output_dir = get_shared_folder() / "%j" + + class_to_submit = CheckpointableSubmitter(args.module_path, args.callable_name, script_args, args.output_dir) + submit_jobs(class_to_submit, args.output_dir, args, name=name) + + +if __name__ == "__main__": + main() diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/thirdparty/CLIP/clip/simple_tokenizer.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/thirdparty/CLIP/clip/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e424249d57012833247c058e9e9f581e2b8097c9 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/thirdparty/CLIP/clip/simple_tokenizer.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +# References: +# https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py + +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + vocab.extend(["<|startoftext|>", "<|endoftext|>"]) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"} + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("", " ") + return text diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..39e2108eb60628bb32958c7b373e76f68a10e891 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from .multidist_meta_arch import MultiDistillationMetaArch +from .ssl_meta_arch import SSLMetaArch +from .train import get_args_parser, main diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/cosine_lr_scheduler.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/cosine_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..2c61d7d01f573bca1e1111586244ef00587dd1f3 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/cosine_lr_scheduler.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging + +import numpy as np + +logger = logging.getLogger("dinov3") + + +class CosineScheduler(object): + def __init__( + self, + base_value, + final_value, + total_iters, + warmup_iters=0, + start_warmup_value=0, + freeze_iters=0, + trunc_extra=0.0, + ): + super().__init__() + self.final_value = np.float64(final_value) + self.total_iters = total_iters + + freeze_schedule = np.zeros((freeze_iters)) + + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + if trunc_extra == 0.0: + iters = np.arange(total_iters - warmup_iters - freeze_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + else: + cosine_steps = total_iters - warmup_iters - freeze_iters + iters = np.linspace(0, np.pi, int((1 + trunc_extra) * cosine_steps))[:cosine_steps] + schedule = np.cos(iters) + schedule = (schedule + 1) / 2 + schedule = (schedule - schedule[-1]) / (1 - schedule[-1]) + schedule = schedule * (base_value - final_value) + final_value + + self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule), dtype=np.float64) + + assert len(self.schedule) == self.total_iters + + def __getitem__(self, it): + if it >= self.total_iters: + return self.final_value + else: + return self.schedule[it] + + +def linear_warmup_cosine_decay( + start: float, + peak: float, + end: float, + warmup_iterations: int, + total_iterations: int, + cosine_iterations: int | None = None, +) -> np.ndarray: + """ + Create a learning rate schedule with linear warmup, a cosine, and an optional constant part in the end. + + Args: + start (float): Initial learning rate. + peak (float): Learning rate after linear warmup. + end (float): Final learning rate after cosine. + warmup_iterations (int): Number of iterations for linear warmup. + total_iterations (int): Total number of iterations for the schedule. + cosine_iterations (int | None): Number of iterations for cosine. + If None, cosine part will be over remaining iterations after warmup. + Returns: + np.ndarray: Learning rate schedule as a numpy array. + """ + linear = np.linspace(start, peak, warmup_iterations, endpoint=False) + if cosine_iterations is None: + cosine_iterations = total_iterations - warmup_iterations + cosine = np.cos(np.linspace(0, np.pi, cosine_iterations)) + cosine = (cosine + 1) / 2 + cosine = (peak - end) * cosine + end + remaining_iterations = total_iterations - cosine_iterations - warmup_iterations + assert remaining_iterations >= 0 + constant = np.full((remaining_iterations,), fill_value=end) + return np.concatenate([linear, cosine, constant]) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/multidist_meta_arch.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/multidist_meta_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..d2238c4a9f974134dd2bf98e6df41de73fdea49a --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/multidist_meta_arch.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging + +import torch +from torch import Tensor + +from .ssl_meta_arch import SSLMetaArch + +logger = logging.getLogger("dinov3") + + +class MultiDistillationMetaArch(SSLMetaArch): + """ + Multidistillation version of SSLMetaArchCompilableGram: + - baked-in scales for DINO, KOLEO, and IBOT losses + - always global and local crops + - always separate heads for DINO and IBOT + - always sinkhorn-knopp centering for DINO and IBOT + - always per-GPU computation of KOLEO loss (non-distributed) + - DINO, IBOT, and KOLEO are always computed even if their weight is 0.0 + """ + + def forward_backward( + self, data, *, teacher_temp, iteration: int = 0, **ignored_kwargs + ) -> tuple[Tensor, dict[str, float | Tensor]]: + del ignored_kwargs + metrics_dict = {} + + # Shapes + n_global_crops = 2 + n_local_crops = self.n_local_crops # self.cfg.crops.local_crops_number + B_teacher = B = data["collated_local_crops"].shape[0] // n_local_crops + assert data["collated_global_crops"].shape[0] == n_global_crops * B + metrics_dict["batch_size"] = B + + global_crops = data["collated_global_crops"].cuda(non_blocking=True) + local_crops = data["collated_local_crops"].cuda(non_blocking=True) + masks = data["collated_masks"].cuda(non_blocking=True) + mask_indices_list = data["mask_indices_list"].cuda(non_blocking=True) + masks_weight = data["masks_weight"].cuda(non_blocking=True) + n_masked_patches_tensor = data["n_masked_patches"].cuda(non_blocking=True) + global_batch_size = data["global_batch_size"] + + # Multidistillation codepath: + global_crops_subgroup = self.broadcast_to_subgroups( + global_crops.view(n_global_crops, -1, *global_crops.shape[1:]), + 1, + global_batch_size=global_batch_size, + ).view(-1, *global_crops.shape[1:]) + local_crops_subgroup = self.broadcast_to_subgroups( + local_crops.view(n_local_crops, -1, *local_crops.shape[1:]), + 1, + global_batch_size=global_batch_size, + ).view(-1, *local_crops.shape[1:]) + B = local_crops_subgroup.shape[0] // n_local_crops + + # Teacher output (will trigger an all-gather to unshard) + teacher_global = self.get_teacher_output( + global_crops.unflatten(0, (n_global_crops, B_teacher)), + teacher_temp=teacher_temp, + n_masked_patches_tensor=n_masked_patches_tensor, + mask_indices_list=mask_indices_list, + upperbound=data["upperbound"], + global_batch_size=global_batch_size, + ) + + # Student output (will trigger an all-gather to unshard) + student_global, student_local = self.get_student_output( + global_crops=global_crops_subgroup.unflatten(0, (n_global_crops, B)), + local_crops=local_crops_subgroup.unflatten(0, (n_local_crops, B)), + upperbound=data["upperbound"], + masks=masks, + mask_indices_list=mask_indices_list, + ) + # End of multidistillation codepath + + # Compute losses and backprop + loss_accumulator, loss_dict = self.compute_losses( + teacher_global=teacher_global, + student_global=student_global, + student_local=student_local, + masks=masks, + mask_indices_list=mask_indices_list, + masks_weight=masks_weight, + gram_global=None, + iteration=iteration, + ) + + self.backprop_loss(loss_accumulator) + + # Return total weighted loss and a dict of metrics to log + return loss_accumulator, metrics_dict | loss_dict + + @torch.no_grad() + def get_teacher_output( + self, + images, + *, + upperbound, + mask_indices_list, + teacher_temp, + n_masked_patches_tensor, + global_batch_size, + ): + n_crops, B_teacher, rgb, H, W = images.shape + + backbone_out = self.teacher.backbone(images.flatten(0, 1), is_training=True) + cls = backbone_out["x_norm_clstoken"] # [n_crops * B, D] + reg = backbone_out["x_storage_tokens"] # [n_crops * B, R, D] + ibot_patch = backbone_out["x_norm_patchtokens"] # [n_crops * B, P, D] + + R, D = reg.shape[-2:] + + # Multidistillation codepath: + # IBOT head only on patches that are masked for the student + n_tokens = ibot_patch.shape[1] + masked_patch_after_head = self.teacher.ibot_head(ibot_patch.flatten(0, 1), no_last_layer=True) + masked_patch_after_head = masked_patch_after_head.view(n_crops, -1, *masked_patch_after_head.shape[1:]) + masked_patch_after_head = self.broadcast_to_subgroups( + masked_patch_after_head, + over_dim=1, + global_batch_size=global_batch_size * n_tokens, + ) + buffer = torch.index_select(masked_patch_after_head.flatten(0, 1), dim=0, index=mask_indices_list) + masked_patch_after_head = self.teacher.ibot_head(buffer, only_last_layer=True) + + # DINO head on CLS tokens + cls_after_head = self.teacher.dino_head(cls, no_last_layer=True) # [n_crops * B, K] + cls_after_head = cls_after_head.view(n_crops, -1, *cls_after_head.shape[1:]) + cls_after_head = self.broadcast_to_subgroups(cls_after_head, over_dim=1, global_batch_size=global_batch_size) + B = cls_after_head.shape[1] + cls_after_head = cls_after_head.flatten(0, 1) + cls_after_head = self.teacher.dino_head(cls_after_head, only_last_layer=True) # [n_crops * B, K] + # End of multidistillation codepath + + # Center with sinkhorn-knopp + cls_centered = self.dino_loss.sinkhorn_knopp_teacher( + cls_after_head, teacher_temp=teacher_temp + ) # [n_crops * B, K] + cls_centered = cls_centered.unflatten(0, (n_crops, B)) # [n_crops, B, K] + masked_patch_centered = self.ibot_patch_loss.sinkhorn_knopp_teacher( + masked_patch_after_head, + teacher_temp=teacher_temp, + n_masked_patches_tensor=n_masked_patches_tensor, + ) # [n_masked_patches, K] + + return { + "cls_after_head": cls_after_head.unflatten(0, [n_crops, B]), # [n_crops, B, K] + "cls_centered": cls_centered, # [n_crops, B, K] + "masked_patch_centered": masked_patch_centered, # [n_masked_patches, K] + } diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/param_groups.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/param_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..d8712a82042db21ae1e21e24f4ed1968205a9d32 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/param_groups.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +from collections import defaultdict + +logger = logging.getLogger("dinov3") + + +def get_vit_lr_decay_rate( + name, + lr_decay_rate=1.0, + num_layers=12, + force_is_backbone=False, + chunked_blocks=False, +): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone") or force_is_backbone: + if ( + ".pos_embed" in name + or ".patch_embed" in name + or ".mask_token" in name + or ".cls_token" in name + or ".storage_tokens" in name + ): + layer_id = 0 + elif force_is_backbone and ( + "pos_embed" in name + or "patch_embed" in name + or "mask_token" in name + or "cls_token" in name + or "storage_tokens" in name + ): + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + elif chunked_blocks and "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 + elif "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0, dino_head_wd_multiplier=1.0): + chunked_blocks = False + if hasattr(model, "n_blocks"): + logger.info("chunked fsdp") + n_blocks = model.n_blocks + chunked_blocks = model.chunked_blocks + elif hasattr(model, "blocks"): + logger.info("first code branch") + n_blocks = len(model.blocks) + elif hasattr(model, "backbone"): + logger.info("second code branch") + n_blocks = len(model.backbone.blocks) + else: + logger.info("else code branch") + n_blocks = 0 + all_param_groups = [] + + for name, param in model.named_parameters(): + name = remove_fsdp_compile_names(name) + if not param.requires_grad: + continue + decay_rate = get_vit_lr_decay_rate( + name, + lr_decay_rate, + num_layers=n_blocks, + force_is_backbone=n_blocks > 0, + chunked_blocks=chunked_blocks, + ) + d = { + "name": name, + "params": param, + "is_last_layer": False, + "lr_multiplier": decay_rate, + "wd_multiplier": 1.0, + } + + if "dino_head" in name: + d["wd_multiplier"] = dino_head_wd_multiplier + + if "last_layer" in name: + d["is_last_layer"] = True + + # No weight-decay on biases, norm parameters, layer scale gamma, learned tokens and embeddings + if name.endswith("bias") or "norm" in name or "gamma" in name or "fourier_w" in name: + d["wd_multiplier"] = 0.0 + + if "patch_embed" in name: + d["lr_multiplier"] *= patch_embed_lr_mult + + all_param_groups.append(d) + logger.info(f"{name}: lr_multiplier: {d['lr_multiplier']}, wd_multiplier: {d['wd_multiplier']}") + + return all_param_groups + + +def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): + fused_params_groups = defaultdict(lambda: {"params": []}) + for d in all_params_groups: + identifier = "" + for k in keys: + identifier += k + str(d[k]) + "_" + + for k in keys: + fused_params_groups[identifier][k] = d[k] + fused_params_groups[identifier]["params"].append(d["params"]) + + return fused_params_groups.values() + + +def get_params_groups_with_decay_fsdp(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0, dino_head_wd_multiplier=1.0): + if hasattr(model, "module"): # SimpleFSDP + is_backbone = hasattr(model.module, "blocks") + n_blocks = len(model.module.blocks) if is_backbone else 0 + else: # FSDP2 + is_backbone = hasattr(model, "blocks") + n_blocks = len(model.blocks) if is_backbone else 0 + + all_param_groups = [] + + for name, param in model.named_parameters(): + name = remove_fsdp_compile_names(name) + if not param.requires_grad: + continue + decay_rate = get_vit_lr_decay_rate( + name, + lr_decay_rate, + num_layers=n_blocks, + force_is_backbone=n_blocks > 0, + chunked_blocks=False, + ) + d = { + "name": name, + "params": param, + "is_last_layer": False, + "lr_multiplier": decay_rate, + "wd_multiplier": 1.0, + } + + if "dino_head" in name: + d["wd_multiplier"] = dino_head_wd_multiplier + + if "last_layer" in name: + d["is_last_layer"] = True + + # No weight-decay on biases, norm parameters, layer scale gamma, learned tokens and embeddings + if name.endswith("bias") or "norm" in name or "gamma" in name or "fourier_w" in name: + d["wd_multiplier"] = 0.0 + + if "patch_embed" in name: + d["lr_multiplier"] *= patch_embed_lr_mult + + all_param_groups.append(d) + logger.info(f"{name}: lr_multiplier: {d['lr_multiplier']}, wd_multiplier: {d['wd_multiplier']}") + + return all_param_groups + + +def remove_fsdp_compile_names(name: str): + name = name.replace("_fsdp_wrapped_module.", "") # Added by FSDP + name = name.replace("_checkpoint_wrapped_module.", "") # Added by activation checkpointing for xFSDP + name = name.replace("parametrizations.", "") # Added by xFSDP + name = name.removesuffix(".original") # Added by xFSDP + name = name.replace("module.", "") # Added by xFSDP + name = name.replace("_orig_mod.", "") # Added by torch.compile + return name diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/ssl_meta_arch.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/ssl_meta_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..172d80b9ef2d87e534785b15904d0d9e472dcc27 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/ssl_meta_arch.py @@ -0,0 +1,815 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import gc +import logging +from functools import partial + +import torch +from omegaconf import OmegaConf +from torch import Tensor, nn + +import dinov3.distributed as distributed +from dinov3.checkpointer import init_fsdp_model_from_checkpoint +from dinov3.configs import get_default_config +from dinov3.data import DataAugmentationDINO +from dinov3.fsdp.ac_compile_parallelize import ac_compile_parallelize +from dinov3.layers.dino_head import DINOHead +from dinov3.loss import DINOLoss, GramLoss, KoLeoLoss, KoLeoLossDistributed, iBOTPatchLoss +from dinov3.models import build_model_from_cfg +from dinov3.train.cosine_lr_scheduler import linear_warmup_cosine_decay +from dinov3.train.param_groups import fuse_params_groups, get_params_groups_with_decay_fsdp +from dinov3.utils import count_parameters + +logger = logging.getLogger("dinov3") + + +class SSLMetaArch(nn.Module): + """ + Modified version of SSLMetaArchCompilable including gram loss: + - Gram loss is used only if gram.use_loss is set to true + """ + + def __init__(self, cfg): + super().__init__() + + # assert cfg.multidistillation.enabled is False + assert cfg.crops.local_crops_number > 0 + assert cfg.ibot.separate_head is True + assert cfg.train.centering == "sinkhorn_knopp" + + # For some reason FULL_SHARD doesn't work + assert cfg.compute_precision.sharding_strategy == "SHARD_GRAD_OP" + + self.cfg = cfg + + student_model_dict = dict() + teacher_model_dict = dict() + gram_model_dict = dict() + + student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg) + torch.cuda.empty_cache() + gc.collect() + gram_backbone, _ = build_model_from_cfg(cfg, only_teacher=True) + logger.info(f"Number of parameters: {count_parameters(student_backbone)}") + student_model_dict["backbone"] = student_backbone + teacher_model_dict["backbone"] = teacher_backbone + gram_model_dict["backbone"] = gram_backbone + logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}") + + self.embed_dim = embed_dim # D + self.dino_out_dim = cfg.dino.head_n_prototypes # K + + logger.info("OPTIONS -- DINO") + logger.info(f"OPTIONS -- DINO -- loss_weight: {cfg.dino.loss_weight}") + logger.info(f"OPTIONS -- DINO -- global_ignore_diagonal: {cfg.dino.global_ignore_diagonal}") + logger.info(f"OPTIONS -- DINO -- head_n_prototypes: {cfg.dino.head_n_prototypes}") + logger.info(f"OPTIONS -- DINO -- head_bottleneck_dim: {cfg.dino.head_bottleneck_dim}") + logger.info(f"OPTIONS -- DINO -- head_hidden_dim: {cfg.dino.head_hidden_dim}") + logger.info(f"OPTIONS -- DINO -- head_norm_last_layer: {cfg.dino.head_norm_last_layer}") + dino_head_class = partial( + DINOHead, + in_dim=embed_dim, + out_dim=cfg.dino.head_n_prototypes, + hidden_dim=cfg.dino.head_hidden_dim, + bottleneck_dim=cfg.dino.head_bottleneck_dim, + nlayers=cfg.dino.head_nlayers, + ) + student_model_dict["dino_head"] = dino_head_class() + teacher_model_dict["dino_head"] = dino_head_class() + self.dino_loss = DINOLoss(self.dino_out_dim) + + logger.info("OPTIONS -- KOLEO") + logger.info(f"OPTIONS -- KOLEO -- loss_weight: {cfg.dino.koleo_loss_weight}") + logger.info(f"OPTIONS -- KOLEO -- distributed: {cfg.dino.koleo_loss_distributed}") + if cfg.dino.koleo_loss_distributed: + logger.info(f"OPTIONS -- KOLEO -- topk: {cfg.dino.koleo_topk}") + logger.info( + f"OPTIONS -- KOLEO -- distributed_loss_group_size: {cfg.dino.koleo_distributed_loss_group_size}" + ) + assert cfg.dino.koleo_distributed_replicas == 0, ( + "Option `dino.koleo_distributed_replicas` is no longer supported" + ) + self.koleo_loss = KoLeoLossDistributed( + topk=cfg.dino.koleo_topk, + loss_group_size=cfg.dino.koleo_distributed_loss_group_size, + ) + else: + assert cfg.dino.koleo_topk == 1, "Non-distributed KoLeo loss only supports `dino.koleo_topk=1`" + self.koleo_loss = KoLeoLoss() + + logger.info("OPTIONS -- IBOT") + logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") + logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_ratio_tuple: {cfg.ibot.mask_ratio_min_max}") + logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_sample_probability: {cfg.ibot.mask_sample_probability}") + + assert 0 <= cfg.ibot.mask_ratio_min_max[0] < cfg.ibot.mask_ratio_min_max[1] <= 1, ( + "provide a valid cfg.ibot.mask_ratio_min_max" + ) + assert 0 <= cfg.ibot.mask_sample_probability <= 1, "provide a positive mask probability for ibot" + logger.info(f"OPTIONS -- IBOT -- head_n_prototypes: {cfg.ibot.head_n_prototypes}") + logger.info(f"OPTIONS -- IBOT -- head_bottleneck_dim: {cfg.ibot.head_bottleneck_dim}") + logger.info(f"OPTIONS -- IBOT -- head_hidden_dim: {cfg.ibot.head_hidden_dim}") + logger.info(f"OPTIONS -- IBOT -- head_norm_last_layer: {cfg.ibot.head_norm_last_layer}") + ibot_head_class = partial( + DINOHead, + in_dim=embed_dim, + out_dim=cfg.ibot.head_n_prototypes, + hidden_dim=cfg.ibot.head_hidden_dim, + bottleneck_dim=cfg.ibot.head_bottleneck_dim, + nlayers=cfg.ibot.head_nlayers, + ) + student_model_dict["ibot_head"] = ibot_head_class() + teacher_model_dict["ibot_head"] = ibot_head_class() + self.ibot_patch_loss = iBOTPatchLoss(cfg.ibot.head_n_prototypes) + + # Build student and teacher models + self.student = nn.ModuleDict(student_model_dict) + self.teacher = nn.ModuleDict(teacher_model_dict) + self.model_ema = self.teacher # this may be overwritten for distillation + logger.info(f"Student and Teacher are built: they are both {cfg.student.arch} network.") + + if cfg.distillation.enabled: + self._setup_distillation() + # No grad is needed for these two + self.teacher.requires_grad_(False) + self.model_ema.requires_grad_(False) + self.ema_params_lists = None + + # getting config params fixed: + self.n_local_crops = self.cfg.crops.local_crops_number + self.is_distillation_enabled = self.cfg.distillation.enabled + self.dino_global_ignore_diagonal = self.cfg.dino.global_ignore_diagonal + self.dino_loss_weight = self.cfg.dino.loss_weight + self.dino_koleo_loss_weight = self.cfg.dino.koleo_loss_weight + self.ibot_loss_weight = self.cfg.ibot.loss_weight + + # Local loss reweighting + if self.cfg.dino.reweight_dino_local_loss: + iter_per_epoch = cfg.train.OFFICIAL_EPOCH_LENGTH + total_iterations = iter_per_epoch * cfg.optim.epochs + schedule_cfg = cfg.dino.local_loss_weight_schedule + self.dino_local_loss_schedule = linear_warmup_cosine_decay( + start=schedule_cfg.start, + peak=schedule_cfg.peak, + end=schedule_cfg.end, + warmup_iterations=iter_per_epoch * schedule_cfg.warmup_epochs, + total_iterations=total_iterations, + cosine_iterations=( + iter_per_epoch * schedule_cfg.cosine_epochs if "cosine_epochs" in schedule_cfg else None + ), + ) + + # Gram + self.gram_use_loss = self.cfg.gram.use_loss + self.gram_ema_teacher = False + self.has_gram_teacher = False + self.gram_teacher_initialized = False + if self.gram_use_loss: + # Gram regularization + self.gram_loss = GramLoss( + apply_norm=self.cfg.gram.normalized, + remove_only_teacher_neg=self.cfg.gram.remove_only_teacher_neg, + remove_neg=self.cfg.gram.remove_neg, + ) + # Construct gram teacher + self.has_gram_teacher = True if not cfg.gram.ema_teacher else False + if self.has_gram_teacher: + self.gram_teacher = nn.ModuleDict(gram_model_dict) + self.gram_teacher.requires_grad_(False) + logger.info(f"Gram teacher parameter at init: {next(self.gram_teacher.named_parameters())}") + else: + self.gram_teacher = None + + self.gram_loss_weight = self.cfg.gram.loss_weight + if self.cfg.gram.get("loss_weight_schedule"): + iter_per_epoch = cfg.train.OFFICIAL_EPOCH_LENGTH + total_iterations = iter_per_epoch * cfg.optim.epochs + schedule_cfg = self.cfg.gram.loss_weight_schedule + self.gram_loss_schedule = linear_warmup_cosine_decay( + start=schedule_cfg.start, + peak=schedule_cfg.peak, + end=schedule_cfg.end, + warmup_iterations=iter_per_epoch * schedule_cfg.warmup_epochs, + total_iterations=total_iterations, + cosine_iterations=( + iter_per_epoch * schedule_cfg.cosine_epochs if "cosine_epochs" in schedule_cfg else None + ), + ) + logger.info(f"Applying gram loss weight schedule instead of `cfg.gram.loss_weight`: {schedule_cfg}") + else: + self.gram_loss_schedule = None + self.gram_ema_teacher = self.cfg.gram.ema_teacher # If true use the EMA_teacher as gram_teacher + self.gram_ckpt = self.cfg.gram.ckpt # Checkpoint to the first gram teacher model + self.gram_img_level = self.cfg.gram.img_level # Apply the loss on the image, if false on the batch + self.gram_tokens_used = self.cfg.gram.tokens_used # Any value in ["all", "masked", "unmasked"] + # Update the teacher frequently + self.gram_rep_update = self.cfg.gram.rep_update # bool, if yes the gram teacher will be updated at the freq + self.gram_update_frequency = self.cfg.gram.update_frequency # defined by this var update_frequency + self.gram_it_first_update = self.cfg.gram.it_first_update # after iteration it_first_update is passed. + self.gram_it_load_ema_teacher = ( + self.cfg.gram.it_load_ema_teacher + ) # after iteration it_load_ema the ema teacher is loaded into the gram teacher + self.gram_compute_stats = self.cfg.gram.compute_stats # whether to compute auxiliary stats + self.gram_params_lists = None + + if self.gram_ema_teacher and self.gram_ckpt is not None: + raise ValueError( + "Cannot use both `gram.ema_teacher` and `gram.ckpt` at the same time. Please set one of them to False." + ) + if self.gram_ckpt is None and self.gram_it_load_ema_teacher < 0: + raise ValueError( + "If no gram checkpoint is provided, `gram.it_load_ema_teacher` must be set to a non-negative value." + ) + + assert not (self.gram_ema_teacher and self.gram_rep_update) + assert self.gram_tokens_used in ["all", "masked", "unmasked"] + # Currently using masked/unmasked not handle at the image-level + if self.gram_tokens_used in ["masked", "unmasked"]: + assert self.gram_img_level is False + + logger.info("OPTIONS -- GRAM") + logger.info(f"OPTIONS -- GRAM -- loss_weight: {cfg.gram.loss_weight}") + logger.info(f"OPTIONS -- GRAM -- ema teacher: {cfg.gram.ema_teacher}") + logger.info(f"OPTIONS -- GRAM -- ckpt: {cfg.gram.ckpt}") + if self.cfg.gram.rep_update: + logger.info(f"OPTIONS -- GRAM -- repeated update: {cfg.gram.rep_update}") + logger.info(f"OPTIONS -- GRAM -- update freq: {cfg.gram.update_frequency}") + logger.info(f"OPTIONS -- GRAM -- iteration first update: {cfg.gram.it_first_update}") + + logger.info(f"OPTIONS -- GRAM -- tokens_used: {cfg.gram.tokens_used}") + logger.info(f"OPTIONS -- GRAM -- apply normalization: {cfg.gram.normalized}") + logger.info(f"OPTIONS -- GRAM -- img_level: {cfg.gram.img_level}") + logger.info(f"OPTIONS -- GRAM -- remove_neg: {cfg.gram.remove_neg}") + logger.info(f"OPTIONS -- GRAM -- remove_only_teacher_neg: {cfg.gram.remove_only_teacher_neg}") + + if cfg.crops.gram_teacher_crops_size is None and self.has_gram_teacher: + raise ValueError("cfg.crops.gram_teacher_crops_size must be set to use gram loss") + if cfg.crops.gram_teacher_crops_size is not None and self.gram_ema_teacher: + raise ValueError("cfg.crops.gram_teacher_crops_size shoud be None when gram.ema_teacher=True") + + self.student_crop_size = cfg.crops.global_crops_size + self.gram_global_teacher_resize_method = cfg.gram.global_teacher_resize_method + self.gram_global_teacher_resize_antialias = cfg.gram.global_teacher_resize_antialias + logger.info(f"OPTIONS -- global crops student/teacher size: {self.student_crop_size}") + logger.info(f"OPTIONS -- global crops GRAM teacher size: {cfg.crops.gram_teacher_crops_size}") + logger.info(f"OPTIONS -- global crops GRAM teacher resize method: {cfg.gram.global_teacher_resize_method}") + logger.info( + f"OPTIONS -- global crops GRAM teacher resize antialias: {cfg.gram.global_teacher_resize_antialias}" + ) + + def _setup_distillation(self): + logger.info(f"Performing distillation from {self.cfg.distillation.full_cfg_path}") + + default_cfg = get_default_config() + distillation_cfg = OmegaConf.load(self.cfg.distillation.full_cfg_path) + distillation_cfg = OmegaConf.merge(default_cfg, distillation_cfg) + + assert distillation_cfg.ibot.separate_head is True + assert distillation_cfg.ibot.head_n_prototypes == self.cfg.ibot.head_n_prototypes + assert distillation_cfg.dino.head_n_prototypes == self.cfg.dino.head_n_prototypes + assert distillation_cfg.student.patch_size == self.cfg.student.patch_size + + teacher_model_dict = dict() + + backbone, embed_dim = build_model_from_cfg(distillation_cfg, only_teacher=True) + teacher_model_dict["backbone"] = backbone + + teacher_model_dict["dino_head"] = DINOHead( + in_dim=embed_dim, + out_dim=distillation_cfg.dino.head_n_prototypes, + hidden_dim=distillation_cfg.dino.head_hidden_dim, + bottleneck_dim=distillation_cfg.dino.head_bottleneck_dim, + nlayers=distillation_cfg.dino.head_nlayers, + ) + teacher_model_dict["ibot_head"] = DINOHead( + in_dim=embed_dim, + out_dim=distillation_cfg.ibot.head_n_prototypes, + hidden_dim=distillation_cfg.ibot.head_hidden_dim, + bottleneck_dim=distillation_cfg.ibot.head_bottleneck_dim, + nlayers=distillation_cfg.ibot.head_nlayers, + ) + self.teacher = nn.ModuleDict(teacher_model_dict) + + def init_weights(self) -> None: + # All weights are set to `nan` to ensure we initialize everything explicitly + self.student.backbone.init_weights() + self.student.dino_head.init_weights() + self.student.ibot_head.init_weights() + self.dino_loss.init_weights() + self.ibot_patch_loss.init_weights() + self.model_ema.load_state_dict(self.student.state_dict()) + if self.has_gram_teacher: + if self.gram_ckpt is not None: + logger.info(f"Loading pretrained weights from {self.gram_ckpt}") + init_fsdp_model_from_checkpoint( + self.gram_teacher, + self.gram_ckpt, + skip_load_keys=[ + "dino_head", + "ibot_head", + "dino_loss.center", + "ibot_patch_loss.center", + ], + keys_not_sharded=["backbone.rope_embed.periods", "qkv.bias_mask"], + process_group=distributed.get_default_process_group(), + ) + self.gram_teacher_initialized = True + else: + raise ValueError(f"Provide a correct path to {self.gram_ckpt}") + self.gram_teacher.requires_grad_(False) + self.gram_teacher.eval() + if self.cfg.student.resume_from_teacher_chkpt: + logger.info(f"Loading pretrained weights from {self.cfg.student.resume_from_teacher_chkpt}") + init_fsdp_model_from_checkpoint( + self.student, + self.cfg.student.resume_from_teacher_chkpt, + skip_load_keys=["dino_loss.center", "ibot_patch_loss.center"], + keys_not_sharded=["backbone.rope_embed.periods", "qkv.bias_mask"], + process_group=distributed.get_process_subgroup(), + ) + self.model_ema.load_state_dict(self.student.state_dict()) + if self.cfg.distillation.enabled: + if self.cfg.distillation.checkpoint_path != "ignore": + logger.info(f"Loading teacher to distil from : {self.cfg.distillation.checkpoint_path}") + init_fsdp_model_from_checkpoint( + self.teacher, + self.cfg.distillation.checkpoint_path, + skip_load_keys=["dino_loss.center", "ibot_patch_loss.center"], + keys_not_sharded=["backbone.rope_embed.periods", "qkv.bias_mask"], + ) + else: + logger.info("Init teacher to distil from, used for testing purpose only") + self.teacher.backbone.init_weights() + self.teacher.dino_head.init_weights() + self.teacher.ibot_head.init_weights() + logger.info(f"Performing distillation from: {self.teacher}") + + def forward_backward( + self, data, *, teacher_temp, iteration=0, **ignored_kwargs + ) -> tuple[Tensor, dict[str, float | Tensor]]: + del ignored_kwargs + metrics_dict = {} + + # Shapes + n_global_crops = 2 + n_local_crops = self.n_local_crops # self.cfg.crops.local_crops_number + B = data["collated_local_crops"].shape[0] // n_local_crops + assert data["collated_global_crops"].shape[0] == n_global_crops * B + metrics_dict["local_batch_size"] = B + metrics_dict["global_batch_size"] = data["global_batch_size"] + + global_crops = data["collated_global_crops"].cuda(non_blocking=True) + local_crops = data["collated_local_crops"].cuda(non_blocking=True) + masks = data["collated_masks"].cuda(non_blocking=True) + mask_indices_list = data["mask_indices_list"].cuda(non_blocking=True) + masks_weight = data["masks_weight"].cuda(non_blocking=True) + n_masked_patches_tensor = data["n_masked_patches"].cuda(non_blocking=True) + + if self.has_gram_teacher: + assert "collated_gram_teacher_crops" in data, ( + "no gram teacher crops in the data, have you set cfg.crops.gram_teacher_crops_size?" + ) + gram_teacher_crops = data["collated_gram_teacher_crops"].cuda(non_blocking=True) + else: + gram_teacher_crops = None + + # Teacher output (will trigger an all-gather to unshard) + teacher_global = self.get_teacher_output( + global_crops.unflatten(0, (n_global_crops, B)), + teacher_temp=teacher_temp, + n_masked_patches_tensor=n_masked_patches_tensor, + mask_indices_list=mask_indices_list, + upperbound=data["upperbound"], + ) + + # Student output (will trigger an all-gather to unshard) + student_global, student_local = self.get_student_output( + global_crops=global_crops.unflatten(0, (n_global_crops, B)), + local_crops=local_crops.unflatten(0, (n_local_crops, B)), + upperbound=data["upperbound"], + masks=masks, + mask_indices_list=mask_indices_list, + ) + + # Gram output + if self.gram_use_loss: + gram_global = self.get_gram_teacher_output( + gram_teacher_crops.unflatten(0, (n_global_crops, B)) if gram_teacher_crops is not None else None, + masks=masks, + teacher_global=teacher_global, + student_global=student_global, + student_global_crops_size=global_crops.shape[-1], + ) + else: + gram_global = {} + + # Compute losses and backprop + loss_accumulator, loss_dict = self.compute_losses( + teacher_global=teacher_global, + student_global=student_global, + student_local=student_local, + gram_global=gram_global, + masks=masks, + mask_indices_list=mask_indices_list, + masks_weight=masks_weight, + iteration=iteration, + ) + + self.backprop_loss(loss_accumulator) + + # Return total weighted loss and a dict of metrics to log + return loss_accumulator, metrics_dict | loss_dict + + @torch.no_grad() + def get_teacher_output( + self, + images, + *, + upperbound, + mask_indices_list, + teacher_temp, + n_masked_patches_tensor, + ): + n_crops, B, rgb, H, W = images.shape + images = images.flatten(0, 1) + + backbone_out = self.teacher.backbone(images, is_training=True) + cls = backbone_out["x_norm_clstoken"] # [n_crops * B, D] + reg = backbone_out["x_storage_tokens"] # [n_crops * B, R, D] + ibot_patch = backbone_out["x_norm_patchtokens"] # [n_crops * B, P, D] + + # IBOT head only on patches that are masked for the student + buffer = torch.index_select(ibot_patch.flatten(0, 1), dim=0, index=mask_indices_list) + masked_patch_after_head = self.teacher.ibot_head(buffer) + + # DINO head on CLS tokens + cls_after_head = self.teacher.dino_head(cls) # [n_crops * B, K] + + # Center with sinkhorn-knopp + cls_centered = self.dino_loss.sinkhorn_knopp_teacher( + cls_after_head, teacher_temp=teacher_temp + ) # [n_crops * B, K] + cls_centered = cls_centered.unflatten(0, (n_crops, B)) # [n_crops, B, K] + masked_patch_centered = self.ibot_patch_loss.sinkhorn_knopp_teacher( + masked_patch_after_head, + teacher_temp=teacher_temp, + n_masked_patches_tensor=n_masked_patches_tensor, + ) # [n_masked_patches, K] + + return { + "cls_pre_head": cls.unflatten(0, [n_crops, B]), # [n_crops, B, D] + "reg_pre_head": reg.unflatten(0, [n_crops, B]), # [n_crops, B, R, D] + "patch_pre_head": ibot_patch.unflatten(0, [n_crops, B]), # [n_crops, B, P, D] + "cls_after_head": cls_after_head.unflatten(0, [n_crops, B]), # [n_crops, B, K] + "cls_centered": cls_centered, # [n_crops, B, K] + "masked_patch_centered": masked_patch_centered, # [n_masked_patches, K] + } + + def get_gram_teacher_output(self, images, *, masks, teacher_global, student_global, student_global_crops_size): + # Get student patch features + student_patches = student_global["patch_pre_head"].flatten(0, 1) # [n_crops * B, P, D] + + # Get gram targets + if self.gram_ema_teacher: + teacher_patches = teacher_global["patch_pre_head"].flatten(0, 1) # [n_crops * B, P, D] + else: + if not self.gram_teacher_initialized: + raise ValueError("Gram teacher has not been initialized. Load a checkpoint or from the EMA teacher.") + n_crops, B, rgb, H, W = images.shape + images = images.flatten(0, 1) # [n_crops * B, rgb, H, W] + + with torch.no_grad(): + backbone_out = self.gram_teacher.backbone(images, is_training=True) + teacher_patches = backbone_out["x_norm_patchtokens"] # [n_crops * B, P_T, D] + + # Downsample Gram teacher features if needed + if teacher_patches.shape[1] != student_patches.shape[1]: + N = H // self.cfg.student.patch_size + assert teacher_patches.shape[1] == N**2 + N_student = student_global_crops_size // self.cfg.student.patch_size + assert student_patches.shape[1] == N_student**2 + patches_hw = teacher_patches.transpose(-2, -1).unflatten(-1, (N, N)) # [n_crops * B, D, N, N] + patches_hw = torch.nn.functional.interpolate( + patches_hw, + size=(N_student, N_student), + mode=self.gram_global_teacher_resize_method, + align_corners=False, + antialias=self.gram_global_teacher_resize_antialias, + ) + teacher_patches = patches_hw.flatten(-2, -1).transpose( + -2, -1 + ) # [n_crops * B, N_student * N_student, D] + assert teacher_patches.shape == student_patches.shape + + # Select the patches to be considered in the loss + orig_student_patches = student_patches + orig_teacher_patches = teacher_patches + if self.gram_tokens_used == "masked": + student_patches = student_patches[masks] + teacher_patches = teacher_patches[masks] + elif self.gram_tokens_used == "unmasked": + student_patches = student_patches[~masks] + teacher_patches = teacher_patches[~masks] + + return { + "student_patches": student_patches, # [n_crops * B, P, D] or [n_selected_patches, D] + "teacher_patches": teacher_patches, # [n_crops * B, P, D] or [n_selected_patches, D] + # Unmasked patches, for computing statistics + "orig_student_patches": orig_student_patches, # [n_crops * B, P, D] + "orig_teacher_patches": orig_teacher_patches, # [n_crops * B, P, D] + } + + def get_student_output(self, *, global_crops, local_crops, upperbound, masks, mask_indices_list): + n_global_crops, B, rgb, H, W = global_crops.shape + n_local_crops, B, rgb, H, W = local_crops.shape + + global_crops = global_crops.flatten(0, 1) + + # Forward global and local crops through the student backbone jointly + global_out, local_out = self.student.backbone( + [global_crops, local_crops.flatten(0, 1)], + masks=[masks if not self.is_distillation_enabled else None, None], + is_training=True, + ) + g_cls, g_reg, g_patch = ( + global_out["x_norm_clstoken"], + global_out["x_storage_tokens"], + global_out["x_norm_patchtokens"], + ) + l_cls, l_reg, l_patch = ( + local_out["x_norm_clstoken"], + local_out["x_storage_tokens"], + local_out["x_norm_patchtokens"], + ) + + # IBOT head only on masked patches + masked_patches_pre_head = torch.index_select(g_patch.flatten(0, 1), dim=0, index=mask_indices_list) + global_masked_patch_after_head = self.student.ibot_head(masked_patches_pre_head) + + # DINO head on CLS tokens (all in one pass) + buffer = [ + g_cls, # [n_global_crops * B, D] + l_cls, # [n_local_crops * B, D] + ] + sizes = [x.shape[0] for x in buffer] + buffer = torch.cat(buffer, dim=0) # [n_global_crops * B + n_local_crops * B, D] + buffer = self.student.dino_head(buffer) # [n_global_crops * B + n_local_crops * B, K] + buffer = torch.split_with_sizes(buffer, sizes, dim=0) + + global_out = { + "cls_pre_head": g_cls.unflatten(0, [n_global_crops, B]), # [n_global_crops, B, D] + "reg_pre_head": g_reg.unflatten(0, [n_global_crops, B]), # [n_global_crops, B, R, D] + "patch_pre_head": g_patch.unflatten(0, [n_global_crops, B]), # [n_global_crops, B, P, D] + "cls_after_head": buffer[0].unflatten(0, [n_global_crops, B]), # [n_global_crops, B, K], + "masked_patch_after_head": global_masked_patch_after_head, # [n_masked_patches, K] + "masked_patch_pre_head": masked_patches_pre_head, # [n_masked_patches, D] + } + local_out = { + "cls_pre_head": l_cls.unflatten(0, [n_local_crops, B]), # [n_local_crops, B, D] + "reg_pre_head": l_reg.unflatten(0, [n_local_crops, B]), # [n_local_crops, B, R, D] + "patch_pre_head": l_patch.unflatten(0, [n_local_crops, B]), # [n_local_crops, B, P, D] + "cls_after_head": buffer[1].unflatten(0, [n_local_crops, B]), # [n_local_crops, B, K], + } + + return global_out, local_out + + def compute_losses( + self, + *, + teacher_global, + student_global, + student_local, + gram_global, + masks, + mask_indices_list, + masks_weight, + iteration, + ): + n_global_crops = student_global["cls_after_head"].shape[0] + n_local_crops = student_local["cls_after_head"].shape[0] + loss_dict = {} + loss_accumulator = 0.0 + + # Loss scales like in DINOv2, these are multiplied with the loss weights from the config + dino_global_terms = ( + n_global_crops * (n_global_crops - 1) if self.dino_global_ignore_diagonal else n_global_crops**2 + ) + dino_local_terms = n_global_crops * n_local_crops + dino_global_scale = dino_global_terms / (dino_global_terms + dino_local_terms) + dino_local_scale = dino_local_terms / (dino_global_terms + dino_local_terms) + koleo_scale = n_global_crops + + # DINO local loss: compare post-head CLS tokens: student(local crops) vs. teacher(global crops) + dino_local_crops_loss = self.dino_loss( + student_logits=student_local["cls_after_head"], + teacher_probs=teacher_global["cls_centered"], + ) + loss_dict["dino_local_crops_loss"] = dino_local_crops_loss + + # Reweighting of DINO loss + if self.cfg.dino.reweight_dino_local_loss: + local_weight = self.dino_local_loss_schedule[iteration] + else: + local_weight = 1.0 + + loss_dict["dino_local_loss_weight"] = local_weight + loss_accumulator += self.dino_loss_weight * dino_local_scale * local_weight * dino_local_crops_loss + + # DINO global loss: compare post-head CLS tokens: student(global crops) vs. teacher(global crops) + dino_global_crops_loss = self.dino_loss( + student_logits=student_global["cls_after_head"], + teacher_probs=teacher_global["cls_centered"], + ignore_diagonal=self.dino_global_ignore_diagonal, + ) + loss_dict["dino_global_crops_loss"] = dino_global_crops_loss + loss_accumulator += self.dino_loss_weight * dino_global_scale * dino_global_crops_loss + + # Koleo: regularize pre-head CLS tokens of student(global crops) + koleo_loss = sum(self.koleo_loss(x) for x in student_global["cls_pre_head"]) / n_global_crops + loss_dict["koleo_loss"] = koleo_loss + loss_accumulator += self.dino_koleo_loss_weight * koleo_scale * koleo_loss + + # IBOT loss + ibot_patch_loss = self.ibot_patch_loss.forward_masked( + student_global["masked_patch_after_head"], + teacher_global["masked_patch_centered"], + student_masks_flat=masks, + n_masked_patches=mask_indices_list.shape[0], + masks_weight=masks_weight, + ) + loss_dict["ibot_loss"] = ibot_patch_loss + loss_accumulator += self.ibot_loss_weight * ibot_patch_loss + + # Gram loss + if self.gram_use_loss: + gram_loss = self.gram_loss( + gram_global["student_patches"], + gram_global["teacher_patches"], + img_level=self.gram_img_level, + ) + + if self.gram_loss_schedule is not None: + gram_loss_weight = self.gram_loss_schedule[iteration] + else: + gram_loss_weight = self.gram_loss_weight + + loss_dict["gram_loss_weight"] = gram_loss_weight + loss_accumulator += gram_loss * gram_loss_weight + loss_dict["gram_loss"] = gram_loss + + if self.gram_compute_stats: + with torch.no_grad(): + # Save stats over masked / unmasked tokens + gram_loss_masked = self.gram_loss( + gram_global["orig_student_patches"][masks].detach(), + gram_global["orig_teacher_patches"][masks], + img_level=False, + ) + loss_dict["stats_only/masked_gram_loss"] = gram_loss_masked + gram_loss_unmasked = self.gram_loss( + gram_global["orig_student_patches"][~masks].detach(), + gram_global["orig_teacher_patches"][~masks], + img_level=False, + ) + loss_dict["stats_only/unmasked_gram_loss"] = gram_loss_unmasked + + return loss_accumulator, loss_dict + + @torch.no_grad() + def gram_load_ema_teacher(self): + if self.has_gram_teacher: + skip_load_prefixes = ["dino_head.", "ibot_head."] + self.gram_teacher.load_state_dict( + { + k: v + for k, v in self.model_ema.state_dict().items() + if not any(k.startswith(prefix) for prefix in skip_load_prefixes) + } + ) + self.gram_teacher.requires_grad_(False) + self.gram_teacher.eval() + self.gram_teacher_initialized = True + + def train(self): + super().train() + self.teacher.eval() + if self.has_gram_teacher: + self.gram_teacher.eval() + + def forward(self, inputs): + raise NotImplementedError + + def backprop_loss(self, loss): + loss.backward() + + def update_ema(self, m): + if self.ema_params_lists is None: + student_param_list = [] + teacher_param_list = [] + for k in self.student.keys(): + for ms, mt in zip(self.student[k].parameters(), self.model_ema[k].parameters()): + student_param_list += [ms] + teacher_param_list += [mt] + self.ema_params_lists = (student_param_list, teacher_param_list) + else: + student_param_list, teacher_param_list = self.ema_params_lists + with torch.no_grad(): + torch._foreach_mul_(teacher_param_list, m) + torch._foreach_add_(teacher_param_list, student_param_list, alpha=1 - m) + + def update_gram(self, m=0): + if not self.has_gram_teacher: + return + logger.info("Updating gram teacher with teacher weights.") + if self.gram_params_lists is None: + teacher_param_list = [] + gramteacher_param_list = [] + for k in self.gram_teacher.keys(): + for mgt, mt in zip(self.gram_teacher[k].parameters(), self.teacher[k].parameters()): + gramteacher_param_list += [mgt] + teacher_param_list += [mt] + self.gram_params_lists = (gramteacher_param_list, teacher_param_list) + else: + gramteacher_param_list, teacher_param_list = self.gram_params_lists + + with torch.no_grad(): + torch._foreach_mul_(gramteacher_param_list, m) + torch._foreach_add_(gramteacher_param_list, teacher_param_list, alpha=1 - m) + + def build_data_augmentation_dino(self, cfg): + return DataAugmentationDINO( + cfg.crops.global_crops_scale, + cfg.crops.local_crops_scale, + cfg.crops.local_crops_number, + global_crops_size=cfg.crops.global_crops_size, + local_crops_size=cfg.crops.local_crops_size, + gram_teacher_crops_size=cfg.crops.gram_teacher_crops_size, + gram_teacher_no_distortions=cfg.crops.gram_teacher_no_distortions, + local_crops_subset_of_global_crops=cfg.crops.localcrops_subset_of_globalcrops, + share_color_jitter=cfg.crops.share_color_jitter, + horizontal_flips=cfg.crops.horizontal_flips, + mean=cfg.crops.rgb_mean, + std=cfg.crops.rgb_std, + ) + + def get_maybe_fused_params_for_submodel(self, m: nn.Module): + params_groups = get_params_groups_with_decay_fsdp( + model=m, + lr_decay_rate=self.cfg.optim.layerwise_decay, + patch_embed_lr_mult=self.cfg.optim.patch_embed_lr_mult, + dino_head_wd_multiplier=self.cfg.optim.dino_head_wd_multiplier, + ) + if self.cfg.optim.multi_tensor_optim: + fused_params_groups = fuse_params_groups(params_groups) + logger.info("fusing param groups") + + for g in fused_params_groups: + g["foreach"] = True + g["fused"] = True + return fused_params_groups + else: + return params_groups + + def get_params_groups(self): + all_params_groups = [] + for name, m in self.student.items(): + logger.info(f"Getting paramer groups for {name}") + all_params_groups += self.get_maybe_fused_params_for_submodel(m) + return all_params_groups + + def prepare_for_distributed_training(self) -> None: + process_subgroup = distributed.get_process_subgroup() + default_process_group = distributed.get_default_process_group() + inference_only_models = [self.model_ema] + inference_only_models_process_groups = [process_subgroup] + if self.has_gram_teacher: + inference_only_models.append(self.gram_teacher) + inference_only_models_process_groups.append(default_process_group) + if self.cfg.distillation.enabled: + inference_only_models.append(self.teacher) + inference_only_models_process_groups.append(default_process_group) + ac_compile_parallelize( + trained_model=self.student, + inference_only_models=inference_only_models, + cfg=self.cfg, + trained_model_process_group=process_subgroup, + inference_only_models_process_groups=inference_only_models_process_groups, + ) + + def broadcast_to_subgroups(self, tensor, over_dim, global_batch_size=None): + """ + This is an operation that takes a tensor from the default process group, gathers it, stacks it, then scatters it within a smaller process subgroup + """ + world_size = distributed.get_world_size() + subgroup_size = distributed.get_subgroup_size() + gathered = [torch.zeros_like(tensor) for _ in range(world_size)] + + torch.distributed.all_gather(gathered, tensor) + catted = torch.cat(gathered, dim=over_dim) + if global_batch_size is not None: + catted = catted.narrow(dim=over_dim, start=0, length=global_batch_size) + + return catted.chunk(subgroup_size, dim=over_dim)[distributed.get_subgroup_rank()].clone() diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/train.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..fe78876571b7c0b09b889422fb50e6e5683759f4 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/train/train.py @@ -0,0 +1,637 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import argparse +import copy +import gc +import logging +import math +import os +import sys +from functools import partial +from pathlib import Path + +import torch +import torch.distributed +from torch.distributed._tensor import DTensor + +import dinov3.distributed as distributed +from dinov3.checkpointer import ( + find_latest_checkpoint, + keep_checkpoint_copy, + keep_last_n_checkpoints, + load_checkpoint, + register_dont_save_hooks, + save_checkpoint, +) +from dinov3.configs import setup_config, setup_job, setup_multidistillation +from dinov3.data import ( + MaskingGenerator, + SamplerType, + collate_data_and_cast, + make_data_loader, + make_dataset, + CombinedDataLoader, +) +from dinov3.logging import MetricLogger, setup_logging +from dinov3.train.cosine_lr_scheduler import CosineScheduler, linear_warmup_cosine_decay +from dinov3.train.multidist_meta_arch import MultiDistillationMetaArch +from dinov3.train.ssl_meta_arch import SSLMetaArch + +assert torch.__version__ >= (2, 1) +torch.backends.cuda.matmul.allow_tf32 = True # pytorch 1.12 sets this to false by default +torch.backends.cudnn.benchmark = False # True + +logger = logging.getLogger("dinov3") + + +def get_args_parser(add_help: bool = True): + parser = argparse.ArgumentParser("DINOv3 training", add_help=add_help) + parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") + parser.add_argument( + "--no-resume", + action="store_true", + help="Whether to not attempt to resume from the checkpoint directory. ", + ) + parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") + parser.add_argument("--eval", type=str, default="", help="Eval type to perform") + parser.add_argument( + "--eval_pretrained_weights", + type=str, + default="", + help="Path to pretrained weights", + ) + parser.add_argument( + "opts", + help=""" +Modify config options at the end of the command. For Yacs configs, use +space-separated "PATH.KEY VALUE" pairs. +For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument( + "--output-dir", + default="./local_dino", + type=str, + help="Path to save logs and checkpoints.", + ) + parser.add_argument("--seed", default=0, type=int, help="RNG seed") + parser.add_argument( + "--benchmark-codebase", + action="store_true", + help="test the codebase for a few iters", + ) + parser.add_argument("--test-ibot", action="store_true", help="test ibot") + parser.add_argument("--profiling", action="store_true", help="do profiling") + parser.add_argument("--dump-fsdp-weights", action="store_true", help="dump fsdp weights") + parser.add_argument("--record_ref_losses", action="store_true", help="record reference losses") + parser.add_argument("--ref_losses_path", default="", type=str) + parser.add_argument("--multi-distillation", action="store_true", help="run multi-distillation") + + return parser + + +def build_optimizer(cfg, params_groups): + return torch.optim.AdamW(params_groups, betas=(cfg.optim.adamw_beta1, cfg.optim.adamw_beta2)) + + +def build_schedulers(cfg): + if "schedules" in cfg: + logger.info("Using schedules v2") + return build_schedulers_v2(cfg) + + OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH + lr = dict( + base_value=cfg.optim["lr"], + final_value=cfg.optim["min_lr"], + total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, + warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH, + start_warmup_value=0, + trunc_extra=cfg.optim["schedule_trunc_extra"], + ) + wd = dict( + base_value=cfg.optim["weight_decay"], + final_value=cfg.optim["weight_decay_end"], + total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, + trunc_extra=cfg.optim["schedule_trunc_extra"], + ) + momentum = dict( + base_value=cfg.teacher["momentum_teacher"], + final_value=cfg.teacher["final_momentum_teacher"], + total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, + trunc_extra=cfg.optim["schedule_trunc_extra"], + ) + teacher_temp = dict( + base_value=cfg.teacher["teacher_temp"], + final_value=cfg.teacher["teacher_temp"], + total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, + warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, + start_warmup_value=cfg.teacher["warmup_teacher_temp"], + ) + + lr_schedule = CosineScheduler(**lr) + wd_schedule = CosineScheduler(**wd) + momentum_schedule = CosineScheduler(**momentum) + teacher_temp_schedule = CosineScheduler(**teacher_temp) + last_layer_lr_schedule = CosineScheduler(**lr) + + last_layer_lr_schedule.schedule[: cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH] = ( + 0 # mimicking the original schedules + ) + logger.info("Schedulers ready.") + return ( + lr_schedule, + wd_schedule, + momentum_schedule, + teacher_temp_schedule, + last_layer_lr_schedule, + ) + + +def build_schedulers_v2(cfg): + iter_per_epoch = cfg.train.OFFICIAL_EPOCH_LENGTH + total_iterations = cfg.train.OFFICIAL_EPOCH_LENGTH * cfg.optim.epochs + logger.info(f"Total training iterations {total_iterations}") + + # LR scaling rules + lr_peak = cfg.schedules.lr.peak + lr_end = cfg.schedules.lr.end + if cfg.optim.scaling_rule == "linear_wrt_256": + lr_peak *= cfg.train.batch_size_per_gpu * distributed.get_world_size() / 256.0 + lr_end *= cfg.train.batch_size_per_gpu * distributed.get_world_size() / 256.0 + logger.info( + f"Scaling rule {cfg.optim.scaling_rule}, LR peak {cfg.schedules.lr.peak} -> {lr_peak}, LR end {cfg.schedules.lr.end} -> {lr_end}" + ) + elif cfg.optim.scaling_rule == "sqrt_wrt_1024": + lr_peak *= 4 * math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_world_size() / 1024.0) + lr_end *= 4 * math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_world_size() / 1024.0) + logger.info( + f"Scaling rule {cfg.optim.scaling_rule}, LR peak {cfg.schedules.lr.peak} -> {lr_peak}, LR end {cfg.schedules.lr.end} -> {lr_end}" + ) + else: + logger.info(f"No scaling rule for {cfg.optim.scaling_rule=}") + + lr = linear_warmup_cosine_decay( + start=cfg.schedules.lr.start, + peak=lr_peak, + end=lr_end, + warmup_iterations=iter_per_epoch * cfg.schedules.lr.warmup_epochs, + total_iterations=total_iterations, + cosine_iterations=( + iter_per_epoch * cfg.schedules.lr.cosine_epochs if "cosine_epochs" in cfg.schedules.lr else None + ), + ) + last_layer_lr = lr.copy() + last_layer_lr[: iter_per_epoch * cfg.schedules.lr.freeze_last_layer_epochs] = 0 + weight_decay = linear_warmup_cosine_decay( + start=cfg.schedules.weight_decay.start, + peak=cfg.schedules.weight_decay.peak, + end=cfg.schedules.weight_decay.end, + warmup_iterations=iter_per_epoch * cfg.schedules.weight_decay.warmup_epochs, + total_iterations=total_iterations, + cosine_iterations=( + iter_per_epoch * cfg.schedules.weight_decay.cosine_epochs + if "cosine_epochs" in cfg.schedules.weight_decay + else None + ), + ) + momentum = linear_warmup_cosine_decay( + start=cfg.schedules.momentum.start, + peak=cfg.schedules.momentum.peak, + end=cfg.schedules.momentum.end, + warmup_iterations=iter_per_epoch * cfg.schedules.momentum.warmup_epochs, + total_iterations=total_iterations, + cosine_iterations=( + iter_per_epoch * cfg.schedules.momentum.cosine_epochs if "cosine_epochs" in cfg.schedules.momentum else None + ), + ) + teacher_temp = linear_warmup_cosine_decay( + start=cfg.schedules.teacher_temp.start, + peak=cfg.schedules.teacher_temp.peak, + end=cfg.schedules.teacher_temp.end, + warmup_iterations=iter_per_epoch * cfg.schedules.teacher_temp.warmup_epochs, + total_iterations=total_iterations, + cosine_iterations=( + iter_per_epoch * cfg.schedules.teacher_temp.cosine_epochs + if "cosine_epochs" in cfg.schedules.teacher_temp + else None + ), + ) + return lr, weight_decay, momentum, teacher_temp, last_layer_lr + + +def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr): + for param_group in optimizer.param_groups: + is_last_layer = param_group["is_last_layer"] + lr_multiplier = param_group["lr_multiplier"] + wd_multiplier = param_group["wd_multiplier"] + param_group["weight_decay"] = wd * wd_multiplier + if is_last_layer: + param_group["lr"] = last_layer_lr * lr_multiplier + else: + param_group["lr"] = lr * lr_multiplier + + +def do_test(cfg, model, iteration, process_group, do_low_freq=False): + # dump a sharded checkpoint + eval_dir = Path(cfg.train.output_dir) / "eval" / str(iteration) + if distributed.is_subgroup_main_process(): + eval_dir.mkdir(parents=True, exist_ok=True) + if cfg.train.sharded_eval_checkpoint: + ckpt_path = eval_dir / "sharded_teacher_checkpoint" + if distributed.is_subgroup_main_process(): + ckpt_path.mkdir(parents=True, exist_ok=True) + torch.distributed.barrier() + teacher_backbone = model.model_ema + save_checkpoint( + ckpt_dir=ckpt_path, iteration=iteration, model=teacher_backbone, overwrite=True, process_group=process_group + ) + if not distributed.is_subgroup_main_process(): + return + else: + new_state_dict = model.model_ema.state_dict() + for k, tensor in list(new_state_dict.items()): + if isinstance(tensor, DTensor): + new_state_dict[k] = tensor.full_tensor() + if not distributed.is_subgroup_main_process(): + return + # save teacher checkpoint + ckpt_path = eval_dir / "teacher_checkpoint.pth" + torch.save({"teacher": new_state_dict}, ckpt_path) + logger.info("Saved eval checkpoint: %s", ckpt_path) + + +def build_data_loader_from_cfg( + cfg, + model, + start_iter, +): + # Collate function + img_size = cfg.crops.global_crops_size + patch_size = cfg.student.patch_size + n_tokens = (img_size // patch_size) ** 2 + mask_generator = MaskingGenerator( + input_size=(img_size // patch_size, img_size // patch_size), + max_num_patches=0.5 * img_size // patch_size * img_size // patch_size, + ) + + if cfg.multidistillation.enabled: + assert cfg.multidistillation.global_batch_size % distributed.get_subgroup_size() == 0 + local_batch_size = cfg.multidistillation.global_batch_size // distributed.get_subgroup_size() + dataloader_batch_size_per_gpu = ( + cfg.multidistillation.global_batch_size + (distributed.get_world_size() - 1) + ) // distributed.get_world_size() + else: + local_batch_size = None # will default to the standard local batch size matching the data batch size + dataloader_batch_size_per_gpu = cfg.train.batch_size_per_gpu + + collate_fn = partial( + collate_data_and_cast, + mask_ratio_tuple=cfg.ibot.mask_ratio_min_max, + mask_probability=cfg.ibot.mask_sample_probability, + dtype={ + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, + }[cfg.compute_precision.param_dtype], + n_tokens=n_tokens, + mask_generator=mask_generator, + random_circular_shift=cfg.ibot.mask_random_circular_shift, + local_batch_size=local_batch_size, + ) + batch_size = dataloader_batch_size_per_gpu + num_workers = cfg.train.num_workers + dataset_path = cfg.train.dataset_path + dataset = make_dataset( + dataset_str=dataset_path, + transform=model.build_data_augmentation_dino(cfg), + target_transform=lambda _: (), + ) + + if isinstance(dataset, torch.utils.data.IterableDataset): + sampler_type = SamplerType.INFINITE + else: + sampler_type = SamplerType.SHARDED_INFINITE if cfg.train.cache_dataset else SamplerType.INFINITE + + data_loader = make_data_loader( + dataset=dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=True, + seed=cfg.train.seed + start_iter + 1, + sampler_type=sampler_type, + sampler_advance=start_iter * dataloader_batch_size_per_gpu, + drop_last=True, + collate_fn=collate_fn, + ) + return data_loader + + +def build_multi_resolution_data_loader_from_cfg( + cfg, + model, + start_iter, + seed=65537, +): + global_crops_sizes = ( + [cfg.crops.global_crops_size] if isinstance(cfg.crops.global_crops_size, int) else cfg.crops.global_crops_size + ) + local_crops_sizes = ( + [cfg.crops.local_crops_size] if isinstance(cfg.crops.local_crops_size, int) else cfg.crops.local_crops_size + ) + gram_teacher_crops_sizes = ( + [cfg.crops.gram_teacher_crops_size] + if cfg.crops.gram_teacher_crops_size is None or isinstance(cfg.crops.gram_teacher_crops_size, int) + else cfg.crops.gram_teacher_crops_size + ) + loader_ratios = ( + [cfg.crops.global_local_crop_pairs_ratios] + if type(cfg.crops.global_local_crop_pairs_ratios) in [int, float] + else cfg.crops.global_local_crop_pairs_ratios + ) + assert len(global_crops_sizes) == len(local_crops_sizes) == len(gram_teacher_crops_sizes) == len(loader_ratios) + + loaders = [] + for increment, (global_crops_size_i, local_crops_size_i, gram_teacher_crops_size_i) in enumerate( + zip(global_crops_sizes, local_crops_sizes, gram_teacher_crops_sizes) + ): + cfg_i = copy.deepcopy(cfg) + cfg_i.crops.global_crops_size = global_crops_size_i + cfg_i.crops.local_crops_size = local_crops_size_i + cfg_i.crops.gram_teacher_crops_size = gram_teacher_crops_size_i + cfg_i.train.seed = cfg.train.seed + increment + 1 + loaders.append(build_data_loader_from_cfg(cfg=cfg_i, model=model, start_iter=start_iter)) + + if len(loaders) == 1: + data_loader = loaders[0] + else: + data_loader = CombinedDataLoader( + loaders_with_ratios=zip(loaders, loader_ratios), + batch_size=cfg.train.batch_size_per_gpu, + combining_mode=0, + seed=seed, + name="MultiResDL", + ) + return data_loader + + +def do_train(cfg, model, resume=False): + process_subgroup = distributed.get_process_subgroup() + ckpt_dir = Path(cfg.train.output_dir, "ckpt").expanduser() + ckpt_dir.mkdir(parents=True, exist_ok=True) + + model.train() + # Optimizer + optimizer = build_optimizer(cfg, model.get_params_groups()) + ( + lr_schedule, + wd_schedule, + momentum_schedule, + teacher_temp_schedule, + last_layer_lr_schedule, + ) = build_schedulers(cfg) + if cfg.multidistillation.enabled: + register_dont_save_hooks( + model, + dont_save=[k for k, _ in model.state_dict().items() if k.startswith("teacher")], + ) + model.init_weights() + start_iter = 0 + if resume and (last_checkpoint_dir := find_latest_checkpoint(ckpt_dir)): + logger.info(f"Checkpoint found {last_checkpoint_dir}") + start_iter = ( + load_checkpoint( + last_checkpoint_dir, + model=model, + optimizer=optimizer, + strict_loading=False, + process_group=process_subgroup, + ) + + 1 + ) + OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH + max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH + if cfg.multidistillation.enabled: + global_batch_size = cfg.multidistillation.global_batch_size + else: + global_batch_size = cfg.train.batch_size_per_gpu * distributed.get_world_size() + + # Build data loader + data_loader = build_multi_resolution_data_loader_from_cfg( + cfg=cfg, + model=model, + start_iter=start_iter, + ) + + # Metric logging + logger.info("Starting training from iteration %d", start_iter) + metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json") + metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file) + # Manual garbage collection + gc.disable() + gc.collect() + + # Training loop + student = model.student + iteration = start_iter + num_gram_updates = 0 + if ( + cfg.gram.use_loss + and model.has_gram_teacher + and cfg.gram.rep_update + and start_iter > 0 + and start_iter >= cfg.gram.it_first_update + ): + # If `start_iter == it_first_update`, we have performed one gram teacher update after + # iteration `start_iter - 1`, except if we are starting training from scratch and `start_iter == 0`. + num_gram_updates = math.ceil((start_iter + 1 - cfg.gram.it_first_update) / cfg.gram.update_frequency) + logger.info(f"Gram was updated {num_gram_updates} times before iteration {start_iter}") + consecutive_nan_count = 0 + for data in metric_logger.log_every( + data_loader, + print_freq=10, + header="Training", + n_iterations=max_iter, + start_iteration=start_iter, + ): + it = iteration + data["global_batch_size"] = global_batch_size + if iteration > max_iter: + return + + # Garbage collection (trigger manually so it happens on all ranks at the same time) + if (iteration + 1) % 150 == 0: + logger.info("Garbage collection") + gc.collect() + + if cfg.gram.use_loss and model.gram_it_load_ema_teacher == it: + logger.info(f"Loading EMA teacher into Gram teacher before iteration {it}") + model.gram_load_ema_teacher() + + # Learning rates and other schedules + lr = lr_schedule[it] + wd = wd_schedule[it] + mom = momentum_schedule[it] + teacher_temp = teacher_temp_schedule[it] + last_layer_lr = last_layer_lr_schedule[it] + apply_optim_scheduler(optimizer, lr, wd, last_layer_lr) + + # Forward backward + optimizer.zero_grad(set_to_none=True) + total_loss, metrics_dict = model.forward_backward(data, teacher_temp=teacher_temp, iteration=it) + + # Gradient clipping + if cfg.optim.clip_grad: + for k, v in student.items(): + grad_norm = torch.nn.utils.clip_grad_norm_( + v.parameters(), + max_norm=cfg.optim.clip_grad, + ) + metrics_dict[f"{k}_grad_norm"] = ( + grad_norm.full_tensor().item() + if isinstance(grad_norm, torch.distributed.tensor.DTensor) + else grad_norm.item() + ) + + # Reduce total_loss to check for NaNs, reduce metrics for logging + total_loss_all_ranks = total_loss.new_empty(distributed.get_subgroup_size()) + torch.distributed.all_gather_into_tensor( + total_loss_all_ranks, + total_loss.detach(), + group=distributed.get_process_subgroup(), + ) + total_loss = total_loss_all_ranks.mean() + metrics_values = torch.stack( + [torch.as_tensor(v, dtype=torch.float32, device=total_loss.device).detach() for v in metrics_dict.values()] + ) + torch.distributed.all_reduce( + metrics_values, + op=torch.distributed.ReduceOp.AVG, + group=distributed.get_process_subgroup(), + ) + metrics_dict = dict(zip(metrics_dict.keys(), metrics_values)) + if total_loss_all_ranks.isnan().any(): + consecutive_nan_count += 1 + which_ranks = total_loss_all_ranks.isnan().nonzero().flatten().tolist() + logger.warning("NaN loss detected on ranks: %s", which_ranks) + logger.warning("Consecutive NaNs: %d", consecutive_nan_count) + metrics_dict_str = "\n".join([f"{k}: {v}" for k, v in metrics_dict.items()]) + logger.warning("All-reduced metrics:\n%s", metrics_dict_str) + if consecutive_nan_count > 2 and not cfg.multidistillation.enabled: + msg = "Too many consecutive nans detected in loss, aborting..." + logger.error(msg) + raise RuntimeError(msg) + else: + consecutive_nan_count = 0 + # Step optimizer + optimizer.step() + model.update_ema(mom) + + # [GRAM] Update gram teacher when using gram teacher and frequent updates + if ( + cfg.gram.use_loss + and model.gram_rep_update + and (it + 1) >= model.gram_it_first_update + and (it + 1) % model.gram_update_frequency == 0 + and (cfg.gram.max_updates is None or num_gram_updates < cfg.gram.max_updates) + ): + logger.info(f"Updating Gram teacher from EMA teacher after iteration {it}") + model.update_gram() + num_gram_updates += 1 + + # Log metrics + metric_logger.update(lr=lr) + metric_logger.update(wd=wd) + metric_logger.update(mom=mom) + metric_logger.update(last_layer_lr=last_layer_lr) + metric_logger.update(total_loss=total_loss, **metrics_dict) + + # Submit evaluation jobs + if ( + cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0 + # and iteration != max_iter - 1 + ): + do_test(cfg, model, f"training_{iteration}", process_group=process_subgroup) + torch.cuda.synchronize() + + # Checkpointing + if (iteration + 1) % cfg.checkpointing.period == 0: + torch.cuda.synchronize() + save_checkpoint( + ckpt_dir / str(iteration), + iteration=iteration, + model=model, + optimizer=optimizer, + overwrite=True, + process_group=process_subgroup, + ) + if distributed.is_subgroup_main_process(): + keep_last_n_checkpoints(ckpt_dir, cfg.checkpointing.max_to_keep) + if "keep_every" in cfg.checkpointing and (iteration + 1) % cfg.checkpointing.keep_every == 0: + keep_checkpoint_copy(ckpt_dir / str(iteration)) + + iteration = iteration + 1 + metric_logger.synchronize_between_processes() + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +def main(argv=None): + if argv is None: + args = get_args_parser().parse_args() + else: + args = get_args_parser().parse_args(argv[1:]) + args.output_dir = sys.argv[1] + if args.multi_distillation: + print("performing multidistillation run") + cfg = setup_multidistillation(args) + torch.distributed.barrier() + logger.info("setup_multidistillation done") + assert cfg.MODEL.META_ARCHITECTURE == "MultiDistillationMetaArch" + else: + setup_job(output_dir=args.output_dir, seed=args.seed) + cfg = setup_config(args, strict_cfg=False) + logger.info(cfg) + setup_logging( + output=os.path.join(os.path.abspath(args.output_dir), "nan_logs"), + name="nan_logger", + ) + meta_arch = { + "SSLMetaArch": SSLMetaArch, + "MultiDistillationMetaArch": MultiDistillationMetaArch, + }.get(cfg.MODEL.META_ARCHITECTURE, None) + if meta_arch is None: + raise ValueError(f"Unknown MODEL.META_ARCHITECTURE {cfg.MODEL.META_ARCHITECTURE}") + logger.info(f"Making meta arch {meta_arch.__name__}") + with torch.device("meta"): + model = meta_arch(cfg) + model.prepare_for_distributed_training() + # Fill all values with `nans` so that we identify + # non-initialized values + model._apply( + lambda t: torch.full_like( + t, + fill_value=math.nan if t.dtype.is_floating_point else (2 ** (t.dtype.itemsize * 8 - 1)), + device="cuda", + ), + recurse=True, + ) + logger.info(f"Model after distributed:\n{model}") + if args.eval_only: + model.init_weights() + iteration = ( + model.get_checkpointer_class()(model, save_dir=cfg.train.output_dir) + .resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume) + .get("iteration", -1) + + 1 + ) + return do_test(cfg, model, f"manual_{iteration}") + do_train(cfg, model, resume=not args.no_resume) + + +if __name__ == "__main__": + main() diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/utils/__init__.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2794ac284e9bba24c6cee6a3eb5ecf7722f8734c --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/utils/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from .dtype import as_torch_dtype +from .utils import ( + cat_keep_shapes, + count_parameters, + fix_random_seeds, + get_conda_env, + get_sha, + named_apply, + named_replace, + uncat_with_shapes, +) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/utils/cluster.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/utils/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..25df3bc556241b1efc22908c7c832d5f3751682f --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/utils/cluster.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import os +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Optional + + +class ClusterType(Enum): + CW = "cw" + + +def _guess_cluster_type() -> ClusterType: + return ClusterType.CW + + +def get_cluster_type( + cluster_type: Optional[ClusterType] = None, +) -> Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_slurm_account(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + return { + ClusterType.CW: "fair_amaia_cw_explore", + }[cluster_type] + + +def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + CHECKPOINT_DIRNAMES = { + ClusterType.CW: "", + } + return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] + + +def get_user_checkpoint_path( + cluster_type: Optional[ClusterType] = None, +) -> Optional[Path]: + checkpoint_path = get_checkpoint_path(cluster_type) + if checkpoint_path is None: + return None + + username = os.environ.get("USER") + assert username is not None + return checkpoint_path / username + + +def get_slurm_qos(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + return { + ClusterType.CW: "explore", + }.get(cluster_type) + + +def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + SLURM_PARTITIONS = { + ClusterType.CW: "learn", + } + return SLURM_PARTITIONS[cluster_type] + + +def get_slurm_executor_parameters( + nodes: int, + num_gpus_per_node: int, + cluster_type: Optional[ClusterType] = None, + **kwargs, +) -> Dict[str, Any]: + # create default parameters + params = { + "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "slurm_partition": get_slurm_partition(cluster_type), + } + # apply cluster-specific adjustments + cluster_type = get_cluster_type(cluster_type) + if cluster_type == ClusterType.CW: + params["cpus_per_task"] = 16 + # set additional parameters / apply overrides + params.update(kwargs) + return params diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/utils/custom_callable.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/utils/custom_callable.py new file mode 100644 index 0000000000000000000000000000000000000000..cb7c2f762835a6027f94006fad3360cf19ca4be3 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/utils/custom_callable.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import contextlib +import importlib +import inspect +import os +import sys +from pathlib import Path + + +@contextlib.contextmanager +def _load_modules_from_dir(dir_: str): + sys.path.insert(0, dir_) + yield + sys.path.pop(0) + + +def load_custom_callable(module_path: str | Path, callable_name: str): + module_full_path = os.path.realpath(module_path) + assert os.path.exists(module_full_path), f"module {module_full_path} does not exist" + module_dir, module_filename = os.path.split(module_full_path) + module_name, _ = os.path.splitext(module_filename) + + with _load_modules_from_dir(module_dir): + module = importlib.import_module(module_name) + if inspect.getfile(module) != module_full_path: + importlib.reload(module) + callable_ = getattr(module, callable_name) + + return callable_ + + +@contextlib.contextmanager +def change_working_dir_and_pythonpath(new_dir): + old_dir = Path.cwd() + new_dir = Path(new_dir).expanduser().resolve().as_posix() + old_pythonpath = sys.path.copy() + sys.path.insert(0, new_dir) + os.chdir(new_dir) + try: + yield + finally: + os.chdir(old_dir) + sys.path = old_pythonpath diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/utils/dtype.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/utils/dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..b2795ed42b512ce51890c8592db5c364b18a5f4c --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/utils/dtype.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from typing import Dict, Union + +import numpy as np +import torch + +TypeSpec = Union[str, np.dtype, torch.dtype] + + +_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} + + +def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + dtype = np.dtype(dtype) + assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" + return _NUMPY_TO_TORCH_DTYPE[dtype] diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/utils/utils.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7b267033f49dfa0819406f48152f490b7c17ac94 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/dinov3/utils/utils.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +import logging +import os +import random +import subprocess +from typing import Callable, List, Optional, Tuple + +import numpy as np +import torch +from torch import Tensor, nn + +logger = logging.getLogger("dinov3") + + +def cat_keep_shapes(x_list: List[Tensor]) -> Tuple[Tensor, List[Tuple[int]], List[int]]: + shapes = [x.shape for x in x_list] + num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list] + flattened = torch.cat([x.flatten(0, -2) for x in x_list]) + return flattened, shapes, num_tokens + + +def uncat_with_shapes(flattened: Tensor, shapes: List[Tuple[int]], num_tokens: List[int]) -> List[Tensor]: + outputs_splitted = torch.split_with_sizes(flattened, num_tokens, dim=0) + shapes_adjusted = [shape[:-1] + torch.Size([flattened.shape[-1]]) for shape in shapes] + outputs_reshaped = [o.reshape(shape) for o, shape in zip(outputs_splitted, shapes_adjusted)] + return outputs_reshaped + + +def named_replace( + fn: Callable, + module: nn.Module, + name: str = "", + depth_first: bool = True, + include_root: bool = False, +) -> nn.Module: + if not depth_first and include_root: + module = fn(module=module, name=name) + for child_name_o, child_module in list(module.named_children()): + child_name = ".".join((name, child_name_o)) if name else child_name_o + new_child = named_replace( + fn=fn, + module=child_module, + name=child_name, + depth_first=depth_first, + include_root=True, + ) + setattr(module, child_name_o, new_child) + + if depth_first and include_root: + module = fn(module=module, name=name) + return module + + +def named_apply( + fn: Callable, + module: nn.Module, + name: str = "", + depth_first: bool = True, + include_root: bool = False, +) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply( + fn=fn, + module=child_module, + name=child_name, + depth_first=depth_first, + include_root=True, + ) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +def fix_random_seeds(seed: int = 31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_sha() -> str: + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommited changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def get_conda_env() -> Tuple[Optional[str], Optional[str]]: + conda_env_name = os.environ.get("CONDA_DEFAULT_ENV") + conda_env_path = os.environ.get("CONDA_PREFIX") + return conda_env_name, conda_env_path + + +def count_parameters(module: nn.Module) -> int: + c = 0 + for m in module.parameters(): + c += m.nelement() + return c + + +def has_batchnorms(model: nn.Module) -> bool: + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for _, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/hubconf.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/hubconf.py new file mode 100644 index 0000000000000000000000000000000000000000..ac08688eb0aada70c9bd1284cff1079f5f0f57e5 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/hubconf.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from dinov3.hub.backbones import ( + dinov3_convnext_base, + dinov3_convnext_large, + dinov3_convnext_small, + dinov3_convnext_tiny, + dinov3_vit7b16, + dinov3_vitb16, + dinov3_vith16plus, + dinov3_vitl16, + dinov3_vitl16plus, + dinov3_vits16, + dinov3_vits16plus, +) +from dinov3.hub.classifiers import dinov3_vit7b16_lc +from dinov3.hub.detectors import dinov3_vit7b16_de +from dinov3.hub.dinotxt import dinov3_vitl16_dinotxt_tet1280d20h24l +from dinov3.hub.segmentors import dinov3_vit7b16_ms + +from dinov3.hub.depthers import dinov3_vit7b16_dd + +dependencies = ["torch", "numpy"] diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/pyproject.toml b/depth_anything_v2_metric/depth_anything_v2/dinov3/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..ede1516806550631618dcccdc74d10b4c0035984 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/pyproject.toml @@ -0,0 +1,45 @@ +[tool.mypy] +python_version = "3.11" +ignore_missing_imports = true +files = "dinov3" +exclude = '''(?x)( + ^dinov3/tests/([^/]+/)*test_.*\.py$ # Unit tests +)''' + +[tool.pylint.master] +persistent = false +score = false + +[tool.pylint.messages_control] +disable = "all" +enable = [ + "miscellaneous", + "similarities", +] + +[tool.pylint.similarities] +ignore-comments = true +ignore-docstrings = true +ignore-imports = true +min-similarity-lines = 8 + +[tool.pylint.reports] +reports = false + +[tool.pylint.miscellaneous] +notes = [ + "FIXME", + "XXX", + "TODO", +] + +[tool.ruff] +line-length = 120 +target-version = "py311" + +[tool.ruff.lint] +ignore = ["E203", "E501"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] +"hubconf.py" = ["F401"] diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/requirements-dev.txt b/depth_anything_v2_metric/depth_anything_v2/dinov3/requirements-dev.txt new file mode 100644 index 0000000000000000000000000000000000000000..44cb0c6c6fd22ddb04add6a91dc4baa253b425e1 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/requirements-dev.txt @@ -0,0 +1,4 @@ +docstr-coverage==2.3.2 +mypy[reports]==1.17 +pylint==3.3.8 +ruff==0.12.8 diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/requirements.txt b/depth_anything_v2_metric/depth_anything_v2/dinov3/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c14d0fac8a9b2294fd2e9b5b3189f3c74b7a719e --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/requirements.txt @@ -0,0 +1,9 @@ +ftfy +omegaconf +regex +scikit-learn +submitit +termcolor +torch +torchmetrics +torchvision diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3/setup.py b/depth_anything_v2_metric/depth_anything_v2/dinov3/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..9302c1249c4c57680aa7b8a6c33582932fa22364 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3/setup.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +from pathlib import Path +import re +from typing import List, Tuple + +from setuptools import setup, find_packages + + +NAME = "dinov3" +DESCRIPTION = "" + +URL = "https://github.com/facebookresearch/dinov3" +AUTHOR = "Meta AI" +REQUIRES_PYTHON = ">=3.11" +HERE = Path(__file__).parent + + +try: + with open(HERE / "README.md", encoding="utf-8") as f: + long_description = "\n" + f.read() +except FileNotFoundError: + long_description = DESCRIPTION + + +def get_requirements(path: str = HERE / "requirements.txt") -> Tuple[List[str], List[str]]: + requirements = [] + extra_indices = [] + with open(path) as f: + for line in f.readlines(): + line = line.rstrip("\r\n") + if line.startswith("--extra-index-url "): + extra_indices.append(line[18:]) + continue + requirements.append(line) + return requirements, extra_indices + + +def get_package_version() -> str: + with open(HERE / "dinov3/__init__.py") as f: + result = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M) + if result: + return result.group(1) + raise RuntimeError("Can't get package version") + + +requirements, extra_indices = get_requirements() +version = get_package_version() +dev_requirements, _ = get_requirements(HERE / "requirements-dev.txt") + + +setup( + name=NAME, + version=version, + description=DESCRIPTION, + long_description=long_description, + long_description_content_type="text/markdown", + author=AUTHOR, + python_requires=REQUIRES_PYTHON, + url=URL, + packages=find_packages(), + package_data={ + "": ["*.yaml"], + }, + install_requires=requirements, + dependency_links=extra_indices, + extras_require={ + "dev": dev_requirements, + }, + install_package_data=True, + license="DINOv3 LIcense", + license_files=("LICENSE.md",), + classifiers=[ + # Trove classifiers: https://github.com/pypa/trove-classifiers/blob/main/src/trove_classifiers/__init__.py + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: Other/Proprietary License", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/depth_anything_v2_metric/depth_anything_v2/dinov3_adpther.py b/depth_anything_v2_metric/depth_anything_v2/dinov3_adpther.py new file mode 100644 index 0000000000000000000000000000000000000000..a77648f1ea2de23638ed6d1c40305aa1f03df3c1 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dinov3_adpther.py @@ -0,0 +1,86 @@ +# dinov3_adapter.py +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DINOv3Adapter(nn.Module): + """ + DINOv3 Adapter: + """ + + MODEL_MAP = { + "vits": "dinov3_vits16", + "vitb": "dinov3_vitb16", + "vitl": "dinov3_vitl16", + "vitg": "dinov3_vitg14", + "vit7b": "dinov3_vit7b16", + } + + def __init__(self, model_name, repo_dir, arch=None, weight_path=None): + super().__init__() + + if arch is None: + if model_name not in self.MODEL_MAP: + raise ValueError(f"Unknown model_name={model_name}, must be one of {list(self.MODEL_MAP.keys())}") + arch = self.MODEL_MAP[model_name] + + self.model = torch.hub.load(repo_dir, arch, source="local", pretrained=False) + + self.embed_dim = getattr(self.model, "embed_dim", None) + if self.embed_dim is None: + raise AttributeError("DINOv3 model missing embed_dim") + + self.patch_size = getattr(self.model, "patch_size", None) + if self.patch_size is None: + pe = getattr(self.model, "patch_embed", None) + if pe is not None and hasattr(pe, "patch_size"): + ps = pe.patch_size + self.patch_size = ps if isinstance(ps, int) else ps[0] + if self.patch_size is None: + raise AttributeError("DINOv3 model missing patch_size") + + self.blocks = getattr(self.model, "blocks", None) + if self.blocks is None: + raise AttributeError("DINOv3 model missing blocks") + + self.n_blocks = getattr(self.model, "n_blocks", len(self.blocks)) + self.depth = self.n_blocks + + self.norm = nn.LayerNorm(self.embed_dim) + + # @torch.no_grad() + def get_intermediate_layers( + self, x, n=1, return_class_token=False, norm=True + ): + outputs = self.model.get_intermediate_layers( + x, n=n, reshape=False, return_class_token=True, norm=norm + ) + + patch_maps, cls_tokens = [], [] + H, W = x.shape[-2], x.shape[-1] + h, w = H // self.patch_size, W // self.patch_size + + for (out_all, out_cls) in outputs: + if norm: + out_all = self.norm(out_all) + + out_patches = out_all[:, 1:, :] # [B, N, C] + B, N, C = out_patches.shape + sqrtN = int(N ** 0.5) + if sqrtN * sqrtN == N: + grid = out_patches.transpose(1, 2).reshape(B, C, sqrtN, sqrtN) + else: + grid = out_patches.transpose(1, 2).reshape(B, C, N, 1) + grid = F.interpolate(grid, size=(h * w, 1), mode="bilinear").squeeze(-1) + grid = grid.reshape(B, C, h, w) + + if grid.shape[-2:] != (h, w): + grid = F.interpolate(grid, size=(h, w), mode="bilinear", align_corners=False) + + patch_maps.append(grid.contiguous()) + cls_tokens.append(out_cls) + + if return_class_token: + return tuple(zip(patch_maps, cls_tokens)) + return tuple(patch_maps) diff --git a/depth_anything_v2_metric/depth_anything_v2/dpt.py b/depth_anything_v2_metric/depth_anything_v2/dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..3ca589bc9c4357ef01d4bd9ee7847bd051e7d603 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dpt.py @@ -0,0 +1,233 @@ +import cv2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Compose + +from .dinov2 import DINOv2 +from .dinov3_adpther import DINOv3Adapter +from .util.blocks import FeatureFusionBlock, _make_scratch +from .util.transform import Resize, NormalizeImage, PrepareForNet + + +def _make_fusion_block(features, use_bn, size=None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class ConvBlock(nn.Module): + def __init__(self, in_feature, out_feature): + super().__init__() + + self.conv_block = nn.Sequential( + nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_feature), + nn.ReLU(True) + ) + + def forward(self, x): + return self.conv_block(x) + + +class DPTHead(nn.Module): + def __init__( + self, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False + ): + super(DPTHead, self).__init__() + + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.Sigmoid() + ) + + def forward(self, out_features, patch_h, patch_w, patch_size=16): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) if x.dim() == 3 else None + if readout is not None: + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] if isinstance(x, (tuple, list)) else x + if x.dim() == 3: + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + elif x.dim() == 4: + if x.shape[-2] != patch_h or x.shape[-1] != patch_w: + x = F.interpolate(x, size=(patch_h, patch_w), mode="bilinear", align_corners=True) + else: + raise RuntimeError(f"Unexpected feature shape {x.shape}, expected 3D or 4D") + + x = self.projects[i](x) + x = self.resize_layers[i](x) + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + out = self.scratch.output_conv1(path_1) + out = F.interpolate(out, (int(patch_h * patch_size), int(patch_w * patch_size)), mode="bilinear", align_corners=True) + out = self.scratch.output_conv2(out) + return out + + +class DepthAnythingV2(nn.Module): + def __init__( + self, + encoder='vitl', + features=256, + out_channels=[256, 512, 1024, 1024], + use_bn=False, + use_clstoken=False, + max_depth=20.0, + dinov3_repo_dir="", # 你的本地 repo + dinov3_arch="dinov3_vitl16", # 例如 'dinov3_vitl16' + dinov3_weight="", + ): + super().__init__() + + self.intermediate_layer_idx = { + 'vits': [2, 5, 8, 11], + 'vitb': [2, 5, 8, 11], + 'vitl': [4, 11, 17, 23], + 'vitg': [9, 19, 29, 39] + } + + self.max_depth = max_depth + + self.encoder = encoder + self.pretrained = DINOv3Adapter( + model_name=encoder, + repo_dir=dinov3_repo_dir, + arch=dinov3_arch, + weight_path=dinov3_weight + ) + self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + self.mask_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + self.patch_size = int(self.pretrained.patch_size) + + def forward(self, x): + patch_size = getattr(self.pretrained, "patch_size", 16) # DINOv3=16 + patch_h, patch_w = x.shape[-2] // patch_size, x.shape[-1] // patch_size + + + features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True) + depth = self.depth_head(features, patch_h, patch_w, patch_size) * self.max_depth + mask = self.mask_head(features, patch_h, patch_w, patch_size) + + return depth.squeeze(1), mask.squeeze(1) + + @torch.no_grad() + def infer_image(self, raw_image, input_size=518): + image, (h, w) = self.image2tensor(raw_image, input_size) + depth = self.forward(image) + depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0] + + return depth.cpu().numpy() + + def image2tensor(self, raw_image, input_size=518): + transform = Compose([ + Resize( + width=input_size, + height=input_size, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=16, + resize_method='lower_bound', + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + ]) + + h, w = raw_image.shape[:2] + + image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0 + + image = transform({'image': image})['image'] + image = torch.from_numpy(image).unsqueeze(0) + + DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' + image = image.to(DEVICE) + + return image, (h, w) diff --git a/depth_anything_v2_metric/depth_anything_v2/dpt_v2.py b/depth_anything_v2_metric/depth_anything_v2/dpt_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..604e78f4d6dd3c9a084493a69d85322ce0977ab3 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/dpt_v2.py @@ -0,0 +1,225 @@ +import cv2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Compose + +from .dinov2 import DINOv2 +from .dinov3_adpther import DINOv3Adapter +from .util.blocks import FeatureFusionBlock, _make_scratch +from .util.transform import Resize, NormalizeImage, PrepareForNet + + +def _make_fusion_block(features, use_bn, size=None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class ConvBlock(nn.Module): + def __init__(self, in_feature, out_feature): + super().__init__() + + self.conv_block = nn.Sequential( + nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_feature), + nn.ReLU(True) + ) + + def forward(self, x): + return self.conv_block(x) + + +class DPTHead(nn.Module): + def __init__( + self, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False + ): + super(DPTHead, self).__init__() + + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.Sigmoid() + ) + + def forward(self, out_features, patch_h, patch_w): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv1(path_1) + out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) + out = self.scratch.output_conv2(out) + + return out + + +class DepthAnythingV2(nn.Module): + def __init__( + self, + encoder='vitl', + features=256, + out_channels=[256, 512, 1024, 1024], + use_bn=False, + use_clstoken=False, + max_depth=20.0 + ): + super(DepthAnythingV2, self).__init__() + + self.intermediate_layer_idx = { + 'vits': [2, 5, 8, 11], + 'vitb': [2, 5, 8, 11], + 'vitl': [4, 11, 17, 23], + 'vitg': [9, 19, 29, 39] + } + + self.max_depth = max_depth + + self.encoder = encoder + self.pretrained = DINOv2(model_name=encoder) + + self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + self.mask_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + + def forward(self, x): + patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14 + + features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True) + + depth = self.depth_head(features, patch_h, patch_w) * self.max_depth + mask = self.mask_head(features, patch_h, patch_w) + + return depth.squeeze(1), mask.squeeze(1) + + @torch.no_grad() + def infer_image(self, raw_image, input_size=518): + image, (h, w) = self.image2tensor(raw_image, input_size) + print(image.shape) + depth = self.forward(image) + print(depth.shape) + depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0] + + return depth.cpu().numpy() + + def image2tensor(self, raw_image, input_size=518): + transform = Compose([ + Resize( + width=input_size, + height=input_size, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=14, + resize_method='lower_bound', + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + ]) + + h, w = raw_image.shape[:2] + + image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0 + + image = transform({'image': image})['image'] + image = torch.from_numpy(image).unsqueeze(0) + + DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' + image = image.to(DEVICE) + + return image, (h, w) diff --git a/depth_anything_v2_metric/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc b/depth_anything_v2_metric/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58028b0f26473685757f88a99c7c556676ac0407 Binary files /dev/null and b/depth_anything_v2_metric/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc differ diff --git a/depth_anything_v2_metric/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc b/depth_anything_v2_metric/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8765df05b99c01a374dfb61d7b81a5a6289fab4d Binary files /dev/null and b/depth_anything_v2_metric/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc differ diff --git a/depth_anything_v2_metric/depth_anything_v2/util/blocks.py b/depth_anything_v2_metric/depth_anything_v2/util/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..382ea183a40264056142afffc201c992a2b01d37 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/util/blocks.py @@ -0,0 +1,148 @@ +import torch.nn as nn + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + + output = self.out_conv(output) + + return output diff --git a/depth_anything_v2_metric/depth_anything_v2/util/transform.py b/depth_anything_v2_metric/depth_anything_v2/util/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..b14aacd44ea086b01725a9ca68bb49eadcf37d73 --- /dev/null +++ b/depth_anything_v2_metric/depth_anything_v2/util/transform.py @@ -0,0 +1,158 @@ +import numpy as np +import cv2 + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) + + # resize sample + sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method) + + if self.__resize_target: + if "depth" in sample: + sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) + + if "mask" in sample: + sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + return sample \ No newline at end of file diff --git a/depth_anything_v2_metric/util/dist_helper.py b/depth_anything_v2_metric/util/dist_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..7b6eb432b4988638ac9549a82fbaebf968fe9c61 --- /dev/null +++ b/depth_anything_v2_metric/util/dist_helper.py @@ -0,0 +1,41 @@ +import os +import subprocess + +import torch +import torch.distributed as dist + + +def setup_distributed(backend="nccl", port=None): + """AdaHessian Optimizer + Lifted from https://github.com/BIGBALLON/distribuuuu/blob/master/distribuuuu/utils.py + Originally licensed MIT, Copyright (c) 2020 Wei Li + """ + num_gpus = torch.cuda.device_count() + + if "SLURM_JOB_ID" in os.environ: + rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NTASKS"]) + node_list = os.environ["SLURM_NODELIST"] + addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") + # specify master port + if port is not None: + os.environ["MASTER_PORT"] = str(port) + elif "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = "10685" + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = addr + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_RANK"] = str(rank % num_gpus) + os.environ["RANK"] = str(rank) + else: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + torch.cuda.set_device(rank % num_gpus) + + dist.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + ) + return rank, world_size diff --git a/depth_anything_v2_metric/util/loss.py b/depth_anything_v2_metric/util/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..2ae5b304effd46661e93ea23127d1115c36b5265 --- /dev/null +++ b/depth_anything_v2_metric/util/loss.py @@ -0,0 +1,16 @@ +import torch +from torch import nn + + +class SiLogLoss(nn.Module): + def __init__(self, lambd=0.5): + super().__init__() + self.lambd = lambd + + def forward(self, pred, target, valid_mask): + valid_mask = valid_mask.detach() + diff_log = torch.log(target[valid_mask]) - torch.log(pred[valid_mask]) + loss = torch.sqrt(torch.pow(diff_log, 2).mean() - + self.lambd * torch.pow(diff_log.mean(), 2)) + + return loss diff --git a/depth_anything_v2_metric/util/metric.py b/depth_anything_v2_metric/util/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..8638cf25875c753cb62c3977af1417c221237dce --- /dev/null +++ b/depth_anything_v2_metric/util/metric.py @@ -0,0 +1,26 @@ +import torch + + +def eval_depth(pred, target): + assert pred.shape == target.shape + + thresh = torch.max((target / pred), (pred / target)) + + d1 = torch.sum(thresh < 1.25).float() / len(thresh) + d2 = torch.sum(thresh < 1.25 ** 2).float() / len(thresh) + d3 = torch.sum(thresh < 1.25 ** 3).float() / len(thresh) + + diff = pred - target + diff_log = torch.log(pred) - torch.log(target) + + abs_rel = torch.mean(torch.abs(diff) / target) + sq_rel = torch.mean(torch.pow(diff, 2) / target) + + rmse = torch.sqrt(torch.mean(torch.pow(diff, 2))) + rmse_log = torch.sqrt(torch.mean(torch.pow(diff_log , 2))) + + log10 = torch.mean(torch.abs(torch.log10(pred) - torch.log10(target))) + silog = torch.sqrt(torch.pow(diff_log, 2).mean() - 0.5 * torch.pow(diff_log.mean(), 2)) + + return {'d1': d1.item(), 'd2': d2.item(), 'd3': d3.item(), 'abs_rel': abs_rel.item(), 'sq_rel': sq_rel.item(), + 'rmse': rmse.item(), 'rmse_log': rmse_log.item(), 'log10':log10.item(), 'silog':silog.item()} \ No newline at end of file diff --git a/depth_anything_v2_metric/util/utils.py b/depth_anything_v2_metric/util/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e89b994538c5123075605fb6130022867f37c99b --- /dev/null +++ b/depth_anything_v2_metric/util/utils.py @@ -0,0 +1,26 @@ +import os +import re +import numpy as np +import logging + +logs = set() + + +def init_log(name, level=logging.INFO): + if (name, level) in logs: + return + logs.add((name, level)) + logger = logging.getLogger(name) + logger.setLevel(level) + ch = logging.StreamHandler() + ch.setLevel(level) + if "SLURM_PROCID" in os.environ: + rank = int(os.environ["SLURM_PROCID"]) + logger.addFilter(lambda record: rank == 0) + else: + rank = 0 + format_str = "[%(asctime)s][%(levelname)8s] %(message)s" + formatter = logging.Formatter(format_str) + ch.setFormatter(formatter) + logger.addHandler(ch) + return logger diff --git a/networks/__init__.py b/networks/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..d089ad98ef9a3c9d80a347781526d0c8ba146070 --- /dev/null +++ b/networks/__init__.py @@ -0,0 +1 @@ +from .dap import DAP \ No newline at end of file diff --git a/networks/__pycache__/__init__.cpython-310.pyc b/networks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..429ad1e422ac0928e40c88c3d956b587d818f5a6 Binary files /dev/null and b/networks/__pycache__/__init__.cpython-310.pyc differ diff --git a/networks/__pycache__/dap.cpython-310.pyc b/networks/__pycache__/dap.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..820678c17b8b98236bbb1c177d70c67050b22226 Binary files /dev/null and b/networks/__pycache__/dap.cpython-310.pyc differ diff --git a/networks/__pycache__/models.cpython-310.pyc b/networks/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff07477cf4687a267dbd3b2a2998d85958ac37c0 Binary files /dev/null and b/networks/__pycache__/models.cpython-310.pyc differ diff --git a/networks/__pycache__/panda.cpython-310.pyc b/networks/__pycache__/panda.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa2afa64bb3e1cd71260937b5a13fbb1a6d08091 Binary files /dev/null and b/networks/__pycache__/panda.cpython-310.pyc differ diff --git a/networks/__pycache__/utils.cpython-310.pyc b/networks/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e6af1945fc49b66e1a1ac38de0be527c6793757 Binary files /dev/null and b/networks/__pycache__/utils.cpython-310.pyc differ diff --git a/networks/blocks.py b/networks/blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..38dbcfeffc0c38ef51bcb20dfd347e50b2a60616 --- /dev/null +++ b/networks/blocks.py @@ -0,0 +1,153 @@ +import torch.nn as nn + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + if len(in_shape) >= 4: + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output diff --git a/networks/dap.py b/networks/dap.py new file mode 100755 index 0000000000000000000000000000000000000000..efdc7f088ae02b4ab735292dd536c9eb048b129c --- /dev/null +++ b/networks/dap.py @@ -0,0 +1,117 @@ +import torch +import numpy as np +from einops import rearrange +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Compose +import cv2 +from depth_anything_v2_metric.depth_anything_v2.dpt import DepthAnythingV2 +from depth_anything_v2_metric.depth_anything_v2.dinov3_adpther import DINOv3Adapter +from argparse import Namespace +from .models import register +from depth_anything_utils import Resize, NormalizeImage, PrepareForNet + +class DAP(nn.Module): + def __init__(self, args): + super().__init__() + midas_model_type = args.midas_model_type + fine_tune_type = args.fine_tune_type + min_depth = args.min_depth + self.max_depth = args.max_depth + train_decoder = args.train_decoder + + # Pre-defined setting of the model + model_configs = { + 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, + 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, + 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, + 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} + } + + # Load the pretrained model of depth anything + dinov3_repo_dir="./depth_anything_v2_metric/depth_anything_v2/dinov3" # 你的本地 repo + dinov3_arch="dinov3_vitl16" + dinov3_weight="" + + depth_anything = DepthAnythingV2( + **{**model_configs[midas_model_type], 'max_depth': 1.0}, + dinov3_repo_dir=dinov3_repo_dir, + dinov3_arch=dinov3_arch, + dinov3_weight=dinov3_weight + ) + + + self.core = depth_anything + for param in self.core.parameters(): + param.requires_grad = True + + + def forward(self, image): + if image.dim() == 3: + image = image.unsqueeze(0) + + erp_pred, mask_pred = self.core(image) + erp_pred = erp_pred.unsqueeze(1) + erp_pred[erp_pred < 0] = 0 + mask_pred = mask_pred.unsqueeze(1) + outputs = {} + outputs["pred_depth"] = erp_pred * self.max_depth + outputs["pred_mask"] = mask_pred + + + return outputs + + def get_encoder_decoder_params(self): + encoder_params = list(self.core.pretrained.parameters()) + decoder_params = list(self.core.depth_head.parameters()) + mask_params = list(self.core.mask_head.parameters()) + + return encoder_params, decoder_params, mask_params + + @torch.no_grad() + def infer_image(self, raw_image, input_size=518): + image, (h, w) = self.image2tensor(raw_image, input_size) + + depth = self.forward(image)["pred_depth"] + + depth = F.interpolate(depth, (h, w), mode="bilinear", align_corners=True)[0, 0] + + return depth.cpu().numpy() + + def image2tensor(self, raw_image, input_size=518): + transform = Compose([ + Resize( + width=input_size * 2, + height=input_size, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=self.core.patch_size, + # ensure_multiple_of=14, + resize_method='lower_bound', + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + ]) + + h, w = raw_image.shape[:2] + + image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0 + + image = transform({'image': image})['image'] + image = torch.from_numpy(image).unsqueeze(0) + + DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' + image = image.to(DEVICE) + + return image, (h, w) + +@register('dap') +def make_model(midas_model_type='vitl', fine_tune_type='none', min_depth=0.001, max_depth=1.0, train_decoder=True): + args = Namespace() + args.midas_model_type = midas_model_type + args.fine_tune_type = fine_tune_type + args.min_depth = min_depth + args.max_depth = max_depth + args.train_decoder = train_decoder + return DAP(args) \ No newline at end of file diff --git a/networks/dpt.py b/networks/dpt.py new file mode 100755 index 0000000000000000000000000000000000000000..862b4caaa45905732111b1c6b24bdc4b39446721 --- /dev/null +++ b/networks/dpt.py @@ -0,0 +1,202 @@ +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin, hf_hub_download + +from .blocks import FeatureFusionBlock, _make_scratch + +from argparse import Namespace +from .models import register + + +def _make_fusion_block(features, use_bn, size = None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class DPTHead(nn.Module): + def __init__(self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False): + super(DPTHead, self).__init__() + + self.nclass = nclass + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + if nclass > 1: + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0), + ) + else: + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True), + nn.Identity(), + ) + + def forward(self, out_features, patch_h, patch_w): + out = [] + print(out_features) + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out_feats = [path_4, path_3, path_2, path_1] + + out = self.scratch.output_conv1(path_1) + out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) + # out_feats = out + out = self.scratch.output_conv2(out) + + # return out, out_feats + return out + + +class DPT_DINOv2(nn.Module): + def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False, localhub=True): + super(DPT_DINOv2, self).__init__() + + assert encoder in ['vits', 'vitb', 'vitl'] + + # in case the Internet connection is not stable, please load the DINOv2 locally + if localhub: + self.pretrained = torch.hub.load('torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False) + else: + self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder)) + + dim = self.pretrained.blocks[0].attn.qkv.in_features + + self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + self.mask_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + + def forward(self, x): + h, w = x.shape[-2:] + + features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True) + + patch_h, patch_w = h // 14, w // 14 + + # depth, depth_feats = self.depth_head(features, patch_h, patch_w) + depth = self.depth_head(features, patch_h, patch_w) + + mask = self.mask_head(features, patch_h, patch_w) + + depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True) + depth = F.relu(depth) + + mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=True) + mask = F.relu(mask) + # return depth, depth_feats + return depth, mask + + +class DepthAnything(DPT_DINOv2, PyTorchModelHubMixin): + def __init__(self, config): + super().__init__(**config) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + "--encoder", + default="vits", + type=str, + choices=["vits", "vitb", "vitl"], + ) + args = parser.parse_args() + + model = DepthAnything.from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder)) + + print(model) + \ No newline at end of file diff --git a/networks/models.py b/networks/models.py new file mode 100755 index 0000000000000000000000000000000000000000..136cce9202106162c9e0d0816d89a9ff9b9ccb2f --- /dev/null +++ b/networks/models.py @@ -0,0 +1,20 @@ +import copy + +models = {} + +def register(name): + def decorator(cls): + models[name] = cls + return cls + return decorator + +def make(model_spec, args=None, load_sd=False): + if args is not None: + model_args = copy.deepcopy(model_spec['args']) + model_args.update(args) + else: + model_args = model_spec['args'] + model = models[model_spec['name']](**model_args) + if load_sd: + model.load_state_dict(model_spec['sd']) + return model \ No newline at end of file diff --git a/networks/projection_utils.py b/networks/projection_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..93b01190da8b061f7c1bc7dbf03a62e1f91e4289 --- /dev/null +++ b/networks/projection_utils.py @@ -0,0 +1,478 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.ndimage import map_coordinates +import cv2 +import math +from os import makedirs +from os.path import join, exists + +# Based on https://github.com/sunset1995/py360convert +class Equirec2Cube: + def __init__(self, equ_h, equ_w, face_w): + ''' + equ_h: int, height of the equirectangular image + equ_w: int, width of the equirectangular image + face_w: int, the length of each face of the cubemap + ''' + + self.equ_h = equ_h + self.equ_w = equ_w + self.face_w = face_w + + self._xyzcube() + self._xyz2coor() + + # For convert R-distance to Z-depth for CubeMaps + cosmap = 1 / np.sqrt((2 * self.grid[..., 0]) ** 2 + (2 * self.grid[..., 1]) ** 2 + 1) + self.cosmaps = np.concatenate(6 * [cosmap], axis=1)[..., np.newaxis] + + def _xyzcube(self): + ''' + Compute the xyz cordinates of the unit cube in [F R B L U D] format. + ''' + self.xyz = np.zeros((self.face_w, self.face_w * 6, 3), np.float32) + rng = np.linspace(-0.5, 0.5, num=self.face_w, dtype=np.float32) + self.grid = np.stack(np.meshgrid(rng, -rng), -1) + + # Front face (z = 0.5) + self.xyz[:, 0 * self.face_w:1 * self.face_w, [0, 1]] = self.grid + self.xyz[:, 0 * self.face_w:1 * self.face_w, 2] = 0.5 + + # Right face (x = 0.5) + self.xyz[:, 1 * self.face_w:2 * self.face_w, [2, 1]] = self.grid[:, ::-1] + self.xyz[:, 1 * self.face_w:2 * self.face_w, 0] = 0.5 + + # Back face (z = -0.5) + self.xyz[:, 2 * self.face_w:3 * self.face_w, [0, 1]] = self.grid[:, ::-1] + self.xyz[:, 2 * self.face_w:3 * self.face_w, 2] = -0.5 + + # Left face (x = -0.5) + self.xyz[:, 3 * self.face_w:4 * self.face_w, [2, 1]] = self.grid + self.xyz[:, 3 * self.face_w:4 * self.face_w, 0] = -0.5 + + # Up face (y = 0.5) + self.xyz[:, 4 * self.face_w:5 * self.face_w, [0, 2]] = self.grid[::-1, :] + self.xyz[:, 4 * self.face_w:5 * self.face_w, 1] = 0.5 + + # Down face (y = -0.5) + self.xyz[:, 5 * self.face_w:6 * self.face_w, [0, 2]] = self.grid + self.xyz[:, 5 * self.face_w:6 * self.face_w, 1] = -0.5 + + def _xyz2coor(self): + + # x, y, z to longitude and latitude + x, y, z = np.split(self.xyz, 3, axis=-1) + lon = np.arctan2(x, z) + c = np.sqrt(x ** 2 + z ** 2) + lat = np.arctan2(y, c) + + # longitude and latitude to equirectangular coordinate + self.coor_x = (lon / (2 * np.pi) + 0.5) * self.equ_w - 0.5 + self.coor_y = (-lat / np.pi + 0.5) * self.equ_h - 0.5 + + def sample_equirec(self, e_img, order=0): + pad_u = np.roll(e_img[[0]], self.equ_w // 2, 1) + pad_d = np.roll(e_img[[-1]], self.equ_w // 2, 1) + e_img = np.concatenate([e_img, pad_d, pad_u], 0) + # pad_l = e_img[:, [0]] + # pad_r = e_img[:, [-1]] + # e_img = np.concatenate([e_img, pad_l, pad_r], 1) + + return map_coordinates(e_img, [self.coor_y, self.coor_x], + order=order, mode='wrap')[..., 0] + + def run(self, equ_img, equ_dep=None): + + h, w = equ_img.shape[:2] + if h != self.equ_h or w != self.equ_w: + equ_img = cv2.resize(equ_img, (self.equ_w, self.equ_h)) + if equ_dep is not None: + equ_dep = cv2.resize(equ_dep, (self.equ_w, self.equ_h), interpolation=cv2.INTER_NEAREST) + + cube_img = np.stack([self.sample_equirec(equ_img[..., i], order=1) + for i in range(equ_img.shape[2])], axis=-1) + + if equ_dep is not None: + cube_dep = np.stack([self.sample_equirec(equ_dep[..., i], order=0) + for i in range(equ_dep.shape[2])], axis=-1) + cube_dep = cube_dep * self.cosmaps + + if equ_dep is not None: + return cube_img, cube_dep + else: + return cube_img + +# Based on https://github.com/sunset1995/py360convert +class Cube2Equirec(nn.Module): + def __init__(self, face_w, equ_h, equ_w): + super(Cube2Equirec, self).__init__() + ''' + face_w: int, the length of each face of the cubemap + equ_h: int, height of the equirectangular image + equ_w: int, width of the equirectangular image + ''' + + self.face_w = face_w + self.equ_h = equ_h + self.equ_w = equ_w + + + # Get face id to each pixel: 0F 1R 2B 3L 4U 5D + self._equirect_facetype() + self._equirect_faceuv() + + + def _equirect_facetype(self): + ''' + 0F 1R 2B 3L 4U 5D + ''' + tp = np.roll(np.arange(4).repeat(self.equ_w // 4)[None, :].repeat(self.equ_h, 0), 3 * self.equ_w // 8, 1) + + # Prepare ceil mask + mask = np.zeros((self.equ_h, self.equ_w // 4), bool) + idx = np.linspace(-np.pi, np.pi, self.equ_w // 4) / 4 + idx = self.equ_h // 2 - np.round(np.arctan(np.cos(idx)) * self.equ_h / np.pi).astype(int) + for i, j in enumerate(idx): + mask[:j, i] = 1 + mask = np.roll(np.concatenate([mask] * 4, 1), 3 * self.equ_w // 8, 1) + + tp[mask] = 4 + tp[np.flip(mask, 0)] = 5 + + self.tp = tp + self.mask = mask + + def _equirect_faceuv(self): + + lon = ((np.linspace(0, self.equ_w -1, num=self.equ_w, dtype=np.float32 ) +0.5 ) /self.equ_w - 0.5 ) * 2 *np.pi + lat = -((np.linspace(0, self.equ_h -1, num=self.equ_h, dtype=np.float32 ) +0.5 ) /self.equ_h -0.5) * np.pi + + lon, lat = np.meshgrid(lon, lat) + + coor_u = np.zeros((self.equ_h, self.equ_w), dtype=np.float32) + coor_v = np.zeros((self.equ_h, self.equ_w), dtype=np.float32) + + for i in range(4): + mask = (self.tp == i) + coor_u[mask] = 0.5 * np.tan(lon[mask] - np.pi * i / 2) + coor_v[mask] = -0.5 * np.tan(lat[mask]) / np.cos(lon[mask] - np.pi * i / 2) + + mask = (self.tp == 4) + c = 0.5 * np.tan(np.pi / 2 - lat[mask]) + coor_u[mask] = c * np.sin(lon[mask]) + coor_v[mask] = c * np.cos(lon[mask]) + + mask = (self.tp == 5) + c = 0.5 * np.tan(np.pi / 2 - np.abs(lat[mask])) + coor_u[mask] = c * np.sin(lon[mask]) + coor_v[mask] = -c * np.cos(lon[mask]) + + # Final renormalize + coor_u = (np.clip(coor_u, -0.5, 0.5)) * 2 + coor_v = (np.clip(coor_v, -0.5, 0.5)) * 2 + + # Convert to torch tensor + self.tp = torch.from_numpy(self.tp.astype(np.float32) / 2.5 - 1) + self.coor_u = torch.from_numpy(coor_u) + self.coor_v = torch.from_numpy(coor_v) + + sample_grid = torch.stack([self.coor_u, self.coor_v, self.tp], dim=-1).view(1, 1, self.equ_h, self.equ_w, 3) + self.sample_grid = nn.Parameter(sample_grid, requires_grad=False) + + def forward(self, cube_feat): + + bs, ch, h, w = cube_feat.shape + assert h == self.face_w and w // 6 == self.face_w + + cube_feat = cube_feat.view(bs, ch, 1, h, w) + cube_feat = torch.cat(torch.split(cube_feat, self.face_w, dim=-1), dim=2) + + cube_feat = cube_feat.view([bs, ch, 6, self.face_w, self.face_w]) + sample_grid = torch.cat(bs * [self.sample_grid], dim=0) + equi_feat = F.grid_sample(cube_feat, sample_grid, padding_mode="border", align_corners=True) + + return equi_feat.squeeze(2) + +# generate patches in a closed-form +# the transformation and equation is referred from http://blog.nitishmutha.com/equirectangular/360degree/2017/06/12/How-to-project-Equirectangular-image-to-rectilinear-view.html +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +def uv2xyz(uv): + xyz = np.zeros((*uv.shape[:-1], 3), dtype = np.float32) + xyz[..., 0] = np.multiply(np.cos(uv[..., 1]), np.sin(uv[..., 0])) + xyz[..., 1] = np.multiply(np.cos(uv[..., 1]), np.cos(uv[..., 0])) + xyz[..., 2] = np.sin(uv[..., 1]) + return xyz + +def equi2pers(erp_img, fov, nrows, patch_size): + bs, _, erp_h, erp_w = erp_img.shape + height, width = pair(patch_size) + fov_h, fov_w = pair(fov) + FOV = torch.tensor([fov_w/360.0, fov_h/180.0], dtype=torch.float32) + + PI = math.pi + PI_2 = math.pi * 0.5 + PI2 = math.pi * 2 + yy, xx = torch.meshgrid(torch.linspace(0, 1, height), torch.linspace(0, 1, width)) + screen_points = torch.stack([xx.flatten(), yy.flatten()], -1) + + if nrows==4: + num_rows = 4 + num_cols = [3, 6, 6, 3] + phi_centers = [-67.5, -22.5, 22.5, 67.5] + if nrows==6: + num_rows = 6 + num_cols = [3, 8, 12, 12, 8, 3] + phi_centers = [-75.2, -45.93, -15.72, 15.72, 45.93, 75.2] + if nrows==3: + num_rows = 3 + num_cols = [3, 4, 3] + phi_centers = [-60, 0, 60] + if nrows==5: + num_rows = 5 + num_cols = [3, 6, 8, 6, 3] + phi_centers = [-72.2, -36.1, 0, 36.1, 72.2] + + phi_interval = 180 // num_rows + all_combos = [] + erp_mask = [] + for i, n_cols in enumerate(num_cols): + for j in np.arange(n_cols): + theta_interval = 360 / n_cols + theta_center = j * theta_interval + theta_interval / 2 + + center = [theta_center, phi_centers[i]] + all_combos.append(center) + up = phi_centers[i] + phi_interval / 2 + down = phi_centers[i] - phi_interval / 2 + left = theta_center - theta_interval / 2 + right = theta_center + theta_interval / 2 + up = int((up + 90) / 180 * erp_h) + down = int((down + 90) / 180 * erp_h) + left = int(left / 360 * erp_w) + right = int(right / 360 * erp_w) + mask = np.zeros((erp_h, erp_w), dtype=int) + mask[down:up, left:right] = 1 + erp_mask.append(mask) + all_combos = np.vstack(all_combos) + shifts = np.arange(all_combos.shape[0]) * width + shifts = torch.from_numpy(shifts).float() + erp_mask = np.stack(erp_mask) + erp_mask = torch.from_numpy(erp_mask).float() + num_patch = all_combos.shape[0] + + center_point = torch.from_numpy(all_combos).float() # -180 to 180, -90 to 90 + center_point[:, 0] = (center_point[:, 0]) / 360 #0 to 1 + center_point[:, 1] = (center_point[:, 1] + 90) / 180 #0 to 1 + + cp = center_point * 2 - 1 + center_p = cp.clone() + cp[:, 0] = cp[:, 0] * PI + cp[:, 1] = cp[:, 1] * PI_2 + cp = cp.unsqueeze(1) + convertedCoord = screen_points * 2 - 1 + convertedCoord[:, 0] = convertedCoord[:, 0] * PI + convertedCoord[:, 1] = convertedCoord[:, 1] * PI_2 + convertedCoord = convertedCoord * (torch.ones(screen_points.shape, dtype=torch.float32) * FOV) + convertedCoord = convertedCoord.unsqueeze(0).repeat(cp.shape[0], 1, 1) + + x = convertedCoord[:, :, 0] + y = convertedCoord[:, :, 1] + + rou = torch.sqrt(x ** 2 + y ** 2) + c = torch.atan(rou) + sin_c = torch.sin(c) + cos_c = torch.cos(c) + lat = torch.asin(cos_c * torch.sin(cp[:, :, 1]) + (y * sin_c * torch.cos(cp[:, :, 1])) / rou) + lon = cp[:, :, 0] + torch.atan2(x * sin_c, rou * torch.cos(cp[:, :, 1]) * cos_c - y * torch.sin(cp[:, :, 1]) * sin_c) + lat_new = lat / PI_2 + lon_new = lon / PI + lon_new[lon_new > 1] -= 2 + lon_new[lon_new<-1] += 2 + + lon_new = lon_new.view(1, num_patch, height, width).permute(0, 2, 1, 3).contiguous().view(height, num_patch*width) + lat_new = lat_new.view(1, num_patch, height, width).permute(0, 2, 1, 3).contiguous().view(height, num_patch*width) + grid = torch.stack([lon_new, lat_new], -1) + grid = grid.unsqueeze(0).repeat(bs, 1, 1, 1).to(erp_img.device) + pers = F.grid_sample(erp_img, grid, mode='bilinear', padding_mode='border', align_corners=True) + pers = F.unfold(pers, kernel_size=(height, width), stride=(height, width)) + pers = pers.reshape(bs, -1, height, width, num_patch) + + grid_tmp = torch.stack([lon, lat], -1) + xyz = uv2xyz(grid_tmp) + xyz = xyz.reshape(num_patch, height, width, 3).transpose(0, 3, 1, 2) + xyz = torch.from_numpy(xyz).to(pers.device).contiguous() + + uv = grid[0, ...].reshape(height, width, num_patch, 2).permute(2, 3, 0, 1) + uv = uv.contiguous() + return pers, xyz, uv, center_p + +def pers2equi(pers_img, fov, nrows, patch_size, erp_size, layer_name): + bs = pers_img.shape[0] + channel = pers_img.shape[1] + device=pers_img.device + height, width = pair(patch_size) + fov_h, fov_w = pair(fov) + erp_h, erp_w = pair(erp_size) + n_patch = pers_img.shape[-1] + grid_dir = './grid' + if not exists(grid_dir): + makedirs(grid_dir) + grid_file = join(grid_dir, layer_name + '.pth') + + if not exists(grid_file): + FOV = torch.tensor([fov_w/360.0, fov_h/180.0], dtype=torch.float32) + + PI = math.pi + PI_2 = math.pi * 0.5 + PI2 = math.pi * 2 + + if nrows==4: + num_rows = 4 + num_cols = [3, 6, 6, 3] + phi_centers = [-67.5, -22.5, 22.5, 67.5] + if nrows==6: + num_rows = 6 + num_cols = [3, 8, 12, 12, 8, 3] + phi_centers = [-75.2, -45.93, -15.72, 15.72, 45.93, 75.2] + if nrows==3: + num_rows = 3 + num_cols = [3, 4, 3] + phi_centers = [-59.6, 0, 59.6] + if nrows==5: + num_rows = 5 + num_cols = [3, 6, 8, 6, 3] + phi_centers = [-72.2, -36.1, 0, 36.1, 72.2] + phi_interval = 180 // num_rows + all_combos = [] + + for i, n_cols in enumerate(num_cols): + for j in np.arange(n_cols): + theta_interval = 360 / n_cols + theta_center = j * theta_interval + theta_interval / 2 + + center = [theta_center, phi_centers[i]] + all_combos.append(center) + + + all_combos = np.vstack(all_combos) + n_patch = all_combos.shape[0] + + center_point = torch.from_numpy(all_combos).float() # -180 to 180, -90 to 90 + center_point[:, 0] = (center_point[:, 0]) / 360 #0 to 1 + center_point[:, 1] = (center_point[:, 1] + 90) / 180 #0 to 1 + + cp = center_point * 2 - 1 + cp[:, 0] = cp[:, 0] * PI + cp[:, 1] = cp[:, 1] * PI_2 + cp = cp.unsqueeze(1) + + lat_grid, lon_grid = torch.meshgrid(torch.linspace(-PI_2, PI_2, erp_h), torch.linspace(-PI, PI, erp_w)) + lon_grid = lon_grid.float().reshape(1, -1)#.repeat(num_rows*num_cols, 1) + lat_grid = lat_grid.float().reshape(1, -1)#.repeat(num_rows*num_cols, 1) + cos_c = torch.sin(cp[..., 1]) * torch.sin(lat_grid) + torch.cos(cp[..., 1]) * torch.cos(lat_grid) * torch.cos(lon_grid - cp[..., 0]) + new_x = (torch.cos(lat_grid) * torch.sin(lon_grid - cp[..., 0])) / cos_c + new_y = (torch.cos(cp[..., 1])*torch.sin(lat_grid) - torch.sin(cp[...,1])*torch.cos(lat_grid)*torch.cos(lon_grid-cp[...,0])) / cos_c + new_x = new_x / FOV[0] / PI # -1 to 1 + new_y = new_y / FOV[1] / PI_2 + cos_c_mask = cos_c.reshape(n_patch, erp_h, erp_w) + cos_c_mask = torch.where(cos_c_mask > 0, 1, 0) + + w_list = torch.zeros((n_patch, erp_h, erp_w, 4), dtype=torch.float32) + + new_x_patch = (new_x + 1) * 0.5 * height + new_y_patch = (new_y + 1) * 0.5 * width + new_x_patch = new_x_patch.reshape(n_patch, erp_h, erp_w) + new_y_patch = new_y_patch.reshape(n_patch, erp_h, erp_w) + mask = torch.where((new_x_patch < width) & (new_x_patch > 0) & (new_y_patch < height) & (new_y_patch > 0), 1, 0) + mask *= cos_c_mask + + x0 = torch.floor(new_x_patch).type(torch.int64) + x1 = x0 + 1 + y0 = torch.floor(new_y_patch).type(torch.int64) + y1 = y0 + 1 + + x0 = torch.clamp(x0, 0, width-1) + x1 = torch.clamp(x1, 0, width-1) + y0 = torch.clamp(y0, 0, height-1) + y1 = torch.clamp(y1, 0, height-1) + + wa = (x1.type(torch.float32)-new_x_patch) * (y1.type(torch.float32)-new_y_patch) + wb = (x1.type(torch.float32)-new_x_patch) * (new_y_patch-y0.type(torch.float32)) + wc = (new_x_patch-x0.type(torch.float32)) * (y1.type(torch.float32)-new_y_patch) + wd = (new_x_patch-x0.type(torch.float32)) * (new_y_patch-y0.type(torch.float32)) + + wa = wa * mask.expand_as(wa) + wb = wb * mask.expand_as(wb) + wc = wc * mask.expand_as(wc) + wd = wd * mask.expand_as(wd) + + w_list[..., 0] = wa + w_list[..., 1] = wb + w_list[..., 2] = wc + w_list[..., 3] = wd + + + save_file = {'x0':x0, 'y0':y0, 'x1':x1, 'y1':y1, 'w_list': w_list, 'mask':mask} + torch.save(save_file, grid_file) + else: + # the online merge really takes time + # pre-calculate the grid for once and use it during training + load_file = torch.load(grid_file) + #print('load_file') + x0 = load_file['x0'] + y0 = load_file['y0'] + x1 = load_file['x1'] + y1 = load_file['y1'] + w_list = load_file['w_list'] + mask = load_file['mask'] + + w_list = w_list.to(device) + mask = mask.to(device) + z = torch.arange(n_patch) + z = z.reshape(n_patch, 1, 1) + Ia = pers_img[:, :, y0, x0, z] + Ib = pers_img[:, :, y1, x0, z] + Ic = pers_img[:, :, y0, x1, z] + Id = pers_img[:, :, y1, x1, z] + output_a = Ia * mask.expand_as(Ia) + output_b = Ib * mask.expand_as(Ib) + output_c = Ic * mask.expand_as(Ic) + output_d = Id * mask.expand_as(Id) + + output_a = output_a.permute(0, 1, 3, 4, 2) + output_b = output_b.permute(0, 1, 3, 4, 2) + output_c = output_c.permute(0, 1, 3, 4, 2) + output_d = output_d.permute(0, 1, 3, 4, 2) + w_list = w_list.permute(1, 2, 0, 3) + w_list = w_list.flatten(2) + w_list *= torch.gt(w_list, 1e-5).type(torch.float32) + w_list = F.normalize(w_list, p=1, dim=-1).reshape(erp_h, erp_w, n_patch, 4) + w_list = w_list.unsqueeze(0).unsqueeze(0) + output = output_a * w_list[..., 0] + output_b * w_list[..., 1] + \ + output_c * w_list[..., 2] + output_d * w_list[..., 3] + img_erp = output.sum(-1) + + return img_erp + +def img2windows(img, H_sp, W_sp): + """ + img: B C H W + """ + B, C, H, W = img.shape + img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) + img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp, W_sp, C) + return img_perm + +def windows2img(img_splits_hw, H_sp, W_sp, H, W): + """ + img_splits_hw: B' H W C + """ + B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp)) + + img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1) + img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return img \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..eaa14582d433910565297ee53ff628aa11c3fe96 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +gradio_imageslider +gradio>=4.0.0 +huggingface_hub>=0.20.0 +torch==2.7.1 +torchmetrics==1.8.2 +torchvision==0.22.1 +tornado==6.5.2 +opencv-python +matplotlib +einops +safetensors +open3d +tensorboardX +mmengine +pyexr \ No newline at end of file