import torch from einops import rearrange from jaxtyping import Float from PIL import Image from torch import Tensor from torch import nn from transformers import AutoImageProcessor from transformers import AutoModel from transformers.feature_extraction_utils import BatchFeature __version__ = "0.1.0" TypeClsToken = Float[Tensor, "batch_size embed_dim"] TypePatchTokensFlat = Float[Tensor, "batch_size (height width) embed_dim"] TypePatchTokens = Float[Tensor, "batch_size embed_dim height width"] TypeInputImages = Image.Image | list[Image.Image] class RadDino(nn.Module): _REPO = "microsoft/rad-dino" def __init__(self): super().__init__() self.model = AutoModel.from_pretrained(self._REPO).eval() self.processor = AutoImageProcessor.from_pretrained(self._REPO, use_fast=False) @property def device(self) -> torch.device: return next(self.model.parameters()).device def preprocess(self, image_or_images: TypeInputImages) -> BatchFeature: return self.processor(image_or_images, return_tensors="pt") def encode(self, inputs: BatchFeature) -> tuple[TypeClsToken, TypePatchTokensFlat]: outputs = self.model(**inputs) cls_token = outputs.last_hidden_state[:, 0] patch_tokens = outputs.last_hidden_state[:, 1:] return cls_token, patch_tokens def reshape_patch_tokens( self, patch_tokens_flat: TypePatchTokensFlat, ) -> TypePatchTokens: input_size = self.processor.crop_size["height"] patch_size = self.model.config.patch_size embeddings_size = input_size // patch_size patches_grid = rearrange( patch_tokens_flat, "batch (height width) embed_dim -> batch embed_dim height width", height=embeddings_size, ) return patches_grid @torch.inference_mode() def extract_features( self, image_or_images: TypeInputImages, ) -> tuple[TypeClsToken, TypePatchTokens]: inputs = self.preprocess(image_or_images).to(self.device) cls_token, patch_tokens_flat = self.encode(inputs) patch_tokens = self.reshape_patch_tokens(patch_tokens_flat) return cls_token, patch_tokens def extract_cls_token(self, image_or_images: TypeInputImages) -> TypeClsToken: cls_token, _ = self.extract_features(image_or_images) return cls_token def extract_patch_tokens(self, image_or_images: TypeInputImages) -> TypePatchTokens: _, patch_tokens = self.extract_features(image_or_images) return patch_tokens def forward(self, *args) -> tuple[TypeClsToken, TypePatchTokens]: return self.extract_features(*args)