vibe-detection-backend-api / api /models /model_definitions.py
vaibhav07112004's picture
Upload 3 files
b002e2f verified
import torch.nn as nn
from torchvision import models
import timm
from antialiased_cnns import resnet18
import torch
# For AA-DCN
# For Vibe Detection (HybridResNetViT)
class HybridResNetViT(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.resnet = models.resnet50(weights='IMAGENET1K_V1')
self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
self.vit = timm.create_model('vit_base_patch16_224', pretrained=True)
self.vit.head = nn.Identity()
self.resnet_fc = nn.Linear(2048, 512)
self.vit_fc = nn.Linear(768, 512)
self.dropout = nn.Dropout(0.5)
self.classifier = nn.Linear(1024, num_classes)
def forward(self, x):
resnet_feat = self.resnet(x).view(x.size(0), -1)
resnet_feat = self.resnet_fc(resnet_feat)
vit_feat = self.vit(x)
vit_feat = self.vit_fc(vit_feat)
fused = torch.cat([resnet_feat, vit_feat], dim=1)
fused = self.dropout(fused)
return self.classifier(fused)
# For Face Mood Detection (AA-DCN/ResNet18 from antialiased_cnns)
def create_aadcn_model(num_classes=8):
model = resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)
return model