Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| import torchvision | |
| class Resnet50Flower102(nn.Module): | |
| def __init__(self, device, pretrained=True, freeze_backbone=True): | |
| super().__init__() | |
| self.device = device | |
| if pretrained: | |
| weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1 | |
| else: | |
| weights = None | |
| self.model = torchvision.models.resnet50(weights=weights) | |
| self.model.fc = nn.Sequential( | |
| nn.Linear(2048, 1024), | |
| nn.BatchNorm1d(1024), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(1024, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(512, 102), | |
| ) | |
| self.model.to(device) | |
| def forward(self, x): | |
| return self.model(x) |