| import numpy as np |
| import torch |
| import torch.optim |
| import os |
| import random |
|
|
| from methods import backbone |
| from methods.backbone_multiblock import model_dict |
| from data.datamgr import SimpleDataManager, SetDataManager |
| |
| from methods.CausalStyle_RN_GNN import CausalStyleGNN |
|
|
| from options import parse_args, get_resume_file, load_warmup_state |
| from test_function_fwt_benchmark import test_bestmodel |
| from test_function_bscdfsl_benchmark import test_bestmodel_bscdfsl |
|
|
|
|
| def train(base_loader, val_loader, model, start_epoch, stop_epoch, params): |
|
|
| |
| optimizer = torch.optim.Adam(model.parameters()) |
| if not os.path.isdir(params.checkpoint_dir): |
| os.makedirs(params.checkpoint_dir) |
|
|
| |
| max_acc = 0 |
| total_it = 0 |
|
|
| |
| for epoch in range(start_epoch, stop_epoch): |
| model.train() |
| total_it = model.train_loop(epoch, base_loader, optimizer, total_it) |
| model.eval() |
|
|
| acc = model.test_loop( val_loader) |
| if acc > max_acc : |
| print("best model! save...") |
| max_acc = acc |
| outfile = os.path.join(params.checkpoint_dir, 'best_model.tar') |
| torch.save({'epoch':epoch, 'state':model.state_dict()}, outfile) |
| else: |
| print("GG! best accuracy {:f}".format(max_acc)) |
|
|
| |
| if(epoch == stop_epoch-1): |
| outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch)) |
| torch.save({'epoch':epoch, 'state':model.state_dict()}, outfile) |
|
|
| return model |
|
|
|
|
| def record_test_result(params): |
| acc_file_path = os.path.join(params.checkpoint_dir, 'acc.txt') |
| acc_file = open(acc_file_path,'w') |
| epoch_id = -1 |
| print('epoch', epoch_id, 'miniImagenet:', 'cub:', 'cars:', 'places:', 'plantae:', file = acc_file) |
| name = params.name |
| n_shot = params.n_shot |
| method = params.method |
| |
| test_bestmodel(acc_file, name, method, 'cub', n_shot, epoch_id) |
| test_bestmodel(acc_file, name, method, 'cars', n_shot, epoch_id) |
| test_bestmodel(acc_file, name, method, 'places', n_shot, epoch_id) |
| test_bestmodel(acc_file, name, method, 'plantae', n_shot, epoch_id) |
|
|
| acc_file.close() |
| return |
|
|
|
|
| def record_test_result_bscdfsl(params): |
| print('hhhhhhh testing for bscdfsl') |
| acc_file_path = os.path.join(params.checkpoint_dir, 'acc_bscdfsl.txt') |
| acc_file = open(acc_file_path,'w') |
| epoch_id = -1 |
| print('epoch', epoch_id, 'ChestX:', 'ISIC:', 'EuroSAT:', 'CropDisease', file = acc_file) |
| name = params.name |
| n_shot = params.n_shot |
| method = params.method |
| test_bestmodel_bscdfsl(acc_file, name, method, 'ChestX', n_shot, epoch_id) |
| test_bestmodel_bscdfsl(acc_file, name, method, 'ISIC', n_shot, epoch_id) |
| test_bestmodel_bscdfsl(acc_file, name, method, 'EuroSAT', n_shot, epoch_id) |
| test_bestmodel_bscdfsl(acc_file, name, method, 'CropDisease', n_shot, epoch_id) |
|
|
| acc_file.close() |
| return |
|
|
|
|
| |
| if __name__=='__main__': |
| |
| seed = 0 |
| print("set seed = %d" % seed) |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
| |
| params = parse_args('train') |
|
|
| |
| params.tf_dir = '%s/log/%s'%(params.save_dir, params.name) |
| params.checkpoint_dir = '%s/checkpoints/%s'%(params.save_dir, params.name) |
| if not os.path.isdir(params.checkpoint_dir): |
| os.makedirs(params.checkpoint_dir) |
|
|
| |
| print('\n--- prepare dataloader ---') |
| print(' train with single seen domain {}'.format(params.dataset)) |
| base_file = os.path.join(params.data_dir, params.dataset, 'base.json') |
| val_file = os.path.join(params.data_dir, params.dataset, 'val.json') |
|
|
| |
| print('\n--- build model ---') |
| image_size = 224 |
| |
| |
| n_query = max(1, int(16* params.test_n_way/params.train_n_way)) |
|
|
| train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot) |
| base_datamgr = SetDataManager(image_size, n_query = n_query, **train_few_shot_params) |
| base_loader = base_datamgr.get_data_loader( base_file , aug = params.train_aug ) |
|
|
| test_few_shot_params = dict(n_way = params.test_n_way, n_support = params.n_shot) |
| val_datamgr = SetDataManager(image_size, n_query = n_query, **test_few_shot_params) |
| val_loader = val_datamgr.get_data_loader( val_file, aug = False) |
|
|
| model = CausalStyleGNN( model_dict[params.model], tf_path=params.tf_dir, **train_few_shot_params) |
| model = model.cuda() |
|
|
| |
| start_epoch = params.start_epoch |
| stop_epoch = params.stop_epoch |
| if params.resume != '': |
| resume_file = get_resume_file('%s/checkpoints/%s'%(params.save_dir, params.resume), params.resume_epoch) |
| if resume_file is not None: |
| tmp = torch.load(resume_file) |
| start_epoch = tmp['epoch']+1 |
| model.load_state_dict(tmp['state']) |
| print(' resume the training with at {} epoch (model file {})'.format(start_epoch, params.resume)) |
| else: |
| if params.warmup == 'gg3b0': |
| raise Exception('Must provide the pre-trained feature encoder file using --warmup option!') |
| state = load_warmup_state('%s/checkpoints/%s'%(params.save_dir, params.warmup), params.method) |
| model.feature.load_state_dict(state, strict=False) |
|
|
| import time |
| |
| start =time.perf_counter() |
| |
| print('\n--- start the training ---') |
| model = train(base_loader, val_loader, model, start_epoch, stop_epoch, params) |
| |
| end =time.perf_counter() |
| print('Running time: %s Seconds: %s Min: %s Min per epoch'%(end-start, (end-start)/60, (end-start)/60/params.stop_epoch)) |
|
|
| |
| record_test_result(params) |
| |
| record_test_result_bscdfsl(params) |
|
|
|
|