Spaces:
Sleeping
Sleeping
havinashpatil
Finalizing CodeArena RL Benchmark: frontend improvements, GRPO training scripts, and cleaned environment
03a7eb9 | import argparse | |
| import csv | |
| import json | |
| from datetime import datetime | |
| from pathlib import Path | |
| import httpx | |
| SYSTEM_PROMPT = ( | |
| "You are an expert Python code repair agent. " | |
| "Fix the buggy Python code and return ONLY raw Python code." | |
| ) | |
| def clean_code(text: str) -> str: | |
| text = (text or "").strip() | |
| if text.startswith("```python"): | |
| text = text[9:] | |
| elif text.startswith("```"): | |
| text = text[3:] | |
| if text.endswith("```"): | |
| text = text[:-3] | |
| return text.strip() | |
| def ollama_generate(client: httpx.Client, model: str, prompt: str, base_url: str) -> str: | |
| def try_chat() -> str: | |
| payload = { | |
| "model": model, | |
| "messages": [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| "stream": False, | |
| "options": { | |
| "temperature": 0.2, | |
| "max_tokens": 512, | |
| "top_p": 0.9, | |
| }, | |
| } | |
| resp = client.post(f"{base_url}/api/chat", json=payload, timeout=90.0) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| return clean_code(data.get("message", {}).get("content", "")) | |
| def try_generate() -> str: | |
| payload = { | |
| "model": model, | |
| "prompt": prompt, | |
| "stream": False, | |
| "options": { | |
| "temperature": 0.2, | |
| "num_predict": 512, | |
| }, | |
| } | |
| resp = client.post(f"{base_url}/api/generate", json=payload, timeout=90.0) | |
| if resp.status_code == 404 or resp.status_code == 405: | |
| return "" | |
| resp.raise_for_status() | |
| data = resp.json() | |
| return clean_code(data.get("response", "") or data.get("text", "")) | |
| code = try_generate() | |
| if not code: | |
| code = try_chat() | |
| if not code: | |
| raise RuntimeError("Ollama returned no valid code from /api/generate or /api/chat.") | |
| return code | |
| def run_episode(env_client: httpx.Client, ollama_client: httpx.Client, model: str, task_id: str, max_steps: int, env_url: str, ollama_url: str): | |
| reset = env_client.post(f"{env_url}/reset", json={"task_id": task_id}, timeout=60.0) | |
| reset.raise_for_status() | |
| obs_json = reset.json() | |
| steps = [] | |
| rewards = [] | |
| done = False | |
| for step in range(1, max_steps + 1): | |
| if done: | |
| break | |
| obs = obs_json.get("observation", {}) | |
| buggy_code = obs.get("buggy_code", "") | |
| error_log = obs.get("error_log", "") | |
| test_results = obs.get("test_results", "") | |
| user_prompt = ( | |
| f"Fix this buggy Python code:\n\n{buggy_code}\n\n" | |
| f"Error log:\n{error_log}\n\n" | |
| f"Test results:\n{test_results}\n" | |
| ) | |
| try: | |
| proposed_fix = ollama_generate(ollama_client, model, user_prompt, ollama_url) | |
| except Exception: | |
| proposed_fix = buggy_code or "pass" | |
| step_resp = env_client.post( | |
| f"{env_url}/step", | |
| json={"proposed_fix": proposed_fix}, | |
| timeout=90.0, | |
| ) | |
| step_resp.raise_for_status() | |
| step_data = step_resp.json() | |
| reward = float(step_data.get("reward", 0.001)) | |
| reward = max(0.001, min(0.999, reward)) | |
| done = bool(step_data.get("done", False)) | |
| steps.append( | |
| { | |
| "step": step, | |
| "prompt": user_prompt, | |
| "proposed_fix": proposed_fix, | |
| "reward": reward, | |
| "done": done, | |
| "task_id": step_data.get("info", {}).get("task_id", task_id), | |
| "reward_components": step_data.get("info", {}).get("reward_components", {}), | |
| } | |
| ) | |
| rewards.append(reward) | |
| obs_json = step_data | |
| return { | |
| "episode_reward_mean": sum(rewards) / len(rewards) if rewards else 0.001, | |
| "episode_reward_best": max(rewards) if rewards else 0.001, | |
| "episode_reward_last": rewards[-1] if rewards else 0.001, | |
| "steps": steps, | |
| } | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model", default="llama3.2:latest") | |
| parser.add_argument("--ollama-url", default="http://127.0.0.1:11434") | |
| parser.add_argument("--env-url", default="http://127.0.0.1:7860") | |
| parser.add_argument("--episodes", type=int, default=30) | |
| parser.add_argument("--max-steps", type=int, default=5) | |
| parser.add_argument("--output-dir", default="ollama_rl_out") | |
| args = parser.parse_args() | |
| out_dir = Path(args.output_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| traj_path = out_dir / f"trajectories_{ts}.jsonl" | |
| summary_path = out_dir / f"summary_{ts}.csv" | |
| tasks = ["easy", "medium", "hard", "type_errors-1", "security_bugs-1"] | |
| episodes = [] | |
| with httpx.Client() as env_client, httpx.Client() as ollama_client: | |
| for idx in range(args.episodes): | |
| task = tasks[idx % len(tasks)] | |
| ep = run_episode( | |
| env_client, | |
| ollama_client, | |
| args.model, | |
| task, | |
| args.max_steps, | |
| args.env_url, | |
| args.ollama_url, | |
| ) | |
| ep["episode_idx"] = idx + 1 | |
| ep["task_seed"] = task | |
| episodes.append(ep) | |
| with traj_path.open("w", encoding="utf-8") as f: | |
| for ep in episodes: | |
| f.write(json.dumps(ep, ensure_ascii=True) + "\n") | |
| with summary_path.open("w", newline="", encoding="utf-8") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(["episode", "task_seed", "mean_reward", "best_reward", "last_reward"]) | |
| for ep in episodes: | |
| writer.writerow( | |
| [ | |
| ep["episode_idx"], | |
| ep["task_seed"], | |
| ep["episode_reward_mean"], | |
| ep["episode_reward_best"], | |
| ep["episode_reward_last"], | |
| ] | |
| ) | |
| all_mean = [e["episode_reward_mean"] for e in episodes] | |
| print(f"episodes={len(episodes)}") | |
| print(f"start_mean_reward={all_mean[0]:.4f}") | |
| print(f"end_mean_reward={all_mean[-1]:.4f}") | |
| print(f"best_mean_reward={max(all_mean):.4f}") | |
| print(f"avg_mean_reward={(sum(all_mean)/len(all_mean)):.4f}") | |
| print(f"trajectories={traj_path}") | |
| print(f"summary={summary_path}") | |
| if __name__ == "__main__": | |
| main() | |