tipsv2-b14-vision-module / modeling_tips.py
nebulette's picture
Upload 6 files
28d6428 verified
raw
history blame
1.85 kB
"""TIPSv2 image encoder for HuggingFace."""
from dataclasses import dataclass
import torch
from transformers import AutoConfig, AutoModel, PreTrainedModel
from .configuration_tips import TIPSv2ImageConfig
from .image_encoder import (
VisionTransformer,
vit_base,
vit_giant2,
vit_large,
vit_small,
vit_so400m,
)
MODEL_INIT_FUNCTIONS = {
"vit_small": vit_small,
"vit_base": vit_base,
"vit_large": vit_large,
"vit_so400m": vit_so400m,
"vit_giant2": vit_giant2,
}
@dataclass
class TIPSv2ImageOutput:
cls_token: torch.Tensor
register_tokens: torch.Tensor
patch_tokens: torch.Tensor
class TIPSv2ImageModel(PreTrainedModel):
config_class = TIPSv2ImageConfig
base_model_prefix = "model"
all_tied_weights_keys = dict()
def __init__(self, config: TIPSv2ImageConfig):
super().__init__(config)
if config.model_variant not in MODEL_INIT_FUNCTIONS:
raise ValueError(
f"Unknown model_variant={config.model_variant!r}. "
f"Expected one of {list(MODEL_INIT_FUNCTIONS)}."
)
build_fn = MODEL_INIT_FUNCTIONS[config.model_variant]
self.model: VisionTransformer = build_fn(
image_size=config.image_size,
patch_size=config.patch_size,
ffn_layer=config.ffn_layer,
init_values=config.init_values,
)
def forward(self, pixel_values: torch.Tensor) -> TIPSv2ImageOutput:
cls_token, register_tokens, patch_tokens = self.model(pixel_values)
return TIPSv2ImageOutput(
cls_token=cls_token,
register_tokens=register_tokens,
patch_tokens=patch_tokens,
)
AutoConfig.register("tipsv2", TIPSv2ImageConfig, exist_ok=True)
AutoModel.register(TIPSv2ImageConfig, TIPSv2ImageModel, exist_ok=True)