| from typing import Union | |
| import torch | |
| from transformers import PreTrainedModel | |
| from configuration_vitv2 import ViTv2Config | |
| from hf_src.model.image.vitv2.transformer import ViTv2 | |
| class ViTv2PretrainedModel(PreTrainedModel): | |
| config_class = ViTv2Config | |
| def __init__(self, config: ViTv2Config): | |
| super().__init__(config) | |
| self.backbone = ViTv2( | |
| img_size=config.img_size, | |
| patch_size=config.patch_size, | |
| embed_dim=config.embed_dim, | |
| depth=config.depth, | |
| num_heads=config.num_heads, | |
| mlp_ratio=config.mlp_ratio, | |
| init_values=config.init_values, | |
| num_register_tokens=config.num_register_tokens, | |
| ) | |
| self.post_init() | |
| def forward(self, *args, **kwargs) -> dict[str, Union[torch.Tensor, None]]: | |
| return self.backbone(*args, **kwargs) | |