Spaces:
Running
Running
Update notebook for Colab: clone repo, explain local sim vs server
#2
by
sissississi - opened
training/Dockerfile.train
CHANGED
|
@@ -28,10 +28,3 @@ ENV NUM_GENERATIONS="4"
|
|
| 28 |
ENV LR="2e-4"
|
| 29 |
|
| 30 |
CMD ["sh", "-c", "python -m training.train_grpo --task $TASK --model $MODEL --max_steps $MAX_STEPS --num_generations $NUM_GENERATIONS --lr $LR"]
|
| 31 |
-
="triangle"
|
| 32 |
-
ENV MODEL="Qwen/Qwen2.5-3B-Instruct"
|
| 33 |
-
ENV MAX_STEPS="600"
|
| 34 |
-
ENV NUM_GENERATIONS="4"
|
| 35 |
-
ENV LR="2e-4"
|
| 36 |
-
|
| 37 |
-
CMD ["sh", "-c", "python -m training.train_grpo --task $TASK --model $MODEL --max_steps $MAX_STEPS --num_generations $NUM_GENERATIONS --lr $LR"]
|
|
|
|
| 28 |
ENV LR="2e-4"
|
| 29 |
|
| 30 |
CMD ["sh", "-c", "python -m training.train_grpo --task $TASK --model $MODEL --max_steps $MAX_STEPS --num_generations $NUM_GENERATIONS --lr $LR"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/train_origami.ipynb
CHANGED
|
@@ -3,19 +3,19 @@
|
|
| 3 |
{
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"id": "p8uwc5bkc4n",
|
| 6 |
-
"source": "# Origami RL β GRPO Training Notebook\n\nTrain an LLM to generate valid FOLD-format crease patterns that fold into target shapes.\n\n**Pipeline:**\n1. LLM receives a prompt describing a target shape (e.g. \"fold diagonally into a triangle\")\n2. LLM generates a FOLD JSON crease pattern\n3. Physics simulator folds the paper analytically\n4. Reward = shape similarity (chamfer distance) to target
|
| 7 |
"metadata": {}
|
| 8 |
},
|
| 9 |
{
|
| 10 |
"cell_type": "markdown",
|
| 11 |
"id": "xxp4krkl6w",
|
| 12 |
-
"source": "## 1. Install Dependencies",
|
| 13 |
"metadata": {}
|
| 14 |
},
|
| 15 |
{
|
| 16 |
"cell_type": "code",
|
| 17 |
"id": "ulhu8a5p5ti",
|
| 18 |
-
"source": "#
|
| 19 |
"metadata": {},
|
| 20 |
"execution_count": null,
|
| 21 |
"outputs": []
|
|
@@ -29,7 +29,7 @@
|
|
| 29 |
{
|
| 30 |
"cell_type": "code",
|
| 31 |
"id": "3hr273dhqiv",
|
| 32 |
-
"source": "import os\nimport sys\nimport json\n\n# Add the repo
|
| 33 |
"metadata": {},
|
| 34 |
"execution_count": null,
|
| 35 |
"outputs": []
|
|
|
|
| 3 |
{
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"id": "p8uwc5bkc4n",
|
| 6 |
+
"source": "# Origami RL β GRPO Training Notebook\n\nTrain an LLM to generate valid FOLD-format crease patterns that fold into target shapes.\n\n**Pipeline:**\n1. LLM receives a prompt describing a target shape (e.g. \"fold diagonally into a triangle\")\n2. LLM generates a FOLD JSON crease pattern\n3. Physics simulator folds the paper analytically (runs locally β no server needed)\n4. Reward = shape similarity (chamfer distance) to target x 20\n\n**Why local simulation instead of the OpenEnv server?**\n> The reward function is called thousands of times during GRPO training (every generation, every step). HTTP roundtrips to a remote server would be way too slow. The simulation is pure numpy/scipy and runs in milliseconds locally. The OpenEnv client/server pattern is for inference and evaluation, not training loops.\n\n**Reward functions:**\n- `valid_fold`: +1.0 valid FOLD JSON, -0.5 parseable but invalid, -2.0 unparseable\n- `shape_match`: similarity x 20.0 (0-20), -1.0 sim fails, -2.0 invalid FOLD\n\n**Algorithm:** GRPO (Group Relative Policy Optimization) via TRL + Unsloth LoRA",
|
| 7 |
"metadata": {}
|
| 8 |
},
|
| 9 |
{
|
| 10 |
"cell_type": "markdown",
|
| 11 |
"id": "xxp4krkl6w",
|
| 12 |
+
"source": "## 1. Clone Repo & Install Dependencies",
|
| 13 |
"metadata": {}
|
| 14 |
},
|
| 15 |
{
|
| 16 |
"cell_type": "code",
|
| 17 |
"id": "ulhu8a5p5ti",
|
| 18 |
+
"source": "import os, sys\n\n# Clone the origami_env repo (contains the simulator + reward functions)\nif not os.path.exists(\"origami_env\"):\n !git clone https://huggingface.co/spaces/openenv-community/origami_env\n print(\"Repo cloned.\")\nelse:\n print(\"Repo already exists.\")\n\n# Install training deps\nIN_COLAB = \"google.colab\" in sys.modules\nif IN_COLAB:\n !pip install -q \"unsloth[colab-new]\"\n !pip install -q trl datasets peft accelerate bitsandbytes xformers\nelse:\n !pip install -q \"trl>=0.7\" \"datasets>=2.14\" torch transformers accelerate bitsandbytes peft\n\n# Simulation deps (lightweight)\n!pip install -q numpy scipy pydantic",
|
| 19 |
"metadata": {},
|
| 20 |
"execution_count": null,
|
| 21 |
"outputs": []
|
|
|
|
| 29 |
{
|
| 30 |
"cell_type": "code",
|
| 31 |
"id": "3hr273dhqiv",
|
| 32 |
+
"source": "import os\nimport sys\nimport json\n\n# Add the cloned repo to Python path\nREPO_ROOT = os.path.abspath(\"origami_env\")\nif REPO_ROOT not in sys.path:\n sys.path.insert(0, REPO_ROOT)\n\nprint(f\"Repo root: {REPO_ROOT}\")\nprint(f\"Python: {sys.version}\")",
|
| 33 |
"metadata": {},
|
| 34 |
"execution_count": null,
|
| 35 |
"outputs": []
|