Add GRPO training notebook + Dockerfile for cloud training

#1
training/Dockerfile.train ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
4
+
5
+ WORKDIR /app
6
+
7
+ # Install PyTorch with CUDA support + training stack
8
+ RUN pip install --no-cache-dir \
9
+ torch --index-url https://download.pytorch.org/whl/cu124 && \
10
+ pip install --no-cache-dir \
11
+ "trl>=0.7" \
12
+ "datasets>=2.14" \
13
+ "transformers>=4.40" \
14
+ "accelerate>=0.30" \
15
+ "peft>=0.10" \
16
+ "bitsandbytes>=0.43" \
17
+ numpy scipy pydantic
18
+
19
+ # Copy the full repo
20
+ COPY . /app
21
+
22
+ # Default: run training script
23
+ # Override TASK, MODEL, MAX_STEPS etc. via env vars on Northflank
24
+ ENV TASK="triangle"
25
+ ENV MODEL="Qwen/Qwen2.5-3B-Instruct"
26
+ ENV MAX_STEPS="600"
27
+ ENV NUM_GENERATIONS="4"
28
+ ENV LR="2e-4"
29
+
30
+ CMD ["sh", "-c", "python -m training.train_grpo --task $TASK --model $MODEL --max_steps $MAX_STEPS --num_generations $NUM_GENERATIONS --lr $LR"]
31
+ ="triangle"
32
+ ENV MODEL="Qwen/Qwen2.5-3B-Instruct"
33
+ ENV MAX_STEPS="600"
34
+ ENV NUM_GENERATIONS="4"
35
+ ENV LR="2e-4"
36
+
37
+ CMD ["sh", "-c", "python -m training.train_grpo --task $TASK --model $MODEL --max_steps $MAX_STEPS --num_generations $NUM_GENERATIONS --lr $LR"]
training/train_grpo.py CHANGED
@@ -5,11 +5,15 @@ Follows the 2048 OpenEnv + Unsloth pattern:
5
  - Two reward functions: valid_fold + shape_match
6
  - GRPOTrainer from TRL handles the RL loop
7
 
8
- Usage (Colab):
9
- python -m origami_env.training.train_grpo --task triangle --max_steps 600
 
 
 
10
  """
11
 
12
  import argparse
 
13
 
14
  PROMPT_TEMPLATE = """You are an origami designer. Generate a FOLD-format crease pattern
15
  that, when folded, produces the target shape described below.
@@ -46,18 +50,24 @@ def main():
46
  parser.add_argument("--task", default="triangle", help="Task name")
47
  parser.add_argument("--max_steps", type=int, default=600)
48
  parser.add_argument("--num_generations", type=int, default=4)
49
- parser.add_argument("--model", default="unsloth/gpt-oss-20b")
50
  parser.add_argument("--lr", type=float, default=2e-4)
51
  args = parser.parse_args()
52
 
53
  # --- These imports are heavy, only load when actually training ---
54
  from datasets import Dataset
55
  from trl import GRPOConfig, GRPOTrainer
56
- from unsloth import FastLanguageModel
57
 
58
  from origami_server.tasks import get_task
59
  from training.reward import shape_match, valid_fold
60
 
 
 
 
 
 
 
 
61
  task = get_task(args.task)
62
  prompt_text = build_prompt(task)
63
 
@@ -73,22 +83,48 @@ def main():
73
  )
74
 
75
  # Load model with LoRA
76
- model, tokenizer = FastLanguageModel.from_pretrained(
77
- model_name=args.model,
78
- load_in_4bit=True,
79
- max_seq_length=2048,
80
- )
81
-
82
- model = FastLanguageModel.get_peft_model(
83
- model,
84
- r=8,
85
- target_modules=[
86
- "q_proj", "k_proj", "v_proj", "o_proj",
87
- "gate_proj", "up_proj", "down_proj",
88
- ],
89
- lora_alpha=16,
90
- use_gradient_checkpointing="unsloth",
91
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  # Wrap shape_match to inject task_name
94
  def shape_match_reward(completions, **kwargs):
@@ -110,7 +146,7 @@ def main():
110
  max_completion_length=1024,
111
  max_steps=args.max_steps,
112
  save_steps=100,
113
- output_dir="outputs",
114
  )
115
 
116
  trainer = GRPOTrainer(
@@ -123,6 +159,15 @@ def main():
123
 
124
  trainer.train()
125
 
 
 
 
 
 
 
 
 
 
126
 
127
  if __name__ == "__main__":
128
  main()
 
5
  - Two reward functions: valid_fold + shape_match
6
  - GRPOTrainer from TRL handles the RL loop
7
 
8
+ Usage (local/Colab):
9
+ python -m training.train_grpo --task triangle --max_steps 600
10
+
11
+ Usage (Northflank — env vars set in Dockerfile.train):
12
+ python -m training.train_grpo --task $TASK --model $MODEL --max_steps $MAX_STEPS
13
  """
14
 
15
  import argparse
16
+ import os
17
 
18
  PROMPT_TEMPLATE = """You are an origami designer. Generate a FOLD-format crease pattern
19
  that, when folded, produces the target shape described below.
 
50
  parser.add_argument("--task", default="triangle", help="Task name")
51
  parser.add_argument("--max_steps", type=int, default=600)
52
  parser.add_argument("--num_generations", type=int, default=4)
53
+ parser.add_argument("--model", default="Qwen/Qwen2.5-3B-Instruct")
54
  parser.add_argument("--lr", type=float, default=2e-4)
55
  args = parser.parse_args()
56
 
57
  # --- These imports are heavy, only load when actually training ---
58
  from datasets import Dataset
59
  from trl import GRPOConfig, GRPOTrainer
 
60
 
61
  from origami_server.tasks import get_task
62
  from training.reward import shape_match, valid_fold
63
 
64
+ # Try Unsloth first (CUDA), fall back to HF+PEFT
65
+ try:
66
+ from unsloth import FastLanguageModel
67
+ USE_UNSLOTH = True
68
+ except ImportError:
69
+ USE_UNSLOTH = False
70
+
71
  task = get_task(args.task)
72
  prompt_text = build_prompt(task)
73
 
 
83
  )
84
 
85
  # Load model with LoRA
86
+ if USE_UNSLOTH:
87
+ model, tokenizer = FastLanguageModel.from_pretrained(
88
+ model_name=args.model,
89
+ load_in_4bit=True,
90
+ max_seq_length=2048,
91
+ )
92
+ model = FastLanguageModel.get_peft_model(
93
+ model,
94
+ r=8,
95
+ target_modules=[
96
+ "q_proj", "k_proj", "v_proj", "o_proj",
97
+ "gate_proj", "up_proj", "down_proj",
98
+ ],
99
+ lora_alpha=16,
100
+ use_gradient_checkpointing="unsloth",
101
+ )
102
+ else:
103
+ import torch
104
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
105
+ from peft import LoraConfig, get_peft_model
106
+
107
+ bnb_config = BitsAndBytesConfig(
108
+ load_in_4bit=True,
109
+ bnb_4bit_quant_type="nf4",
110
+ bnb_4bit_compute_dtype=torch.bfloat16,
111
+ ) if torch.cuda.is_available() else None
112
+
113
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
114
+ model = AutoModelForCausalLM.from_pretrained(
115
+ args.model,
116
+ quantization_config=bnb_config,
117
+ device_map="auto" if torch.cuda.is_available() else "cpu",
118
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
119
+ )
120
+ model = get_peft_model(model, LoraConfig(
121
+ r=8, lora_alpha=16, task_type="CAUSAL_LM",
122
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
123
+ "gate_proj", "up_proj", "down_proj"],
124
+ ))
125
+
126
+ if tokenizer.pad_token is None:
127
+ tokenizer.pad_token = tokenizer.eos_token
128
 
129
  # Wrap shape_match to inject task_name
130
  def shape_match_reward(completions, **kwargs):
 
146
  max_completion_length=1024,
147
  max_steps=args.max_steps,
148
  save_steps=100,
149
+ output_dir=os.environ.get("OUTPUT_DIR", "outputs"),
150
  )
151
 
152
  trainer = GRPOTrainer(
 
159
 
160
  trainer.train()
161
 
162
+ # Save the LoRA adapter
163
+ save_path = os.path.join(
164
+ os.environ.get("OUTPUT_DIR", "outputs"),
165
+ f"origami-{args.task}-lora-final",
166
+ )
167
+ model.save_pretrained(save_path)
168
+ tokenizer.save_pretrained(save_path)
169
+ print(f"Model saved to {save_path}")
170
+
171
 
172
  if __name__ == "__main__":
173
  main()
training/train_origami.ipynb ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "p8uwc5bkc4n",
6
+ "source": "# Origami RL — GRPO Training Notebook\n\nTrain an LLM to generate valid FOLD-format crease patterns that fold into target shapes.\n\n**Pipeline:**\n1. LLM receives a prompt describing a target shape (e.g. \"fold diagonally into a triangle\")\n2. LLM generates a FOLD JSON crease pattern\n3. Physics simulator folds the paper analytically\n4. Reward = shape similarity (chamfer distance) to target × 20\n\n**Reward functions:**\n- `valid_fold`: +1.0 valid FOLD JSON, −0.5 parseable but invalid, −2.0 unparseable\n- `shape_match`: similarity × 20.0 (0–20), −1.0 sim fails, −2.0 invalid FOLD\n\n**Algorithm:** GRPO (Group Relative Policy Optimization) via TRL + Unsloth LoRA",
7
+ "metadata": {}
8
+ },
9
+ {
10
+ "cell_type": "markdown",
11
+ "id": "xxp4krkl6w",
12
+ "source": "## 1. Install Dependencies",
13
+ "metadata": {}
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "id": "ulhu8a5p5ti",
18
+ "source": "# Run this cell once to install all dependencies\n# For Colab: unsloth has a specific install process\nimport sys\nIN_COLAB = \"google.colab\" in sys.modules\n\nif IN_COLAB:\n # Unsloth's recommended Colab install\n !pip install --no-deps \"unsloth[colab-new]\"\n !pip install --no-deps trl datasets peft accelerate bitsandbytes xformers\nelse:\n !pip install -q \"trl>=0.7\" \"datasets>=2.14\" unsloth torch transformers accelerate bitsandbytes\n\n# Core origami env deps (numpy, scipy, pydantic)\n!pip install -q numpy scipy pydantic",
19
+ "metadata": {},
20
+ "execution_count": null,
21
+ "outputs": []
22
+ },
23
+ {
24
+ "cell_type": "markdown",
25
+ "id": "qcetkmcq1hf",
26
+ "source": "## 2. Setup Python Path & Imports",
27
+ "metadata": {}
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "id": "3hr273dhqiv",
32
+ "source": "import os\nimport sys\nimport json\n\n# Add the repo root to Python path so origami_server and training modules are importable\nREPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(\"__file__\"), \"..\"))\nif REPO_ROOT not in sys.path:\n sys.path.insert(0, REPO_ROOT)\n\nprint(f\"Repo root: {REPO_ROOT}\")\nprint(f\"Python: {sys.version}\")",
33
+ "metadata": {},
34
+ "execution_count": null,
35
+ "outputs": []
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "id": "bnm2w57r3lc",
40
+ "source": "import numpy as np\n\n# Verify origami env modules load correctly\nfrom origami_server.tasks import TASKS, get_task, list_tasks\nfrom origami_server.engine.fold_parser import validate_fold, parse_fold\nfrom origami_server.engine.simulate import simulate\nfrom origami_server.engine.shape_match import compute_shape_match\nfrom training.reward import valid_fold, shape_match, extract_fold_json\n\nprint(f\"Available tasks: {list_tasks()}\")\nprint(\"All origami modules loaded successfully.\")",
41
+ "metadata": {},
42
+ "execution_count": null,
43
+ "outputs": []
44
+ },
45
+ {
46
+ "cell_type": "markdown",
47
+ "id": "lcaus7mtuj",
48
+ "source": "## 3. Explore the Environment\n\nSanity-check the simulator and reward functions before training.",
49
+ "metadata": {}
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "id": "hlqp4y30m87",
54
+ "source": "# Print all tasks with their details\nfor name, task in TASKS.items():\n print(f\"\\n{'='*50}\")\n print(f\"Task: {task['name']}\")\n print(f\"Description: {task['description']}\")\n print(f\"Difficulty: {task['difficulty']}\")\n print(f\"Paper: {task['paper']}\")\n fold = task[\"target_fold\"]\n n_verts = len(fold[\"vertices_coords\"])\n n_edges = len(fold[\"edges_vertices\"])\n n_folds = sum(1 for a in fold[\"edges_assignment\"] if a in (\"M\", \"V\"))\n print(f\"Vertices: {n_verts}, Edges: {n_edges}, Fold creases: {n_folds}\")",
55
+ "metadata": {},
56
+ "execution_count": null,
57
+ "outputs": []
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "id": "dwqqus8mhlj",
62
+ "source": "# Test the simulator on each task\nfor name in list_tasks():\n task = get_task(name)\n target_fold = task[\"target_fold\"]\n \n # Simulate flat (0%), half (50%), and fully folded (100%)\n r_flat = simulate(target_fold, crease_percent=0.0)\n r_half = simulate(target_fold, crease_percent=0.5)\n r_full = simulate(target_fold, crease_percent=1.0)\n \n z_half = r_half.positions[:, 2].max() - r_half.positions[:, 2].min()\n \n # Shape match: target vs itself should be 1.0\n self_sim = compute_shape_match(r_full.positions, r_full.positions)\n \n print(f\"{name:15s} | converged={r_full.converged} | strain={r_full.max_strain:.6f} | \"\n f\"z_range@50%={z_half:.3f} | self_similarity={self_sim:.3f}\")",
63
+ "metadata": {},
64
+ "execution_count": null,
65
+ "outputs": []
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "id": "p1weq9kv5q",
70
+ "source": "# Test reward functions with mock LLM outputs\ntriangle_fold = TASKS[\"triangle\"][\"target_fold\"]\n\n# Simulate what the reward functions see during training:\n# completions = list of [{\"content\": \"...LLM response...\"}]\ngood_response = json.dumps(triangle_fold)\nbad_json = \"I think we should fold it like this...\"\ninvalid_fold = json.dumps({\"vertices_coords\": [[0, 0]], \"edges_vertices\": [], \"edges_assignment\": []})\n\ncompletions = [\n [{\"content\": f\"```json\\n{good_response}\\n```\"}], # correct answer in fenced block\n [{\"content\": bad_json}], # garbage\n [{\"content\": invalid_fold}], # parseable but invalid FOLD\n]\n\nprint(\"valid_fold rewards:\", valid_fold(completions))\nprint(\"shape_match rewards:\", shape_match(completions, task_name=\"triangle\"))\nprint()\nprint(\"Expected: valid_fold = [1.0, -2.0, -0.5]\")\nprint(\"Expected: shape_match = [20.0, -2.0, -1.0]\")",
71
+ "metadata": {},
72
+ "execution_count": null,
73
+ "outputs": []
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "id": "45l0n1hgvr",
78
+ "source": "## 4. Visualize Tasks\n\n2D crease patterns for each task (matplotlib).",
79
+ "metadata": {}
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "id": "fkopb9lgg7i",
84
+ "source": "import matplotlib.pyplot as plt\nfrom mpl_toolkits.mplot3d import Axes3D\nfrom mpl_toolkits.mplot3d.art3d import Poly3DCollection\n\nEDGE_COLORS = {\"M\": \"red\", \"V\": \"blue\", \"B\": \"black\"}\nEDGE_STYLES = {\"M\": \"--\", \"V\": \":\", \"B\": \"-\"}\n\nfig, axes = plt.subplots(2, 4, figsize=(16, 8))\n\nfor idx, (name, task) in enumerate(TASKS.items()):\n fold = task[\"target_fold\"]\n verts = np.array(fold[\"vertices_coords\"])\n \n # Row 1: 2D crease pattern\n ax = axes[0, idx]\n ax.set_title(f\"{name}\\n{task['description']}\", fontsize=9)\n ax.set_aspect(\"equal\")\n ax.set_xlim(-0.1, 1.1)\n ax.set_ylim(-0.1, 1.1)\n ax.grid(True, alpha=0.2)\n \n for i, (e, a) in enumerate(zip(fold[\"edges_vertices\"], fold[\"edges_assignment\"])):\n v1, v2 = verts[e[0]], verts[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n style = EDGE_STYLES.get(a, \"-\")\n lw = 2.5 if a == \"B\" else 1.8\n ax.plot([v1[0], v2[0]], [v1[1], v2[1]], color=color, linestyle=style, linewidth=lw)\n \n ax.scatter(verts[:, 0], verts[:, 1], c=\"black\", s=15, zorder=5)\n \n # Row 2: 3D folded shape\n ax3 = fig.add_subplot(2, 4, idx + 5, projection=\"3d\")\n result = simulate(fold, crease_percent=1.0)\n pos = result.positions\n \n if \"faces_vertices\" in fold:\n for face in fold[\"faces_vertices\"]:\n tri_verts = [pos[vi] for vi in face]\n poly = Poly3DCollection([tri_verts], alpha=0.3, facecolor=\"lightskyblue\", edgecolor=\"steelblue\")\n ax3.add_collection3d(poly)\n \n for i, (e, a) in enumerate(zip(fold[\"edges_vertices\"], fold[\"edges_assignment\"])):\n p1, p2 = pos[e[0]], pos[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n ax3.plot([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], color=color, linewidth=1.2)\n \n ax3.scatter(pos[:, 0], pos[:, 1], pos[:, 2], c=\"black\", s=10, zorder=5)\n ax3.set_title(f\"Folded (3D)\", fontsize=9)\n ax3.set_xlim(-0.2, 1.2)\n ax3.set_ylim(-0.2, 1.2)\n ax3.set_zlim(-0.6, 0.6)\n \n # Remove the empty 2D subplot that was in row 2\n axes[1, idx].remove()\n\nplt.tight_layout()\nplt.show()",
85
+ "metadata": {},
86
+ "execution_count": null,
87
+ "outputs": []
88
+ },
89
+ {
90
+ "cell_type": "markdown",
91
+ "id": "a14w2fkoewq",
92
+ "source": "## 5. Training Configuration",
93
+ "metadata": {}
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "id": "2phdejbobq3",
98
+ "source": "# ============================================================\n# Training hyperparameters — edit these before launching\n# ============================================================\n\nTASK_NAME = \"triangle\" # \"triangle\", \"half_fold\", \"quarter_fold\", \"letter_fold\"\nMODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\" # Change to your preferred model\nMAX_STEPS = 600 # Total GRPO training steps\nNUM_GENERATIONS = 4 # Completions per prompt per step\nLEARNING_RATE = 2e-4\nLORA_R = 8 # LoRA rank\nLORA_ALPHA = 16 # LoRA alpha\nMAX_PROMPT_LENGTH = 1024\nMAX_COMPLETION_LENGTH = 1024\nDATASET_SIZE = 1000 # Number of prompt copies (same prompt repeated)\nOUTPUT_DIR = \"outputs\"\nSAVE_STEPS = 100",
99
+ "metadata": {},
100
+ "execution_count": null,
101
+ "outputs": []
102
+ },
103
+ {
104
+ "cell_type": "markdown",
105
+ "id": "feal20fr8j5",
106
+ "source": "## 6. Build the Prompt & Dataset",
107
+ "metadata": {}
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "id": "uo7zh1dwp6r",
112
+ "source": "from training.train_grpo import PROMPT_TEMPLATE, build_prompt\n\ntask = get_task(TASK_NAME)\nprompt_text = build_prompt(task)\n\nprint(\"=\"*60)\nprint(\"PROMPT THAT THE LLM WILL SEE:\")\nprint(\"=\"*60)\nprint(prompt_text)",
113
+ "metadata": {},
114
+ "execution_count": null,
115
+ "outputs": []
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "id": "900vyqwb8g",
120
+ "source": "from datasets import Dataset\n\n# GRPO pattern: same prompt repeated many times, the RL loop generates\n# multiple completions per prompt and uses relative rewards to update policy\ndataset = Dataset.from_list(\n [\n {\n \"prompt\": [{\"role\": \"user\", \"content\": prompt_text}],\n \"answer\": 0, # placeholder, not used by GRPO\n }\n ]\n * DATASET_SIZE\n)\n\nprint(f\"Dataset size: {len(dataset)}\")\nprint(f\"Sample prompt (first 100 chars): {dataset[0]['prompt'][0]['content'][:100]}...\")",
121
+ "metadata": {},
122
+ "execution_count": null,
123
+ "outputs": []
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "id": "xn6n1hpx2aa",
128
+ "source": "## 7. Load Model + LoRA\n\nUses Unsloth for fast 4-bit LoRA fine-tuning. Falls back to standard HuggingFace if Unsloth isn't available.",
129
+ "metadata": {}
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "id": "vkfaeuu9dq",
134
+ "source": "import torch\nprint(f\"CUDA available: {torch.cuda.is_available()}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")\nelif hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n print(\"Apple MPS (Metal) available — note: Unsloth requires CUDA, will use HF fallback\")\nelse:\n print(\"No GPU detected — training will be very slow\")",
135
+ "metadata": {},
136
+ "execution_count": null,
137
+ "outputs": []
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "id": "xwlkfw3xxoo",
142
+ "source": "USE_UNSLOTH = False\n\ntry:\n from unsloth import FastLanguageModel\n USE_UNSLOTH = True\n print(\"Using Unsloth for fast LoRA loading\")\nexcept ImportError:\n print(\"Unsloth not available, using standard HuggingFace + PEFT\")\n\nif USE_UNSLOTH:\n model, tokenizer = FastLanguageModel.from_pretrained(\n model_name=MODEL_NAME,\n load_in_4bit=True,\n max_seq_length=MAX_PROMPT_LENGTH + MAX_COMPLETION_LENGTH,\n )\n model = FastLanguageModel.get_peft_model(\n model,\n r=LORA_R,\n target_modules=[\n \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\",\n ],\n lora_alpha=LORA_ALPHA,\n use_gradient_checkpointing=\"unsloth\",\n )\nelse:\n from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n from peft import LoraConfig, get_peft_model\n\n bnb_config = BitsAndBytesConfig(\n load_in_4bit=True,\n bnb_4bit_quant_type=\"nf4\",\n bnb_4bit_compute_dtype=torch.bfloat16,\n ) if torch.cuda.is_available() else None\n\n tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n model = AutoModelForCausalLM.from_pretrained(\n MODEL_NAME,\n quantization_config=bnb_config,\n device_map=\"auto\" if torch.cuda.is_available() else \"cpu\",\n torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,\n )\n\n lora_config = LoraConfig(\n r=LORA_R,\n lora_alpha=LORA_ALPHA,\n target_modules=[\n \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\",\n ],\n task_type=\"CAUSAL_LM\",\n )\n model = get_peft_model(model, lora_config)\n\nif tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n\nmodel.print_trainable_parameters()",
143
+ "metadata": {},
144
+ "execution_count": null,
145
+ "outputs": []
146
+ },
147
+ {
148
+ "cell_type": "markdown",
149
+ "id": "3f7ritml396",
150
+ "source": "## 8. Setup GRPO Trainer",
151
+ "metadata": {}
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "id": "4dqsw30e9nq",
156
+ "source": "from trl import GRPOConfig, GRPOTrainer\n\n# Wrap shape_match to inject the task name\ndef shape_match_reward(completions, **kwargs):\n return shape_match(completions, task_name=TASK_NAME, **kwargs)\n\ntraining_args = GRPOConfig(\n temperature=1.0,\n learning_rate=LEARNING_RATE,\n weight_decay=0.001,\n warmup_ratio=0.1,\n lr_scheduler_type=\"linear\",\n optim=\"adamw_8bit\" if torch.cuda.is_available() else \"adamw_torch\",\n logging_steps=1,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=1,\n num_generations=NUM_GENERATIONS,\n max_prompt_length=MAX_PROMPT_LENGTH,\n max_completion_length=MAX_COMPLETION_LENGTH,\n max_steps=MAX_STEPS,\n save_steps=SAVE_STEPS,\n output_dir=OUTPUT_DIR,\n report_to=\"none\", # Set to \"wandb\" if you want W&B logging\n)\n\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=[valid_fold, shape_match_reward],\n args=training_args,\n train_dataset=dataset,\n)\n\nprint(f\"Trainer ready. Task: {TASK_NAME}, Model: {MODEL_NAME}\")\nprint(f\"Max steps: {MAX_STEPS}, Generations per step: {NUM_GENERATIONS}\")\nprint(f\"Reward functions: valid_fold + shape_match\")",
157
+ "metadata": {},
158
+ "execution_count": null,
159
+ "outputs": []
160
+ },
161
+ {
162
+ "cell_type": "markdown",
163
+ "id": "62lvkfoyu1p",
164
+ "source": "## 9. Train!",
165
+ "metadata": {}
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "id": "eohisxhna96",
170
+ "source": "trainer.train()",
171
+ "metadata": {},
172
+ "execution_count": null,
173
+ "outputs": []
174
+ },
175
+ {
176
+ "cell_type": "markdown",
177
+ "id": "dd8b2pk2s7",
178
+ "source": "## 10. Save the Trained Model",
179
+ "metadata": {}
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "id": "t3d4tu6o5mc",
184
+ "source": "SAVE_PATH = f\"origami-{TASK_NAME}-lora\"\n\n# Save LoRA adapter\nmodel.save_pretrained(SAVE_PATH)\ntokenizer.save_pretrained(SAVE_PATH)\nprint(f\"LoRA adapter saved to {SAVE_PATH}/\")\n\n# Optional: merge LoRA into base model and save full model\n# merged_path = f\"origami-{TASK_NAME}-merged\"\n# if USE_UNSLOTH:\n# model.save_pretrained_merged(merged_path, tokenizer)\n# else:\n# merged_model = model.merge_and_unload()\n# merged_model.save_pretrained(merged_path)\n# tokenizer.save_pretrained(merged_path)\n# print(f\"Merged model saved to {merged_path}/\")",
185
+ "metadata": {},
186
+ "execution_count": null,
187
+ "outputs": []
188
+ },
189
+ {
190
+ "cell_type": "markdown",
191
+ "id": "q18eizy1ok",
192
+ "source": "## 11. Evaluate — Generate & Score Completions\n\nTest the trained model by generating crease patterns and scoring them.",
193
+ "metadata": {}
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "id": "on56augj41",
198
+ "source": "# Put model in inference mode\nif USE_UNSLOTH:\n FastLanguageModel.for_inference(model)\n\nNUM_EVAL_SAMPLES = 8\n\n# Build chat messages\nmessages = [{\"role\": \"user\", \"content\": prompt_text}]\ninput_ids = tokenizer.apply_chat_template(\n messages, tokenize=True, add_generation_prompt=True, return_tensors=\"pt\"\n).to(model.device)\n\nprint(f\"Generating {NUM_EVAL_SAMPLES} completions...\")\nprint(f\"Input length: {input_ids.shape[1]} tokens\\n\")\n\neval_completions = []\nfor i in range(NUM_EVAL_SAMPLES):\n with torch.no_grad():\n output = model.generate(\n input_ids,\n max_new_tokens=MAX_COMPLETION_LENGTH,\n temperature=0.7,\n top_p=0.9,\n do_sample=True,\n pad_token_id=tokenizer.pad_token_id,\n )\n response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)\n eval_completions.append([{\"content\": response}])\n \n # Quick score\n fold_data = extract_fold_json(response)\n if fold_data is None:\n status = \"UNPARSEABLE\"\n else:\n is_valid, err = validate_fold(fold_data)\n if not is_valid:\n status = f\"INVALID: {err}\"\n else:\n try:\n result = simulate(fold_data, crease_percent=1.0)\n target_result = simulate(task[\"target_fold\"], crease_percent=1.0)\n sim = compute_shape_match(result.positions, target_result.positions)\n status = f\"similarity={sim:.3f} (reward={sim * 20:.1f})\"\n except Exception as e:\n status = f\"SIM ERROR: {e}\"\n \n print(f\" Sample {i+1}: {status}\")\n\n# Compute aggregate reward scores\nprint(f\"\\nAggregate rewards:\")\nvf_scores = valid_fold(eval_completions)\nsm_scores = shape_match(eval_completions, task_name=TASK_NAME)\nprint(f\" valid_fold: mean={np.mean(vf_scores):.2f}, scores={vf_scores}\")\nprint(f\" shape_match: mean={np.mean(sm_scores):.2f}, scores={sm_scores}\")",
199
+ "metadata": {},
200
+ "execution_count": null,
201
+ "outputs": []
202
+ },
203
+ {
204
+ "cell_type": "markdown",
205
+ "id": "tb1y8hszrk",
206
+ "source": "## 12. Visualize a Generated Fold\n\nPick the best completion and visualize its crease pattern + 3D fold vs the target.",
207
+ "metadata": {}
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "id": "0zo3krbkiqej",
212
+ "source": "# Find the best valid completion\nbest_idx = int(np.argmax(sm_scores))\nbest_response = eval_completions[best_idx][0][\"content\"]\nbest_fold = extract_fold_json(best_response)\n\nif best_fold is None or sm_scores[best_idx] <= 0:\n print(\"No valid completions to visualize.\")\nelse:\n is_valid, _ = validate_fold(best_fold)\n if not is_valid:\n print(\"Best completion has invalid FOLD structure.\")\n else:\n pred_result = simulate(best_fold, crease_percent=1.0)\n target_result = simulate(task[\"target_fold\"], crease_percent=1.0)\n \n fig = plt.figure(figsize=(14, 5))\n \n # 1) Generated 2D crease pattern\n ax1 = fig.add_subplot(131)\n ax1.set_title(f\"Generated Crease Pattern\\n(sample {best_idx+1})\", fontsize=10)\n ax1.set_aspect(\"equal\")\n verts = np.array(best_fold[\"vertices_coords\"])\n for i, (e, a) in enumerate(zip(best_fold[\"edges_vertices\"], best_fold[\"edges_assignment\"])):\n v1, v2 = verts[e[0]], verts[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n style = EDGE_STYLES.get(a, \"-\")\n ax1.plot([v1[0], v2[0]], [v1[1], v2[1]], color=color, linestyle=style, linewidth=2)\n ax1.scatter(verts[:, 0], verts[:, 1], c=\"black\", s=20, zorder=5)\n ax1.grid(True, alpha=0.2)\n \n # 2) Generated 3D fold\n ax2 = fig.add_subplot(132, projection=\"3d\")\n ax2.set_title(f\"Generated 3D Fold\\nsimilarity={sm_scores[best_idx]/20:.3f}\", fontsize=10)\n pos = pred_result.positions\n for i, (e, a) in enumerate(zip(best_fold[\"edges_vertices\"], best_fold[\"edges_assignment\"])):\n p1, p2 = pos[e[0]], pos[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n ax2.plot([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], color=color, linewidth=1.5)\n ax2.scatter(pos[:, 0], pos[:, 1], pos[:, 2], c=\"black\", s=15, zorder=5)\n \n # 3) Target 3D fold\n ax3 = fig.add_subplot(133, projection=\"3d\")\n ax3.set_title(\"Target 3D Fold\", fontsize=10)\n tpos = target_result.positions\n tfold = task[\"target_fold\"]\n for i, (e, a) in enumerate(zip(tfold[\"edges_vertices\"], tfold[\"edges_assignment\"])):\n p1, p2 = tpos[e[0]], tpos[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n ax3.plot([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], color=color, linewidth=1.5)\n ax3.scatter(tpos[:, 0], tpos[:, 1], tpos[:, 2], c=\"black\", s=15, zorder=5)\n \n plt.tight_layout()\n plt.show()\n \n print(f\"\\nBest generated FOLD JSON:\")\n print(json.dumps(best_fold, indent=2))",
213
+ "metadata": {},
214
+ "execution_count": null,
215
+ "outputs": []
216
+ },
217
+ {
218
+ "cell_type": "markdown",
219
+ "id": "qlakksqmoe",
220
+ "source": "## 13. Plot Training Logs",
221
+ "metadata": {}
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "id": "6nivbx4wgx9",
226
+ "source": "# Extract training logs from the trainer\nlogs = trainer.state.log_history\n\n# Parse out loss and reward metrics\nsteps, losses, rewards = [], [], []\nfor entry in logs:\n if \"loss\" in entry:\n steps.append(entry.get(\"step\", 0))\n losses.append(entry[\"loss\"])\n if \"reward\" in entry:\n rewards.append(entry[\"reward\"])\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n\nif losses:\n ax1.plot(steps[:len(losses)], losses, color=\"steelblue\", linewidth=1, alpha=0.7)\n # Smoothed line\n if len(losses) > 10:\n window = min(20, len(losses) // 5)\n smoothed = np.convolve(losses, np.ones(window)/window, mode=\"valid\")\n ax1.plot(steps[window-1:len(smoothed)+window-1], smoothed, color=\"navy\", linewidth=2)\n ax1.set_xlabel(\"Step\")\n ax1.set_ylabel(\"Loss\")\n ax1.set_title(\"Training Loss\")\n ax1.grid(True, alpha=0.3)\nelse:\n ax1.text(0.5, 0.5, \"No loss data\", ha=\"center\", va=\"center\", transform=ax1.transAxes)\n\nif rewards:\n ax2.plot(range(len(rewards)), rewards, color=\"coral\", linewidth=1, alpha=0.7)\n if len(rewards) > 10:\n window = min(20, len(rewards) // 5)\n smoothed = np.convolve(rewards, np.ones(window)/window, mode=\"valid\")\n ax2.plot(range(window-1, len(smoothed)+window-1), smoothed, color=\"darkred\", linewidth=2)\n ax2.set_xlabel(\"Step\")\n ax2.set_ylabel(\"Reward\")\n ax2.set_title(\"Mean Reward\")\n ax2.grid(True, alpha=0.3)\nelse:\n ax2.text(0.5, 0.5, \"No reward data\", ha=\"center\", va=\"center\", transform=ax2.transAxes)\n\nplt.tight_layout()\nplt.show()",
227
+ "metadata": {},
228
+ "execution_count": null,
229
+ "outputs": []
230
+ }
231
+ ],
232
+ "metadata": {
233
+ "kernelspec": {
234
+ "display_name": "Python 3",
235
+ "language": "python",
236
+ "name": "python3"
237
+ },
238
+ "language_info": {
239
+ "name": "python",
240
+ "version": "3.11.0"
241
+ }
242
+ },
243
+ "nbformat": 4,
244
+ "nbformat_minor": 5
245
+ }