from typing import Union import torch from transformers import PreTrainedModel from configuration_vitv2 import ViTv2Config from hf_src.model.image.vitv2.transformer import ViTv2 class ViTv2PretrainedModel(PreTrainedModel): config_class = ViTv2Config def __init__(self, config: ViTv2Config): super().__init__(config) self.backbone = ViTv2( img_size=config.img_size, patch_size=config.patch_size, embed_dim=config.embed_dim, depth=config.depth, num_heads=config.num_heads, mlp_ratio=config.mlp_ratio, init_values=config.init_values, num_register_tokens=config.num_register_tokens, ) self.post_init() def forward(self, *args, **kwargs) -> dict[str, Union[torch.Tensor, None]]: return self.backbone(*args, **kwargs)