Spaces:
Runtime error
Runtime error
| import torch | |
| import torchvision | |
| from torch import nn | |
| from transformers import PreTrainedModel | |
| from configuration import EffNetPlantDiseaseConfig | |
| # create model class | |
| class EffNetPlantDiseaseClassification(PreTrainedModel): | |
| config_class = EffNetPlantDiseaseConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| # get the model architecture from torchvision | |
| self.model = torchvision.models.efficientnet_v2_s() | |
| # modify the classifier head according to the config | |
| self.model.classifier = nn.Sequential( | |
| nn.Dropout(p=config.dropout_rate, inplace=True), | |
| nn.Linear(in_features=self.model.classifier[-1].in_features, | |
| out_features=config.num_classes) | |
| ) | |
| self.num_classes = config.num_classes | |
| self.loss_fn = nn.CrossEntropyLoss() | |
| # define forward method to be similar to hugging face model standards | |
| def forward(self, image, label=None): | |
| logits = self.model(image) | |
| loss = None | |
| if label is not None: | |
| loss = self.loss_fn(logits, label) | |
| return {"logits":logits, | |
| "loss": loss} | |