Spaces:
Running
Running
Commit ·
c7c13a1
1
Parent(s): 91deb66
refactor(train): update train.py for openenv flow
Browse files
train.py
CHANGED
|
@@ -10,13 +10,10 @@ Usage:
|
|
| 10 |
python train.py --model unsloth/Qwen2.5-VL-7B-Instruct --epochs 3
|
| 11 |
python train.py --model OX-PIXL/SpatialThinker-Qwen2.5-VL-7B --epochs 3
|
| 12 |
python train.py --dry_run # test rewards without GPU
|
|
|
|
| 13 |
"""
|
| 14 |
import argparse
|
| 15 |
-
import json
|
| 16 |
-
import copy
|
| 17 |
import random
|
| 18 |
-
from pathlib import Path
|
| 19 |
-
from typing import Optional
|
| 20 |
|
| 21 |
# VL (vision-language) model identifiers — use FastVisionModel for these
|
| 22 |
_VL_MODEL_PATTERNS = ['VL', 'vl', 'Vision', 'vision', 'SpatialThinker', 'SpaceThinker']
|
|
@@ -40,16 +37,21 @@ def parse_args():
|
|
| 40 |
parser.add_argument('--max_folds', type=int, default=8)
|
| 41 |
parser.add_argument('--output', default='origami-grpo')
|
| 42 |
parser.add_argument('--level', type=int, default=1, help='Target difficulty level (1-3)')
|
|
|
|
|
|
|
| 43 |
parser.add_argument('--dry_run', action='store_true', help='Test reward function without training')
|
| 44 |
return parser.parse_args()
|
| 45 |
|
| 46 |
|
| 47 |
-
def build_dataset(env, level: int = 1, max_folds: int = 8) -> list[dict]:
|
| 48 |
"""
|
| 49 |
Build a training dataset of prompts from available targets.
|
| 50 |
Each item: {'prompt': str, 'target_name': str}
|
| 51 |
Repeats each target multiple times to give enough training steps.
|
|
|
|
| 52 |
"""
|
|
|
|
|
|
|
| 53 |
all_names = env.available_targets()
|
| 54 |
|
| 55 |
# Filter by level; fall back to all targets if none match
|
|
@@ -62,8 +64,12 @@ def build_dataset(env, level: int = 1, max_folds: int = 8) -> list[dict]:
|
|
| 62 |
|
| 63 |
items = []
|
| 64 |
for name in level_names:
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
items.append({'prompt': prompt, 'target_name': name})
|
| 68 |
|
| 69 |
# Repeat each target 10x; ensure at least 50 examples
|
|
@@ -143,8 +149,9 @@ def main():
|
|
| 143 |
env = OrigamiEnvironment(mode='code_as_policy', max_steps=args.max_folds)
|
| 144 |
|
| 145 |
# Build dataset
|
| 146 |
-
|
| 147 |
-
|
|
|
|
| 148 |
print(f"Targets: {env.available_targets()}")
|
| 149 |
|
| 150 |
# Dry run: test reward function without loading model
|
|
|
|
| 10 |
python train.py --model unsloth/Qwen2.5-VL-7B-Instruct --epochs 3
|
| 11 |
python train.py --model OX-PIXL/SpatialThinker-Qwen2.5-VL-7B --epochs 3
|
| 12 |
python train.py --dry_run # test rewards without GPU
|
| 13 |
+
python train.py --no_semantic # use coordinate-based prompts instead of semantic
|
| 14 |
"""
|
| 15 |
import argparse
|
|
|
|
|
|
|
| 16 |
import random
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# VL (vision-language) model identifiers — use FastVisionModel for these
|
| 19 |
_VL_MODEL_PATTERNS = ['VL', 'vl', 'Vision', 'vision', 'SpatialThinker', 'SpaceThinker']
|
|
|
|
| 37 |
parser.add_argument('--max_folds', type=int, default=8)
|
| 38 |
parser.add_argument('--output', default='origami-grpo')
|
| 39 |
parser.add_argument('--level', type=int, default=1, help='Target difficulty level (1-3)')
|
| 40 |
+
parser.add_argument('--no_semantic', action='store_true',
|
| 41 |
+
help='Disable semantic prompts; use coordinate-based target (default: semantic)')
|
| 42 |
parser.add_argument('--dry_run', action='store_true', help='Test reward function without training')
|
| 43 |
return parser.parse_args()
|
| 44 |
|
| 45 |
|
| 46 |
+
def build_dataset(env, level: int = 1, max_folds: int = 8, semantic: bool = True) -> list[dict]:
|
| 47 |
"""
|
| 48 |
Build a training dataset of prompts from available targets.
|
| 49 |
Each item: {'prompt': str, 'target_name': str}
|
| 50 |
Repeats each target multiple times to give enough training steps.
|
| 51 |
+
When semantic=True, uses get_semantic_description for task descriptions.
|
| 52 |
"""
|
| 53 |
+
from env.prompts import get_semantic_description, code_as_policy_prompt
|
| 54 |
+
|
| 55 |
all_names = env.available_targets()
|
| 56 |
|
| 57 |
# Filter by level; fall back to all targets if none match
|
|
|
|
| 64 |
|
| 65 |
items = []
|
| 66 |
for name in level_names:
|
| 67 |
+
target = env._targets[name]
|
| 68 |
+
if semantic:
|
| 69 |
+
desc = get_semantic_description(name, target)
|
| 70 |
+
prompt = code_as_policy_prompt(target, max_folds=max_folds, semantic_description=desc)
|
| 71 |
+
else:
|
| 72 |
+
prompt = code_as_policy_prompt(target, max_folds=max_folds, semantic_description=None)
|
| 73 |
items.append({'prompt': prompt, 'target_name': name})
|
| 74 |
|
| 75 |
# Repeat each target 10x; ensure at least 50 examples
|
|
|
|
| 149 |
env = OrigamiEnvironment(mode='code_as_policy', max_steps=args.max_folds)
|
| 150 |
|
| 151 |
# Build dataset
|
| 152 |
+
use_semantic = not args.no_semantic
|
| 153 |
+
dataset_items = build_dataset(env, level=args.level, max_folds=args.max_folds, semantic=use_semantic)
|
| 154 |
+
print(f"Dataset: {len(dataset_items)} examples from level-{args.level} targets (semantic={'on' if use_semantic else 'off'})")
|
| 155 |
print(f"Targets: {env.available_targets()}")
|
| 156 |
|
| 157 |
# Dry run: test reward function without loading model
|