Rewrite notebook: connect to OpenEnv server via WebSocket

#3
by sissississi - opened
Files changed (1) hide show
  1. training/train_origami.ipynb +54 -82
training/train_origami.ipynb CHANGED
@@ -2,99 +2,91 @@
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
5
- "id": "p8uwc5bkc4n",
6
- "source": "# Origami RL β€” GRPO Training Notebook\n\nTrain an LLM to generate valid FOLD-format crease patterns that fold into target shapes.\n\n**Pipeline:**\n1. LLM receives a prompt describing a target shape (e.g. \"fold diagonally into a triangle\")\n2. LLM generates a FOLD JSON crease pattern\n3. Physics simulator folds the paper analytically\n4. Reward = shape similarity (chamfer distance) to target Γ— 20\n\n**Reward functions:**\n- `valid_fold`: +1.0 valid FOLD JSON, βˆ’0.5 parseable but invalid, βˆ’2.0 unparseable\n- `shape_match`: similarity Γ— 20.0 (0–20), βˆ’1.0 sim fails, βˆ’2.0 invalid FOLD\n\n**Algorithm:** GRPO (Group Relative Policy Optimization) via TRL + Unsloth LoRA",
7
  "metadata": {}
8
  },
9
  {
10
  "cell_type": "markdown",
11
- "id": "xxp4krkl6w",
12
  "source": "## 1. Install Dependencies",
13
  "metadata": {}
14
  },
15
  {
16
  "cell_type": "code",
17
- "id": "ulhu8a5p5ti",
18
- "source": "# Run this cell once to install all dependencies\n# For Colab: unsloth has a specific install process\nimport sys\nIN_COLAB = \"google.colab\" in sys.modules\n\nif IN_COLAB:\n # Unsloth's recommended Colab install\n !pip install --no-deps \"unsloth[colab-new]\"\n !pip install --no-deps trl datasets peft accelerate bitsandbytes xformers\nelse:\n !pip install -q \"trl>=0.7\" \"datasets>=2.14\" unsloth torch transformers accelerate bitsandbytes\n\n# Core origami env deps (numpy, scipy, pydantic)\n!pip install -q numpy scipy pydantic",
19
  "metadata": {},
20
  "execution_count": null,
21
  "outputs": []
22
  },
23
  {
24
  "cell_type": "markdown",
25
- "id": "qcetkmcq1hf",
26
- "source": "## 2. Setup Python Path & Imports",
27
  "metadata": {}
28
  },
29
  {
30
  "cell_type": "code",
31
- "id": "3hr273dhqiv",
32
- "source": "import os\nimport sys\nimport json\n\n# Add the repo root to Python path so origami_server and training modules are importable\nREPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(\"__file__\"), \"..\"))\nif REPO_ROOT not in sys.path:\n sys.path.insert(0, REPO_ROOT)\n\nprint(f\"Repo root: {REPO_ROOT}\")\nprint(f\"Python: {sys.version}\")",
33
  "metadata": {},
34
  "execution_count": null,
35
  "outputs": []
36
  },
37
  {
38
  "cell_type": "code",
39
- "id": "bnm2w57r3lc",
40
- "source": "import numpy as np\n\n# Verify origami env modules load correctly\nfrom origami_server.tasks import TASKS, get_task, list_tasks\nfrom origami_server.engine.fold_parser import validate_fold, parse_fold\nfrom origami_server.engine.simulate import simulate\nfrom origami_server.engine.shape_match import compute_shape_match\nfrom training.reward import valid_fold, shape_match, extract_fold_json\n\nprint(f\"Available tasks: {list_tasks()}\")\nprint(\"All origami modules loaded successfully.\")",
41
  "metadata": {},
42
  "execution_count": null,
43
  "outputs": []
44
  },
45
  {
46
  "cell_type": "markdown",
47
- "id": "lcaus7mtuj",
48
- "source": "## 3. Explore the Environment\n\nSanity-check the simulator and reward functions before training.",
49
  "metadata": {}
50
  },
51
  {
52
  "cell_type": "code",
53
- "id": "hlqp4y30m87",
54
- "source": "# Print all tasks with their details\nfor name, task in TASKS.items():\n print(f\"\\n{'='*50}\")\n print(f\"Task: {task['name']}\")\n print(f\"Description: {task['description']}\")\n print(f\"Difficulty: {task['difficulty']}\")\n print(f\"Paper: {task['paper']}\")\n fold = task[\"target_fold\"]\n n_verts = len(fold[\"vertices_coords\"])\n n_edges = len(fold[\"edges_vertices\"])\n n_folds = sum(1 for a in fold[\"edges_assignment\"] if a in (\"M\", \"V\"))\n print(f\"Vertices: {n_verts}, Edges: {n_edges}, Fold creases: {n_folds}\")",
55
  "metadata": {},
56
  "execution_count": null,
57
  "outputs": []
58
  },
59
  {
60
- "cell_type": "code",
61
- "id": "dwqqus8mhlj",
62
- "source": "# Test the simulator on each task\nfor name in list_tasks():\n task = get_task(name)\n target_fold = task[\"target_fold\"]\n \n # Simulate flat (0%), half (50%), and fully folded (100%)\n r_flat = simulate(target_fold, crease_percent=0.0)\n r_half = simulate(target_fold, crease_percent=0.5)\n r_full = simulate(target_fold, crease_percent=1.0)\n \n z_half = r_half.positions[:, 2].max() - r_half.positions[:, 2].min()\n \n # Shape match: target vs itself should be 1.0\n self_sim = compute_shape_match(r_full.positions, r_full.positions)\n \n print(f\"{name:15s} | converged={r_full.converged} | strain={r_full.max_strain:.6f} | \"\n f\"z_range@50%={z_half:.3f} | self_similarity={self_sim:.3f}\")",
63
- "metadata": {},
64
- "execution_count": null,
65
- "outputs": []
66
  },
67
  {
68
  "cell_type": "code",
69
- "id": "p1weq9kv5q",
70
- "source": "# Test reward functions with mock LLM outputs\ntriangle_fold = TASKS[\"triangle\"][\"target_fold\"]\n\n# Simulate what the reward functions see during training:\n# completions = list of [{\"content\": \"...LLM response...\"}]\ngood_response = json.dumps(triangle_fold)\nbad_json = \"I think we should fold it like this...\"\ninvalid_fold = json.dumps({\"vertices_coords\": [[0, 0]], \"edges_vertices\": [], \"edges_assignment\": []})\n\ncompletions = [\n [{\"content\": f\"```json\\n{good_response}\\n```\"}], # correct answer in fenced block\n [{\"content\": bad_json}], # garbage\n [{\"content\": invalid_fold}], # parseable but invalid FOLD\n]\n\nprint(\"valid_fold rewards:\", valid_fold(completions))\nprint(\"shape_match rewards:\", shape_match(completions, task_name=\"triangle\"))\nprint()\nprint(\"Expected: valid_fold = [1.0, -2.0, -0.5]\")\nprint(\"Expected: shape_match = [20.0, -2.0, -1.0]\")",
71
  "metadata": {},
72
  "execution_count": null,
73
  "outputs": []
74
  },
75
- {
76
- "cell_type": "markdown",
77
- "id": "45l0n1hgvr",
78
- "source": "## 4. Visualize Tasks\n\n2D crease patterns for each task (matplotlib).",
79
- "metadata": {}
80
- },
81
  {
82
  "cell_type": "code",
83
- "id": "fkopb9lgg7i",
84
- "source": "import matplotlib.pyplot as plt\nfrom mpl_toolkits.mplot3d import Axes3D\nfrom mpl_toolkits.mplot3d.art3d import Poly3DCollection\n\nEDGE_COLORS = {\"M\": \"red\", \"V\": \"blue\", \"B\": \"black\"}\nEDGE_STYLES = {\"M\": \"--\", \"V\": \":\", \"B\": \"-\"}\n\nfig, axes = plt.subplots(2, 4, figsize=(16, 8))\n\nfor idx, (name, task) in enumerate(TASKS.items()):\n fold = task[\"target_fold\"]\n verts = np.array(fold[\"vertices_coords\"])\n \n # Row 1: 2D crease pattern\n ax = axes[0, idx]\n ax.set_title(f\"{name}\\n{task['description']}\", fontsize=9)\n ax.set_aspect(\"equal\")\n ax.set_xlim(-0.1, 1.1)\n ax.set_ylim(-0.1, 1.1)\n ax.grid(True, alpha=0.2)\n \n for i, (e, a) in enumerate(zip(fold[\"edges_vertices\"], fold[\"edges_assignment\"])):\n v1, v2 = verts[e[0]], verts[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n style = EDGE_STYLES.get(a, \"-\")\n lw = 2.5 if a == \"B\" else 1.8\n ax.plot([v1[0], v2[0]], [v1[1], v2[1]], color=color, linestyle=style, linewidth=lw)\n \n ax.scatter(verts[:, 0], verts[:, 1], c=\"black\", s=15, zorder=5)\n \n # Row 2: 3D folded shape\n ax3 = fig.add_subplot(2, 4, idx + 5, projection=\"3d\")\n result = simulate(fold, crease_percent=1.0)\n pos = result.positions\n \n if \"faces_vertices\" in fold:\n for face in fold[\"faces_vertices\"]:\n tri_verts = [pos[vi] for vi in face]\n poly = Poly3DCollection([tri_verts], alpha=0.3, facecolor=\"lightskyblue\", edgecolor=\"steelblue\")\n ax3.add_collection3d(poly)\n \n for i, (e, a) in enumerate(zip(fold[\"edges_vertices\"], fold[\"edges_assignment\"])):\n p1, p2 = pos[e[0]], pos[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n ax3.plot([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], color=color, linewidth=1.2)\n \n ax3.scatter(pos[:, 0], pos[:, 1], pos[:, 2], c=\"black\", s=10, zorder=5)\n ax3.set_title(f\"Folded (3D)\", fontsize=9)\n ax3.set_xlim(-0.2, 1.2)\n ax3.set_ylim(-0.2, 1.2)\n ax3.set_zlim(-0.6, 0.6)\n \n # Remove the empty 2D subplot that was in row 2\n axes[1, idx].remove()\n\nplt.tight_layout()\nplt.show()",
85
  "metadata": {},
86
  "execution_count": null,
87
  "outputs": []
88
  },
89
  {
90
  "cell_type": "markdown",
91
- "id": "a14w2fkoewq",
92
  "source": "## 5. Training Configuration",
93
  "metadata": {}
94
  },
95
  {
96
  "cell_type": "code",
97
- "id": "2phdejbobq3",
98
  "source": "# ============================================================\n# Training hyperparameters β€” edit these before launching\n# ============================================================\n\nTASK_NAME = \"triangle\" # \"triangle\", \"half_fold\", \"quarter_fold\", \"letter_fold\"\nMODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\" # Change to your preferred model\nMAX_STEPS = 600 # Total GRPO training steps\nNUM_GENERATIONS = 4 # Completions per prompt per step\nLEARNING_RATE = 2e-4\nLORA_R = 8 # LoRA rank\nLORA_ALPHA = 16 # LoRA alpha\nMAX_PROMPT_LENGTH = 1024\nMAX_COMPLETION_LENGTH = 1024\nDATASET_SIZE = 1000 # Number of prompt copies (same prompt repeated)\nOUTPUT_DIR = \"outputs\"\nSAVE_STEPS = 100",
99
  "metadata": {},
100
  "execution_count": null,
@@ -102,71 +94,65 @@
102
  },
103
  {
104
  "cell_type": "markdown",
105
- "id": "feal20fr8j5",
106
- "source": "## 6. Build the Prompt & Dataset",
107
  "metadata": {}
108
  },
109
  {
110
  "cell_type": "code",
111
- "id": "uo7zh1dwp6r",
112
- "source": "from training.train_grpo import PROMPT_TEMPLATE, build_prompt\n\ntask = get_task(TASK_NAME)\nprompt_text = build_prompt(task)\n\nprint(\"=\"*60)\nprint(\"PROMPT THAT THE LLM WILL SEE:\")\nprint(\"=\"*60)\nprint(prompt_text)",
113
  "metadata": {},
114
  "execution_count": null,
115
  "outputs": []
116
  },
117
  {
118
  "cell_type": "code",
119
- "id": "900vyqwb8g",
120
- "source": "from datasets import Dataset\n\n# GRPO pattern: same prompt repeated many times, the RL loop generates\n# multiple completions per prompt and uses relative rewards to update policy\ndataset = Dataset.from_list(\n [\n {\n \"prompt\": [{\"role\": \"user\", \"content\": prompt_text}],\n \"answer\": 0, # placeholder, not used by GRPO\n }\n ]\n * DATASET_SIZE\n)\n\nprint(f\"Dataset size: {len(dataset)}\")\nprint(f\"Sample prompt (first 100 chars): {dataset[0]['prompt'][0]['content'][:100]}...\")",
121
  "metadata": {},
122
  "execution_count": null,
123
  "outputs": []
124
  },
125
  {
126
  "cell_type": "markdown",
127
- "id": "xn6n1hpx2aa",
128
- "source": "## 7. Load Model + LoRA\n\nUses Unsloth for fast 4-bit LoRA fine-tuning. Falls back to standard HuggingFace if Unsloth isn't available.",
129
  "metadata": {}
130
  },
131
  {
132
  "cell_type": "code",
133
- "id": "vkfaeuu9dq",
134
- "source": "import torch\nprint(f\"CUDA available: {torch.cuda.is_available()}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")\nelif hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n print(\"Apple MPS (Metal) available β€” note: Unsloth requires CUDA, will use HF fallback\")\nelse:\n print(\"No GPU detected β€” training will be very slow\")",
135
  "metadata": {},
136
  "execution_count": null,
137
  "outputs": []
138
  },
139
  {
140
  "cell_type": "code",
141
- "id": "xwlkfw3xxoo",
142
- "source": "USE_UNSLOTH = False\n\ntry:\n from unsloth import FastLanguageModel\n USE_UNSLOTH = True\n print(\"Using Unsloth for fast LoRA loading\")\nexcept ImportError:\n print(\"Unsloth not available, using standard HuggingFace + PEFT\")\n\nif USE_UNSLOTH:\n model, tokenizer = FastLanguageModel.from_pretrained(\n model_name=MODEL_NAME,\n load_in_4bit=True,\n max_seq_length=MAX_PROMPT_LENGTH + MAX_COMPLETION_LENGTH,\n )\n model = FastLanguageModel.get_peft_model(\n model,\n r=LORA_R,\n target_modules=[\n \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\",\n ],\n lora_alpha=LORA_ALPHA,\n use_gradient_checkpointing=\"unsloth\",\n )\nelse:\n from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n from peft import LoraConfig, get_peft_model\n\n bnb_config = BitsAndBytesConfig(\n load_in_4bit=True,\n bnb_4bit_quant_type=\"nf4\",\n bnb_4bit_compute_dtype=torch.bfloat16,\n ) if torch.cuda.is_available() else None\n\n tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n model = AutoModelForCausalLM.from_pretrained(\n MODEL_NAME,\n quantization_config=bnb_config,\n device_map=\"auto\" if torch.cuda.is_available() else \"cpu\",\n torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,\n )\n\n lora_config = LoraConfig(\n r=LORA_R,\n lora_alpha=LORA_ALPHA,\n target_modules=[\n \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\",\n ],\n task_type=\"CAUSAL_LM\",\n )\n model = get_peft_model(model, lora_config)\n\nif tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n\nmodel.print_trainable_parameters()",
143
  "metadata": {},
144
  "execution_count": null,
145
  "outputs": []
146
  },
147
  {
148
  "cell_type": "markdown",
149
- "id": "3f7ritml396",
150
- "source": "## 8. Setup GRPO Trainer",
151
  "metadata": {}
152
  },
153
  {
154
  "cell_type": "code",
155
- "id": "4dqsw30e9nq",
156
- "source": "from trl import GRPOConfig, GRPOTrainer\n\n# Wrap shape_match to inject the task name\ndef shape_match_reward(completions, **kwargs):\n return shape_match(completions, task_name=TASK_NAME, **kwargs)\n\ntraining_args = GRPOConfig(\n temperature=1.0,\n learning_rate=LEARNING_RATE,\n weight_decay=0.001,\n warmup_ratio=0.1,\n lr_scheduler_type=\"linear\",\n optim=\"adamw_8bit\" if torch.cuda.is_available() else \"adamw_torch\",\n logging_steps=1,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=1,\n num_generations=NUM_GENERATIONS,\n max_prompt_length=MAX_PROMPT_LENGTH,\n max_completion_length=MAX_COMPLETION_LENGTH,\n max_steps=MAX_STEPS,\n save_steps=SAVE_STEPS,\n output_dir=OUTPUT_DIR,\n report_to=\"none\", # Set to \"wandb\" if you want W&B logging\n)\n\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=[valid_fold, shape_match_reward],\n args=training_args,\n train_dataset=dataset,\n)\n\nprint(f\"Trainer ready. Task: {TASK_NAME}, Model: {MODEL_NAME}\")\nprint(f\"Max steps: {MAX_STEPS}, Generations per step: {NUM_GENERATIONS}\")\nprint(f\"Reward functions: valid_fold + shape_match\")",
157
  "metadata": {},
158
  "execution_count": null,
159
  "outputs": []
160
  },
161
- {
162
- "cell_type": "markdown",
163
- "id": "62lvkfoyu1p",
164
- "source": "## 9. Train!",
165
- "metadata": {}
166
- },
167
  {
168
  "cell_type": "code",
169
- "id": "eohisxhna96",
170
  "source": "trainer.train()",
171
  "metadata": {},
172
  "execution_count": null,
@@ -174,56 +160,42 @@
174
  },
175
  {
176
  "cell_type": "markdown",
177
- "id": "dd8b2pk2s7",
178
- "source": "## 10. Save the Trained Model",
179
- "metadata": {}
180
- },
181
- {
182
- "cell_type": "code",
183
- "id": "t3d4tu6o5mc",
184
- "source": "SAVE_PATH = f\"origami-{TASK_NAME}-lora\"\n\n# Save LoRA adapter\nmodel.save_pretrained(SAVE_PATH)\ntokenizer.save_pretrained(SAVE_PATH)\nprint(f\"LoRA adapter saved to {SAVE_PATH}/\")\n\n# Optional: merge LoRA into base model and save full model\n# merged_path = f\"origami-{TASK_NAME}-merged\"\n# if USE_UNSLOTH:\n# model.save_pretrained_merged(merged_path, tokenizer)\n# else:\n# merged_model = model.merge_and_unload()\n# merged_model.save_pretrained(merged_path)\n# tokenizer.save_pretrained(merged_path)\n# print(f\"Merged model saved to {merged_path}/\")",
185
- "metadata": {},
186
- "execution_count": null,
187
- "outputs": []
188
- },
189
- {
190
- "cell_type": "markdown",
191
- "id": "q18eizy1ok",
192
- "source": "## 11. Evaluate β€” Generate & Score Completions\n\nTest the trained model by generating crease patterns and scoring them.",
193
  "metadata": {}
194
  },
195
  {
196
  "cell_type": "code",
197
- "id": "on56augj41",
198
- "source": "# Put model in inference mode\nif USE_UNSLOTH:\n FastLanguageModel.for_inference(model)\n\nNUM_EVAL_SAMPLES = 8\n\n# Build chat messages\nmessages = [{\"role\": \"user\", \"content\": prompt_text}]\ninput_ids = tokenizer.apply_chat_template(\n messages, tokenize=True, add_generation_prompt=True, return_tensors=\"pt\"\n).to(model.device)\n\nprint(f\"Generating {NUM_EVAL_SAMPLES} completions...\")\nprint(f\"Input length: {input_ids.shape[1]} tokens\\n\")\n\neval_completions = []\nfor i in range(NUM_EVAL_SAMPLES):\n with torch.no_grad():\n output = model.generate(\n input_ids,\n max_new_tokens=MAX_COMPLETION_LENGTH,\n temperature=0.7,\n top_p=0.9,\n do_sample=True,\n pad_token_id=tokenizer.pad_token_id,\n )\n response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)\n eval_completions.append([{\"content\": response}])\n \n # Quick score\n fold_data = extract_fold_json(response)\n if fold_data is None:\n status = \"UNPARSEABLE\"\n else:\n is_valid, err = validate_fold(fold_data)\n if not is_valid:\n status = f\"INVALID: {err}\"\n else:\n try:\n result = simulate(fold_data, crease_percent=1.0)\n target_result = simulate(task[\"target_fold\"], crease_percent=1.0)\n sim = compute_shape_match(result.positions, target_result.positions)\n status = f\"similarity={sim:.3f} (reward={sim * 20:.1f})\"\n except Exception as e:\n status = f\"SIM ERROR: {e}\"\n \n print(f\" Sample {i+1}: {status}\")\n\n# Compute aggregate reward scores\nprint(f\"\\nAggregate rewards:\")\nvf_scores = valid_fold(eval_completions)\nsm_scores = shape_match(eval_completions, task_name=TASK_NAME)\nprint(f\" valid_fold: mean={np.mean(vf_scores):.2f}, scores={vf_scores}\")\nprint(f\" shape_match: mean={np.mean(sm_scores):.2f}, scores={sm_scores}\")",
199
  "metadata": {},
200
  "execution_count": null,
201
  "outputs": []
202
  },
203
  {
204
  "cell_type": "markdown",
205
- "id": "tb1y8hszrk",
206
- "source": "## 12. Visualize a Generated Fold\n\nPick the best completion and visualize its crease pattern + 3D fold vs the target.",
207
  "metadata": {}
208
  },
209
  {
210
  "cell_type": "code",
211
- "id": "0zo3krbkiqej",
212
- "source": "# Find the best valid completion\nbest_idx = int(np.argmax(sm_scores))\nbest_response = eval_completions[best_idx][0][\"content\"]\nbest_fold = extract_fold_json(best_response)\n\nif best_fold is None or sm_scores[best_idx] <= 0:\n print(\"No valid completions to visualize.\")\nelse:\n is_valid, _ = validate_fold(best_fold)\n if not is_valid:\n print(\"Best completion has invalid FOLD structure.\")\n else:\n pred_result = simulate(best_fold, crease_percent=1.0)\n target_result = simulate(task[\"target_fold\"], crease_percent=1.0)\n \n fig = plt.figure(figsize=(14, 5))\n \n # 1) Generated 2D crease pattern\n ax1 = fig.add_subplot(131)\n ax1.set_title(f\"Generated Crease Pattern\\n(sample {best_idx+1})\", fontsize=10)\n ax1.set_aspect(\"equal\")\n verts = np.array(best_fold[\"vertices_coords\"])\n for i, (e, a) in enumerate(zip(best_fold[\"edges_vertices\"], best_fold[\"edges_assignment\"])):\n v1, v2 = verts[e[0]], verts[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n style = EDGE_STYLES.get(a, \"-\")\n ax1.plot([v1[0], v2[0]], [v1[1], v2[1]], color=color, linestyle=style, linewidth=2)\n ax1.scatter(verts[:, 0], verts[:, 1], c=\"black\", s=20, zorder=5)\n ax1.grid(True, alpha=0.2)\n \n # 2) Generated 3D fold\n ax2 = fig.add_subplot(132, projection=\"3d\")\n ax2.set_title(f\"Generated 3D Fold\\nsimilarity={sm_scores[best_idx]/20:.3f}\", fontsize=10)\n pos = pred_result.positions\n for i, (e, a) in enumerate(zip(best_fold[\"edges_vertices\"], best_fold[\"edges_assignment\"])):\n p1, p2 = pos[e[0]], pos[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n ax2.plot([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], color=color, linewidth=1.5)\n ax2.scatter(pos[:, 0], pos[:, 1], pos[:, 2], c=\"black\", s=15, zorder=5)\n \n # 3) Target 3D fold\n ax3 = fig.add_subplot(133, projection=\"3d\")\n ax3.set_title(\"Target 3D Fold\", fontsize=10)\n tpos = target_result.positions\n tfold = task[\"target_fold\"]\n for i, (e, a) in enumerate(zip(tfold[\"edges_vertices\"], tfold[\"edges_assignment\"])):\n p1, p2 = tpos[e[0]], tpos[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n ax3.plot([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], color=color, linewidth=1.5)\n ax3.scatter(tpos[:, 0], tpos[:, 1], tpos[:, 2], c=\"black\", s=15, zorder=5)\n \n plt.tight_layout()\n plt.show()\n \n print(f\"\\nBest generated FOLD JSON:\")\n print(json.dumps(best_fold, indent=2))",
213
  "metadata": {},
214
  "execution_count": null,
215
  "outputs": []
216
  },
217
  {
218
  "cell_type": "markdown",
219
- "id": "qlakksqmoe",
220
- "source": "## 13. Plot Training Logs",
221
  "metadata": {}
222
  },
223
  {
224
  "cell_type": "code",
225
- "id": "6nivbx4wgx9",
226
- "source": "# Extract training logs from the trainer\nlogs = trainer.state.log_history\n\n# Parse out loss and reward metrics\nsteps, losses, rewards = [], [], []\nfor entry in logs:\n if \"loss\" in entry:\n steps.append(entry.get(\"step\", 0))\n losses.append(entry[\"loss\"])\n if \"reward\" in entry:\n rewards.append(entry[\"reward\"])\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n\nif losses:\n ax1.plot(steps[:len(losses)], losses, color=\"steelblue\", linewidth=1, alpha=0.7)\n # Smoothed line\n if len(losses) > 10:\n window = min(20, len(losses) // 5)\n smoothed = np.convolve(losses, np.ones(window)/window, mode=\"valid\")\n ax1.plot(steps[window-1:len(smoothed)+window-1], smoothed, color=\"navy\", linewidth=2)\n ax1.set_xlabel(\"Step\")\n ax1.set_ylabel(\"Loss\")\n ax1.set_title(\"Training Loss\")\n ax1.grid(True, alpha=0.3)\nelse:\n ax1.text(0.5, 0.5, \"No loss data\", ha=\"center\", va=\"center\", transform=ax1.transAxes)\n\nif rewards:\n ax2.plot(range(len(rewards)), rewards, color=\"coral\", linewidth=1, alpha=0.7)\n if len(rewards) > 10:\n window = min(20, len(rewards) // 5)\n smoothed = np.convolve(rewards, np.ones(window)/window, mode=\"valid\")\n ax2.plot(range(window-1, len(smoothed)+window-1), smoothed, color=\"darkred\", linewidth=2)\n ax2.set_xlabel(\"Step\")\n ax2.set_ylabel(\"Reward\")\n ax2.set_title(\"Mean Reward\")\n ax2.grid(True, alpha=0.3)\nelse:\n ax2.text(0.5, 0.5, \"No reward data\", ha=\"center\", va=\"center\", transform=ax2.transAxes)\n\nplt.tight_layout()\nplt.show()",
227
  "metadata": {},
228
  "execution_count": null,
229
  "outputs": []
 
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
5
+ "id": "1g9gapmu4yc",
6
+ "source": "# Origami RL β€” GRPO Training Notebook\n\nTrain an LLM to generate valid FOLD-format crease patterns that fold into target shapes.\n\n**Architecture:** Connects to the origami_env OpenEnv server (HuggingFace Space) for simulation and reward computation.\n\n**Pipeline:**\n1. LLM receives a prompt describing a target shape\n2. LLM generates a FOLD JSON crease pattern\n3. Pattern is sent to the OpenEnv server for physics simulation\n4. Server returns reward = shape similarity to target x 20\n\n**Reward functions:**\n- `valid_fold`: Local JSON validation (+1.0 / -0.5 / -2.0)\n- `shape_match`: Calls OpenEnv server to simulate and score (0-20 / -1.0 / -2.0)\n\n**Algorithm:** GRPO (Group Relative Policy Optimization) via TRL + Unsloth LoRA",
7
  "metadata": {}
8
  },
9
  {
10
  "cell_type": "markdown",
11
+ "id": "kux2o8t14m",
12
  "source": "## 1. Install Dependencies",
13
  "metadata": {}
14
  },
15
  {
16
  "cell_type": "code",
17
+ "id": "hlz4gjiqcr",
18
+ "source": "import sys\nIN_COLAB = \"google.colab\" in sys.modules\n\nif IN_COLAB:\n !pip install -q \"unsloth[colab-new]\"\n !pip install -q trl datasets peft accelerate bitsandbytes xformers\nelse:\n !pip install -q \"trl>=0.7\" \"datasets>=2.14\" torch transformers accelerate bitsandbytes peft\n\n!pip install -q requests websocket-client numpy",
19
  "metadata": {},
20
  "execution_count": null,
21
  "outputs": []
22
  },
23
  {
24
  "cell_type": "markdown",
25
+ "id": "s0yrntws1x",
26
+ "source": "## 2. Connect to OpenEnv Server\n\nConnect to the origami_env HuggingFace Space. The server handles all physics simulation and shape matching.",
27
  "metadata": {}
28
  },
29
  {
30
  "cell_type": "code",
31
+ "id": "a261mxlne8n",
32
+ "source": "import json\nimport re\nimport requests\nimport numpy as np\nfrom typing import Any\n\n# ============================================================\n# OpenEnv Server URL β€” the origami_env HuggingFace Space\n# ============================================================\nENV_URL = \"https://openenv-community-origami-env.hf.space\"\n\n# Verify server is up\ntry:\n r = requests.get(f\"{ENV_URL}/health\", timeout=10)\n r.raise_for_status()\n print(f\"Server connected: {ENV_URL}\")\n print(f\"Health: {r.json()}\")\nexcept Exception as e:\n print(f\"ERROR: Cannot reach server at {ENV_URL}\")\n print(f\" {e}\")\n print(f\"\\nMake sure the HF Space is running!\")\n print(f\" https://huggingface.co/spaces/openenv-community/origami_env\")",
33
  "metadata": {},
34
  "execution_count": null,
35
  "outputs": []
36
  },
37
  {
38
  "cell_type": "code",
39
+ "id": "o48y627ctwd",
40
+ "source": "# Fetch available tasks from the server\ntasks_resp = requests.get(f\"{ENV_URL}/tasks\")\ntasks_resp.raise_for_status()\nTASKS = tasks_resp.json()\n\nprint(\"Available tasks from server:\")\nfor name, task in TASKS.items():\n fold = task[\"target_fold\"]\n n_folds = sum(1 for a in fold[\"edges_assignment\"] if a in (\"M\", \"V\"))\n print(f\" {name:15s} | {task['description']} | difficulty={task['difficulty']} | folds={n_folds}\")",
41
  "metadata": {},
42
  "execution_count": null,
43
  "outputs": []
44
  },
45
  {
46
  "cell_type": "markdown",
47
+ "id": "qppn5im1e6i",
48
+ "source": "## 3. Define Reward Functions (using OpenEnv server)",
49
  "metadata": {}
50
  },
51
  {
52
  "cell_type": "code",
53
+ "id": "b7fa2jl3dxu",
54
+ "source": "import websocket\nimport threading\n\ndef extract_fold_json(response: str) -> dict | None:\n \"\"\"Extract FOLD JSON from LLM response text.\"\"\"\n # Try fenced code block first\n match = re.search(r\"```(?:json)?\\s*(\\{.*?\\})\\s*```\", response, re.DOTALL)\n if match:\n try:\n return json.loads(match.group(1))\n except json.JSONDecodeError:\n pass\n # Try raw JSON with vertices_coords\n match = re.search(r\"\\{[^{}]*\\\"vertices_coords\\\"[^{}]*\\}\", response, re.DOTALL)\n if match:\n try:\n return json.loads(match.group(0))\n except json.JSONDecodeError:\n pass\n # Try whole response\n try:\n data = json.loads(response.strip())\n if isinstance(data, dict) and \"vertices_coords\" in data:\n return data\n except (json.JSONDecodeError, ValueError):\n pass\n return None\n\n\ndef validate_fold_local(fold_data: dict) -> tuple[bool, str]:\n \"\"\"Local validation of FOLD JSON structure (no server needed).\"\"\"\n for key in (\"vertices_coords\", \"edges_vertices\", \"edges_assignment\"):\n if key not in fold_data:\n return False, f\"Missing: {key}\"\n verts = fold_data[\"vertices_coords\"]\n edges = fold_data[\"edges_vertices\"]\n assignments = fold_data[\"edges_assignment\"]\n if len(verts) < 3:\n return False, f\"Need >= 3 vertices, got {len(verts)}\"\n if len(edges) < 3:\n return False, f\"Need >= 3 edges, got {len(edges)}\"\n if len(edges) != len(assignments):\n return False, \"edges/assignments length mismatch\"\n num_verts = len(verts)\n for i, e in enumerate(edges):\n if not isinstance(e, (list, tuple)) or len(e) != 2:\n return False, f\"Edge {i} invalid\"\n if e[0] < 0 or e[0] >= num_verts or e[1] < 0 or e[1] >= num_verts:\n return False, f\"Edge {i} bad vertex index\"\n if e[0] == e[1]:\n return False, f\"Edge {i} degenerate\"\n valid_a = {\"M\", \"V\", \"B\", \"F\", \"U\", \"C\"}\n for i, a in enumerate(assignments):\n if a not in valid_a:\n return False, f\"Edge {i} bad assignment '{a}'\"\n if not any(a in (\"M\", \"V\") for a in assignments):\n return False, \"No fold creases\"\n if not any(a == \"B\" for a in assignments):\n return False, \"No boundary edges\"\n return True, \"\"\n\n\ndef call_env_server(task_name: str, fold_data: dict) -> dict:\n \"\"\"Call the OpenEnv server via WebSocket: reset + step, return observation.\"\"\"\n ws_url = ENV_URL.replace(\"https://\", \"wss://\").replace(\"http://\", \"ws://\") + \"/ws\"\n ws = websocket.create_connection(ws_url, timeout=30)\n try:\n # Reset\n ws.send(json.dumps({\"type\": \"reset\", \"data\": {\"task_name\": task_name}}))\n reset_resp = json.loads(ws.recv())\n\n # Step\n ws.send(json.dumps({\"type\": \"step\", \"data\": {\"fold_data\": fold_data}}))\n step_resp = json.loads(ws.recv())\n\n return step_resp.get(\"data\", {})\n finally:\n ws.close()\n\n\ndef valid_fold(completions: list, **kwargs: Any) -> list[float]:\n \"\"\"Reward 1: Does the LLM output parse as valid FOLD JSON?\n +1.0 valid, -0.5 parseable but invalid, -2.0 unparseable\n \"\"\"\n scores = []\n for completion in completions:\n response = completion[0][\"content\"]\n fold_data = extract_fold_json(response)\n if fold_data is None:\n scores.append(-2.0)\n continue\n is_valid, _ = validate_fold_local(fold_data)\n scores.append(1.0 if is_valid else -0.5)\n return scores\n\n\ndef shape_match(completions: list, task_name: str = \"triangle\", **kwargs: Any) -> list[float]:\n \"\"\"Reward 2: Send fold to OpenEnv server, get shape similarity reward.\n Score = server reward (similarity x 20, range 0-20)\n -1.0 if simulation fails, -2.0 if FOLD data invalid\n \"\"\"\n scores = []\n for completion in completions:\n response = completion[0][\"content\"]\n fold_data = extract_fold_json(response)\n if fold_data is None:\n scores.append(-2.0)\n continue\n is_valid, _ = validate_fold_local(fold_data)\n if not is_valid:\n scores.append(-1.0)\n continue\n try:\n obs = call_env_server(task_name, fold_data)\n reward = obs.get(\"reward\", -1.0)\n scores.append(reward)\n except Exception as e:\n scores.append(-1.0)\n return scores\n\n\nprint(\"Reward functions defined (using OpenEnv server for shape_match).\")",
55
  "metadata": {},
56
  "execution_count": null,
57
  "outputs": []
58
  },
59
  {
60
+ "cell_type": "markdown",
61
+ "id": "d1y2o6fjk0f",
62
+ "source": "## 4. Test the Server Connection",
63
+ "metadata": {}
 
 
64
  },
65
  {
66
  "cell_type": "code",
67
+ "id": "af9wy807tre",
68
+ "source": "# Test the full pipeline: send the target fold to the server, expect reward=20.0\ntriangle_fold = TASKS[\"triangle\"][\"target_fold\"]\nobs = call_env_server(\"triangle\", triangle_fold)\nprint(f\"Server test β€” triangle target fold:\")\nprint(f\" reward: {obs.get('reward')}\")\nprint(f\" shape_similarity: {obs.get('shape_similarity')}\")\nprint(f\" done: {obs.get('done')}\")\nprint(f\" error: {obs.get('error')}\")\nassert obs.get(\"reward\") == 20.0, f\"Expected reward=20.0, got {obs.get('reward')}\"\nprint(\"\\nServer connection verified!\")",
69
  "metadata": {},
70
  "execution_count": null,
71
  "outputs": []
72
  },
 
 
 
 
 
 
73
  {
74
  "cell_type": "code",
75
+ "id": "skocmbwjmyf",
76
+ "source": "# Test reward functions with mock LLM outputs\ngood_response = json.dumps(triangle_fold)\nbad_json = \"I think we should fold it like this...\"\ninvalid_fold = json.dumps({\"vertices_coords\": [[0, 0]], \"edges_vertices\": [], \"edges_assignment\": []})\n\ncompletions = [\n [{\"content\": f\"```json\\n{good_response}\\n```\"}],\n [{\"content\": bad_json}],\n [{\"content\": invalid_fold}],\n]\n\nprint(\"valid_fold rewards:\", valid_fold(completions))\nprint(\"shape_match rewards:\", shape_match(completions, task_name=\"triangle\"))\nprint()\nprint(\"Expected: valid_fold = [1.0, -2.0, -0.5]\")\nprint(\"Expected: shape_match = [20.0, -2.0, -1.0]\")",
77
  "metadata": {},
78
  "execution_count": null,
79
  "outputs": []
80
  },
81
  {
82
  "cell_type": "markdown",
83
+ "id": "ufqu2bw7a8h",
84
  "source": "## 5. Training Configuration",
85
  "metadata": {}
86
  },
87
  {
88
  "cell_type": "code",
89
+ "id": "ybwejb0fqh",
90
  "source": "# ============================================================\n# Training hyperparameters β€” edit these before launching\n# ============================================================\n\nTASK_NAME = \"triangle\" # \"triangle\", \"half_fold\", \"quarter_fold\", \"letter_fold\"\nMODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\" # Change to your preferred model\nMAX_STEPS = 600 # Total GRPO training steps\nNUM_GENERATIONS = 4 # Completions per prompt per step\nLEARNING_RATE = 2e-4\nLORA_R = 8 # LoRA rank\nLORA_ALPHA = 16 # LoRA alpha\nMAX_PROMPT_LENGTH = 1024\nMAX_COMPLETION_LENGTH = 1024\nDATASET_SIZE = 1000 # Number of prompt copies (same prompt repeated)\nOUTPUT_DIR = \"outputs\"\nSAVE_STEPS = 100",
91
  "metadata": {},
92
  "execution_count": null,
 
94
  },
95
  {
96
  "cell_type": "markdown",
97
+ "id": "j2sfyperba",
98
+ "source": "## 6. Build Prompt & Dataset",
99
  "metadata": {}
100
  },
101
  {
102
  "cell_type": "code",
103
+ "id": "00xlaqwlrdfuq",
104
+ "source": "PROMPT_TEMPLATE = \"\"\"You are an origami designer. Generate a FOLD-format crease pattern\nthat, when folded, produces the target shape described below.\n\nTarget: {description}\nPaper size: {width} x {height}\n\nOutput a JSON object with these exact fields:\n- vertices_coords: [[x, y], ...] β€” 2D positions on the flat paper (0 to {width} for x, 0 to {height} for y)\n- edges_vertices: [[v1, v2], ...] β€” pairs of vertex indices forming edges\n- edges_assignment: [\"B\"|\"M\"|\"V\", ...] β€” B=boundary, M=mountain fold, V=valley fold\n- edges_foldAngle: [angle, ...] β€” fold angles in degrees (M: negative like -180, V: positive like 180, B: 0)\n\nRules:\n- Boundary edges (B) must outline the paper rectangle\n- At least one fold crease (M or V) must exist\n- Mountain fold angles are negative (-180 to 0)\n- Valley fold angles are positive (0 to 180)\n- All vertex indices in edges must be valid (0 to N-1)\n\nOutput ONLY the JSON object wrapped in ```json ... ``` markers.\"\"\"\n\n# Get task info from server\ntask = TASKS[TASK_NAME]\nprompt_text = PROMPT_TEMPLATE.format(\n description=task[\"description\"],\n width=task[\"paper\"][\"width\"],\n height=task[\"paper\"][\"height\"],\n)\n\nprint(\"=\"*60)\nprint(\"PROMPT THAT THE LLM WILL SEE:\")\nprint(\"=\"*60)\nprint(prompt_text)",
105
  "metadata": {},
106
  "execution_count": null,
107
  "outputs": []
108
  },
109
  {
110
  "cell_type": "code",
111
+ "id": "uvcwwjiid3",
112
+ "source": "from datasets import Dataset\n\ndataset = Dataset.from_list(\n [{\"prompt\": [{\"role\": \"user\", \"content\": prompt_text}], \"answer\": 0}] * DATASET_SIZE\n)\nprint(f\"Dataset size: {len(dataset)}\")",
113
  "metadata": {},
114
  "execution_count": null,
115
  "outputs": []
116
  },
117
  {
118
  "cell_type": "markdown",
119
+ "id": "6q32ftes9y",
120
+ "source": "## 7. Load Model + LoRA",
121
  "metadata": {}
122
  },
123
  {
124
  "cell_type": "code",
125
+ "id": "hllux7qgxjf",
126
+ "source": "import torch\nprint(f\"CUDA available: {torch.cuda.is_available()}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")",
127
  "metadata": {},
128
  "execution_count": null,
129
  "outputs": []
130
  },
131
  {
132
  "cell_type": "code",
133
+ "id": "g33tcdik4v",
134
+ "source": "USE_UNSLOTH = False\ntry:\n from unsloth import FastLanguageModel\n USE_UNSLOTH = True\n print(\"Using Unsloth\")\nexcept ImportError:\n print(\"Using HuggingFace + PEFT\")\n\nif USE_UNSLOTH:\n model, tokenizer = FastLanguageModel.from_pretrained(\n model_name=MODEL_NAME, load_in_4bit=True,\n max_seq_length=MAX_PROMPT_LENGTH + MAX_COMPLETION_LENGTH,\n )\n model = FastLanguageModel.get_peft_model(\n model, r=LORA_R, lora_alpha=LORA_ALPHA,\n target_modules=[\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"],\n use_gradient_checkpointing=\"unsloth\",\n )\nelse:\n from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n from peft import LoraConfig, get_peft_model\n\n bnb_config = BitsAndBytesConfig(\n load_in_4bit=True, bnb_4bit_quant_type=\"nf4\", bnb_4bit_compute_dtype=torch.bfloat16,\n ) if torch.cuda.is_available() else None\n\n tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n model = AutoModelForCausalLM.from_pretrained(\n MODEL_NAME, quantization_config=bnb_config,\n device_map=\"auto\" if torch.cuda.is_available() else \"cpu\",\n torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,\n )\n model = get_peft_model(model, LoraConfig(\n r=LORA_R, lora_alpha=LORA_ALPHA, task_type=\"CAUSAL_LM\",\n target_modules=[\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"],\n ))\n\nif tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n\nmodel.print_trainable_parameters()",
135
  "metadata": {},
136
  "execution_count": null,
137
  "outputs": []
138
  },
139
  {
140
  "cell_type": "markdown",
141
+ "id": "0eur29ylvjw",
142
+ "source": "## 8. Setup GRPO Trainer & Train",
143
  "metadata": {}
144
  },
145
  {
146
  "cell_type": "code",
147
+ "id": "7t5e1almadd",
148
+ "source": "from trl import GRPOConfig, GRPOTrainer\n\ndef shape_match_reward(completions, **kwargs):\n return shape_match(completions, task_name=TASK_NAME, **kwargs)\n\ntraining_args = GRPOConfig(\n temperature=1.0,\n learning_rate=LEARNING_RATE,\n weight_decay=0.001,\n warmup_ratio=0.1,\n lr_scheduler_type=\"linear\",\n optim=\"adamw_8bit\" if torch.cuda.is_available() else \"adamw_torch\",\n logging_steps=1,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=1,\n num_generations=NUM_GENERATIONS,\n max_prompt_length=MAX_PROMPT_LENGTH,\n max_completion_length=MAX_COMPLETION_LENGTH,\n max_steps=MAX_STEPS,\n save_steps=SAVE_STEPS,\n output_dir=OUTPUT_DIR,\n report_to=\"none\",\n)\n\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=[valid_fold, shape_match_reward],\n args=training_args,\n train_dataset=dataset,\n)\n\nprint(f\"Trainer ready. Task: {TASK_NAME}, Model: {MODEL_NAME}\")\nprint(f\"Reward: valid_fold (local) + shape_match (via OpenEnv server at {ENV_URL})\")",
149
  "metadata": {},
150
  "execution_count": null,
151
  "outputs": []
152
  },
 
 
 
 
 
 
153
  {
154
  "cell_type": "code",
155
+ "id": "pu90vgkj4mk",
156
  "source": "trainer.train()",
157
  "metadata": {},
158
  "execution_count": null,
 
160
  },
161
  {
162
  "cell_type": "markdown",
163
+ "id": "jhw5kwlznif",
164
+ "source": "## 9. Save Model",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  "metadata": {}
166
  },
167
  {
168
  "cell_type": "code",
169
+ "id": "hf58ngofjsf",
170
+ "source": "SAVE_PATH = f\"origami-{TASK_NAME}-lora\"\nmodel.save_pretrained(SAVE_PATH)\ntokenizer.save_pretrained(SAVE_PATH)\nprint(f\"LoRA adapter saved to {SAVE_PATH}/\")",
171
  "metadata": {},
172
  "execution_count": null,
173
  "outputs": []
174
  },
175
  {
176
  "cell_type": "markdown",
177
+ "id": "fszxqs5edxt",
178
+ "source": "## 10. Evaluate β€” Generate & Score via Server",
179
  "metadata": {}
180
  },
181
  {
182
  "cell_type": "code",
183
+ "id": "exo1kdngzxc",
184
+ "source": "if USE_UNSLOTH:\n FastLanguageModel.for_inference(model)\n\nNUM_EVAL_SAMPLES = 8\nmessages = [{\"role\": \"user\", \"content\": prompt_text}]\ninput_ids = tokenizer.apply_chat_template(\n messages, tokenize=True, add_generation_prompt=True, return_tensors=\"pt\"\n).to(model.device)\n\nprint(f\"Generating {NUM_EVAL_SAMPLES} completions...\\n\")\n\neval_completions = []\nfor i in range(NUM_EVAL_SAMPLES):\n with torch.no_grad():\n output = model.generate(\n input_ids, max_new_tokens=MAX_COMPLETION_LENGTH,\n temperature=0.7, top_p=0.9, do_sample=True,\n pad_token_id=tokenizer.pad_token_id,\n )\n response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)\n eval_completions.append([{\"content\": response}])\n\n # Score via server\n fold_data = extract_fold_json(response)\n if fold_data is None:\n status = \"UNPARSEABLE\"\n else:\n is_valid, err = validate_fold_local(fold_data)\n if not is_valid:\n status = f\"INVALID: {err}\"\n else:\n try:\n obs = call_env_server(TASK_NAME, fold_data)\n sim = obs.get(\"shape_similarity\", 0)\n reward = obs.get(\"reward\", 0)\n status = f\"similarity={sim:.3f} (reward={reward:.1f})\"\n except Exception as e:\n status = f\"SERVER ERROR: {e}\"\n\n print(f\" Sample {i+1}: {status}\")\n\nprint(f\"\\nAggregate rewards:\")\nvf_scores = valid_fold(eval_completions)\nsm_scores = shape_match(eval_completions, task_name=TASK_NAME)\nprint(f\" valid_fold: mean={np.mean(vf_scores):.2f}, scores={vf_scores}\")\nprint(f\" shape_match: mean={np.mean(sm_scores):.2f}, scores={sm_scores}\")",
185
  "metadata": {},
186
  "execution_count": null,
187
  "outputs": []
188
  },
189
  {
190
  "cell_type": "markdown",
191
+ "id": "zqal21b7rtr",
192
+ "source": "## 11. Plot Training Logs",
193
  "metadata": {}
194
  },
195
  {
196
  "cell_type": "code",
197
+ "id": "o5mhg6s1tcj",
198
+ "source": "import matplotlib.pyplot as plt\n\nlogs = trainer.state.log_history\nsteps, losses, rewards = [], [], []\nfor entry in logs:\n if \"loss\" in entry:\n steps.append(entry.get(\"step\", 0))\n losses.append(entry[\"loss\"])\n if \"reward\" in entry:\n rewards.append(entry[\"reward\"])\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n\nif losses:\n ax1.plot(steps[:len(losses)], losses, color=\"steelblue\", linewidth=1, alpha=0.7)\n if len(losses) > 10:\n w = min(20, len(losses) // 5)\n smoothed = np.convolve(losses, np.ones(w)/w, mode=\"valid\")\n ax1.plot(steps[w-1:len(smoothed)+w-1], smoothed, color=\"navy\", linewidth=2)\n ax1.set_xlabel(\"Step\"); ax1.set_ylabel(\"Loss\"); ax1.set_title(\"Training Loss\"); ax1.grid(True, alpha=0.3)\n\nif rewards:\n ax2.plot(range(len(rewards)), rewards, color=\"coral\", linewidth=1, alpha=0.7)\n if len(rewards) > 10:\n w = min(20, len(rewards) // 5)\n smoothed = np.convolve(rewards, np.ones(w)/w, mode=\"valid\")\n ax2.plot(range(w-1, len(smoothed)+w-1), smoothed, color=\"darkred\", linewidth=2)\n ax2.set_xlabel(\"Step\"); ax2.set_ylabel(\"Reward\"); ax2.set_title(\"Mean Reward\"); ax2.grid(True, alpha=0.3)\n\nplt.tight_layout()\nplt.show()",
199
  "metadata": {},
200
  "execution_count": null,
201
  "outputs": []