Upload 5 files
Browse files- .gitattributes +1 -0
- grpo.sh +43 -0
- plugin.py +166 -0
- prompt.txt +45 -0
- test_am.jsonl +0 -0
- train_am.jsonl +3 -0
.gitattributes
CHANGED
|
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
ckpt/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
ckpt/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
train_am.jsonl filter=lfs diff=lfs merge=lfs -text
|
grpo.sh
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
| 2 |
+
export WANDB_API_KEY=47cad383a2e1b48737c9d92694709d6de14309a0
|
| 3 |
+
export MASTER_ADDR=127.0.0.1
|
| 4 |
+
export MASTER_PORT=29500
|
| 5 |
+
export SWIFT_DISTRIBUTED_BACKEND=nccl
|
| 6 |
+
export GLOO_SOCKET_IFNAME=lo
|
| 7 |
+
export MAX_PIXELS=65536
|
| 8 |
+
export NPROC_PER_NODE=8
|
| 9 |
+
export OMP_NUM_THREADS=1
|
| 10 |
+
export WANDB_BASE_URL=https://api.bandw.top
|
| 11 |
+
|
| 12 |
+
swift rlhf \
|
| 13 |
+
--rlhf_type grpo \
|
| 14 |
+
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
| 15 |
+
--reward_funcs external_r1v_acc external_r1v_format format\
|
| 16 |
+
--reward_weights 1 0.1 0.1 \
|
| 17 |
+
--torch_dtype bfloat16 \
|
| 18 |
+
--dataset train_am.jsonl \
|
| 19 |
+
--external_plugins plugin.py \
|
| 20 |
+
--max_completion_length 2048 \
|
| 21 |
+
--num_train_epochs 1 \
|
| 22 |
+
--per_device_train_batch_size 1 \
|
| 23 |
+
--per_device_eval_batch_size 1 \
|
| 24 |
+
--learning_rate 1e-6 \
|
| 25 |
+
--gradient_accumulation_steps 16 \
|
| 26 |
+
--max_steps 100000 \
|
| 27 |
+
--eval_steps 100 \
|
| 28 |
+
--save_steps 100 \
|
| 29 |
+
--save_total_limit 2 \
|
| 30 |
+
--logging_steps 5 \
|
| 31 |
+
--max_length 8192 \
|
| 32 |
+
--output_dir GRPO_MAZE \
|
| 33 |
+
--warmup_ratio 0.05 \
|
| 34 |
+
--dataloader_num_workers 4 \
|
| 35 |
+
--dataset_num_proc 4 \
|
| 36 |
+
--num_generations 8 \
|
| 37 |
+
--temperature 1. \
|
| 38 |
+
--repetition_penalty 1.1 \
|
| 39 |
+
--system 'prompt.txt' \
|
| 40 |
+
--deepspeed zero3 \
|
| 41 |
+
--log_completions false \
|
| 42 |
+
--train_type full \
|
| 43 |
+
--report_to wandb \
|
plugin.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import torch
|
| 3 |
+
torch.cuda.empty_cache()
|
| 4 |
+
from typing import List
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
|
| 7 |
+
from swift.plugin import ORM, orms
|
| 8 |
+
from swift.utils import get_logger
|
| 9 |
+
|
| 10 |
+
logger = get_logger()
|
| 11 |
+
"""
|
| 12 |
+
Step 1: Define a Reward Class
|
| 13 |
+
Implement your custom reward calculation logic within the __call__ method.
|
| 14 |
+
The method accepts the model's output completions and dataset columns (passed as kwargs) as input parameters.
|
| 15 |
+
|
| 16 |
+
Step 2: Register the Reward Class in orms
|
| 17 |
+
For example:
|
| 18 |
+
python orms['external_math_acc'] = MathAccuracy
|
| 19 |
+
|
| 20 |
+
Step 3: Configure the Arguments
|
| 21 |
+
Use the following arguments when running the script:
|
| 22 |
+
bash --plugin /path/to/plugin.py --reward_funcs external_math_acc
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def count_xml(text) -> float:
|
| 26 |
+
"""
|
| 27 |
+
Count XML tags in response.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
text: Input text
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
Score based on XML tag presence
|
| 34 |
+
"""
|
| 35 |
+
count = 0.0
|
| 36 |
+
if text.count("<think>") == 1:
|
| 37 |
+
count += 0.5
|
| 38 |
+
if text.count("</think>") == 1:
|
| 39 |
+
count += 0.5
|
| 40 |
+
return count
|
| 41 |
+
|
| 42 |
+
def extract_xml_answer(text: str) -> str:
|
| 43 |
+
"""
|
| 44 |
+
Extract answer from XML-formatted text.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
text: Input text with XML tags
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Extracted answer text
|
| 51 |
+
"""
|
| 52 |
+
try:
|
| 53 |
+
answer = text.split("</think>")[1]
|
| 54 |
+
return answer.strip()
|
| 55 |
+
except:
|
| 56 |
+
return ""
|
| 57 |
+
|
| 58 |
+
def xmlcount_reward_func(completions, **kwargs) -> List[float]:
|
| 59 |
+
"""
|
| 60 |
+
Reward function based on proper XML tag usage.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
completions: Model completions
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
List of reward scores
|
| 67 |
+
"""
|
| 68 |
+
# contents = [completion[0]["content"] for completion in completions]
|
| 69 |
+
contents = completions
|
| 70 |
+
return [count_xml(c) for c in contents]
|
| 71 |
+
|
| 72 |
+
def int_reward_func(completions, **kwargs) -> List[float]:
|
| 73 |
+
"""
|
| 74 |
+
Reward function that checks if responses contain valid direction tokens.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
completions: Model completions
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
List of reward scores
|
| 81 |
+
"""
|
| 82 |
+
allowed_tokens = {"<|up|>", "<|down|>", "<|right|>", "<|left|>"}
|
| 83 |
+
|
| 84 |
+
# responses = [completion[0]['content'] for completion in completions]
|
| 85 |
+
responses = completions
|
| 86 |
+
extracted_responses = [extract_xml_answer(r) for r in responses]
|
| 87 |
+
|
| 88 |
+
def is_valid_sequence(seq):
|
| 89 |
+
|
| 90 |
+
seq_no_whitespace = re.sub(r'\s+', '', seq)
|
| 91 |
+
if not seq_no_whitespace:
|
| 92 |
+
return False
|
| 93 |
+
found_tokens = re.findall(r'<\|(?:up|down|right|left)\|>', seq_no_whitespace)
|
| 94 |
+
reconstructed = ''.join(found_tokens)
|
| 95 |
+
if reconstructed != seq_no_whitespace:
|
| 96 |
+
return False
|
| 97 |
+
return all(token in allowed_tokens for token in found_tokens)
|
| 98 |
+
|
| 99 |
+
return [1.0 if is_valid_sequence(r) else 0.0 for r in extracted_responses]
|
| 100 |
+
|
| 101 |
+
def count_turns(steps):
|
| 102 |
+
moves = re.findall(r"<\|(.*?)\|>", steps)
|
| 103 |
+
turns = sum(1 for i in range(1, len(moves)) if moves[i] != moves[i - 1])
|
| 104 |
+
return moves, turns
|
| 105 |
+
|
| 106 |
+
def correctness_reward_func(completions, answer, **kwargs) -> List[float]:
|
| 107 |
+
"""
|
| 108 |
+
Reward function that checks correctness of answers.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
prompts: Input prompts
|
| 112 |
+
completions: Model completions
|
| 113 |
+
answer: Ground truth answers
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
List of reward scores
|
| 117 |
+
"""
|
| 118 |
+
rewards = []
|
| 119 |
+
responses = completions
|
| 120 |
+
extracted_responses = [extract_xml_answer(r) for r in responses]
|
| 121 |
+
logger.debug('-'*20)
|
| 122 |
+
# logger.debug(f"Question:\n{q}")
|
| 123 |
+
logger.debug(f"\nAnswer:\n{answer[0]}")
|
| 124 |
+
logger.debug(f"\nResponse:\n{responses[0]}")
|
| 125 |
+
logger.debug(f"\nExtracted:\n{extracted_responses[0]}")
|
| 126 |
+
for r, a in zip(extracted_responses, answer):
|
| 127 |
+
r_steps, r_turns = count_turns(r)
|
| 128 |
+
a_steps, a_turns = count_turns(a)
|
| 129 |
+
if r == a:
|
| 130 |
+
reward = len(r_steps) * 2 * (r_turns + 1)
|
| 131 |
+
else:
|
| 132 |
+
k = 0
|
| 133 |
+
for r_s, a_s in zip(r_steps, a_steps):
|
| 134 |
+
if r_s == a_s:
|
| 135 |
+
k += 1
|
| 136 |
+
else:
|
| 137 |
+
break
|
| 138 |
+
prefix = r_steps[:k]
|
| 139 |
+
turns = count_turns("".join(prefix))[1]
|
| 140 |
+
reward = k * 1 * (turns + 1)
|
| 141 |
+
rewards.append(reward)
|
| 142 |
+
return rewards
|
| 143 |
+
|
| 144 |
+
class MazeReward(ORM):
|
| 145 |
+
|
| 146 |
+
def __call__(self, completions, solution, **kwargs) -> List[float]:
|
| 147 |
+
# print(completions)
|
| 148 |
+
rewards = correctness_reward_func(completions, solution)
|
| 149 |
+
return rewards
|
| 150 |
+
|
| 151 |
+
class MazeFormat(ORM):
|
| 152 |
+
|
| 153 |
+
def __call__(self, completions, solution, **kwargs) -> List[float]:
|
| 154 |
+
# print(completions)
|
| 155 |
+
rewards = int_reward_func(completions)
|
| 156 |
+
return rewards
|
| 157 |
+
|
| 158 |
+
class Format(ORM):
|
| 159 |
+
|
| 160 |
+
def __call__(self, completions, **kwargs) -> List[float]:
|
| 161 |
+
rewards = xmlcount_reward_func(completions)
|
| 162 |
+
return rewards
|
| 163 |
+
|
| 164 |
+
orms['external_r1v_acc'] = MazeReward
|
| 165 |
+
orms['external_r1v_format'] = MazeFormat
|
| 166 |
+
orms['format'] = Format
|
prompt.txt
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are a navigation assistant to solve visual pathfinding tasks.
|
| 2 |
+
Your goal is to **infer a valid path** from a visually marked starting point (green cell labeled 'O') to a visually marked target (red cell labeled 'T') by analyzing the maze image.
|
| 3 |
+
|
| 4 |
+
*Rules*:
|
| 5 |
+
- The maze is composed of open paths and impassable black walls.
|
| 6 |
+
- Movement is only allowed through open paths, not through walls.
|
| 7 |
+
- You can move one step at a time in the four cardinal directions: <|up|>, <|down|>, <|left|>, <|right|>.
|
| 8 |
+
|
| 9 |
+
*Output Format*:
|
| 10 |
+
Think through each step inside <think> and </think> tags.
|
| 11 |
+
At each step:
|
| 12 |
+
1. Describe your current position based on visual layout and structure (e.g., "in a corridor", "facing a wall", "at a crossroad", "turning a corner").
|
| 13 |
+
2. Decide the next move, and explain your reasoning
|
| 14 |
+
3. Move and continue the path.
|
| 15 |
+
After your full reasoning, output only the final movement sequence using the allowed tokens:<|up|><|down|><|left|><|right|>
|
| 16 |
+
|
| 17 |
+
*Example Output 1*:
|
| 18 |
+
<think>
|
| 19 |
+
Step 1: I am at the green starting point 'O'. There is a wall below, but open paths to the left and right. Therefore, I move <|right|>.
|
| 20 |
+
Step 2: I am in a straight horizontal corridor. The path continues to the right. Therefore, I move <|right|>.
|
| 21 |
+
Step 3: I reach a corner where the path turns upward. The path is open upward. Therefore, I move <|up|>.
|
| 22 |
+
Step 4: I am in a vertical passage continuing upward. The path is open upward. Therefore, I move <|up|>.
|
| 23 |
+
Step 5: I am adjacent to the red target cell 'T' on my left. The left path is open. Therefore, I move <|left|> to reach the goal.
|
| 24 |
+
</think>
|
| 25 |
+
<|right|><|right|><|up|><|up|><|left|>
|
| 26 |
+
|
| 27 |
+
*Example Output 2*:
|
| 28 |
+
<think>
|
| 29 |
+
Step 1: I am at the green starting point 'O'. The path is open below. Therefore, I move <|down|>.
|
| 30 |
+
Step 2: I am in a dead end. The surrounding structure is closed on three sides, only the path above is open. Therefore, I move <|up|> to backtrack.
|
| 31 |
+
Step 3: I am back at the starting point. Now I try the left path. The structure is open to the left. Therefore, I move <|left|>.
|
| 32 |
+
Step 4: I am in a horizontal corridor. The left side remains open. Therefore, I move <|left|>.
|
| 33 |
+
Step 5: I can now see the red target directly ahead. Therefore, I move <|left|> to reach it.
|
| 34 |
+
</think>
|
| 35 |
+
<|down|><|up|><|left|><|left|><|left|>
|
| 36 |
+
|
| 37 |
+
*Example Output 3*:
|
| 38 |
+
<think>
|
| 39 |
+
Step 1: I am at the green starting point 'O'. The surrounding structure is open only to the left. Therefore, I move <|left|>.
|
| 40 |
+
Step 2: I am at the top of a vertical passage. The structure is open above. Therefore, I move <|up|>.
|
| 41 |
+
Step 3: I reach a horizontal corridor. The left is open. Therefore, I move <|left|>.
|
| 42 |
+
Step 4: I enter a vertical junction with a upward path. Therefore, I move <|up|>.
|
| 43 |
+
Step 5: I see an opening to the right toward the red goal. Therefore, I move <|right|>.
|
| 44 |
+
</think>
|
| 45 |
+
<|left|><|up|><|left|><|up|><|right|>
|
test_am.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
train_am.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:37153a03df4af0f603266f21b227a9383fc9782a738ca1fb23234db0d625dd52
|
| 3 |
+
size 22634746
|