optigami / train.py
sissississi's picture
Add GRPO training notebook + Trackio integration + SpatialThinker support
2c8a058
raw
history blame
9.23 kB
"""
OrigamiRL — GRPO Training Script
Code-as-policy: model generates complete fold sequence, gets terminal reward.
Base model: SpatialThinker (Qwen2.5-VL-7B fine-tuned for spatial reasoning)
or any Unsloth-compatible model.
Usage:
python train.py
python train.py --model unsloth/Qwen2.5-VL-7B-Instruct --epochs 3
python train.py --model OX-PIXL/SpatialThinker-Qwen2.5-VL-7B --epochs 3
python train.py --dry_run # test rewards without GPU
"""
import argparse
import json
import copy
import random
from pathlib import Path
from typing import Optional
# VL (vision-language) model identifiers — use FastVisionModel for these
_VL_MODEL_PATTERNS = ['VL', 'vl', 'Vision', 'vision', 'SpatialThinker', 'SpaceThinker']
def _is_vl_model(model_name: str) -> bool:
return any(p in model_name for p in _VL_MODEL_PATTERNS)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='unsloth/Qwen2.5-VL-7B-Instruct',
help='Base model. Use unsloth/Qwen2.5-VL-7B-Instruct or '
'OX-PIXL/SpatialThinker-Qwen2.5-VL-7B for spatial reasoning')
parser.add_argument('--max_seq_length', type=int, default=2048)
parser.add_argument('--epochs', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--grad_accum', type=int, default=4)
parser.add_argument('--lr', type=float, default=5e-6)
parser.add_argument('--n_generations', type=int, default=8)
parser.add_argument('--max_folds', type=int, default=8)
parser.add_argument('--output', default='origami-grpo')
parser.add_argument('--level', type=int, default=1, help='Target difficulty level (1-3)')
parser.add_argument('--dry_run', action='store_true', help='Test reward function without training')
return parser.parse_args()
def build_dataset(env, level: int = 1, max_folds: int = 8) -> list[dict]:
"""
Build a training dataset of prompts from available targets.
Each item: {'prompt': str, 'target_name': str}
Repeats each target multiple times to give enough training steps.
"""
all_names = env.available_targets()
# Filter by level; fall back to all targets if none match
level_names = [
name for name in all_names
if env._targets[name].get('level', 1) == level
]
if not level_names:
level_names = all_names
items = []
for name in level_names:
obs = env.reset(target_name=name)
prompt = obs['prompt']
items.append({'prompt': prompt, 'target_name': name})
# Repeat each target 10x; ensure at least 50 examples
repeat = max(10, (50 + len(items) - 1) // len(items))
items = items * repeat
random.shuffle(items)
return items
def make_reward_fn(env_template, max_folds: int):
"""
Returns a reward function compatible with trl GRPOTrainer.
Signature: reward_fn(completions, prompts=None, **kwargs) -> list[float]
For each completion:
1. Clone the environment (fresh paper state)
2. Reset to the target embedded in the prompt (use target_name from kwargs if available)
3. Execute the completion as a fold sequence
4. Return the total reward
"""
def reward_fn(completions, prompts=None, **kwargs):
rewards = []
target_names = kwargs.get('target_names', [None] * len(completions))
for completion, target_name in zip(completions, target_names):
try:
env = env_template.clone()
env.reset(target_name=target_name)
_, reward_dict, _, _ = env.step(completion)
rewards.append(float(reward_dict['total']))
except Exception:
rewards.append(-0.1)
return rewards
return reward_fn
def make_detailed_reward_fns(env_template, max_folds: int) -> list:
"""
Returns a list of reward functions, one per reward component.
Used for detailed W&B logging of each component separately.
Components: kawasaki, maekawa, blb, progress, economy, completion
"""
components = ['kawasaki', 'maekawa', 'blb', 'progress', 'economy', 'completion']
def make_component_fn(component: str):
def component_fn(completions, prompts=None, **kwargs):
rewards = []
target_names = kwargs.get('target_names', [None] * len(completions))
for completion, target_name in zip(completions, target_names):
try:
env = env_template.clone()
env.reset(target_name=target_name)
_, reward_dict, _, _ = env.step(completion)
rewards.append(float(reward_dict.get(component, 0.0)))
except Exception:
rewards.append(0.0)
return rewards
component_fn.__name__ = f'reward_{component}'
return component_fn
return [make_component_fn(c) for c in components]
def main():
args = parse_args()
# Import here to allow dry_run without GPU
from env.environment import OrigamiEnvironment
env = OrigamiEnvironment(mode='code_as_policy', max_steps=args.max_folds)
# Build dataset
dataset_items = build_dataset(env, level=args.level, max_folds=args.max_folds)
print(f"Dataset: {len(dataset_items)} examples from level-{args.level} targets")
print(f"Targets: {env.available_targets()}")
# Dry run: test reward function without loading model
if args.dry_run:
reward_fn = make_reward_fn(env, args.max_folds)
test_completions = [
'<folds>[{"instruction": "Valley fold along horizontal center", "from": [0, 0.5], "to": [1, 0.5], "assignment": "V"}]</folds>',
'<folds>[{"instruction": "Invalid fold", "from": [0.3, 0.3], "to": [0.7, 0.7], "assignment": "V"}]</folds>',
'this is not valid JSON at all',
]
target_names = [dataset_items[0]['target_name']] * 3
rewards = reward_fn(test_completions, target_names=target_names)
print(f"\nDry run rewards: {rewards}")
print("Dry run complete — reward function works.")
return
# Load model via unsloth
# VL models (SpatialThinker, Qwen2.5-VL) use FastVisionModel
# Text-only models use FastLanguageModel
is_vl = _is_vl_model(args.model)
try:
if is_vl:
from unsloth import FastVisionModel as ModelLoader
print(f"Loading VL model (vision-language): {args.model}")
else:
from unsloth import FastLanguageModel as ModelLoader
print(f"Loading text model: {args.model}")
except ImportError:
print("ERROR: unsloth not installed. Run: pip install unsloth")
print("Or run with --dry_run to test the reward function without a model.")
return
model, tokenizer = ModelLoader.from_pretrained(
model_name=args.model,
max_seq_length=args.max_seq_length,
load_in_4bit=True,
)
model = ModelLoader.get_peft_model(
model,
r=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=32,
lora_dropout=0,
use_gradient_checkpointing="unsloth",
)
# Convert dataset to HuggingFace Dataset format
from datasets import Dataset
# GRPOTrainer expects 'prompt' column and optionally others.
# We embed target_name in the dataset so the reward fn can use it.
hf_dataset = Dataset.from_list(dataset_items)
# Build main reward function
reward_fn = make_reward_fn(env, args.max_folds)
from trl import GRPOConfig, GRPOTrainer
config = GRPOConfig(
output_dir=args.output,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.lr,
max_completion_length=512,
num_generations=args.n_generations,
temperature=1.0,
logging_steps=1,
report_to="trackio",
run_name="origami-grpo",
)
# GRPOTrainer passes all dataset columns as kwargs to reward_funcs.
# The 'target_name' column arrives as a list (one per completion in the batch).
def wrapped_reward_fn(completions, target_name=None, **kwargs):
"""Wrapper that extracts target_name from batch columns."""
target_names = target_name if isinstance(target_name, list) else [target_name] * len(completions)
return reward_fn(completions, target_names=target_names)
trainer = GRPOTrainer(
model=model,
config=config,
train_dataset=hf_dataset,
reward_funcs=[wrapped_reward_fn],
tokenizer=tokenizer,
)
print(f"\nStarting GRPO training...")
print(f" Model: {args.model}")
print(f" Level: {args.level} targets")
print(f" Epochs: {args.epochs}")
print(f" Generations per prompt: {args.n_generations}")
print(f" Output: {args.output}/")
trainer.train()
# Save
model.save_pretrained(args.output)
tokenizer.save_pretrained(args.output)
print(f"\nModel saved to {args.output}/")
if __name__ == '__main__':
main()