|
|
|
|
|
|
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
import torchvision |
|
|
from torchvision.models import resnet50, ResNet50_Weights |
|
|
from utils import * |
|
|
from data_setup import classes |
|
|
|
|
|
|
|
|
class ImageClassificationBase(nn.Module): |
|
|
def training_step(self, batch): |
|
|
images, labels = batch |
|
|
out = self(images) |
|
|
|
|
|
loss = F.cross_entropy(out, labels) |
|
|
acc = accuracy(out, labels) |
|
|
|
|
|
return loss, acc |
|
|
|
|
|
def validation_step(self, batch): |
|
|
images, labels = batch |
|
|
out = self(images) |
|
|
|
|
|
loss = F.cross_entropy(out, labels) |
|
|
acc = accuracy(out, labels) |
|
|
|
|
|
return {'val_loss':loss.detach(), 'val_acc':acc} |
|
|
|
|
|
def validation_end_epoch(self, results): |
|
|
batch_loss = [x['val_loss'] for x in results] |
|
|
epoch_loss = torch.stack(batch_loss).mean() |
|
|
batch_acc = [x['val_acc'] for x in results] |
|
|
epoch_acc = torch.stack(batch_acc).mean() |
|
|
return {'val_loss':epoch_loss.item(), 'val_acc':epoch_acc.item()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def epoch_end(self, epoch, result): |
|
|
print(f"Epoch {epoch+1}: train_loss: {result['train_losses']:.4f}, train_acc: {result['train_acc']:.4f}, \ |
|
|
val_loss: {result['val_loss']:.4f}, val_acc: {result['val_acc']:.4f} ") |
|
|
|
|
|
|
|
|
class FlowerClassificationModel(ImageClassificationBase): |
|
|
def __init__(self, num_classes, pretrained=True): |
|
|
super().__init__() |
|
|
|
|
|
if pretrained: |
|
|
if torchvision.__version__ >= '0.13.0': |
|
|
self.network = self.network = resnet50(weights=ResNet50_Weights.DEFAULT) |
|
|
|
|
|
else: |
|
|
|
|
|
self.network = torchvision.models.resnet50(pretrained=True) |
|
|
|
|
|
for param in self.network.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
self.network.fc = nn.Linear(in_features=2048, |
|
|
out_features=num_classes, |
|
|
bias=True) |
|
|
|
|
|
def forward(self, xb): |
|
|
return self.network(xb) |