| import ray |
|
|
| from slime.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models |
| from slime.utils.arguments import parse_args |
| from slime.utils.logging_utils import configure_logger, init_tracking |
| from slime.utils.misc import should_run_periodic_action |
|
|
|
|
| def train(args): |
| configure_logger() |
| |
| pgs = create_placement_groups(args) |
| init_tracking(args) |
|
|
| |
| |
| rollout_manager, num_rollout_per_epoch = create_rollout_manager(args, pgs["rollout"]) |
|
|
| |
| actor_model, critic_model = create_training_models(args, pgs, rollout_manager) |
|
|
| if args.offload_rollout: |
| ray.get(rollout_manager.onload_weights.remote()) |
|
|
| |
| actor_model.update_weights() |
|
|
| if args.check_weight_update_equal: |
| ray.get(rollout_manager.check_weights.remote(action="compare")) |
|
|
| if args.offload_rollout: |
| ray.get(rollout_manager.onload_kv.remote()) |
|
|
| |
| if args.num_rollout == 0 and args.eval_interval is not None: |
| ray.get(rollout_manager.eval.remote(rollout_id=0)) |
|
|
| def offload_train(): |
| if args.offload_train: |
| if args.use_critic: |
| critic_model.offload() |
| if rollout_id >= args.num_critic_only_steps: |
| actor_model.offload() |
| else: |
| actor_model.offload() |
| else: |
| actor_model.clear_memory() |
|
|
| def save(rollout_id): |
| if (not args.use_critic) or (rollout_id >= args.num_critic_only_steps): |
| actor_model.save_model( |
| rollout_id, |
| force_sync=rollout_id == args.num_rollout - 1, |
| ) |
| if args.use_critic: |
| critic_model.save_model( |
| rollout_id, |
| force_sync=rollout_id == args.num_rollout - 1, |
| ) |
| if args.rollout_global_dataset or getattr(args, "evolving_gym", False): |
| ray.get(rollout_manager.save.remote(rollout_id)) |
|
|
| |
| |
| for rollout_id in range(args.start_rollout_id, args.num_rollout): |
| if args.eval_interval is not None and rollout_id == 0 and not args.skip_eval_before_train: |
| ray.get(rollout_manager.eval.remote(rollout_id)) |
|
|
| rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id)) |
|
|
| if args.offload_rollout: |
| ray.get(rollout_manager.offload.remote()) |
|
|
| if args.use_critic: |
| critic_train_handle = critic_model.async_train(rollout_id, rollout_data_ref) |
| if rollout_id >= args.num_critic_only_steps: |
| ray.get(actor_model.async_train(rollout_id, rollout_data_ref)) |
| ray.get(critic_train_handle) |
| else: |
| ray.get(actor_model.async_train(rollout_id, rollout_data_ref)) |
|
|
| if should_run_periodic_action(rollout_id, args.save_interval, num_rollout_per_epoch, args.num_rollout): |
| save(rollout_id) |
|
|
| offload_train() |
| if args.offload_rollout: |
| ray.get(rollout_manager.onload_weights.remote()) |
| actor_model.update_weights() |
| if args.offload_rollout: |
| ray.get(rollout_manager.onload_kv.remote()) |
|
|
| if should_run_periodic_action(rollout_id, args.eval_interval, num_rollout_per_epoch): |
| ray.get(rollout_manager.eval.remote(rollout_id)) |
|
|
| ray.get(rollout_manager.dispose.remote()) |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| train(args) |
|
|