# -*- coding: utf-8 -*- ''' Model implementation. We'll be using a "simple" ResNet-18 for image classification here. 2022 Benjamin Kellenberger ''' import glob import os from os.path import exists, split, splitext import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class SmallModel(nn.Module): @classmethod def load(cls, cfg): log = cfg.get('log') net = cls() epoch = 0 best_loss = np.inf output = cfg.get('output') filepaths = sorted(glob.glob(f'{output}/*.pt')) if len(filepaths) > 1: filepaths = [filepath for filepath in filepaths if 'best.pt' not in filepath] if len(filepaths): filepath = filepaths[-1] log.info(f'Resuming from {filepath}') state = torch.load(open(filepath, 'rb'), map_location='cpu') net.load_state_dict(state['model']) filename = split(filepath)[1] try: epoch = int(splitext(filename)[0]) except ValueError: pass filepath = f'{output}/best.pt' if exists(filepath): state = torch.load(open(filepath, 'rb'), map_location='cpu') best_loss = state['loss_val'] else: log.info('Starting new network model') device = cfg.get('device') net.to(device) return net, epoch, best_loss def __init__(self): super(SmallModel, self).__init__() self.conv1 = nn.Conv2d(1, 16, 5) self.conv2 = nn.Conv2d(16, 32, 5) self.fc1 = nn.Linear(32 * 5 * 5, 128) self.fc2 = nn.Linear(128, 128) self.fc3 = nn.Linear(128, 10) def forward(self, x): x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x def save(self, cfg, epoch, stats, best=False): output = cfg.get('output') os.makedirs(output, exist_ok=True) stats['model'] = self.state_dict() torch.save(stats, open(f'{output}/{epoch:04d}.pt', 'wb')) if best: torch.save(stats, open(f'{output}/best.pt', 'wb')) def load(cfg): return SmallModel.load(cfg)