File size: 309 Bytes
ee1b207
 
 
 
0154ec1
ee1b207
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
import torch.nn as nn
from torchvision import models

def create_model(num_classes, dropout=0.5):
    model = models.resnet18(pretrained=True) 
    in_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(dropout),
        nn.Linear(in_features, num_classes)
    )
    return model