Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| GRPO Training Script for gpt-oss-120b using OpenEnv and TRL. | |
| Adapted from the openenv-course repository architecture. | |
| Requirements: | |
| pip install "trl>=0.17.0" openenv-core transformers datasets accelerate vllm | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import torch | |
| from datasets import Dataset | |
| from transformers import AutoTokenizer | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from code_debug_env.client import CodeDebugEnv | |
| from code_debug_env.models import Action | |
| # TRL imports | |
| from trl import GRPOConfig, GRPOTrainer | |
| from trl.experimental.openenv import generate_rollout_completions | |
| # 1. Configuration | |
| MODEL_NAME = "openai/gpt-oss-120b" | |
| OUTPUT_DIR = "code-debug-grpo-120b" | |
| ENV_URL = os.getenv("OPENENV_URL", "http://127.0.0.1:8000") | |
| # 2. Setup Persistent Environment Connection | |
| print(f"Connecting to env: {ENV_URL}") | |
| env = CodeDebugEnv(base_url=ENV_URL) | |
| sync_env = env.sync() | |
| sync_env.connect() | |
| # 3. Setup Tokenizer | |
| print(f"Loading tokenizer for {MODEL_NAME}") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # 4. System Prompt Definition | |
| SYSTEM_PROMPT = """You are an expert Python debugger and RL agent. | |
| Your task is to fix the buggy code provided to you. | |
| Provide ONLY a valid JSON object matching this schema: | |
| { | |
| "patch": "The FULL python function as a string, with the bugs fixed", | |
| "task_id": "the task requested", | |
| "think": "Your chain-of-thought reasoning before patching (important for rewards!)" | |
| } | |
| """ | |
| def make_user_prompt(observation): | |
| return ( | |
| f"Task Description: {observation.task_description}\n\n" | |
| f"Buggy Code:\n```python\n{observation.buggy_code}\n```\n\n" | |
| f"Passed {observation.passed} out of {observation.total} tests." | |
| ) | |
| # 5. Rollout Function | |
| def rollout_once(trainer, sync_env, tokenizer, dataset_prompt, system_prompt, max_turns): | |
| """Execute one full episode to gather trajectory formatting for GRPO.""" | |
| result = sync_env.reset() | |
| observation = result.observation | |
| prompt_ids = [] | |
| completion_ids = [] | |
| logprobs = [] | |
| composite_rewards = [] | |
| for _turn in range(max_turns): | |
| if result.done: | |
| break | |
| user_prompt = make_user_prompt(observation) | |
| messages = [ | |
| {'role': 'system', 'content': system_prompt}, | |
| {'role': 'user', 'content': user_prompt}, | |
| ] | |
| prompt_text = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=False, | |
| enable_thinking=False, | |
| ) | |
| rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0] | |
| prompt_ids.append(rollout_outputs['prompt_ids']) | |
| completion_ids.append(rollout_outputs['completion_ids']) | |
| logprobs.append(rollout_outputs['logprobs']) | |
| completion_text = rollout_outputs.get('text') or tokenizer.decode( | |
| rollout_outputs['completion_ids'], skip_special_tokens=True | |
| ) | |
| # Parse JSON output from the model | |
| try: | |
| # simple extraction since prompt dictates JSON | |
| start = completion_text.find("{") | |
| end = completion_text.rfind("}") + 1 | |
| if start != -1 and end != -1: | |
| data = json.loads(completion_text[start:end]) | |
| action = Action(patch=data["patch"], task_id=observation.task_id, think=data.get("think", "")) | |
| else: | |
| raise ValueError("No JSON found") | |
| except: | |
| # Fallback action if parsing fails | |
| action = Action(patch=observation.buggy_code, task_id=observation.task_id, think="") | |
| # Step the environment | |
| result = sync_env.step(action) | |
| observation = result.observation | |
| # The environment already calculates the composite reward (0.0 to 1.0) | |
| # correctness, format, CoT bonus, and efficiency are all baked in. | |
| composite_rewards.append(observation.score) | |
| return { | |
| 'prompt_ids': [pid for sub in prompt_ids for pid in sub], # flatten | |
| 'completion_ids': [cid for sub in completion_ids for cid in sub], | |
| 'logprobs': [lp for sub in logprobs for lp in sub], | |
| 'env_reward': composite_rewards[-1] if composite_rewards else 0.0, | |
| } | |
| def rollout_func(prompts, trainer=None): | |
| """Rollout function called by GRPOTrainer.""" | |
| episode_prompt_ids = [] | |
| episode_completion_ids = [] | |
| episode_logprobs = [] | |
| rewards = [] | |
| for prompt_text in prompts: | |
| episode = rollout_once( | |
| trainer=trainer, | |
| sync_env=sync_env, | |
| tokenizer=tokenizer, | |
| dataset_prompt=prompt_text, | |
| system_prompt=SYSTEM_PROMPT, | |
| max_turns=3, # Keep turns low for heavy models like 120B | |
| ) | |
| episode_prompt_ids.append(episode['prompt_ids']) | |
| episode_completion_ids.append(episode['completion_ids']) | |
| episode_logprobs.append(episode['logprobs']) | |
| rewards.append(episode['env_reward']) | |
| return { | |
| 'prompt_ids': episode_prompt_ids, | |
| 'completion_ids': episode_completion_ids, | |
| 'logprobs': episode_logprobs, | |
| 'env_reward': rewards, | |
| } | |
| # 6. Reward Functions (Mapped from rollout_func keys) | |
| def composite_env_reward(completions, **kwargs): | |
| rewards = kwargs.get("env_reward") | |
| return [float(r) for r in rewards] if rewards else [0.0] * len(completions) | |
| # 7. Create Dataset & Config | |
| def main(): | |
| print("Preparing dataset...") | |
| # Dummy prompts to kick off the rollout loop (the actual env state overrides this) | |
| dataset = Dataset.from_dict({"prompt": ["Fix the buggy Python code."] * 500}) | |
| # Using specific optimizations for 120B model (like MXFP4, tensor parallelism if available) | |
| grpo_config = GRPOConfig( | |
| num_train_epochs=1, | |
| learning_rate=1e-6, # lower LR for 120B | |
| gradient_accumulation_steps=128, | |
| per_device_train_batch_size=1, | |
| warmup_steps=10, | |
| num_generations=2, | |
| max_completion_length=512, | |
| max_prompt_length=1500, | |
| use_vllm=True, | |
| vllm_mode="colocate", | |
| vllm_gpu_memory_utilization=0.9, # maximize for 120B | |
| output_dir=OUTPUT_DIR, | |
| logging_steps=1, | |
| save_steps=50, | |
| gradient_checkpointing=True, | |
| gradient_checkpointing_kwargs={"use_reentrant": False}, | |
| push_to_hub=False, | |
| ) | |
| print(f"Initializing GRPOTrainer for {MODEL_NAME}...") | |
| trainer = GRPOTrainer( | |
| model=MODEL_NAME, | |
| processing_class=tokenizer, | |
| reward_funcs=[composite_env_reward], | |
| train_dataset=dataset, | |
| args=grpo_config, | |
| rollout_func=rollout_func, | |
| ) | |
| print("Starting training...") | |
| trainer.train() | |
| sync_env.close() | |
| trainer.save_model(OUTPUT_DIR) | |
| print("Training complete! Model saved.") | |
| if __name__ == "__main__": | |
| main() | |