Spaces:
Running
Running
Add GRPO training notebook + Dockerfile for cloud training
#1
by
sissississi - opened
- training/Dockerfile.train +37 -0
- training/train_grpo.py +66 -21
- training/train_origami.ipynb +245 -0
training/Dockerfile.train
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
|
| 4 |
+
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install PyTorch with CUDA support + training stack
|
| 8 |
+
RUN pip install --no-cache-dir \
|
| 9 |
+
torch --index-url https://download.pytorch.org/whl/cu124 && \
|
| 10 |
+
pip install --no-cache-dir \
|
| 11 |
+
"trl>=0.7" \
|
| 12 |
+
"datasets>=2.14" \
|
| 13 |
+
"transformers>=4.40" \
|
| 14 |
+
"accelerate>=0.30" \
|
| 15 |
+
"peft>=0.10" \
|
| 16 |
+
"bitsandbytes>=0.43" \
|
| 17 |
+
numpy scipy pydantic
|
| 18 |
+
|
| 19 |
+
# Copy the full repo
|
| 20 |
+
COPY . /app
|
| 21 |
+
|
| 22 |
+
# Default: run training script
|
| 23 |
+
# Override TASK, MODEL, MAX_STEPS etc. via env vars on Northflank
|
| 24 |
+
ENV TASK="triangle"
|
| 25 |
+
ENV MODEL="Qwen/Qwen2.5-3B-Instruct"
|
| 26 |
+
ENV MAX_STEPS="600"
|
| 27 |
+
ENV NUM_GENERATIONS="4"
|
| 28 |
+
ENV LR="2e-4"
|
| 29 |
+
|
| 30 |
+
CMD ["sh", "-c", "python -m training.train_grpo --task $TASK --model $MODEL --max_steps $MAX_STEPS --num_generations $NUM_GENERATIONS --lr $LR"]
|
| 31 |
+
="triangle"
|
| 32 |
+
ENV MODEL="Qwen/Qwen2.5-3B-Instruct"
|
| 33 |
+
ENV MAX_STEPS="600"
|
| 34 |
+
ENV NUM_GENERATIONS="4"
|
| 35 |
+
ENV LR="2e-4"
|
| 36 |
+
|
| 37 |
+
CMD ["sh", "-c", "python -m training.train_grpo --task $TASK --model $MODEL --max_steps $MAX_STEPS --num_generations $NUM_GENERATIONS --lr $LR"]
|
training/train_grpo.py
CHANGED
|
@@ -5,11 +5,15 @@ Follows the 2048 OpenEnv + Unsloth pattern:
|
|
| 5 |
- Two reward functions: valid_fold + shape_match
|
| 6 |
- GRPOTrainer from TRL handles the RL loop
|
| 7 |
|
| 8 |
-
Usage (Colab):
|
| 9 |
-
python -m
|
|
|
|
|
|
|
|
|
|
| 10 |
"""
|
| 11 |
|
| 12 |
import argparse
|
|
|
|
| 13 |
|
| 14 |
PROMPT_TEMPLATE = """You are an origami designer. Generate a FOLD-format crease pattern
|
| 15 |
that, when folded, produces the target shape described below.
|
|
@@ -46,18 +50,24 @@ def main():
|
|
| 46 |
parser.add_argument("--task", default="triangle", help="Task name")
|
| 47 |
parser.add_argument("--max_steps", type=int, default=600)
|
| 48 |
parser.add_argument("--num_generations", type=int, default=4)
|
| 49 |
-
parser.add_argument("--model", default="
|
| 50 |
parser.add_argument("--lr", type=float, default=2e-4)
|
| 51 |
args = parser.parse_args()
|
| 52 |
|
| 53 |
# --- These imports are heavy, only load when actually training ---
|
| 54 |
from datasets import Dataset
|
| 55 |
from trl import GRPOConfig, GRPOTrainer
|
| 56 |
-
from unsloth import FastLanguageModel
|
| 57 |
|
| 58 |
from origami_server.tasks import get_task
|
| 59 |
from training.reward import shape_match, valid_fold
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
task = get_task(args.task)
|
| 62 |
prompt_text = build_prompt(task)
|
| 63 |
|
|
@@ -73,22 +83,48 @@ def main():
|
|
| 73 |
)
|
| 74 |
|
| 75 |
# Load model with LoRA
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
# Wrap shape_match to inject task_name
|
| 94 |
def shape_match_reward(completions, **kwargs):
|
|
@@ -110,7 +146,7 @@ def main():
|
|
| 110 |
max_completion_length=1024,
|
| 111 |
max_steps=args.max_steps,
|
| 112 |
save_steps=100,
|
| 113 |
-
output_dir="outputs",
|
| 114 |
)
|
| 115 |
|
| 116 |
trainer = GRPOTrainer(
|
|
@@ -123,6 +159,15 @@ def main():
|
|
| 123 |
|
| 124 |
trainer.train()
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
if __name__ == "__main__":
|
| 128 |
main()
|
|
|
|
| 5 |
- Two reward functions: valid_fold + shape_match
|
| 6 |
- GRPOTrainer from TRL handles the RL loop
|
| 7 |
|
| 8 |
+
Usage (local/Colab):
|
| 9 |
+
python -m training.train_grpo --task triangle --max_steps 600
|
| 10 |
+
|
| 11 |
+
Usage (Northflank — env vars set in Dockerfile.train):
|
| 12 |
+
python -m training.train_grpo --task $TASK --model $MODEL --max_steps $MAX_STEPS
|
| 13 |
"""
|
| 14 |
|
| 15 |
import argparse
|
| 16 |
+
import os
|
| 17 |
|
| 18 |
PROMPT_TEMPLATE = """You are an origami designer. Generate a FOLD-format crease pattern
|
| 19 |
that, when folded, produces the target shape described below.
|
|
|
|
| 50 |
parser.add_argument("--task", default="triangle", help="Task name")
|
| 51 |
parser.add_argument("--max_steps", type=int, default=600)
|
| 52 |
parser.add_argument("--num_generations", type=int, default=4)
|
| 53 |
+
parser.add_argument("--model", default="Qwen/Qwen2.5-3B-Instruct")
|
| 54 |
parser.add_argument("--lr", type=float, default=2e-4)
|
| 55 |
args = parser.parse_args()
|
| 56 |
|
| 57 |
# --- These imports are heavy, only load when actually training ---
|
| 58 |
from datasets import Dataset
|
| 59 |
from trl import GRPOConfig, GRPOTrainer
|
|
|
|
| 60 |
|
| 61 |
from origami_server.tasks import get_task
|
| 62 |
from training.reward import shape_match, valid_fold
|
| 63 |
|
| 64 |
+
# Try Unsloth first (CUDA), fall back to HF+PEFT
|
| 65 |
+
try:
|
| 66 |
+
from unsloth import FastLanguageModel
|
| 67 |
+
USE_UNSLOTH = True
|
| 68 |
+
except ImportError:
|
| 69 |
+
USE_UNSLOTH = False
|
| 70 |
+
|
| 71 |
task = get_task(args.task)
|
| 72 |
prompt_text = build_prompt(task)
|
| 73 |
|
|
|
|
| 83 |
)
|
| 84 |
|
| 85 |
# Load model with LoRA
|
| 86 |
+
if USE_UNSLOTH:
|
| 87 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 88 |
+
model_name=args.model,
|
| 89 |
+
load_in_4bit=True,
|
| 90 |
+
max_seq_length=2048,
|
| 91 |
+
)
|
| 92 |
+
model = FastLanguageModel.get_peft_model(
|
| 93 |
+
model,
|
| 94 |
+
r=8,
|
| 95 |
+
target_modules=[
|
| 96 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 97 |
+
"gate_proj", "up_proj", "down_proj",
|
| 98 |
+
],
|
| 99 |
+
lora_alpha=16,
|
| 100 |
+
use_gradient_checkpointing="unsloth",
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
import torch
|
| 104 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 105 |
+
from peft import LoraConfig, get_peft_model
|
| 106 |
+
|
| 107 |
+
bnb_config = BitsAndBytesConfig(
|
| 108 |
+
load_in_4bit=True,
|
| 109 |
+
bnb_4bit_quant_type="nf4",
|
| 110 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 111 |
+
) if torch.cuda.is_available() else None
|
| 112 |
+
|
| 113 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
| 114 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 115 |
+
args.model,
|
| 116 |
+
quantization_config=bnb_config,
|
| 117 |
+
device_map="auto" if torch.cuda.is_available() else "cpu",
|
| 118 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
| 119 |
+
)
|
| 120 |
+
model = get_peft_model(model, LoraConfig(
|
| 121 |
+
r=8, lora_alpha=16, task_type="CAUSAL_LM",
|
| 122 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 123 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 124 |
+
))
|
| 125 |
+
|
| 126 |
+
if tokenizer.pad_token is None:
|
| 127 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 128 |
|
| 129 |
# Wrap shape_match to inject task_name
|
| 130 |
def shape_match_reward(completions, **kwargs):
|
|
|
|
| 146 |
max_completion_length=1024,
|
| 147 |
max_steps=args.max_steps,
|
| 148 |
save_steps=100,
|
| 149 |
+
output_dir=os.environ.get("OUTPUT_DIR", "outputs"),
|
| 150 |
)
|
| 151 |
|
| 152 |
trainer = GRPOTrainer(
|
|
|
|
| 159 |
|
| 160 |
trainer.train()
|
| 161 |
|
| 162 |
+
# Save the LoRA adapter
|
| 163 |
+
save_path = os.path.join(
|
| 164 |
+
os.environ.get("OUTPUT_DIR", "outputs"),
|
| 165 |
+
f"origami-{args.task}-lora-final",
|
| 166 |
+
)
|
| 167 |
+
model.save_pretrained(save_path)
|
| 168 |
+
tokenizer.save_pretrained(save_path)
|
| 169 |
+
print(f"Model saved to {save_path}")
|
| 170 |
+
|
| 171 |
|
| 172 |
if __name__ == "__main__":
|
| 173 |
main()
|
training/train_origami.ipynb
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "p8uwc5bkc4n",
|
| 6 |
+
"source": "# Origami RL — GRPO Training Notebook\n\nTrain an LLM to generate valid FOLD-format crease patterns that fold into target shapes.\n\n**Pipeline:**\n1. LLM receives a prompt describing a target shape (e.g. \"fold diagonally into a triangle\")\n2. LLM generates a FOLD JSON crease pattern\n3. Physics simulator folds the paper analytically\n4. Reward = shape similarity (chamfer distance) to target × 20\n\n**Reward functions:**\n- `valid_fold`: +1.0 valid FOLD JSON, −0.5 parseable but invalid, −2.0 unparseable\n- `shape_match`: similarity × 20.0 (0–20), −1.0 sim fails, −2.0 invalid FOLD\n\n**Algorithm:** GRPO (Group Relative Policy Optimization) via TRL + Unsloth LoRA",
|
| 7 |
+
"metadata": {}
|
| 8 |
+
},
|
| 9 |
+
{
|
| 10 |
+
"cell_type": "markdown",
|
| 11 |
+
"id": "xxp4krkl6w",
|
| 12 |
+
"source": "## 1. Install Dependencies",
|
| 13 |
+
"metadata": {}
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"cell_type": "code",
|
| 17 |
+
"id": "ulhu8a5p5ti",
|
| 18 |
+
"source": "# Run this cell once to install all dependencies\n# For Colab: unsloth has a specific install process\nimport sys\nIN_COLAB = \"google.colab\" in sys.modules\n\nif IN_COLAB:\n # Unsloth's recommended Colab install\n !pip install --no-deps \"unsloth[colab-new]\"\n !pip install --no-deps trl datasets peft accelerate bitsandbytes xformers\nelse:\n !pip install -q \"trl>=0.7\" \"datasets>=2.14\" unsloth torch transformers accelerate bitsandbytes\n\n# Core origami env deps (numpy, scipy, pydantic)\n!pip install -q numpy scipy pydantic",
|
| 19 |
+
"metadata": {},
|
| 20 |
+
"execution_count": null,
|
| 21 |
+
"outputs": []
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "markdown",
|
| 25 |
+
"id": "qcetkmcq1hf",
|
| 26 |
+
"source": "## 2. Setup Python Path & Imports",
|
| 27 |
+
"metadata": {}
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"cell_type": "code",
|
| 31 |
+
"id": "3hr273dhqiv",
|
| 32 |
+
"source": "import os\nimport sys\nimport json\n\n# Add the repo root to Python path so origami_server and training modules are importable\nREPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(\"__file__\"), \"..\"))\nif REPO_ROOT not in sys.path:\n sys.path.insert(0, REPO_ROOT)\n\nprint(f\"Repo root: {REPO_ROOT}\")\nprint(f\"Python: {sys.version}\")",
|
| 33 |
+
"metadata": {},
|
| 34 |
+
"execution_count": null,
|
| 35 |
+
"outputs": []
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"id": "bnm2w57r3lc",
|
| 40 |
+
"source": "import numpy as np\n\n# Verify origami env modules load correctly\nfrom origami_server.tasks import TASKS, get_task, list_tasks\nfrom origami_server.engine.fold_parser import validate_fold, parse_fold\nfrom origami_server.engine.simulate import simulate\nfrom origami_server.engine.shape_match import compute_shape_match\nfrom training.reward import valid_fold, shape_match, extract_fold_json\n\nprint(f\"Available tasks: {list_tasks()}\")\nprint(\"All origami modules loaded successfully.\")",
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"execution_count": null,
|
| 43 |
+
"outputs": []
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "markdown",
|
| 47 |
+
"id": "lcaus7mtuj",
|
| 48 |
+
"source": "## 3. Explore the Environment\n\nSanity-check the simulator and reward functions before training.",
|
| 49 |
+
"metadata": {}
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "code",
|
| 53 |
+
"id": "hlqp4y30m87",
|
| 54 |
+
"source": "# Print all tasks with their details\nfor name, task in TASKS.items():\n print(f\"\\n{'='*50}\")\n print(f\"Task: {task['name']}\")\n print(f\"Description: {task['description']}\")\n print(f\"Difficulty: {task['difficulty']}\")\n print(f\"Paper: {task['paper']}\")\n fold = task[\"target_fold\"]\n n_verts = len(fold[\"vertices_coords\"])\n n_edges = len(fold[\"edges_vertices\"])\n n_folds = sum(1 for a in fold[\"edges_assignment\"] if a in (\"M\", \"V\"))\n print(f\"Vertices: {n_verts}, Edges: {n_edges}, Fold creases: {n_folds}\")",
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"execution_count": null,
|
| 57 |
+
"outputs": []
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"cell_type": "code",
|
| 61 |
+
"id": "dwqqus8mhlj",
|
| 62 |
+
"source": "# Test the simulator on each task\nfor name in list_tasks():\n task = get_task(name)\n target_fold = task[\"target_fold\"]\n \n # Simulate flat (0%), half (50%), and fully folded (100%)\n r_flat = simulate(target_fold, crease_percent=0.0)\n r_half = simulate(target_fold, crease_percent=0.5)\n r_full = simulate(target_fold, crease_percent=1.0)\n \n z_half = r_half.positions[:, 2].max() - r_half.positions[:, 2].min()\n \n # Shape match: target vs itself should be 1.0\n self_sim = compute_shape_match(r_full.positions, r_full.positions)\n \n print(f\"{name:15s} | converged={r_full.converged} | strain={r_full.max_strain:.6f} | \"\n f\"z_range@50%={z_half:.3f} | self_similarity={self_sim:.3f}\")",
|
| 63 |
+
"metadata": {},
|
| 64 |
+
"execution_count": null,
|
| 65 |
+
"outputs": []
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "code",
|
| 69 |
+
"id": "p1weq9kv5q",
|
| 70 |
+
"source": "# Test reward functions with mock LLM outputs\ntriangle_fold = TASKS[\"triangle\"][\"target_fold\"]\n\n# Simulate what the reward functions see during training:\n# completions = list of [{\"content\": \"...LLM response...\"}]\ngood_response = json.dumps(triangle_fold)\nbad_json = \"I think we should fold it like this...\"\ninvalid_fold = json.dumps({\"vertices_coords\": [[0, 0]], \"edges_vertices\": [], \"edges_assignment\": []})\n\ncompletions = [\n [{\"content\": f\"```json\\n{good_response}\\n```\"}], # correct answer in fenced block\n [{\"content\": bad_json}], # garbage\n [{\"content\": invalid_fold}], # parseable but invalid FOLD\n]\n\nprint(\"valid_fold rewards:\", valid_fold(completions))\nprint(\"shape_match rewards:\", shape_match(completions, task_name=\"triangle\"))\nprint()\nprint(\"Expected: valid_fold = [1.0, -2.0, -0.5]\")\nprint(\"Expected: shape_match = [20.0, -2.0, -1.0]\")",
|
| 71 |
+
"metadata": {},
|
| 72 |
+
"execution_count": null,
|
| 73 |
+
"outputs": []
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"cell_type": "markdown",
|
| 77 |
+
"id": "45l0n1hgvr",
|
| 78 |
+
"source": "## 4. Visualize Tasks\n\n2D crease patterns for each task (matplotlib).",
|
| 79 |
+
"metadata": {}
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"cell_type": "code",
|
| 83 |
+
"id": "fkopb9lgg7i",
|
| 84 |
+
"source": "import matplotlib.pyplot as plt\nfrom mpl_toolkits.mplot3d import Axes3D\nfrom mpl_toolkits.mplot3d.art3d import Poly3DCollection\n\nEDGE_COLORS = {\"M\": \"red\", \"V\": \"blue\", \"B\": \"black\"}\nEDGE_STYLES = {\"M\": \"--\", \"V\": \":\", \"B\": \"-\"}\n\nfig, axes = plt.subplots(2, 4, figsize=(16, 8))\n\nfor idx, (name, task) in enumerate(TASKS.items()):\n fold = task[\"target_fold\"]\n verts = np.array(fold[\"vertices_coords\"])\n \n # Row 1: 2D crease pattern\n ax = axes[0, idx]\n ax.set_title(f\"{name}\\n{task['description']}\", fontsize=9)\n ax.set_aspect(\"equal\")\n ax.set_xlim(-0.1, 1.1)\n ax.set_ylim(-0.1, 1.1)\n ax.grid(True, alpha=0.2)\n \n for i, (e, a) in enumerate(zip(fold[\"edges_vertices\"], fold[\"edges_assignment\"])):\n v1, v2 = verts[e[0]], verts[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n style = EDGE_STYLES.get(a, \"-\")\n lw = 2.5 if a == \"B\" else 1.8\n ax.plot([v1[0], v2[0]], [v1[1], v2[1]], color=color, linestyle=style, linewidth=lw)\n \n ax.scatter(verts[:, 0], verts[:, 1], c=\"black\", s=15, zorder=5)\n \n # Row 2: 3D folded shape\n ax3 = fig.add_subplot(2, 4, idx + 5, projection=\"3d\")\n result = simulate(fold, crease_percent=1.0)\n pos = result.positions\n \n if \"faces_vertices\" in fold:\n for face in fold[\"faces_vertices\"]:\n tri_verts = [pos[vi] for vi in face]\n poly = Poly3DCollection([tri_verts], alpha=0.3, facecolor=\"lightskyblue\", edgecolor=\"steelblue\")\n ax3.add_collection3d(poly)\n \n for i, (e, a) in enumerate(zip(fold[\"edges_vertices\"], fold[\"edges_assignment\"])):\n p1, p2 = pos[e[0]], pos[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n ax3.plot([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], color=color, linewidth=1.2)\n \n ax3.scatter(pos[:, 0], pos[:, 1], pos[:, 2], c=\"black\", s=10, zorder=5)\n ax3.set_title(f\"Folded (3D)\", fontsize=9)\n ax3.set_xlim(-0.2, 1.2)\n ax3.set_ylim(-0.2, 1.2)\n ax3.set_zlim(-0.6, 0.6)\n \n # Remove the empty 2D subplot that was in row 2\n axes[1, idx].remove()\n\nplt.tight_layout()\nplt.show()",
|
| 85 |
+
"metadata": {},
|
| 86 |
+
"execution_count": null,
|
| 87 |
+
"outputs": []
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"cell_type": "markdown",
|
| 91 |
+
"id": "a14w2fkoewq",
|
| 92 |
+
"source": "## 5. Training Configuration",
|
| 93 |
+
"metadata": {}
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"cell_type": "code",
|
| 97 |
+
"id": "2phdejbobq3",
|
| 98 |
+
"source": "# ============================================================\n# Training hyperparameters — edit these before launching\n# ============================================================\n\nTASK_NAME = \"triangle\" # \"triangle\", \"half_fold\", \"quarter_fold\", \"letter_fold\"\nMODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\" # Change to your preferred model\nMAX_STEPS = 600 # Total GRPO training steps\nNUM_GENERATIONS = 4 # Completions per prompt per step\nLEARNING_RATE = 2e-4\nLORA_R = 8 # LoRA rank\nLORA_ALPHA = 16 # LoRA alpha\nMAX_PROMPT_LENGTH = 1024\nMAX_COMPLETION_LENGTH = 1024\nDATASET_SIZE = 1000 # Number of prompt copies (same prompt repeated)\nOUTPUT_DIR = \"outputs\"\nSAVE_STEPS = 100",
|
| 99 |
+
"metadata": {},
|
| 100 |
+
"execution_count": null,
|
| 101 |
+
"outputs": []
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"cell_type": "markdown",
|
| 105 |
+
"id": "feal20fr8j5",
|
| 106 |
+
"source": "## 6. Build the Prompt & Dataset",
|
| 107 |
+
"metadata": {}
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"cell_type": "code",
|
| 111 |
+
"id": "uo7zh1dwp6r",
|
| 112 |
+
"source": "from training.train_grpo import PROMPT_TEMPLATE, build_prompt\n\ntask = get_task(TASK_NAME)\nprompt_text = build_prompt(task)\n\nprint(\"=\"*60)\nprint(\"PROMPT THAT THE LLM WILL SEE:\")\nprint(\"=\"*60)\nprint(prompt_text)",
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"execution_count": null,
|
| 115 |
+
"outputs": []
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"cell_type": "code",
|
| 119 |
+
"id": "900vyqwb8g",
|
| 120 |
+
"source": "from datasets import Dataset\n\n# GRPO pattern: same prompt repeated many times, the RL loop generates\n# multiple completions per prompt and uses relative rewards to update policy\ndataset = Dataset.from_list(\n [\n {\n \"prompt\": [{\"role\": \"user\", \"content\": prompt_text}],\n \"answer\": 0, # placeholder, not used by GRPO\n }\n ]\n * DATASET_SIZE\n)\n\nprint(f\"Dataset size: {len(dataset)}\")\nprint(f\"Sample prompt (first 100 chars): {dataset[0]['prompt'][0]['content'][:100]}...\")",
|
| 121 |
+
"metadata": {},
|
| 122 |
+
"execution_count": null,
|
| 123 |
+
"outputs": []
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"cell_type": "markdown",
|
| 127 |
+
"id": "xn6n1hpx2aa",
|
| 128 |
+
"source": "## 7. Load Model + LoRA\n\nUses Unsloth for fast 4-bit LoRA fine-tuning. Falls back to standard HuggingFace if Unsloth isn't available.",
|
| 129 |
+
"metadata": {}
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"cell_type": "code",
|
| 133 |
+
"id": "vkfaeuu9dq",
|
| 134 |
+
"source": "import torch\nprint(f\"CUDA available: {torch.cuda.is_available()}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")\nelif hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n print(\"Apple MPS (Metal) available — note: Unsloth requires CUDA, will use HF fallback\")\nelse:\n print(\"No GPU detected — training will be very slow\")",
|
| 135 |
+
"metadata": {},
|
| 136 |
+
"execution_count": null,
|
| 137 |
+
"outputs": []
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"cell_type": "code",
|
| 141 |
+
"id": "xwlkfw3xxoo",
|
| 142 |
+
"source": "USE_UNSLOTH = False\n\ntry:\n from unsloth import FastLanguageModel\n USE_UNSLOTH = True\n print(\"Using Unsloth for fast LoRA loading\")\nexcept ImportError:\n print(\"Unsloth not available, using standard HuggingFace + PEFT\")\n\nif USE_UNSLOTH:\n model, tokenizer = FastLanguageModel.from_pretrained(\n model_name=MODEL_NAME,\n load_in_4bit=True,\n max_seq_length=MAX_PROMPT_LENGTH + MAX_COMPLETION_LENGTH,\n )\n model = FastLanguageModel.get_peft_model(\n model,\n r=LORA_R,\n target_modules=[\n \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\",\n ],\n lora_alpha=LORA_ALPHA,\n use_gradient_checkpointing=\"unsloth\",\n )\nelse:\n from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n from peft import LoraConfig, get_peft_model\n\n bnb_config = BitsAndBytesConfig(\n load_in_4bit=True,\n bnb_4bit_quant_type=\"nf4\",\n bnb_4bit_compute_dtype=torch.bfloat16,\n ) if torch.cuda.is_available() else None\n\n tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n model = AutoModelForCausalLM.from_pretrained(\n MODEL_NAME,\n quantization_config=bnb_config,\n device_map=\"auto\" if torch.cuda.is_available() else \"cpu\",\n torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,\n )\n\n lora_config = LoraConfig(\n r=LORA_R,\n lora_alpha=LORA_ALPHA,\n target_modules=[\n \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\",\n ],\n task_type=\"CAUSAL_LM\",\n )\n model = get_peft_model(model, lora_config)\n\nif tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n\nmodel.print_trainable_parameters()",
|
| 143 |
+
"metadata": {},
|
| 144 |
+
"execution_count": null,
|
| 145 |
+
"outputs": []
|
| 146 |
+
},
|
| 147 |
+
{
|
| 148 |
+
"cell_type": "markdown",
|
| 149 |
+
"id": "3f7ritml396",
|
| 150 |
+
"source": "## 8. Setup GRPO Trainer",
|
| 151 |
+
"metadata": {}
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"cell_type": "code",
|
| 155 |
+
"id": "4dqsw30e9nq",
|
| 156 |
+
"source": "from trl import GRPOConfig, GRPOTrainer\n\n# Wrap shape_match to inject the task name\ndef shape_match_reward(completions, **kwargs):\n return shape_match(completions, task_name=TASK_NAME, **kwargs)\n\ntraining_args = GRPOConfig(\n temperature=1.0,\n learning_rate=LEARNING_RATE,\n weight_decay=0.001,\n warmup_ratio=0.1,\n lr_scheduler_type=\"linear\",\n optim=\"adamw_8bit\" if torch.cuda.is_available() else \"adamw_torch\",\n logging_steps=1,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=1,\n num_generations=NUM_GENERATIONS,\n max_prompt_length=MAX_PROMPT_LENGTH,\n max_completion_length=MAX_COMPLETION_LENGTH,\n max_steps=MAX_STEPS,\n save_steps=SAVE_STEPS,\n output_dir=OUTPUT_DIR,\n report_to=\"none\", # Set to \"wandb\" if you want W&B logging\n)\n\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=[valid_fold, shape_match_reward],\n args=training_args,\n train_dataset=dataset,\n)\n\nprint(f\"Trainer ready. Task: {TASK_NAME}, Model: {MODEL_NAME}\")\nprint(f\"Max steps: {MAX_STEPS}, Generations per step: {NUM_GENERATIONS}\")\nprint(f\"Reward functions: valid_fold + shape_match\")",
|
| 157 |
+
"metadata": {},
|
| 158 |
+
"execution_count": null,
|
| 159 |
+
"outputs": []
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"cell_type": "markdown",
|
| 163 |
+
"id": "62lvkfoyu1p",
|
| 164 |
+
"source": "## 9. Train!",
|
| 165 |
+
"metadata": {}
|
| 166 |
+
},
|
| 167 |
+
{
|
| 168 |
+
"cell_type": "code",
|
| 169 |
+
"id": "eohisxhna96",
|
| 170 |
+
"source": "trainer.train()",
|
| 171 |
+
"metadata": {},
|
| 172 |
+
"execution_count": null,
|
| 173 |
+
"outputs": []
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"cell_type": "markdown",
|
| 177 |
+
"id": "dd8b2pk2s7",
|
| 178 |
+
"source": "## 10. Save the Trained Model",
|
| 179 |
+
"metadata": {}
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"cell_type": "code",
|
| 183 |
+
"id": "t3d4tu6o5mc",
|
| 184 |
+
"source": "SAVE_PATH = f\"origami-{TASK_NAME}-lora\"\n\n# Save LoRA adapter\nmodel.save_pretrained(SAVE_PATH)\ntokenizer.save_pretrained(SAVE_PATH)\nprint(f\"LoRA adapter saved to {SAVE_PATH}/\")\n\n# Optional: merge LoRA into base model and save full model\n# merged_path = f\"origami-{TASK_NAME}-merged\"\n# if USE_UNSLOTH:\n# model.save_pretrained_merged(merged_path, tokenizer)\n# else:\n# merged_model = model.merge_and_unload()\n# merged_model.save_pretrained(merged_path)\n# tokenizer.save_pretrained(merged_path)\n# print(f\"Merged model saved to {merged_path}/\")",
|
| 185 |
+
"metadata": {},
|
| 186 |
+
"execution_count": null,
|
| 187 |
+
"outputs": []
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
"cell_type": "markdown",
|
| 191 |
+
"id": "q18eizy1ok",
|
| 192 |
+
"source": "## 11. Evaluate — Generate & Score Completions\n\nTest the trained model by generating crease patterns and scoring them.",
|
| 193 |
+
"metadata": {}
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"cell_type": "code",
|
| 197 |
+
"id": "on56augj41",
|
| 198 |
+
"source": "# Put model in inference mode\nif USE_UNSLOTH:\n FastLanguageModel.for_inference(model)\n\nNUM_EVAL_SAMPLES = 8\n\n# Build chat messages\nmessages = [{\"role\": \"user\", \"content\": prompt_text}]\ninput_ids = tokenizer.apply_chat_template(\n messages, tokenize=True, add_generation_prompt=True, return_tensors=\"pt\"\n).to(model.device)\n\nprint(f\"Generating {NUM_EVAL_SAMPLES} completions...\")\nprint(f\"Input length: {input_ids.shape[1]} tokens\\n\")\n\neval_completions = []\nfor i in range(NUM_EVAL_SAMPLES):\n with torch.no_grad():\n output = model.generate(\n input_ids,\n max_new_tokens=MAX_COMPLETION_LENGTH,\n temperature=0.7,\n top_p=0.9,\n do_sample=True,\n pad_token_id=tokenizer.pad_token_id,\n )\n response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)\n eval_completions.append([{\"content\": response}])\n \n # Quick score\n fold_data = extract_fold_json(response)\n if fold_data is None:\n status = \"UNPARSEABLE\"\n else:\n is_valid, err = validate_fold(fold_data)\n if not is_valid:\n status = f\"INVALID: {err}\"\n else:\n try:\n result = simulate(fold_data, crease_percent=1.0)\n target_result = simulate(task[\"target_fold\"], crease_percent=1.0)\n sim = compute_shape_match(result.positions, target_result.positions)\n status = f\"similarity={sim:.3f} (reward={sim * 20:.1f})\"\n except Exception as e:\n status = f\"SIM ERROR: {e}\"\n \n print(f\" Sample {i+1}: {status}\")\n\n# Compute aggregate reward scores\nprint(f\"\\nAggregate rewards:\")\nvf_scores = valid_fold(eval_completions)\nsm_scores = shape_match(eval_completions, task_name=TASK_NAME)\nprint(f\" valid_fold: mean={np.mean(vf_scores):.2f}, scores={vf_scores}\")\nprint(f\" shape_match: mean={np.mean(sm_scores):.2f}, scores={sm_scores}\")",
|
| 199 |
+
"metadata": {},
|
| 200 |
+
"execution_count": null,
|
| 201 |
+
"outputs": []
|
| 202 |
+
},
|
| 203 |
+
{
|
| 204 |
+
"cell_type": "markdown",
|
| 205 |
+
"id": "tb1y8hszrk",
|
| 206 |
+
"source": "## 12. Visualize a Generated Fold\n\nPick the best completion and visualize its crease pattern + 3D fold vs the target.",
|
| 207 |
+
"metadata": {}
|
| 208 |
+
},
|
| 209 |
+
{
|
| 210 |
+
"cell_type": "code",
|
| 211 |
+
"id": "0zo3krbkiqej",
|
| 212 |
+
"source": "# Find the best valid completion\nbest_idx = int(np.argmax(sm_scores))\nbest_response = eval_completions[best_idx][0][\"content\"]\nbest_fold = extract_fold_json(best_response)\n\nif best_fold is None or sm_scores[best_idx] <= 0:\n print(\"No valid completions to visualize.\")\nelse:\n is_valid, _ = validate_fold(best_fold)\n if not is_valid:\n print(\"Best completion has invalid FOLD structure.\")\n else:\n pred_result = simulate(best_fold, crease_percent=1.0)\n target_result = simulate(task[\"target_fold\"], crease_percent=1.0)\n \n fig = plt.figure(figsize=(14, 5))\n \n # 1) Generated 2D crease pattern\n ax1 = fig.add_subplot(131)\n ax1.set_title(f\"Generated Crease Pattern\\n(sample {best_idx+1})\", fontsize=10)\n ax1.set_aspect(\"equal\")\n verts = np.array(best_fold[\"vertices_coords\"])\n for i, (e, a) in enumerate(zip(best_fold[\"edges_vertices\"], best_fold[\"edges_assignment\"])):\n v1, v2 = verts[e[0]], verts[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n style = EDGE_STYLES.get(a, \"-\")\n ax1.plot([v1[0], v2[0]], [v1[1], v2[1]], color=color, linestyle=style, linewidth=2)\n ax1.scatter(verts[:, 0], verts[:, 1], c=\"black\", s=20, zorder=5)\n ax1.grid(True, alpha=0.2)\n \n # 2) Generated 3D fold\n ax2 = fig.add_subplot(132, projection=\"3d\")\n ax2.set_title(f\"Generated 3D Fold\\nsimilarity={sm_scores[best_idx]/20:.3f}\", fontsize=10)\n pos = pred_result.positions\n for i, (e, a) in enumerate(zip(best_fold[\"edges_vertices\"], best_fold[\"edges_assignment\"])):\n p1, p2 = pos[e[0]], pos[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n ax2.plot([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], color=color, linewidth=1.5)\n ax2.scatter(pos[:, 0], pos[:, 1], pos[:, 2], c=\"black\", s=15, zorder=5)\n \n # 3) Target 3D fold\n ax3 = fig.add_subplot(133, projection=\"3d\")\n ax3.set_title(\"Target 3D Fold\", fontsize=10)\n tpos = target_result.positions\n tfold = task[\"target_fold\"]\n for i, (e, a) in enumerate(zip(tfold[\"edges_vertices\"], tfold[\"edges_assignment\"])):\n p1, p2 = tpos[e[0]], tpos[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n ax3.plot([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], color=color, linewidth=1.5)\n ax3.scatter(tpos[:, 0], tpos[:, 1], tpos[:, 2], c=\"black\", s=15, zorder=5)\n \n plt.tight_layout()\n plt.show()\n \n print(f\"\\nBest generated FOLD JSON:\")\n print(json.dumps(best_fold, indent=2))",
|
| 213 |
+
"metadata": {},
|
| 214 |
+
"execution_count": null,
|
| 215 |
+
"outputs": []
|
| 216 |
+
},
|
| 217 |
+
{
|
| 218 |
+
"cell_type": "markdown",
|
| 219 |
+
"id": "qlakksqmoe",
|
| 220 |
+
"source": "## 13. Plot Training Logs",
|
| 221 |
+
"metadata": {}
|
| 222 |
+
},
|
| 223 |
+
{
|
| 224 |
+
"cell_type": "code",
|
| 225 |
+
"id": "6nivbx4wgx9",
|
| 226 |
+
"source": "# Extract training logs from the trainer\nlogs = trainer.state.log_history\n\n# Parse out loss and reward metrics\nsteps, losses, rewards = [], [], []\nfor entry in logs:\n if \"loss\" in entry:\n steps.append(entry.get(\"step\", 0))\n losses.append(entry[\"loss\"])\n if \"reward\" in entry:\n rewards.append(entry[\"reward\"])\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n\nif losses:\n ax1.plot(steps[:len(losses)], losses, color=\"steelblue\", linewidth=1, alpha=0.7)\n # Smoothed line\n if len(losses) > 10:\n window = min(20, len(losses) // 5)\n smoothed = np.convolve(losses, np.ones(window)/window, mode=\"valid\")\n ax1.plot(steps[window-1:len(smoothed)+window-1], smoothed, color=\"navy\", linewidth=2)\n ax1.set_xlabel(\"Step\")\n ax1.set_ylabel(\"Loss\")\n ax1.set_title(\"Training Loss\")\n ax1.grid(True, alpha=0.3)\nelse:\n ax1.text(0.5, 0.5, \"No loss data\", ha=\"center\", va=\"center\", transform=ax1.transAxes)\n\nif rewards:\n ax2.plot(range(len(rewards)), rewards, color=\"coral\", linewidth=1, alpha=0.7)\n if len(rewards) > 10:\n window = min(20, len(rewards) // 5)\n smoothed = np.convolve(rewards, np.ones(window)/window, mode=\"valid\")\n ax2.plot(range(window-1, len(smoothed)+window-1), smoothed, color=\"darkred\", linewidth=2)\n ax2.set_xlabel(\"Step\")\n ax2.set_ylabel(\"Reward\")\n ax2.set_title(\"Mean Reward\")\n ax2.grid(True, alpha=0.3)\nelse:\n ax2.text(0.5, 0.5, \"No reward data\", ha=\"center\", va=\"center\", transform=ax2.transAxes)\n\nplt.tight_layout()\nplt.show()",
|
| 227 |
+
"metadata": {},
|
| 228 |
+
"execution_count": null,
|
| 229 |
+
"outputs": []
|
| 230 |
+
}
|
| 231 |
+
],
|
| 232 |
+
"metadata": {
|
| 233 |
+
"kernelspec": {
|
| 234 |
+
"display_name": "Python 3",
|
| 235 |
+
"language": "python",
|
| 236 |
+
"name": "python3"
|
| 237 |
+
},
|
| 238 |
+
"language_info": {
|
| 239 |
+
"name": "python",
|
| 240 |
+
"version": "3.11.0"
|
| 241 |
+
}
|
| 242 |
+
},
|
| 243 |
+
"nbformat": 4,
|
| 244 |
+
"nbformat_minor": 5
|
| 245 |
+
}
|