PoxNet / model.py
ronithsharmila's picture
Update model.py
e6b3256 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
class CustomModel(nn.Module):
def __init__(self, num_classes=4):
super(CustomModel, self).__init__()
self.efficientnet = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1)
# Get the number of features from the last layer of EfficientNetV2
num_features = self.efficientnet.classifier[1].in_features
# Remove the classifier
self.efficientnet = nn.Sequential(*list(self.efficientnet.children())[:-1])
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(num_features, 512)
self.dropout1 = nn.Dropout(0.5)
self.fc2 = nn.Linear(512, 256)
self.dropout2 = nn.Dropout(0.3)
self.fc3 = nn.Linear(256, num_classes)
def forward(self, x):
x = self.efficientnet(x)
x = self.gap(x)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.dropout1(x)
x = F.relu(self.fc2(x))
x = self.dropout2(x)
x = self.fc3(x)
return x
def load_model(model_path):
model = CustomModel()
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model