xfu314's picture
Add phantom project with submodules and dependencies
96da58e
"""
The main entry point for training policies.
Args:
config (str): path to a config json that will be used to override the default settings.
If omitted, default settings are used. This is the preferred way to run experiments.
algo (str): name of the algorithm to run. Only needs to be provided if @config is not
provided.
name (str): if provided, override the experiment name defined in the config
dataset (str): if provided, override the dataset path defined in the config
debug (bool): set this flag to run a quick training run for debugging purposes
"""
import argparse
import json
import numpy as np
import time
import os
import shutil
import psutil
import sys
import socket
import traceback
from collections import OrderedDict
import torch
from torch.utils.data import DataLoader
import robomimic
import robomimic.macros as Macros
import robomimic.utils.train_utils as TrainUtils
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.file_utils as FileUtils
from robomimic.config import config_factory
from robomimic.algo import algo_factory, RolloutPolicy
from robomimic.utils.log_utils import PrintLogger, DataLogger, flush_warnings
def train(config, device, auto_remove_exp=False):
"""
Train a model using the algorithm.
"""
# time this run
start_time = time.time()
# first set seeds
np.random.seed(config.train.seed)
torch.manual_seed(config.train.seed)
torch.set_num_threads(2)
print("\n============= New Training Run with Config =============")
print(config)
print("")
log_dir, ckpt_dir, video_dir = TrainUtils.get_exp_dir(config, auto_remove_exp_dir=auto_remove_exp)
if config.experiment.logging.terminal_output_to_txt:
# log stdout and stderr to a text file
logger = PrintLogger(os.path.join(log_dir, 'log.txt'))
sys.stdout = logger
sys.stderr = logger
# read config to set up metadata for observation modalities (e.g. detecting rgb observations)
ObsUtils.initialize_obs_utils_with_config(config)
# make sure the dataset exists
if isinstance(config.train.data, str):
dataset_path = os.path.expandvars(os.path.expanduser(config.train.data))
else:
eval_dataset_cfg = config.train.data[0]
dataset_path = os.path.expandvars(os.path.expanduser(eval_dataset_cfg["path"]))
ds_format = config.train.data_format
if not os.path.exists(dataset_path):
raise Exception("Dataset at provided path {} not found!".format(dataset_path))
# load basic metadata from training file
print("\n============= Loaded Environment Metadata =============")
env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path, ds_format=ds_format)
# update env meta if applicable
from robomimic.utils.script_utils import deep_update
deep_update(env_meta, config.experiment.env_meta_update_dict)
shape_meta = FileUtils.get_shape_metadata_from_dataset(
dataset_path=dataset_path,
action_keys=config.train.action_keys,
all_obs_keys=config.all_obs_keys,
ds_format=ds_format,
verbose=True
)
if config.experiment.env is not None:
env_meta["env_name"] = config.experiment.env
print("=" * 30 + "\n" + "Replacing Env to {}\n".format(env_meta["env_name"]) + "=" * 30)
# create environment
envs = OrderedDict()
if config.experiment.rollout.enabled:
# create environments for validation runs
env_names = [env_meta["env_name"]]
if config.experiment.additional_envs is not None:
for name in config.experiment.additional_envs:
env_names.append(name)
for env_name in env_names:
env = EnvUtils.create_env_from_metadata(
env_meta=env_meta,
env_name=env_name,
render=config.experiment.render,
render_offscreen=config.experiment.render_video,
use_image_obs=shape_meta["use_images"],
use_depth_obs=shape_meta["use_depths"],
)
env = EnvUtils.wrap_env_from_config(env, config=config) # apply environment warpper, if applicable
envs[env.name] = env
print(envs[env.name])
print("")
# setup for a new training run
data_logger = DataLogger(
log_dir,
config,
log_tb=config.experiment.logging.log_tb,
log_wandb=config.experiment.logging.log_wandb,
)
model = algo_factory(
algo_name=config.algo_name,
config=config,
obs_key_shapes=shape_meta["all_shapes"],
ac_dim=shape_meta["ac_dim"],
device=device,
)
# save the config as a json file
with open(os.path.join(log_dir, '..', 'config.json'), 'w') as outfile:
json.dump(config, outfile, indent=4)
print("\n============= Model Summary =============")
print(model) # print model summary
print("")
# load training data
trainset, validset = TrainUtils.load_data_for_training(
config, obs_keys=shape_meta["all_obs_keys"])
train_sampler = trainset.get_dataset_sampler()
print("\n============= Training Dataset =============")
print(trainset)
print("")
if validset is not None:
print("\n============= Validation Dataset =============")
print(validset)
print("")
# maybe retreve statistics for normalizing observations
obs_normalization_stats = None
if config.train.hdf5_normalize_obs:
obs_normalization_stats = trainset.get_obs_normalization_stats()
# maybe retreve statistics for normalizing actions
action_normalization_stats = trainset.get_action_normalization_stats()
# initialize data loaders
train_loader = DataLoader(
dataset=trainset,
sampler=train_sampler,
batch_size=config.train.batch_size,
shuffle=(train_sampler is None),
num_workers=config.train.num_data_workers,
drop_last=True
)
if config.experiment.validate:
# cap num workers for validation dataset at 1
num_workers = min(config.train.num_data_workers, 1)
valid_sampler = validset.get_dataset_sampler()
valid_loader = DataLoader(
dataset=validset,
sampler=valid_sampler,
batch_size=config.train.batch_size,
shuffle=(valid_sampler is None),
num_workers=num_workers,
drop_last=True
)
else:
valid_loader = None
# print all warnings before training begins
print("*" * 50)
print("Warnings generated by robomimic have been duplicated here (from above) for convenience. Please check them carefully.")
flush_warnings()
print("*" * 50)
print("")
# main training loop
best_valid_loss = None
best_return = {k: -np.inf for k in envs} if config.experiment.rollout.enabled else None
best_success_rate = {k: -1. for k in envs} if config.experiment.rollout.enabled else None
last_ckpt_time = time.time()
need_sync_results = (Macros.RESULTS_SYNC_PATH_ABS is not None)
if need_sync_results:
# these paths will be updated after each evaluation
best_ckpt_path_synced = None
best_video_path_synced = None
last_ckpt_path_synced = None
last_video_path_synced = None
log_dir_path_synced = os.path.join(Macros.RESULTS_SYNC_PATH_ABS, "logs")
# number of learning steps per epoch (defaults to a full dataset pass)
train_num_steps = config.experiment.epoch_every_n_steps
valid_num_steps = config.experiment.validation_epoch_every_n_steps
for epoch in range(1, config.train.num_epochs + 1): # epoch numbers start at 1
step_log = TrainUtils.run_epoch(
model=model,
data_loader=train_loader,
epoch=epoch,
num_steps=train_num_steps,
obs_normalization_stats=obs_normalization_stats,
)
model.on_epoch_end(epoch)
# setup checkpoint path
epoch_ckpt_name = "model_epoch_{}".format(epoch)
# check for recurring checkpoint saving conditions
should_save_ckpt = False
if config.experiment.save.enabled:
time_check = (config.experiment.save.every_n_seconds is not None) and \
(time.time() - last_ckpt_time > config.experiment.save.every_n_seconds)
epoch_check = (config.experiment.save.every_n_epochs is not None) and \
(epoch > 0) and (epoch % config.experiment.save.every_n_epochs == 0)
epoch_list_check = (epoch in config.experiment.save.epochs)
should_save_ckpt = (time_check or epoch_check or epoch_list_check)
ckpt_reason = None
if should_save_ckpt:
last_ckpt_time = time.time()
ckpt_reason = "time"
print("Train Epoch {}".format(epoch))
print(json.dumps(step_log, sort_keys=True, indent=4))
for k, v in step_log.items():
if k.startswith("Time_"):
data_logger.record("Timing_Stats/Train_{}".format(k[5:]), v, epoch)
else:
data_logger.record("Train/{}".format(k), v, epoch)
# Evaluate the model on validation set
if config.experiment.validate:
with torch.no_grad():
step_log = TrainUtils.run_epoch(model=model, data_loader=valid_loader, epoch=epoch, validate=True, num_steps=valid_num_steps)
for k, v in step_log.items():
if k.startswith("Time_"):
data_logger.record("Timing_Stats/Valid_{}".format(k[5:]), v, epoch)
else:
data_logger.record("Valid/{}".format(k), v, epoch)
print("Validation Epoch {}".format(epoch))
print(json.dumps(step_log, sort_keys=True, indent=4))
# save checkpoint if achieve new best validation loss
valid_check = "Loss" in step_log
if valid_check and (best_valid_loss is None or (step_log["Loss"] <= best_valid_loss)):
best_valid_loss = step_log["Loss"]
if config.experiment.save.enabled and config.experiment.save.on_best_validation:
epoch_ckpt_name += "_best_validation_{}".format(best_valid_loss)
should_save_ckpt = True
ckpt_reason = "valid" if ckpt_reason is None else ckpt_reason
# Evaluate the model by by running rollouts
# do rollouts at fixed rate or if it's time to save a new ckpt
video_paths = None
rollout_check = (epoch % config.experiment.rollout.rate == 0) or (should_save_ckpt and ckpt_reason == "time")
did_rollouts = False
if config.experiment.rollout.enabled and (epoch > config.experiment.rollout.warmstart) and rollout_check:
# wrap model as a RolloutPolicy to prepare for rollouts
rollout_model = RolloutPolicy(
model,
obs_normalization_stats=obs_normalization_stats,
action_normalization_stats=action_normalization_stats,
)
num_episodes = config.experiment.rollout.n
all_rollout_logs, video_paths = TrainUtils.rollout_with_stats(
policy=rollout_model,
envs=envs,
horizon=config.experiment.rollout.horizon,
use_goals=config.use_goals,
num_episodes=num_episodes,
render=False,
video_dir=video_dir if config.experiment.render_video else None,
epoch=epoch,
video_skip=config.experiment.get("video_skip", 5),
terminate_on_success=config.experiment.rollout.terminate_on_success,
)
# summarize results from rollouts to tensorboard and terminal
for env_name in all_rollout_logs:
rollout_logs = all_rollout_logs[env_name]
for k, v in rollout_logs.items():
if k.startswith("Time_"):
data_logger.record("Timing_Stats/Rollout_{}_{}".format(env_name, k[5:]), v, epoch)
else:
data_logger.record("Rollout/{}/{}".format(k, env_name), v, epoch, log_stats=True)
print("\nEpoch {} Rollouts took {}s (avg) with results:".format(epoch, rollout_logs["time"]))
print('Env: {}'.format(env_name))
print(json.dumps(rollout_logs, sort_keys=True, indent=4))
# checkpoint and video saving logic
updated_stats = TrainUtils.should_save_from_rollout_logs(
all_rollout_logs=all_rollout_logs,
best_return=best_return,
best_success_rate=best_success_rate,
epoch_ckpt_name=epoch_ckpt_name,
save_on_best_rollout_return=config.experiment.save.on_best_rollout_return,
save_on_best_rollout_success_rate=config.experiment.save.on_best_rollout_success_rate,
)
best_return = updated_stats["best_return"]
best_success_rate = updated_stats["best_success_rate"]
epoch_ckpt_name = updated_stats["epoch_ckpt_name"]
should_save_ckpt = (config.experiment.save.enabled and updated_stats["should_save_ckpt"]) or should_save_ckpt
if updated_stats["ckpt_reason"] is not None:
ckpt_reason = updated_stats["ckpt_reason"]
did_rollouts = True
# Only keep saved videos if the ckpt should be saved (but not because of validation score)
should_save_video = (should_save_ckpt and (ckpt_reason != "valid")) or config.experiment.keep_all_videos
if video_paths is not None and not should_save_video:
for env_name in video_paths:
os.remove(video_paths[env_name])
# Save model checkpoints based on conditions (success rate, validation loss, etc)
if should_save_ckpt:
TrainUtils.save_model(
model=model,
config=config,
env_meta=env_meta,
shape_meta=shape_meta,
ckpt_path=os.path.join(ckpt_dir, epoch_ckpt_name + ".pth"),
obs_normalization_stats=obs_normalization_stats,
action_normalization_stats=action_normalization_stats,
)
# maybe sync some results back to scratch space (only if rollouts happened)
if did_rollouts and need_sync_results:
print("Sync results back to sync path: {}".format(Macros.RESULTS_SYNC_PATH_ABS))
# get best and latest model checkpoints and videos
best_ckpt_path_to_sync, best_video_path_to_sync, best_epoch_to_sync = TrainUtils.get_model_from_output_folder(
models_path=ckpt_dir,
videos_path=video_dir if config.experiment.render_video else None,
best=True,
)
last_ckpt_path_to_sync, last_video_path_to_sync, last_epoch_to_sync = TrainUtils.get_model_from_output_folder(
models_path=ckpt_dir,
videos_path=video_dir if config.experiment.render_video else None,
last=True,
)
# clear last files that we synced over
if best_ckpt_path_synced is not None:
os.remove(best_ckpt_path_synced)
if last_ckpt_path_synced is not None:
os.remove(last_ckpt_path_synced)
if best_video_path_synced is not None:
os.remove(best_video_path_synced)
if last_video_path_synced is not None:
os.remove(last_video_path_synced)
if os.path.exists(log_dir_path_synced):
shutil.rmtree(log_dir_path_synced)
# set write paths and sync new files over
best_success_rate_for_sync = float(best_ckpt_path_to_sync.split("success_")[-1][:-4])
best_ckpt_path_synced = os.path.join(
Macros.RESULTS_SYNC_PATH_ABS,
os.path.basename(best_ckpt_path_to_sync)[:-4] + "_best.pth",
)
shutil.copyfile(best_ckpt_path_to_sync, best_ckpt_path_synced)
last_ckpt_path_synced = os.path.join(
Macros.RESULTS_SYNC_PATH_ABS,
os.path.basename(last_ckpt_path_to_sync)[:-4] + "_last.pth",
)
shutil.copyfile(last_ckpt_path_to_sync, last_ckpt_path_synced)
if config.experiment.render_video:
best_video_path_synced = os.path.join(
Macros.RESULTS_SYNC_PATH_ABS,
os.path.basename(best_video_path_to_sync)[:-4] + "_best_{}.mp4".format(best_success_rate_for_sync),
)
shutil.copyfile(best_video_path_to_sync, best_video_path_synced)
last_video_path_synced = os.path.join(
Macros.RESULTS_SYNC_PATH_ABS,
os.path.basename(last_video_path_to_sync)[:-4] + "_last.mp4",
)
shutil.copyfile(last_video_path_to_sync, last_video_path_synced)
# sync logs dir
shutil.copytree(log_dir, log_dir_path_synced)
# sync config json
shutil.copyfile(
os.path.join(log_dir, '..', 'config.json'),
os.path.join(Macros.RESULTS_SYNC_PATH_ABS, 'config.json')
)
# Finally, log memory usage in MB
process = psutil.Process(os.getpid())
mem_usage = int(process.memory_info().rss / 1000000)
data_logger.record("System/RAM Usage (MB)", mem_usage, epoch)
print("\nEpoch {} Memory Usage: {} MB\n".format(epoch, mem_usage))
# terminate logging
data_logger.close()
# sync logs after closing data logger to make sure everything was transferred
if need_sync_results:
print("Sync results back to sync path: {}".format(Macros.RESULTS_SYNC_PATH_ABS))
# sync logs dir
if os.path.exists(log_dir_path_synced):
shutil.rmtree(log_dir_path_synced)
shutil.copytree(log_dir, log_dir_path_synced)
# collect important statistics
important_stats = dict()
prefix = "Rollout/Success_Rate/"
exception_prefix = "Rollout/Exception_Rate/"
for k in data_logger._data:
if k.startswith(prefix):
suffix = k[len(prefix):]
stats = data_logger.get_stats(k)
important_stats["{}-max".format(suffix)] = stats["max"]
important_stats["{}-mean".format(suffix)] = stats["mean"]
elif k.startswith(exception_prefix):
suffix = k[len(exception_prefix):]
stats = data_logger.get_stats(k)
important_stats["{}-exception-rate-max".format(suffix)] = stats["max"]
important_stats["{}-exception-rate-mean".format(suffix)] = stats["mean"]
# add in time taken
important_stats["time spent (hrs)"] = "{:.2f}".format((time.time() - start_time) / 3600.)
# write stats to disk
json_file_path = os.path.join(log_dir, "important_stats.json")
with open(json_file_path, 'w') as f:
# preserve original key ordering
json.dump(important_stats, f, sort_keys=False, indent=4)
return important_stats
def main(args):
if args.config is not None:
ext_cfg = json.load(open(args.config, 'r'))
config = config_factory(ext_cfg["algo_name"])
# update config with external json - this will throw errors if
# the external config has keys not present in the base algo config
with config.values_unlocked():
config.update(ext_cfg)
else:
config = config_factory(args.algo)
if args.dataset is not None:
config.train.data = [dict(path=args.dataset)]
if args.name is not None:
config.experiment.name = args.name
if args.output is not None:
config.train.output_dir = args.output
# get torch device
device = TorchUtils.get_torch_device(try_to_use_cuda=config.train.cuda)
# maybe modify config for debugging purposes
if args.debug:
Macros.DEBUG = True
# shrink length of training to test whether this run is likely to crash
config.unlock()
config.lock_keys()
# train and validate (if enabled) for 3 gradient steps, for 2 epochs
config.experiment.epoch_every_n_steps = 3
config.experiment.validation_epoch_every_n_steps = 3
config.train.num_epochs = 2
# if rollouts are enabled, try 2 rollouts at end of each epoch, with 10 environment steps
config.experiment.rollout.rate = 1
config.experiment.rollout.n = 2
config.experiment.rollout.horizon = 10
# send output to a temporary directory
config.train.output_dir = "/tmp/tmp_trained_models"
# lock config to prevent further modifications and ensure missing keys raise errors
config.lock()
# catch error during training and print it
res_str = "finished run successfully!"
important_stats = None
try:
important_stats = train(config, device=device, auto_remove_exp=args.auto_remove_exp)
except Exception as e:
res_str = "run failed with error:\n{}\n\n{}".format(e, traceback.format_exc())
print(res_str)
if important_stats is not None:
important_stats = json.dumps(important_stats, indent=4)
print("\nRollout Success Rate Stats")
print(important_stats)
# maybe sync important stats back
if Macros.RESULTS_SYNC_PATH_ABS is not None:
json_file_path = os.path.join(Macros.RESULTS_SYNC_PATH_ABS, "important_stats.json")
with open(json_file_path, 'w') as f:
# preserve original key ordering
json.dump(important_stats, f, sort_keys=False, indent=4)
# maybe give slack notification
if Macros.SLACK_TOKEN is not None:
from robomimic.scripts.give_slack_notification import give_slack_notif
msg = "Completed the following training run!\nHostname: {}\nExperiment Name: {}\n".format(socket.gethostname(), config.experiment.name)
msg += "```{}```".format(res_str)
if important_stats is not None:
msg += "\nRollout Success Rate Stats"
msg += "\n```{}```".format(important_stats)
give_slack_notif(msg)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# External config file that overwrites default config
parser.add_argument(
"--config",
type=str,
default=None,
help="(optional) path to a config json that will be used to override the default settings. \
If omitted, default settings are used. This is the preferred way to run experiments.",
)
# Algorithm Name
parser.add_argument(
"--algo",
type=str,
help="(optional) name of algorithm to run. Only needs to be provided if --config is not provided",
)
# Experiment Name (for tensorboard, saving models, etc.)
parser.add_argument(
"--name",
type=str,
default=None,
help="(optional) if provided, override the experiment name defined in the config",
)
# Dataset path, to override the one in the config
parser.add_argument(
"--dataset",
type=str,
default=None,
help="(optional) if provided, override the dataset path defined in the config",
)
# Output path, to override the one in the config
parser.add_argument(
"--output",
type=str,
default=None,
help="(optional) if provided, override the output folder path defined in the config",
)
# force delete the experiment folder if it exists
parser.add_argument(
"--auto-remove-exp",
action='store_true',
help="force delete the experiment folder if it exists"
)
# debug mode
parser.add_argument(
"--debug",
action='store_true',
help="set this flag to run a quick training run for debugging purposes"
)
args = parser.parse_args()
main(args)