#!/usr/bin/env python3 """ GRPO Training Script for gpt-oss-120b using OpenEnv and TRL. Adapted from the openenv-course repository architecture. Requirements: pip install "trl>=0.17.0" openenv-core transformers datasets accelerate vllm """ import os import sys import json import torch from datasets import Dataset from transformers import AutoTokenizer from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from code_debug_env.client import CodeDebugEnv from code_debug_env.models import Action # TRL imports from trl import GRPOConfig, GRPOTrainer from trl.experimental.openenv import generate_rollout_completions # 1. Configuration MODEL_NAME = "openai/gpt-oss-120b" OUTPUT_DIR = "code-debug-grpo-120b" ENV_URL = os.getenv("OPENENV_URL", "http://127.0.0.1:8000") # 2. Setup Persistent Environment Connection print(f"Connecting to env: {ENV_URL}") env = CodeDebugEnv(base_url=ENV_URL) sync_env = env.sync() sync_env.connect() # 3. Setup Tokenizer print(f"Loading tokenizer for {MODEL_NAME}") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # 4. System Prompt Definition SYSTEM_PROMPT = """You are an expert Python debugger and RL agent. Your task is to fix the buggy code provided to you. Provide ONLY a valid JSON object matching this schema: { "patch": "The FULL python function as a string, with the bugs fixed", "task_id": "the task requested", "think": "Your chain-of-thought reasoning before patching (important for rewards!)" } """ def make_user_prompt(observation): return ( f"Task Description: {observation.task_description}\n\n" f"Buggy Code:\n```python\n{observation.buggy_code}\n```\n\n" f"Passed {observation.passed} out of {observation.total} tests." ) # 5. Rollout Function def rollout_once(trainer, sync_env, tokenizer, dataset_prompt, system_prompt, max_turns): """Execute one full episode to gather trajectory formatting for GRPO.""" result = sync_env.reset() observation = result.observation prompt_ids = [] completion_ids = [] logprobs = [] composite_rewards = [] for _turn in range(max_turns): if result.done: break user_prompt = make_user_prompt(observation) messages = [ {'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': user_prompt}, ] prompt_text = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False, enable_thinking=False, ) rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0] prompt_ids.append(rollout_outputs['prompt_ids']) completion_ids.append(rollout_outputs['completion_ids']) logprobs.append(rollout_outputs['logprobs']) completion_text = rollout_outputs.get('text') or tokenizer.decode( rollout_outputs['completion_ids'], skip_special_tokens=True ) # Parse JSON output from the model try: # simple extraction since prompt dictates JSON start = completion_text.find("{") end = completion_text.rfind("}") + 1 if start != -1 and end != -1: data = json.loads(completion_text[start:end]) action = Action(patch=data["patch"], task_id=observation.task_id, think=data.get("think", "")) else: raise ValueError("No JSON found") except: # Fallback action if parsing fails action = Action(patch=observation.buggy_code, task_id=observation.task_id, think="") # Step the environment result = sync_env.step(action) observation = result.observation # The environment already calculates the composite reward (0.0 to 1.0) # correctness, format, CoT bonus, and efficiency are all baked in. composite_rewards.append(observation.score) return { 'prompt_ids': [pid for sub in prompt_ids for pid in sub], # flatten 'completion_ids': [cid for sub in completion_ids for cid in sub], 'logprobs': [lp for sub in logprobs for lp in sub], 'env_reward': composite_rewards[-1] if composite_rewards else 0.0, } def rollout_func(prompts, trainer=None): """Rollout function called by GRPOTrainer.""" episode_prompt_ids = [] episode_completion_ids = [] episode_logprobs = [] rewards = [] for prompt_text in prompts: episode = rollout_once( trainer=trainer, sync_env=sync_env, tokenizer=tokenizer, dataset_prompt=prompt_text, system_prompt=SYSTEM_PROMPT, max_turns=3, # Keep turns low for heavy models like 120B ) episode_prompt_ids.append(episode['prompt_ids']) episode_completion_ids.append(episode['completion_ids']) episode_logprobs.append(episode['logprobs']) rewards.append(episode['env_reward']) return { 'prompt_ids': episode_prompt_ids, 'completion_ids': episode_completion_ids, 'logprobs': episode_logprobs, 'env_reward': rewards, } # 6. Reward Functions (Mapped from rollout_func keys) def composite_env_reward(completions, **kwargs): rewards = kwargs.get("env_reward") return [float(r) for r in rewards] if rewards else [0.0] * len(completions) # 7. Create Dataset & Config def main(): print("Preparing dataset...") # Dummy prompts to kick off the rollout loop (the actual env state overrides this) dataset = Dataset.from_dict({"prompt": ["Fix the buggy Python code."] * 500}) # Using specific optimizations for 120B model (like MXFP4, tensor parallelism if available) grpo_config = GRPOConfig( num_train_epochs=1, learning_rate=1e-6, # lower LR for 120B gradient_accumulation_steps=128, per_device_train_batch_size=1, warmup_steps=10, num_generations=2, max_completion_length=512, max_prompt_length=1500, use_vllm=True, vllm_mode="colocate", vllm_gpu_memory_utilization=0.9, # maximize for 120B output_dir=OUTPUT_DIR, logging_steps=1, save_steps=50, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, push_to_hub=False, ) print(f"Initializing GRPOTrainer for {MODEL_NAME}...") trainer = GRPOTrainer( model=MODEL_NAME, processing_class=tokenizer, reward_funcs=[composite_env_reward], train_dataset=dataset, args=grpo_config, rollout_func=rollout_func, ) print("Starting training...") trainer.train() sync_env.close() trainer.save_model(OUTPUT_DIR) print("Training complete! Model saved.") if __name__ == "__main__": main()