File size: 9,950 Bytes
19abe39
 
 
 
2c8a058
 
 
19abe39
 
2c8a058
 
 
e9b7141
19abe39
 
 
 
2c8a058
 
 
 
 
 
 
19abe39
 
 
2c8a058
 
 
19abe39
 
 
 
 
 
 
 
 
e9b7141
 
19abe39
 
 
 
e9b7141
19abe39
 
 
 
e9b7141
19abe39
e9b7141
 
19abe39
 
 
 
 
 
 
 
 
 
 
 
e9b7141
 
 
 
 
 
19abe39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9b7141
 
 
19abe39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c8a058
 
 
 
19abe39
2c8a058
 
 
 
 
 
19abe39
 
 
 
 
2c8a058
19abe39
 
 
 
 
2c8a058
19abe39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c8a058
19abe39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
"""
OrigamiRL — GRPO Training Script
Code-as-policy: model generates complete fold sequence, gets terminal reward.

Base model: SpatialThinker (Qwen2.5-VL-7B fine-tuned for spatial reasoning)
or any Unsloth-compatible model.

Usage:
    python train.py
    python train.py --model unsloth/Qwen2.5-VL-7B-Instruct --epochs 3
    python train.py --model OX-PIXL/SpatialThinker-Qwen2.5-VL-7B --epochs 3
    python train.py --dry_run  # test rewards without GPU
    python train.py --no_semantic  # use coordinate-based prompts instead of semantic
"""
import argparse
import random

# VL (vision-language) model identifiers — use FastVisionModel for these
_VL_MODEL_PATTERNS = ['VL', 'vl', 'Vision', 'vision', 'SpatialThinker', 'SpaceThinker']


def _is_vl_model(model_name: str) -> bool:
    return any(p in model_name for p in _VL_MODEL_PATTERNS)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default='unsloth/Qwen2.5-VL-7B-Instruct',
                        help='Base model. Use unsloth/Qwen2.5-VL-7B-Instruct or '
                             'OX-PIXL/SpatialThinker-Qwen2.5-VL-7B for spatial reasoning')
    parser.add_argument('--max_seq_length', type=int, default=2048)
    parser.add_argument('--epochs', type=int, default=3)
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--grad_accum', type=int, default=4)
    parser.add_argument('--lr', type=float, default=5e-6)
    parser.add_argument('--n_generations', type=int, default=8)
    parser.add_argument('--max_folds', type=int, default=8)
    parser.add_argument('--output', default='origami-grpo')
    parser.add_argument('--level', type=int, default=1, help='Target difficulty level (1-3)')
    parser.add_argument('--no_semantic', action='store_true',
                        help='Disable semantic prompts; use coordinate-based target (default: semantic)')
    parser.add_argument('--dry_run', action='store_true', help='Test reward function without training')
    return parser.parse_args()


def build_dataset(env, level: int = 1, max_folds: int = 8, semantic: bool = True) -> list[dict]:
    """
    Build a training dataset of prompts from available targets.
    Each item: {'prompt': str, 'target_name': str}
    Repeats each target multiple times to give enough training steps.
    When semantic=True, uses get_semantic_description for task descriptions.
    """
    from env.prompts import get_semantic_description, code_as_policy_prompt

    all_names = env.available_targets()

    # Filter by level; fall back to all targets if none match
    level_names = [
        name for name in all_names
        if env._targets[name].get('level', 1) == level
    ]
    if not level_names:
        level_names = all_names

    items = []
    for name in level_names:
        target = env._targets[name]
        if semantic:
            desc = get_semantic_description(name, target)
            prompt = code_as_policy_prompt(target, max_folds=max_folds, semantic_description=desc)
        else:
            prompt = code_as_policy_prompt(target, max_folds=max_folds, semantic_description=None)
        items.append({'prompt': prompt, 'target_name': name})

    # Repeat each target 10x; ensure at least 50 examples
    repeat = max(10, (50 + len(items) - 1) // len(items))
    items = items * repeat

    random.shuffle(items)
    return items


def make_reward_fn(env_template, max_folds: int):
    """
    Returns a reward function compatible with trl GRPOTrainer.

    Signature: reward_fn(completions, prompts=None, **kwargs) -> list[float]

    For each completion:
    1. Clone the environment (fresh paper state)
    2. Reset to the target embedded in the prompt (use target_name from kwargs if available)
    3. Execute the completion as a fold sequence
    4. Return the total reward
    """
    def reward_fn(completions, prompts=None, **kwargs):
        rewards = []
        target_names = kwargs.get('target_names', [None] * len(completions))

        for completion, target_name in zip(completions, target_names):
            try:
                env = env_template.clone()
                env.reset(target_name=target_name)
                _, reward_dict, _, _ = env.step(completion)
                rewards.append(float(reward_dict['total']))
            except Exception:
                rewards.append(-0.1)

        return rewards

    return reward_fn


def make_detailed_reward_fns(env_template, max_folds: int) -> list:
    """
    Returns a list of reward functions, one per reward component.
    Used for detailed W&B logging of each component separately.
    Components: kawasaki, maekawa, blb, progress, economy, completion
    """
    components = ['kawasaki', 'maekawa', 'blb', 'progress', 'economy', 'completion']

    def make_component_fn(component: str):
        def component_fn(completions, prompts=None, **kwargs):
            rewards = []
            target_names = kwargs.get('target_names', [None] * len(completions))

            for completion, target_name in zip(completions, target_names):
                try:
                    env = env_template.clone()
                    env.reset(target_name=target_name)
                    _, reward_dict, _, _ = env.step(completion)
                    rewards.append(float(reward_dict.get(component, 0.0)))
                except Exception:
                    rewards.append(0.0)

            return rewards

        component_fn.__name__ = f'reward_{component}'
        return component_fn

    return [make_component_fn(c) for c in components]


def main():
    args = parse_args()

    # Import here to allow dry_run without GPU
    from env.environment import OrigamiEnvironment

    env = OrigamiEnvironment(mode='code_as_policy', max_steps=args.max_folds)

    # Build dataset
    use_semantic = not args.no_semantic
    dataset_items = build_dataset(env, level=args.level, max_folds=args.max_folds, semantic=use_semantic)
    print(f"Dataset: {len(dataset_items)} examples from level-{args.level} targets (semantic={'on' if use_semantic else 'off'})")
    print(f"Targets: {env.available_targets()}")

    # Dry run: test reward function without loading model
    if args.dry_run:
        reward_fn = make_reward_fn(env, args.max_folds)
        test_completions = [
            '<folds>[{"instruction": "Valley fold along horizontal center", "from": [0, 0.5], "to": [1, 0.5], "assignment": "V"}]</folds>',
            '<folds>[{"instruction": "Invalid fold", "from": [0.3, 0.3], "to": [0.7, 0.7], "assignment": "V"}]</folds>',
            'this is not valid JSON at all',
        ]
        target_names = [dataset_items[0]['target_name']] * 3
        rewards = reward_fn(test_completions, target_names=target_names)
        print(f"\nDry run rewards: {rewards}")
        print("Dry run complete — reward function works.")
        return

    # Load model via unsloth
    # VL models (SpatialThinker, Qwen2.5-VL) use FastVisionModel
    # Text-only models use FastLanguageModel
    is_vl = _is_vl_model(args.model)

    try:
        if is_vl:
            from unsloth import FastVisionModel as ModelLoader
            print(f"Loading VL model (vision-language): {args.model}")
        else:
            from unsloth import FastLanguageModel as ModelLoader
            print(f"Loading text model: {args.model}")
    except ImportError:
        print("ERROR: unsloth not installed. Run: pip install unsloth")
        print("Or run with --dry_run to test the reward function without a model.")
        return

    model, tokenizer = ModelLoader.from_pretrained(
        model_name=args.model,
        max_seq_length=args.max_seq_length,
        load_in_4bit=True,
    )

    model = ModelLoader.get_peft_model(
        model,
        r=32,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj"],
        lora_alpha=32,
        lora_dropout=0,
        use_gradient_checkpointing="unsloth",
    )

    # Convert dataset to HuggingFace Dataset format
    from datasets import Dataset

    # GRPOTrainer expects 'prompt' column and optionally others.
    # We embed target_name in the dataset so the reward fn can use it.
    hf_dataset = Dataset.from_list(dataset_items)

    # Build main reward function
    reward_fn = make_reward_fn(env, args.max_folds)

    from trl import GRPOConfig, GRPOTrainer

    config = GRPOConfig(
        output_dir=args.output,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum,
        learning_rate=args.lr,
        max_completion_length=512,
        num_generations=args.n_generations,
        temperature=1.0,
        logging_steps=1,
        report_to="trackio",
        run_name="origami-grpo",
    )

    # GRPOTrainer passes all dataset columns as kwargs to reward_funcs.
    # The 'target_name' column arrives as a list (one per completion in the batch).
    def wrapped_reward_fn(completions, target_name=None, **kwargs):
        """Wrapper that extracts target_name from batch columns."""
        target_names = target_name if isinstance(target_name, list) else [target_name] * len(completions)
        return reward_fn(completions, target_names=target_names)

    trainer = GRPOTrainer(
        model=model,
        config=config,
        train_dataset=hf_dataset,
        reward_funcs=[wrapped_reward_fn],
        tokenizer=tokenizer,
    )

    print(f"\nStarting GRPO training...")
    print(f"  Model: {args.model}")
    print(f"  Level: {args.level} targets")
    print(f"  Epochs: {args.epochs}")
    print(f"  Generations per prompt: {args.n_generations}")
    print(f"  Output: {args.output}/")

    trainer.train()

    # Save
    model.save_pretrained(args.output)
    tokenizer.save_pretrained(args.output)
    print(f"\nModel saved to {args.output}/")


if __name__ == '__main__':
    main()