| | import numpy as np |
| | import torch |
| | import torch.optim |
| | import os |
| | import random |
| |
|
| | from methods import backbone |
| | from methods.backbone_multiblock import model_dict |
| | from data.datamgr import SimpleDataManager, SetDataManager |
| | from methods.StyleAdv_RN_GNN import StyleAdvGNN |
| |
|
| | from options import parse_args, get_resume_file, load_warmup_state |
| | from test_function_fwt_benchmark import test_bestmodel |
| | from test_function_bscdfsl_benchmark import test_bestmodel_bscdfsl |
| |
|
| |
|
| | def train(base_loader, val_loader, model, start_epoch, stop_epoch, params): |
| |
|
| | |
| | optimizer = torch.optim.Adam(model.parameters()) |
| | if not os.path.isdir(params.checkpoint_dir): |
| | os.makedirs(params.checkpoint_dir) |
| |
|
| | |
| | max_acc = 0 |
| | total_it = 0 |
| |
|
| | |
| | for epoch in range(start_epoch, stop_epoch): |
| | model.train() |
| | total_it = model.train_loop(epoch, base_loader, optimizer, total_it) |
| | model.eval() |
| |
|
| | acc = model.test_loop( val_loader) |
| | if acc > max_acc : |
| | print("best model! save...") |
| | max_acc = acc |
| | outfile = os.path.join(params.checkpoint_dir, 'best_model.tar') |
| | torch.save({'epoch':epoch, 'state':model.state_dict()}, outfile) |
| | else: |
| | print("GG! best accuracy {:f}".format(max_acc)) |
| |
|
| | |
| | if(epoch == stop_epoch-1): |
| | outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch)) |
| | torch.save({'epoch':epoch, 'state':model.state_dict()}, outfile) |
| |
|
| | return model |
| |
|
| |
|
| | def record_test_result(params): |
| | acc_file_path = os.path.join(params.checkpoint_dir, 'acc.txt') |
| | acc_file = open(acc_file_path,'w') |
| | epoch_id = -1 |
| | print('epoch', epoch_id, 'miniImagenet:', 'cub:', 'cars:', 'places:', 'plantae:', file = acc_file) |
| | name = params.name |
| | n_shot = params.n_shot |
| | method = params.method |
| | test_bestmodel(acc_file, name, method, 'miniImagenet', n_shot, epoch_id) |
| | test_bestmodel(acc_file, name, method, 'cub', n_shot, epoch_id) |
| | test_bestmodel(acc_file, name, method, 'cars', n_shot, epoch_id) |
| | test_bestmodel(acc_file, name, method, 'places', n_shot, epoch_id) |
| | test_bestmodel(acc_file, name, method, 'plantae', n_shot, epoch_id) |
| |
|
| | acc_file.close() |
| | return |
| |
|
| |
|
| | def record_test_result_bscdfsl(params): |
| | print('hhhhhhh testing for bscdfsl') |
| | acc_file_path = os.path.join(params.checkpoint_dir, 'acc_bscdfsl.txt') |
| | acc_file = open(acc_file_path,'w') |
| | epoch_id = -1 |
| | print('epoch', epoch_id, 'ChestX:', 'ISIC:', 'EuroSAT:', 'CropDisease', file = acc_file) |
| | name = params.name |
| | n_shot = params.n_shot |
| | method = params.method |
| | test_bestmodel_bscdfsl(acc_file, name, method, 'ChestX', n_shot, epoch_id) |
| | test_bestmodel_bscdfsl(acc_file, name, method, 'ISIC', n_shot, epoch_id) |
| | test_bestmodel_bscdfsl(acc_file, name, method, 'EuroSAT', n_shot, epoch_id) |
| | test_bestmodel_bscdfsl(acc_file, name, method, 'CropDisease', n_shot, epoch_id) |
| |
|
| | acc_file.close() |
| | return |
| |
|
| |
|
| | |
| | if __name__=='__main__': |
| | |
| | seed = 0 |
| | print("set seed = %d" % seed) |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = False |
| |
|
| | |
| | params = parse_args('train') |
| |
|
| | |
| | params.tf_dir = '%s/log/%s'%(params.save_dir, params.name) |
| | params.checkpoint_dir = '%s/checkpoints/%s'%(params.save_dir, params.name) |
| | if not os.path.isdir(params.checkpoint_dir): |
| | os.makedirs(params.checkpoint_dir) |
| |
|
| | |
| | print('\n--- prepare dataloader ---') |
| | print(' train with single seen domain {}'.format(params.dataset)) |
| | base_file = os.path.join(params.data_dir, params.dataset, 'base.json') |
| | val_file = os.path.join(params.data_dir, params.dataset, 'val.json') |
| |
|
| | |
| | print('\n--- build model ---') |
| | image_size = 224 |
| | |
| | |
| | n_query = max(1, int(16* params.test_n_way/params.train_n_way)) |
| |
|
| | train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot) |
| | base_datamgr = SetDataManager(image_size, n_query = n_query, **train_few_shot_params) |
| | base_loader = base_datamgr.get_data_loader( base_file , aug = params.train_aug ) |
| |
|
| | test_few_shot_params = dict(n_way = params.test_n_way, n_support = params.n_shot) |
| | val_datamgr = SetDataManager(image_size, n_query = n_query, **test_few_shot_params) |
| | val_loader = val_datamgr.get_data_loader( val_file, aug = False) |
| |
|
| | model = StyleAdvGNN( model_dict[params.model], tf_path=params.tf_dir, **train_few_shot_params) |
| | model = model.cuda() |
| |
|
| | |
| | start_epoch = params.start_epoch |
| | stop_epoch = params.stop_epoch |
| | if params.resume != '': |
| | resume_file = get_resume_file('%s/checkpoints/%s'%(params.save_dir, params.resume), params.resume_epoch) |
| | if resume_file is not None: |
| | tmp = torch.load(resume_file) |
| | start_epoch = tmp['epoch']+1 |
| | model.load_state_dict(tmp['state']) |
| | print(' resume the training with at {} epoch (model file {})'.format(start_epoch, params.resume)) |
| | else: |
| | if params.warmup == 'gg3b0': |
| | raise Exception('Must provide the pre-trained feature encoder file using --warmup option!') |
| | state = load_warmup_state('%s/checkpoints/%s'%(params.save_dir, params.warmup), params.method) |
| | model.feature.load_state_dict(state, strict=False) |
| |
|
| | import time |
| | |
| | start =time.perf_counter() |
| | |
| | print('\n--- start the training ---') |
| | model = train(base_loader, val_loader, model, start_epoch, stop_epoch, params) |
| | |
| | end =time.perf_counter() |
| | print('Running time: %s Seconds: %s Min: %s Min per epoch'%(end-start, (end-start)/60, (end-start)/60/params.stop_epoch)) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|