File size: 6,964 Bytes
cacd58c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
#!/usr/bin/env python3
"""
GRPO Training Script for gpt-oss-120b using OpenEnv and TRL.
Adapted from the openenv-course repository architecture.

Requirements:
pip install "trl>=0.17.0" openenv-core transformers datasets accelerate vllm
"""
import os
import sys
import json
import torch
from datasets import Dataset
from transformers import AutoTokenizer

from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))

from code_debug_env.client import CodeDebugEnv
from code_debug_env.models import Action

# TRL imports
from trl import GRPOConfig, GRPOTrainer
from trl.experimental.openenv import generate_rollout_completions

# 1. Configuration
MODEL_NAME = "openai/gpt-oss-120b"
OUTPUT_DIR = "code-debug-grpo-120b"
ENV_URL = os.getenv("OPENENV_URL", "http://127.0.0.1:8000")

# 2. Setup Persistent Environment Connection
print(f"Connecting to env: {ENV_URL}")
env = CodeDebugEnv(base_url=ENV_URL)
sync_env = env.sync()
sync_env.connect()

# 3. Setup Tokenizer
print(f"Loading tokenizer for {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# 4. System Prompt Definition
SYSTEM_PROMPT = """You are an expert Python debugger and RL agent.
Your task is to fix the buggy code provided to you.

Provide ONLY a valid JSON object matching this schema:
{
  "patch": "The FULL python function as a string, with the bugs fixed",
  "task_id": "the task requested",
  "think": "Your chain-of-thought reasoning before patching (important for rewards!)"
}
"""

def make_user_prompt(observation):
    return (
        f"Task Description: {observation.task_description}\n\n"
        f"Buggy Code:\n```python\n{observation.buggy_code}\n```\n\n"
        f"Passed {observation.passed} out of {observation.total} tests."
    )

# 5. Rollout Function
def rollout_once(trainer, sync_env, tokenizer, dataset_prompt, system_prompt, max_turns):
    """Execute one full episode to gather trajectory formatting for GRPO."""
    result = sync_env.reset()
    observation = result.observation

    prompt_ids = []
    completion_ids = []
    logprobs = []
    composite_rewards = []

    for _turn in range(max_turns):
        if result.done:
            break

        user_prompt = make_user_prompt(observation)
        messages = [
            {'role': 'system', 'content': system_prompt},
            {'role': 'user', 'content': user_prompt},
        ]
        
        prompt_text = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=False,
            enable_thinking=False,
        )

        rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0]
        prompt_ids.append(rollout_outputs['prompt_ids'])
        completion_ids.append(rollout_outputs['completion_ids'])
        logprobs.append(rollout_outputs['logprobs'])
        
        completion_text = rollout_outputs.get('text') or tokenizer.decode(
            rollout_outputs['completion_ids'], skip_special_tokens=True
        )

        # Parse JSON output from the model
        try:
            # simple extraction since prompt dictates JSON
            start = completion_text.find("{")
            end = completion_text.rfind("}") + 1
            if start != -1 and end != -1:
                data = json.loads(completion_text[start:end])
                action = Action(patch=data["patch"], task_id=observation.task_id, think=data.get("think", ""))
            else:
                raise ValueError("No JSON found")
        except:
            # Fallback action if parsing fails
            action = Action(patch=observation.buggy_code, task_id=observation.task_id, think="")

        # Step the environment
        result = sync_env.step(action)
        observation = result.observation
        
        # The environment already calculates the composite reward (0.0 to 1.0)
        # correctness, format, CoT bonus, and efficiency are all baked in.
        composite_rewards.append(observation.score)

    return {
        'prompt_ids': [pid for sub in prompt_ids for pid in sub], # flatten
        'completion_ids': [cid for sub in completion_ids for cid in sub],
        'logprobs': [lp for sub in logprobs for lp in sub],
        'env_reward': composite_rewards[-1] if composite_rewards else 0.0,
    }


def rollout_func(prompts, trainer=None):
    """Rollout function called by GRPOTrainer."""
    episode_prompt_ids = []
    episode_completion_ids = []
    episode_logprobs = []
    rewards = []

    for prompt_text in prompts:
        episode = rollout_once(
            trainer=trainer,
            sync_env=sync_env,
            tokenizer=tokenizer,
            dataset_prompt=prompt_text,
            system_prompt=SYSTEM_PROMPT,
            max_turns=3, # Keep turns low for heavy models like 120B
        )
        episode_prompt_ids.append(episode['prompt_ids'])
        episode_completion_ids.append(episode['completion_ids'])
        episode_logprobs.append(episode['logprobs'])
        rewards.append(episode['env_reward'])

    return {
        'prompt_ids': episode_prompt_ids,
        'completion_ids': episode_completion_ids,
        'logprobs': episode_logprobs,
        'env_reward': rewards,
    }

# 6. Reward Functions (Mapped from rollout_func keys)
def composite_env_reward(completions, **kwargs):
    rewards = kwargs.get("env_reward")
    return [float(r) for r in rewards] if rewards else [0.0] * len(completions)


# 7. Create Dataset & Config
def main():
    print("Preparing dataset...")
    # Dummy prompts to kick off the rollout loop (the actual env state overrides this)
    dataset = Dataset.from_dict({"prompt": ["Fix the buggy Python code."] * 500})
    
    # Using specific optimizations for 120B model (like MXFP4, tensor parallelism if available)
    grpo_config = GRPOConfig(
        num_train_epochs=1,
        learning_rate=1e-6, # lower LR for 120B
        gradient_accumulation_steps=128,
        per_device_train_batch_size=1,
        warmup_steps=10,
        num_generations=2,
        max_completion_length=512,
        max_prompt_length=1500,
        use_vllm=True,
        vllm_mode="colocate",
        vllm_gpu_memory_utilization=0.9, # maximize for 120B
        output_dir=OUTPUT_DIR,
        logging_steps=1,
        save_steps=50,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        push_to_hub=False,
    )

    print(f"Initializing GRPOTrainer for {MODEL_NAME}...")
    trainer = GRPOTrainer(
        model=MODEL_NAME,
        processing_class=tokenizer,
        reward_funcs=[composite_env_reward],
        train_dataset=dataset,
        args=grpo_config,
        rollout_func=rollout_func,
    )

    print("Starting training...")
    trainer.train()
    
    sync_env.close()
    trainer.save_model(OUTPUT_DIR)
    print("Training complete! Model saved.")

if __name__ == "__main__":
    main()