KOKKKOKK commited on
Commit
0ad9ab3
·
verified ·
1 Parent(s): 939e3a5

Upload 5 files

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. grpo.sh +43 -0
  3. plugin.py +166 -0
  4. prompt.txt +45 -0
  5. test_am.jsonl +0 -0
  6. train_am.jsonl +3 -0
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  ckpt/tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  ckpt/tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ train_am.jsonl filter=lfs diff=lfs merge=lfs -text
grpo.sh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
2
+ export WANDB_API_KEY=47cad383a2e1b48737c9d92694709d6de14309a0
3
+ export MASTER_ADDR=127.0.0.1
4
+ export MASTER_PORT=29500
5
+ export SWIFT_DISTRIBUTED_BACKEND=nccl
6
+ export GLOO_SOCKET_IFNAME=lo
7
+ export MAX_PIXELS=65536
8
+ export NPROC_PER_NODE=8
9
+ export OMP_NUM_THREADS=1
10
+ export WANDB_BASE_URL=https://api.bandw.top
11
+
12
+ swift rlhf \
13
+ --rlhf_type grpo \
14
+ --model Qwen/Qwen2.5-VL-7B-Instruct \
15
+ --reward_funcs external_r1v_acc external_r1v_format format\
16
+ --reward_weights 1 0.1 0.1 \
17
+ --torch_dtype bfloat16 \
18
+ --dataset train_am.jsonl \
19
+ --external_plugins plugin.py \
20
+ --max_completion_length 2048 \
21
+ --num_train_epochs 1 \
22
+ --per_device_train_batch_size 1 \
23
+ --per_device_eval_batch_size 1 \
24
+ --learning_rate 1e-6 \
25
+ --gradient_accumulation_steps 16 \
26
+ --max_steps 100000 \
27
+ --eval_steps 100 \
28
+ --save_steps 100 \
29
+ --save_total_limit 2 \
30
+ --logging_steps 5 \
31
+ --max_length 8192 \
32
+ --output_dir GRPO_MAZE \
33
+ --warmup_ratio 0.05 \
34
+ --dataloader_num_workers 4 \
35
+ --dataset_num_proc 4 \
36
+ --num_generations 8 \
37
+ --temperature 1. \
38
+ --repetition_penalty 1.1 \
39
+ --system 'prompt.txt' \
40
+ --deepspeed zero3 \
41
+ --log_completions false \
42
+ --train_type full \
43
+ --report_to wandb \
plugin.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ torch.cuda.empty_cache()
4
+ from typing import List
5
+ from copy import deepcopy
6
+
7
+ from swift.plugin import ORM, orms
8
+ from swift.utils import get_logger
9
+
10
+ logger = get_logger()
11
+ """
12
+ Step 1: Define a Reward Class
13
+ Implement your custom reward calculation logic within the __call__ method.
14
+ The method accepts the model's output completions and dataset columns (passed as kwargs) as input parameters.
15
+
16
+ Step 2: Register the Reward Class in orms
17
+ For example:
18
+ python orms['external_math_acc'] = MathAccuracy
19
+
20
+ Step 3: Configure the Arguments
21
+ Use the following arguments when running the script:
22
+ bash --plugin /path/to/plugin.py --reward_funcs external_math_acc
23
+ """
24
+
25
+ def count_xml(text) -> float:
26
+ """
27
+ Count XML tags in response.
28
+
29
+ Args:
30
+ text: Input text
31
+
32
+ Returns:
33
+ Score based on XML tag presence
34
+ """
35
+ count = 0.0
36
+ if text.count("<think>") == 1:
37
+ count += 0.5
38
+ if text.count("</think>") == 1:
39
+ count += 0.5
40
+ return count
41
+
42
+ def extract_xml_answer(text: str) -> str:
43
+ """
44
+ Extract answer from XML-formatted text.
45
+
46
+ Args:
47
+ text: Input text with XML tags
48
+
49
+ Returns:
50
+ Extracted answer text
51
+ """
52
+ try:
53
+ answer = text.split("</think>")[1]
54
+ return answer.strip()
55
+ except:
56
+ return ""
57
+
58
+ def xmlcount_reward_func(completions, **kwargs) -> List[float]:
59
+ """
60
+ Reward function based on proper XML tag usage.
61
+
62
+ Args:
63
+ completions: Model completions
64
+
65
+ Returns:
66
+ List of reward scores
67
+ """
68
+ # contents = [completion[0]["content"] for completion in completions]
69
+ contents = completions
70
+ return [count_xml(c) for c in contents]
71
+
72
+ def int_reward_func(completions, **kwargs) -> List[float]:
73
+ """
74
+ Reward function that checks if responses contain valid direction tokens.
75
+
76
+ Args:
77
+ completions: Model completions
78
+
79
+ Returns:
80
+ List of reward scores
81
+ """
82
+ allowed_tokens = {"<|up|>", "<|down|>", "<|right|>", "<|left|>"}
83
+
84
+ # responses = [completion[0]['content'] for completion in completions]
85
+ responses = completions
86
+ extracted_responses = [extract_xml_answer(r) for r in responses]
87
+
88
+ def is_valid_sequence(seq):
89
+
90
+ seq_no_whitespace = re.sub(r'\s+', '', seq)
91
+ if not seq_no_whitespace:
92
+ return False
93
+ found_tokens = re.findall(r'<\|(?:up|down|right|left)\|>', seq_no_whitespace)
94
+ reconstructed = ''.join(found_tokens)
95
+ if reconstructed != seq_no_whitespace:
96
+ return False
97
+ return all(token in allowed_tokens for token in found_tokens)
98
+
99
+ return [1.0 if is_valid_sequence(r) else 0.0 for r in extracted_responses]
100
+
101
+ def count_turns(steps):
102
+ moves = re.findall(r"<\|(.*?)\|>", steps)
103
+ turns = sum(1 for i in range(1, len(moves)) if moves[i] != moves[i - 1])
104
+ return moves, turns
105
+
106
+ def correctness_reward_func(completions, answer, **kwargs) -> List[float]:
107
+ """
108
+ Reward function that checks correctness of answers.
109
+
110
+ Args:
111
+ prompts: Input prompts
112
+ completions: Model completions
113
+ answer: Ground truth answers
114
+
115
+ Returns:
116
+ List of reward scores
117
+ """
118
+ rewards = []
119
+ responses = completions
120
+ extracted_responses = [extract_xml_answer(r) for r in responses]
121
+ logger.debug('-'*20)
122
+ # logger.debug(f"Question:\n{q}")
123
+ logger.debug(f"\nAnswer:\n{answer[0]}")
124
+ logger.debug(f"\nResponse:\n{responses[0]}")
125
+ logger.debug(f"\nExtracted:\n{extracted_responses[0]}")
126
+ for r, a in zip(extracted_responses, answer):
127
+ r_steps, r_turns = count_turns(r)
128
+ a_steps, a_turns = count_turns(a)
129
+ if r == a:
130
+ reward = len(r_steps) * 2 * (r_turns + 1)
131
+ else:
132
+ k = 0
133
+ for r_s, a_s in zip(r_steps, a_steps):
134
+ if r_s == a_s:
135
+ k += 1
136
+ else:
137
+ break
138
+ prefix = r_steps[:k]
139
+ turns = count_turns("".join(prefix))[1]
140
+ reward = k * 1 * (turns + 1)
141
+ rewards.append(reward)
142
+ return rewards
143
+
144
+ class MazeReward(ORM):
145
+
146
+ def __call__(self, completions, solution, **kwargs) -> List[float]:
147
+ # print(completions)
148
+ rewards = correctness_reward_func(completions, solution)
149
+ return rewards
150
+
151
+ class MazeFormat(ORM):
152
+
153
+ def __call__(self, completions, solution, **kwargs) -> List[float]:
154
+ # print(completions)
155
+ rewards = int_reward_func(completions)
156
+ return rewards
157
+
158
+ class Format(ORM):
159
+
160
+ def __call__(self, completions, **kwargs) -> List[float]:
161
+ rewards = xmlcount_reward_func(completions)
162
+ return rewards
163
+
164
+ orms['external_r1v_acc'] = MazeReward
165
+ orms['external_r1v_format'] = MazeFormat
166
+ orms['format'] = Format
prompt.txt ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are a navigation assistant to solve visual pathfinding tasks.
2
+ Your goal is to **infer a valid path** from a visually marked starting point (green cell labeled 'O') to a visually marked target (red cell labeled 'T') by analyzing the maze image.
3
+
4
+ *Rules*:
5
+ - The maze is composed of open paths and impassable black walls.
6
+ - Movement is only allowed through open paths, not through walls.
7
+ - You can move one step at a time in the four cardinal directions: <|up|>, <|down|>, <|left|>, <|right|>.
8
+
9
+ *Output Format*:
10
+ Think through each step inside <think> and </think> tags.
11
+ At each step:
12
+ 1. Describe your current position based on visual layout and structure (e.g., "in a corridor", "facing a wall", "at a crossroad", "turning a corner").
13
+ 2. Decide the next move, and explain your reasoning
14
+ 3. Move and continue the path.
15
+ After your full reasoning, output only the final movement sequence using the allowed tokens:<|up|><|down|><|left|><|right|>
16
+
17
+ *Example Output 1*:
18
+ <think>
19
+ Step 1: I am at the green starting point 'O'. There is a wall below, but open paths to the left and right. Therefore, I move <|right|>.
20
+ Step 2: I am in a straight horizontal corridor. The path continues to the right. Therefore, I move <|right|>.
21
+ Step 3: I reach a corner where the path turns upward. The path is open upward. Therefore, I move <|up|>.
22
+ Step 4: I am in a vertical passage continuing upward. The path is open upward. Therefore, I move <|up|>.
23
+ Step 5: I am adjacent to the red target cell 'T' on my left. The left path is open. Therefore, I move <|left|> to reach the goal.
24
+ </think>
25
+ <|right|><|right|><|up|><|up|><|left|>
26
+
27
+ *Example Output 2*:
28
+ <think>
29
+ Step 1: I am at the green starting point 'O'. The path is open below. Therefore, I move <|down|>.
30
+ Step 2: I am in a dead end. The surrounding structure is closed on three sides, only the path above is open. Therefore, I move <|up|> to backtrack.
31
+ Step 3: I am back at the starting point. Now I try the left path. The structure is open to the left. Therefore, I move <|left|>.
32
+ Step 4: I am in a horizontal corridor. The left side remains open. Therefore, I move <|left|>.
33
+ Step 5: I can now see the red target directly ahead. Therefore, I move <|left|> to reach it.
34
+ </think>
35
+ <|down|><|up|><|left|><|left|><|left|>
36
+
37
+ *Example Output 3*:
38
+ <think>
39
+ Step 1: I am at the green starting point 'O'. The surrounding structure is open only to the left. Therefore, I move <|left|>.
40
+ Step 2: I am at the top of a vertical passage. The structure is open above. Therefore, I move <|up|>.
41
+ Step 3: I reach a horizontal corridor. The left is open. Therefore, I move <|left|>.
42
+ Step 4: I enter a vertical junction with a upward path. Therefore, I move <|up|>.
43
+ Step 5: I see an opening to the right toward the red goal. Therefore, I move <|right|>.
44
+ </think>
45
+ <|left|><|up|><|left|><|up|><|right|>
test_am.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
train_am.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37153a03df4af0f603266f21b227a9383fc9782a738ca1fb23234db0d625dd52
3
+ size 22634746