Spaces:
Sleeping
Sleeping
File size: 24,209 Bytes
96da58e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 |
"""
The main entry point for training policies.
Args:
config (str): path to a config json that will be used to override the default settings.
If omitted, default settings are used. This is the preferred way to run experiments.
algo (str): name of the algorithm to run. Only needs to be provided if @config is not
provided.
name (str): if provided, override the experiment name defined in the config
dataset (str): if provided, override the dataset path defined in the config
debug (bool): set this flag to run a quick training run for debugging purposes
"""
import argparse
import json
import numpy as np
import time
import os
import shutil
import psutil
import sys
import socket
import traceback
from collections import OrderedDict
import torch
from torch.utils.data import DataLoader
import robomimic
import robomimic.macros as Macros
import robomimic.utils.train_utils as TrainUtils
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.file_utils as FileUtils
from robomimic.config import config_factory
from robomimic.algo import algo_factory, RolloutPolicy
from robomimic.utils.log_utils import PrintLogger, DataLogger, flush_warnings
def train(config, device, auto_remove_exp=False):
"""
Train a model using the algorithm.
"""
# time this run
start_time = time.time()
# first set seeds
np.random.seed(config.train.seed)
torch.manual_seed(config.train.seed)
torch.set_num_threads(2)
print("\n============= New Training Run with Config =============")
print(config)
print("")
log_dir, ckpt_dir, video_dir = TrainUtils.get_exp_dir(config, auto_remove_exp_dir=auto_remove_exp)
if config.experiment.logging.terminal_output_to_txt:
# log stdout and stderr to a text file
logger = PrintLogger(os.path.join(log_dir, 'log.txt'))
sys.stdout = logger
sys.stderr = logger
# read config to set up metadata for observation modalities (e.g. detecting rgb observations)
ObsUtils.initialize_obs_utils_with_config(config)
# make sure the dataset exists
if isinstance(config.train.data, str):
dataset_path = os.path.expandvars(os.path.expanduser(config.train.data))
else:
eval_dataset_cfg = config.train.data[0]
dataset_path = os.path.expandvars(os.path.expanduser(eval_dataset_cfg["path"]))
ds_format = config.train.data_format
if not os.path.exists(dataset_path):
raise Exception("Dataset at provided path {} not found!".format(dataset_path))
# load basic metadata from training file
print("\n============= Loaded Environment Metadata =============")
env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path, ds_format=ds_format)
# update env meta if applicable
from robomimic.utils.script_utils import deep_update
deep_update(env_meta, config.experiment.env_meta_update_dict)
shape_meta = FileUtils.get_shape_metadata_from_dataset(
dataset_path=dataset_path,
action_keys=config.train.action_keys,
all_obs_keys=config.all_obs_keys,
ds_format=ds_format,
verbose=True
)
if config.experiment.env is not None:
env_meta["env_name"] = config.experiment.env
print("=" * 30 + "\n" + "Replacing Env to {}\n".format(env_meta["env_name"]) + "=" * 30)
# create environment
envs = OrderedDict()
if config.experiment.rollout.enabled:
# create environments for validation runs
env_names = [env_meta["env_name"]]
if config.experiment.additional_envs is not None:
for name in config.experiment.additional_envs:
env_names.append(name)
for env_name in env_names:
env = EnvUtils.create_env_from_metadata(
env_meta=env_meta,
env_name=env_name,
render=config.experiment.render,
render_offscreen=config.experiment.render_video,
use_image_obs=shape_meta["use_images"],
use_depth_obs=shape_meta["use_depths"],
)
env = EnvUtils.wrap_env_from_config(env, config=config) # apply environment warpper, if applicable
envs[env.name] = env
print(envs[env.name])
print("")
# setup for a new training run
data_logger = DataLogger(
log_dir,
config,
log_tb=config.experiment.logging.log_tb,
log_wandb=config.experiment.logging.log_wandb,
)
model = algo_factory(
algo_name=config.algo_name,
config=config,
obs_key_shapes=shape_meta["all_shapes"],
ac_dim=shape_meta["ac_dim"],
device=device,
)
# save the config as a json file
with open(os.path.join(log_dir, '..', 'config.json'), 'w') as outfile:
json.dump(config, outfile, indent=4)
print("\n============= Model Summary =============")
print(model) # print model summary
print("")
# load training data
trainset, validset = TrainUtils.load_data_for_training(
config, obs_keys=shape_meta["all_obs_keys"])
train_sampler = trainset.get_dataset_sampler()
print("\n============= Training Dataset =============")
print(trainset)
print("")
if validset is not None:
print("\n============= Validation Dataset =============")
print(validset)
print("")
# maybe retreve statistics for normalizing observations
obs_normalization_stats = None
if config.train.hdf5_normalize_obs:
obs_normalization_stats = trainset.get_obs_normalization_stats()
# maybe retreve statistics for normalizing actions
action_normalization_stats = trainset.get_action_normalization_stats()
# initialize data loaders
train_loader = DataLoader(
dataset=trainset,
sampler=train_sampler,
batch_size=config.train.batch_size,
shuffle=(train_sampler is None),
num_workers=config.train.num_data_workers,
drop_last=True
)
if config.experiment.validate:
# cap num workers for validation dataset at 1
num_workers = min(config.train.num_data_workers, 1)
valid_sampler = validset.get_dataset_sampler()
valid_loader = DataLoader(
dataset=validset,
sampler=valid_sampler,
batch_size=config.train.batch_size,
shuffle=(valid_sampler is None),
num_workers=num_workers,
drop_last=True
)
else:
valid_loader = None
# print all warnings before training begins
print("*" * 50)
print("Warnings generated by robomimic have been duplicated here (from above) for convenience. Please check them carefully.")
flush_warnings()
print("*" * 50)
print("")
# main training loop
best_valid_loss = None
best_return = {k: -np.inf for k in envs} if config.experiment.rollout.enabled else None
best_success_rate = {k: -1. for k in envs} if config.experiment.rollout.enabled else None
last_ckpt_time = time.time()
need_sync_results = (Macros.RESULTS_SYNC_PATH_ABS is not None)
if need_sync_results:
# these paths will be updated after each evaluation
best_ckpt_path_synced = None
best_video_path_synced = None
last_ckpt_path_synced = None
last_video_path_synced = None
log_dir_path_synced = os.path.join(Macros.RESULTS_SYNC_PATH_ABS, "logs")
# number of learning steps per epoch (defaults to a full dataset pass)
train_num_steps = config.experiment.epoch_every_n_steps
valid_num_steps = config.experiment.validation_epoch_every_n_steps
for epoch in range(1, config.train.num_epochs + 1): # epoch numbers start at 1
step_log = TrainUtils.run_epoch(
model=model,
data_loader=train_loader,
epoch=epoch,
num_steps=train_num_steps,
obs_normalization_stats=obs_normalization_stats,
)
model.on_epoch_end(epoch)
# setup checkpoint path
epoch_ckpt_name = "model_epoch_{}".format(epoch)
# check for recurring checkpoint saving conditions
should_save_ckpt = False
if config.experiment.save.enabled:
time_check = (config.experiment.save.every_n_seconds is not None) and \
(time.time() - last_ckpt_time > config.experiment.save.every_n_seconds)
epoch_check = (config.experiment.save.every_n_epochs is not None) and \
(epoch > 0) and (epoch % config.experiment.save.every_n_epochs == 0)
epoch_list_check = (epoch in config.experiment.save.epochs)
should_save_ckpt = (time_check or epoch_check or epoch_list_check)
ckpt_reason = None
if should_save_ckpt:
last_ckpt_time = time.time()
ckpt_reason = "time"
print("Train Epoch {}".format(epoch))
print(json.dumps(step_log, sort_keys=True, indent=4))
for k, v in step_log.items():
if k.startswith("Time_"):
data_logger.record("Timing_Stats/Train_{}".format(k[5:]), v, epoch)
else:
data_logger.record("Train/{}".format(k), v, epoch)
# Evaluate the model on validation set
if config.experiment.validate:
with torch.no_grad():
step_log = TrainUtils.run_epoch(model=model, data_loader=valid_loader, epoch=epoch, validate=True, num_steps=valid_num_steps)
for k, v in step_log.items():
if k.startswith("Time_"):
data_logger.record("Timing_Stats/Valid_{}".format(k[5:]), v, epoch)
else:
data_logger.record("Valid/{}".format(k), v, epoch)
print("Validation Epoch {}".format(epoch))
print(json.dumps(step_log, sort_keys=True, indent=4))
# save checkpoint if achieve new best validation loss
valid_check = "Loss" in step_log
if valid_check and (best_valid_loss is None or (step_log["Loss"] <= best_valid_loss)):
best_valid_loss = step_log["Loss"]
if config.experiment.save.enabled and config.experiment.save.on_best_validation:
epoch_ckpt_name += "_best_validation_{}".format(best_valid_loss)
should_save_ckpt = True
ckpt_reason = "valid" if ckpt_reason is None else ckpt_reason
# Evaluate the model by by running rollouts
# do rollouts at fixed rate or if it's time to save a new ckpt
video_paths = None
rollout_check = (epoch % config.experiment.rollout.rate == 0) or (should_save_ckpt and ckpt_reason == "time")
did_rollouts = False
if config.experiment.rollout.enabled and (epoch > config.experiment.rollout.warmstart) and rollout_check:
# wrap model as a RolloutPolicy to prepare for rollouts
rollout_model = RolloutPolicy(
model,
obs_normalization_stats=obs_normalization_stats,
action_normalization_stats=action_normalization_stats,
)
num_episodes = config.experiment.rollout.n
all_rollout_logs, video_paths = TrainUtils.rollout_with_stats(
policy=rollout_model,
envs=envs,
horizon=config.experiment.rollout.horizon,
use_goals=config.use_goals,
num_episodes=num_episodes,
render=False,
video_dir=video_dir if config.experiment.render_video else None,
epoch=epoch,
video_skip=config.experiment.get("video_skip", 5),
terminate_on_success=config.experiment.rollout.terminate_on_success,
)
# summarize results from rollouts to tensorboard and terminal
for env_name in all_rollout_logs:
rollout_logs = all_rollout_logs[env_name]
for k, v in rollout_logs.items():
if k.startswith("Time_"):
data_logger.record("Timing_Stats/Rollout_{}_{}".format(env_name, k[5:]), v, epoch)
else:
data_logger.record("Rollout/{}/{}".format(k, env_name), v, epoch, log_stats=True)
print("\nEpoch {} Rollouts took {}s (avg) with results:".format(epoch, rollout_logs["time"]))
print('Env: {}'.format(env_name))
print(json.dumps(rollout_logs, sort_keys=True, indent=4))
# checkpoint and video saving logic
updated_stats = TrainUtils.should_save_from_rollout_logs(
all_rollout_logs=all_rollout_logs,
best_return=best_return,
best_success_rate=best_success_rate,
epoch_ckpt_name=epoch_ckpt_name,
save_on_best_rollout_return=config.experiment.save.on_best_rollout_return,
save_on_best_rollout_success_rate=config.experiment.save.on_best_rollout_success_rate,
)
best_return = updated_stats["best_return"]
best_success_rate = updated_stats["best_success_rate"]
epoch_ckpt_name = updated_stats["epoch_ckpt_name"]
should_save_ckpt = (config.experiment.save.enabled and updated_stats["should_save_ckpt"]) or should_save_ckpt
if updated_stats["ckpt_reason"] is not None:
ckpt_reason = updated_stats["ckpt_reason"]
did_rollouts = True
# Only keep saved videos if the ckpt should be saved (but not because of validation score)
should_save_video = (should_save_ckpt and (ckpt_reason != "valid")) or config.experiment.keep_all_videos
if video_paths is not None and not should_save_video:
for env_name in video_paths:
os.remove(video_paths[env_name])
# Save model checkpoints based on conditions (success rate, validation loss, etc)
if should_save_ckpt:
TrainUtils.save_model(
model=model,
config=config,
env_meta=env_meta,
shape_meta=shape_meta,
ckpt_path=os.path.join(ckpt_dir, epoch_ckpt_name + ".pth"),
obs_normalization_stats=obs_normalization_stats,
action_normalization_stats=action_normalization_stats,
)
# maybe sync some results back to scratch space (only if rollouts happened)
if did_rollouts and need_sync_results:
print("Sync results back to sync path: {}".format(Macros.RESULTS_SYNC_PATH_ABS))
# get best and latest model checkpoints and videos
best_ckpt_path_to_sync, best_video_path_to_sync, best_epoch_to_sync = TrainUtils.get_model_from_output_folder(
models_path=ckpt_dir,
videos_path=video_dir if config.experiment.render_video else None,
best=True,
)
last_ckpt_path_to_sync, last_video_path_to_sync, last_epoch_to_sync = TrainUtils.get_model_from_output_folder(
models_path=ckpt_dir,
videos_path=video_dir if config.experiment.render_video else None,
last=True,
)
# clear last files that we synced over
if best_ckpt_path_synced is not None:
os.remove(best_ckpt_path_synced)
if last_ckpt_path_synced is not None:
os.remove(last_ckpt_path_synced)
if best_video_path_synced is not None:
os.remove(best_video_path_synced)
if last_video_path_synced is not None:
os.remove(last_video_path_synced)
if os.path.exists(log_dir_path_synced):
shutil.rmtree(log_dir_path_synced)
# set write paths and sync new files over
best_success_rate_for_sync = float(best_ckpt_path_to_sync.split("success_")[-1][:-4])
best_ckpt_path_synced = os.path.join(
Macros.RESULTS_SYNC_PATH_ABS,
os.path.basename(best_ckpt_path_to_sync)[:-4] + "_best.pth",
)
shutil.copyfile(best_ckpt_path_to_sync, best_ckpt_path_synced)
last_ckpt_path_synced = os.path.join(
Macros.RESULTS_SYNC_PATH_ABS,
os.path.basename(last_ckpt_path_to_sync)[:-4] + "_last.pth",
)
shutil.copyfile(last_ckpt_path_to_sync, last_ckpt_path_synced)
if config.experiment.render_video:
best_video_path_synced = os.path.join(
Macros.RESULTS_SYNC_PATH_ABS,
os.path.basename(best_video_path_to_sync)[:-4] + "_best_{}.mp4".format(best_success_rate_for_sync),
)
shutil.copyfile(best_video_path_to_sync, best_video_path_synced)
last_video_path_synced = os.path.join(
Macros.RESULTS_SYNC_PATH_ABS,
os.path.basename(last_video_path_to_sync)[:-4] + "_last.mp4",
)
shutil.copyfile(last_video_path_to_sync, last_video_path_synced)
# sync logs dir
shutil.copytree(log_dir, log_dir_path_synced)
# sync config json
shutil.copyfile(
os.path.join(log_dir, '..', 'config.json'),
os.path.join(Macros.RESULTS_SYNC_PATH_ABS, 'config.json')
)
# Finally, log memory usage in MB
process = psutil.Process(os.getpid())
mem_usage = int(process.memory_info().rss / 1000000)
data_logger.record("System/RAM Usage (MB)", mem_usage, epoch)
print("\nEpoch {} Memory Usage: {} MB\n".format(epoch, mem_usage))
# terminate logging
data_logger.close()
# sync logs after closing data logger to make sure everything was transferred
if need_sync_results:
print("Sync results back to sync path: {}".format(Macros.RESULTS_SYNC_PATH_ABS))
# sync logs dir
if os.path.exists(log_dir_path_synced):
shutil.rmtree(log_dir_path_synced)
shutil.copytree(log_dir, log_dir_path_synced)
# collect important statistics
important_stats = dict()
prefix = "Rollout/Success_Rate/"
exception_prefix = "Rollout/Exception_Rate/"
for k in data_logger._data:
if k.startswith(prefix):
suffix = k[len(prefix):]
stats = data_logger.get_stats(k)
important_stats["{}-max".format(suffix)] = stats["max"]
important_stats["{}-mean".format(suffix)] = stats["mean"]
elif k.startswith(exception_prefix):
suffix = k[len(exception_prefix):]
stats = data_logger.get_stats(k)
important_stats["{}-exception-rate-max".format(suffix)] = stats["max"]
important_stats["{}-exception-rate-mean".format(suffix)] = stats["mean"]
# add in time taken
important_stats["time spent (hrs)"] = "{:.2f}".format((time.time() - start_time) / 3600.)
# write stats to disk
json_file_path = os.path.join(log_dir, "important_stats.json")
with open(json_file_path, 'w') as f:
# preserve original key ordering
json.dump(important_stats, f, sort_keys=False, indent=4)
return important_stats
def main(args):
if args.config is not None:
ext_cfg = json.load(open(args.config, 'r'))
config = config_factory(ext_cfg["algo_name"])
# update config with external json - this will throw errors if
# the external config has keys not present in the base algo config
with config.values_unlocked():
config.update(ext_cfg)
else:
config = config_factory(args.algo)
if args.dataset is not None:
config.train.data = [dict(path=args.dataset)]
if args.name is not None:
config.experiment.name = args.name
if args.output is not None:
config.train.output_dir = args.output
# get torch device
device = TorchUtils.get_torch_device(try_to_use_cuda=config.train.cuda)
# maybe modify config for debugging purposes
if args.debug:
Macros.DEBUG = True
# shrink length of training to test whether this run is likely to crash
config.unlock()
config.lock_keys()
# train and validate (if enabled) for 3 gradient steps, for 2 epochs
config.experiment.epoch_every_n_steps = 3
config.experiment.validation_epoch_every_n_steps = 3
config.train.num_epochs = 2
# if rollouts are enabled, try 2 rollouts at end of each epoch, with 10 environment steps
config.experiment.rollout.rate = 1
config.experiment.rollout.n = 2
config.experiment.rollout.horizon = 10
# send output to a temporary directory
config.train.output_dir = "/tmp/tmp_trained_models"
# lock config to prevent further modifications and ensure missing keys raise errors
config.lock()
# catch error during training and print it
res_str = "finished run successfully!"
important_stats = None
try:
important_stats = train(config, device=device, auto_remove_exp=args.auto_remove_exp)
except Exception as e:
res_str = "run failed with error:\n{}\n\n{}".format(e, traceback.format_exc())
print(res_str)
if important_stats is not None:
important_stats = json.dumps(important_stats, indent=4)
print("\nRollout Success Rate Stats")
print(important_stats)
# maybe sync important stats back
if Macros.RESULTS_SYNC_PATH_ABS is not None:
json_file_path = os.path.join(Macros.RESULTS_SYNC_PATH_ABS, "important_stats.json")
with open(json_file_path, 'w') as f:
# preserve original key ordering
json.dump(important_stats, f, sort_keys=False, indent=4)
# maybe give slack notification
if Macros.SLACK_TOKEN is not None:
from robomimic.scripts.give_slack_notification import give_slack_notif
msg = "Completed the following training run!\nHostname: {}\nExperiment Name: {}\n".format(socket.gethostname(), config.experiment.name)
msg += "```{}```".format(res_str)
if important_stats is not None:
msg += "\nRollout Success Rate Stats"
msg += "\n```{}```".format(important_stats)
give_slack_notif(msg)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# External config file that overwrites default config
parser.add_argument(
"--config",
type=str,
default=None,
help="(optional) path to a config json that will be used to override the default settings. \
If omitted, default settings are used. This is the preferred way to run experiments.",
)
# Algorithm Name
parser.add_argument(
"--algo",
type=str,
help="(optional) name of algorithm to run. Only needs to be provided if --config is not provided",
)
# Experiment Name (for tensorboard, saving models, etc.)
parser.add_argument(
"--name",
type=str,
default=None,
help="(optional) if provided, override the experiment name defined in the config",
)
# Dataset path, to override the one in the config
parser.add_argument(
"--dataset",
type=str,
default=None,
help="(optional) if provided, override the dataset path defined in the config",
)
# Output path, to override the one in the config
parser.add_argument(
"--output",
type=str,
default=None,
help="(optional) if provided, override the output folder path defined in the config",
)
# force delete the experiment folder if it exists
parser.add_argument(
"--auto-remove-exp",
action='store_true',
help="force delete the experiment folder if it exists"
)
# debug mode
parser.add_argument(
"--debug",
action='store_true',
help="set this flag to run a quick training run for debugging purposes"
)
args = parser.parse_args()
main(args)
|