| | from distutils.version import LooseVersion |
| | from types import MethodType |
| | from typing import List, Optional, Tuple, Union |
| | import warnings |
| |
|
| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| |
|
| | from timm.models.registry import register_model |
| | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
| |
|
| | from .forward_intermediates import forward_intermediates |
| | from .input_conditioner import InputConditioner |
| |
|
| | _has_torch_sdpa = hasattr(F, 'scaled_dot_product_attention') |
| |
|
| |
|
| | class PaliGemmaWrapper(nn.Module): |
| | def __init__(self, vis_model: nn.Module, embed_dim: int): |
| | super().__init__() |
| |
|
| | self.vis_model = vis_model |
| | self.embed_dim = embed_dim |
| |
|
| | @property |
| | def patch_size(self): |
| | return self.vis_model.embeddings.patch_size |
| |
|
| | @property |
| | def blocks(self): |
| | return self.vis_model.encoder.layers |
| |
|
| | @property |
| | def embed_dim(self): |
| | return self.vis_model.embeddings.embed_dim |
| |
|
| | def forward(self, x: torch.Tensor): |
| | outputs = self.vis_model( |
| | x, |
| | return_dict=False, |
| | interpolate_pos_encoding=True, |
| | ) |
| |
|
| | features = outputs[0].to(torch.float32) |
| |
|
| | summary = features.mean(dim=1) |
| |
|
| | return summary, features |
| |
|
| | def forward_features(self, x: torch.Tensor): |
| | return self(x) |
| |
|
| |
|
| | def _get_paligemma_model(repo: str, embed_dim: int = None, dtype: torch.dtype = torch.bfloat16): |
| | from transformers import PaliGemmaForConditionalGeneration, __version__ as tx_version |
| |
|
| | if LooseVersion(tx_version) > LooseVersion('4.44.2'): |
| | warnings.warn(f'Your transformers version "{tx_version}" is higher than 4.44.2, and for whatever reason, PaliGemma might be broken.') |
| |
|
| | extra_args = dict() |
| |
|
| | if dtype is not None: |
| | extra_args['torch_dtype'] = dtype |
| | rev = str(dtype).split('.')[-1] |
| | extra_args['revision'] = rev |
| |
|
| | model = PaliGemmaForConditionalGeneration.from_pretrained(repo, **extra_args) |
| |
|
| | vis_model = model.vision_tower.vision_model |
| |
|
| | vis_model = PaliGemmaWrapper(vis_model, embed_dim) |
| |
|
| | return vis_model |
| |
|
| | @register_model |
| | def paligemma_896_student(**kwargs): |
| | model = _get_paligemma_model('google/paligemma-3b-pt-896', embed_dim=1152, dtype=None) |
| |
|
| | return model |
| |
|
| |
|
| | def dv2_sdpa(self, x: torch.Tensor) -> torch.Tensor: |
| | B, N, C = x.shape |
| | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| |
|
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| | x = F.scaled_dot_product_attention( |
| | q, k, v, |
| | is_causal=False, |
| | dropout_p=self.attn_drop.p if self.training else 0., |
| | scale=self.scale, |
| | ) |
| | x = x.transpose(1, 2).reshape(B, N, C) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| | return x |
| |
|
| | def _load_dino_v2(dino_v2_model, cache_dir: Optional[str] = None, pretrained=True, **kwargs): |
| | if cache_dir: |
| | torch.hub.set_dir(cache_dir) |
| | model: nn.Module = torch.hub.load( |
| | 'facebookresearch/dinov2', |
| | dino_v2_model, |
| | pretrained=pretrained, |
| | |
| | ) |
| |
|
| | if _has_torch_sdpa: |
| | for n, m in model.named_modules(): |
| | if n.endswith('.attn'): |
| | m.forward = MethodType(dv2_sdpa, m) |
| |
|
| | return model |
| |
|
| | class DinoWrapper(nn.Module): |
| | def __init__(self, dino_model: nn.Module): |
| | super().__init__() |
| |
|
| | self.inner = dino_model |
| | dino_model.blocks = nn.Sequential(*dino_model.blocks) |
| |
|
| | @property |
| | def embed_dim(self): |
| | return self.inner.embed_dim |
| |
|
| | @property |
| | def patch_size(self): |
| | return self.inner.patch_size |
| |
|
| | @property |
| | def num_cls_tokens(self): |
| | return getattr(self.inner, 'num_tokens', 1) |
| |
|
| | @property |
| | def num_registers(self): |
| | return getattr(self.inner, 'num_register_tokens', 0) |
| |
|
| | @property |
| | def num_summary_tokens(self): |
| | return self.num_cls_tokens + self.num_registers |
| |
|
| | @property |
| | def blocks(self): |
| | return self.inner.blocks |
| |
|
| | def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: |
| | parts = self.inner.forward_features(*args, **kwargs) |
| |
|
| | cls_token = parts['x_norm_clstoken'] |
| | features = parts['x_norm_patchtokens'] |
| |
|
| | return cls_token, features |
| |
|
| | def forward_features(self, x: torch.Tensor): |
| | x = self.inner.prepare_tokens_with_masks(x) |
| | x = self.inner.blocks(x) |
| | x_norm = self.inner.norm(x) |
| |
|
| | return x_norm[:, 0], x_norm[:, self.num_summary_tokens:] |
| |
|
| | def patchify(self, x: torch.Tensor) -> torch.Tensor: |
| | return self.inner.prepare_tokens_with_masks(x) |
| |
|
| | def forward_intermediates(self, |
| | x: torch.Tensor, |
| | norm: bool = False, |
| | **kwargs, |
| | ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: |
| | return forward_intermediates( |
| | self, |
| | patch_extractor=self.inner.prepare_tokens_with_masks, |
| | num_summary_tokens=self.num_summary_tokens, |
| | num_cls_tokens=self.num_cls_tokens, |
| | norm=self.inner.norm if norm else lambda y: y, |
| | x=x, |
| | **kwargs, |
| | ) |
| |
|
| |
|
| | def _dino_student(arch: str, **kwargs): |
| | from . import dinov2_arch |
| |
|
| | factory = getattr(dinov2_arch, arch) |
| | model = factory() |
| |
|
| | model = DinoWrapper(model) |
| |
|
| | conditioner = InputConditioner( |
| | input_scale=1.0, |
| | norm_mean=IMAGENET_DEFAULT_MEAN, |
| | norm_std=IMAGENET_DEFAULT_STD, |
| | ) |
| |
|
| | model.input_conditioner = conditioner |
| |
|
| | return model |
| |
|
| |
|
| | @register_model |
| | def dino_v2_l_student(**kwargs): |
| | return _dino_student('dinov2_vitl14_reg', **kwargs) |
| |
|
| | @register_model |
| | def dino_v2_g_student(**kwargs): |
| | return _dino_student('dinov2_vitg14_reg', **kwargs) |
| |
|