"""TIPSv2 image encoder for HuggingFace.""" from dataclasses import dataclass import torch from transformers import AutoConfig, AutoModel, PreTrainedModel from .configuration_tips import TIPSv2ImageConfig from .image_encoder import ( VisionTransformer, vit_base, vit_giant2, vit_large, vit_small, vit_so400m, ) MODEL_INIT_FUNCTIONS = { "vit_small": vit_small, "vit_base": vit_base, "vit_large": vit_large, "vit_so400m": vit_so400m, "vit_giant2": vit_giant2, } @dataclass class TIPSv2ImageOutput: cls_token: torch.Tensor register_tokens: torch.Tensor patch_tokens: torch.Tensor class TIPSv2ImageModel(PreTrainedModel): config_class = TIPSv2ImageConfig base_model_prefix = "model" all_tied_weights_keys = dict() def __init__(self, config: TIPSv2ImageConfig): super().__init__(config) if config.model_variant not in MODEL_INIT_FUNCTIONS: raise ValueError( f"Unknown model_variant={config.model_variant!r}. " f"Expected one of {list(MODEL_INIT_FUNCTIONS)}." ) build_fn = MODEL_INIT_FUNCTIONS[config.model_variant] self.model: VisionTransformer = build_fn( image_size=config.image_size, patch_size=config.patch_size, ffn_layer=config.ffn_layer, init_values=config.init_values, ) def forward(self, pixel_values: torch.Tensor) -> TIPSv2ImageOutput: cls_token, register_tokens, patch_tokens = self.model(pixel_values) return TIPSv2ImageOutput( cls_token=cls_token, register_tokens=register_tokens, patch_tokens=patch_tokens, ) AutoConfig.register("tipsv2", TIPSv2ImageConfig, exist_ok=True) AutoModel.register(TIPSv2ImageConfig, TIPSv2ImageModel, exist_ok=True)