import torch import torch.nn as nn import timm class TimmVisionEncoder(nn.Module): def __init__(self, pretrained_encoder_name: str, load_pretrained: bool = True): super().__init__() self.model = timm.create_model( pretrained_encoder_name, pretrained=load_pretrained, dynamic_img_size=True, ) self.hidden_size = self.model.embed_dim self.model.norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False) self.num_prefix_tokens = self.model.num_prefix_tokens patch_size = self.model.patch_embed.patch_size self.patch_size = int( patch_size[0] if isinstance(patch_size, tuple) else patch_size ) if self.patch_size == 14: input_size = 224 else: input_size = 256 self.model.set_input_size(input_size, patch_size) data_config = timm.data.resolve_model_data_config(self.model) self.register_buffer( "pixel_mean", torch.tensor(data_config["mean"])[None, :, None, None], persistent=True, ) self.register_buffer( "pixel_std", torch.tensor(data_config["std"])[None, :, None, None], persistent=True, ) def forward( self, pixel_values: torch.Tensor, return_prefix_tokens: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: x = pixel_values.float() / 255.0 weight_dtype = self.model.patch_embed.proj.weight.dtype x = ((x - self.pixel_mean) / self.pixel_std).to(dtype=weight_dtype) output = self.model.forward_features(x) if isinstance(output, dict): output = output["x"] prefix_tokens = output[:, : self.num_prefix_tokens, :] patch_tokens = output[:, self.num_prefix_tokens :, :] if return_prefix_tokens: return patch_tokens, prefix_tokens return patch_tokens class DinoV2Encoder(TimmVisionEncoder): pass class DinoV3Encoder(TimmVisionEncoder): pass ALL_ENCODERS = { "dinov3_small": (DinoV3Encoder, "vit_small_patch16_dinov3.lvd1689m"), "dinov3_base": (DinoV3Encoder, "vit_base_patch16_dinov3.lvd1689m"), "dinov3_large": (DinoV3Encoder, "vit_large_patch16_dinov3.lvd1689m"), "dinov2_small_reg": (DinoV2Encoder, "vit_small_patch14_reg4_dinov2.lvd142m"), "dinov2_base_reg": (DinoV2Encoder, "vit_base_patch14_reg4_dinov2.lvd142m"), "dinov2_large_reg": (DinoV2Encoder, "vit_large_patch14_reg4_dinov2.lvd142m"), } def build_encoder(encoder_name: str, pretrained: bool = False): if encoder_name not in ALL_ENCODERS: raise ValueError( f"Unknown encoder {encoder_name!r}. Available: {list(ALL_ENCODERS)}" ) model_cls, model_id = ALL_ENCODERS[encoder_name] return model_cls(model_id, load_pretrained=pretrained) __all__ = [ "ALL_ENCODERS", "DinoV2Encoder", "DinoV3Encoder", "TimmVisionEncoder", "build_encoder", ]