| import torch |
| from einops import rearrange |
| from jaxtyping import Float |
| from PIL import Image |
| from torch import Tensor |
| from torch import nn |
| from transformers import AutoImageProcessor |
| from transformers import AutoModel |
| from transformers.feature_extraction_utils import BatchFeature |
|
|
|
|
| __version__ = "0.1.0" |
|
|
| TypeClsToken = Float[Tensor, "batch_size embed_dim"] |
| TypePatchTokensFlat = Float[Tensor, "batch_size (height width) embed_dim"] |
| TypePatchTokens = Float[Tensor, "batch_size embed_dim height width"] |
| TypeInputImages = Image.Image | list[Image.Image] |
|
|
|
|
| class RadDino(nn.Module): |
| _REPO = "microsoft/rad-dino" |
|
|
| def __init__(self): |
| super().__init__() |
| self.model = AutoModel.from_pretrained(self._REPO).eval() |
| self.processor = AutoImageProcessor.from_pretrained(self._REPO, use_fast=False) |
|
|
| @property |
| def device(self) -> torch.device: |
| return next(self.model.parameters()).device |
|
|
| def preprocess(self, image_or_images: TypeInputImages) -> BatchFeature: |
| return self.processor(image_or_images, return_tensors="pt") |
|
|
| def encode(self, inputs: BatchFeature) -> tuple[TypeClsToken, TypePatchTokensFlat]: |
| outputs = self.model(**inputs) |
| cls_token = outputs.last_hidden_state[:, 0] |
| patch_tokens = outputs.last_hidden_state[:, 1:] |
| return cls_token, patch_tokens |
|
|
| def reshape_patch_tokens( |
| self, |
| patch_tokens_flat: TypePatchTokensFlat, |
| ) -> TypePatchTokens: |
| input_size = self.processor.crop_size["height"] |
| patch_size = self.model.config.patch_size |
| embeddings_size = input_size // patch_size |
| patches_grid = rearrange( |
| patch_tokens_flat, |
| "batch (height width) embed_dim -> batch embed_dim height width", |
| height=embeddings_size, |
| ) |
| return patches_grid |
|
|
| @torch.inference_mode() |
| def extract_features( |
| self, |
| image_or_images: TypeInputImages, |
| ) -> tuple[TypeClsToken, TypePatchTokens]: |
| inputs = self.preprocess(image_or_images).to(self.device) |
| cls_token, patch_tokens_flat = self.encode(inputs) |
| patch_tokens = self.reshape_patch_tokens(patch_tokens_flat) |
| return cls_token, patch_tokens |
|
|
| def extract_cls_token(self, image_or_images: TypeInputImages) -> TypeClsToken: |
| cls_token, _ = self.extract_features(image_or_images) |
| return cls_token |
|
|
| def extract_patch_tokens(self, image_or_images: TypeInputImages) -> TypePatchTokens: |
| _, patch_tokens = self.extract_features(image_or_images) |
| return patch_tokens |
|
|
| def forward(self, *args) -> tuple[TypeClsToken, TypePatchTokens]: |
| return self.extract_features(*args) |
|
|