File size: 2,378 Bytes
2192664
8ee14ff
 
 
 
 
 
 
2192664
 
 
 
 
8ee14ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2192664
8ee14ff
 
 
 
 
 
 
 
 
 
 
 
 
 
2192664
8ee14ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2192664
8ee14ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# -*- coding: utf-8 -*-
'''
    Model implementation.
    We'll be using a "simple" ResNet-18 for image classification here.

    2022 Benjamin Kellenberger
'''

import glob
import os
from os.path import exists, split, splitext

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class SmallModel(nn.Module):
    @classmethod
    def load(cls, cfg):
        log = cfg.get('log')

        net = cls()

        epoch = 0
        best_loss = np.inf

        output = cfg.get('output')

        filepaths = sorted(glob.glob(f'{output}/*.pt'))

        if len(filepaths) > 1:
            filepaths = [filepath for filepath in filepaths if 'best.pt' not in filepath]

        if len(filepaths):
            filepath = filepaths[-1]

            log.info(f'Resuming from {filepath}')

            state = torch.load(open(filepath, 'rb'), map_location='cpu')
            net.load_state_dict(state['model'])

            filename = split(filepath)[1]
            try:
                epoch = int(splitext(filename)[0])
            except ValueError:
                pass

            filepath = f'{output}/best.pt'
            if exists(filepath):
                state = torch.load(open(filepath, 'rb'), map_location='cpu')
                best_loss = state['loss_val']
        else:
            log.info('Starting new network model')

        device = cfg.get('device')
        net.to(device)

        return net, epoch, best_loss

    def __init__(self):
        super(SmallModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 5)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.fc1 = nn.Linear(32 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def save(self, cfg, epoch, stats, best=False):
        output = cfg.get('output')

        os.makedirs(output, exist_ok=True)

        stats['model'] = self.state_dict()

        torch.save(stats, open(f'{output}/{epoch:04d}.pt', 'wb'))

        if best:
            torch.save(stats, open(f'{output}/best.pt', 'wb'))


def load(cfg):
    return SmallModel.load(cfg)