luanns commited on
Commit
0ec5e1f
·
verified ·
1 Parent(s): 020c401

Upload scripts/train.sh

Browse files
Files changed (1) hide show
  1. scripts/train.sh +112 -0
scripts/train.sh ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # GUI-Shift GRPO Training Script
3
+
4
+ set -e
5
+
6
+ # Default values
7
+ MODEL_NAME="Qwen/Qwen2.5-VL-7B-Instruct"
8
+ DATA_DIR="./data/gui_transition/filtered"
9
+ OUTPUT_DIR="./checkpoints/gui-shift"
10
+ K=1
11
+ NUM_GPUS=8
12
+
13
+ # Parse arguments
14
+ while [[ $# -gt 0 ]]; do
15
+ case $1 in
16
+ --model_name_or_path)
17
+ MODEL_NAME="$2"
18
+ shift 2
19
+ ;;
20
+ --data_dir)
21
+ DATA_DIR="$2"
22
+ shift 2
23
+ ;;
24
+ --output_dir)
25
+ OUTPUT_DIR="$2"
26
+ shift 2
27
+ ;;
28
+ --k)
29
+ K="$2"
30
+ shift 2
31
+ ;;
32
+ --num_gpus)
33
+ NUM_GPUS="$2"
34
+ shift 2
35
+ ;;
36
+ *)
37
+ echo "Unknown option: $1"
38
+ exit 1
39
+ ;;
40
+ esac
41
+ done
42
+
43
+ echo "=== GUI-Shift GRPO Training ==="
44
+ echo "Model: $MODEL_NAME"
45
+ echo "Data: $DATA_DIR"
46
+ echo "Output: $OUTPUT_DIR"
47
+ echo "K value: $K"
48
+ echo "GPUs: $NUM_GPUS"
49
+ echo ""
50
+
51
+ # Find data file for the specified K
52
+ DATA_FILE="${DATA_DIR}/k${K}_transition_filtered.jsonl"
53
+ if [ ! -f "$DATA_FILE" ]; then
54
+ DATA_FILE="${DATA_DIR}/k${K}_transition.jsonl"
55
+ if [ ! -f "$DATA_FILE" ]; then
56
+ echo "Error: Could not find data file for k=$K in $DATA_DIR"
57
+ exit 1
58
+ fi
59
+ fi
60
+
61
+ # Image folder (relative to trajectory data)
62
+ IMAGE_FOLDER="${DATA_DIR}/images"
63
+
64
+ echo "Using data file: $DATA_FILE"
65
+ echo "Using image folder: $IMAGE_FOLDER"
66
+ echo ""
67
+
68
+ # GRPO training hyperparameters (from paper Appendix A)
69
+ # learning_rate: 1e-6, temperature: 0.9, num_generations: 8
70
+ # num_train_epochs: 4, max_prompt_length: 1024, max_completion_length: 256
71
+ # per_device_train_batch_size: 2, gradient_accumulation_steps: 8
72
+ # epsilon: 0.2, beta: 0.04
73
+
74
+ export EXP_NAME="gui-shift-k${K}"
75
+ mkdir -p "runs/${EXP_NAME}/log"
76
+
77
+ # Note: This requires the VLM-R1 framework to be installed
78
+ # See: https://github.com/om-ai-lab/VLM-R1
79
+
80
+ torchrun --nproc_per_node="$NUM_GPUS" \
81
+ --nnodes="1" \
82
+ --node_rank="0" \
83
+ --master_addr="127.0.0.1" \
84
+ --master_port="12349" \
85
+ src/training/gui_grpo_trainer.py \
86
+ --model_name_or_path "$MODEL_NAME" \
87
+ --data_file_paths "$DATA_FILE" \
88
+ --image_folders "$IMAGE_FOLDER" \
89
+ --output_dir "$OUTPUT_DIR" \
90
+ --per_device_train_batch_size 2 \
91
+ --gradient_accumulation_steps 8 \
92
+ --gradient_checkpointing true \
93
+ --logging_steps 1 \
94
+ --num_train_epochs 4 \
95
+ --max_steps -1 \
96
+ --bf16 \
97
+ --attn_implementation flash_attention_2 \
98
+ --run_name "$EXP_NAME" \
99
+ --save_steps 400 \
100
+ --num_generations 8 \
101
+ --max_completion_length 256 \
102
+ --reward_funcs format accuracy \
103
+ --beta 0.04 \
104
+ --learning_rate 1.0e-6 \
105
+ --lr_scheduler_type cosine \
106
+ --warmup_ratio 0.1 \
107
+ --report_to wandb \
108
+ --dataset-name not_used \
109
+ --freeze_vision_modules true \
110
+
111
+ echo ""
112
+ echo "Training complete! Model saved to: $OUTPUT_DIR"