| resources: |
| infra: k8s |
| accelerators: H100:1 |
| memory: 128+ |
| image_id: docker:verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.0-fa2.7.4 |
| ports: 8265 |
|
|
| num_nodes: 2 |
|
|
| secrets: |
| WANDB_API_KEY: |
|
|
| setup: | |
| rm -rf verl |
| git clone https://github.com/volcengine/verl.git |
| cd verl |
| pip3 install -v -e .[vllm] |
| pip3 install flashinfer-python |
| # Download GSM8K dataset - alternative approach |
| echo "Downloading GSM8K dataset..." |
| mkdir -p ~/data/gsm8k |
| # Check if the script exists and use absolute path |
| if [ -f "$(pwd)/examples/data_preprocess/gsm8k.py" ]; then |
| python3 "$(pwd)/examples/data_preprocess/gsm8k.py" --local_dir ~/data/gsm8k |
| else |
| echo "Warning: gsm8k.py script not found, skipping dataset download" |
| # You might want to download the dataset manually or use a different approach |
| fi |
| echo "GSM8K dataset download completed" |
| |
| run: | |
| # Get the Head node's IP and total number of nodes |
| HEAD_IP=$(echo "$SKYPILOT_NODE_IPS" | head -n1) |
| NUM_NODES=$SKYPILOT_NUM_NODES |
| |
| |
| |
|
|
| if [ "$SKYPILOT_NODE_RANK" == "0" ]; then |
| |
| echo "Starting Ray head node..." |
| ps aux | grep ray | grep 6379 &> /dev/null || ray start --head --disable-usage-stats \ |
| --port=6379 \ |
| --dashboard-host=0.0.0.0 \ |
| --dashboard-port=8265 |
|
|
| |
| echo "Waiting for all nodes to join Ray cluster..." |
| retry_count=0 |
| max_retries=30 |
| while [ $retry_count -lt $max_retries ]; do |
| connected_nodes=$(ray status 2>/dev/null | grep -c "node_" || echo "0") |
| echo "Connected nodes: $connected_nodes/$NUM_NODES (attempt $((retry_count+1))/$max_retries)" |
| |
| if [ "$connected_nodes" -ge "$NUM_NODES" ]; then |
| echo "All nodes connected to Ray cluster" |
| break |
| fi |
| |
| retry_count=$((retry_count+1)) |
| sleep 10 |
| done |
|
|
| if [ $retry_count -eq $max_retries ]; then |
| echo "WARNING: Not all nodes connected to Ray cluster after $max_retries attempts" |
| echo "Current Ray status:" |
| ray status |
| fi |
|
|
| python3 -m verl.trainer.main_ppo \ |
| data.train_files=$HOME/data/gsm8k/train.parquet \ |
| data.val_files=$HOME/data/gsm8k/test.parquet \ |
| data.train_batch_size=256 \ |
| data.max_prompt_length=512 \ |
| data.max_response_length=256 \ |
| actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ |
| actor_rollout_ref.actor.optim.lr=1e-6 \ |
| actor_rollout_ref.actor.ppo_mini_batch_size=64 \ |
| actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ |
| actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ |
| actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ |
| actor_rollout_ref.rollout.name=vllm \ |
| actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ |
| actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ |
| critic.optim.lr=1e-5 \ |
| critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ |
| critic.ppo_micro_batch_size_per_gpu=4 \ |
| algorithm.kl_ctrl.kl_coef=0.001 \ |
| trainer.logger=[console,wandb] \ |
| trainer.val_before_train=False \ |
| trainer.default_hdfs_dir=null \ |
| trainer.n_gpus_per_node=1 \ |
| trainer.nnodes=2 \ |
| trainer.save_freq=20 \ |
| trainer.test_freq=20 \ |
| trainer.total_epochs=2 \ |
| trainer.project_name=verl_examples \ |
| trainer.experiment_name=experiment_name_gsm8k |
|
|
| else |
| |
| sleep 15 |
| |
| echo "Starting Ray worker node..." |
| ps aux | grep ray | grep $HEAD_IP:6379 &> /dev/null || ray start --address $HEAD_IP:6379 --disable-usage-stats |
| sleep 10 |
| fi |
|
|
| echo "Node setup and Ray start script finished for rank $SKYPILOT_NODE_RANK." |