| """TIPSv2 image encoder for HuggingFace.""" |
|
|
| from dataclasses import dataclass |
| from typing import Any, Optional |
|
|
| import torch |
| from torch.nn.attention.flex_attention import BlockMask |
| from transformers import AutoConfig, AutoModel, PreTrainedModel |
|
|
| from .configuration_tips import TIPSv2ImageConfig |
| from .image_encoder import ( |
| VisionTransformer, |
| vit_base, |
| vit_giant2, |
| vit_large, |
| vit_small, |
| vit_so400m, |
| ) |
|
|
|
|
| MODEL_INIT_FUNCTIONS = { |
| "vit_small": vit_small, |
| "vit_base": vit_base, |
| "vit_large": vit_large, |
| "vit_so400m": vit_so400m, |
| "vit_giant2": vit_giant2, |
| } |
|
|
|
|
| @dataclass |
| class TIPSv2ImageOutput: |
| last_hidden_state: torch.Tensor |
| cls_token: torch.Tensor |
| register_tokens: torch.Tensor |
| patch_tokens: torch.Tensor |
|
|
|
|
| class TIPSv2ImageModel(PreTrainedModel): |
| config_class = TIPSv2ImageConfig |
| base_model_prefix = "model" |
| all_tied_weights_keys = {} |
|
|
| def __init__(self, config: TIPSv2ImageConfig): |
| super().__init__(config) |
|
|
| if config.model_variant not in MODEL_INIT_FUNCTIONS: |
| raise ValueError( |
| f"Unknown model_variant={config.model_variant!r}. " |
| f"Expected one of {list(MODEL_INIT_FUNCTIONS)}." |
| ) |
|
|
| build_fn = MODEL_INIT_FUNCTIONS[config.model_variant] |
| self.model: VisionTransformer = build_fn( |
| image_size=config.image_size, |
| patch_size=config.patch_size, |
| ffn_layer=config.ffn_layer, |
| init_values=config.init_values, |
| num_register_tokens=config.num_register_tokens, |
| ) |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| input_ids: torch.Tensor, |
| position_ids: torch.Tensor, |
| grid_sizes: torch.Tensor, |
| document_ids: torch.Tensor, |
| block_mask: Optional[BlockMask] = None, |
| **kwargs: Any, |
| ) -> TIPSv2ImageOutput: |
| del kwargs |
| cls_token, register_tokens, patch_tokens, last_hidden_state = self.model( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| position_ids=position_ids, |
| grid_sizes=grid_sizes, |
| document_ids=document_ids, |
| block_mask=block_mask, |
| ) |
| return TIPSv2ImageOutput( |
| last_hidden_state=last_hidden_state, |
| cls_token=cls_token, |
| register_tokens=register_tokens, |
| patch_tokens=patch_tokens, |
| ) |
|
|
|
|
| AutoConfig.register("tipsv2", TIPSv2ImageConfig, exist_ok=True) |
| AutoModel.register(TIPSv2ImageConfig, TIPSv2ImageModel, exist_ok=True) |
|
|