| import torch.nn as nn
|
| from torchvision import models
|
| import timm
|
| from antialiased_cnns import resnet18
|
| import torch
|
|
|
|
|
|
|
| class HybridResNetViT(nn.Module):
|
| def __init__(self, num_classes):
|
| super().__init__()
|
| self.resnet = models.resnet50(weights='IMAGENET1K_V1')
|
| self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
|
| self.vit = timm.create_model('vit_base_patch16_224', pretrained=True)
|
| self.vit.head = nn.Identity()
|
| self.resnet_fc = nn.Linear(2048, 512)
|
| self.vit_fc = nn.Linear(768, 512)
|
| self.dropout = nn.Dropout(0.5)
|
| self.classifier = nn.Linear(1024, num_classes)
|
| def forward(self, x):
|
| resnet_feat = self.resnet(x).view(x.size(0), -1)
|
| resnet_feat = self.resnet_fc(resnet_feat)
|
| vit_feat = self.vit(x)
|
| vit_feat = self.vit_fc(vit_feat)
|
| fused = torch.cat([resnet_feat, vit_feat], dim=1)
|
| fused = self.dropout(fused)
|
| return self.classifier(fused)
|
|
|
|
|
| def create_aadcn_model(num_classes=8):
|
| model = resnet18(pretrained=True)
|
| model.fc = nn.Linear(model.fc.in_features, num_classes)
|
| return model
|
|
|