Spaces:
Sleeping
Sleeping
| import argparse | |
| import os | |
| import time | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torch.backends.cudnn as cudnn | |
| import torchvision | |
| from model import Net | |
| parser = argparse.ArgumentParser(description="Train on market1501") | |
| parser.add_argument("--data-dir", default='data', type=str) | |
| parser.add_argument("--no-cuda", action="store_true") | |
| parser.add_argument("--gpu-id", default=0, type=int) | |
| parser.add_argument("--lr", default=0.1, type=float) | |
| parser.add_argument("--interval", '-i', default=20, type=int) | |
| parser.add_argument('--resume', '-r', action='store_true') | |
| args = parser.parse_args() | |
| # device | |
| device = "cuda:{}".format( | |
| args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu" | |
| if torch.cuda.is_available() and not args.no_cuda: | |
| cudnn.benchmark = True | |
| # data loading | |
| root = args.data_dir | |
| train_dir = os.path.join(root, "train") | |
| test_dir = os.path.join(root, "test") | |
| transform_train = torchvision.transforms.Compose([ | |
| torchvision.transforms.RandomCrop((128, 64), padding=4), | |
| torchvision.transforms.RandomHorizontalFlip(), | |
| torchvision.transforms.ToTensor(), | |
| torchvision.transforms.Normalize( | |
| [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| transform_test = torchvision.transforms.Compose([ | |
| torchvision.transforms.Resize((128, 64)), | |
| torchvision.transforms.ToTensor(), | |
| torchvision.transforms.Normalize( | |
| [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| trainloader = torch.utils.data.DataLoader( | |
| torchvision.datasets.ImageFolder(train_dir, transform=transform_train), | |
| batch_size=64, shuffle=True | |
| ) | |
| testloader = torch.utils.data.DataLoader( | |
| torchvision.datasets.ImageFolder(test_dir, transform=transform_test), | |
| batch_size=64, shuffle=True | |
| ) | |
| num_classes = max(len(trainloader.dataset.classes), | |
| len(testloader.dataset.classes)) | |
| # net definition | |
| start_epoch = 0 | |
| net = Net(num_classes=num_classes) | |
| if args.resume: | |
| assert os.path.isfile( | |
| "./checkpoint/ckpt.t7"), "Error: no checkpoint file found!" | |
| print('Loading from checkpoint/ckpt.t7') | |
| checkpoint = torch.load("./checkpoint/ckpt.t7") | |
| # import ipdb; ipdb.set_trace() | |
| net_dict = checkpoint['net_dict'] | |
| net.load_state_dict(net_dict) | |
| best_acc = checkpoint['acc'] | |
| start_epoch = checkpoint['epoch'] | |
| net.to(device) | |
| # loss and optimizer | |
| criterion = torch.nn.CrossEntropyLoss() | |
| optimizer = torch.optim.SGD( | |
| net.parameters(), args.lr, momentum=0.9, weight_decay=5e-4) | |
| best_acc = 0. | |
| # train function for each epoch | |
| def train(epoch): | |
| print("\nEpoch : %d" % (epoch+1)) | |
| net.train() | |
| training_loss = 0. | |
| train_loss = 0. | |
| correct = 0 | |
| total = 0 | |
| interval = args.interval | |
| start = time.time() | |
| for idx, (inputs, labels) in enumerate(trainloader): | |
| # forward | |
| inputs, labels = inputs.to(device), labels.to(device) | |
| outputs = net(inputs) | |
| loss = criterion(outputs, labels) | |
| # backward | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| # accumurating | |
| training_loss += loss.item() | |
| train_loss += loss.item() | |
| correct += outputs.max(dim=1)[1].eq(labels).sum().item() | |
| total += labels.size(0) | |
| if (idx+1) % interval == 0: | |
| end = time.time() | |
| print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format( | |
| 100.*(idx+1)/len(trainloader), end-start, training_loss / | |
| interval, correct, total, 100.*correct/total | |
| )) | |
| training_loss = 0. | |
| start = time.time() | |
| return train_loss/len(trainloader), 1. - correct/total | |
| def test(epoch): | |
| global best_acc | |
| net.eval() | |
| test_loss = 0. | |
| correct = 0 | |
| total = 0 | |
| start = time.time() | |
| with torch.no_grad(): | |
| for idx, (inputs, labels) in enumerate(testloader): | |
| inputs, labels = inputs.to(device), labels.to(device) | |
| outputs = net(inputs) | |
| loss = criterion(outputs, labels) | |
| test_loss += loss.item() | |
| correct += outputs.max(dim=1)[1].eq(labels).sum().item() | |
| total += labels.size(0) | |
| print("Testing ...") | |
| end = time.time() | |
| print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format( | |
| 100.*(idx+1)/len(testloader), end-start, test_loss / | |
| len(testloader), correct, total, 100.*correct/total | |
| )) | |
| # saving checkpoint | |
| acc = 100.*correct/total | |
| if acc > best_acc: | |
| best_acc = acc | |
| print("Saving parameters to checkpoint/ckpt.t7") | |
| checkpoint = { | |
| 'net_dict': net.state_dict(), | |
| 'acc': acc, | |
| 'epoch': epoch, | |
| } | |
| if not os.path.isdir('checkpoint'): | |
| os.mkdir('checkpoint') | |
| torch.save(checkpoint, './checkpoint/ckpt.t7') | |
| return test_loss/len(testloader), 1. - correct/total | |
| # plot figure | |
| x_epoch = [] | |
| record = {'train_loss': [], 'train_err': [], 'test_loss': [], 'test_err': []} | |
| fig = plt.figure() | |
| ax0 = fig.add_subplot(121, title="loss") | |
| ax1 = fig.add_subplot(122, title="top1err") | |
| def draw_curve(epoch, train_loss, train_err, test_loss, test_err): | |
| global record | |
| record['train_loss'].append(train_loss) | |
| record['train_err'].append(train_err) | |
| record['test_loss'].append(test_loss) | |
| record['test_err'].append(test_err) | |
| x_epoch.append(epoch) | |
| ax0.plot(x_epoch, record['train_loss'], 'bo-', label='train') | |
| ax0.plot(x_epoch, record['test_loss'], 'ro-', label='val') | |
| ax1.plot(x_epoch, record['train_err'], 'bo-', label='train') | |
| ax1.plot(x_epoch, record['test_err'], 'ro-', label='val') | |
| if epoch == 0: | |
| ax0.legend() | |
| ax1.legend() | |
| fig.savefig("train.jpg") | |
| # lr decay | |
| def lr_decay(): | |
| global optimizer | |
| for params in optimizer.param_groups: | |
| params['lr'] *= 0.1 | |
| lr = params['lr'] | |
| print("Learning rate adjusted to {}".format(lr)) | |
| def main(): | |
| for epoch in range(start_epoch, start_epoch+40): | |
| train_loss, train_err = train(epoch) | |
| test_loss, test_err = test(epoch) | |
| draw_curve(epoch, train_loss, train_err, test_loss, test_err) | |
| if (epoch+1) % 20 == 0: | |
| lr_decay() | |
| if __name__ == '__main__': | |
| main() | |