| #!/bin/bash |
| |
|
|
| set -e |
|
|
| |
| MODEL_NAME="Qwen/Qwen2.5-VL-7B-Instruct" |
| DATA_DIR="./data/gui_transition/filtered" |
| OUTPUT_DIR="./checkpoints/gui-shift" |
| K=1 |
| NUM_GPUS=8 |
|
|
| |
| while [[ $# -gt 0 ]]; do |
| case $1 in |
| --model_name_or_path) |
| MODEL_NAME="$2" |
| shift 2 |
| ;; |
| --data_dir) |
| DATA_DIR="$2" |
| shift 2 |
| ;; |
| --output_dir) |
| OUTPUT_DIR="$2" |
| shift 2 |
| ;; |
| --k) |
| K="$2" |
| shift 2 |
| ;; |
| --num_gpus) |
| NUM_GPUS="$2" |
| shift 2 |
| ;; |
| *) |
| echo "Unknown option: $1" |
| exit 1 |
| ;; |
| esac |
| done |
|
|
| echo "=== GUI-Shift GRPO Training ===" |
| echo "Model: $MODEL_NAME" |
| echo "Data: $DATA_DIR" |
| echo "Output: $OUTPUT_DIR" |
| echo "K value: $K" |
| echo "GPUs: $NUM_GPUS" |
| echo "" |
|
|
| |
| DATA_FILE="${DATA_DIR}/k${K}_transition_filtered.jsonl" |
| if [ ! -f "$DATA_FILE" ]; then |
| DATA_FILE="${DATA_DIR}/k${K}_transition.jsonl" |
| if [ ! -f "$DATA_FILE" ]; then |
| echo "Error: Could not find data file for k=$K in $DATA_DIR" |
| exit 1 |
| fi |
| fi |
|
|
| |
| IMAGE_FOLDER="${DATA_DIR}/images" |
|
|
| echo "Using data file: $DATA_FILE" |
| echo "Using image folder: $IMAGE_FOLDER" |
| echo "" |
|
|
| |
| |
| |
| |
| |
|
|
| export EXP_NAME="gui-shift-k${K}" |
| mkdir -p "runs/${EXP_NAME}/log" |
|
|
| |
| |
|
|
| torchrun --nproc_per_node="$NUM_GPUS" \ |
| --nnodes="1" \ |
| --node_rank="0" \ |
| --master_addr="127.0.0.1" \ |
| --master_port="12349" \ |
| src/training/gui_grpo_trainer.py \ |
| --model_name_or_path "$MODEL_NAME" \ |
| --data_file_paths "$DATA_FILE" \ |
| --image_folders "$IMAGE_FOLDER" \ |
| --output_dir "$OUTPUT_DIR" \ |
| --per_device_train_batch_size 2 \ |
| --gradient_accumulation_steps 8 \ |
| --gradient_checkpointing true \ |
| --logging_steps 1 \ |
| --num_train_epochs 4 \ |
| --max_steps -1 \ |
| --bf16 \ |
| --attn_implementation flash_attention_2 \ |
| --run_name "$EXP_NAME" \ |
| --save_steps 400 \ |
| --num_generations 8 \ |
| --max_completion_length 256 \ |
| --reward_funcs format accuracy \ |
| --beta 0.04 \ |
| --learning_rate 1.0e-6 \ |
| --lr_scheduler_type cosine \ |
| --warmup_ratio 0.1 \ |
| --report_to wandb \ |
| --dataset-name not_used \ |
| --freeze_vision_modules true \ |
|
|
| echo "" |
| echo "Training complete! Model saved to: $OUTPUT_DIR" |
|
|