import torch import torch.nn as nn from torchvision import models from transformers import PreTrainedModel, AutoConfig class EyeDiseaseEfficientNetConfig(AutoConfig): model_type = "EyeDiseaseEfficientNet" def __init__(self, num_labels=8, **kwargs): super().__init__(**kwargs) self.num_labels = num_labels class EyeDiseaseEfficientNet(PreTrainedModel): config_class = EyeDiseaseEfficientNetConfig def __init__(self, config): super().__init__(config) self.efficientnet = models.efficientnet_b4(pretrained=True) self.efficientnet.classifier = nn.Identity() for param in self.efficientnet.features[-2:].parameters(): param.requires_grad = True self.fc_age_sex = nn.Sequential( nn.Linear(2, 64), nn.ReLU(), nn.Dropout(0.5) ) self.fc_combined = nn.Sequential( nn.Linear(1792 + 64, 512), nn.ReLU(), nn.Dropout(0.6), nn.Linear(512, config.num_labels) ) def forward(self, x_img, x_age_sex): x_img = self.efficientnet(x_img) x_age_sex = self.fc_age_sex(x_age_sex) x = torch.cat((x_img, x_age_sex), dim=1) x = self.fc_combined(x) return x