import argparse import json import os import caffe2.contrib.playground.AnyExp as AnyExp import caffe2.contrib.playground.checkpoint as checkpoint import logging logging.basicConfig() log = logging.getLogger("AnyExpOnTerm") log.setLevel(logging.DEBUG) def runShardedTrainLoop(opts, myTrainFun): start_epoch = 0 pretrained_model = opts['model_param']['pretrained_model'] if pretrained_model != '' and os.path.exists(pretrained_model): # Only want to get start_epoch. start_epoch, prev_checkpointed_lr, best_metric = \ checkpoint.initialize_params_from_file( model=None, weights_file=pretrained_model, num_xpus=1, opts=opts, broadcast_computed_param=True, reset_epoch=opts['model_param']['reset_epoch'], ) log.info('start epoch: {}'.format(start_epoch)) pretrained_model = None if pretrained_model == '' else pretrained_model ret = None pretrained_model = "" shard_results = [] for epoch in range(start_epoch, opts['epoch_iter']['num_epochs'], opts['epoch_iter']['num_epochs_per_flow_schedule']): # must support checkpoint or the multiple schedule will always # start from initial state checkpoint_model = None if epoch == start_epoch else ret['model'] pretrained_model = None if epoch > start_epoch else pretrained_model shard_results = [] # with LexicalContext('epoch{}_gang'.format(epoch),gang_schedule=False): for shard_id in range(opts['distributed']['num_shards']): opts['temp_var']['shard_id'] = shard_id opts['temp_var']['pretrained_model'] = pretrained_model opts['temp_var']['checkpoint_model'] = checkpoint_model opts['temp_var']['epoch'] = epoch opts['temp_var']['start_epoch'] = start_epoch shard_ret = myTrainFun(opts) shard_results.append(shard_ret) ret = None # always only take shard_0 return for shard_ret in shard_results: if shard_ret is not None: ret = shard_ret opts['temp_var']['metrics_output'] = ret['metrics'] break log.info('ret is: {}'.format(str(ret))) return ret def trainFun(): def simpleTrainFun(opts): trainerClass = AnyExp.createTrainerClass(opts) trainerClass = AnyExp.overrideAdditionalMethods(trainerClass, opts) trainer = trainerClass(opts) return trainer.buildModelAndTrain(opts) return simpleTrainFun if __name__ == '__main__': parser = argparse.ArgumentParser(description='Any Experiment training.') parser.add_argument("--parameters-json", type=json.loads, help='model options in json format', dest="params") args = parser.parse_args() opts = args.params['opts'] opts = AnyExp.initOpts(opts) log.info('opts is: {}'.format(str(opts))) AnyExp.initDefaultModuleMap() opts['input']['datasets'] = AnyExp.aquireDatasets(opts) # defined this way so that AnyExp.trainFun(opts) can be replaced with # some other custermized training function. ret = runShardedTrainLoop(opts, trainFun()) log.info('ret is: {}'.format(str(ret)))