codearena-rl / temp_grpo_check.py
havinashpatil
Finalizing CodeArena RL Benchmark: frontend improvements, GRPO training scripts, and cleaned environment
03a7eb9
import re
import argparse
from typing import Any
import httpx
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
ENV_URL = "http://127.0.0.1:7860"
MODEL_NAME = "distilgpt2"
def _extract_text(completion: Any) -> str:
if isinstance(completion, str):
return completion
if isinstance(completion, list):
chunks = []
for item in completion:
if isinstance(item, dict) and "content" in item:
chunks.append(str(item["content"]))
else:
chunks.append(str(item))
return "\n".join(chunks)
if isinstance(completion, dict):
return str(completion.get("content", ""))
return str(completion)
def _clean_fix(text: str) -> str:
text = text.strip()
text = re.sub(r"^```(?:python)?\s*", "", text)
text = re.sub(r"\s*```$", "", text)
return text.strip() or "pass"
def codearena_reward_func(completions, prompts, **kwargs):
rewards = []
with httpx.Client(timeout=60.0) as client:
for completion in completions:
proposed_fix = _clean_fix(_extract_text(completion))
reward = 0.001
for _ in range(2):
try:
client.post(f"{ENV_URL}/reset", json={"task_id": "easy-1"})
res = client.post(
f"{ENV_URL}/step",
json={"proposed_fix": proposed_fix},
)
reward = float(res.json().get("reward", 0.001))
break
except Exception:
reward = 0.001
rewards.append(max(0.001, min(0.999, reward)))
return rewards
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--max-steps", type=int, default=3)
parser.add_argument("--output-dir", type=str, default="./grpo-check-output")
args = parser.parse_args()
prompts = [
"Fix this Python function: def average_list(numbers)\\n if length(numbers) == 0:\\n return 0\\n return sum(numbers) / length(numbers)",
"Repair all root-cause issues in the function and keep readability high.",
"Return a corrected Python function only. Ensure tests pass.",
"Fix missing syntax and replace invalid APIs with valid Python APIs.",
"Correct both compile and semantic issues in the provided function.",
"Provide a secure, clean fix for average_list in Python.",
]
train_dataset = Dataset.from_dict({"prompt": prompts})
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
training_args = GRPOConfig(
output_dir=args.output_dir,
learning_rate=1e-5,
max_steps=args.max_steps,
per_device_train_batch_size=2,
gradient_accumulation_steps=1,
logging_steps=1,
num_generations=2,
max_prompt_length=256,
max_completion_length=96,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
shuffle_dataset=False,
seed=42,
bf16=False,
fp16=False,
report_to=[],
)
trainer = GRPOTrainer(
model=model,
reward_funcs=codearena_reward_func,
args=training_args,
train_dataset=train_dataset,
)
trainer.train()
print("GRPO check finished.")
if __name__ == "__main__":
main()