| import torch
|
| import torch.nn.functional as F
|
| import torch.nn as nn
|
|
|
| class Base(nn.Module):
|
| def training_step(self, batch):
|
| images, labels = batch
|
| out = self(images)
|
| loss = F.cross_entropy(out, labels)
|
| return loss
|
|
|
| def validation_step(self, batch):
|
| images, labels = batch
|
| out = self(images)
|
| loss = F.cross_entropy(out, labels)
|
| acc = accuracy(out, labels)
|
| return {'val_loss': loss.detach(), 'val_acc': acc}
|
|
|
| def validation_epoch_end(self, outputs):
|
| batch_losses = [x['val_loss'] for x in outputs]
|
| epoch_loss = torch.stack(batch_losses).mean()
|
| batch_accs = [x['val_acc'] for x in outputs]
|
| epoch_acc = torch.stack(batch_accs).mean()
|
| return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
|
|
|
| def epoch_end(self, epoch, result):
|
| print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
|
| epoch, result['train_loss'], result['val_loss'], result['val_acc']))
|
|
|
|
|
|
|
| def accuracy(outputs, labels):
|
| _, preds = torch.max(outputs, dim=1)
|
| return torch.tensor(torch.sum(preds == labels).item() / len(preds))
|
|
|
|
|
| class PotatoDiseaseDetectionModel(Base):
|
| def __init__(self, in_channels=3, num_classes=3):
|
| super(PotatoDiseaseDetectionModel, self).__init__()
|
|
|
|
|
| self.network = nn.Sequential(
|
| nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=1),
|
| nn.BatchNorm2d(64),
|
| nn.ReLU(inplace=True),
|
| nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
| nn.BatchNorm2d(64),
|
| nn.ReLU(inplace=True),
|
| nn.MaxPool2d(kernel_size=2, stride=2),
|
|
|
| nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
|
| nn.BatchNorm2d(128),
|
| nn.ReLU(inplace=True),
|
| nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
| nn.BatchNorm2d(128),
|
| nn.ReLU(inplace=True),
|
| nn.MaxPool2d(kernel_size=2, stride=2),
|
|
|
| nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
|
| nn.BatchNorm2d(256),
|
| nn.ReLU(inplace=True),
|
| nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
| nn.BatchNorm2d(256),
|
| nn.ReLU(inplace=True),
|
| nn.MaxPool2d(kernel_size=2, stride=2),
|
|
|
| nn.Flatten()
|
| )
|
|
|
|
|
| self.classifier = nn.Sequential(
|
| nn.Linear(in_features=256*28*28, out_features=128),
|
| nn.BatchNorm1d(128),
|
| nn.ReLU(inplace=True),
|
| nn.Dropout(0.5),
|
| nn.Linear(in_features=128, out_features=num_classes)
|
| )
|
|
|
| def forward(self, x):
|
|
|
| x = self.network(x)
|
|
|
|
|
| x = self.classifier(x)
|
|
|
| return x
|
|
|
|
|
| model = PotatoDiseaseDetectionModel(num_classes=3) |