origami_env / training /train_grpo.py
praveen287's picture
Add GRPO training notebook + Dockerfile for cloud training (#1)
610ba6d
raw
history blame
5.76 kB
"""GRPO training script for origami RL.
Follows the 2048 OpenEnv + Unsloth pattern:
- LLM generates FOLD JSON crease patterns
- Two reward functions: valid_fold + shape_match
- GRPOTrainer from TRL handles the RL loop
Usage (local/Colab):
python -m training.train_grpo --task triangle --max_steps 600
Usage (Northflank — env vars set in Dockerfile.train):
python -m training.train_grpo --task $TASK --model $MODEL --max_steps $MAX_STEPS
"""
import argparse
import os
PROMPT_TEMPLATE = """You are an origami designer. Generate a FOLD-format crease pattern
that, when folded, produces the target shape described below.
Target: {description}
Paper size: {width} x {height}
Output a JSON object with these exact fields:
- vertices_coords: [[x, y], ...] — 2D positions on the flat paper (0 to {width} for x, 0 to {height} for y)
- edges_vertices: [[v1, v2], ...] — pairs of vertex indices forming edges
- edges_assignment: ["B"|"M"|"V", ...] — B=boundary, M=mountain fold, V=valley fold
- edges_foldAngle: [angle, ...] — fold angles in degrees (M: negative like -180, V: positive like 180, B: 0)
Rules:
- Boundary edges (B) must outline the paper rectangle
- At least one fold crease (M or V) must exist
- Mountain fold angles are negative (-180 to 0)
- Valley fold angles are positive (0 to 180)
- All vertex indices in edges must be valid (0 to N-1)
Output ONLY the JSON object wrapped in ```json ... ``` markers."""
def build_prompt(task: dict) -> str:
return PROMPT_TEMPLATE.format(
description=task["description"],
width=task["paper"]["width"],
height=task["paper"]["height"],
)
def main():
parser = argparse.ArgumentParser(description="GRPO training for origami RL")
parser.add_argument("--task", default="triangle", help="Task name")
parser.add_argument("--max_steps", type=int, default=600)
parser.add_argument("--num_generations", type=int, default=4)
parser.add_argument("--model", default="Qwen/Qwen2.5-3B-Instruct")
parser.add_argument("--lr", type=float, default=2e-4)
args = parser.parse_args()
# --- These imports are heavy, only load when actually training ---
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from origami_server.tasks import get_task
from training.reward import shape_match, valid_fold
# Try Unsloth first (CUDA), fall back to HF+PEFT
try:
from unsloth import FastLanguageModel
USE_UNSLOTH = True
except ImportError:
USE_UNSLOTH = False
task = get_task(args.task)
prompt_text = build_prompt(task)
# Build dataset (1000 copies of same prompt, like 2048)
dataset = Dataset.from_list(
[
{
"prompt": [{"role": "user", "content": prompt_text}],
"answer": 0,
}
]
* 1000
)
# Load model with LoRA
if USE_UNSLOTH:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model,
load_in_4bit=True,
max_seq_length=2048,
)
model = FastLanguageModel.get_peft_model(
model,
r=8,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha=16,
use_gradient_checkpointing="unsloth",
)
else:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
) if torch.cuda.is_available() else None
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = AutoModelForCausalLM.from_pretrained(
args.model,
quantization_config=bnb_config,
device_map="auto" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)
model = get_peft_model(model, LoraConfig(
r=8, lora_alpha=16, task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
))
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Wrap shape_match to inject task_name
def shape_match_reward(completions, **kwargs):
return shape_match(completions, task_name=args.task, **kwargs)
# GRPO config
training_args = GRPOConfig(
temperature=1.0,
learning_rate=args.lr,
weight_decay=0.001,
warmup_ratio=0.1,
lr_scheduler_type="linear",
optim="adamw_8bit",
logging_steps=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
num_generations=args.num_generations,
max_prompt_length=1024,
max_completion_length=1024,
max_steps=args.max_steps,
save_steps=100,
output_dir=os.environ.get("OUTPUT_DIR", "outputs"),
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[valid_fold, shape_match_reward],
args=training_args,
train_dataset=dataset,
)
trainer.train()
# Save the LoRA adapter
save_path = os.path.join(
os.environ.get("OUTPUT_DIR", "outputs"),
f"origami-{args.task}-lora-final",
)
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f"Model saved to {save_path}")
if __name__ == "__main__":
main()