PlantDiseaseDetection / modeling.py
BrandonFors's picture
uploading files to space
4f09bb0
import torch
import torchvision
from torch import nn
from transformers import PreTrainedModel
from configuration import EffNetPlantDiseaseConfig
# create model class
class EffNetPlantDiseaseClassification(PreTrainedModel):
config_class = EffNetPlantDiseaseConfig
def __init__(self, config):
super().__init__(config)
# get the model architecture from torchvision
self.model = torchvision.models.efficientnet_v2_s()
# modify the classifier head according to the config
self.model.classifier = nn.Sequential(
nn.Dropout(p=config.dropout_rate, inplace=True),
nn.Linear(in_features=self.model.classifier[-1].in_features,
out_features=config.num_classes)
)
self.num_classes = config.num_classes
self.loss_fn = nn.CrossEntropyLoss()
# define forward method to be similar to hugging face model standards
def forward(self, image, label=None):
logits = self.model(image)
loss = None
if label is not None:
loss = self.loss_fn(logits, label)
return {"logits":logits,
"loss": loss}