File size: 561 Bytes
207a388
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch.nn as nn
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

class CatDogEfficientNetB0(nn.Module):
    def __init__(self):
        super().__init__()
        weights = EfficientNet_B0_Weights.DEFAULT
        self.base = efficientnet_b0(weights=weights)
        for param in self.base.parameters():
            param.requires_grad = False
        in_features = self.base.classifier[1].in_features
        self.base.classifier[1] = nn.Linear(in_features, 2)

    def forward(self, x):
        return self.base(x)