import torch.nn as nn from torchvision import models def create_model(num_classes, dropout=0.5): model = models.resnet18(pretrained=True) in_features = model.fc.in_features model.fc = nn.Sequential( nn.Dropout(dropout), nn.Linear(in_features, num_classes) ) return model