""" The main entry point for training policies. Args: config (str): path to a config json that will be used to override the default settings. If omitted, default settings are used. This is the preferred way to run experiments. algo (str): name of the algorithm to run. Only needs to be provided if @config is not provided. name (str): if provided, override the experiment name defined in the config dataset (str): if provided, override the dataset path defined in the config debug (bool): set this flag to run a quick training run for debugging purposes """ import argparse import json import numpy as np import time import os import shutil import psutil import sys import socket import traceback from collections import OrderedDict import torch from torch.utils.data import DataLoader import robomimic import robomimic.macros as Macros import robomimic.utils.train_utils as TrainUtils import robomimic.utils.torch_utils as TorchUtils import robomimic.utils.obs_utils as ObsUtils import robomimic.utils.env_utils as EnvUtils import robomimic.utils.file_utils as FileUtils from robomimic.config import config_factory from robomimic.algo import algo_factory, RolloutPolicy from robomimic.utils.log_utils import PrintLogger, DataLogger, flush_warnings def train(config, device, auto_remove_exp=False): """ Train a model using the algorithm. """ # time this run start_time = time.time() # first set seeds np.random.seed(config.train.seed) torch.manual_seed(config.train.seed) torch.set_num_threads(2) print("\n============= New Training Run with Config =============") print(config) print("") log_dir, ckpt_dir, video_dir = TrainUtils.get_exp_dir(config, auto_remove_exp_dir=auto_remove_exp) if config.experiment.logging.terminal_output_to_txt: # log stdout and stderr to a text file logger = PrintLogger(os.path.join(log_dir, 'log.txt')) sys.stdout = logger sys.stderr = logger # read config to set up metadata for observation modalities (e.g. detecting rgb observations) ObsUtils.initialize_obs_utils_with_config(config) # make sure the dataset exists if isinstance(config.train.data, str): dataset_path = os.path.expandvars(os.path.expanduser(config.train.data)) else: eval_dataset_cfg = config.train.data[0] dataset_path = os.path.expandvars(os.path.expanduser(eval_dataset_cfg["path"])) ds_format = config.train.data_format if not os.path.exists(dataset_path): raise Exception("Dataset at provided path {} not found!".format(dataset_path)) # load basic metadata from training file print("\n============= Loaded Environment Metadata =============") env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path, ds_format=ds_format) # update env meta if applicable from robomimic.utils.script_utils import deep_update deep_update(env_meta, config.experiment.env_meta_update_dict) shape_meta = FileUtils.get_shape_metadata_from_dataset( dataset_path=dataset_path, action_keys=config.train.action_keys, all_obs_keys=config.all_obs_keys, ds_format=ds_format, verbose=True ) if config.experiment.env is not None: env_meta["env_name"] = config.experiment.env print("=" * 30 + "\n" + "Replacing Env to {}\n".format(env_meta["env_name"]) + "=" * 30) # create environment envs = OrderedDict() if config.experiment.rollout.enabled: # create environments for validation runs env_names = [env_meta["env_name"]] if config.experiment.additional_envs is not None: for name in config.experiment.additional_envs: env_names.append(name) for env_name in env_names: env = EnvUtils.create_env_from_metadata( env_meta=env_meta, env_name=env_name, render=config.experiment.render, render_offscreen=config.experiment.render_video, use_image_obs=shape_meta["use_images"], use_depth_obs=shape_meta["use_depths"], ) env = EnvUtils.wrap_env_from_config(env, config=config) # apply environment warpper, if applicable envs[env.name] = env print(envs[env.name]) print("") # setup for a new training run data_logger = DataLogger( log_dir, config, log_tb=config.experiment.logging.log_tb, log_wandb=config.experiment.logging.log_wandb, ) model = algo_factory( algo_name=config.algo_name, config=config, obs_key_shapes=shape_meta["all_shapes"], ac_dim=shape_meta["ac_dim"], device=device, ) # save the config as a json file with open(os.path.join(log_dir, '..', 'config.json'), 'w') as outfile: json.dump(config, outfile, indent=4) print("\n============= Model Summary =============") print(model) # print model summary print("") # load training data trainset, validset = TrainUtils.load_data_for_training( config, obs_keys=shape_meta["all_obs_keys"]) train_sampler = trainset.get_dataset_sampler() print("\n============= Training Dataset =============") print(trainset) print("") if validset is not None: print("\n============= Validation Dataset =============") print(validset) print("") # maybe retreve statistics for normalizing observations obs_normalization_stats = None if config.train.hdf5_normalize_obs: obs_normalization_stats = trainset.get_obs_normalization_stats() # maybe retreve statistics for normalizing actions action_normalization_stats = trainset.get_action_normalization_stats() # initialize data loaders train_loader = DataLoader( dataset=trainset, sampler=train_sampler, batch_size=config.train.batch_size, shuffle=(train_sampler is None), num_workers=config.train.num_data_workers, drop_last=True ) if config.experiment.validate: # cap num workers for validation dataset at 1 num_workers = min(config.train.num_data_workers, 1) valid_sampler = validset.get_dataset_sampler() valid_loader = DataLoader( dataset=validset, sampler=valid_sampler, batch_size=config.train.batch_size, shuffle=(valid_sampler is None), num_workers=num_workers, drop_last=True ) else: valid_loader = None # print all warnings before training begins print("*" * 50) print("Warnings generated by robomimic have been duplicated here (from above) for convenience. Please check them carefully.") flush_warnings() print("*" * 50) print("") # main training loop best_valid_loss = None best_return = {k: -np.inf for k in envs} if config.experiment.rollout.enabled else None best_success_rate = {k: -1. for k in envs} if config.experiment.rollout.enabled else None last_ckpt_time = time.time() need_sync_results = (Macros.RESULTS_SYNC_PATH_ABS is not None) if need_sync_results: # these paths will be updated after each evaluation best_ckpt_path_synced = None best_video_path_synced = None last_ckpt_path_synced = None last_video_path_synced = None log_dir_path_synced = os.path.join(Macros.RESULTS_SYNC_PATH_ABS, "logs") # number of learning steps per epoch (defaults to a full dataset pass) train_num_steps = config.experiment.epoch_every_n_steps valid_num_steps = config.experiment.validation_epoch_every_n_steps for epoch in range(1, config.train.num_epochs + 1): # epoch numbers start at 1 step_log = TrainUtils.run_epoch( model=model, data_loader=train_loader, epoch=epoch, num_steps=train_num_steps, obs_normalization_stats=obs_normalization_stats, ) model.on_epoch_end(epoch) # setup checkpoint path epoch_ckpt_name = "model_epoch_{}".format(epoch) # check for recurring checkpoint saving conditions should_save_ckpt = False if config.experiment.save.enabled: time_check = (config.experiment.save.every_n_seconds is not None) and \ (time.time() - last_ckpt_time > config.experiment.save.every_n_seconds) epoch_check = (config.experiment.save.every_n_epochs is not None) and \ (epoch > 0) and (epoch % config.experiment.save.every_n_epochs == 0) epoch_list_check = (epoch in config.experiment.save.epochs) should_save_ckpt = (time_check or epoch_check or epoch_list_check) ckpt_reason = None if should_save_ckpt: last_ckpt_time = time.time() ckpt_reason = "time" print("Train Epoch {}".format(epoch)) print(json.dumps(step_log, sort_keys=True, indent=4)) for k, v in step_log.items(): if k.startswith("Time_"): data_logger.record("Timing_Stats/Train_{}".format(k[5:]), v, epoch) else: data_logger.record("Train/{}".format(k), v, epoch) # Evaluate the model on validation set if config.experiment.validate: with torch.no_grad(): step_log = TrainUtils.run_epoch(model=model, data_loader=valid_loader, epoch=epoch, validate=True, num_steps=valid_num_steps) for k, v in step_log.items(): if k.startswith("Time_"): data_logger.record("Timing_Stats/Valid_{}".format(k[5:]), v, epoch) else: data_logger.record("Valid/{}".format(k), v, epoch) print("Validation Epoch {}".format(epoch)) print(json.dumps(step_log, sort_keys=True, indent=4)) # save checkpoint if achieve new best validation loss valid_check = "Loss" in step_log if valid_check and (best_valid_loss is None or (step_log["Loss"] <= best_valid_loss)): best_valid_loss = step_log["Loss"] if config.experiment.save.enabled and config.experiment.save.on_best_validation: epoch_ckpt_name += "_best_validation_{}".format(best_valid_loss) should_save_ckpt = True ckpt_reason = "valid" if ckpt_reason is None else ckpt_reason # Evaluate the model by by running rollouts # do rollouts at fixed rate or if it's time to save a new ckpt video_paths = None rollout_check = (epoch % config.experiment.rollout.rate == 0) or (should_save_ckpt and ckpt_reason == "time") did_rollouts = False if config.experiment.rollout.enabled and (epoch > config.experiment.rollout.warmstart) and rollout_check: # wrap model as a RolloutPolicy to prepare for rollouts rollout_model = RolloutPolicy( model, obs_normalization_stats=obs_normalization_stats, action_normalization_stats=action_normalization_stats, ) num_episodes = config.experiment.rollout.n all_rollout_logs, video_paths = TrainUtils.rollout_with_stats( policy=rollout_model, envs=envs, horizon=config.experiment.rollout.horizon, use_goals=config.use_goals, num_episodes=num_episodes, render=False, video_dir=video_dir if config.experiment.render_video else None, epoch=epoch, video_skip=config.experiment.get("video_skip", 5), terminate_on_success=config.experiment.rollout.terminate_on_success, ) # summarize results from rollouts to tensorboard and terminal for env_name in all_rollout_logs: rollout_logs = all_rollout_logs[env_name] for k, v in rollout_logs.items(): if k.startswith("Time_"): data_logger.record("Timing_Stats/Rollout_{}_{}".format(env_name, k[5:]), v, epoch) else: data_logger.record("Rollout/{}/{}".format(k, env_name), v, epoch, log_stats=True) print("\nEpoch {} Rollouts took {}s (avg) with results:".format(epoch, rollout_logs["time"])) print('Env: {}'.format(env_name)) print(json.dumps(rollout_logs, sort_keys=True, indent=4)) # checkpoint and video saving logic updated_stats = TrainUtils.should_save_from_rollout_logs( all_rollout_logs=all_rollout_logs, best_return=best_return, best_success_rate=best_success_rate, epoch_ckpt_name=epoch_ckpt_name, save_on_best_rollout_return=config.experiment.save.on_best_rollout_return, save_on_best_rollout_success_rate=config.experiment.save.on_best_rollout_success_rate, ) best_return = updated_stats["best_return"] best_success_rate = updated_stats["best_success_rate"] epoch_ckpt_name = updated_stats["epoch_ckpt_name"] should_save_ckpt = (config.experiment.save.enabled and updated_stats["should_save_ckpt"]) or should_save_ckpt if updated_stats["ckpt_reason"] is not None: ckpt_reason = updated_stats["ckpt_reason"] did_rollouts = True # Only keep saved videos if the ckpt should be saved (but not because of validation score) should_save_video = (should_save_ckpt and (ckpt_reason != "valid")) or config.experiment.keep_all_videos if video_paths is not None and not should_save_video: for env_name in video_paths: os.remove(video_paths[env_name]) # Save model checkpoints based on conditions (success rate, validation loss, etc) if should_save_ckpt: TrainUtils.save_model( model=model, config=config, env_meta=env_meta, shape_meta=shape_meta, ckpt_path=os.path.join(ckpt_dir, epoch_ckpt_name + ".pth"), obs_normalization_stats=obs_normalization_stats, action_normalization_stats=action_normalization_stats, ) # maybe sync some results back to scratch space (only if rollouts happened) if did_rollouts and need_sync_results: print("Sync results back to sync path: {}".format(Macros.RESULTS_SYNC_PATH_ABS)) # get best and latest model checkpoints and videos best_ckpt_path_to_sync, best_video_path_to_sync, best_epoch_to_sync = TrainUtils.get_model_from_output_folder( models_path=ckpt_dir, videos_path=video_dir if config.experiment.render_video else None, best=True, ) last_ckpt_path_to_sync, last_video_path_to_sync, last_epoch_to_sync = TrainUtils.get_model_from_output_folder( models_path=ckpt_dir, videos_path=video_dir if config.experiment.render_video else None, last=True, ) # clear last files that we synced over if best_ckpt_path_synced is not None: os.remove(best_ckpt_path_synced) if last_ckpt_path_synced is not None: os.remove(last_ckpt_path_synced) if best_video_path_synced is not None: os.remove(best_video_path_synced) if last_video_path_synced is not None: os.remove(last_video_path_synced) if os.path.exists(log_dir_path_synced): shutil.rmtree(log_dir_path_synced) # set write paths and sync new files over best_success_rate_for_sync = float(best_ckpt_path_to_sync.split("success_")[-1][:-4]) best_ckpt_path_synced = os.path.join( Macros.RESULTS_SYNC_PATH_ABS, os.path.basename(best_ckpt_path_to_sync)[:-4] + "_best.pth", ) shutil.copyfile(best_ckpt_path_to_sync, best_ckpt_path_synced) last_ckpt_path_synced = os.path.join( Macros.RESULTS_SYNC_PATH_ABS, os.path.basename(last_ckpt_path_to_sync)[:-4] + "_last.pth", ) shutil.copyfile(last_ckpt_path_to_sync, last_ckpt_path_synced) if config.experiment.render_video: best_video_path_synced = os.path.join( Macros.RESULTS_SYNC_PATH_ABS, os.path.basename(best_video_path_to_sync)[:-4] + "_best_{}.mp4".format(best_success_rate_for_sync), ) shutil.copyfile(best_video_path_to_sync, best_video_path_synced) last_video_path_synced = os.path.join( Macros.RESULTS_SYNC_PATH_ABS, os.path.basename(last_video_path_to_sync)[:-4] + "_last.mp4", ) shutil.copyfile(last_video_path_to_sync, last_video_path_synced) # sync logs dir shutil.copytree(log_dir, log_dir_path_synced) # sync config json shutil.copyfile( os.path.join(log_dir, '..', 'config.json'), os.path.join(Macros.RESULTS_SYNC_PATH_ABS, 'config.json') ) # Finally, log memory usage in MB process = psutil.Process(os.getpid()) mem_usage = int(process.memory_info().rss / 1000000) data_logger.record("System/RAM Usage (MB)", mem_usage, epoch) print("\nEpoch {} Memory Usage: {} MB\n".format(epoch, mem_usage)) # terminate logging data_logger.close() # sync logs after closing data logger to make sure everything was transferred if need_sync_results: print("Sync results back to sync path: {}".format(Macros.RESULTS_SYNC_PATH_ABS)) # sync logs dir if os.path.exists(log_dir_path_synced): shutil.rmtree(log_dir_path_synced) shutil.copytree(log_dir, log_dir_path_synced) # collect important statistics important_stats = dict() prefix = "Rollout/Success_Rate/" exception_prefix = "Rollout/Exception_Rate/" for k in data_logger._data: if k.startswith(prefix): suffix = k[len(prefix):] stats = data_logger.get_stats(k) important_stats["{}-max".format(suffix)] = stats["max"] important_stats["{}-mean".format(suffix)] = stats["mean"] elif k.startswith(exception_prefix): suffix = k[len(exception_prefix):] stats = data_logger.get_stats(k) important_stats["{}-exception-rate-max".format(suffix)] = stats["max"] important_stats["{}-exception-rate-mean".format(suffix)] = stats["mean"] # add in time taken important_stats["time spent (hrs)"] = "{:.2f}".format((time.time() - start_time) / 3600.) # write stats to disk json_file_path = os.path.join(log_dir, "important_stats.json") with open(json_file_path, 'w') as f: # preserve original key ordering json.dump(important_stats, f, sort_keys=False, indent=4) return important_stats def main(args): if args.config is not None: ext_cfg = json.load(open(args.config, 'r')) config = config_factory(ext_cfg["algo_name"]) # update config with external json - this will throw errors if # the external config has keys not present in the base algo config with config.values_unlocked(): config.update(ext_cfg) else: config = config_factory(args.algo) if args.dataset is not None: config.train.data = [dict(path=args.dataset)] if args.name is not None: config.experiment.name = args.name if args.output is not None: config.train.output_dir = args.output # get torch device device = TorchUtils.get_torch_device(try_to_use_cuda=config.train.cuda) # maybe modify config for debugging purposes if args.debug: Macros.DEBUG = True # shrink length of training to test whether this run is likely to crash config.unlock() config.lock_keys() # train and validate (if enabled) for 3 gradient steps, for 2 epochs config.experiment.epoch_every_n_steps = 3 config.experiment.validation_epoch_every_n_steps = 3 config.train.num_epochs = 2 # if rollouts are enabled, try 2 rollouts at end of each epoch, with 10 environment steps config.experiment.rollout.rate = 1 config.experiment.rollout.n = 2 config.experiment.rollout.horizon = 10 # send output to a temporary directory config.train.output_dir = "/tmp/tmp_trained_models" # lock config to prevent further modifications and ensure missing keys raise errors config.lock() # catch error during training and print it res_str = "finished run successfully!" important_stats = None try: important_stats = train(config, device=device, auto_remove_exp=args.auto_remove_exp) except Exception as e: res_str = "run failed with error:\n{}\n\n{}".format(e, traceback.format_exc()) print(res_str) if important_stats is not None: important_stats = json.dumps(important_stats, indent=4) print("\nRollout Success Rate Stats") print(important_stats) # maybe sync important stats back if Macros.RESULTS_SYNC_PATH_ABS is not None: json_file_path = os.path.join(Macros.RESULTS_SYNC_PATH_ABS, "important_stats.json") with open(json_file_path, 'w') as f: # preserve original key ordering json.dump(important_stats, f, sort_keys=False, indent=4) # maybe give slack notification if Macros.SLACK_TOKEN is not None: from robomimic.scripts.give_slack_notification import give_slack_notif msg = "Completed the following training run!\nHostname: {}\nExperiment Name: {}\n".format(socket.gethostname(), config.experiment.name) msg += "```{}```".format(res_str) if important_stats is not None: msg += "\nRollout Success Rate Stats" msg += "\n```{}```".format(important_stats) give_slack_notif(msg) if __name__ == "__main__": parser = argparse.ArgumentParser() # External config file that overwrites default config parser.add_argument( "--config", type=str, default=None, help="(optional) path to a config json that will be used to override the default settings. \ If omitted, default settings are used. This is the preferred way to run experiments.", ) # Algorithm Name parser.add_argument( "--algo", type=str, help="(optional) name of algorithm to run. Only needs to be provided if --config is not provided", ) # Experiment Name (for tensorboard, saving models, etc.) parser.add_argument( "--name", type=str, default=None, help="(optional) if provided, override the experiment name defined in the config", ) # Dataset path, to override the one in the config parser.add_argument( "--dataset", type=str, default=None, help="(optional) if provided, override the dataset path defined in the config", ) # Output path, to override the one in the config parser.add_argument( "--output", type=str, default=None, help="(optional) if provided, override the output folder path defined in the config", ) # force delete the experiment folder if it exists parser.add_argument( "--auto-remove-exp", action='store_true', help="force delete the experiment folder if it exists" ) # debug mode parser.add_argument( "--debug", action='store_true', help="set this flag to run a quick training run for debugging purposes" ) args = parser.parse_args() main(args)