Spaces:
Sleeping
Sleeping
| """ | |
| The main entry point for training policies. | |
| Args: | |
| config (str): path to a config json that will be used to override the default settings. | |
| If omitted, default settings are used. This is the preferred way to run experiments. | |
| algo (str): name of the algorithm to run. Only needs to be provided if @config is not | |
| provided. | |
| name (str): if provided, override the experiment name defined in the config | |
| dataset (str): if provided, override the dataset path defined in the config | |
| debug (bool): set this flag to run a quick training run for debugging purposes | |
| """ | |
| import argparse | |
| import json | |
| import numpy as np | |
| import time | |
| import os | |
| import shutil | |
| import psutil | |
| import sys | |
| import socket | |
| import traceback | |
| from collections import OrderedDict | |
| import torch | |
| from torch.utils.data import DataLoader | |
| import robomimic | |
| import robomimic.macros as Macros | |
| import robomimic.utils.train_utils as TrainUtils | |
| import robomimic.utils.torch_utils as TorchUtils | |
| import robomimic.utils.obs_utils as ObsUtils | |
| import robomimic.utils.env_utils as EnvUtils | |
| import robomimic.utils.file_utils as FileUtils | |
| from robomimic.config import config_factory | |
| from robomimic.algo import algo_factory, RolloutPolicy | |
| from robomimic.utils.log_utils import PrintLogger, DataLogger, flush_warnings | |
| def train(config, device, auto_remove_exp=False): | |
| """ | |
| Train a model using the algorithm. | |
| """ | |
| # time this run | |
| start_time = time.time() | |
| # first set seeds | |
| np.random.seed(config.train.seed) | |
| torch.manual_seed(config.train.seed) | |
| torch.set_num_threads(2) | |
| print("\n============= New Training Run with Config =============") | |
| print(config) | |
| print("") | |
| log_dir, ckpt_dir, video_dir = TrainUtils.get_exp_dir(config, auto_remove_exp_dir=auto_remove_exp) | |
| if config.experiment.logging.terminal_output_to_txt: | |
| # log stdout and stderr to a text file | |
| logger = PrintLogger(os.path.join(log_dir, 'log.txt')) | |
| sys.stdout = logger | |
| sys.stderr = logger | |
| # read config to set up metadata for observation modalities (e.g. detecting rgb observations) | |
| ObsUtils.initialize_obs_utils_with_config(config) | |
| # make sure the dataset exists | |
| if isinstance(config.train.data, str): | |
| dataset_path = os.path.expandvars(os.path.expanduser(config.train.data)) | |
| else: | |
| eval_dataset_cfg = config.train.data[0] | |
| dataset_path = os.path.expandvars(os.path.expanduser(eval_dataset_cfg["path"])) | |
| ds_format = config.train.data_format | |
| if not os.path.exists(dataset_path): | |
| raise Exception("Dataset at provided path {} not found!".format(dataset_path)) | |
| # load basic metadata from training file | |
| print("\n============= Loaded Environment Metadata =============") | |
| env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path, ds_format=ds_format) | |
| # update env meta if applicable | |
| from robomimic.utils.script_utils import deep_update | |
| deep_update(env_meta, config.experiment.env_meta_update_dict) | |
| shape_meta = FileUtils.get_shape_metadata_from_dataset( | |
| dataset_path=dataset_path, | |
| action_keys=config.train.action_keys, | |
| all_obs_keys=config.all_obs_keys, | |
| ds_format=ds_format, | |
| verbose=True | |
| ) | |
| if config.experiment.env is not None: | |
| env_meta["env_name"] = config.experiment.env | |
| print("=" * 30 + "\n" + "Replacing Env to {}\n".format(env_meta["env_name"]) + "=" * 30) | |
| # create environment | |
| envs = OrderedDict() | |
| if config.experiment.rollout.enabled: | |
| # create environments for validation runs | |
| env_names = [env_meta["env_name"]] | |
| if config.experiment.additional_envs is not None: | |
| for name in config.experiment.additional_envs: | |
| env_names.append(name) | |
| for env_name in env_names: | |
| env = EnvUtils.create_env_from_metadata( | |
| env_meta=env_meta, | |
| env_name=env_name, | |
| render=config.experiment.render, | |
| render_offscreen=config.experiment.render_video, | |
| use_image_obs=shape_meta["use_images"], | |
| use_depth_obs=shape_meta["use_depths"], | |
| ) | |
| env = EnvUtils.wrap_env_from_config(env, config=config) # apply environment warpper, if applicable | |
| envs[env.name] = env | |
| print(envs[env.name]) | |
| print("") | |
| # setup for a new training run | |
| data_logger = DataLogger( | |
| log_dir, | |
| config, | |
| log_tb=config.experiment.logging.log_tb, | |
| log_wandb=config.experiment.logging.log_wandb, | |
| ) | |
| model = algo_factory( | |
| algo_name=config.algo_name, | |
| config=config, | |
| obs_key_shapes=shape_meta["all_shapes"], | |
| ac_dim=shape_meta["ac_dim"], | |
| device=device, | |
| ) | |
| # save the config as a json file | |
| with open(os.path.join(log_dir, '..', 'config.json'), 'w') as outfile: | |
| json.dump(config, outfile, indent=4) | |
| print("\n============= Model Summary =============") | |
| print(model) # print model summary | |
| print("") | |
| # load training data | |
| trainset, validset = TrainUtils.load_data_for_training( | |
| config, obs_keys=shape_meta["all_obs_keys"]) | |
| train_sampler = trainset.get_dataset_sampler() | |
| print("\n============= Training Dataset =============") | |
| print(trainset) | |
| print("") | |
| if validset is not None: | |
| print("\n============= Validation Dataset =============") | |
| print(validset) | |
| print("") | |
| # maybe retreve statistics for normalizing observations | |
| obs_normalization_stats = None | |
| if config.train.hdf5_normalize_obs: | |
| obs_normalization_stats = trainset.get_obs_normalization_stats() | |
| # maybe retreve statistics for normalizing actions | |
| action_normalization_stats = trainset.get_action_normalization_stats() | |
| # initialize data loaders | |
| train_loader = DataLoader( | |
| dataset=trainset, | |
| sampler=train_sampler, | |
| batch_size=config.train.batch_size, | |
| shuffle=(train_sampler is None), | |
| num_workers=config.train.num_data_workers, | |
| drop_last=True | |
| ) | |
| if config.experiment.validate: | |
| # cap num workers for validation dataset at 1 | |
| num_workers = min(config.train.num_data_workers, 1) | |
| valid_sampler = validset.get_dataset_sampler() | |
| valid_loader = DataLoader( | |
| dataset=validset, | |
| sampler=valid_sampler, | |
| batch_size=config.train.batch_size, | |
| shuffle=(valid_sampler is None), | |
| num_workers=num_workers, | |
| drop_last=True | |
| ) | |
| else: | |
| valid_loader = None | |
| # print all warnings before training begins | |
| print("*" * 50) | |
| print("Warnings generated by robomimic have been duplicated here (from above) for convenience. Please check them carefully.") | |
| flush_warnings() | |
| print("*" * 50) | |
| print("") | |
| # main training loop | |
| best_valid_loss = None | |
| best_return = {k: -np.inf for k in envs} if config.experiment.rollout.enabled else None | |
| best_success_rate = {k: -1. for k in envs} if config.experiment.rollout.enabled else None | |
| last_ckpt_time = time.time() | |
| need_sync_results = (Macros.RESULTS_SYNC_PATH_ABS is not None) | |
| if need_sync_results: | |
| # these paths will be updated after each evaluation | |
| best_ckpt_path_synced = None | |
| best_video_path_synced = None | |
| last_ckpt_path_synced = None | |
| last_video_path_synced = None | |
| log_dir_path_synced = os.path.join(Macros.RESULTS_SYNC_PATH_ABS, "logs") | |
| # number of learning steps per epoch (defaults to a full dataset pass) | |
| train_num_steps = config.experiment.epoch_every_n_steps | |
| valid_num_steps = config.experiment.validation_epoch_every_n_steps | |
| for epoch in range(1, config.train.num_epochs + 1): # epoch numbers start at 1 | |
| step_log = TrainUtils.run_epoch( | |
| model=model, | |
| data_loader=train_loader, | |
| epoch=epoch, | |
| num_steps=train_num_steps, | |
| obs_normalization_stats=obs_normalization_stats, | |
| ) | |
| model.on_epoch_end(epoch) | |
| # setup checkpoint path | |
| epoch_ckpt_name = "model_epoch_{}".format(epoch) | |
| # check for recurring checkpoint saving conditions | |
| should_save_ckpt = False | |
| if config.experiment.save.enabled: | |
| time_check = (config.experiment.save.every_n_seconds is not None) and \ | |
| (time.time() - last_ckpt_time > config.experiment.save.every_n_seconds) | |
| epoch_check = (config.experiment.save.every_n_epochs is not None) and \ | |
| (epoch > 0) and (epoch % config.experiment.save.every_n_epochs == 0) | |
| epoch_list_check = (epoch in config.experiment.save.epochs) | |
| should_save_ckpt = (time_check or epoch_check or epoch_list_check) | |
| ckpt_reason = None | |
| if should_save_ckpt: | |
| last_ckpt_time = time.time() | |
| ckpt_reason = "time" | |
| print("Train Epoch {}".format(epoch)) | |
| print(json.dumps(step_log, sort_keys=True, indent=4)) | |
| for k, v in step_log.items(): | |
| if k.startswith("Time_"): | |
| data_logger.record("Timing_Stats/Train_{}".format(k[5:]), v, epoch) | |
| else: | |
| data_logger.record("Train/{}".format(k), v, epoch) | |
| # Evaluate the model on validation set | |
| if config.experiment.validate: | |
| with torch.no_grad(): | |
| step_log = TrainUtils.run_epoch(model=model, data_loader=valid_loader, epoch=epoch, validate=True, num_steps=valid_num_steps) | |
| for k, v in step_log.items(): | |
| if k.startswith("Time_"): | |
| data_logger.record("Timing_Stats/Valid_{}".format(k[5:]), v, epoch) | |
| else: | |
| data_logger.record("Valid/{}".format(k), v, epoch) | |
| print("Validation Epoch {}".format(epoch)) | |
| print(json.dumps(step_log, sort_keys=True, indent=4)) | |
| # save checkpoint if achieve new best validation loss | |
| valid_check = "Loss" in step_log | |
| if valid_check and (best_valid_loss is None or (step_log["Loss"] <= best_valid_loss)): | |
| best_valid_loss = step_log["Loss"] | |
| if config.experiment.save.enabled and config.experiment.save.on_best_validation: | |
| epoch_ckpt_name += "_best_validation_{}".format(best_valid_loss) | |
| should_save_ckpt = True | |
| ckpt_reason = "valid" if ckpt_reason is None else ckpt_reason | |
| # Evaluate the model by by running rollouts | |
| # do rollouts at fixed rate or if it's time to save a new ckpt | |
| video_paths = None | |
| rollout_check = (epoch % config.experiment.rollout.rate == 0) or (should_save_ckpt and ckpt_reason == "time") | |
| did_rollouts = False | |
| if config.experiment.rollout.enabled and (epoch > config.experiment.rollout.warmstart) and rollout_check: | |
| # wrap model as a RolloutPolicy to prepare for rollouts | |
| rollout_model = RolloutPolicy( | |
| model, | |
| obs_normalization_stats=obs_normalization_stats, | |
| action_normalization_stats=action_normalization_stats, | |
| ) | |
| num_episodes = config.experiment.rollout.n | |
| all_rollout_logs, video_paths = TrainUtils.rollout_with_stats( | |
| policy=rollout_model, | |
| envs=envs, | |
| horizon=config.experiment.rollout.horizon, | |
| use_goals=config.use_goals, | |
| num_episodes=num_episodes, | |
| render=False, | |
| video_dir=video_dir if config.experiment.render_video else None, | |
| epoch=epoch, | |
| video_skip=config.experiment.get("video_skip", 5), | |
| terminate_on_success=config.experiment.rollout.terminate_on_success, | |
| ) | |
| # summarize results from rollouts to tensorboard and terminal | |
| for env_name in all_rollout_logs: | |
| rollout_logs = all_rollout_logs[env_name] | |
| for k, v in rollout_logs.items(): | |
| if k.startswith("Time_"): | |
| data_logger.record("Timing_Stats/Rollout_{}_{}".format(env_name, k[5:]), v, epoch) | |
| else: | |
| data_logger.record("Rollout/{}/{}".format(k, env_name), v, epoch, log_stats=True) | |
| print("\nEpoch {} Rollouts took {}s (avg) with results:".format(epoch, rollout_logs["time"])) | |
| print('Env: {}'.format(env_name)) | |
| print(json.dumps(rollout_logs, sort_keys=True, indent=4)) | |
| # checkpoint and video saving logic | |
| updated_stats = TrainUtils.should_save_from_rollout_logs( | |
| all_rollout_logs=all_rollout_logs, | |
| best_return=best_return, | |
| best_success_rate=best_success_rate, | |
| epoch_ckpt_name=epoch_ckpt_name, | |
| save_on_best_rollout_return=config.experiment.save.on_best_rollout_return, | |
| save_on_best_rollout_success_rate=config.experiment.save.on_best_rollout_success_rate, | |
| ) | |
| best_return = updated_stats["best_return"] | |
| best_success_rate = updated_stats["best_success_rate"] | |
| epoch_ckpt_name = updated_stats["epoch_ckpt_name"] | |
| should_save_ckpt = (config.experiment.save.enabled and updated_stats["should_save_ckpt"]) or should_save_ckpt | |
| if updated_stats["ckpt_reason"] is not None: | |
| ckpt_reason = updated_stats["ckpt_reason"] | |
| did_rollouts = True | |
| # Only keep saved videos if the ckpt should be saved (but not because of validation score) | |
| should_save_video = (should_save_ckpt and (ckpt_reason != "valid")) or config.experiment.keep_all_videos | |
| if video_paths is not None and not should_save_video: | |
| for env_name in video_paths: | |
| os.remove(video_paths[env_name]) | |
| # Save model checkpoints based on conditions (success rate, validation loss, etc) | |
| if should_save_ckpt: | |
| TrainUtils.save_model( | |
| model=model, | |
| config=config, | |
| env_meta=env_meta, | |
| shape_meta=shape_meta, | |
| ckpt_path=os.path.join(ckpt_dir, epoch_ckpt_name + ".pth"), | |
| obs_normalization_stats=obs_normalization_stats, | |
| action_normalization_stats=action_normalization_stats, | |
| ) | |
| # maybe sync some results back to scratch space (only if rollouts happened) | |
| if did_rollouts and need_sync_results: | |
| print("Sync results back to sync path: {}".format(Macros.RESULTS_SYNC_PATH_ABS)) | |
| # get best and latest model checkpoints and videos | |
| best_ckpt_path_to_sync, best_video_path_to_sync, best_epoch_to_sync = TrainUtils.get_model_from_output_folder( | |
| models_path=ckpt_dir, | |
| videos_path=video_dir if config.experiment.render_video else None, | |
| best=True, | |
| ) | |
| last_ckpt_path_to_sync, last_video_path_to_sync, last_epoch_to_sync = TrainUtils.get_model_from_output_folder( | |
| models_path=ckpt_dir, | |
| videos_path=video_dir if config.experiment.render_video else None, | |
| last=True, | |
| ) | |
| # clear last files that we synced over | |
| if best_ckpt_path_synced is not None: | |
| os.remove(best_ckpt_path_synced) | |
| if last_ckpt_path_synced is not None: | |
| os.remove(last_ckpt_path_synced) | |
| if best_video_path_synced is not None: | |
| os.remove(best_video_path_synced) | |
| if last_video_path_synced is not None: | |
| os.remove(last_video_path_synced) | |
| if os.path.exists(log_dir_path_synced): | |
| shutil.rmtree(log_dir_path_synced) | |
| # set write paths and sync new files over | |
| best_success_rate_for_sync = float(best_ckpt_path_to_sync.split("success_")[-1][:-4]) | |
| best_ckpt_path_synced = os.path.join( | |
| Macros.RESULTS_SYNC_PATH_ABS, | |
| os.path.basename(best_ckpt_path_to_sync)[:-4] + "_best.pth", | |
| ) | |
| shutil.copyfile(best_ckpt_path_to_sync, best_ckpt_path_synced) | |
| last_ckpt_path_synced = os.path.join( | |
| Macros.RESULTS_SYNC_PATH_ABS, | |
| os.path.basename(last_ckpt_path_to_sync)[:-4] + "_last.pth", | |
| ) | |
| shutil.copyfile(last_ckpt_path_to_sync, last_ckpt_path_synced) | |
| if config.experiment.render_video: | |
| best_video_path_synced = os.path.join( | |
| Macros.RESULTS_SYNC_PATH_ABS, | |
| os.path.basename(best_video_path_to_sync)[:-4] + "_best_{}.mp4".format(best_success_rate_for_sync), | |
| ) | |
| shutil.copyfile(best_video_path_to_sync, best_video_path_synced) | |
| last_video_path_synced = os.path.join( | |
| Macros.RESULTS_SYNC_PATH_ABS, | |
| os.path.basename(last_video_path_to_sync)[:-4] + "_last.mp4", | |
| ) | |
| shutil.copyfile(last_video_path_to_sync, last_video_path_synced) | |
| # sync logs dir | |
| shutil.copytree(log_dir, log_dir_path_synced) | |
| # sync config json | |
| shutil.copyfile( | |
| os.path.join(log_dir, '..', 'config.json'), | |
| os.path.join(Macros.RESULTS_SYNC_PATH_ABS, 'config.json') | |
| ) | |
| # Finally, log memory usage in MB | |
| process = psutil.Process(os.getpid()) | |
| mem_usage = int(process.memory_info().rss / 1000000) | |
| data_logger.record("System/RAM Usage (MB)", mem_usage, epoch) | |
| print("\nEpoch {} Memory Usage: {} MB\n".format(epoch, mem_usage)) | |
| # terminate logging | |
| data_logger.close() | |
| # sync logs after closing data logger to make sure everything was transferred | |
| if need_sync_results: | |
| print("Sync results back to sync path: {}".format(Macros.RESULTS_SYNC_PATH_ABS)) | |
| # sync logs dir | |
| if os.path.exists(log_dir_path_synced): | |
| shutil.rmtree(log_dir_path_synced) | |
| shutil.copytree(log_dir, log_dir_path_synced) | |
| # collect important statistics | |
| important_stats = dict() | |
| prefix = "Rollout/Success_Rate/" | |
| exception_prefix = "Rollout/Exception_Rate/" | |
| for k in data_logger._data: | |
| if k.startswith(prefix): | |
| suffix = k[len(prefix):] | |
| stats = data_logger.get_stats(k) | |
| important_stats["{}-max".format(suffix)] = stats["max"] | |
| important_stats["{}-mean".format(suffix)] = stats["mean"] | |
| elif k.startswith(exception_prefix): | |
| suffix = k[len(exception_prefix):] | |
| stats = data_logger.get_stats(k) | |
| important_stats["{}-exception-rate-max".format(suffix)] = stats["max"] | |
| important_stats["{}-exception-rate-mean".format(suffix)] = stats["mean"] | |
| # add in time taken | |
| important_stats["time spent (hrs)"] = "{:.2f}".format((time.time() - start_time) / 3600.) | |
| # write stats to disk | |
| json_file_path = os.path.join(log_dir, "important_stats.json") | |
| with open(json_file_path, 'w') as f: | |
| # preserve original key ordering | |
| json.dump(important_stats, f, sort_keys=False, indent=4) | |
| return important_stats | |
| def main(args): | |
| if args.config is not None: | |
| ext_cfg = json.load(open(args.config, 'r')) | |
| config = config_factory(ext_cfg["algo_name"]) | |
| # update config with external json - this will throw errors if | |
| # the external config has keys not present in the base algo config | |
| with config.values_unlocked(): | |
| config.update(ext_cfg) | |
| else: | |
| config = config_factory(args.algo) | |
| if args.dataset is not None: | |
| config.train.data = [dict(path=args.dataset)] | |
| if args.name is not None: | |
| config.experiment.name = args.name | |
| if args.output is not None: | |
| config.train.output_dir = args.output | |
| # get torch device | |
| device = TorchUtils.get_torch_device(try_to_use_cuda=config.train.cuda) | |
| # maybe modify config for debugging purposes | |
| if args.debug: | |
| Macros.DEBUG = True | |
| # shrink length of training to test whether this run is likely to crash | |
| config.unlock() | |
| config.lock_keys() | |
| # train and validate (if enabled) for 3 gradient steps, for 2 epochs | |
| config.experiment.epoch_every_n_steps = 3 | |
| config.experiment.validation_epoch_every_n_steps = 3 | |
| config.train.num_epochs = 2 | |
| # if rollouts are enabled, try 2 rollouts at end of each epoch, with 10 environment steps | |
| config.experiment.rollout.rate = 1 | |
| config.experiment.rollout.n = 2 | |
| config.experiment.rollout.horizon = 10 | |
| # send output to a temporary directory | |
| config.train.output_dir = "/tmp/tmp_trained_models" | |
| # lock config to prevent further modifications and ensure missing keys raise errors | |
| config.lock() | |
| # catch error during training and print it | |
| res_str = "finished run successfully!" | |
| important_stats = None | |
| try: | |
| important_stats = train(config, device=device, auto_remove_exp=args.auto_remove_exp) | |
| except Exception as e: | |
| res_str = "run failed with error:\n{}\n\n{}".format(e, traceback.format_exc()) | |
| print(res_str) | |
| if important_stats is not None: | |
| important_stats = json.dumps(important_stats, indent=4) | |
| print("\nRollout Success Rate Stats") | |
| print(important_stats) | |
| # maybe sync important stats back | |
| if Macros.RESULTS_SYNC_PATH_ABS is not None: | |
| json_file_path = os.path.join(Macros.RESULTS_SYNC_PATH_ABS, "important_stats.json") | |
| with open(json_file_path, 'w') as f: | |
| # preserve original key ordering | |
| json.dump(important_stats, f, sort_keys=False, indent=4) | |
| # maybe give slack notification | |
| if Macros.SLACK_TOKEN is not None: | |
| from robomimic.scripts.give_slack_notification import give_slack_notif | |
| msg = "Completed the following training run!\nHostname: {}\nExperiment Name: {}\n".format(socket.gethostname(), config.experiment.name) | |
| msg += "```{}```".format(res_str) | |
| if important_stats is not None: | |
| msg += "\nRollout Success Rate Stats" | |
| msg += "\n```{}```".format(important_stats) | |
| give_slack_notif(msg) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| # External config file that overwrites default config | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| default=None, | |
| help="(optional) path to a config json that will be used to override the default settings. \ | |
| If omitted, default settings are used. This is the preferred way to run experiments.", | |
| ) | |
| # Algorithm Name | |
| parser.add_argument( | |
| "--algo", | |
| type=str, | |
| help="(optional) name of algorithm to run. Only needs to be provided if --config is not provided", | |
| ) | |
| # Experiment Name (for tensorboard, saving models, etc.) | |
| parser.add_argument( | |
| "--name", | |
| type=str, | |
| default=None, | |
| help="(optional) if provided, override the experiment name defined in the config", | |
| ) | |
| # Dataset path, to override the one in the config | |
| parser.add_argument( | |
| "--dataset", | |
| type=str, | |
| default=None, | |
| help="(optional) if provided, override the dataset path defined in the config", | |
| ) | |
| # Output path, to override the one in the config | |
| parser.add_argument( | |
| "--output", | |
| type=str, | |
| default=None, | |
| help="(optional) if provided, override the output folder path defined in the config", | |
| ) | |
| # force delete the experiment folder if it exists | |
| parser.add_argument( | |
| "--auto-remove-exp", | |
| action='store_true', | |
| help="force delete the experiment folder if it exists" | |
| ) | |
| # debug mode | |
| parser.add_argument( | |
| "--debug", | |
| action='store_true', | |
| help="set this flag to run a quick training run for debugging purposes" | |
| ) | |
| args = parser.parse_args() | |
| main(args) | |