tipsv2-b14-vision / modeling_tips.py
toilaluan's picture
update
d1941eb
"""TIPSv2 image encoder for HuggingFace."""
from dataclasses import dataclass
from typing import Any, Optional
import torch
from torch.nn.attention.flex_attention import BlockMask
from transformers import AutoConfig, AutoModel, PreTrainedModel
from .configuration_tips import TIPSv2ImageConfig
from .image_encoder import (
VisionTransformer,
vit_base,
vit_giant2,
vit_large,
vit_small,
vit_so400m,
)
MODEL_INIT_FUNCTIONS = {
"vit_small": vit_small,
"vit_base": vit_base,
"vit_large": vit_large,
"vit_so400m": vit_so400m,
"vit_giant2": vit_giant2,
}
@dataclass
class TIPSv2ImageOutput:
last_hidden_state: torch.Tensor
cls_token: torch.Tensor
register_tokens: torch.Tensor
patch_tokens: torch.Tensor
class TIPSv2ImageModel(PreTrainedModel):
config_class = TIPSv2ImageConfig
base_model_prefix = "model"
all_tied_weights_keys = {}
def __init__(self, config: TIPSv2ImageConfig):
super().__init__(config)
if config.model_variant not in MODEL_INIT_FUNCTIONS:
raise ValueError(
f"Unknown model_variant={config.model_variant!r}. "
f"Expected one of {list(MODEL_INIT_FUNCTIONS)}."
)
build_fn = MODEL_INIT_FUNCTIONS[config.model_variant]
self.model: VisionTransformer = build_fn(
image_size=config.image_size,
patch_size=config.patch_size,
ffn_layer=config.ffn_layer,
init_values=config.init_values,
num_register_tokens=config.num_register_tokens,
)
def forward(
self,
pixel_values: torch.Tensor,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
grid_sizes: torch.Tensor,
document_ids: torch.Tensor,
block_mask: Optional[BlockMask] = None,
**kwargs: Any,
) -> TIPSv2ImageOutput:
del kwargs
cls_token, register_tokens, patch_tokens, last_hidden_state = self.model(
pixel_values=pixel_values,
input_ids=input_ids,
position_ids=position_ids,
grid_sizes=grid_sizes,
document_ids=document_ids,
block_mask=block_mask,
)
return TIPSv2ImageOutput(
last_hidden_state=last_hidden_state,
cls_token=cls_token,
register_tokens=register_tokens,
patch_tokens=patch_tokens,
)
AutoConfig.register("tipsv2", TIPSv2ImageConfig, exist_ok=True)
AutoModel.register(TIPSv2ImageConfig, TIPSv2ImageModel, exist_ok=True)