Spaces:
Sleeping
Sleeping
havinashpatil
Finalizing CodeArena RL Benchmark: frontend improvements, GRPO training scripts, and cleaned environment
03a7eb9 | import re | |
| import argparse | |
| from typing import Any | |
| import httpx | |
| from datasets import Dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from trl import GRPOConfig, GRPOTrainer | |
| ENV_URL = "http://127.0.0.1:7860" | |
| MODEL_NAME = "distilgpt2" | |
| def _extract_text(completion: Any) -> str: | |
| if isinstance(completion, str): | |
| return completion | |
| if isinstance(completion, list): | |
| chunks = [] | |
| for item in completion: | |
| if isinstance(item, dict) and "content" in item: | |
| chunks.append(str(item["content"])) | |
| else: | |
| chunks.append(str(item)) | |
| return "\n".join(chunks) | |
| if isinstance(completion, dict): | |
| return str(completion.get("content", "")) | |
| return str(completion) | |
| def _clean_fix(text: str) -> str: | |
| text = text.strip() | |
| text = re.sub(r"^```(?:python)?\s*", "", text) | |
| text = re.sub(r"\s*```$", "", text) | |
| return text.strip() or "pass" | |
| def codearena_reward_func(completions, prompts, **kwargs): | |
| rewards = [] | |
| with httpx.Client(timeout=60.0) as client: | |
| for completion in completions: | |
| proposed_fix = _clean_fix(_extract_text(completion)) | |
| reward = 0.001 | |
| for _ in range(2): | |
| try: | |
| client.post(f"{ENV_URL}/reset", json={"task_id": "easy-1"}) | |
| res = client.post( | |
| f"{ENV_URL}/step", | |
| json={"proposed_fix": proposed_fix}, | |
| ) | |
| reward = float(res.json().get("reward", 0.001)) | |
| break | |
| except Exception: | |
| reward = 0.001 | |
| rewards.append(max(0.001, min(0.999, reward))) | |
| return rewards | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--max-steps", type=int, default=3) | |
| parser.add_argument("--output-dir", type=str, default="./grpo-check-output") | |
| args = parser.parse_args() | |
| prompts = [ | |
| "Fix this Python function: def average_list(numbers)\\n if length(numbers) == 0:\\n return 0\\n return sum(numbers) / length(numbers)", | |
| "Repair all root-cause issues in the function and keep readability high.", | |
| "Return a corrected Python function only. Ensure tests pass.", | |
| "Fix missing syntax and replace invalid APIs with valid Python APIs.", | |
| "Correct both compile and semantic issues in the provided function.", | |
| "Provide a secure, clean fix for average_list in Python.", | |
| ] | |
| train_dataset = Dataset.from_dict({"prompt": prompts}) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| training_args = GRPOConfig( | |
| output_dir=args.output_dir, | |
| learning_rate=1e-5, | |
| max_steps=args.max_steps, | |
| per_device_train_batch_size=2, | |
| gradient_accumulation_steps=1, | |
| logging_steps=1, | |
| num_generations=2, | |
| max_prompt_length=256, | |
| max_completion_length=96, | |
| temperature=0.7, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| shuffle_dataset=False, | |
| seed=42, | |
| bf16=False, | |
| fp16=False, | |
| report_to=[], | |
| ) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| reward_funcs=codearena_reward_func, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| ) | |
| trainer.train() | |
| print("GRPO check finished.") | |
| if __name__ == "__main__": | |
| main() | |