Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- Dockerfile +37 -0
- Hackathon.md +108 -0
- README.md +134 -7
- RLM_Forge_Project_Overview.md +608 -0
- __init__.py +2 -0
- client.py +2 -0
- main.py +6 -0
- models.py +2 -0
- openenv.yaml +20 -0
- pyproject.toml +23 -0
- rlm_forge/__init__.py +5 -0
- rlm_forge/client.py +26 -0
- rlm_forge/models.py +65 -0
- rlm_forge/server/__init__.py +0 -0
- rlm_forge/server/app.py +30 -0
- rlm_forge/server/environment.py +192 -0
- rlm_forge/server/feature_extractor.py +310 -0
- rlm_forge/server/repo_manager.py +106 -0
- rlm_forge/server/reward.py +169 -0
- rlm_forge/server/sandbox.py +213 -0
- rlm_forge_training.ipynb +802 -0
- rlm_forge_training.py +470 -0
- server/__init__.py +1 -0
- server/app.py +23 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# System deps: git for cloning repos, grep for search
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
git \
|
| 8 |
+
grep \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
# Copy project files
|
| 12 |
+
COPY pyproject.toml .
|
| 13 |
+
COPY openenv.yaml .
|
| 14 |
+
COPY rlm_forge/ rlm_forge/
|
| 15 |
+
|
| 16 |
+
# Install Python deps
|
| 17 |
+
RUN pip install --no-cache-dir -e .
|
| 18 |
+
|
| 19 |
+
# Pre-install common test dependencies for target repos
|
| 20 |
+
RUN pip install --no-cache-dir pytest text-unidecode freezegun
|
| 21 |
+
|
| 22 |
+
# AMENDMENT 2: Pre-clone curated repos to avoid network I/O on every reset()
|
| 23 |
+
RUN mkdir -p /app/repos && \
|
| 24 |
+
git clone --depth=1 https://github.com/un33k/python-slugify /app/repos/python-slugify && \
|
| 25 |
+
git clone --depth=1 https://github.com/python-humanize/humanize /app/repos/humanize
|
| 26 |
+
|
| 27 |
+
# Install curated repo dependencies
|
| 28 |
+
RUN pip install --no-cache-dir -e /app/repos/python-slugify || true
|
| 29 |
+
RUN pip install --no-cache-dir -e /app/repos/humanize || true
|
| 30 |
+
|
| 31 |
+
EXPOSE 8000
|
| 32 |
+
|
| 33 |
+
ENV PYTHONUNBUFFERED=1
|
| 34 |
+
ENV RLM_FORGE_PRE_CLONED_DIR=/app/repos
|
| 35 |
+
|
| 36 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 37 |
+
CMD ["python", "-m", "uvicorn", "rlm_forge.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
Hackathon.md
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Rules
|
| 2 |
+
|
| 3 |
+
- Your project **must** use OpenEnv (stable release 0.2.1) deployed on HF spaces
|
| 4 |
+
|
| 5 |
+
- You must show a minimal training script for your environment using Unsloth or HF TRL in Colab.
|
| 6 |
+
|
| 7 |
+
- You must upload a **one minute** demo video to YouTube talking about your submission.
|
| 8 |
+
|
| 9 |
+
## Hackathon Problem Statements**
|
| 10 |
+
|
| 11 |
+
Your project must address at least **one of the five** required problem statements.
|
| 12 |
+
|
| 13 |
+
- Some problem statements include **optional partner-sponsored sub-problem statements**, which are additional focus areas related to the main theme.
|
| 14 |
+
|
| 15 |
+
- Your project may align with **multiple partner sub-problem statements**, but you can only be **judged for a maximum of two**. Please **select up to two** when submitting.
|
| 16 |
+
|
| 17 |
+
- Projects that match these partner sub-problem statements are eligible for **extra partner prizes**, judged separately from the main track winners.
|
| 18 |
+
|
| 19 |
+
- Each partner sub-problem statement carries a prize of **$10,000 USD**.
|
| 20 |
+
|
| 21 |
+
**Statement 1: Multi-Agent Interactions**
|
| 22 |
+
|
| 23 |
+
Environments for this theme involve cooperation, competition, negotiation, and coalition formation. Learning from these environments will enable agents to model the beliefs and incentives of others in partially observable settings. This drives theory-of-mind reasoning and emergent strategic behavior.
|
| 24 |
+
|
| 25 |
+
- **Expected Outcome:** an environment that can be used to train multi-agent task handling in a LLM
|
| 26 |
+
|
| 27 |
+
- **Example Environments:** Market simulations, compute-allocation negotiations, collaborative puzzle worlds, mixed cooperative/competitive strategy games.
|
| 28 |
+
|
| 29 |
+
- **Partner Sub-Themes:**
|
| 30 |
+
|
| 31 |
+
- **Fleet AI:** Scalable Oversight: Environments that train oversight agents to monitor, analyze, and explain the behavior of other AI agents operating in complex, multi-agent settings.
|
| 32 |
+
- **Halluminate:** Multi-Actor Environments: Build a realistic environment where an agent interacts with and manages multiple actors (agents) to discover and achieve the task
|
| 33 |
+
|
| 34 |
+
**Statement 2: (Super) Long-Horizon Planning & Instruction Following**
|
| 35 |
+
|
| 36 |
+
You will build environments that require deep, multi-step reasoning with sparse or delayed rewards. After using these environments, the goal is to enable agents to decompose goals, track state over extended trajectories, and recover from early mistakes. The aim is to push beyond shallow next-token reasoning toward structured planning and durable internal representations.
|
| 37 |
+
|
| 38 |
+
- **Expected Outcome:** an environment that can capture and improve LLM behaviour on challenging long horizon tasks that need long running sessions beyond context memory limits.
|
| 39 |
+
|
| 40 |
+
- **Example Environments:** Research-planning simulators, large-scale codebase refactoring tasks, strategic resource management worlds, long-horizon logistics optimization, extremely complicated long-horizon instruction following (e.g., 300 instructions scattered around).
|
| 41 |
+
|
| 42 |
+
- **Partner Sub-Themes:**
|
| 43 |
+
|
| 44 |
+
- **Mercor:** Make an environment with capped/uncapped rewards where frontier model rewards scale with token output.
|
| 45 |
+
|
| 46 |
+
- **Scale AI:** Environments for long horizon workflows for non-code use cases within a business setting: focusing on either Sales, Project management, or HR & IT.
|
| 47 |
+
|
| 48 |
+
**Statement 3: World Modeling**
|
| 49 |
+
|
| 50 |
+
- **Statement 3.1: Professional Tasks:** Here you will develop environments that require real interaction with tools, APIs, or dynamic systems where the model is expected to do real hard work instead of exploiting short-cuts to arrive at the desired outcome. Learning from these environments will enable agents to maintain consistent internal state, update beliefs based on outcomes, and orchestrate multi-step workflows. The goal is to strengthen causal reasoning and persistent world models.
|
| 51 |
+
|
| 52 |
+
- **Expected Outcome:** an environment capturing nuances of a defined partially observable world and improve LLM interaction with it
|
| 53 |
+
|
| 54 |
+
- **Example Environments:** Dynamic browser/API ecosystems, enterprise applications, scientific workflow loops (papers → code → experiments), economic simulations with feedback, tool-discovery benchmarks.
|
| 55 |
+
|
| 56 |
+
- **Partner Sub-Theme:**
|
| 57 |
+
|
| 58 |
+
- **Scaler AI Labs:** Multi-App RL Environment for Enterprise Workflows: Create RL environments to demonstrate complex workflows, business rule nuances etc in a large enterprise
|
| 59 |
+
|
| 60 |
+
- **Statement 3.2: Personalized Tasks:** Here we will develop an environment that offers real personalized task handling, imagine replying to personal messages or handling dinner conflicts due to work conflicts, replying to tough emails. Think any personal assistant tasks.
|
| 61 |
+
|
| 62 |
+
- **Expected Outcome:** An environment that gives the model a realistic simulation of handling personal tasks, conflicts and managing them as delegations
|
| 63 |
+
|
| 64 |
+
- **Example Environments:** Executive Assistant Meeting Planner, Dinner and drive planning, email and message replying, etc
|
| 65 |
+
|
| 66 |
+
- **Partner Sub-Theme:**
|
| 67 |
+
|
| 68 |
+
- **Patronus AI:** Consumer Workflows with Schema Drift: Multi-step consumer workflow environments where the underlying data schemas, API contracts, and t&cs/policies/rules change.
|
| 69 |
+
|
| 70 |
+
**Statement 4: Self-Improvement**
|
| 71 |
+
|
| 72 |
+
The focus here is to create environments where agents can learn to generate new challenges, escalate difficulty, and improve through self-play or adaptive curricula. Rather than optimizing fixed tasks, the goal is for agents to learn to drive their own capability growth. The objective is recursive skill amplification.
|
| 73 |
+
|
| 74 |
+
- **Expected Outcome:** an environment for improving self-play of a LLM over a defined set of tasks
|
| 75 |
+
|
| 76 |
+
- **Example Environments:** Self-play negotiation arenas, auto-generated math/proof tasks, evolving coding competitions, adaptive RL curricula.
|
| 77 |
+
|
| 78 |
+
- **Partner Sub-Theme:**
|
| 79 |
+
|
| 80 |
+
- **Snorkel AI:** Simulated Experts-in-the-Loop: Environment that simulates interactions with real subject-matter experts, with changing requirements / preferences.
|
| 81 |
+
|
| 82 |
+
**Statement 5: Wild Card - Impress Us!**
|
| 83 |
+
|
| 84 |
+
We do not want to limit your focus if your idea doesn’t fit the boxes above, we want and WILL reward out of box tasks, please be creative but remember to add submissions that meaningfully add value to LLM training on a certain task.
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
**Judging Criteria**
|
| 88 |
+
|
| 89 |
+
- **Environment Innovation (40%) -** Is the environment novel, creative, or challenging? Does it meaningfully test the agent’s behavior?
|
| 90 |
+
- **Storytelling (30%) -** Does the team clearly explain the problem, environment, and agent behavior? Is the demo engaging and easy to follow?
|
| 91 |
+
- **Training Script Showing Improvement in Rewards (20%) -** Does the demo provide observable evidence of training progress (reward curves, metrics, or before/after behavior)?
|
| 92 |
+
- **Reward and Training Pipeline Setup (10%) -** Is the reward logic coherent, and does the pipeline produce meaningful improvement in the agent’s inference (how it acts in the environment)?
|
| 93 |
+
|
| 94 |
+
**Judging Process**
|
| 95 |
+
|
| 96 |
+
**|** Judging proceeds in two rounds:
|
| 97 |
+
|
| 98 |
+
- Hackers will be assigned groups of judges; \~3 minutes to pitch followed by 1-2 minutes of Q/A
|
| 99 |
+
|
| 100 |
+
- The top **six** teams in ranking will get to demo on stage to a panel of judges; \~3 minutes to pitch followed by 2-3 minutes for Q/A.
|
| 101 |
+
|
| 102 |
+
## **11. Prizes**
|
| 103 |
+
|
| 104 |
+
- **1st Place:** $15,000 USD Cash
|
| 105 |
+
|
| 106 |
+
- **2nd Place:** $9,000 USD Cash
|
| 107 |
+
|
| 108 |
+
- **3rd Place:** $6,000 USD Cash
|
README.md
CHANGED
|
@@ -1,10 +1,137 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
|
| 7 |
-
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: RLM-Forge
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
base_path: /web
|
|
|
|
| 7 |
---
|
| 8 |
|
| 9 |
+
# RLM-Forge
|
| 10 |
+
|
| 11 |
+
**Recursive Language Model training environment for AI coding agents.**
|
| 12 |
+
|
| 13 |
+
RLM-Forge is an [OpenEnv](https://github.com/meta-pytorch/OpenEnv) environment that trains language models to solve coding tasks on real Python repositories using Recursive Language Model (RLM) patterns.
|
| 14 |
+
|
| 15 |
+
## How It Works
|
| 16 |
+
|
| 17 |
+
1. **Clone** a real Python repo (e.g., python-slugify, humanize)
|
| 18 |
+
2. **Extract** a source file and replace it with a broken stub (correct signatures, wrong implementations)
|
| 19 |
+
3. **Agent** explores the repo via a sandboxed multi-step REPL with built-in tools
|
| 20 |
+
4. **Reward** = test pass rate (55%) + structural validity (15%) + efficiency (30%)
|
| 21 |
+
5. **Train** with GRPO to improve the agent's coding ability over time
|
| 22 |
+
|
| 23 |
+
### The REPL Tools
|
| 24 |
+
|
| 25 |
+
The agent has access to these functions in the sandbox:
|
| 26 |
+
|
| 27 |
+
| Function | Description |
|
| 28 |
+
|----------|-------------|
|
| 29 |
+
| `read_file(path)` | Read a file from the repo |
|
| 30 |
+
| `list_dir(path='.')` | List directory contents |
|
| 31 |
+
| `search(pattern, path='.')` | Grep for a pattern |
|
| 32 |
+
| `write_file(path, content)` | Write/create a file |
|
| 33 |
+
| `run_tests(test_path=None)` | Run pytest |
|
| 34 |
+
| `spawn_agent(scope, mission)` | Explore a directory scope |
|
| 35 |
+
| `FINAL()` | Signal implementation is complete |
|
| 36 |
+
|
| 37 |
+
## Project Structure
|
| 38 |
+
|
| 39 |
+
```
|
| 40 |
+
rlm_forge/
|
| 41 |
+
├── __init__.py # Package exports
|
| 42 |
+
├── models.py # Pydantic models (Action, Observation, State)
|
| 43 |
+
├── client.py # EnvClient for remote connections
|
| 44 |
+
└── server/
|
| 45 |
+
├── app.py # FastAPI server (create_app)
|
| 46 |
+
├── environment.py # Core Environment (reset/step)
|
| 47 |
+
├── sandbox.py # Sandboxed Python REPL
|
| 48 |
+
├── repo_manager.py # Repo cloning & dependency management
|
| 49 |
+
├── feature_extractor.py # Source file extraction & stub generation
|
| 50 |
+
└── reward.py # Composite reward computation
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Quick Start
|
| 54 |
+
|
| 55 |
+
### Install
|
| 56 |
+
|
| 57 |
+
```bash
|
| 58 |
+
uv sync
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
### Run the Server
|
| 62 |
+
|
| 63 |
+
```bash
|
| 64 |
+
uv run uvicorn rlm_forge.server.app:app --host 0.0.0.0 --port 8000
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
### Use the Environment Directly
|
| 68 |
+
|
| 69 |
+
```python
|
| 70 |
+
from rlm_forge.server.environment import RLMForgeEnvironment
|
| 71 |
+
from rlm_forge.models import RLMForgeAction
|
| 72 |
+
|
| 73 |
+
env = RLMForgeEnvironment()
|
| 74 |
+
obs = env.reset(seed=1)
|
| 75 |
+
print(obs.task_description)
|
| 76 |
+
|
| 77 |
+
# Agent takes actions
|
| 78 |
+
obs = env.step(RLMForgeAction(code="print(read_file('test.py'))"))
|
| 79 |
+
obs = env.step(RLMForgeAction(code="write_file('slugify/slugify.py', '...')"))
|
| 80 |
+
obs = env.step(RLMForgeAction(code="FINAL()"))
|
| 81 |
+
print(f"Reward: {obs.reward}")
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
### Connect via Client
|
| 85 |
+
|
| 86 |
+
```python
|
| 87 |
+
from rlm_forge.client import RLMForgeClient
|
| 88 |
+
from rlm_forge.models import RLMForgeAction
|
| 89 |
+
|
| 90 |
+
client = RLMForgeClient(base_url="http://localhost:8000")
|
| 91 |
+
client.connect()
|
| 92 |
+
|
| 93 |
+
result = client.reset(seed=1)
|
| 94 |
+
result = client.step(RLMForgeAction(code="print(list_dir())"))
|
| 95 |
+
result = client.step(RLMForgeAction(code="FINAL()"))
|
| 96 |
+
print(f"Reward: {result.reward}")
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
## Training
|
| 100 |
+
|
| 101 |
+
See `rlm_forge_training.ipynb` for the full GRPO training notebook. Designed for Google Colab with an H100 GPU.
|
| 102 |
+
|
| 103 |
+
Key training approach:
|
| 104 |
+
- **Multi-step trajectory concatenation**: Full episode (all code actions) treated as one GRPO "completion"
|
| 105 |
+
- **Group Relative Policy Optimization**: Multiple completions per task, advantages computed relative to group mean
|
| 106 |
+
- **LoRA fine-tuning**: 4-bit quantized Qwen2.5-Coder-32B with LoRA adapter
|
| 107 |
+
|
| 108 |
+
## Reward Breakdown
|
| 109 |
+
|
| 110 |
+
| Component | Weight | Description |
|
| 111 |
+
|-----------|--------|-------------|
|
| 112 |
+
| Test Pass Rate | 55% | Fraction of tests passing |
|
| 113 |
+
| Structural Validity | 15% | AST parse check + import check |
|
| 114 |
+
| Efficiency | 30% | Tiered by iteration budget used |
|
| 115 |
+
|
| 116 |
+
## Curated Repos
|
| 117 |
+
|
| 118 |
+
| Repo | Source File | Tests | Difficulty |
|
| 119 |
+
|------|-----------|-------|------------|
|
| 120 |
+
| python-slugify | `slugify/slugify.py` | 82 | Easy |
|
| 121 |
+
| humanize (number) | `src/humanize/number.py` | 219 | Medium |
|
| 122 |
+
| humanize (time) | `src/humanize/time.py` | varies | Medium |
|
| 123 |
+
|
| 124 |
+
## Docker
|
| 125 |
+
|
| 126 |
+
```bash
|
| 127 |
+
docker build -t rlm-forge .
|
| 128 |
+
docker run -p 8000:8000 rlm-forge
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
The Dockerfile pre-clones curated repos to avoid network I/O on each `reset()`.
|
| 132 |
+
|
| 133 |
+
## Deploy to HF Spaces
|
| 134 |
+
|
| 135 |
+
```bash
|
| 136 |
+
openenv push -r your-username/rlm-forge
|
| 137 |
+
```
|
RLM_Forge_Project_Overview.md
ADDED
|
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RLM-Forge: A Recursive Language Model Training Environment for AI Coding Agents
|
| 2 |
+
|
| 3 |
+
## Project Overview
|
| 4 |
+
|
| 5 |
+
RLM-Forge is an OpenEnv environment designed to train small language models to utilize the Recursive Language Model (RLM) framework for solving complex coding tasks on large repositories. It is inspired by the research paper "Recursive Language Models" (Zhang, Kraska, & Khattab, MIT CSAIL, December 2025), which demonstrated that LLMs can process inputs orders of magnitude beyond their context windows by treating prompts as external environment variables and interacting with them through code execution in a REPL.
|
| 6 |
+
|
| 7 |
+
The core innovation of RLM-Forge is combining the RLM paradigm with depth-limited sub-agents for repository exploration, creating an environment where a root agent can orchestrate multiple sub-agents — each with their own scoped REPL and file-system tools — to understand and modify codebases far too large for any single model's context window.
|
| 8 |
+
|
| 9 |
+
The environment is self-supervised: it clones open-source repositories, programmatically removes a file or module that has associated test coverage, and tasks the agent with rebuilding that feature using only the surrounding codebase. The removed feature's test suite serves as an automatic, objective reward signal.
|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## Motivation & Research Background
|
| 14 |
+
|
| 15 |
+
### The Problem
|
| 16 |
+
|
| 17 |
+
Modern AI coding agents (Claude Code, Cursor, Codex CLI) struggle with very large repositories because a single agent must somehow fit enough context to understand the entire system. Context windows are finite, and even within those limits, model quality degrades as context grows longer — a phenomenon known as "context rot."
|
| 18 |
+
|
| 19 |
+
### The RLM Insight
|
| 20 |
+
|
| 21 |
+
The Recursive Language Models paper (arXiv:2512.24601) proposes a paradigm shift: instead of feeding long prompts directly into the neural network, treat the prompt as part of an external environment. The model interacts with the context through code — slicing, searching, chunking — and only pulls small pieces into its context window at a time. Crucially, the model can programmatically invoke sub-LM calls on constructed snippets, enabling recursive decomposition.
|
| 22 |
+
|
| 23 |
+
Key findings from the paper:
|
| 24 |
+
- RLMs handle inputs up to 10M+ tokens (two orders of magnitude beyond context windows)
|
| 25 |
+
- On information-dense tasks, RLMs outperform base models by 28-58% absolute
|
| 26 |
+
- The approach is model-agnostic and works with both closed and open-source models
|
| 27 |
+
- Costs remain comparable to base model calls at the median
|
| 28 |
+
- Emergent strategies appear without explicit training: regex filtering, intelligent chunking, answer verification, variable-based output stitching
|
| 29 |
+
|
| 30 |
+
### The Gap We Fill
|
| 31 |
+
|
| 32 |
+
The paper's "Future Work" section explicitly identifies the opportunity we are pursuing:
|
| 33 |
+
|
| 34 |
+
> "Explicitly training models to be used as RLMs (e.g. as root or sub-LMs) could provide additional performance improvements... We hypothesize that RLM trajectories can be viewed as a form of reasoning, which can be trained by bootstrapping existing frontier models."
|
| 35 |
+
|
| 36 |
+
We plan to allow a recursion depth of 1 (or 2?), so that a root agent can spawn sub-agents, and those sub-agents have access to their own REPL and file system tools, but the sub-agents cannot spawn their own sub-agents.
|
| 37 |
+
|
| 38 |
+
This will allow the model to be trained as both a root agent and a sub-agent, which is key to the success of the RLM-Forge environment.
|
| 39 |
+
|
| 40 |
+
### Why Coding Tasks?
|
| 41 |
+
|
| 42 |
+
Coding is the ideal domain for RLM training because:
|
| 43 |
+
1. **Natural structure**: Repositories have files, modules, imports, and tests — providing clear decomposition targets
|
| 44 |
+
2. **Objective evaluation**: Test suites provide automatic, binary reward signals
|
| 45 |
+
3. **Unlimited data**: Every well-tested open-source repository is a potential training example
|
| 46 |
+
4. **Real-world impact**: Improved coding agents have immediate practical value
|
| 47 |
+
5. **Complexity scaling**: Repositories naturally range from simple (100 LOC) to enormous (1M+ LOC), providing a natural curriculum
|
| 48 |
+
|
| 49 |
+
---
|
| 50 |
+
|
| 51 |
+
## Architecture Design
|
| 52 |
+
|
| 53 |
+
### Environment Type
|
| 54 |
+
|
| 55 |
+
RLM-Forge is an **OpenEnv environment** built on the OpenEnv 0.2.1 framework. It follows the standard OpenEnv pattern:
|
| 56 |
+
|
| 57 |
+
```
|
| 58 |
+
rlm_forge/
|
| 59 |
+
├── __init__.py
|
| 60 |
+
├── README.md
|
| 61 |
+
├── models.py # Action, Observation, State (Pydantic models)
|
| 62 |
+
├── client.py # HTTPEnvClient subclass
|
| 63 |
+
├── openenv.yaml # Environment manifest
|
| 64 |
+
├── pyproject.toml
|
| 65 |
+
├── uv.lock
|
| 66 |
+
└── server/
|
| 67 |
+
├── __init__.py
|
| 68 |
+
├── app.py # FastAPI server using create_app()
|
| 69 |
+
├── environment.py # Core Environment implementation
|
| 70 |
+
├── repo_manager.py # Repository cloning, feature extraction, test discovery
|
| 71 |
+
├── sandbox.py # Sandboxed code execution (REPL)
|
| 72 |
+
├── sub_agent.py # Sub-agent lifecycle management
|
| 73 |
+
├── reward.py # Composite reward computation
|
| 74 |
+
├── feature_extractor.py # Module/file removal and test mapping
|
| 75 |
+
└── Dockerfile
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### Core Concepts
|
| 79 |
+
|
| 80 |
+
#### The Root Agent
|
| 81 |
+
|
| 82 |
+
The root agent operates in an iterative REPL loop. It receives a task description and a high-level manifest of the repository (directory tree, file sizes, README excerpt). It does NOT see the actual source code in its context window. Instead, it writes Python code to:
|
| 83 |
+
- Explore the repository structure
|
| 84 |
+
- Read specific files
|
| 85 |
+
- Search for patterns (grep, regex, AST parsing)
|
| 86 |
+
- Spawn sub-agents to explore specific directories or modules
|
| 87 |
+
- Write implementation code
|
| 88 |
+
- Save files to rebuild the removed feature
|
| 89 |
+
|
| 90 |
+
#### Sub-Agents (Depth = 1)
|
| 91 |
+
|
| 92 |
+
Sub-agents are scoped explorers. When the root agent spawns a sub-agent, it specifies:
|
| 93 |
+
- A target scope (directory path or set of files)
|
| 94 |
+
- A mission (what to look for, what to report back)
|
| 95 |
+
- A budget (maximum iterations)
|
| 96 |
+
|
| 97 |
+
The sub-agent gets its own sandboxed REPL with:
|
| 98 |
+
- Read-only access to its scoped portion of the repository
|
| 99 |
+
- The ability to execute Python code (read files, parse ASTs, search, analyze)
|
| 100 |
+
- An `llm_query()` function for semantic understanding of code snippets
|
| 101 |
+
- NO ability to spawn further sub-agents (depth limit = 1)
|
| 102 |
+
|
| 103 |
+
The sub-agent runs its own iteration loop and returns a structured report to the root agent's REPL environment as a variable.
|
| 104 |
+
|
| 105 |
+
**Important distinction from the RLM paper**: In the paper, sub-calls are stateless LM calls — simple prompt-in, text-out. In RLM-Forge, sub-agents have their own REPL state, their own iteration loop, and their own tool access. They are mini-RLMs, not plain LM calls. This is the "depth-1 recursive RLM with tools" architecture. Sub-agents CANNOT spawn their own sub-agents.
|
| 106 |
+
|
| 107 |
+
#### The REPL Environment
|
| 108 |
+
|
| 109 |
+
Both root and sub-agents operate within sandboxed Python REPL environments. Key properties:
|
| 110 |
+
- **Persistent state**: Variables persist across iterations within an episode
|
| 111 |
+
- **Sandboxed execution**: Code runs in an isolated environment with controlled file system access
|
| 112 |
+
- **Truncated output**: stdout/stderr is truncated to prevent context overflow (configurable limit)
|
| 113 |
+
- **Iteration tracking**: The environment tracks iteration count against a configurable maximum
|
| 114 |
+
- **Built-in functions**:
|
| 115 |
+
- `llm_query(prompt: str) -> str` — Invoke a sub-LM for semantic understanding
|
| 116 |
+
- `spawn_agent(scope: str, mission: str, budget: int) -> dict` — Spawn a sub-agent (root only)
|
| 117 |
+
- `read_file(path: str) -> str` — Read a file from the repository
|
| 118 |
+
- `list_dir(path: str) -> list` — List directory contents
|
| 119 |
+
- `search(pattern: str, path: str) -> list` — Grep/regex search
|
| 120 |
+
- `write_file(path: str, content: str)` — Write implementation files (root only)
|
| 121 |
+
- `run_tests(test_path: str) -> dict` — Run specific test files and get results
|
| 122 |
+
- `FINAL()` — Signal episode completion
|
| 123 |
+
|
| 124 |
+
---
|
| 125 |
+
|
| 126 |
+
## Episode Lifecycle
|
| 127 |
+
|
| 128 |
+
### Phase 1: Environment Setup (on `reset()`)
|
| 129 |
+
|
| 130 |
+
1. **Repository selection**: The environment selects a repository from its configured dataset (a list of Git repository URLs or local paths)
|
| 131 |
+
2. **Clone and baseline**: Clone the repository. Run the full test suite to establish a baseline (all tests should pass)
|
| 132 |
+
3. **Feature extraction**: Select a target file or module for removal:
|
| 133 |
+
- Identify files/modules that have dedicated test files with clear mappings (e.g., `src/auth.py` → `tests/test_auth.py`)
|
| 134 |
+
- Prefer modules with moderate complexity (configurable LOC range)
|
| 135 |
+
- Record which tests are associated with the target
|
| 136 |
+
- Record the original content of the target (this is the ground truth, never shown to the agent)
|
| 137 |
+
4. **Feature removal**: Delete the target file(s) from the repository working copy
|
| 138 |
+
5. **Manifest generation**: Create a high-level manifest for the agent:
|
| 139 |
+
- Directory tree structure
|
| 140 |
+
- File sizes and languages
|
| 141 |
+
- README excerpt (first N characters)
|
| 142 |
+
- List of failing tests (names and file paths)
|
| 143 |
+
- Task description: "The following module has been removed: `[path]`. N tests in `[test_path]` are now failing. Your task is to implement the missing module so that all tests pass."
|
| 144 |
+
6. **REPL initialization**: Set up the root agent's REPL environment with the repository loaded and built-in functions available
|
| 145 |
+
7. **Return initial observation**: The observation includes the manifest, the task description, the failing test list, and REPL environment metadata (available variables, available functions)
|
| 146 |
+
|
| 147 |
+
### Phase 2: Agent Interaction (the `step()` loop)
|
| 148 |
+
|
| 149 |
+
Each step, the agent submits an action containing Python code to execute. The environment:
|
| 150 |
+
|
| 151 |
+
1. **Extracts code blocks** from the agent's response
|
| 152 |
+
2. **Executes each code block** in the sandboxed REPL
|
| 153 |
+
3. **Captures output** (stdout, stderr, success/failure, any variables set)
|
| 154 |
+
4. **Checks for sub-agent spawns**: If the code calls `spawn_agent()`, the environment:
|
| 155 |
+
- Creates a new scoped REPL for the sub-agent
|
| 156 |
+
- Runs the sub-agent's iteration loop (the sub-agent is driven by an `llm_query()` call internally, or by a policy if training the sub-agent)
|
| 157 |
+
- Returns the sub-agent's report as a variable in the root agent's REPL
|
| 158 |
+
5. **Checks for termination**: Episode ends if:
|
| 159 |
+
- Agent calls `FINAL()` — voluntary completion
|
| 160 |
+
- Maximum iterations reached — forced termination
|
| 161 |
+
- Maximum wall-clock time exceeded — timeout
|
| 162 |
+
6. **Returns observation**: stdout/stderr (truncated), success boolean, iteration count, list of available variables, any sub-agent reports
|
| 163 |
+
|
| 164 |
+
### Phase 3: Evaluation (on episode completion)
|
| 165 |
+
|
| 166 |
+
When the episode ends (either through `FINAL()` or iteration limit):
|
| 167 |
+
|
| 168 |
+
1. **Collect implementation**: Gather all files the agent wrote via `write_file()`
|
| 169 |
+
2. **Run target tests**: Execute the test files associated with the removed feature
|
| 170 |
+
3. **Run regression tests**: Execute the full test suite to check for regressions
|
| 171 |
+
4. **Compute composite reward** (see Reward Function below)
|
| 172 |
+
5. **Return final observation** with done=True, reward, and detailed test results
|
| 173 |
+
|
| 174 |
+
---
|
| 175 |
+
|
| 176 |
+
## Reward Function
|
| 177 |
+
|
| 178 |
+
The reward is a weighted composite of three components. Weights are configurable via environment parameters, with these defaults:
|
| 179 |
+
|
| 180 |
+
### Test Pass Rate (Default: 55% of total reward)
|
| 181 |
+
|
| 182 |
+
```
|
| 183 |
+
test_pass_reward = (num_target_tests_passed / num_target_tests_total)
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
This is the primary signal. The agent is rewarded proportionally to how many of the removed feature's tests it gets passing. Partial credit is given — passing 7 out of 10 tests yields 0.70 on this component.
|
| 187 |
+
|
| 188 |
+
### Structural Validity (Default: 15% of total reward)
|
| 189 |
+
|
| 190 |
+
```
|
| 191 |
+
structural_reward = weighted_average(
|
| 192 |
+
parse_success, # Does the code parse without syntax errors? (weight: 0.3)
|
| 193 |
+
import_success, # Do imports resolve correctly? (weight: 0.3)
|
| 194 |
+
no_regressions, # Do previously-passing tests still pass? (weight: 0.4)
|
| 195 |
+
)
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
This penalizes agents that produce invalid code or hack solutions that break the rest of the codebase. The regression check is particularly important — it prevents the agent from modifying shared utilities in ways that pass target tests but break everything else.
|
| 199 |
+
|
| 200 |
+
### Efficiency Bonus (Default: 30% of total reward)
|
| 201 |
+
|
| 202 |
+
```
|
| 203 |
+
if iterations_used <= budget * 0.5:
|
| 204 |
+
efficiency_reward = 1.0 # Full bonus for fast solutions
|
| 205 |
+
elif iterations_used <= budget * 0.75:
|
| 206 |
+
efficiency_reward = 0.75 # Reduced bonus
|
| 207 |
+
elif iterations_used <= budget:
|
| 208 |
+
efficiency_reward = 0.5 # Minimal bonus for using full budget
|
| 209 |
+
else:
|
| 210 |
+
efficiency_reward = 0.0 # No bonus if forced termination
|
| 211 |
+
|
| 212 |
+
# Sub-agent efficiency modifier
|
| 213 |
+
sub_agent_penalty = max(0, 1.0 - (num_sub_agents_spawned / max_reasonable_sub_agents))
|
| 214 |
+
efficiency_reward *= (0.7 + 0.3 * sub_agent_penalty)
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
This encourages the agent to learn efficient exploration and decomposition strategies. It rewards agents that solve problems quickly and use sub-agents judiciously rather than spawning one for every directory.
|
| 218 |
+
|
| 219 |
+
### Total Reward Computation
|
| 220 |
+
|
| 221 |
+
```
|
| 222 |
+
total_reward = (
|
| 223 |
+
test_weight * test_pass_reward +
|
| 224 |
+
structural_weight * structural_reward +
|
| 225 |
+
efficiency_weight * efficiency_reward
|
| 226 |
+
)
|
| 227 |
+
```
|
| 228 |
+
|
| 229 |
+
Where `test_weight`, `structural_weight`, and `efficiency_weight` are configurable and default to 0.55, 0.15, and 0.30 respectively.
|
| 230 |
+
|
| 231 |
+
---
|
| 232 |
+
|
| 233 |
+
## Data Models (Pydantic Schemas)
|
| 234 |
+
|
| 235 |
+
### Action
|
| 236 |
+
|
| 237 |
+
```python
|
| 238 |
+
class RLMForgeAction(Action):
|
| 239 |
+
"""Agent's action: Python code to execute in the REPL."""
|
| 240 |
+
code: str = Field(..., description="Python code to execute in the REPL environment")
|
| 241 |
+
action_type: str = Field(
|
| 242 |
+
default="execute",
|
| 243 |
+
description="Type of action: 'execute' for code, 'final' to submit solution"
|
| 244 |
+
)
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
### Observation
|
| 248 |
+
|
| 249 |
+
```python
|
| 250 |
+
class RLMForgeObservation(Observation):
|
| 251 |
+
"""What the agent sees after each step."""
|
| 252 |
+
# REPL execution results
|
| 253 |
+
stdout: str = Field(default="", description="Truncated stdout from code execution")
|
| 254 |
+
stderr: str = Field(default="", description="Truncated stderr from code execution")
|
| 255 |
+
success: bool = Field(default=True, description="Whether code executed without errors")
|
| 256 |
+
|
| 257 |
+
# Episode tracking
|
| 258 |
+
iteration: int = Field(default=0, description="Current iteration number")
|
| 259 |
+
max_iterations: int = Field(default=50, description="Maximum allowed iterations")
|
| 260 |
+
|
| 261 |
+
# Repository context (provided on reset, may be refreshed)
|
| 262 |
+
repo_manifest: Optional[dict] = Field(default=None, description="Repository structure manifest")
|
| 263 |
+
task_description: Optional[str] = Field(default=None, description="The coding task to complete")
|
| 264 |
+
failing_tests: Optional[list[str]] = Field(default=None, description="List of currently failing test names")
|
| 265 |
+
|
| 266 |
+
# REPL state
|
| 267 |
+
available_variables: list[str] = Field(default_factory=list, description="Variables currently in REPL scope")
|
| 268 |
+
available_functions: list[str] = Field(default_factory=list, description="Built-in functions available")
|
| 269 |
+
|
| 270 |
+
# Sub-agent reports (populated when sub-agents complete)
|
| 271 |
+
sub_agent_reports: list[dict] = Field(default_factory=list, description="Reports from completed sub-agents")
|
| 272 |
+
|
| 273 |
+
# Test results (populated on final evaluation)
|
| 274 |
+
test_results: Optional[dict] = Field(default=None, description="Detailed test results on completion")
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
### State
|
| 278 |
+
|
| 279 |
+
```python
|
| 280 |
+
class RLMForgeState(State):
|
| 281 |
+
"""Internal environment state, not directly sent to agent."""
|
| 282 |
+
episode_id: Optional[str] = None
|
| 283 |
+
step_count: int = 0
|
| 284 |
+
|
| 285 |
+
# Repository info
|
| 286 |
+
repo_url: str = ""
|
| 287 |
+
repo_local_path: str = ""
|
| 288 |
+
removed_feature_path: str = ""
|
| 289 |
+
removed_feature_content: dict[str, str] = {} # filename -> original content
|
| 290 |
+
target_test_files: list[str] = []
|
| 291 |
+
baseline_test_count: int = 0
|
| 292 |
+
|
| 293 |
+
# Agent progress
|
| 294 |
+
files_written: dict[str, str] = {} # filename -> content written by agent
|
| 295 |
+
sub_agents_spawned: int = 0
|
| 296 |
+
total_llm_queries: int = 0
|
| 297 |
+
|
| 298 |
+
# Evaluation
|
| 299 |
+
final_reward: Optional[float] = None
|
| 300 |
+
test_pass_rate: Optional[float] = None
|
| 301 |
+
has_regressions: Optional[bool] = None
|
| 302 |
+
```
|
| 303 |
+
|
| 304 |
+
---
|
| 305 |
+
|
| 306 |
+
## Feature Extraction Pipeline
|
| 307 |
+
|
| 308 |
+
The feature extraction pipeline is responsible for selecting what to remove from a repository and mapping it to tests. This is a critical component that must work reliably.
|
| 309 |
+
|
| 310 |
+
### Strategy: File and Module Level Extraction
|
| 311 |
+
|
| 312 |
+
The pipeline operates in two modes:
|
| 313 |
+
|
| 314 |
+
#### Single-File Mode
|
| 315 |
+
1. Scan the repository for Python/Rust/TS/Julia source files
|
| 316 |
+
2. For each source file, look for a corresponding test file using common patterns:
|
| 317 |
+
- `src/foo.py` → `tests/test_foo.py`
|
| 318 |
+
- `src/foo.py` → `tests/foo_test.py`
|
| 319 |
+
- `src/foo/bar.py` → `tests/test_bar.py`
|
| 320 |
+
- `lib/foo.rs` → `tests/foo.rs` or `tests/test_foo.rs`
|
| 321 |
+
- `src/foo.ts` → `__tests__/foo.test.ts` or `tests/foo.spec.ts`
|
| 322 |
+
3. Verify the test file actually imports from / tests the source file
|
| 323 |
+
4. Run the test file in isolation to confirm it passes
|
| 324 |
+
5. Score candidates by:
|
| 325 |
+
- Number of tests (prefer 5-30 tests; too few = trivial, too many = too complex)
|
| 326 |
+
- Source file LOC (prefer 50-500 lines for hackathon scope)
|
| 327 |
+
- Import complexity (prefer files that are imported by few other files, to minimize cascade)
|
| 328 |
+
|
| 329 |
+
#### Module Mode
|
| 330 |
+
1. Scan for directories that represent modules (contain `__init__.py` or are listed in package config)
|
| 331 |
+
2. Find test directories or files that correspond to the module
|
| 332 |
+
3. Same scoring criteria but at the module (directory) level
|
| 333 |
+
4. Prefer small, self-contained modules (2-8 files)
|
| 334 |
+
|
| 335 |
+
### Output of Feature Extraction
|
| 336 |
+
|
| 337 |
+
```python
|
| 338 |
+
@dataclass
|
| 339 |
+
class ExtractedFeature:
|
| 340 |
+
"""Represents a feature to be removed for training."""
|
| 341 |
+
source_paths: list[str] # Files to remove
|
| 342 |
+
test_paths: list[str] # Test files that exercise this feature
|
| 343 |
+
original_content: dict[str, str] # Map of path -> original file content
|
| 344 |
+
num_tests: int # Number of individual test cases
|
| 345 |
+
estimated_complexity: str # "easy", "medium", "hard"
|
| 346 |
+
import_dependents: list[str] # Files that import from the removed feature
|
| 347 |
+
task_description: str # Auto-generated task description for the agent
|
| 348 |
+
```
|
| 349 |
+
|
| 350 |
+
---
|
| 351 |
+
|
| 352 |
+
## Sub-Agent Mechanism
|
| 353 |
+
|
| 354 |
+
### Spawning a Sub-Agent
|
| 355 |
+
|
| 356 |
+
From the root agent's REPL:
|
| 357 |
+
|
| 358 |
+
```python
|
| 359 |
+
report = spawn_agent(
|
| 360 |
+
scope="/src/database/",
|
| 361 |
+
mission="Explore the database module. Report: 1) What ORM or database library is used, 2) What models/tables exist, 3) What patterns are used for queries, 4) The public API of this module",
|
| 362 |
+
budget=10 # max iterations for the sub-agent
|
| 363 |
+
)
|
| 364 |
+
# `report` is now a dict variable in the root agent's REPL
|
| 365 |
+
print(report["summary"])
|
| 366 |
+
print(report["files_examined"])
|
| 367 |
+
```
|
| 368 |
+
|
| 369 |
+
### Sub-Agent Lifecycle
|
| 370 |
+
|
| 371 |
+
1. **Initialization**: A new sandboxed REPL is created with read-only access to the specified scope
|
| 372 |
+
2. **Mission prompt**: The sub-agent receives a system prompt with:
|
| 373 |
+
- Its scoped directory listing
|
| 374 |
+
- The mission description from the root agent
|
| 375 |
+
- Available built-in functions (read_file, list_dir, search, llm_query)
|
| 376 |
+
- Its iteration budget
|
| 377 |
+
3. **Iteration loop**: The sub-agent iterates (driven by `llm_query` internally):
|
| 378 |
+
- Writes code to explore its scope
|
| 379 |
+
- Executes code, observes results
|
| 380 |
+
- Refines its understanding
|
| 381 |
+
- Calls `FINAL(report)` when done or budget exhausted
|
| 382 |
+
4. **Report return**: The sub-agent's final report (a structured dict) is injected as a variable into the root agent's REPL
|
| 383 |
+
|
| 384 |
+
### Sub-Agent Constraints
|
| 385 |
+
|
| 386 |
+
- **Read-only file access**: Sub-agents can read files within their scope but cannot write files
|
| 387 |
+
- **No sub-agent spawning**: Sub-agents cannot spawn their own sub-agents (depth = 1)
|
| 388 |
+
- **Scoped access**: Sub-agents can only access files within their assigned directory scope
|
| 389 |
+
- **Budget limited**: Each sub-agent has a maximum iteration count
|
| 390 |
+
- **Concurrent limit**: The root agent can have at most N sub-agents per episode (configurable, default 10)
|
| 391 |
+
|
| 392 |
+
---
|
| 393 |
+
|
| 394 |
+
## Repository Dataset
|
| 395 |
+
|
| 396 |
+
### Requirements for Training Repositories
|
| 397 |
+
|
| 398 |
+
Each repository used as a training dataset must have:
|
| 399 |
+
1. **Strong test coverage** with test files that clearly map to source modules
|
| 400 |
+
2. **Modular architecture** where individual files/modules can be removed without collapsing the entire project
|
| 401 |
+
3. **Medium-large size** (10,000 - 150,000 LOC)
|
| 402 |
+
4. **Active maintenance** (commits within last 3 months)
|
| 403 |
+
5. **Permissive license** (MIT, Apache 2.0, BSD)
|
| 404 |
+
6. **80%+ in one of**: Python, Rust, TypeScript, or Julia
|
| 405 |
+
|
| 406 |
+
### Repository Configuration
|
| 407 |
+
|
| 408 |
+
```yaml
|
| 409 |
+
# repos.yaml - Dataset configuration
|
| 410 |
+
repositories:
|
| 411 |
+
- url: "https://github.com/org/repo1"
|
| 412 |
+
language: "python"
|
| 413 |
+
difficulty: "medium"
|
| 414 |
+
test_command: "pytest"
|
| 415 |
+
source_dir: "src/"
|
| 416 |
+
test_dir: "tests/"
|
| 417 |
+
|
| 418 |
+
- url: "https://github.com/org/repo2"
|
| 419 |
+
language: "rust"
|
| 420 |
+
difficulty: "hard"
|
| 421 |
+
test_command: "cargo test"
|
| 422 |
+
source_dir: "src/"
|
| 423 |
+
test_dir: "tests/"
|
| 424 |
+
|
| 425 |
+
# ... more repositories
|
| 426 |
+
|
| 427 |
+
settings:
|
| 428 |
+
max_file_loc: 500 # Max LOC for single-file extraction
|
| 429 |
+
max_module_files: 8 # Max files for module extraction
|
| 430 |
+
min_tests: 3 # Minimum tests for a valid feature
|
| 431 |
+
max_tests: 50 # Maximum tests (avoid overly complex features)
|
| 432 |
+
preferred_test_range: [5, 30] # Sweet spot for test count
|
| 433 |
+
```
|
| 434 |
+
|
| 435 |
+
---
|
| 436 |
+
|
| 437 |
+
## Hackathon Problem Statement Alignment
|
| 438 |
+
|
| 439 |
+
RLM-Forge addresses multiple hackathon problem statements:
|
| 440 |
+
|
| 441 |
+
### Primary: Statement 2 — Long-Horizon Planning & Instruction Following
|
| 442 |
+
|
| 443 |
+
The environment requires deep, multi-step reasoning with delayed rewards. The agent must:
|
| 444 |
+
- Decompose the goal of rebuilding a feature into exploration sub-tasks
|
| 445 |
+
- Track state across an extended REPL trajectory (potentially dozens of iterations)
|
| 446 |
+
- Recover from wrong turns (exploring irrelevant code, writing buggy implementations)
|
| 447 |
+
- Plan sub-agent deployments strategically
|
| 448 |
+
|
| 449 |
+
### Secondary: Statement 3.1 — World Modeling (Professional Tasks)
|
| 450 |
+
|
| 451 |
+
The environment involves real interaction with tools and dynamic systems:
|
| 452 |
+
- File system exploration with real code
|
| 453 |
+
- Test execution with real pass/fail results
|
| 454 |
+
- Code execution in a sandboxed REPL
|
| 455 |
+
- Multi-step workflows: explore → understand → plan → implement → verify
|
| 456 |
+
|
| 457 |
+
### Partner Sub-Theme: Mercor (Statement 2)
|
| 458 |
+
|
| 459 |
+
"Make an environment with capped/uncapped rewards where frontier model rewards scale with token output." — RLM-Forge naturally fits this: longer, more sophisticated RLM trajectories that correctly process more of the codebase should earn higher rewards, as they'll pass more tests.
|
| 460 |
+
|
| 461 |
+
---
|
| 462 |
+
|
| 463 |
+
## Implementation Plan
|
| 464 |
+
|
| 465 |
+
### Phase 1: Core Environment Scaffold
|
| 466 |
+
|
| 467 |
+
1. Set up the OpenEnv project structure using `openenv init`
|
| 468 |
+
2. Define all Pydantic models (Action, Observation, State)
|
| 469 |
+
3. Implement the basic `Environment` class with `reset()` and `step()` stubs
|
| 470 |
+
4. Implement the sandboxed REPL (code execution with safety restrictions)
|
| 471 |
+
5. Implement the `app.py` FastAPI server and `client.py`
|
| 472 |
+
6. Verify the environment scaffold works with `openenv validate`
|
| 473 |
+
|
| 474 |
+
### Phase 2: Repository & Feature Pipeline
|
| 475 |
+
|
| 476 |
+
1. Implement `repo_manager.py` — repository cloning, caching, test suite discovery
|
| 477 |
+
2. Implement `feature_extractor.py` — file/module selection, test mapping, feature removal
|
| 478 |
+
3. Build the manifest generator (directory tree, file metadata, task description)
|
| 479 |
+
4. Test the pipeline end-to-end on 2-3 repositories
|
| 480 |
+
5. Handle multi-language support (Python pytest, Rust cargo test, TS jest/vitest)
|
| 481 |
+
|
| 482 |
+
### Phase 3: Sub-Agent System
|
| 483 |
+
|
| 484 |
+
1. Implement `sub_agent.py` — sub-agent REPL creation, scoping, lifecycle
|
| 485 |
+
2. Implement `spawn_agent()` as a built-in REPL function
|
| 486 |
+
3. Implement the sub-agent iteration loop with `llm_query()` integration
|
| 487 |
+
4. Implement sub-agent report format and injection into root REPL
|
| 488 |
+
5. Add sub-agent budget tracking and concurrent limits
|
| 489 |
+
6. Test sub-agent spawning and report aggregation
|
| 490 |
+
|
| 491 |
+
### Phase 4: Reward & Evaluation
|
| 492 |
+
|
| 493 |
+
1. Implement `reward.py` — test execution, pass rate calculation, regression detection
|
| 494 |
+
2. Implement structural validity checks (parsing, import resolution)
|
| 495 |
+
3. Implement efficiency scoring
|
| 496 |
+
4. Implement the composite reward computation with configurable weights
|
| 497 |
+
5. Test reward computation on sample episodes
|
| 498 |
+
|
| 499 |
+
### Phase 5: Integration, Docker & HF Spaces
|
| 500 |
+
|
| 501 |
+
1. Full integration testing — run complete episodes end-to-end
|
| 502 |
+
2. Build the Dockerfile with all dependencies (git, language runtimes, test frameworks)
|
| 503 |
+
3. Configure the Gradio web UI for the HF Space
|
| 504 |
+
4. Deploy to HF Spaces using `openenv push`
|
| 505 |
+
5. Verify the deployed environment works remotely
|
| 506 |
+
|
| 507 |
+
### Phase 6: Minimal Training Demo
|
| 508 |
+
|
| 509 |
+
1. Create a Google Colab notebook
|
| 510 |
+
2. Set up Unsloth + a small model (Qwen2.5-1.5B or similar)
|
| 511 |
+
3. Connect to the deployed environment
|
| 512 |
+
4. Implement GRPO training loop with the environment's reward function
|
| 513 |
+
5. Run a few training steps to demonstrate the pipeline works
|
| 514 |
+
6. Save results and training curves
|
| 515 |
+
|
| 516 |
+
### Phase 7: Demo Video & Submission
|
| 517 |
+
|
| 518 |
+
1. Record 1-minute YouTube demo video
|
| 519 |
+
2. Final testing and bug fixes
|
| 520 |
+
3. Submit to hackathon
|
| 521 |
+
|
| 522 |
+
---
|
| 523 |
+
|
| 524 |
+
## Key Technical Resources
|
| 525 |
+
|
| 526 |
+
### OpenEnv Framework
|
| 527 |
+
- OpenEnv GitHub: `https://github.com/meta-pytorch/OpenEnv`
|
| 528 |
+
- OpenEnv 0.2.1 stable release
|
| 529 |
+
- Environment builder guide: `docs/source/getting_started/environment-builder.md`
|
| 530 |
+
- Existing REPL environment: `src/envs/repl_env/` (study this closely as a reference)
|
| 531 |
+
- Existing coding environment: `src/envs/coding_env/` (another key reference)
|
| 532 |
+
- 2048 RL training tutorial: `docs/source/tutorials/rl-training-2048.md`
|
| 533 |
+
|
| 534 |
+
### RLM Paper
|
| 535 |
+
- arXiv: `https://arxiv.org/abs/2512.24601`
|
| 536 |
+
- Key sections: §2 (methods), §3.1 (emergent patterns), §5 (limitations/future work)
|
| 537 |
+
- System prompts: Appendix D (pages 24-28)
|
| 538 |
+
- Example trajectories: Appendix B (pages 13-20)
|
| 539 |
+
|
| 540 |
+
### Training Stack
|
| 541 |
+
- Unsloth: Memory-efficient fine-tuning with LoRA
|
| 542 |
+
- HuggingFace TRL: GRPO (Group Relative Policy Optimization)
|
| 543 |
+
- Google Colab: Free T4 GPU for the training demo
|
| 544 |
+
|
| 545 |
+
### Sandboxing
|
| 546 |
+
- Docker isolation (primary — OpenEnv already uses this)
|
| 547 |
+
- RestrictedPython or similar for additional code execution safety
|
| 548 |
+
- Filesystem scoping via chroot or bind mounts
|
| 549 |
+
|
| 550 |
+
---
|
| 551 |
+
|
| 552 |
+
## Configuration & Defaults
|
| 553 |
+
|
| 554 |
+
All key parameters should be configurable through the environment's reset kwargs or openenv.yaml:
|
| 555 |
+
|
| 556 |
+
```yaml
|
| 557 |
+
# openenv.yaml
|
| 558 |
+
name: rlm_forge
|
| 559 |
+
version: "0.1.0"
|
| 560 |
+
description: "RLM training environment for AI coding agents"
|
| 561 |
+
|
| 562 |
+
defaults:
|
| 563 |
+
# Episode parameters
|
| 564 |
+
max_iterations: 50
|
| 565 |
+
max_wall_clock_seconds: 600
|
| 566 |
+
max_sub_agents: 10
|
| 567 |
+
sub_agent_budget: 15
|
| 568 |
+
output_truncation_chars: 5000
|
| 569 |
+
|
| 570 |
+
# Reward weights
|
| 571 |
+
test_pass_weight: 0.55
|
| 572 |
+
structural_validity_weight: 0.15
|
| 573 |
+
efficiency_weight: 0.30
|
| 574 |
+
|
| 575 |
+
# Feature extraction
|
| 576 |
+
extraction_mode: "mixed" # "file", "module", or "mixed"
|
| 577 |
+
min_source_loc: 50
|
| 578 |
+
max_source_loc: 500
|
| 579 |
+
min_tests: 3
|
| 580 |
+
max_tests: 50
|
| 581 |
+
|
| 582 |
+
# Sub-agent configuration
|
| 583 |
+
sub_agent_max_iterations: 15
|
| 584 |
+
sub_agent_output_truncation: 3000
|
| 585 |
+
sub_agent_read_only: true
|
| 586 |
+
sub_agent_depth_limit: 1
|
| 587 |
+
```
|
| 588 |
+
|
| 589 |
+
---
|
| 590 |
+
|
| 591 |
+
## Success Criteria
|
| 592 |
+
|
| 593 |
+
### For the Hackathon
|
| 594 |
+
|
| 595 |
+
1. **Working environment** deployed on HF Spaces that accepts reset/step/state API calls
|
| 596 |
+
2. **Feature extraction** working on at least 2-3 demonstration repositories
|
| 597 |
+
3. **Sub-agent spawning** functional with scoped REPL access
|
| 598 |
+
4. **Reward computation** returning meaningful composite scores
|
| 599 |
+
5. **Minimal training notebook** in Colab showing GRPO training loop connecting to the environment
|
| 600 |
+
6. **1-minute demo video** explaining the concept and showing the environment in action
|
| 601 |
+
|
| 602 |
+
### For Long-Term Value
|
| 603 |
+
|
| 604 |
+
1. Environment generalizes across programming languages and repository structures
|
| 605 |
+
2. Reward signal is informative enough for models to learn meaningful exploration strategies
|
| 606 |
+
3. Sub-agent reports genuinely improve root agent performance vs. no sub-agents
|
| 607 |
+
4. Trained models show transfer to unseen repositories
|
| 608 |
+
5. Environment can serve as a benchmark for comparing coding agent architectures
|
__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OpenEnv environment root package.
|
| 2 |
+
from rlm_forge import * # noqa: F401, F403
|
client.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Re-export client for OpenEnv standard layout.
|
| 2 |
+
from rlm_forge.client import * # noqa: F401, F403
|
main.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def main():
|
| 2 |
+
print("Hello from rlm-forge!")
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Re-export models for OpenEnv standard layout.
|
| 2 |
+
from rlm_forge.models import * # noqa: F401, F403
|
openenv.yaml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: rlm_forge
|
| 2 |
+
version: "0.1.0"
|
| 3 |
+
description: "RLM-Forge: Recursive Language Model training environment for AI coding agents"
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
max_iterations: 10
|
| 7 |
+
max_sub_agents: 10
|
| 8 |
+
output_truncation_chars: 5000
|
| 9 |
+
|
| 10 |
+
# Reward weights
|
| 11 |
+
test_pass_weight: 0.55
|
| 12 |
+
structural_validity_weight: 0.15
|
| 13 |
+
efficiency_weight: 0.30
|
| 14 |
+
|
| 15 |
+
# Feature extraction
|
| 16 |
+
extraction_mode: "curated"
|
| 17 |
+
min_source_loc: 50
|
| 18 |
+
max_source_loc: 500
|
| 19 |
+
min_tests: 3
|
| 20 |
+
max_tests: 50
|
pyproject.toml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "rlm-forge"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "RLM-Forge: Recursive Language Model training environment for AI coding agents"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.11"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"fastapi>=0.135.1",
|
| 9 |
+
"freezegun>=1.5.5",
|
| 10 |
+
"gitpython>=3.1.0",
|
| 11 |
+
"openenv-core[core]>=0.2.0",
|
| 12 |
+
"pydantic>=2.0.0",
|
| 13 |
+
"pytest>=9.0.2",
|
| 14 |
+
"requests>=2.31.0",
|
| 15 |
+
"text-unidecode>=1.3",
|
| 16 |
+
"uvicorn[standard]>=0.24.0",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
[project.scripts]
|
| 20 |
+
server = "server.app:main"
|
| 21 |
+
|
| 22 |
+
[tool.setuptools.packages.find]
|
| 23 |
+
include = ["rlm_forge*"]
|
rlm_forge/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RLM-Forge: Recursive Language Model Training Environment for AI Coding Agents."""
|
| 2 |
+
|
| 3 |
+
from .models import RLMForgeAction, RLMForgeObservation, RLMForgeState
|
| 4 |
+
|
| 5 |
+
__all__ = ["RLMForgeAction", "RLMForgeObservation", "RLMForgeState"]
|
rlm_forge/client.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Client for connecting to a remote RLM-Forge environment."""
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
from openenv.core import EnvClient
|
| 6 |
+
from openenv.core.env_client import StepResult
|
| 7 |
+
|
| 8 |
+
from .models import RLMForgeAction, RLMForgeObservation, RLMForgeState
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RLMForgeClient(EnvClient[RLMForgeAction, RLMForgeObservation, RLMForgeState]):
|
| 12 |
+
"""Client for the RLM-Forge environment."""
|
| 13 |
+
|
| 14 |
+
def _step_payload(self, action: RLMForgeAction) -> Dict[str, Any]:
|
| 15 |
+
return {"code": action.code, "action_type": action.action_type}
|
| 16 |
+
|
| 17 |
+
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[RLMForgeObservation]:
|
| 18 |
+
obs = RLMForgeObservation(**payload["observation"])
|
| 19 |
+
return StepResult(
|
| 20 |
+
observation=obs,
|
| 21 |
+
reward=payload.get("reward"),
|
| 22 |
+
done=payload.get("done", False),
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
def _parse_state(self, payload: Dict[str, Any]) -> RLMForgeState:
|
| 26 |
+
return RLMForgeState(**payload)
|
rlm_forge/models.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic models for RLM-Forge environment."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from pydantic import Field
|
| 5 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class RLMForgeAction(Action):
|
| 9 |
+
"""Agent submits Python code to execute in the REPL."""
|
| 10 |
+
|
| 11 |
+
code: str = Field(..., description="Python code to execute in the REPL environment")
|
| 12 |
+
action_type: str = Field(
|
| 13 |
+
default="execute",
|
| 14 |
+
description="Type of action: 'execute' for code, 'final' to submit solution",
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RLMForgeObservation(Observation):
|
| 19 |
+
"""What the agent sees after each step.
|
| 20 |
+
|
| 21 |
+
Inherits from Observation base:
|
| 22 |
+
done: bool = False
|
| 23 |
+
reward: Optional[float] = None
|
| 24 |
+
metadata: Dict[str, Any] = {}
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
stdout: str = Field(default="", description="Truncated stdout from code execution")
|
| 28 |
+
stderr: str = Field(default="", description="Truncated stderr from code execution")
|
| 29 |
+
success: bool = Field(default=True, description="Whether code executed without errors")
|
| 30 |
+
iteration: int = Field(default=0, description="Current iteration number")
|
| 31 |
+
max_iterations: int = Field(default=10, description="Maximum allowed iterations")
|
| 32 |
+
repo_manifest: Optional[dict] = Field(
|
| 33 |
+
default=None, description="Repository structure manifest"
|
| 34 |
+
)
|
| 35 |
+
task_description: Optional[str] = Field(
|
| 36 |
+
default=None, description="The coding task to complete"
|
| 37 |
+
)
|
| 38 |
+
failing_tests: Optional[list[str]] = Field(
|
| 39 |
+
default=None, description="List of currently failing test names"
|
| 40 |
+
)
|
| 41 |
+
available_functions: list[str] = Field(
|
| 42 |
+
default_factory=list, description="Built-in functions available in the REPL"
|
| 43 |
+
)
|
| 44 |
+
test_results: Optional[dict] = Field(
|
| 45 |
+
default=None, description="Detailed test results on completion"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class RLMForgeState(State):
|
| 50 |
+
"""Internal environment state, not directly sent to agent.
|
| 51 |
+
|
| 52 |
+
Inherits from State base:
|
| 53 |
+
episode_id: Optional[str] = None
|
| 54 |
+
step_count: int = 0
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
repo_url: str = ""
|
| 58 |
+
repo_local_path: str = ""
|
| 59 |
+
removed_feature_path: str = ""
|
| 60 |
+
removed_feature_content: str = ""
|
| 61 |
+
target_test_files: list[str] = Field(default_factory=list)
|
| 62 |
+
baseline_test_count: int = 0
|
| 63 |
+
files_written: dict[str, str] = Field(default_factory=dict)
|
| 64 |
+
sub_agents_spawned: int = 0
|
| 65 |
+
final_reward: Optional[float] = None
|
rlm_forge/server/__init__.py
ADDED
|
File without changes
|
rlm_forge/server/app.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI server for RLM-Forge environment."""
|
| 2 |
+
|
| 3 |
+
from openenv.core.env_server import create_app
|
| 4 |
+
|
| 5 |
+
from ..models import RLMForgeAction, RLMForgeObservation
|
| 6 |
+
from .environment import RLMForgeEnvironment
|
| 7 |
+
|
| 8 |
+
# OpenEnv's HTTP server calls the factory per-request.
|
| 9 |
+
# Use a singleton so reset/step share the same environment instance.
|
| 10 |
+
_singleton_env = None
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _env_factory():
|
| 14 |
+
global _singleton_env
|
| 15 |
+
if _singleton_env is None:
|
| 16 |
+
_singleton_env = RLMForgeEnvironment()
|
| 17 |
+
return _singleton_env
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
app = create_app(
|
| 21 |
+
_env_factory,
|
| 22 |
+
RLMForgeAction,
|
| 23 |
+
RLMForgeObservation,
|
| 24 |
+
env_name="rlm_forge",
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
import uvicorn
|
| 29 |
+
|
| 30 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
rlm_forge/server/environment.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Core RLM-Forge Environment implementation."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import uuid
|
| 6 |
+
from typing import Any, Optional
|
| 7 |
+
|
| 8 |
+
from openenv.core.env_server import Environment
|
| 9 |
+
|
| 10 |
+
from ..models import RLMForgeAction, RLMForgeObservation, RLMForgeState
|
| 11 |
+
from .feature_extractor import CURATED_PAIRS, FeatureExtractor
|
| 12 |
+
from .repo_manager import RepoManager
|
| 13 |
+
from .reward import RewardComputer
|
| 14 |
+
from .sandbox import REPLSandbox
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class RLMForgeEnvironment(
|
| 18 |
+
Environment[RLMForgeAction, RLMForgeObservation, RLMForgeState]
|
| 19 |
+
):
|
| 20 |
+
"""RLM-Forge: Recursive Language Model training environment for coding agents.
|
| 21 |
+
|
| 22 |
+
Clones a Python repo, removes a source file with test coverage, and provides
|
| 23 |
+
a multi-step REPL for the agent to explore and rebuild the feature.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
SUPPORTS_CONCURRENT_SESSIONS = False
|
| 27 |
+
|
| 28 |
+
def __init__(self):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.repo_manager = RepoManager()
|
| 31 |
+
self.feature_extractor = FeatureExtractor()
|
| 32 |
+
self.reward_computer = RewardComputer()
|
| 33 |
+
self._state = RLMForgeState()
|
| 34 |
+
self._sandbox: Optional[REPLSandbox] = None
|
| 35 |
+
self._max_iterations = 10
|
| 36 |
+
|
| 37 |
+
def reset(
|
| 38 |
+
self,
|
| 39 |
+
seed: Optional[int] = None,
|
| 40 |
+
episode_id: Optional[str] = None,
|
| 41 |
+
**kwargs: Any,
|
| 42 |
+
) -> RLMForgeObservation:
|
| 43 |
+
"""Clone repo, remove feature, return initial observation."""
|
| 44 |
+
# Clean up previous episode
|
| 45 |
+
if self._state.repo_local_path:
|
| 46 |
+
self.repo_manager.cleanup(self._state.repo_local_path)
|
| 47 |
+
|
| 48 |
+
if seed is not None:
|
| 49 |
+
random.seed(seed)
|
| 50 |
+
|
| 51 |
+
# Select a curated pair
|
| 52 |
+
pair = random.choice(CURATED_PAIRS)
|
| 53 |
+
|
| 54 |
+
# AMENDMENT 2: Use pre-cloned repos if available, else clone from network
|
| 55 |
+
pre_cloned_dir = os.environ.get("RLM_FORGE_PRE_CLONED_DIR", "")
|
| 56 |
+
repo_name = pair["repo_url"].rstrip("/").split("/")[-1]
|
| 57 |
+
pre_cloned_path = os.path.join(pre_cloned_dir, repo_name) if pre_cloned_dir else ""
|
| 58 |
+
|
| 59 |
+
if pre_cloned_path and os.path.isdir(pre_cloned_path):
|
| 60 |
+
repo_path = self.repo_manager.copy_pre_cloned(pre_cloned_path)
|
| 61 |
+
else:
|
| 62 |
+
repo_path = self.repo_manager.clone_repo(pair["repo_url"])
|
| 63 |
+
|
| 64 |
+
# Install dependencies (best-effort)
|
| 65 |
+
self.repo_manager.install_dependencies(repo_path)
|
| 66 |
+
|
| 67 |
+
# Extract feature (remove source file)
|
| 68 |
+
feature = self.feature_extractor.extract_feature(
|
| 69 |
+
repo_path, pair["source_file"], pair["test_file"]
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Generate manifest
|
| 73 |
+
manifest = self.repo_manager.generate_manifest(repo_path)
|
| 74 |
+
|
| 75 |
+
# Create sandbox
|
| 76 |
+
self._sandbox = REPLSandbox(repo_path)
|
| 77 |
+
|
| 78 |
+
# Get initial failing test info
|
| 79 |
+
initial_test_result = self._sandbox._run_tests(pair["test_file"])
|
| 80 |
+
failing_tests = [
|
| 81 |
+
f"FAILING: {pair['test_file']} "
|
| 82 |
+
f"({initial_test_result.get('failed', '?')} failures, "
|
| 83 |
+
f"{initial_test_result.get('errors', '?')} errors)"
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
# Initialize state
|
| 87 |
+
self._state = RLMForgeState(
|
| 88 |
+
episode_id=episode_id or str(uuid.uuid4()),
|
| 89 |
+
step_count=0,
|
| 90 |
+
repo_url=pair["repo_url"],
|
| 91 |
+
repo_local_path=repo_path,
|
| 92 |
+
removed_feature_path=pair["source_file"],
|
| 93 |
+
removed_feature_content=feature.original_content,
|
| 94 |
+
target_test_files=[pair["test_file"]],
|
| 95 |
+
baseline_test_count=feature.num_tests,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
return RLMForgeObservation(
|
| 99 |
+
stdout="Environment initialized. Repository cloned and feature removed.",
|
| 100 |
+
stderr="",
|
| 101 |
+
success=True,
|
| 102 |
+
iteration=0,
|
| 103 |
+
max_iterations=self._max_iterations,
|
| 104 |
+
repo_manifest=manifest,
|
| 105 |
+
task_description=feature.task_description,
|
| 106 |
+
failing_tests=failing_tests,
|
| 107 |
+
available_functions=self._sandbox.available_functions,
|
| 108 |
+
done=False,
|
| 109 |
+
reward=None,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def step(
|
| 113 |
+
self,
|
| 114 |
+
action: RLMForgeAction,
|
| 115 |
+
timeout_s: Optional[float] = None,
|
| 116 |
+
**kwargs: Any,
|
| 117 |
+
) -> RLMForgeObservation:
|
| 118 |
+
"""Execute code in REPL, check for termination, compute reward if done."""
|
| 119 |
+
if self._sandbox is None:
|
| 120 |
+
raise RuntimeError("Environment not initialized. Call reset() first.")
|
| 121 |
+
|
| 122 |
+
self._state.step_count += 1
|
| 123 |
+
|
| 124 |
+
# Check for explicit final action or iteration limit
|
| 125 |
+
if action.action_type == "final":
|
| 126 |
+
return self._finalize_episode()
|
| 127 |
+
|
| 128 |
+
if self._state.step_count >= self._max_iterations:
|
| 129 |
+
return self._finalize_episode()
|
| 130 |
+
|
| 131 |
+
# Execute code in sandbox
|
| 132 |
+
result = self._sandbox.execute(action.code)
|
| 133 |
+
|
| 134 |
+
# Check if FINAL() was called in the code
|
| 135 |
+
if result["final_called"]:
|
| 136 |
+
return self._finalize_episode()
|
| 137 |
+
|
| 138 |
+
return RLMForgeObservation(
|
| 139 |
+
stdout=result["stdout"],
|
| 140 |
+
stderr=result["stderr"],
|
| 141 |
+
success=result["success"],
|
| 142 |
+
iteration=self._state.step_count,
|
| 143 |
+
max_iterations=self._max_iterations,
|
| 144 |
+
available_functions=self._sandbox.available_functions,
|
| 145 |
+
done=False,
|
| 146 |
+
reward=None,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def _finalize_episode(self) -> RLMForgeObservation:
|
| 150 |
+
"""Compute reward and return final observation."""
|
| 151 |
+
assert self._sandbox is not None
|
| 152 |
+
|
| 153 |
+
reward_result = self.reward_computer.compute(
|
| 154 |
+
repo_path=self._state.repo_local_path,
|
| 155 |
+
target_test=self._state.target_test_files[0],
|
| 156 |
+
files_written=self._sandbox.files_written,
|
| 157 |
+
max_iterations=self._max_iterations,
|
| 158 |
+
iterations_used=self._state.step_count,
|
| 159 |
+
baseline_test_count=self._state.baseline_test_count,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
self._state.final_reward = reward_result["total_reward"]
|
| 163 |
+
self._state.files_written = self._sandbox.files_written
|
| 164 |
+
self._state.sub_agents_spawned = self._sandbox._sub_agents_spawned
|
| 165 |
+
|
| 166 |
+
return RLMForgeObservation(
|
| 167 |
+
stdout=f"Episode complete. Reward: {reward_result['total_reward']:.3f}",
|
| 168 |
+
stderr="",
|
| 169 |
+
success=True,
|
| 170 |
+
iteration=self._state.step_count,
|
| 171 |
+
max_iterations=self._max_iterations,
|
| 172 |
+
test_results=reward_result,
|
| 173 |
+
done=True,
|
| 174 |
+
reward=reward_result["total_reward"],
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
@property
|
| 178 |
+
def state(self) -> RLMForgeState:
|
| 179 |
+
return self._state
|
| 180 |
+
|
| 181 |
+
def close(self):
|
| 182 |
+
"""No-op for HTTP singleton. Use cleanup() for explicit teardown."""
|
| 183 |
+
# OpenEnv HTTP server calls close() after each request handler.
|
| 184 |
+
# For singleton mode, we must NOT destroy state here.
|
| 185 |
+
# Actual cleanup happens in reset() (previous episode) or explicit cleanup().
|
| 186 |
+
pass
|
| 187 |
+
|
| 188 |
+
def cleanup(self):
|
| 189 |
+
"""Explicit teardown: remove cloned repo."""
|
| 190 |
+
if self._state.repo_local_path:
|
| 191 |
+
self.repo_manager.cleanup(self._state.repo_local_path)
|
| 192 |
+
self._state.repo_local_path = ""
|
rlm_forge/server/feature_extractor.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Semi-automatic feature extraction: discovers (source, test) pairs and removes features."""
|
| 2 |
+
|
| 3 |
+
import ast
|
| 4 |
+
import os
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class ExtractedFeature:
|
| 11 |
+
"""Represents a feature removed from a repo for training."""
|
| 12 |
+
|
| 13 |
+
source_path: str
|
| 14 |
+
test_path: str
|
| 15 |
+
original_content: str
|
| 16 |
+
num_tests: int
|
| 17 |
+
difficulty: str
|
| 18 |
+
task_description: str
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Curated fallback pairs — known-good (repo, source, test) triples
|
| 22 |
+
# AMENDMENT 1: python-slugify test file is test.py at root, NOT test/test_slugify.py
|
| 23 |
+
CURATED_PAIRS = [
|
| 24 |
+
{
|
| 25 |
+
"repo_url": "https://github.com/un33k/python-slugify",
|
| 26 |
+
"source_file": "slugify/slugify.py",
|
| 27 |
+
"test_file": "test.py",
|
| 28 |
+
"test_command": "pytest test.py -v",
|
| 29 |
+
"difficulty": "easy",
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"repo_url": "https://github.com/python-humanize/humanize",
|
| 33 |
+
"source_file": "src/humanize/number.py",
|
| 34 |
+
"test_file": "tests/test_number.py",
|
| 35 |
+
"test_command": "pytest tests/test_number.py -v",
|
| 36 |
+
"difficulty": "medium",
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"repo_url": "https://github.com/python-humanize/humanize",
|
| 40 |
+
"source_file": "src/humanize/time.py",
|
| 41 |
+
"test_file": "tests/test_time.py",
|
| 42 |
+
"test_command": "pytest tests/test_time.py -v",
|
| 43 |
+
"difficulty": "medium",
|
| 44 |
+
},
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class FeatureExtractor:
|
| 49 |
+
"""Discovers and extracts (source, test) pairs from Python repos."""
|
| 50 |
+
|
| 51 |
+
def discover_pairs(self, repo_path: str) -> list[dict]:
|
| 52 |
+
"""Auto-discover (source, test) pairs via filename pattern matching."""
|
| 53 |
+
pairs = []
|
| 54 |
+
test_files = self._find_test_files(repo_path)
|
| 55 |
+
|
| 56 |
+
for test_file in test_files:
|
| 57 |
+
source_file = self._match_source_file(repo_path, test_file)
|
| 58 |
+
if source_file and self._verify_import(repo_path, test_file, source_file):
|
| 59 |
+
num_tests = self._count_tests(os.path.join(repo_path, test_file))
|
| 60 |
+
source_loc = self._count_lines(os.path.join(repo_path, source_file))
|
| 61 |
+
|
| 62 |
+
# Filter by complexity sweet spot
|
| 63 |
+
if 3 <= num_tests <= 50 and 30 <= source_loc <= 500:
|
| 64 |
+
pairs.append(
|
| 65 |
+
{
|
| 66 |
+
"source_path": source_file,
|
| 67 |
+
"test_path": test_file,
|
| 68 |
+
"num_tests": num_tests,
|
| 69 |
+
"source_loc": source_loc,
|
| 70 |
+
}
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Sort by best fit (prefer 5-20 tests, 50-300 LOC)
|
| 74 |
+
pairs.sort(
|
| 75 |
+
key=lambda p: abs(p["num_tests"] - 12) + abs(p["source_loc"] - 150)
|
| 76 |
+
)
|
| 77 |
+
return pairs
|
| 78 |
+
|
| 79 |
+
def _find_test_files(self, repo_path: str) -> list[str]:
|
| 80 |
+
"""Find all test files in the repo."""
|
| 81 |
+
test_files = []
|
| 82 |
+
for root, dirs, files in os.walk(repo_path):
|
| 83 |
+
dirs[:] = [d for d in dirs if not d.startswith(".") and d != "__pycache__"]
|
| 84 |
+
for f in files:
|
| 85 |
+
if f.endswith(".py") and (
|
| 86 |
+
f.startswith("test_") or f.endswith("_test.py")
|
| 87 |
+
):
|
| 88 |
+
rel = os.path.relpath(os.path.join(root, f), repo_path)
|
| 89 |
+
test_files.append(rel)
|
| 90 |
+
return test_files
|
| 91 |
+
|
| 92 |
+
def _match_source_file(
|
| 93 |
+
self, repo_path: str, test_file: str
|
| 94 |
+
) -> Optional[str]:
|
| 95 |
+
"""Given test_foo.py, find foo.py in common source locations."""
|
| 96 |
+
test_basename = os.path.basename(test_file)
|
| 97 |
+
|
| 98 |
+
if test_basename.startswith("test_"):
|
| 99 |
+
source_name = test_basename[5:] # Remove "test_" prefix
|
| 100 |
+
elif test_basename.endswith("_test.py"):
|
| 101 |
+
source_name = test_basename[:-8] + ".py"
|
| 102 |
+
else:
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
# Search common source locations
|
| 106 |
+
search_dirs = ["src", "lib", "."]
|
| 107 |
+
|
| 108 |
+
# Also try package directories (dirs with __init__.py)
|
| 109 |
+
try:
|
| 110 |
+
for item in os.listdir(repo_path):
|
| 111 |
+
item_path = os.path.join(repo_path, item)
|
| 112 |
+
if os.path.isdir(item_path) and os.path.exists(
|
| 113 |
+
os.path.join(item_path, "__init__.py")
|
| 114 |
+
):
|
| 115 |
+
search_dirs.append(item)
|
| 116 |
+
except Exception:
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
for search_dir in search_dirs:
|
| 120 |
+
if search_dir == ".":
|
| 121 |
+
candidate = source_name
|
| 122 |
+
else:
|
| 123 |
+
candidate = os.path.join(search_dir, source_name)
|
| 124 |
+
|
| 125 |
+
if os.path.exists(os.path.join(repo_path, candidate)):
|
| 126 |
+
return candidate
|
| 127 |
+
|
| 128 |
+
# Also search subdirectories of src/
|
| 129 |
+
src_dir = os.path.join(repo_path, search_dir)
|
| 130 |
+
if os.path.isdir(src_dir):
|
| 131 |
+
for sub in os.listdir(src_dir):
|
| 132 |
+
sub_candidate = os.path.join(search_dir, sub, source_name)
|
| 133 |
+
if os.path.exists(os.path.join(repo_path, sub_candidate)):
|
| 134 |
+
return sub_candidate
|
| 135 |
+
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
def _verify_import(
|
| 139 |
+
self, repo_path: str, test_file: str, source_file: str
|
| 140 |
+
) -> bool:
|
| 141 |
+
"""Check if test_file likely imports from source_file (basic heuristic)."""
|
| 142 |
+
try:
|
| 143 |
+
base_name = os.path.splitext(os.path.basename(source_file))[0]
|
| 144 |
+
test_content = open(os.path.join(repo_path, test_file)).read()
|
| 145 |
+
return base_name in test_content
|
| 146 |
+
except Exception:
|
| 147 |
+
return False
|
| 148 |
+
|
| 149 |
+
def _count_tests(self, test_file_path: str) -> int:
|
| 150 |
+
"""Count test functions/methods in a test file using AST."""
|
| 151 |
+
try:
|
| 152 |
+
with open(test_file_path) as f:
|
| 153 |
+
tree = ast.parse(f.read())
|
| 154 |
+
count = 0
|
| 155 |
+
for node in ast.walk(tree):
|
| 156 |
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
| 157 |
+
if node.name.startswith("test_"):
|
| 158 |
+
count += 1
|
| 159 |
+
return count
|
| 160 |
+
except Exception:
|
| 161 |
+
return 0
|
| 162 |
+
|
| 163 |
+
def _generate_stub(self, original_content: str) -> str:
|
| 164 |
+
"""Generate a stub module with correct function/class signatures but broken implementations.
|
| 165 |
+
|
| 166 |
+
Parses the original source with AST to extract all top-level function
|
| 167 |
+
and class definitions, then generates a stub that:
|
| 168 |
+
- Has the same imports (so dependencies resolve)
|
| 169 |
+
- Has the same function/class names with correct signatures
|
| 170 |
+
- Returns None/raises NotImplementedError for all functions
|
| 171 |
+
"""
|
| 172 |
+
try:
|
| 173 |
+
tree = ast.parse(original_content)
|
| 174 |
+
except SyntaxError:
|
| 175 |
+
return "# Stub: original file could not be parsed\n"
|
| 176 |
+
|
| 177 |
+
lines = ["# STUB: This file needs to be reimplemented.\n"]
|
| 178 |
+
lines.append("# All functions return None — tests will fail.\n\n")
|
| 179 |
+
|
| 180 |
+
# Preserve imports from the original
|
| 181 |
+
for node in ast.iter_child_nodes(tree):
|
| 182 |
+
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
| 183 |
+
lines.append(ast.get_source_segment(original_content, node) + "\n")
|
| 184 |
+
|
| 185 |
+
lines.append("\n")
|
| 186 |
+
|
| 187 |
+
# Generate stub functions/classes
|
| 188 |
+
for node in ast.iter_child_nodes(tree):
|
| 189 |
+
if isinstance(node, ast.FunctionDef):
|
| 190 |
+
# Extract the full signature from source using body start line
|
| 191 |
+
func_lines = original_content.splitlines()
|
| 192 |
+
# Signature spans from the def line to the line before the body
|
| 193 |
+
body_start = node.body[0].lineno # 1-indexed
|
| 194 |
+
sig_lines = func_lines[node.lineno - 1 : body_start - 1]
|
| 195 |
+
signature = "\n".join(sig_lines)
|
| 196 |
+
if not signature.rstrip().endswith(":"):
|
| 197 |
+
signature = signature.rstrip() + ":"
|
| 198 |
+
lines.append(f"{signature}\n")
|
| 199 |
+
lines.append(" return None\n\n")
|
| 200 |
+
|
| 201 |
+
elif isinstance(node, ast.ClassDef):
|
| 202 |
+
lines.append(f"class {node.name}:\n")
|
| 203 |
+
lines.append(" pass\n\n")
|
| 204 |
+
|
| 205 |
+
elif isinstance(node, ast.Assign):
|
| 206 |
+
# Preserve top-level variable assignments
|
| 207 |
+
segment = ast.get_source_segment(original_content, node)
|
| 208 |
+
if segment:
|
| 209 |
+
lines.append(segment + "\n")
|
| 210 |
+
|
| 211 |
+
return "".join(lines)
|
| 212 |
+
|
| 213 |
+
def _patch_init_files(self, repo_path: str, removed_source: str) -> None:
|
| 214 |
+
"""Remove imports of the deleted module from __init__.py files.
|
| 215 |
+
|
| 216 |
+
When a module like `package/number.py` is removed, the package's
|
| 217 |
+
`__init__.py` may do `from package.number import ...` which would
|
| 218 |
+
crash the entire package import. We comment out those lines.
|
| 219 |
+
"""
|
| 220 |
+
module_base = os.path.splitext(os.path.basename(removed_source))[0]
|
| 221 |
+
source_dir = os.path.dirname(removed_source)
|
| 222 |
+
|
| 223 |
+
# Check __init__.py in the same directory as the removed file
|
| 224 |
+
init_path = os.path.join(repo_path, source_dir, "__init__.py")
|
| 225 |
+
if not os.path.exists(init_path):
|
| 226 |
+
return
|
| 227 |
+
|
| 228 |
+
try:
|
| 229 |
+
with open(init_path, "r") as f:
|
| 230 |
+
lines = f.readlines()
|
| 231 |
+
|
| 232 |
+
patched = []
|
| 233 |
+
in_multiline_import = False
|
| 234 |
+
for line in lines:
|
| 235 |
+
# Detect imports referencing the removed module
|
| 236 |
+
if in_multiline_import:
|
| 237 |
+
patched.append(f"# [RLM-FORGE REMOVED] {line}")
|
| 238 |
+
if ")" in line:
|
| 239 |
+
in_multiline_import = False
|
| 240 |
+
elif f".{module_base}" in line and ("import" in line or "from" in line):
|
| 241 |
+
patched.append(f"# [RLM-FORGE REMOVED] {line}")
|
| 242 |
+
if "(" in line and ")" not in line:
|
| 243 |
+
in_multiline_import = True
|
| 244 |
+
elif f'"{module_base}"' in line or f"'{module_base}'" in line:
|
| 245 |
+
# Catch __all__ references
|
| 246 |
+
patched.append(line)
|
| 247 |
+
else:
|
| 248 |
+
patched.append(line)
|
| 249 |
+
|
| 250 |
+
with open(init_path, "w") as f:
|
| 251 |
+
f.writelines(patched)
|
| 252 |
+
except Exception:
|
| 253 |
+
pass
|
| 254 |
+
|
| 255 |
+
def _count_lines(self, file_path: str) -> int:
|
| 256 |
+
try:
|
| 257 |
+
with open(file_path) as f:
|
| 258 |
+
return sum(1 for _ in f)
|
| 259 |
+
except Exception:
|
| 260 |
+
return 0
|
| 261 |
+
|
| 262 |
+
def extract_feature(
|
| 263 |
+
self, repo_path: str, source_path: str, test_path: str
|
| 264 |
+
) -> ExtractedFeature:
|
| 265 |
+
"""Remove source file and create the ExtractedFeature."""
|
| 266 |
+
full_source = os.path.join(repo_path, source_path)
|
| 267 |
+
full_test = os.path.join(repo_path, test_path)
|
| 268 |
+
|
| 269 |
+
# Save original content
|
| 270 |
+
with open(full_source, "r") as f:
|
| 271 |
+
original_content = f.read()
|
| 272 |
+
|
| 273 |
+
# Count tests
|
| 274 |
+
num_tests = self._count_tests(full_test)
|
| 275 |
+
|
| 276 |
+
# Replace the source file with a stub that has correct signatures
|
| 277 |
+
# but wrong implementations. This ensures:
|
| 278 |
+
# - Other modules can still import from it (no cascading ImportErrors)
|
| 279 |
+
# - Tests FAIL (not ERROR), giving a better reward signal
|
| 280 |
+
# - The agent's job is to write the correct implementation
|
| 281 |
+
stub = self._generate_stub(original_content)
|
| 282 |
+
with open(full_source, "w") as f:
|
| 283 |
+
f.write(stub)
|
| 284 |
+
|
| 285 |
+
# Generate task description
|
| 286 |
+
task_description = (
|
| 287 |
+
f"The file `{source_path}` has been replaced with a broken stub. "
|
| 288 |
+
f"{num_tests} tests in `{test_path}` are now failing. "
|
| 289 |
+
f"Your task is to explore the repository, understand the expected behavior "
|
| 290 |
+
f"from the tests and other code, and rewrite `{source_path}` with a correct "
|
| 291 |
+
f"implementation so that all tests pass.\n\n"
|
| 292 |
+
f"Available tools:\n"
|
| 293 |
+
f" read_file(path) - Read a file from the repo\n"
|
| 294 |
+
f" list_dir(path='.') - List directory contents\n"
|
| 295 |
+
f" search(pattern, path='.') - Grep for a pattern\n"
|
| 296 |
+
f" write_file(path, content) - Write/create a file\n"
|
| 297 |
+
f" run_tests(test_path=None) - Run pytest on a test file\n"
|
| 298 |
+
f" spawn_agent(scope, mission, budget=5) - Explore a directory scope\n"
|
| 299 |
+
f" FINAL() - Signal that your implementation is complete\n\n"
|
| 300 |
+
f"Call FINAL() when you believe your implementation is complete."
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
return ExtractedFeature(
|
| 304 |
+
source_path=source_path,
|
| 305 |
+
test_path=test_path,
|
| 306 |
+
original_content=original_content,
|
| 307 |
+
num_tests=num_tests,
|
| 308 |
+
difficulty="medium",
|
| 309 |
+
task_description=task_description,
|
| 310 |
+
)
|
rlm_forge/server/repo_manager.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Repository cloning, dependency installation, and manifest generation."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import subprocess
|
| 6 |
+
import sys
|
| 7 |
+
import tempfile
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class RepoManager:
|
| 11 |
+
"""Manages repository cloning and lifecycle."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, cache_dir: str = "/tmp/rlm_forge_repos"):
|
| 14 |
+
self.cache_dir = cache_dir
|
| 15 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 16 |
+
|
| 17 |
+
def clone_repo(self, repo_url: str) -> str:
|
| 18 |
+
"""Clone repo to a unique temp directory. Returns path."""
|
| 19 |
+
work_dir = tempfile.mkdtemp(dir=self.cache_dir, prefix="rlm_")
|
| 20 |
+
subprocess.run(
|
| 21 |
+
["git", "clone", "--depth=1", repo_url, work_dir],
|
| 22 |
+
check=True,
|
| 23 |
+
capture_output=True,
|
| 24 |
+
timeout=120,
|
| 25 |
+
)
|
| 26 |
+
return work_dir
|
| 27 |
+
|
| 28 |
+
def copy_pre_cloned(self, pre_cloned_path: str) -> str:
|
| 29 |
+
"""Copy a pre-cloned repo directory for a fresh episode. Returns new path."""
|
| 30 |
+
work_dir = tempfile.mkdtemp(dir=self.cache_dir, prefix="rlm_")
|
| 31 |
+
# Remove the empty temp dir first, then copy
|
| 32 |
+
shutil.rmtree(work_dir)
|
| 33 |
+
shutil.copytree(pre_cloned_path, work_dir)
|
| 34 |
+
return work_dir
|
| 35 |
+
|
| 36 |
+
def install_dependencies(self, repo_path: str) -> bool:
|
| 37 |
+
"""Best-effort dependency installation using uv pip (falls back to pip)."""
|
| 38 |
+
uv_path = shutil.which("uv")
|
| 39 |
+
|
| 40 |
+
# Build install command: prefer uv pip, fall back to sys.executable -m pip
|
| 41 |
+
def _pip_install(args: list[str]) -> bool:
|
| 42 |
+
if uv_path:
|
| 43 |
+
cmd = [uv_path, "pip", "install"] + args
|
| 44 |
+
else:
|
| 45 |
+
cmd = [sys.executable, "-m", "pip", "install"] + args
|
| 46 |
+
try:
|
| 47 |
+
subprocess.run(
|
| 48 |
+
cmd, capture_output=True, timeout=120, check=True
|
| 49 |
+
)
|
| 50 |
+
return True
|
| 51 |
+
except Exception:
|
| 52 |
+
return False
|
| 53 |
+
|
| 54 |
+
# Try pyproject.toml / setup.py first
|
| 55 |
+
has_pyproject = os.path.exists(os.path.join(repo_path, "pyproject.toml"))
|
| 56 |
+
has_setup = os.path.exists(os.path.join(repo_path, "setup.py"))
|
| 57 |
+
if has_pyproject or has_setup:
|
| 58 |
+
if _pip_install(["-e", repo_path]):
|
| 59 |
+
return True
|
| 60 |
+
|
| 61 |
+
# Try requirements.txt
|
| 62 |
+
req_file = os.path.join(repo_path, "requirements.txt")
|
| 63 |
+
if os.path.exists(req_file):
|
| 64 |
+
if _pip_install(["-r", req_file]):
|
| 65 |
+
return True
|
| 66 |
+
|
| 67 |
+
return False
|
| 68 |
+
|
| 69 |
+
def generate_manifest(self, repo_path: str) -> dict:
|
| 70 |
+
"""Generate a high-level manifest of the repo structure."""
|
| 71 |
+
manifest: dict = {"files": [], "total_files": 0, "total_loc": 0}
|
| 72 |
+
|
| 73 |
+
for root, dirs, files in os.walk(repo_path):
|
| 74 |
+
dirs[:] = [
|
| 75 |
+
d for d in dirs if not d.startswith(".") and d != "__pycache__"
|
| 76 |
+
]
|
| 77 |
+
for f in files:
|
| 78 |
+
if f.endswith(".py"):
|
| 79 |
+
full_path = os.path.join(root, f)
|
| 80 |
+
rel_path = os.path.relpath(full_path, repo_path)
|
| 81 |
+
try:
|
| 82 |
+
with open(full_path) as fh:
|
| 83 |
+
loc = sum(1 for _ in fh)
|
| 84 |
+
except Exception:
|
| 85 |
+
loc = 0
|
| 86 |
+
manifest["files"].append({"path": rel_path, "loc": loc})
|
| 87 |
+
manifest["total_files"] += 1
|
| 88 |
+
manifest["total_loc"] += loc
|
| 89 |
+
|
| 90 |
+
# Read README excerpt if available
|
| 91 |
+
for readme_name in ["README.md", "README.rst", "README.txt", "README"]:
|
| 92 |
+
readme_path = os.path.join(repo_path, readme_name)
|
| 93 |
+
if os.path.exists(readme_path):
|
| 94 |
+
try:
|
| 95 |
+
with open(readme_path) as f:
|
| 96 |
+
manifest["readme_excerpt"] = f.read()[:2000]
|
| 97 |
+
except Exception:
|
| 98 |
+
pass
|
| 99 |
+
break
|
| 100 |
+
|
| 101 |
+
return manifest
|
| 102 |
+
|
| 103 |
+
def cleanup(self, repo_path: str):
|
| 104 |
+
"""Remove cloned repo directory."""
|
| 105 |
+
if repo_path and repo_path.startswith(self.cache_dir):
|
| 106 |
+
shutil.rmtree(repo_path, ignore_errors=True)
|
rlm_forge/server/reward.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Composite reward computation for RLM-Forge episodes."""
|
| 2 |
+
|
| 3 |
+
import ast
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import subprocess
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class RewardComputer:
|
| 10 |
+
"""Computes composite reward: test pass rate + structural validity + efficiency."""
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
test_weight: float = 0.55,
|
| 15 |
+
structural_weight: float = 0.15,
|
| 16 |
+
efficiency_weight: float = 0.30,
|
| 17 |
+
):
|
| 18 |
+
self.test_weight = test_weight
|
| 19 |
+
self.structural_weight = structural_weight
|
| 20 |
+
self.efficiency_weight = efficiency_weight
|
| 21 |
+
|
| 22 |
+
def compute(
|
| 23 |
+
self,
|
| 24 |
+
repo_path: str,
|
| 25 |
+
target_test: str,
|
| 26 |
+
files_written: dict[str, str],
|
| 27 |
+
max_iterations: int,
|
| 28 |
+
iterations_used: int,
|
| 29 |
+
baseline_test_count: int,
|
| 30 |
+
) -> dict:
|
| 31 |
+
"""Compute composite reward. Returns detailed breakdown."""
|
| 32 |
+
# 1. Test pass rate (55%)
|
| 33 |
+
test_result = self._run_target_tests(repo_path, target_test)
|
| 34 |
+
total_tests = max(test_result["total"], baseline_test_count, 1)
|
| 35 |
+
test_pass_rate = test_result["passed"] / total_tests
|
| 36 |
+
|
| 37 |
+
# 2. Structural validity (15%)
|
| 38 |
+
structural_score = self._compute_structural(repo_path, files_written)
|
| 39 |
+
|
| 40 |
+
# 3. Efficiency (30%)
|
| 41 |
+
efficiency_score = self._compute_efficiency(iterations_used, max_iterations)
|
| 42 |
+
|
| 43 |
+
# Composite
|
| 44 |
+
total = (
|
| 45 |
+
self.test_weight * test_pass_rate
|
| 46 |
+
+ self.structural_weight * structural_score
|
| 47 |
+
+ self.efficiency_weight * efficiency_score
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
return {
|
| 51 |
+
"total_reward": round(total, 4),
|
| 52 |
+
"test_pass_rate": round(test_pass_rate, 4),
|
| 53 |
+
"tests_passed": test_result["passed"],
|
| 54 |
+
"tests_failed": test_result["failed"],
|
| 55 |
+
"tests_total": test_result["total"],
|
| 56 |
+
"structural_score": round(structural_score, 4),
|
| 57 |
+
"efficiency_score": round(efficiency_score, 4),
|
| 58 |
+
"breakdown": {
|
| 59 |
+
"test_component": round(self.test_weight * test_pass_rate, 4),
|
| 60 |
+
"structural_component": round(
|
| 61 |
+
self.structural_weight * structural_score, 4
|
| 62 |
+
),
|
| 63 |
+
"efficiency_component": round(
|
| 64 |
+
self.efficiency_weight * efficiency_score, 4
|
| 65 |
+
),
|
| 66 |
+
},
|
| 67 |
+
"test_output": test_result.get("output", "")[:2000],
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
def _run_target_tests(self, repo_path: str, test_path: str) -> dict:
|
| 71 |
+
"""Run the target test file and parse results."""
|
| 72 |
+
import sys
|
| 73 |
+
|
| 74 |
+
cmd = [sys.executable, "-m", "pytest", "-v", "--tb=short", "--no-header"]
|
| 75 |
+
cmd.append(os.path.join(repo_path, test_path))
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
result = subprocess.run(
|
| 79 |
+
cmd,
|
| 80 |
+
capture_output=True,
|
| 81 |
+
text=True,
|
| 82 |
+
timeout=60,
|
| 83 |
+
cwd=repo_path,
|
| 84 |
+
)
|
| 85 |
+
raw_output = result.stdout + result.stderr
|
| 86 |
+
# Strip ANSI color codes for reliable parsing
|
| 87 |
+
output = re.sub(r"\x1b\[[0-9;]*m", "", raw_output)
|
| 88 |
+
passed = len(re.findall(r" PASSED", output))
|
| 89 |
+
failed = len(re.findall(r" FAILED", output))
|
| 90 |
+
errors = len(re.findall(r" ERROR", output))
|
| 91 |
+
|
| 92 |
+
return {
|
| 93 |
+
"passed": passed,
|
| 94 |
+
"failed": failed,
|
| 95 |
+
"errors": errors,
|
| 96 |
+
"total": passed + failed + errors,
|
| 97 |
+
"output": output[:3000],
|
| 98 |
+
"returncode": result.returncode,
|
| 99 |
+
}
|
| 100 |
+
except subprocess.TimeoutExpired:
|
| 101 |
+
return {
|
| 102 |
+
"passed": 0,
|
| 103 |
+
"failed": 0,
|
| 104 |
+
"errors": 1,
|
| 105 |
+
"total": 1,
|
| 106 |
+
"output": "Test execution timed out",
|
| 107 |
+
"returncode": -1,
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
def _compute_structural(
|
| 111 |
+
self, repo_path: str, files_written: dict[str, str]
|
| 112 |
+
) -> float:
|
| 113 |
+
"""Check structural validity of written files."""
|
| 114 |
+
if not files_written:
|
| 115 |
+
return 0.0
|
| 116 |
+
|
| 117 |
+
file_scores = []
|
| 118 |
+
for path, content in files_written.items():
|
| 119 |
+
# Parse check (weight 0.3)
|
| 120 |
+
try:
|
| 121 |
+
ast.parse(content)
|
| 122 |
+
parse_ok = 1.0
|
| 123 |
+
except SyntaxError:
|
| 124 |
+
parse_ok = 0.0
|
| 125 |
+
|
| 126 |
+
# Import check (weight 0.3)
|
| 127 |
+
module_name = path.replace("/", ".").replace(".py", "")
|
| 128 |
+
try:
|
| 129 |
+
import sys
|
| 130 |
+
|
| 131 |
+
result = subprocess.run(
|
| 132 |
+
[
|
| 133 |
+
sys.executable,
|
| 134 |
+
"-c",
|
| 135 |
+
f"import importlib; importlib.import_module('{module_name}')",
|
| 136 |
+
],
|
| 137 |
+
capture_output=True,
|
| 138 |
+
timeout=10,
|
| 139 |
+
cwd=repo_path,
|
| 140 |
+
)
|
| 141 |
+
import_ok = 1.0 if result.returncode == 0 else 0.0
|
| 142 |
+
except Exception:
|
| 143 |
+
import_ok = 0.0
|
| 144 |
+
|
| 145 |
+
file_scores.append(0.3 * parse_ok + 0.3 * import_ok)
|
| 146 |
+
|
| 147 |
+
avg_file_score = sum(file_scores) / len(file_scores)
|
| 148 |
+
|
| 149 |
+
# Regression check (weight 0.4)
|
| 150 |
+
# For hackathon: assume no regressions since we only modify the removed file
|
| 151 |
+
regression_score = 0.4
|
| 152 |
+
|
| 153 |
+
return avg_file_score + regression_score
|
| 154 |
+
|
| 155 |
+
def _compute_efficiency(
|
| 156 |
+
self, iterations_used: int, max_iterations: int
|
| 157 |
+
) -> float:
|
| 158 |
+
"""Tiered efficiency score."""
|
| 159 |
+
if max_iterations <= 0:
|
| 160 |
+
return 0.0
|
| 161 |
+
ratio = iterations_used / max_iterations
|
| 162 |
+
if ratio <= 0.5:
|
| 163 |
+
return 1.0
|
| 164 |
+
elif ratio <= 0.75:
|
| 165 |
+
return 0.75
|
| 166 |
+
elif ratio <= 1.0:
|
| 167 |
+
return 0.5
|
| 168 |
+
else:
|
| 169 |
+
return 0.0
|
rlm_forge/server/sandbox.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sandboxed Python REPL using exec() with persistent globals."""
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import io
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import subprocess
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class REPLSandbox:
|
| 11 |
+
"""Sandboxed Python REPL with built-in tool functions for repo exploration."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, repo_path: str, max_output_chars: int = 5000):
|
| 14 |
+
self.repo_path = os.path.realpath(repo_path)
|
| 15 |
+
self.max_output_chars = max_output_chars
|
| 16 |
+
self.files_written: dict[str, str] = {}
|
| 17 |
+
self._final_called = False
|
| 18 |
+
self._sub_agents_spawned = 0
|
| 19 |
+
|
| 20 |
+
self.globals_dict: dict = {"__builtins__": __builtins__}
|
| 21 |
+
self.globals_dict.update(
|
| 22 |
+
{
|
| 23 |
+
"read_file": self._read_file,
|
| 24 |
+
"list_dir": self._list_dir,
|
| 25 |
+
"search": self._search,
|
| 26 |
+
"write_file": self._write_file,
|
| 27 |
+
"run_tests": self._run_tests,
|
| 28 |
+
"spawn_agent": self._spawn_agent,
|
| 29 |
+
"FINAL": self._final,
|
| 30 |
+
}
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def execute(self, code: str) -> dict:
|
| 34 |
+
"""Execute code in the sandbox, return stdout/stderr/success."""
|
| 35 |
+
stdout_capture = io.StringIO()
|
| 36 |
+
stderr_capture = io.StringIO()
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(
|
| 40 |
+
stderr_capture
|
| 41 |
+
):
|
| 42 |
+
exec(code, self.globals_dict)
|
| 43 |
+
success = True
|
| 44 |
+
except Exception as e:
|
| 45 |
+
stderr_capture.write(f"{type(e).__name__}: {e}\n")
|
| 46 |
+
success = False
|
| 47 |
+
|
| 48 |
+
stdout = stdout_capture.getvalue()[: self.max_output_chars]
|
| 49 |
+
stderr = stderr_capture.getvalue()[: self.max_output_chars]
|
| 50 |
+
|
| 51 |
+
return {
|
| 52 |
+
"stdout": stdout,
|
| 53 |
+
"stderr": stderr,
|
| 54 |
+
"success": success,
|
| 55 |
+
"final_called": self._final_called,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
def _validate_path(self, path: str) -> str:
|
| 59 |
+
"""Ensure path stays within repo. Returns the real absolute path."""
|
| 60 |
+
full_path = os.path.join(self.repo_path, path)
|
| 61 |
+
real_path = os.path.realpath(full_path)
|
| 62 |
+
if not real_path.startswith(self.repo_path):
|
| 63 |
+
raise PermissionError(f"Access denied: {path}")
|
| 64 |
+
return real_path
|
| 65 |
+
|
| 66 |
+
def _read_file(self, path: str) -> str:
|
| 67 |
+
"""Read a file from the repo. Path relative to repo root."""
|
| 68 |
+
real_path = self._validate_path(path)
|
| 69 |
+
with open(real_path, "r") as f:
|
| 70 |
+
content = f.read()
|
| 71 |
+
if len(content) > 10000:
|
| 72 |
+
content = content[:10000] + "\n... [truncated]"
|
| 73 |
+
return content
|
| 74 |
+
|
| 75 |
+
def _list_dir(self, path: str = ".") -> list[str]:
|
| 76 |
+
"""List directory contents relative to repo root."""
|
| 77 |
+
real_path = self._validate_path(path)
|
| 78 |
+
entries = os.listdir(real_path)
|
| 79 |
+
result = []
|
| 80 |
+
for e in sorted(entries):
|
| 81 |
+
full = os.path.join(real_path, e)
|
| 82 |
+
suffix = "/" if os.path.isdir(full) else ""
|
| 83 |
+
result.append(e + suffix)
|
| 84 |
+
return result
|
| 85 |
+
|
| 86 |
+
def _search(self, pattern: str, path: str = ".") -> list[str]:
|
| 87 |
+
"""Grep for pattern in repo files. Returns list of matches."""
|
| 88 |
+
real_path = self._validate_path(path)
|
| 89 |
+
results = []
|
| 90 |
+
try:
|
| 91 |
+
output = subprocess.run(
|
| 92 |
+
["grep", "-rn", "--include=*.py", pattern, real_path],
|
| 93 |
+
capture_output=True,
|
| 94 |
+
text=True,
|
| 95 |
+
timeout=10,
|
| 96 |
+
)
|
| 97 |
+
for line in output.stdout.strip().split("\n")[:50]:
|
| 98 |
+
if line:
|
| 99 |
+
results.append(line.replace(self.repo_path + "/", ""))
|
| 100 |
+
except (subprocess.TimeoutExpired, Exception):
|
| 101 |
+
pass
|
| 102 |
+
return results
|
| 103 |
+
|
| 104 |
+
def _write_file(self, path: str, content: str) -> str:
|
| 105 |
+
"""Write a file to the repo. Records it for evaluation."""
|
| 106 |
+
real_path = self._validate_path(path)
|
| 107 |
+
os.makedirs(os.path.dirname(real_path), exist_ok=True)
|
| 108 |
+
with open(real_path, "w") as f:
|
| 109 |
+
f.write(content)
|
| 110 |
+
self.files_written[path] = content
|
| 111 |
+
return f"Written {len(content)} chars to {path}"
|
| 112 |
+
|
| 113 |
+
def _run_tests(self, test_path: str | None = None) -> dict:
|
| 114 |
+
"""Run pytest on specified test file(s). Returns pass/fail summary."""
|
| 115 |
+
import sys
|
| 116 |
+
|
| 117 |
+
cmd = [sys.executable, "-m", "pytest", "-v", "--tb=short", "--no-header"]
|
| 118 |
+
if test_path:
|
| 119 |
+
cmd.append(os.path.join(self.repo_path, test_path))
|
| 120 |
+
else:
|
| 121 |
+
cmd.append(self.repo_path)
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
result = subprocess.run(
|
| 125 |
+
cmd,
|
| 126 |
+
capture_output=True,
|
| 127 |
+
text=True,
|
| 128 |
+
timeout=60,
|
| 129 |
+
cwd=self.repo_path,
|
| 130 |
+
)
|
| 131 |
+
raw_output = result.stdout + result.stderr
|
| 132 |
+
# Strip ANSI color codes for reliable parsing
|
| 133 |
+
output = re.sub(r"\x1b\[[0-9;]*m", "", raw_output)
|
| 134 |
+
passed = len(re.findall(r" PASSED", output))
|
| 135 |
+
failed = len(re.findall(r" FAILED", output))
|
| 136 |
+
errors = len(re.findall(r" ERROR", output))
|
| 137 |
+
output_truncated = output[: self.max_output_chars]
|
| 138 |
+
|
| 139 |
+
return {
|
| 140 |
+
"passed": passed,
|
| 141 |
+
"failed": failed,
|
| 142 |
+
"errors": errors,
|
| 143 |
+
"total": passed + failed + errors,
|
| 144 |
+
"output": output_truncated,
|
| 145 |
+
"returncode": result.returncode,
|
| 146 |
+
}
|
| 147 |
+
except subprocess.TimeoutExpired:
|
| 148 |
+
return {
|
| 149 |
+
"passed": 0,
|
| 150 |
+
"failed": 0,
|
| 151 |
+
"errors": 1,
|
| 152 |
+
"total": 1,
|
| 153 |
+
"output": "Test execution timed out (60s limit)",
|
| 154 |
+
"returncode": -1,
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
def _spawn_agent(self, scope: str, mission: str, budget: int = 5) -> dict:
|
| 158 |
+
"""Stateless sub-LM call. Gathers scoped context and returns structured report."""
|
| 159 |
+
self._sub_agents_spawned += 1
|
| 160 |
+
scope_path = os.path.join(self.repo_path, scope)
|
| 161 |
+
|
| 162 |
+
if not os.path.exists(scope_path):
|
| 163 |
+
return {
|
| 164 |
+
"error": f"Scope path not found: {scope}",
|
| 165 |
+
"summary": "",
|
| 166 |
+
"files_examined": [],
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
# Build file listing for the scope
|
| 170 |
+
files = []
|
| 171 |
+
for root, dirs, filenames in os.walk(scope_path):
|
| 172 |
+
dirs[:] = [d for d in dirs if not d.startswith(".") and d != "__pycache__"]
|
| 173 |
+
for f in filenames:
|
| 174 |
+
if f.endswith(".py"):
|
| 175 |
+
rel = os.path.relpath(os.path.join(root, f), self.repo_path)
|
| 176 |
+
files.append(rel)
|
| 177 |
+
|
| 178 |
+
# Read first few files to build context
|
| 179 |
+
context_parts = []
|
| 180 |
+
for fpath in files[:5]:
|
| 181 |
+
try:
|
| 182 |
+
content = self._read_file(fpath)
|
| 183 |
+
context_parts.append(f"--- {fpath} ---\n{content[:2000]}")
|
| 184 |
+
except Exception:
|
| 185 |
+
pass
|
| 186 |
+
|
| 187 |
+
report = {
|
| 188 |
+
"summary": (
|
| 189 |
+
f"Explored scope '{scope}' for mission: {mission}. "
|
| 190 |
+
f"Found {len(files)} Python files."
|
| 191 |
+
),
|
| 192 |
+
"files_examined": files[:10],
|
| 193 |
+
"file_contents_preview": context_parts[:3],
|
| 194 |
+
"mission": mission,
|
| 195 |
+
}
|
| 196 |
+
return report
|
| 197 |
+
|
| 198 |
+
def _final(self) -> str:
|
| 199 |
+
"""Signal episode completion."""
|
| 200 |
+
self._final_called = True
|
| 201 |
+
return "Episode marked as complete. Evaluating..."
|
| 202 |
+
|
| 203 |
+
@property
|
| 204 |
+
def available_functions(self) -> list[str]:
|
| 205 |
+
return [
|
| 206 |
+
"read_file(path)",
|
| 207 |
+
"list_dir(path='.')",
|
| 208 |
+
"search(pattern, path='.')",
|
| 209 |
+
"write_file(path, content)",
|
| 210 |
+
"run_tests(test_path=None)",
|
| 211 |
+
"spawn_agent(scope, mission, budget=5)",
|
| 212 |
+
"FINAL()",
|
| 213 |
+
]
|
rlm_forge_training.ipynb
ADDED
|
@@ -0,0 +1,802 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# RLM-Forge: Training LLMs with GRPO on Coding Tasks\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"**RLM-Forge** is an OpenEnv environment that trains language models to solve coding tasks using Recursive Language Model (RLM) patterns. The environment:\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"1. Clones a Python repository\n",
|
| 12 |
+
"2. Replaces a source file with a broken stub (correct signatures, wrong implementations)\n",
|
| 13 |
+
"3. Provides a sandboxed REPL with tools (read_file, list_dir, search, write_file, run_tests)\n",
|
| 14 |
+
"4. The agent must explore the repo, understand the tests, and rewrite the source file\n",
|
| 15 |
+
"5. Reward = test pass rate (55%) + structural validity (15%) + efficiency (30%)\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"This notebook trains a model using **GRPO (Group Relative Policy Optimization)** with multi-step trajectory concatenation."
|
| 18 |
+
]
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"cell_type": "markdown",
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"source": [
|
| 24 |
+
"## 1. Setup & Installation"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"execution_count": null,
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"outputs": [],
|
| 32 |
+
"source": [
|
| 33 |
+
"%%capture\n",
|
| 34 |
+
"# Install dependencies\n",
|
| 35 |
+
"!pip install -q \"openenv-core[core]>=0.2.0\" trl transformers accelerate bitsandbytes peft datasets\n",
|
| 36 |
+
"!pip install -q text-unidecode freezegun pytest vllm\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"# Clone RLM-Forge repo\n",
|
| 39 |
+
"!git clone https://github.com/kking112/rlm-forge.git content/rlm-forge 2>/dev/null || true\n",
|
| 40 |
+
"# Or upload files manually — adjust path as needed\n",
|
| 41 |
+
"# import sys\n",
|
| 42 |
+
"# sys.path.insert(0, \"content/rlm-forge\")\n",
|
| 43 |
+
"\n",
|
| 44 |
+
"# Install RLM-Forge\n",
|
| 45 |
+
"!pip install -q -e content/rlm-forge"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": 1,
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [
|
| 53 |
+
{
|
| 54 |
+
"name": "stdout",
|
| 55 |
+
"output_type": "stream",
|
| 56 |
+
"text": [
|
| 57 |
+
"GPU: NVIDIA RTX PRO 6000 Blackwell Workstation Edition\n",
|
| 58 |
+
"PyTorch: 2.10.0+cu128\n"
|
| 59 |
+
]
|
| 60 |
+
}
|
| 61 |
+
],
|
| 62 |
+
"source": [
|
| 63 |
+
"import torch\n",
|
| 64 |
+
"import json\n",
|
| 65 |
+
"import re\n",
|
| 66 |
+
"import random\n",
|
| 67 |
+
"from typing import Optional\n",
|
| 68 |
+
"from dataclasses import dataclass\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
|
| 71 |
+
"# print(f\"VRAM: {torch.cuda.get_device_properties(0). / 1e9:.1f} GB\")\n",
|
| 72 |
+
"print(f\"PyTorch: {torch.__version__}\")"
|
| 73 |
+
]
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"cell_type": "markdown",
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"source": [
|
| 79 |
+
"## 2. Environment Smoke Test\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"Verify the environment works before training."
|
| 82 |
+
]
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"cell_type": "code",
|
| 86 |
+
"execution_count": 2,
|
| 87 |
+
"metadata": {},
|
| 88 |
+
"outputs": [
|
| 89 |
+
{
|
| 90 |
+
"name": "stdout",
|
| 91 |
+
"output_type": "stream",
|
| 92 |
+
"text": [
|
| 93 |
+
"Task: The file `slugify/slugify.py` has been replaced with a broken stub. 82 tests in `test.py` are now failing. Your task is to explore the repository, understand the expected behavior from the tests and o...\n",
|
| 94 |
+
"Available tools: ['read_file(path)', \"list_dir(path='.')\", \"search(pattern, path='.')\", 'write_file(path, content)', 'run_tests(test_path=None)', 'spawn_agent(scope, mission, budget=5)', 'FINAL()']\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"Step 1 stdout: ['.git/', '.github/', '.gitignore', '.pytest_cache/', '.vscode/', 'CHANGELOG.md', 'LICENSE', 'MANIFEST.in', 'README.md', '__pycache__/', 'dev.requirements.txt', 'format.sh', 'pyproject.toml', 'python_\n",
|
| 97 |
+
"\n",
|
| 98 |
+
"Baseline reward (no implementation): 0.3939\n",
|
| 99 |
+
"Test results: {'total_reward': 0.3939, 'test_pass_rate': 0.1707, 'tests_passed': 14, 'tests_failed': 68, 'tests_total': 82, 'structural_score': 0.0, 'efficiency_score': 1.0, 'breakdown': {'test_component': 0.0939, 'structural_component': 0.0, 'efficiency_component': 0.3}, 'test_output': '============================= test session starts ==============================\\ncollecting ... collected 82 items\\n\\ntest.py::TestSlugify::test_accented_text FAILED [ 1%]\\ntest.py::TestSlugify::test_accented_text_with_non_word_characters FAILED [ 2%]\\ntest.py::TestSlugify::test_contains_numbers FAILED [ 3%]\\ntest.py::TestSlugify::test_custom_separator FAILED [ 4%]\\ntest.py::TestSlugify::test_cyrillic_text FAILED [ 6%]\\ntest.py::TestSlugify::test_differently_cased_stopword_match FAILED [ 7%]\\ntest.py::TestSlugify::test_ends_with_number FAILED [ 8%]\\ntest.py::TestSlugify::test_extraneous_seperators FAILED [ 9%]\\ntest.py::TestSlugify::test_html_decimal_off FAILED [ 10%]\\ntest.py::TestSlugify::test_html_decimal_on FAILED [ 12%]\\ntest.py::TestSlugify::test_html_entities_off FAILED [ 13%]\\ntest.py::TestSlugify::test_html_entities_on FAILED [ 14%]\\ntest.py::TestSlugify::test_html_hexadecimal_off FAILED [ 15%]\\ntest.py::TestSlugify::test_html_hexadecimal_on FAILED [ 17%]\\ntest.py::TestSlugify::test_max_length FAILED [ 18%]\\ntest.py::TestSlugify::test_max_length_cutoff_not_required FAILED [ 19%]\\ntest.py::TestSlugify::test_multi_character_separator FAILED [ 20%]\\ntest.py::TestSlugify::test_multiple_stopword_occurances FAILED [ 21%]\\ntest.py::TestSlugify::test_multiple_stopwords FAILED [ 23%]\\ntest.py::TestSlugify::test_non_word_characters FAILED [ 24%]\\ntest.py::TestSlugify::test_numbers_and_symbols FAILED [ 25%]\\ntest.py::TestSlugify::test_numbers_only FAILED [ 26%]\\ntest.py::TestSlugify::test_phonetic_conversion_of_eastern_scripts FAILED [ 28%]\\ntest.py::TestSlugify::test_pre_translation P'}\n"
|
| 100 |
+
]
|
| 101 |
+
}
|
| 102 |
+
],
|
| 103 |
+
"source": [
|
| 104 |
+
"from rlm_forge.server.environment import RLMForgeEnvironment\n",
|
| 105 |
+
"from rlm_forge.models import RLMForgeAction\n",
|
| 106 |
+
"\n",
|
| 107 |
+
"env = RLMForgeEnvironment()\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"# Run a quick episode\n",
|
| 110 |
+
"obs = env.reset(seed=1)\n",
|
| 111 |
+
"print(f\"Task: {obs.task_description[:200]}...\")\n",
|
| 112 |
+
"print(f\"Available tools: {obs.available_functions}\")\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"# Take a step — list files\n",
|
| 115 |
+
"obs2 = env.step(RLMForgeAction(code=\"print(list_dir())\"))\n",
|
| 116 |
+
"print(f\"\\nStep 1 stdout: {obs2.stdout[:200]}\")\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"# Finalize and get reward\n",
|
| 119 |
+
"obs3 = env.step(RLMForgeAction(code=\"FINAL()\"))\n",
|
| 120 |
+
"print(f\"\\nBaseline reward (no implementation): {obs3.reward:.4f}\")\n",
|
| 121 |
+
"print(f\"Test results: {obs3.test_results}\")\n",
|
| 122 |
+
"\n",
|
| 123 |
+
"env.cleanup()"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"cell_type": "markdown",
|
| 128 |
+
"metadata": {},
|
| 129 |
+
"source": [
|
| 130 |
+
"## 3. Load Model\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"We use Qwen2.5-Coder-32B-Instruct with 4-bit quantization for inference, and train a LoRA adapter with GRPO."
|
| 133 |
+
]
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"cell_type": "code",
|
| 137 |
+
"execution_count": null,
|
| 138 |
+
"metadata": {},
|
| 139 |
+
"outputs": [],
|
| 140 |
+
"source": [
|
| 141 |
+
"# Model config — adjust based on available VRAM\n",
|
| 142 |
+
"# MODEL_ID = \"Qwen/Qwen2.5-Coder-32B-Instruct\" # 32B for H100\n",
|
| 143 |
+
"MODEL_ID = \"Qwen/Qwen2.5-Coder-7B-Instruct\" # Fallback for smaller GPUs\n",
|
| 144 |
+
"HF_TOKEN = '' #! Fill in HF TOKEN HERE!\n",
|
| 145 |
+
"MAX_STEPS_PER_EPISODE = 6 # Max REPL interactions per episode\n",
|
| 146 |
+
"NUM_EPISODES_PER_PROMPT = 4 # GRPO group size (completions per prompt)\n",
|
| 147 |
+
"NUM_TRAINING_PROMPTS = 16 # Total unique prompts (episodes) for training\n",
|
| 148 |
+
"GRPO_EPOCHS = 2 # Training epochs over collected data\n",
|
| 149 |
+
"BATCH_SIZE = 2\n",
|
| 150 |
+
"GRAD_ACCUM = 4"
|
| 151 |
+
]
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"cell_type": "code",
|
| 155 |
+
"execution_count": 4,
|
| 156 |
+
"metadata": {},
|
| 157 |
+
"outputs": [],
|
| 158 |
+
"source": [
|
| 159 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
|
| 160 |
+
"from peft import LoraConfig, get_peft_model\n",
|
| 161 |
+
"\n",
|
| 162 |
+
"# 4-bit quantization for 32B model on H100\n",
|
| 163 |
+
"bnb_config = BitsAndBytesConfig(\n",
|
| 164 |
+
" load_in_4bit=True,\n",
|
| 165 |
+
" bnb_4bit_quant_type=\"nf4\",\n",
|
| 166 |
+
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
|
| 167 |
+
" bnb_4bit_use_double_quant=True,\n",
|
| 168 |
+
")\n",
|
| 169 |
+
"\n",
|
| 170 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True,token=HF_TOKEN)\n",
|
| 171 |
+
"if tokenizer.pad_token is None:\n",
|
| 172 |
+
" tokenizer.pad_token = tokenizer.eos_token\n"
|
| 173 |
+
]
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"cell_type": "code",
|
| 177 |
+
"execution_count": null,
|
| 178 |
+
"metadata": {},
|
| 179 |
+
"outputs": [
|
| 180 |
+
{
|
| 181 |
+
"name": "stderr",
|
| 182 |
+
"output_type": "stream",
|
| 183 |
+
"text": [
|
| 184 |
+
"`torch_dtype` is deprecated! Use `dtype` instead!\n",
|
| 185 |
+
"/home/neo/Desktop/Projects/OpenEnv_Hackathon_SF/V1/.venv/lib/python3.13/site-packages/torch/jit/_script.py:362: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.\n",
|
| 186 |
+
" warnings.warn(\n"
|
| 187 |
+
]
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
"data": {
|
| 191 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 192 |
+
"model_id": "1c8654adb1804c6e944f84026e38a81b",
|
| 193 |
+
"version_major": 2,
|
| 194 |
+
"version_minor": 0
|
| 195 |
+
},
|
| 196 |
+
"text/plain": [
|
| 197 |
+
"Fetching 4 files: 0%| | 0/4 [00:00<?, ?it/s]"
|
| 198 |
+
]
|
| 199 |
+
},
|
| 200 |
+
"metadata": {},
|
| 201 |
+
"output_type": "display_data"
|
| 202 |
+
},
|
| 203 |
+
{
|
| 204 |
+
"data": {
|
| 205 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 206 |
+
"model_id": "ac2529cdddcf450ea1d5a50f2cea7814",
|
| 207 |
+
"version_major": 2,
|
| 208 |
+
"version_minor": 0
|
| 209 |
+
},
|
| 210 |
+
"text/plain": [
|
| 211 |
+
"model-00003-of-00004.safetensors: 0%| | 0.00/4.33G [00:00<?, ?B/s]"
|
| 212 |
+
]
|
| 213 |
+
},
|
| 214 |
+
"metadata": {},
|
| 215 |
+
"output_type": "display_data"
|
| 216 |
+
},
|
| 217 |
+
{
|
| 218 |
+
"data": {
|
| 219 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 220 |
+
"model_id": "dd5ac4ddfc7b46e2a1ca9515752aa745",
|
| 221 |
+
"version_major": 2,
|
| 222 |
+
"version_minor": 0
|
| 223 |
+
},
|
| 224 |
+
"text/plain": [
|
| 225 |
+
"model-00001-of-00004.safetensors: 0%| | 0.00/4.88G [00:00<?, ?B/s]"
|
| 226 |
+
]
|
| 227 |
+
},
|
| 228 |
+
"metadata": {},
|
| 229 |
+
"output_type": "display_data"
|
| 230 |
+
},
|
| 231 |
+
{
|
| 232 |
+
"data": {
|
| 233 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 234 |
+
"model_id": "a354cc2372c6467eb49761d5fc153940",
|
| 235 |
+
"version_major": 2,
|
| 236 |
+
"version_minor": 0
|
| 237 |
+
},
|
| 238 |
+
"text/plain": [
|
| 239 |
+
"model-00004-of-00004.safetensors: 0%| | 0.00/1.09G [00:00<?, ?B/s]"
|
| 240 |
+
]
|
| 241 |
+
},
|
| 242 |
+
"metadata": {},
|
| 243 |
+
"output_type": "display_data"
|
| 244 |
+
},
|
| 245 |
+
{
|
| 246 |
+
"data": {
|
| 247 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 248 |
+
"model_id": "ecca5855d3f647c8b3e49f43214bca5c",
|
| 249 |
+
"version_major": 2,
|
| 250 |
+
"version_minor": 0
|
| 251 |
+
},
|
| 252 |
+
"text/plain": [
|
| 253 |
+
"model-00002-of-00004.safetensors: 0%| | 0.00/4.93G [00:00<?, ?B/s]"
|
| 254 |
+
]
|
| 255 |
+
},
|
| 256 |
+
"metadata": {},
|
| 257 |
+
"output_type": "display_data"
|
| 258 |
+
}
|
| 259 |
+
],
|
| 260 |
+
"source": [
|
| 261 |
+
"\n",
|
| 262 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 263 |
+
" MODEL_ID,\n",
|
| 264 |
+
" quantization_config=bnb_config,\n",
|
| 265 |
+
" device_map=\"auto\",\n",
|
| 266 |
+
" torch_dtype=torch.bfloat16,\n",
|
| 267 |
+
" trust_remote_code=True,\n",
|
| 268 |
+
" # attn_implementation=\"flash_attention_2\",\n",
|
| 269 |
+
" token=HF_TOKEN\n",
|
| 270 |
+
")\n",
|
| 271 |
+
"\n",
|
| 272 |
+
"# LoRA config for efficient training\n",
|
| 273 |
+
"lora_config = LoraConfig(\n",
|
| 274 |
+
" r=16,\n",
|
| 275 |
+
" lora_alpha=32,\n",
|
| 276 |
+
" lora_dropout=0.05,\n",
|
| 277 |
+
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
| 278 |
+
" task_type=\"CAUSAL_LM\",\n",
|
| 279 |
+
")\n",
|
| 280 |
+
"\n",
|
| 281 |
+
"model = get_peft_model(model, lora_config)\n",
|
| 282 |
+
"model.print_trainable_parameters()"
|
| 283 |
+
]
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"cell_type": "markdown",
|
| 287 |
+
"metadata": {},
|
| 288 |
+
"source": [
|
| 289 |
+
"## 4. Trajectory Collection\n",
|
| 290 |
+
"\n",
|
| 291 |
+
"The key idea: treat the full multi-step episode as one \"completion\" for GRPO.\n",
|
| 292 |
+
"\n",
|
| 293 |
+
"**Prompt** = system message + task description + initial observation\n",
|
| 294 |
+
"**Completion** = sequence of all code actions (with observation feedback between them)\n",
|
| 295 |
+
"**Reward** = final composite reward from the environment\n",
|
| 296 |
+
"\n",
|
| 297 |
+
"We roll out multiple episodes per prompt (GRPO group) and use relative rewards within each group."
|
| 298 |
+
]
|
| 299 |
+
},
|
| 300 |
+
{
|
| 301 |
+
"cell_type": "code",
|
| 302 |
+
"execution_count": null,
|
| 303 |
+
"metadata": {},
|
| 304 |
+
"outputs": [],
|
| 305 |
+
"source": [
|
| 306 |
+
"SYSTEM_PROMPT = \"\"\"You are an expert Python developer. You are given a repository where a source file has been replaced with a broken stub. Your task is to explore the repository, understand the expected behavior from the tests, and rewrite the source file so all tests pass.\n",
|
| 307 |
+
"\n",
|
| 308 |
+
"You interact via a Python REPL. Available functions:\n",
|
| 309 |
+
"- read_file(path) — Read a file from the repo\n",
|
| 310 |
+
"- list_dir(path='.') — List directory contents\n",
|
| 311 |
+
"- search(pattern, path='.') — Grep for a pattern\n",
|
| 312 |
+
"- write_file(path, content) — Write/create a file\n",
|
| 313 |
+
"- run_tests(test_path=None) — Run pytest on a test file\n",
|
| 314 |
+
"- FINAL() — Signal that your implementation is complete\n",
|
| 315 |
+
"\n",
|
| 316 |
+
"Strategy:\n",
|
| 317 |
+
"1. Read the failing test file to understand expected behavior\n",
|
| 318 |
+
"2. Read other source files for context (imports, dependencies)\n",
|
| 319 |
+
"3. Write the implementation\n",
|
| 320 |
+
"4. Run tests to verify\n",
|
| 321 |
+
"5. Fix any failures\n",
|
| 322 |
+
"6. Call FINAL() when done\n",
|
| 323 |
+
"\n",
|
| 324 |
+
"Output ONLY valid Python code. No markdown, no explanations — just code to execute.\"\"\"\n",
|
| 325 |
+
"\n",
|
| 326 |
+
"\n",
|
| 327 |
+
"def build_prompt(task_description: str, failing_tests: list[str]) -> list[dict]:\n",
|
| 328 |
+
" \"\"\"Build the chat prompt for the initial observation.\"\"\"\n",
|
| 329 |
+
" user_msg = f\"{task_description}\\n\\nFailing tests:\\n\" + \"\\n\".join(failing_tests)\n",
|
| 330 |
+
" return [\n",
|
| 331 |
+
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 332 |
+
" {\"role\": \"user\", \"content\": user_msg},\n",
|
| 333 |
+
" ]\n",
|
| 334 |
+
"\n",
|
| 335 |
+
"\n",
|
| 336 |
+
"def extract_code_from_response(response: str) -> str:\n",
|
| 337 |
+
" \"\"\"Extract executable Python code from model response.\"\"\"\n",
|
| 338 |
+
" # Try to find code blocks first\n",
|
| 339 |
+
" code_blocks = re.findall(r\"```(?:python)?\\n(.*?)```\", response, re.DOTALL)\n",
|
| 340 |
+
" if code_blocks:\n",
|
| 341 |
+
" return \"\\n\".join(code_blocks)\n",
|
| 342 |
+
" # Otherwise treat the whole response as code\n",
|
| 343 |
+
" lines = response.strip().split(\"\\n\")\n",
|
| 344 |
+
" code_lines = []\n",
|
| 345 |
+
" for line in lines:\n",
|
| 346 |
+
" stripped = line.strip()\n",
|
| 347 |
+
" if stripped and not stripped.startswith(\"#\") and any(c in stripped for c in \"=()[]{}:\"):\n",
|
| 348 |
+
" code_lines.append(line)\n",
|
| 349 |
+
" elif stripped.startswith(\"#\") or stripped.startswith(\"import\") or stripped.startswith(\"from\"):\n",
|
| 350 |
+
" code_lines.append(line)\n",
|
| 351 |
+
" elif not stripped:\n",
|
| 352 |
+
" code_lines.append(line)\n",
|
| 353 |
+
" else:\n",
|
| 354 |
+
" code_lines.append(f\"# {line}\")\n",
|
| 355 |
+
" return \"\\n\".join(code_lines)\n",
|
| 356 |
+
"\n",
|
| 357 |
+
"\n",
|
| 358 |
+
"print(\"Prompt builder ready.\")"
|
| 359 |
+
]
|
| 360 |
+
},
|
| 361 |
+
{
|
| 362 |
+
"cell_type": "code",
|
| 363 |
+
"execution_count": null,
|
| 364 |
+
"metadata": {},
|
| 365 |
+
"outputs": [],
|
| 366 |
+
"source": [
|
| 367 |
+
"@dataclass\n",
|
| 368 |
+
"class Trajectory:\n",
|
| 369 |
+
" \"\"\"A full multi-step episode trajectory for GRPO training.\"\"\"\n",
|
| 370 |
+
" prompt_text: str # Tokenized prompt (system + task)\n",
|
| 371 |
+
" completion_text: str # All model outputs concatenated\n",
|
| 372 |
+
" reward: float # Final episode reward\n",
|
| 373 |
+
" steps: int # Number of steps taken\n",
|
| 374 |
+
" seed: int # Environment seed (for reproducibility)\n",
|
| 375 |
+
" tests_passed: int\n",
|
| 376 |
+
" tests_total: int\n",
|
| 377 |
+
"\n",
|
| 378 |
+
"\n",
|
| 379 |
+
"def run_episode(\n",
|
| 380 |
+
" model,\n",
|
| 381 |
+
" tokenizer,\n",
|
| 382 |
+
" env: RLMForgeEnvironment,\n",
|
| 383 |
+
" seed: int,\n",
|
| 384 |
+
" max_steps: int = MAX_STEPS_PER_EPISODE,\n",
|
| 385 |
+
" temperature: float = 0.7,\n",
|
| 386 |
+
" max_new_tokens: int = 2048,\n",
|
| 387 |
+
") -> Trajectory:\n",
|
| 388 |
+
" \"\"\"Run a single episode: generate code actions, execute them, collect trajectory.\"\"\"\n",
|
| 389 |
+
" obs = env.reset(seed=seed)\n",
|
| 390 |
+
"\n",
|
| 391 |
+
" messages = build_prompt(obs.task_description, obs.failing_tests or [])\n",
|
| 392 |
+
" prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
|
| 393 |
+
"\n",
|
| 394 |
+
" all_completions = [] # All model outputs for this episode\n",
|
| 395 |
+
"\n",
|
| 396 |
+
" for step_i in range(max_steps):\n",
|
| 397 |
+
" # Build the full conversation so far for the model\n",
|
| 398 |
+
" if step_i > 0:\n",
|
| 399 |
+
" # Add the observation as assistant feedback\n",
|
| 400 |
+
" messages.append({\"role\": \"user\", \"content\": f\"REPL output:\\n{obs.stdout}\\n{obs.stderr}\"})\n",
|
| 401 |
+
"\n",
|
| 402 |
+
" # Generate next action\n",
|
| 403 |
+
" full_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
|
| 404 |
+
" inputs = tokenizer(full_text, return_tensors=\"pt\", truncation=True, max_length=8192).to(model.device)\n",
|
| 405 |
+
"\n",
|
| 406 |
+
" with torch.no_grad():\n",
|
| 407 |
+
" outputs = model.generate(\n",
|
| 408 |
+
" **inputs,\n",
|
| 409 |
+
" max_new_tokens=max_new_tokens,\n",
|
| 410 |
+
" temperature=temperature,\n",
|
| 411 |
+
" top_p=0.95,\n",
|
| 412 |
+
" do_sample=True,\n",
|
| 413 |
+
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 414 |
+
" )\n",
|
| 415 |
+
"\n",
|
| 416 |
+
" # Decode only the new tokens\n",
|
| 417 |
+
" new_tokens = outputs[0][inputs[\"input_ids\"].shape[1]:]\n",
|
| 418 |
+
" response = tokenizer.decode(new_tokens, skip_special_tokens=True)\n",
|
| 419 |
+
" all_completions.append(response)\n",
|
| 420 |
+
"\n",
|
| 421 |
+
" # Add to conversation history\n",
|
| 422 |
+
" messages.append({\"role\": \"assistant\", \"content\": response})\n",
|
| 423 |
+
"\n",
|
| 424 |
+
" # Extract and execute code\n",
|
| 425 |
+
" code = extract_code_from_response(response)\n",
|
| 426 |
+
"\n",
|
| 427 |
+
" # Check if model wants to finalize\n",
|
| 428 |
+
" if \"FINAL()\" in code:\n",
|
| 429 |
+
" obs = env.step(RLMForgeAction(code=code))\n",
|
| 430 |
+
" break\n",
|
| 431 |
+
" else:\n",
|
| 432 |
+
" obs = env.step(RLMForgeAction(code=code))\n",
|
| 433 |
+
"\n",
|
| 434 |
+
" if obs.done:\n",
|
| 435 |
+
" break\n",
|
| 436 |
+
"\n",
|
| 437 |
+
" # If we exhausted steps without FINAL, force finalize\n",
|
| 438 |
+
" if not obs.done:\n",
|
| 439 |
+
" obs = env.step(RLMForgeAction(code=\"FINAL()\"))\n",
|
| 440 |
+
"\n",
|
| 441 |
+
" # Build the full completion text (all model outputs joined)\n",
|
| 442 |
+
" completion_text = \"\\n<|step|>\\n\".join(all_completions)\n",
|
| 443 |
+
"\n",
|
| 444 |
+
" reward = obs.reward or 0.0\n",
|
| 445 |
+
" test_results = obs.test_results or {}\n",
|
| 446 |
+
"\n",
|
| 447 |
+
" return Trajectory(\n",
|
| 448 |
+
" prompt_text=prompt_text,\n",
|
| 449 |
+
" completion_text=completion_text,\n",
|
| 450 |
+
" reward=reward,\n",
|
| 451 |
+
" steps=step_i + 1,\n",
|
| 452 |
+
" seed=seed,\n",
|
| 453 |
+
" tests_passed=test_results.get(\"tests_passed\", 0),\n",
|
| 454 |
+
" tests_total=test_results.get(\"tests_total\", 0),\n",
|
| 455 |
+
" )\n",
|
| 456 |
+
"\n",
|
| 457 |
+
"\n",
|
| 458 |
+
"print(\"Episode runner ready.\")"
|
| 459 |
+
]
|
| 460 |
+
},
|
| 461 |
+
{
|
| 462 |
+
"cell_type": "markdown",
|
| 463 |
+
"metadata": {},
|
| 464 |
+
"source": [
|
| 465 |
+
"## 5. Collect Baseline Trajectories\n",
|
| 466 |
+
"\n",
|
| 467 |
+
"Run episodes to collect (prompt, completion, reward) tuples before training. This establishes the pre-training baseline."
|
| 468 |
+
]
|
| 469 |
+
},
|
| 470 |
+
{
|
| 471 |
+
"cell_type": "code",
|
| 472 |
+
"execution_count": null,
|
| 473 |
+
"metadata": {},
|
| 474 |
+
"outputs": [],
|
| 475 |
+
"source": [
|
| 476 |
+
"def collect_trajectories(\n",
|
| 477 |
+
" model,\n",
|
| 478 |
+
" tokenizer,\n",
|
| 479 |
+
" num_prompts: int = NUM_TRAINING_PROMPTS,\n",
|
| 480 |
+
" episodes_per_prompt: int = NUM_EPISODES_PER_PROMPT,\n",
|
| 481 |
+
" temperature: float = 0.7,\n",
|
| 482 |
+
") -> list[list[Trajectory]]:\n",
|
| 483 |
+
" \"\"\"Collect GRPO groups: multiple trajectories per unique prompt/seed.\"\"\"\n",
|
| 484 |
+
" env = RLMForgeEnvironment()\n",
|
| 485 |
+
" all_groups = []\n",
|
| 486 |
+
"\n",
|
| 487 |
+
" for prompt_idx in range(num_prompts):\n",
|
| 488 |
+
" seed = prompt_idx * 100 # Deterministic seeds\n",
|
| 489 |
+
" group = []\n",
|
| 490 |
+
"\n",
|
| 491 |
+
" for ep_idx in range(episodes_per_prompt):\n",
|
| 492 |
+
" print(f\" Prompt {prompt_idx+1}/{num_prompts}, Episode {ep_idx+1}/{episodes_per_prompt}...\", end=\" \")\n",
|
| 493 |
+
" traj = run_episode(\n",
|
| 494 |
+
" model, tokenizer, env,\n",
|
| 495 |
+
" seed=seed, # Same seed = same task for GRPO group\n",
|
| 496 |
+
" temperature=temperature + 0.1 * ep_idx, # Vary temperature for diversity\n",
|
| 497 |
+
" )\n",
|
| 498 |
+
" group.append(traj)\n",
|
| 499 |
+
" print(f\"reward={traj.reward:.3f}, steps={traj.steps}, \"\n",
|
| 500 |
+
" f\"tests={traj.tests_passed}/{traj.tests_total}\")\n",
|
| 501 |
+
"\n",
|
| 502 |
+
" all_groups.append(group)\n",
|
| 503 |
+
"\n",
|
| 504 |
+
" env.cleanup()\n",
|
| 505 |
+
" return all_groups\n",
|
| 506 |
+
"\n",
|
| 507 |
+
"\n",
|
| 508 |
+
"# Collect pre-training baseline\n",
|
| 509 |
+
"print(\"=\" * 60)\n",
|
| 510 |
+
"print(\"COLLECTING BASELINE TRAJECTORIES\")\n",
|
| 511 |
+
"print(\"=\" * 60)\n",
|
| 512 |
+
"baseline_groups = collect_trajectories(model, tokenizer)\n",
|
| 513 |
+
"\n",
|
| 514 |
+
"# Summary stats\n",
|
| 515 |
+
"all_rewards = [t.reward for g in baseline_groups for t in g]\n",
|
| 516 |
+
"print(f\"\\nBaseline: mean_reward={sum(all_rewards)/len(all_rewards):.4f}, \"\n",
|
| 517 |
+
" f\"min={min(all_rewards):.4f}, max={max(all_rewards):.4f}\")"
|
| 518 |
+
]
|
| 519 |
+
},
|
| 520 |
+
{
|
| 521 |
+
"cell_type": "markdown",
|
| 522 |
+
"metadata": {},
|
| 523 |
+
"source": [
|
| 524 |
+
"## 6. GRPO Training\n",
|
| 525 |
+
"\n",
|
| 526 |
+
"Train with Group Relative Policy Optimization. For each group of trajectories (same prompt, different completions), compute advantages relative to the group mean reward, then update the policy to increase probability of higher-reward trajectories."
|
| 527 |
+
]
|
| 528 |
+
},
|
| 529 |
+
{
|
| 530 |
+
"cell_type": "code",
|
| 531 |
+
"execution_count": null,
|
| 532 |
+
"metadata": {},
|
| 533 |
+
"outputs": [],
|
| 534 |
+
"source": [
|
| 535 |
+
"from datasets import Dataset\n",
|
| 536 |
+
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 537 |
+
"\n",
|
| 538 |
+
"\n",
|
| 539 |
+
"def trajectories_to_dataset(groups: list[list[Trajectory]]) -> Dataset:\n",
|
| 540 |
+
" \"\"\"Convert trajectory groups into a HuggingFace Dataset for GRPO training.\"\"\"\n",
|
| 541 |
+
" records = []\n",
|
| 542 |
+
" for group in groups:\n",
|
| 543 |
+
" prompt = group[0].prompt_text\n",
|
| 544 |
+
" for traj in group:\n",
|
| 545 |
+
" records.append({\n",
|
| 546 |
+
" \"prompt\": prompt,\n",
|
| 547 |
+
" \"completion\": traj.completion_text,\n",
|
| 548 |
+
" \"reward\": traj.reward,\n",
|
| 549 |
+
" })\n",
|
| 550 |
+
" return Dataset.from_list(records)\n",
|
| 551 |
+
"\n",
|
| 552 |
+
"\n",
|
| 553 |
+
"def build_reward_fn(groups: list[list[Trajectory]]):\n",
|
| 554 |
+
" \"\"\"Build a reward function from pre-collected trajectories.\"\"\"\n",
|
| 555 |
+
" reward_map = {}\n",
|
| 556 |
+
" for group in groups:\n",
|
| 557 |
+
" for traj in group:\n",
|
| 558 |
+
" key = traj.completion_text[:200]\n",
|
| 559 |
+
" reward_map[key] = traj.reward\n",
|
| 560 |
+
"\n",
|
| 561 |
+
" def reward_fn(completions: list[str], **kwargs) -> list[float]:\n",
|
| 562 |
+
" rewards = []\n",
|
| 563 |
+
" for c in completions:\n",
|
| 564 |
+
" key = c[:200]\n",
|
| 565 |
+
" rewards.append(reward_map.get(key, 0.0))\n",
|
| 566 |
+
" return rewards\n",
|
| 567 |
+
"\n",
|
| 568 |
+
" return reward_fn\n",
|
| 569 |
+
"\n",
|
| 570 |
+
"\n",
|
| 571 |
+
"# Build dataset from baseline trajectories\n",
|
| 572 |
+
"train_dataset = trajectories_to_dataset(baseline_groups)\n",
|
| 573 |
+
"print(f\"Training dataset: {len(train_dataset)} examples\")\n",
|
| 574 |
+
"print(f\"Sample prompt length: {len(train_dataset[0]['prompt'])} chars\")\n",
|
| 575 |
+
"print(f\"Sample completion length: {len(train_dataset[0]['completion'])} chars\")\n",
|
| 576 |
+
"print(f\"Sample reward: {train_dataset[0]['reward']:.4f}\")"
|
| 577 |
+
]
|
| 578 |
+
},
|
| 579 |
+
{
|
| 580 |
+
"cell_type": "code",
|
| 581 |
+
"execution_count": null,
|
| 582 |
+
"metadata": {},
|
| 583 |
+
"outputs": [],
|
| 584 |
+
"source": [
|
| 585 |
+
"# GRPO Training configuration\n",
|
| 586 |
+
"grpo_config = GRPOConfig(\n",
|
| 587 |
+
" output_dir=\"./rlm_forge_grpo_output\",\n",
|
| 588 |
+
" num_train_epochs=GRPO_EPOCHS,\n",
|
| 589 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
| 590 |
+
" gradient_accumulation_steps=GRAD_ACCUM,\n",
|
| 591 |
+
" learning_rate=1e-5,\n",
|
| 592 |
+
" warmup_ratio=0.1,\n",
|
| 593 |
+
" max_completion_length=4096,\n",
|
| 594 |
+
" # max_prompt_length=4096,\n",
|
| 595 |
+
" num_generations=NUM_EPISODES_PER_PROMPT, # GRPO group size\n",
|
| 596 |
+
" logging_steps=1,\n",
|
| 597 |
+
" save_strategy=\"epoch\",\n",
|
| 598 |
+
" bf16=True,\n",
|
| 599 |
+
" gradient_checkpointing=True,\n",
|
| 600 |
+
" # GRPO-specific\n",
|
| 601 |
+
" beta=0.1, # KL penalty coefficient\n",
|
| 602 |
+
" report_to=\"none\",\n",
|
| 603 |
+
")\n",
|
| 604 |
+
"\n",
|
| 605 |
+
"# Build reward function from collected trajectories\n",
|
| 606 |
+
"reward_fn = build_reward_fn(baseline_groups)\n",
|
| 607 |
+
"\n",
|
| 608 |
+
"# Prepare prompts dataset (unique prompts only, GRPO generates completions)\n",
|
| 609 |
+
"prompt_dataset = Dataset.from_list([\n",
|
| 610 |
+
" {\"prompt\": group[0].prompt_text}\n",
|
| 611 |
+
" for group in baseline_groups\n",
|
| 612 |
+
"])\n",
|
| 613 |
+
"\n",
|
| 614 |
+
"# Initialize GRPO trainer\n",
|
| 615 |
+
"trainer = GRPOTrainer(\n",
|
| 616 |
+
" model=model,\n",
|
| 617 |
+
" args=grpo_config,\n",
|
| 618 |
+
" train_dataset=prompt_dataset,\n",
|
| 619 |
+
" reward_funcs=reward_fn,\n",
|
| 620 |
+
" processing_class=tokenizer,\n",
|
| 621 |
+
")\n",
|
| 622 |
+
"\n",
|
| 623 |
+
"print(\"GRPO Trainer initialized. Starting training...\")\n",
|
| 624 |
+
"trainer.train()\n",
|
| 625 |
+
"print(\"Training complete!\")"
|
| 626 |
+
]
|
| 627 |
+
},
|
| 628 |
+
{
|
| 629 |
+
"cell_type": "markdown",
|
| 630 |
+
"metadata": {},
|
| 631 |
+
"source": [
|
| 632 |
+
"## 7. Post-Training Evaluation\n",
|
| 633 |
+
"\n",
|
| 634 |
+
"Collect new trajectories with the trained model and compare rewards to the baseline."
|
| 635 |
+
]
|
| 636 |
+
},
|
| 637 |
+
{
|
| 638 |
+
"cell_type": "code",
|
| 639 |
+
"execution_count": null,
|
| 640 |
+
"metadata": {},
|
| 641 |
+
"outputs": [],
|
| 642 |
+
"source": [
|
| 643 |
+
"# Collect post-training trajectories with the same seeds\n",
|
| 644 |
+
"print(\"=\" * 60)\n",
|
| 645 |
+
"print(\"COLLECTING POST-TRAINING TRAJECTORIES\")\n",
|
| 646 |
+
"print(\"=\" * 60)\n",
|
| 647 |
+
"post_groups = collect_trajectories(model, tokenizer, temperature=0.5)\n",
|
| 648 |
+
"\n",
|
| 649 |
+
"post_rewards = [t.reward for g in post_groups for t in g]\n",
|
| 650 |
+
"baseline_rewards = [t.reward for g in baseline_groups for t in g]\n",
|
| 651 |
+
"\n",
|
| 652 |
+
"print(f\"\\n{'='*60}\")\n",
|
| 653 |
+
"print(f\"RESULTS COMPARISON\")\n",
|
| 654 |
+
"print(f\"{'='*60}\")\n",
|
| 655 |
+
"print(f\"Baseline: mean={sum(baseline_rewards)/len(baseline_rewards):.4f}, \"\n",
|
| 656 |
+
" f\"max={max(baseline_rewards):.4f}\")\n",
|
| 657 |
+
"print(f\"Trained: mean={sum(post_rewards)/len(post_rewards):.4f}, \"\n",
|
| 658 |
+
" f\"max={max(post_rewards):.4f}\")\n",
|
| 659 |
+
"print(f\"Improvement: {(sum(post_rewards)/len(post_rewards) - sum(baseline_rewards)/len(baseline_rewards)):.4f}\")\n",
|
| 660 |
+
"\n",
|
| 661 |
+
"# Per-task comparison\n",
|
| 662 |
+
"print(f\"\\nPer-task breakdown:\")\n",
|
| 663 |
+
"for i, (bg, pg) in enumerate(zip(baseline_groups, post_groups)):\n",
|
| 664 |
+
" b_mean = sum(t.reward for t in bg) / len(bg)\n",
|
| 665 |
+
" p_mean = sum(t.reward for t in pg) / len(pg)\n",
|
| 666 |
+
" delta = p_mean - b_mean\n",
|
| 667 |
+
" arrow = \"\\u2191\" if delta > 0 else \"\\u2193\" if delta < 0 else \"\\u2192\"\n",
|
| 668 |
+
" print(f\" Task {i}: baseline={b_mean:.3f} \\u2192 trained={p_mean:.3f} ({arrow} {abs(delta):.3f})\")"
|
| 669 |
+
]
|
| 670 |
+
},
|
| 671 |
+
{
|
| 672 |
+
"cell_type": "markdown",
|
| 673 |
+
"metadata": {},
|
| 674 |
+
"source": [
|
| 675 |
+
"## 8. Visualize Results"
|
| 676 |
+
]
|
| 677 |
+
},
|
| 678 |
+
{
|
| 679 |
+
"cell_type": "code",
|
| 680 |
+
"execution_count": null,
|
| 681 |
+
"metadata": {},
|
| 682 |
+
"outputs": [],
|
| 683 |
+
"source": [
|
| 684 |
+
"import matplotlib.pyplot as plt\n",
|
| 685 |
+
"import numpy as np\n",
|
| 686 |
+
"\n",
|
| 687 |
+
"fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
|
| 688 |
+
"\n",
|
| 689 |
+
"# 1. Reward distribution: baseline vs trained\n",
|
| 690 |
+
"ax1 = axes[0]\n",
|
| 691 |
+
"ax1.hist(baseline_rewards, bins=20, alpha=0.6, label=\"Baseline\", color=\"steelblue\")\n",
|
| 692 |
+
"ax1.hist(post_rewards, bins=20, alpha=0.6, label=\"After GRPO\", color=\"coral\")\n",
|
| 693 |
+
"ax1.set_xlabel(\"Episode Reward\")\n",
|
| 694 |
+
"ax1.set_ylabel(\"Count\")\n",
|
| 695 |
+
"ax1.set_title(\"Reward Distribution\")\n",
|
| 696 |
+
"ax1.legend()\n",
|
| 697 |
+
"ax1.axvline(np.mean(baseline_rewards), color=\"steelblue\", linestyle=\"--\", alpha=0.8)\n",
|
| 698 |
+
"ax1.axvline(np.mean(post_rewards), color=\"coral\", linestyle=\"--\", alpha=0.8)\n",
|
| 699 |
+
"\n",
|
| 700 |
+
"# 2. Per-task mean reward comparison\n",
|
| 701 |
+
"ax2 = axes[1]\n",
|
| 702 |
+
"task_ids = list(range(len(baseline_groups)))\n",
|
| 703 |
+
"b_means = [np.mean([t.reward for t in g]) for g in baseline_groups]\n",
|
| 704 |
+
"p_means = [np.mean([t.reward for t in g]) for g in post_groups]\n",
|
| 705 |
+
"x = np.arange(len(task_ids))\n",
|
| 706 |
+
"width = 0.35\n",
|
| 707 |
+
"ax2.bar(x - width/2, b_means, width, label=\"Baseline\", color=\"steelblue\", alpha=0.8)\n",
|
| 708 |
+
"ax2.bar(x + width/2, p_means, width, label=\"After GRPO\", color=\"coral\", alpha=0.8)\n",
|
| 709 |
+
"ax2.set_xlabel(\"Task ID\")\n",
|
| 710 |
+
"ax2.set_ylabel(\"Mean Reward\")\n",
|
| 711 |
+
"ax2.set_title(\"Per-Task Reward Improvement\")\n",
|
| 712 |
+
"ax2.legend()\n",
|
| 713 |
+
"ax2.set_xticks(x)\n",
|
| 714 |
+
"\n",
|
| 715 |
+
"# 3. Test pass rate improvement\n",
|
| 716 |
+
"ax3 = axes[2]\n",
|
| 717 |
+
"b_pass_rates = [np.mean([t.tests_passed / max(t.tests_total, 1) for t in g]) for g in baseline_groups]\n",
|
| 718 |
+
"p_pass_rates = [np.mean([t.tests_passed / max(t.tests_total, 1) for t in g]) for g in post_groups]\n",
|
| 719 |
+
"ax3.bar(x - width/2, b_pass_rates, width, label=\"Baseline\", color=\"steelblue\", alpha=0.8)\n",
|
| 720 |
+
"ax3.bar(x + width/2, p_pass_rates, width, label=\"After GRPO\", color=\"coral\", alpha=0.8)\n",
|
| 721 |
+
"ax3.set_xlabel(\"Task ID\")\n",
|
| 722 |
+
"ax3.set_ylabel(\"Test Pass Rate\")\n",
|
| 723 |
+
"ax3.set_title(\"Test Pass Rate Improvement\")\n",
|
| 724 |
+
"ax3.legend()\n",
|
| 725 |
+
"ax3.set_xticks(x)\n",
|
| 726 |
+
"\n",
|
| 727 |
+
"plt.tight_layout()\n",
|
| 728 |
+
"plt.savefig(\"rlm_forge_results.png\", dpi=150, bbox_inches=\"tight\")\n",
|
| 729 |
+
"plt.show()\n",
|
| 730 |
+
"\n",
|
| 731 |
+
"print(f\"\\nOverall test pass rate:\")\n",
|
| 732 |
+
"print(f\" Baseline: {np.mean(b_pass_rates):.1%}\")\n",
|
| 733 |
+
"print(f\" Trained: {np.mean(p_pass_rates):.1%}\")"
|
| 734 |
+
]
|
| 735 |
+
},
|
| 736 |
+
{
|
| 737 |
+
"cell_type": "markdown",
|
| 738 |
+
"metadata": {},
|
| 739 |
+
"source": [
|
| 740 |
+
"## 9. Save Model & Training Log"
|
| 741 |
+
]
|
| 742 |
+
},
|
| 743 |
+
{
|
| 744 |
+
"cell_type": "code",
|
| 745 |
+
"execution_count": null,
|
| 746 |
+
"metadata": {},
|
| 747 |
+
"outputs": [],
|
| 748 |
+
"source": [
|
| 749 |
+
"# Save the trained LoRA adapter\n",
|
| 750 |
+
"model.save_pretrained(\"./rlm_forge_lora_adapter\")\n",
|
| 751 |
+
"tokenizer.save_pretrained(\"./rlm_forge_lora_adapter\")\n",
|
| 752 |
+
"\n",
|
| 753 |
+
"# Save training log\n",
|
| 754 |
+
"training_log = {\n",
|
| 755 |
+
" \"model_id\": MODEL_ID,\n",
|
| 756 |
+
" \"num_prompts\": NUM_TRAINING_PROMPTS,\n",
|
| 757 |
+
" \"episodes_per_prompt\": NUM_EPISODES_PER_PROMPT,\n",
|
| 758 |
+
" \"max_steps_per_episode\": MAX_STEPS_PER_EPISODE,\n",
|
| 759 |
+
" \"grpo_epochs\": GRPO_EPOCHS,\n",
|
| 760 |
+
" \"baseline_mean_reward\": float(np.mean(baseline_rewards)),\n",
|
| 761 |
+
" \"baseline_max_reward\": float(max(baseline_rewards)),\n",
|
| 762 |
+
" \"trained_mean_reward\": float(np.mean(post_rewards)),\n",
|
| 763 |
+
" \"trained_max_reward\": float(max(post_rewards)),\n",
|
| 764 |
+
" \"improvement\": float(np.mean(post_rewards) - np.mean(baseline_rewards)),\n",
|
| 765 |
+
" \"baseline_test_pass_rate\": float(np.mean(b_pass_rates)),\n",
|
| 766 |
+
" \"trained_test_pass_rate\": float(np.mean(p_pass_rates)),\n",
|
| 767 |
+
"}\n",
|
| 768 |
+
"\n",
|
| 769 |
+
"with open(\"training_log.json\", \"w\") as f:\n",
|
| 770 |
+
" json.dump(training_log, f, indent=2)\n",
|
| 771 |
+
"\n",
|
| 772 |
+
"print(\"Saved LoRA adapter to ./rlm_forge_lora_adapter\")\n",
|
| 773 |
+
"print(\"Saved training log to training_log.json\")\n",
|
| 774 |
+
"print(f\"\\nFinal summary:\")\n",
|
| 775 |
+
"print(json.dumps(training_log, indent=2))"
|
| 776 |
+
]
|
| 777 |
+
}
|
| 778 |
+
],
|
| 779 |
+
"metadata": {
|
| 780 |
+
"accelerator": "GPU",
|
| 781 |
+
"gpuClass": "premium",
|
| 782 |
+
"kernelspec": {
|
| 783 |
+
"display_name": ".venv",
|
| 784 |
+
"language": "python",
|
| 785 |
+
"name": "python3"
|
| 786 |
+
},
|
| 787 |
+
"language_info": {
|
| 788 |
+
"codemirror_mode": {
|
| 789 |
+
"name": "ipython",
|
| 790 |
+
"version": 3
|
| 791 |
+
},
|
| 792 |
+
"file_extension": ".py",
|
| 793 |
+
"mimetype": "text/x-python",
|
| 794 |
+
"name": "python",
|
| 795 |
+
"nbconvert_exporter": "python",
|
| 796 |
+
"pygments_lexer": "ipython3",
|
| 797 |
+
"version": "3.13.3"
|
| 798 |
+
}
|
| 799 |
+
},
|
| 800 |
+
"nbformat": 4,
|
| 801 |
+
"nbformat_minor": 4
|
| 802 |
+
}
|
rlm_forge_training.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
import random
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from datasets import Dataset
|
| 8 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 9 |
+
|
| 10 |
+
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 11 |
+
# print(f"VRAM: {torch.cuda.get_device_properties(0). / 1e9:.1f} GB")
|
| 12 |
+
print(f"PyTorch: {torch.__version__}")
|
| 13 |
+
|
| 14 |
+
from rlm_forge.server.environment import RLMForgeEnvironment
|
| 15 |
+
from rlm_forge.models import RLMForgeAction
|
| 16 |
+
|
| 17 |
+
env = RLMForgeEnvironment()
|
| 18 |
+
|
| 19 |
+
# Run a quick episode
|
| 20 |
+
obs = env.reset(seed=1)
|
| 21 |
+
print(f"Task: {obs.task_description[:200]}...")
|
| 22 |
+
print(f"Available tools: {obs.available_functions}")
|
| 23 |
+
|
| 24 |
+
# Take a step — list files
|
| 25 |
+
obs2 = env.step(RLMForgeAction(code="print(list_dir())"))
|
| 26 |
+
print(f"\nStep 1 stdout: {obs2.stdout[:200]}")
|
| 27 |
+
|
| 28 |
+
# Finalize and get reward
|
| 29 |
+
obs3 = env.step(RLMForgeAction(code="FINAL()"))
|
| 30 |
+
print(f"\nBaseline reward (no implementation): {obs3.reward:.4f}")
|
| 31 |
+
print(f"Test results: {obs3.test_results}")
|
| 32 |
+
|
| 33 |
+
env.cleanup()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Model config — adjust based on available VRAM
|
| 37 |
+
# MODEL_ID = "Qwen/Qwen2.5-Coder-32B-Instruct" # 32B for H100
|
| 38 |
+
MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct" # Fallback for smaller GPUs
|
| 39 |
+
HF_TOKEN = ''
|
| 40 |
+
MAX_STEPS_PER_EPISODE = 6 # Max REPL interactions per episode
|
| 41 |
+
NUM_EPISODES_PER_PROMPT = 2 # GRPO group size (completions per prompt)
|
| 42 |
+
NUM_TRAINING_PROMPTS = 8 # 16 # Total unique prompts (episodes) for training
|
| 43 |
+
GRPO_EPOCHS = 2 # Training epochs over collected data
|
| 44 |
+
BATCH_SIZE = 2
|
| 45 |
+
GRAD_ACCUM = 4
|
| 46 |
+
|
| 47 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 48 |
+
from peft import LoraConfig, get_peft_model
|
| 49 |
+
|
| 50 |
+
# 4-bit quantization for 32B model on H100
|
| 51 |
+
bnb_config = BitsAndBytesConfig(
|
| 52 |
+
load_in_4bit=True,
|
| 53 |
+
bnb_4bit_quant_type="nf4",
|
| 54 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 55 |
+
bnb_4bit_use_double_quant=True,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True,token=HF_TOKEN)
|
| 59 |
+
if tokenizer.pad_token is None:
|
| 60 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 65 |
+
MODEL_ID,
|
| 66 |
+
quantization_config=bnb_config,
|
| 67 |
+
device_map="auto",
|
| 68 |
+
torch_dtype=torch.bfloat16,
|
| 69 |
+
trust_remote_code=True,
|
| 70 |
+
# attn_implementation="flash_attention_2",
|
| 71 |
+
token=HF_TOKEN
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# LoRA config for efficient training
|
| 75 |
+
lora_config = LoraConfig(
|
| 76 |
+
r=16,
|
| 77 |
+
lora_alpha=32,
|
| 78 |
+
lora_dropout=0.05,
|
| 79 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 80 |
+
task_type="CAUSAL_LM",
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
model = get_peft_model(model, lora_config)
|
| 84 |
+
model.print_trainable_parameters()
|
| 85 |
+
|
| 86 |
+
SYSTEM_PROMPT = """You are an expert Python developer. You are given a repository where a source file has been replaced with a broken stub. Your task is to explore the repository, understand the expected behavior from the tests, and rewrite the source file so all tests pass.
|
| 87 |
+
|
| 88 |
+
You interact via a Python REPL. Available functions:
|
| 89 |
+
- read_file(path) — Read a file from the repo
|
| 90 |
+
- list_dir(path='.') — List directory contents
|
| 91 |
+
- search(pattern, path='.') — Grep for a pattern
|
| 92 |
+
- write_file(path, content) — Write/create a file
|
| 93 |
+
- run_tests(test_path=None) — Run pytest on a test file
|
| 94 |
+
- FINAL() — Signal that your implementation is complete
|
| 95 |
+
|
| 96 |
+
Strategy:
|
| 97 |
+
1. Read the failing test file to understand expected behavior
|
| 98 |
+
2. Read other source files for context (imports, dependencies)
|
| 99 |
+
3. Write the implementation
|
| 100 |
+
4. Run tests to verify
|
| 101 |
+
5. Fix any failures
|
| 102 |
+
6. Call FINAL() when done
|
| 103 |
+
|
| 104 |
+
Output ONLY valid Python code. No markdown, no explanations — just code to execute."""
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def build_prompt(task_description: str, failing_tests: list[str]) -> list[dict]:
|
| 108 |
+
"""Build the chat prompt for the initial observation."""
|
| 109 |
+
user_msg = f"{task_description}\n\nFailing tests:\n" + "\n".join(failing_tests)
|
| 110 |
+
return [
|
| 111 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 112 |
+
{"role": "user", "content": user_msg},
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def extract_code_from_response(response: str) -> str:
|
| 117 |
+
"""Extract executable Python code from model response."""
|
| 118 |
+
# Try to find code blocks first
|
| 119 |
+
code_blocks = re.findall(r"```(?:python)?\n(.*?)```", response, re.DOTALL)
|
| 120 |
+
if code_blocks:
|
| 121 |
+
return "\n".join(code_blocks)
|
| 122 |
+
# Otherwise treat the whole response as code
|
| 123 |
+
lines = response.strip().split("\n")
|
| 124 |
+
code_lines = []
|
| 125 |
+
for line in lines:
|
| 126 |
+
stripped = line.strip()
|
| 127 |
+
if stripped and not stripped.startswith("#") and any(c in stripped for c in "=()[]{}:"):
|
| 128 |
+
code_lines.append(line)
|
| 129 |
+
elif stripped.startswith("#") or stripped.startswith("import") or stripped.startswith("from"):
|
| 130 |
+
code_lines.append(line)
|
| 131 |
+
elif not stripped:
|
| 132 |
+
code_lines.append(line)
|
| 133 |
+
else:
|
| 134 |
+
code_lines.append(f"# {line}")
|
| 135 |
+
return "\n".join(code_lines)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
print("Prompt builder ready.")
|
| 139 |
+
|
| 140 |
+
@dataclass
|
| 141 |
+
class Trajectory:
|
| 142 |
+
"""A full multi-step episode trajectory for GRPO training."""
|
| 143 |
+
prompt_text: str # Tokenized prompt (system + task)
|
| 144 |
+
completion_text: str # All model outputs concatenated
|
| 145 |
+
reward: float # Final episode reward
|
| 146 |
+
steps: int # Number of steps taken
|
| 147 |
+
seed: int # Environment seed (for reproducibility)
|
| 148 |
+
tests_passed: int
|
| 149 |
+
tests_total: int
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def run_episode(
|
| 153 |
+
model,
|
| 154 |
+
tokenizer,
|
| 155 |
+
env: RLMForgeEnvironment,
|
| 156 |
+
seed: int,
|
| 157 |
+
max_steps: int = MAX_STEPS_PER_EPISODE,
|
| 158 |
+
temperature: float = 0.7,
|
| 159 |
+
max_new_tokens: int = 2048,
|
| 160 |
+
) -> Trajectory:
|
| 161 |
+
"""Run a single episode: generate code actions, execute them, collect trajectory."""
|
| 162 |
+
obs = env.reset(seed=seed)
|
| 163 |
+
|
| 164 |
+
messages = build_prompt(obs.task_description, obs.failing_tests or [])
|
| 165 |
+
prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 166 |
+
|
| 167 |
+
all_completions = [] # All model outputs for this episode
|
| 168 |
+
|
| 169 |
+
for step_i in range(max_steps):
|
| 170 |
+
# Build the full conversation so far for the model
|
| 171 |
+
if step_i > 0:
|
| 172 |
+
# Add the observation as assistant feedback
|
| 173 |
+
messages.append({"role": "user", "content": f"REPL output:\n{obs.stdout}\n{obs.stderr}"})
|
| 174 |
+
|
| 175 |
+
# Generate next action
|
| 176 |
+
full_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 177 |
+
inputs = tokenizer(full_text, return_tensors="pt", truncation=True, max_length=8192).to(model.device)
|
| 178 |
+
|
| 179 |
+
with torch.no_grad():
|
| 180 |
+
outputs = model.generate(
|
| 181 |
+
**inputs,
|
| 182 |
+
max_new_tokens=max_new_tokens,
|
| 183 |
+
temperature=temperature,
|
| 184 |
+
top_p=0.95,
|
| 185 |
+
do_sample=True,
|
| 186 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Decode only the new tokens
|
| 190 |
+
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
| 191 |
+
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 192 |
+
all_completions.append(response)
|
| 193 |
+
|
| 194 |
+
# Add to conversation history
|
| 195 |
+
messages.append({"role": "assistant", "content": response})
|
| 196 |
+
|
| 197 |
+
# Extract and execute code
|
| 198 |
+
code = extract_code_from_response(response)
|
| 199 |
+
|
| 200 |
+
# Check if model wants to finalize
|
| 201 |
+
if "FINAL()" in code:
|
| 202 |
+
obs = env.step(RLMForgeAction(code=code))
|
| 203 |
+
break
|
| 204 |
+
else:
|
| 205 |
+
obs = env.step(RLMForgeAction(code=code))
|
| 206 |
+
|
| 207 |
+
if obs.done:
|
| 208 |
+
break
|
| 209 |
+
|
| 210 |
+
# If we exhausted steps without FINAL, force finalize
|
| 211 |
+
if not obs.done:
|
| 212 |
+
obs = env.step(RLMForgeAction(code="FINAL()"))
|
| 213 |
+
|
| 214 |
+
# Build the full completion text (all model outputs joined)
|
| 215 |
+
completion_text = "\n<|step|>\n".join(all_completions)
|
| 216 |
+
|
| 217 |
+
reward = obs.reward or 0.0
|
| 218 |
+
test_results = obs.test_results or {}
|
| 219 |
+
|
| 220 |
+
return Trajectory(
|
| 221 |
+
prompt_text=prompt_text,
|
| 222 |
+
completion_text=completion_text,
|
| 223 |
+
reward=reward,
|
| 224 |
+
steps=step_i + 1,
|
| 225 |
+
seed=seed,
|
| 226 |
+
tests_passed=test_results.get("tests_passed", 0),
|
| 227 |
+
tests_total=test_results.get("tests_total", 0),
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
print("Episode runner ready.")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def collect_trajectories(
|
| 235 |
+
model,
|
| 236 |
+
tokenizer,
|
| 237 |
+
num_prompts: int = NUM_TRAINING_PROMPTS,
|
| 238 |
+
episodes_per_prompt: int = NUM_EPISODES_PER_PROMPT,
|
| 239 |
+
temperature: float = 0.7,
|
| 240 |
+
) -> list[list[Trajectory]]:
|
| 241 |
+
"""Collect GRPO groups: multiple trajectories per unique prompt/seed."""
|
| 242 |
+
env = RLMForgeEnvironment()
|
| 243 |
+
all_groups = []
|
| 244 |
+
|
| 245 |
+
for prompt_idx in range(num_prompts):
|
| 246 |
+
seed = prompt_idx * 100 # Deterministic seeds
|
| 247 |
+
group = []
|
| 248 |
+
|
| 249 |
+
for ep_idx in range(episodes_per_prompt):
|
| 250 |
+
print(f" Prompt {prompt_idx+1}/{num_prompts}, Episode {ep_idx+1}/{episodes_per_prompt}...", end=" ")
|
| 251 |
+
traj = run_episode(
|
| 252 |
+
model, tokenizer, env,
|
| 253 |
+
seed=seed, # Same seed = same task for GRPO group
|
| 254 |
+
temperature=temperature + 0.1 * ep_idx, # Vary temperature for diversity
|
| 255 |
+
)
|
| 256 |
+
group.append(traj)
|
| 257 |
+
print(f"reward={traj.reward:.3f}, steps={traj.steps}, "
|
| 258 |
+
f"tests={traj.tests_passed}/{traj.tests_total}")
|
| 259 |
+
|
| 260 |
+
all_groups.append(group)
|
| 261 |
+
|
| 262 |
+
env.cleanup()
|
| 263 |
+
return all_groups
|
| 264 |
+
|
| 265 |
+
# GRPO Training configuration
|
| 266 |
+
grpo_config = GRPOConfig(
|
| 267 |
+
output_dir="./rlm_forge_grpo_output",
|
| 268 |
+
num_train_epochs=GRPO_EPOCHS,
|
| 269 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 270 |
+
gradient_accumulation_steps=GRAD_ACCUM,
|
| 271 |
+
learning_rate=1e-5,
|
| 272 |
+
warmup_ratio=0.1,
|
| 273 |
+
max_completion_length=4096,
|
| 274 |
+
# max_prompt_length=4096,
|
| 275 |
+
num_generations=NUM_EPISODES_PER_PROMPT, # GRPO group size
|
| 276 |
+
logging_steps=1,
|
| 277 |
+
save_strategy="epoch",
|
| 278 |
+
bf16=True,
|
| 279 |
+
gradient_checkpointing=True,
|
| 280 |
+
# GRPO-specific
|
| 281 |
+
beta=0.1, # KL penalty coefficient
|
| 282 |
+
report_to="none",
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# Collect pre-training baseline
|
| 286 |
+
print("=" * 60)
|
| 287 |
+
print("COLLECTING BASELINE TRAJECTORIES")
|
| 288 |
+
print("=" * 60)
|
| 289 |
+
baseline_groups = collect_trajectories(model, tokenizer)
|
| 290 |
+
|
| 291 |
+
# Summary stats
|
| 292 |
+
all_rewards = [t.reward for g in baseline_groups for t in g]
|
| 293 |
+
print(f"\nBaseline: mean_reward={sum(all_rewards)/len(all_rewards):.4f}, "
|
| 294 |
+
f"min={min(all_rewards):.4f}, max={max(all_rewards):.4f}")
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def trajectories_to_dataset(groups: list[list[Trajectory]]) -> Dataset:
|
| 301 |
+
"""Convert trajectory groups into a HuggingFace Dataset for GRPO training."""
|
| 302 |
+
records = []
|
| 303 |
+
for group in groups:
|
| 304 |
+
prompt = group[0].prompt_text
|
| 305 |
+
for traj in group:
|
| 306 |
+
records.append({
|
| 307 |
+
"prompt": prompt,
|
| 308 |
+
"completion": traj.completion_text,
|
| 309 |
+
"reward": traj.reward,
|
| 310 |
+
})
|
| 311 |
+
return Dataset.from_list(records)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def build_reward_fn(groups: list[list[Trajectory]]):
|
| 315 |
+
"""Build a reward function from pre-collected trajectories."""
|
| 316 |
+
reward_map = {}
|
| 317 |
+
for group in groups:
|
| 318 |
+
for traj in group:
|
| 319 |
+
key = traj.completion_text[:200]
|
| 320 |
+
reward_map[key] = traj.reward
|
| 321 |
+
|
| 322 |
+
def reward_fn(completions: list[str], **kwargs) -> list[float]:
|
| 323 |
+
rewards = []
|
| 324 |
+
for c in completions:
|
| 325 |
+
key = c[:200]
|
| 326 |
+
rewards.append(reward_map.get(key, 0.0))
|
| 327 |
+
return rewards
|
| 328 |
+
|
| 329 |
+
return reward_fn
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
# Build dataset from baseline trajectories
|
| 333 |
+
train_dataset = trajectories_to_dataset(baseline_groups)
|
| 334 |
+
print(f"Training dataset: {len(train_dataset)} examples")
|
| 335 |
+
print(f"Sample prompt length: {len(train_dataset[0]['prompt'])} chars")
|
| 336 |
+
print(f"Sample completion length: {len(train_dataset[0]['completion'])} chars")
|
| 337 |
+
print(f"Sample reward: {train_dataset[0]['reward']:.4f}")
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# Build reward function from collected trajectories
|
| 343 |
+
reward_fn = build_reward_fn(baseline_groups)
|
| 344 |
+
|
| 345 |
+
# Prepare prompts dataset (unique prompts only, GRPO generates completions)
|
| 346 |
+
prompt_dataset = Dataset.from_list([
|
| 347 |
+
{"prompt": group[0].prompt_text}
|
| 348 |
+
for group in baseline_groups
|
| 349 |
+
])
|
| 350 |
+
|
| 351 |
+
# Initialize GRPO trainer
|
| 352 |
+
trainer = GRPOTrainer(
|
| 353 |
+
model=model,
|
| 354 |
+
args=grpo_config,
|
| 355 |
+
train_dataset=prompt_dataset,
|
| 356 |
+
reward_funcs=reward_fn,
|
| 357 |
+
processing_class=tokenizer,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
print("GRPO Trainer initialized. Starting training...")
|
| 361 |
+
trainer.train()
|
| 362 |
+
print("Training complete!")
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# Collect post-training trajectories with the same seeds
|
| 366 |
+
print("=" * 60)
|
| 367 |
+
print("COLLECTING POST-TRAINING TRAJECTORIES")
|
| 368 |
+
print("=" * 60)
|
| 369 |
+
post_groups = collect_trajectories(model, tokenizer, temperature=0.5)
|
| 370 |
+
|
| 371 |
+
post_rewards = [t.reward for g in post_groups for t in g]
|
| 372 |
+
baseline_rewards = [t.reward for g in baseline_groups for t in g]
|
| 373 |
+
|
| 374 |
+
print(f"\n{'='*60}")
|
| 375 |
+
print(f"RESULTS COMPARISON")
|
| 376 |
+
print(f"{'='*60}")
|
| 377 |
+
print(f"Baseline: mean={sum(baseline_rewards)/len(baseline_rewards):.4f}, "
|
| 378 |
+
f"max={max(baseline_rewards):.4f}")
|
| 379 |
+
print(f"Trained: mean={sum(post_rewards)/len(post_rewards):.4f}, "
|
| 380 |
+
f"max={max(post_rewards):.4f}")
|
| 381 |
+
print(f"Improvement: {(sum(post_rewards)/len(post_rewards) - sum(baseline_rewards)/len(baseline_rewards)):.4f}")
|
| 382 |
+
|
| 383 |
+
# Per-task comparison
|
| 384 |
+
print(f"\nPer-task breakdown:")
|
| 385 |
+
for i, (bg, pg) in enumerate(zip(baseline_groups, post_groups)):
|
| 386 |
+
b_mean = sum(t.reward for t in bg) / len(bg)
|
| 387 |
+
p_mean = sum(t.reward for t in pg) / len(pg)
|
| 388 |
+
delta = p_mean - b_mean
|
| 389 |
+
arrow = "\u2191" if delta > 0 else "\u2193" if delta < 0 else "\u2192"
|
| 390 |
+
print(f" Task {i}: baseline={b_mean:.3f} \u2192 trained={p_mean:.3f} ({arrow} {abs(delta):.3f})")
|
| 391 |
+
|
| 392 |
+
import matplotlib.pyplot as plt
|
| 393 |
+
import numpy as np
|
| 394 |
+
|
| 395 |
+
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
|
| 396 |
+
|
| 397 |
+
# 1. Reward distribution: baseline vs trained
|
| 398 |
+
ax1 = axes[0]
|
| 399 |
+
ax1.hist(baseline_rewards, bins=20, alpha=0.6, label="Baseline", color="steelblue")
|
| 400 |
+
ax1.hist(post_rewards, bins=20, alpha=0.6, label="After GRPO", color="coral")
|
| 401 |
+
ax1.set_xlabel("Episode Reward")
|
| 402 |
+
ax1.set_ylabel("Count")
|
| 403 |
+
ax1.set_title("Reward Distribution")
|
| 404 |
+
ax1.legend()
|
| 405 |
+
ax1.axvline(np.mean(baseline_rewards), color="steelblue", linestyle="--", alpha=0.8)
|
| 406 |
+
ax1.axvline(np.mean(post_rewards), color="coral", linestyle="--", alpha=0.8)
|
| 407 |
+
|
| 408 |
+
# 2. Per-task mean reward comparison
|
| 409 |
+
ax2 = axes[1]
|
| 410 |
+
task_ids = list(range(len(baseline_groups)))
|
| 411 |
+
b_means = [np.mean([t.reward for t in g]) for g in baseline_groups]
|
| 412 |
+
p_means = [np.mean([t.reward for t in g]) for g in post_groups]
|
| 413 |
+
x = np.arange(len(task_ids))
|
| 414 |
+
width = 0.35
|
| 415 |
+
ax2.bar(x - width/2, b_means, width, label="Baseline", color="steelblue", alpha=0.8)
|
| 416 |
+
ax2.bar(x + width/2, p_means, width, label="After GRPO", color="coral", alpha=0.8)
|
| 417 |
+
ax2.set_xlabel("Task ID")
|
| 418 |
+
ax2.set_ylabel("Mean Reward")
|
| 419 |
+
ax2.set_title("Per-Task Reward Improvement")
|
| 420 |
+
ax2.legend()
|
| 421 |
+
ax2.set_xticks(x)
|
| 422 |
+
|
| 423 |
+
# 3. Test pass rate improvement
|
| 424 |
+
ax3 = axes[2]
|
| 425 |
+
b_pass_rates = [np.mean([t.tests_passed / max(t.tests_total, 1) for t in g]) for g in baseline_groups]
|
| 426 |
+
p_pass_rates = [np.mean([t.tests_passed / max(t.tests_total, 1) for t in g]) for g in post_groups]
|
| 427 |
+
ax3.bar(x - width/2, b_pass_rates, width, label="Baseline", color="steelblue", alpha=0.8)
|
| 428 |
+
ax3.bar(x + width/2, p_pass_rates, width, label="After GRPO", color="coral", alpha=0.8)
|
| 429 |
+
ax3.set_xlabel("Task ID")
|
| 430 |
+
ax3.set_ylabel("Test Pass Rate")
|
| 431 |
+
ax3.set_title("Test Pass Rate Improvement")
|
| 432 |
+
ax3.legend()
|
| 433 |
+
ax3.set_xticks(x)
|
| 434 |
+
|
| 435 |
+
plt.tight_layout()
|
| 436 |
+
plt.savefig("rlm_forge_results.png", dpi=150, bbox_inches="tight")
|
| 437 |
+
plt.show()
|
| 438 |
+
|
| 439 |
+
print(f"\nOverall test pass rate:")
|
| 440 |
+
print(f" Baseline: {np.mean(b_pass_rates):.1%}")
|
| 441 |
+
print(f" Trained: {np.mean(p_pass_rates):.1%}")
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
# Save the trained LoRA adapter
|
| 445 |
+
model.save_pretrained("./rlm_forge_lora_adapter")
|
| 446 |
+
tokenizer.save_pretrained("./rlm_forge_lora_adapter")
|
| 447 |
+
|
| 448 |
+
# Save training log
|
| 449 |
+
training_log = {
|
| 450 |
+
"model_id": MODEL_ID,
|
| 451 |
+
"num_prompts": NUM_TRAINING_PROMPTS,
|
| 452 |
+
"episodes_per_prompt": NUM_EPISODES_PER_PROMPT,
|
| 453 |
+
"max_steps_per_episode": MAX_STEPS_PER_EPISODE,
|
| 454 |
+
"grpo_epochs": GRPO_EPOCHS,
|
| 455 |
+
"baseline_mean_reward": float(np.mean(baseline_rewards)),
|
| 456 |
+
"baseline_max_reward": float(max(baseline_rewards)),
|
| 457 |
+
"trained_mean_reward": float(np.mean(post_rewards)),
|
| 458 |
+
"trained_max_reward": float(max(post_rewards)),
|
| 459 |
+
"improvement": float(np.mean(post_rewards) - np.mean(baseline_rewards)),
|
| 460 |
+
"baseline_test_pass_rate": float(np.mean(b_pass_rates)),
|
| 461 |
+
"trained_test_pass_rate": float(np.mean(p_pass_rates)),
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
with open("training_log.json", "w") as f:
|
| 465 |
+
json.dump(training_log, f, indent=2)
|
| 466 |
+
|
| 467 |
+
print("Saved LoRA adapter to ./rlm_forge_lora_adapter")
|
| 468 |
+
print("Saved training log to training_log.json")
|
| 469 |
+
print(f"\nFinal summary:")
|
| 470 |
+
print(json.dumps(training_log, indent=2))
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Server package for OpenEnv deployment.
|
server/app.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI server entry point for RLM-Forge environment.
|
| 2 |
+
|
| 3 |
+
This module provides the standardized OpenEnv server entry point.
|
| 4 |
+
It wraps the rlm_forge.server.app module for multi-mode deployment.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
uv run server
|
| 8 |
+
python -m server.app
|
| 9 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from rlm_forge.server.app import app # noqa: F401
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 16 |
+
"""Entry point for direct execution via uv run or python -m."""
|
| 17 |
+
import uvicorn
|
| 18 |
+
|
| 19 |
+
uvicorn.run(app, host=host, port=port)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if __name__ == "__main__":
|
| 23 |
+
main()
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|