| import torch |
| import torch.nn as nn |
| import timm |
|
|
|
|
| class TimmVisionEncoder(nn.Module): |
| def __init__(self, pretrained_encoder_name: str, load_pretrained: bool = True): |
| super().__init__() |
| self.model = timm.create_model( |
| pretrained_encoder_name, |
| pretrained=load_pretrained, |
| dynamic_img_size=True, |
| ) |
| self.hidden_size = self.model.embed_dim |
| self.model.norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False) |
| self.num_prefix_tokens = self.model.num_prefix_tokens |
| patch_size = self.model.patch_embed.patch_size |
| self.patch_size = int( |
| patch_size[0] if isinstance(patch_size, tuple) else patch_size |
| ) |
| if self.patch_size == 14: |
| input_size = 224 |
| else: |
| input_size = 256 |
| self.model.set_input_size(input_size, patch_size) |
|
|
| data_config = timm.data.resolve_model_data_config(self.model) |
| self.register_buffer( |
| "pixel_mean", |
| torch.tensor(data_config["mean"])[None, :, None, None], |
| persistent=True, |
| ) |
| self.register_buffer( |
| "pixel_std", |
| torch.tensor(data_config["std"])[None, :, None, None], |
| persistent=True, |
| ) |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| return_prefix_tokens: bool = False, |
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: |
| x = pixel_values.float() / 255.0 |
| weight_dtype = self.model.patch_embed.proj.weight.dtype |
| x = ((x - self.pixel_mean) / self.pixel_std).to(dtype=weight_dtype) |
| output = self.model.forward_features(x) |
| if isinstance(output, dict): |
| output = output["x"] |
| prefix_tokens = output[:, : self.num_prefix_tokens, :] |
| patch_tokens = output[:, self.num_prefix_tokens :, :] |
| if return_prefix_tokens: |
| return patch_tokens, prefix_tokens |
| return patch_tokens |
|
|
|
|
| class DinoV2Encoder(TimmVisionEncoder): |
| pass |
|
|
|
|
| class DinoV3Encoder(TimmVisionEncoder): |
| pass |
|
|
|
|
| ALL_ENCODERS = { |
| "dinov3_small": (DinoV3Encoder, "vit_small_patch16_dinov3.lvd1689m"), |
| "dinov3_base": (DinoV3Encoder, "vit_base_patch16_dinov3.lvd1689m"), |
| "dinov3_large": (DinoV3Encoder, "vit_large_patch16_dinov3.lvd1689m"), |
| "dinov2_small_reg": (DinoV2Encoder, "vit_small_patch14_reg4_dinov2.lvd142m"), |
| "dinov2_base_reg": (DinoV2Encoder, "vit_base_patch14_reg4_dinov2.lvd142m"), |
| "dinov2_large_reg": (DinoV2Encoder, "vit_large_patch14_reg4_dinov2.lvd142m"), |
| } |
|
|
|
|
| def build_encoder(encoder_name: str, pretrained: bool = False): |
| if encoder_name not in ALL_ENCODERS: |
| raise ValueError( |
| f"Unknown encoder {encoder_name!r}. Available: {list(ALL_ENCODERS)}" |
| ) |
| model_cls, model_id = ALL_ENCODERS[encoder_name] |
| return model_cls(model_id, load_pretrained=pretrained) |
|
|
|
|
| __all__ = [ |
| "ALL_ENCODERS", |
| "DinoV2Encoder", |
| "DinoV3Encoder", |
| "TimmVisionEncoder", |
| "build_encoder", |
| ] |
|
|