"""TIPSv2 image encoder for HuggingFace.""" from dataclasses import dataclass from typing import Any, Optional import torch from torch.nn.attention.flex_attention import BlockMask from transformers import AutoConfig, AutoModel, PreTrainedModel from .configuration_tips import TIPSv2ImageConfig from .image_encoder import ( VisionTransformer, vit_base, vit_giant2, vit_large, vit_small, vit_so400m, ) MODEL_INIT_FUNCTIONS = { "vit_small": vit_small, "vit_base": vit_base, "vit_large": vit_large, "vit_so400m": vit_so400m, "vit_giant2": vit_giant2, } @dataclass class TIPSv2ImageOutput: last_hidden_state: torch.Tensor cls_token: torch.Tensor register_tokens: torch.Tensor patch_tokens: torch.Tensor class TIPSv2ImageModel(PreTrainedModel): config_class = TIPSv2ImageConfig base_model_prefix = "model" all_tied_weights_keys = {} def __init__(self, config: TIPSv2ImageConfig): super().__init__(config) if config.model_variant not in MODEL_INIT_FUNCTIONS: raise ValueError( f"Unknown model_variant={config.model_variant!r}. " f"Expected one of {list(MODEL_INIT_FUNCTIONS)}." ) build_fn = MODEL_INIT_FUNCTIONS[config.model_variant] self.model: VisionTransformer = build_fn( image_size=config.image_size, patch_size=config.patch_size, ffn_layer=config.ffn_layer, init_values=config.init_values, num_register_tokens=config.num_register_tokens, ) def forward( self, pixel_values: torch.Tensor, input_ids: torch.Tensor, position_ids: torch.Tensor, grid_sizes: torch.Tensor, document_ids: torch.Tensor, block_mask: Optional[BlockMask] = None, **kwargs: Any, ) -> TIPSv2ImageOutput: del kwargs cls_token, register_tokens, patch_tokens, last_hidden_state = self.model( pixel_values=pixel_values, input_ids=input_ids, position_ids=position_ids, grid_sizes=grid_sizes, document_ids=document_ids, block_mask=block_mask, ) return TIPSv2ImageOutput( last_hidden_state=last_hidden_state, cls_token=cls_token, register_tokens=register_tokens, patch_tokens=patch_tokens, ) AutoConfig.register("tipsv2", TIPSv2ImageConfig, exist_ok=True) AutoModel.register(TIPSv2ImageConfig, TIPSv2ImageModel, exist_ok=True)