Plant_Disease_Detection_App / src /models /resnet18_finetune.py
JAMM032's picture
Upload github repo files
97fcc90 verified
raw
history blame contribute delete
662 Bytes
import torch.nn as nn
from torchvision import models
def make_resnet18(num_classes=39):
"""
Constructs a ResNet18 model with a custom classification head for PlantVillage.
The feature extractor weights are frozen (transfer learning).
"""
# Load pretrained model
# Note: In the lab we used weights=...V1. Here we use the same.
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
# Freeze feature extractor
for param in model.parameters():
param.requires_grad = False
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
return model