import timm import torch import torch.nn as nn from PIL import Image from timm.models.eva import Eva class PerceptionEncoder(nn.Module): def __init__( self, model_name="vit_pe_core_large_patch14_336", img_size: None | int = None, ): super().__init__() if img_size is not None: dynamic_img_size = True self.backbone: Eva = timm.create_model( model_name, pretrained=True, dynamic_img_size=dynamic_img_size, ) # Get model specific transforms (normalization, resize) data_config = timm.data.resolve_model_data_config(self.backbone) if img_size is not None: data_config["input_size"] = (3, img_size, img_size) self._preprocess = timm.data.create_transform(**data_config, is_training=False) # Remove head self.backbone.head = nn.Identity() self.features_dim = self.backbone.num_features def preprocess(self, image: Image.Image) -> torch.Tensor: return self._preprocess(image) def forward(self, inputs: torch.Tensor) -> torch.Tensor: return self.backbone(inputs) def get_features_dim(self) -> int: return self.features_dim if __name__ == "__main__": import autorootcwd # noqa: F401 from src.config import Backbone from src.encoders._common import inference model = PerceptionEncoder(Backbone.PerceptionEncoder_B_p16_224.value, img_size=224) inference(model)