ianalin123 commited on
Commit
c7c13a1
·
1 Parent(s): 91deb66

refactor(train): update train.py for openenv flow

Browse files
Files changed (1) hide show
  1. train.py +16 -9
train.py CHANGED
@@ -10,13 +10,10 @@ Usage:
10
  python train.py --model unsloth/Qwen2.5-VL-7B-Instruct --epochs 3
11
  python train.py --model OX-PIXL/SpatialThinker-Qwen2.5-VL-7B --epochs 3
12
  python train.py --dry_run # test rewards without GPU
 
13
  """
14
  import argparse
15
- import json
16
- import copy
17
  import random
18
- from pathlib import Path
19
- from typing import Optional
20
 
21
  # VL (vision-language) model identifiers — use FastVisionModel for these
22
  _VL_MODEL_PATTERNS = ['VL', 'vl', 'Vision', 'vision', 'SpatialThinker', 'SpaceThinker']
@@ -40,16 +37,21 @@ def parse_args():
40
  parser.add_argument('--max_folds', type=int, default=8)
41
  parser.add_argument('--output', default='origami-grpo')
42
  parser.add_argument('--level', type=int, default=1, help='Target difficulty level (1-3)')
 
 
43
  parser.add_argument('--dry_run', action='store_true', help='Test reward function without training')
44
  return parser.parse_args()
45
 
46
 
47
- def build_dataset(env, level: int = 1, max_folds: int = 8) -> list[dict]:
48
  """
49
  Build a training dataset of prompts from available targets.
50
  Each item: {'prompt': str, 'target_name': str}
51
  Repeats each target multiple times to give enough training steps.
 
52
  """
 
 
53
  all_names = env.available_targets()
54
 
55
  # Filter by level; fall back to all targets if none match
@@ -62,8 +64,12 @@ def build_dataset(env, level: int = 1, max_folds: int = 8) -> list[dict]:
62
 
63
  items = []
64
  for name in level_names:
65
- obs = env.reset(target_name=name)
66
- prompt = obs['prompt']
 
 
 
 
67
  items.append({'prompt': prompt, 'target_name': name})
68
 
69
  # Repeat each target 10x; ensure at least 50 examples
@@ -143,8 +149,9 @@ def main():
143
  env = OrigamiEnvironment(mode='code_as_policy', max_steps=args.max_folds)
144
 
145
  # Build dataset
146
- dataset_items = build_dataset(env, level=args.level, max_folds=args.max_folds)
147
- print(f"Dataset: {len(dataset_items)} examples from level-{args.level} targets")
 
148
  print(f"Targets: {env.available_targets()}")
149
 
150
  # Dry run: test reward function without loading model
 
10
  python train.py --model unsloth/Qwen2.5-VL-7B-Instruct --epochs 3
11
  python train.py --model OX-PIXL/SpatialThinker-Qwen2.5-VL-7B --epochs 3
12
  python train.py --dry_run # test rewards without GPU
13
+ python train.py --no_semantic # use coordinate-based prompts instead of semantic
14
  """
15
  import argparse
 
 
16
  import random
 
 
17
 
18
  # VL (vision-language) model identifiers — use FastVisionModel for these
19
  _VL_MODEL_PATTERNS = ['VL', 'vl', 'Vision', 'vision', 'SpatialThinker', 'SpaceThinker']
 
37
  parser.add_argument('--max_folds', type=int, default=8)
38
  parser.add_argument('--output', default='origami-grpo')
39
  parser.add_argument('--level', type=int, default=1, help='Target difficulty level (1-3)')
40
+ parser.add_argument('--no_semantic', action='store_true',
41
+ help='Disable semantic prompts; use coordinate-based target (default: semantic)')
42
  parser.add_argument('--dry_run', action='store_true', help='Test reward function without training')
43
  return parser.parse_args()
44
 
45
 
46
+ def build_dataset(env, level: int = 1, max_folds: int = 8, semantic: bool = True) -> list[dict]:
47
  """
48
  Build a training dataset of prompts from available targets.
49
  Each item: {'prompt': str, 'target_name': str}
50
  Repeats each target multiple times to give enough training steps.
51
+ When semantic=True, uses get_semantic_description for task descriptions.
52
  """
53
+ from env.prompts import get_semantic_description, code_as_policy_prompt
54
+
55
  all_names = env.available_targets()
56
 
57
  # Filter by level; fall back to all targets if none match
 
64
 
65
  items = []
66
  for name in level_names:
67
+ target = env._targets[name]
68
+ if semantic:
69
+ desc = get_semantic_description(name, target)
70
+ prompt = code_as_policy_prompt(target, max_folds=max_folds, semantic_description=desc)
71
+ else:
72
+ prompt = code_as_policy_prompt(target, max_folds=max_folds, semantic_description=None)
73
  items.append({'prompt': prompt, 'target_name': name})
74
 
75
  # Repeat each target 10x; ensure at least 50 examples
 
149
  env = OrigamiEnvironment(mode='code_as_policy', max_steps=args.max_folds)
150
 
151
  # Build dataset
152
+ use_semantic = not args.no_semantic
153
+ dataset_items = build_dataset(env, level=args.level, max_folds=args.max_folds, semantic=use_semantic)
154
+ print(f"Dataset: {len(dataset_items)} examples from level-{args.level} targets (semantic={'on' if use_semantic else 'off'})")
155
  print(f"Targets: {env.available_targets()}")
156
 
157
  # Dry run: test reward function without loading model