gui-shift / scripts /train.sh
luanns's picture
Upload scripts/train.sh
0ec5e1f verified
#!/bin/bash
# GUI-Shift GRPO Training Script
set -e
# Default values
MODEL_NAME="Qwen/Qwen2.5-VL-7B-Instruct"
DATA_DIR="./data/gui_transition/filtered"
OUTPUT_DIR="./checkpoints/gui-shift"
K=1
NUM_GPUS=8
# Parse arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model_name_or_path)
MODEL_NAME="$2"
shift 2
;;
--data_dir)
DATA_DIR="$2"
shift 2
;;
--output_dir)
OUTPUT_DIR="$2"
shift 2
;;
--k)
K="$2"
shift 2
;;
--num_gpus)
NUM_GPUS="$2"
shift 2
;;
*)
echo "Unknown option: $1"
exit 1
;;
esac
done
echo "=== GUI-Shift GRPO Training ==="
echo "Model: $MODEL_NAME"
echo "Data: $DATA_DIR"
echo "Output: $OUTPUT_DIR"
echo "K value: $K"
echo "GPUs: $NUM_GPUS"
echo ""
# Find data file for the specified K
DATA_FILE="${DATA_DIR}/k${K}_transition_filtered.jsonl"
if [ ! -f "$DATA_FILE" ]; then
DATA_FILE="${DATA_DIR}/k${K}_transition.jsonl"
if [ ! -f "$DATA_FILE" ]; then
echo "Error: Could not find data file for k=$K in $DATA_DIR"
exit 1
fi
fi
# Image folder (relative to trajectory data)
IMAGE_FOLDER="${DATA_DIR}/images"
echo "Using data file: $DATA_FILE"
echo "Using image folder: $IMAGE_FOLDER"
echo ""
# GRPO training hyperparameters (from paper Appendix A)
# learning_rate: 1e-6, temperature: 0.9, num_generations: 8
# num_train_epochs: 4, max_prompt_length: 1024, max_completion_length: 256
# per_device_train_batch_size: 2, gradient_accumulation_steps: 8
# epsilon: 0.2, beta: 0.04
export EXP_NAME="gui-shift-k${K}"
mkdir -p "runs/${EXP_NAME}/log"
# Note: This requires the VLM-R1 framework to be installed
# See: https://github.com/om-ai-lab/VLM-R1
torchrun --nproc_per_node="$NUM_GPUS" \
--nnodes="1" \
--node_rank="0" \
--master_addr="127.0.0.1" \
--master_port="12349" \
src/training/gui_grpo_trainer.py \
--model_name_or_path "$MODEL_NAME" \
--data_file_paths "$DATA_FILE" \
--image_folders "$IMAGE_FOLDER" \
--output_dir "$OUTPUT_DIR" \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing true \
--logging_steps 1 \
--num_train_epochs 4 \
--max_steps -1 \
--bf16 \
--attn_implementation flash_attention_2 \
--run_name "$EXP_NAME" \
--save_steps 400 \
--num_generations 8 \
--max_completion_length 256 \
--reward_funcs format accuracy \
--beta 0.04 \
--learning_rate 1.0e-6 \
--lr_scheduler_type cosine \
--warmup_ratio 0.1 \
--report_to wandb \
--dataset-name not_used \
--freeze_vision_modules true \
echo ""
echo "Training complete! Model saved to: $OUTPUT_DIR"