File size: 3,788 Bytes
092fe0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import torch
from torchvision import models
import torch.nn as nn
import torch.optim as optim
import time
import copy
import sys
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_auc_score
from scripts.model import train_model
import scripts.dataset as DATA
import scripts.config as config
import argparse
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


def main(args):
    modelname = args.modelname
    imagepath = args.imagepath
    label_num,subgroup_num = config.THYROID()
    Datasets = DATA.Thyroid_Datasets

    if args.architecture =='resnet':
        net = models.resnet18(pretrained=True)  
        features = net.fc.in_features
        net.fc = nn.Sequential(
            nn.Linear(features, args.Class))


    if os.path.exists('./modelsaved/%s' % modelname) == False:  
        os.makedirs('./modelsaved/%s' % modelname)
    if os.path.exists('./result/%s' % modelname) == False:  
        os.makedirs('./result/%s' % modelname)


    data_transforms = config.Transforms(modelname)

    print("%s Initializing Datasets and Dataloaders..." % modelname)
    
    transformed_datasets = {}
    transformed_datasets['train'] = Datasets(
        path_to_images=imagepath,
        fold=args.train_data,
        PRED_LABEL=label_num,
        transform=data_transforms['train'])
    transformed_datasets['valid'] = Datasets(
        path_to_images=imagepath,
        fold=args.valid_data,
        PRED_LABEL=label_num,
        transform=data_transforms['valid'])
    transformed_datasets['test'] = Datasets(
        path_to_images=imagepath,
        fold=args.test_data,
        PRED_LABEL=label_num,
        transform=data_transforms['valid'])

    dataloaders = {}
    dataloaders['train'] = torch.utils.data.DataLoader(
        transformed_datasets['train'],
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=24)
    dataloaders['valid'] = torch.utils.data.DataLoader(
        transformed_datasets['valid'],
        batch_size=1,
        shuffle=False,
        num_workers=24)
    dataloaders['test'] = torch.utils.data.DataLoader(
        transformed_datasets['test'],
        batch_size=1,
        shuffle=False,
        num_workers=24)


    if args.modelload_path:
        net.load_state_dict(torch.load('%s' % args.modelload_path , map_location=lambda storage, loc: storage),strict=False)
   
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    net = net.to(device)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=args.learning_rate, betas=(0.9, 0.99),weight_decay=0.03)
   

    if args.Class > 2:
        criterion = nn.BCELoss()
    else:
        criterion = nn.CrossEntropyLoss()

    train_model(net,dataloaders, criterion, optimizer, args.num_epochs, modelname, device)



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--modelname", type=str, default="Thyroid")
    parser.add_argument("--architecture", type=str, choices= ["resnet","densnet","efficientnet"], default="resnet")
    parser.add_argument("--modelload_path", type=str,  default= None)
    parser.add_argument("--imagepath", type=str,  default="./dataset/")
    parser.add_argument("--train_data", type=str, default='thyroid_train')
    parser.add_argument("--valid_data", type=str, default='thyroid_valid')
    parser.add_argument("--test_data", type=str, default='thyroid_test')
    parser.add_argument("--learning_rate", type=float, default=0.00005)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--num_epochs", type=int, default=100)
    parser.add_argument("--Class", type=int, default=2)
    args = parser.parse_args()
    main(args)