ViT-Patch-PCA-Visualisation / modelling_vitv2.py
Tenbatsu24
add: missing files
a10ce46
Raw
History Blame Contribute Delete
902 Bytes
from typing import Union
import torch
from transformers import PreTrainedModel
from configuration_vitv2 import ViTv2Config
from hf_src.model.image.vitv2.transformer import ViTv2
class ViTv2PretrainedModel(PreTrainedModel):
config_class = ViTv2Config
def __init__(self, config: ViTv2Config):
super().__init__(config)
self.backbone = ViTv2(
img_size=config.img_size,
patch_size=config.patch_size,
embed_dim=config.embed_dim,
depth=config.depth,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio,
init_values=config.init_values,
num_register_tokens=config.num_register_tokens,
)
self.post_init()
def forward(self, *args, **kwargs) -> dict[str, Union[torch.Tensor, None]]:
return self.backbone(*args, **kwargs)