myshafazil commited on
Commit
af8092c
·
verified ·
1 Parent(s): 6e2f98e

Upload 7 files

Browse files
Files changed (7) hide show
  1. __init__.py +0 -0
  2. evaluate.py +13 -0
  3. feature_extractor.py +54 -0
  4. model.py +109 -0
  5. original_model.py +111 -0
  6. test.py +80 -0
  7. train.py +206 -0
__init__.py ADDED
File without changes
evaluate.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ features = torch.load("features.pth")
4
+ qf = features["qf"]
5
+ ql = features["ql"]
6
+ gf = features["gf"]
7
+ gl = features["gl"]
8
+
9
+ scores = qf.mm(gf.t())
10
+ res = scores.topk(5, dim=1)[1][:, 0]
11
+ top1correct = gl[res].eq(ql).sum().item()
12
+
13
+ print("Acc top1:{:.3f}".format(top1correct / ql.size(0)))
feature_extractor.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ import numpy as np
4
+ import cv2
5
+ import logging
6
+
7
+ from .model import Net
8
+
9
+
10
+ class Extractor(object):
11
+ def __init__(self, model_path, use_cuda=True):
12
+ self.net = Net(reid=True)
13
+ self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
14
+ state_dict = torch.load(model_path, map_location=torch.device(self.device))[
15
+ 'net_dict']
16
+ self.net.load_state_dict(state_dict)
17
+ logger = logging.getLogger("root.tracker")
18
+ logger.info("Loading weights from {}... Done!".format(model_path))
19
+ self.net.to(self.device)
20
+ self.size = (64, 128)
21
+ self.norm = transforms.Compose([
22
+ transforms.ToTensor(),
23
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
24
+ ])
25
+
26
+ def _preprocess(self, im_crops):
27
+ """
28
+ TODO:
29
+ 1. to float with scale from 0 to 1
30
+ 2. resize to (64, 128) as Market1501 dataset did
31
+ 3. concatenate to a numpy array
32
+ 3. to torch Tensor
33
+ 4. normalize
34
+ """
35
+ def _resize(im, size):
36
+ return cv2.resize(im.astype(np.float32)/255., size)
37
+
38
+ im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(
39
+ 0) for im in im_crops], dim=0).float()
40
+ return im_batch
41
+
42
+ def __call__(self, im_crops):
43
+ im_batch = self._preprocess(im_crops)
44
+ with torch.no_grad():
45
+ im_batch = im_batch.to(self.device)
46
+ features = self.net(im_batch)
47
+ return features.cpu().numpy()
48
+
49
+
50
+ if __name__ == '__main__':
51
+ img = cv2.imread("demo.jpg")[:, :, (2, 1, 0)]
52
+ extr = Extractor("checkpoint/ckpt.t7")
53
+ feature = extr(img)
54
+ print(feature.shape)
model.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class BasicBlock(nn.Module):
7
+ def __init__(self, c_in, c_out, is_downsample=False):
8
+ super(BasicBlock, self).__init__()
9
+ self.is_downsample = is_downsample
10
+ if is_downsample:
11
+ self.conv1 = nn.Conv2d(
12
+ c_in, c_out, 3, stride=2, padding=1, bias=False)
13
+ else:
14
+ self.conv1 = nn.Conv2d(
15
+ c_in, c_out, 3, stride=1, padding=1, bias=False)
16
+ self.bn1 = nn.BatchNorm2d(c_out)
17
+ self.relu = nn.ReLU(True)
18
+ self.conv2 = nn.Conv2d(c_out, c_out, 3, stride=1,
19
+ padding=1, bias=False)
20
+ self.bn2 = nn.BatchNorm2d(c_out)
21
+ if is_downsample:
22
+ self.downsample = nn.Sequential(
23
+ nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),
24
+ nn.BatchNorm2d(c_out)
25
+ )
26
+ elif c_in != c_out:
27
+ self.downsample = nn.Sequential(
28
+ nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),
29
+ nn.BatchNorm2d(c_out)
30
+ )
31
+ self.is_downsample = True
32
+
33
+ def forward(self, x):
34
+ y = self.conv1(x)
35
+ y = self.bn1(y)
36
+ y = self.relu(y)
37
+ y = self.conv2(y)
38
+ y = self.bn2(y)
39
+ if self.is_downsample:
40
+ x = self.downsample(x)
41
+ return F.relu(x.add(y), True)
42
+
43
+
44
+ def make_layers(c_in, c_out, repeat_times, is_downsample=False):
45
+ blocks = []
46
+ for i in range(repeat_times):
47
+ if i == 0:
48
+ blocks += [BasicBlock(c_in, c_out, is_downsample=is_downsample), ]
49
+ else:
50
+ blocks += [BasicBlock(c_out, c_out), ]
51
+ return nn.Sequential(*blocks)
52
+
53
+
54
+ class Net(nn.Module):
55
+ def __init__(self, num_classes=751, reid=False):
56
+ super(Net, self).__init__()
57
+ # 3 128 64
58
+ self.conv = nn.Sequential(
59
+ nn.Conv2d(3, 64, 3, stride=1, padding=1),
60
+ nn.BatchNorm2d(64),
61
+ nn.ReLU(inplace=True),
62
+ # nn.Conv2d(32,32,3,stride=1,padding=1),
63
+ # nn.BatchNorm2d(32),
64
+ # nn.ReLU(inplace=True),
65
+ nn.MaxPool2d(3, 2, padding=1),
66
+ )
67
+ # 32 64 32
68
+ self.layer1 = make_layers(64, 64, 2, False)
69
+ # 32 64 32
70
+ self.layer2 = make_layers(64, 128, 2, True)
71
+ # 64 32 16
72
+ self.layer3 = make_layers(128, 256, 2, True)
73
+ # 128 16 8
74
+ self.layer4 = make_layers(256, 512, 2, True)
75
+ # 256 8 4
76
+ self.avgpool = nn.AvgPool2d((8, 4), 1)
77
+ # 256 1 1
78
+ self.reid = reid
79
+ self.classifier = nn.Sequential(
80
+ nn.Linear(512, 256),
81
+ nn.BatchNorm1d(256),
82
+ nn.ReLU(inplace=True),
83
+ nn.Dropout(),
84
+ nn.Linear(256, num_classes),
85
+ )
86
+
87
+ def forward(self, x):
88
+ x = self.conv(x)
89
+ x = self.layer1(x)
90
+ x = self.layer2(x)
91
+ x = self.layer3(x)
92
+ x = self.layer4(x)
93
+ x = self.avgpool(x)
94
+ x = x.view(x.size(0), -1)
95
+ # B x 128
96
+ if self.reid:
97
+ x = x.div(x.norm(p=2, dim=1, keepdim=True))
98
+ return x
99
+ # classifier
100
+ x = self.classifier(x)
101
+ return x
102
+
103
+
104
+ if __name__ == '__main__':
105
+ net = Net()
106
+ x = torch.randn(4, 3, 128, 64)
107
+ y = net(x)
108
+ import ipdb
109
+ ipdb.set_trace()
original_model.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class BasicBlock(nn.Module):
7
+ def __init__(self, c_in, c_out, is_downsample=False):
8
+ super(BasicBlock, self).__init__()
9
+ self.is_downsample = is_downsample
10
+ if is_downsample:
11
+ self.conv1 = nn.Conv2d(
12
+ c_in, c_out, 3, stride=2, padding=1, bias=False)
13
+ else:
14
+ self.conv1 = nn.Conv2d(
15
+ c_in, c_out, 3, stride=1, padding=1, bias=False)
16
+ self.bn1 = nn.BatchNorm2d(c_out)
17
+ self.relu = nn.ReLU(True)
18
+ self.conv2 = nn.Conv2d(c_out, c_out, 3, stride=1,
19
+ padding=1, bias=False)
20
+ self.bn2 = nn.BatchNorm2d(c_out)
21
+ if is_downsample:
22
+ self.downsample = nn.Sequential(
23
+ nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),
24
+ nn.BatchNorm2d(c_out)
25
+ )
26
+ elif c_in != c_out:
27
+ self.downsample = nn.Sequential(
28
+ nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),
29
+ nn.BatchNorm2d(c_out)
30
+ )
31
+ self.is_downsample = True
32
+
33
+ def forward(self, x):
34
+ y = self.conv1(x)
35
+ y = self.bn1(y)
36
+ y = self.relu(y)
37
+ y = self.conv2(y)
38
+ y = self.bn2(y)
39
+ if self.is_downsample:
40
+ x = self.downsample(x)
41
+ return F.relu(x.add(y), True)
42
+
43
+
44
+ def make_layers(c_in, c_out, repeat_times, is_downsample=False):
45
+ blocks = []
46
+ for i in range(repeat_times):
47
+ if i == 0:
48
+ blocks += [BasicBlock(c_in, c_out, is_downsample=is_downsample), ]
49
+ else:
50
+ blocks += [BasicBlock(c_out, c_out), ]
51
+ return nn.Sequential(*blocks)
52
+
53
+
54
+ class Net(nn.Module):
55
+ def __init__(self, num_classes=625, reid=False):
56
+ super(Net, self).__init__()
57
+ # 3 128 64
58
+ self.conv = nn.Sequential(
59
+ nn.Conv2d(3, 32, 3, stride=1, padding=1),
60
+ nn.BatchNorm2d(32),
61
+ nn.ELU(inplace=True),
62
+ nn.Conv2d(32, 32, 3, stride=1, padding=1),
63
+ nn.BatchNorm2d(32),
64
+ nn.ELU(inplace=True),
65
+ nn.MaxPool2d(3, 2, padding=1),
66
+ )
67
+ # 32 64 32
68
+ self.layer1 = make_layers(32, 32, 2, False)
69
+ # 32 64 32
70
+ self.layer2 = make_layers(32, 64, 2, True)
71
+ # 64 32 16
72
+ self.layer3 = make_layers(64, 128, 2, True)
73
+ # 128 16 8
74
+ self.dense = nn.Sequential(
75
+ nn.Dropout(p=0.6),
76
+ nn.Linear(128*16*8, 128),
77
+ nn.BatchNorm1d(128),
78
+ nn.ELU(inplace=True)
79
+ )
80
+ # 256 1 1
81
+ self.reid = reid
82
+ self.batch_norm = nn.BatchNorm1d(128)
83
+ self.classifier = nn.Sequential(
84
+ nn.Linear(128, num_classes),
85
+ )
86
+
87
+ def forward(self, x):
88
+ x = self.conv(x)
89
+ x = self.layer1(x)
90
+ x = self.layer2(x)
91
+ x = self.layer3(x)
92
+
93
+ x = x.view(x.size(0), -1)
94
+ if self.reid:
95
+ x = self.dense[0](x)
96
+ x = self.dense[1](x)
97
+ x = x.div(x.norm(p=2, dim=1, keepdim=True))
98
+ return x
99
+ x = self.dense(x)
100
+ # B x 128
101
+ # classifier
102
+ x = self.classifier(x)
103
+ return x
104
+
105
+
106
+ if __name__ == '__main__':
107
+ net = Net(reid=True)
108
+ x = torch.randn(4, 3, 128, 64)
109
+ y = net(x)
110
+ import ipdb
111
+ ipdb.set_trace()
test.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.backends.cudnn as cudnn
3
+ import torchvision
4
+
5
+ import argparse
6
+ import os
7
+
8
+ from model import Net
9
+
10
+ parser = argparse.ArgumentParser(description="Train on market1501")
11
+ parser.add_argument("--data-dir", default='data', type=str)
12
+ parser.add_argument("--no-cuda", action="store_true")
13
+ parser.add_argument("--gpu-id", default=0, type=int)
14
+ args = parser.parse_args()
15
+
16
+ # device
17
+ device = "cuda:{}".format(
18
+ args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu"
19
+ if torch.cuda.is_available() and not args.no_cuda:
20
+ cudnn.benchmark = True
21
+
22
+ # data loader
23
+ root = args.data_dir
24
+ query_dir = os.path.join(root, "query")
25
+ gallery_dir = os.path.join(root, "gallery")
26
+ transform = torchvision.transforms.Compose([
27
+ torchvision.transforms.Resize((128, 64)),
28
+ torchvision.transforms.ToTensor(),
29
+ torchvision.transforms.Normalize(
30
+ [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
31
+ ])
32
+ queryloader = torch.utils.data.DataLoader(
33
+ torchvision.datasets.ImageFolder(query_dir, transform=transform),
34
+ batch_size=64, shuffle=False
35
+ )
36
+ galleryloader = torch.utils.data.DataLoader(
37
+ torchvision.datasets.ImageFolder(gallery_dir, transform=transform),
38
+ batch_size=64, shuffle=False
39
+ )
40
+
41
+ # net definition
42
+ net = Net(reid=True)
43
+ assert os.path.isfile(
44
+ "./checkpoint/ckpt.t7"), "Error: no checkpoint file found!"
45
+ print('Loading from checkpoint/ckpt.t7')
46
+ checkpoint = torch.load("./checkpoint/ckpt.t7")
47
+ net_dict = checkpoint['net_dict']
48
+ net.load_state_dict(net_dict, strict=False)
49
+ net.eval()
50
+ net.to(device)
51
+
52
+ # compute features
53
+ query_features = torch.tensor([]).float()
54
+ query_labels = torch.tensor([]).long()
55
+ gallery_features = torch.tensor([]).float()
56
+ gallery_labels = torch.tensor([]).long()
57
+
58
+ with torch.no_grad():
59
+ for idx, (inputs, labels) in enumerate(queryloader):
60
+ inputs = inputs.to(device)
61
+ features = net(inputs).cpu()
62
+ query_features = torch.cat((query_features, features), dim=0)
63
+ query_labels = torch.cat((query_labels, labels))
64
+
65
+ for idx, (inputs, labels) in enumerate(galleryloader):
66
+ inputs = inputs.to(device)
67
+ features = net(inputs).cpu()
68
+ gallery_features = torch.cat((gallery_features, features), dim=0)
69
+ gallery_labels = torch.cat((gallery_labels, labels))
70
+
71
+ gallery_labels -= 2
72
+
73
+ # save features
74
+ features = {
75
+ "qf": query_features,
76
+ "ql": query_labels,
77
+ "gf": gallery_features,
78
+ "gl": gallery_labels
79
+ }
80
+ torch.save(features, "features.pth")
train.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import torch
8
+ import torch.backends.cudnn as cudnn
9
+ import torchvision
10
+
11
+ from model import Net
12
+
13
+ parser = argparse.ArgumentParser(description="Train on market1501")
14
+ parser.add_argument("--data-dir", default='data', type=str)
15
+ parser.add_argument("--no-cuda", action="store_true")
16
+ parser.add_argument("--gpu-id", default=0, type=int)
17
+ parser.add_argument("--lr", default=0.1, type=float)
18
+ parser.add_argument("--interval", '-i', default=20, type=int)
19
+ parser.add_argument('--resume', '-r', action='store_true')
20
+ args = parser.parse_args()
21
+
22
+ # device
23
+ device = "cuda:{}".format(
24
+ args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu"
25
+ if torch.cuda.is_available() and not args.no_cuda:
26
+ cudnn.benchmark = True
27
+
28
+ # data loading
29
+ root = args.data_dir
30
+ train_dir = os.path.join(root, "train")
31
+ test_dir = os.path.join(root, "test")
32
+ transform_train = torchvision.transforms.Compose([
33
+ torchvision.transforms.RandomCrop((128, 64), padding=4),
34
+ torchvision.transforms.RandomHorizontalFlip(),
35
+ torchvision.transforms.ToTensor(),
36
+ torchvision.transforms.Normalize(
37
+ [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
38
+ ])
39
+ transform_test = torchvision.transforms.Compose([
40
+ torchvision.transforms.Resize((128, 64)),
41
+ torchvision.transforms.ToTensor(),
42
+ torchvision.transforms.Normalize(
43
+ [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
44
+ ])
45
+ trainloader = torch.utils.data.DataLoader(
46
+ torchvision.datasets.ImageFolder(train_dir, transform=transform_train),
47
+ batch_size=64, shuffle=True
48
+ )
49
+ testloader = torch.utils.data.DataLoader(
50
+ torchvision.datasets.ImageFolder(test_dir, transform=transform_test),
51
+ batch_size=64, shuffle=True
52
+ )
53
+ num_classes = max(len(trainloader.dataset.classes),
54
+ len(testloader.dataset.classes))
55
+
56
+ # net definition
57
+ start_epoch = 0
58
+ net = Net(num_classes=num_classes)
59
+ if args.resume:
60
+ assert os.path.isfile(
61
+ "./checkpoint/ckpt.t7"), "Error: no checkpoint file found!"
62
+ print('Loading from checkpoint/ckpt.t7')
63
+ checkpoint = torch.load("./checkpoint/ckpt.t7")
64
+ # import ipdb; ipdb.set_trace()
65
+ net_dict = checkpoint['net_dict']
66
+ net.load_state_dict(net_dict)
67
+ best_acc = checkpoint['acc']
68
+ start_epoch = checkpoint['epoch']
69
+ net.to(device)
70
+
71
+ # loss and optimizer
72
+ criterion = torch.nn.CrossEntropyLoss()
73
+ optimizer = torch.optim.SGD(
74
+ net.parameters(), args.lr, momentum=0.9, weight_decay=5e-4)
75
+ best_acc = 0.
76
+
77
+ # train function for each epoch
78
+
79
+
80
+ def train(epoch):
81
+ print("\nEpoch : %d" % (epoch+1))
82
+ net.train()
83
+ training_loss = 0.
84
+ train_loss = 0.
85
+ correct = 0
86
+ total = 0
87
+ interval = args.interval
88
+ start = time.time()
89
+ for idx, (inputs, labels) in enumerate(trainloader):
90
+ # forward
91
+ inputs, labels = inputs.to(device), labels.to(device)
92
+ outputs = net(inputs)
93
+ loss = criterion(outputs, labels)
94
+
95
+ # backward
96
+ optimizer.zero_grad()
97
+ loss.backward()
98
+ optimizer.step()
99
+
100
+ # accumurating
101
+ training_loss += loss.item()
102
+ train_loss += loss.item()
103
+ correct += outputs.max(dim=1)[1].eq(labels).sum().item()
104
+ total += labels.size(0)
105
+
106
+ # print
107
+ if (idx+1) % interval == 0:
108
+ end = time.time()
109
+ print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(
110
+ 100.*(idx+1)/len(trainloader), end-start, training_loss /
111
+ interval, correct, total, 100.*correct/total
112
+ ))
113
+ training_loss = 0.
114
+ start = time.time()
115
+
116
+ return train_loss/len(trainloader), 1. - correct/total
117
+
118
+
119
+ def test(epoch):
120
+ global best_acc
121
+ net.eval()
122
+ test_loss = 0.
123
+ correct = 0
124
+ total = 0
125
+ start = time.time()
126
+ with torch.no_grad():
127
+ for idx, (inputs, labels) in enumerate(testloader):
128
+ inputs, labels = inputs.to(device), labels.to(device)
129
+ outputs = net(inputs)
130
+ loss = criterion(outputs, labels)
131
+
132
+ test_loss += loss.item()
133
+ correct += outputs.max(dim=1)[1].eq(labels).sum().item()
134
+ total += labels.size(0)
135
+
136
+ print("Testing ...")
137
+ end = time.time()
138
+ print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(
139
+ 100.*(idx+1)/len(testloader), end-start, test_loss /
140
+ len(testloader), correct, total, 100.*correct/total
141
+ ))
142
+
143
+ # saving checkpoint
144
+ acc = 100.*correct/total
145
+ if acc > best_acc:
146
+ best_acc = acc
147
+ print("Saving parameters to checkpoint/ckpt.t7")
148
+ checkpoint = {
149
+ 'net_dict': net.state_dict(),
150
+ 'acc': acc,
151
+ 'epoch': epoch,
152
+ }
153
+ if not os.path.isdir('checkpoint'):
154
+ os.mkdir('checkpoint')
155
+ torch.save(checkpoint, './checkpoint/ckpt.t7')
156
+
157
+ return test_loss/len(testloader), 1. - correct/total
158
+
159
+
160
+ # plot figure
161
+ x_epoch = []
162
+ record = {'train_loss': [], 'train_err': [], 'test_loss': [], 'test_err': []}
163
+ fig = plt.figure()
164
+ ax0 = fig.add_subplot(121, title="loss")
165
+ ax1 = fig.add_subplot(122, title="top1err")
166
+
167
+
168
+ def draw_curve(epoch, train_loss, train_err, test_loss, test_err):
169
+ global record
170
+ record['train_loss'].append(train_loss)
171
+ record['train_err'].append(train_err)
172
+ record['test_loss'].append(test_loss)
173
+ record['test_err'].append(test_err)
174
+
175
+ x_epoch.append(epoch)
176
+ ax0.plot(x_epoch, record['train_loss'], 'bo-', label='train')
177
+ ax0.plot(x_epoch, record['test_loss'], 'ro-', label='val')
178
+ ax1.plot(x_epoch, record['train_err'], 'bo-', label='train')
179
+ ax1.plot(x_epoch, record['test_err'], 'ro-', label='val')
180
+ if epoch == 0:
181
+ ax0.legend()
182
+ ax1.legend()
183
+ fig.savefig("train.jpg")
184
+
185
+ # lr decay
186
+
187
+
188
+ def lr_decay():
189
+ global optimizer
190
+ for params in optimizer.param_groups:
191
+ params['lr'] *= 0.1
192
+ lr = params['lr']
193
+ print("Learning rate adjusted to {}".format(lr))
194
+
195
+
196
+ def main():
197
+ for epoch in range(start_epoch, start_epoch+40):
198
+ train_loss, train_err = train(epoch)
199
+ test_loss, test_err = test(epoch)
200
+ draw_curve(epoch, train_loss, train_err, test_loss, test_err)
201
+ if (epoch+1) % 20 == 0:
202
+ lr_decay()
203
+
204
+
205
+ if __name__ == '__main__':
206
+ main()