| import torch | |
| import timm | |
| from transformers import PreTrainedModel | |
| from .configuration_vitmodel import ViTConfig | |
| class VitMemModel(PreTrainedModel): | |
| config_class = ViTConfig | |
| def __init__(self, config: ViTConfig): | |
| super().__init__(config) | |
| self.model = timm.create_model("vit_base_patch16_224_miil", pretrained=False, num_classes=1) | |
| def forward(self, tensor, labels=None): | |
| vitfeat = self.model(tensor) | |
| out = torch.sigmoid(vitfeat) | |
| return out |