w2 / vision_encoder.py
toilaluan's picture
Upload folder using huggingface_hub
8a6e75e verified
Raw
History Blame Contribute Delete
3.03 kB
import torch
import torch.nn as nn
import timm
class TimmVisionEncoder(nn.Module):
def __init__(self, pretrained_encoder_name: str, load_pretrained: bool = True):
super().__init__()
self.model = timm.create_model(
pretrained_encoder_name,
pretrained=load_pretrained,
dynamic_img_size=True,
)
self.hidden_size = self.model.embed_dim
self.model.norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False)
self.num_prefix_tokens = self.model.num_prefix_tokens
patch_size = self.model.patch_embed.patch_size
self.patch_size = int(
patch_size[0] if isinstance(patch_size, tuple) else patch_size
)
if self.patch_size == 14:
input_size = 224
else:
input_size = 256
self.model.set_input_size(input_size, patch_size)
data_config = timm.data.resolve_model_data_config(self.model)
self.register_buffer(
"pixel_mean",
torch.tensor(data_config["mean"])[None, :, None, None],
persistent=True,
)
self.register_buffer(
"pixel_std",
torch.tensor(data_config["std"])[None, :, None, None],
persistent=True,
)
def forward(
self,
pixel_values: torch.Tensor,
return_prefix_tokens: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
x = pixel_values.float() / 255.0
weight_dtype = self.model.patch_embed.proj.weight.dtype
x = ((x - self.pixel_mean) / self.pixel_std).to(dtype=weight_dtype)
output = self.model.forward_features(x)
if isinstance(output, dict):
output = output["x"]
prefix_tokens = output[:, : self.num_prefix_tokens, :]
patch_tokens = output[:, self.num_prefix_tokens :, :]
if return_prefix_tokens:
return patch_tokens, prefix_tokens
return patch_tokens
class DinoV2Encoder(TimmVisionEncoder):
pass
class DinoV3Encoder(TimmVisionEncoder):
pass
ALL_ENCODERS = {
"dinov3_small": (DinoV3Encoder, "vit_small_patch16_dinov3.lvd1689m"),
"dinov3_base": (DinoV3Encoder, "vit_base_patch16_dinov3.lvd1689m"),
"dinov3_large": (DinoV3Encoder, "vit_large_patch16_dinov3.lvd1689m"),
"dinov2_small_reg": (DinoV2Encoder, "vit_small_patch14_reg4_dinov2.lvd142m"),
"dinov2_base_reg": (DinoV2Encoder, "vit_base_patch14_reg4_dinov2.lvd142m"),
"dinov2_large_reg": (DinoV2Encoder, "vit_large_patch14_reg4_dinov2.lvd142m"),
}
def build_encoder(encoder_name: str, pretrained: bool = False):
if encoder_name not in ALL_ENCODERS:
raise ValueError(
f"Unknown encoder {encoder_name!r}. Available: {list(ALL_ENCODERS)}"
)
model_cls, model_id = ALL_ENCODERS[encoder_name]
return model_cls(model_id, load_pretrained=pretrained)
__all__ = [
"ALL_ENCODERS",
"DinoV2Encoder",
"DinoV3Encoder",
"TimmVisionEncoder",
"build_encoder",
]