Spaces:
Running
Running
| import torch.nn as nn | |
| from transformers import AutoModel | |
| class DiseaseModel(nn.Module): | |
| def __init__(self, backbone_id, num_disease_classes): | |
| super().__init__() | |
| self.backbone = AutoModel.from_pretrained(backbone_id) | |
| hidden = self.backbone.config.hidden_size | |
| self.head = nn.Linear(hidden, num_disease_classes) | |
| def forward(self, pixel_values): | |
| feat = self.backbone(pixel_values=pixel_values).last_hidden_state[:, 0] | |
| return self.head(feat) |