codearena-rl / ollama_rl_rollout.py
havinashpatil
Finalizing CodeArena RL Benchmark: frontend improvements, GRPO training scripts, and cleaned environment
03a7eb9
import argparse
import csv
import json
from datetime import datetime
from pathlib import Path
import httpx
SYSTEM_PROMPT = (
"You are an expert Python code repair agent. "
"Fix the buggy Python code and return ONLY raw Python code."
)
def clean_code(text: str) -> str:
text = (text or "").strip()
if text.startswith("```python"):
text = text[9:]
elif text.startswith("```"):
text = text[3:]
if text.endswith("```"):
text = text[:-3]
return text.strip()
def ollama_generate(client: httpx.Client, model: str, prompt: str, base_url: str) -> str:
def try_chat() -> str:
payload = {
"model": model,
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
"stream": False,
"options": {
"temperature": 0.2,
"max_tokens": 512,
"top_p": 0.9,
},
}
resp = client.post(f"{base_url}/api/chat", json=payload, timeout=90.0)
resp.raise_for_status()
data = resp.json()
return clean_code(data.get("message", {}).get("content", ""))
def try_generate() -> str:
payload = {
"model": model,
"prompt": prompt,
"stream": False,
"options": {
"temperature": 0.2,
"num_predict": 512,
},
}
resp = client.post(f"{base_url}/api/generate", json=payload, timeout=90.0)
if resp.status_code == 404 or resp.status_code == 405:
return ""
resp.raise_for_status()
data = resp.json()
return clean_code(data.get("response", "") or data.get("text", ""))
code = try_generate()
if not code:
code = try_chat()
if not code:
raise RuntimeError("Ollama returned no valid code from /api/generate or /api/chat.")
return code
def run_episode(env_client: httpx.Client, ollama_client: httpx.Client, model: str, task_id: str, max_steps: int, env_url: str, ollama_url: str):
reset = env_client.post(f"{env_url}/reset", json={"task_id": task_id}, timeout=60.0)
reset.raise_for_status()
obs_json = reset.json()
steps = []
rewards = []
done = False
for step in range(1, max_steps + 1):
if done:
break
obs = obs_json.get("observation", {})
buggy_code = obs.get("buggy_code", "")
error_log = obs.get("error_log", "")
test_results = obs.get("test_results", "")
user_prompt = (
f"Fix this buggy Python code:\n\n{buggy_code}\n\n"
f"Error log:\n{error_log}\n\n"
f"Test results:\n{test_results}\n"
)
try:
proposed_fix = ollama_generate(ollama_client, model, user_prompt, ollama_url)
except Exception:
proposed_fix = buggy_code or "pass"
step_resp = env_client.post(
f"{env_url}/step",
json={"proposed_fix": proposed_fix},
timeout=90.0,
)
step_resp.raise_for_status()
step_data = step_resp.json()
reward = float(step_data.get("reward", 0.001))
reward = max(0.001, min(0.999, reward))
done = bool(step_data.get("done", False))
steps.append(
{
"step": step,
"prompt": user_prompt,
"proposed_fix": proposed_fix,
"reward": reward,
"done": done,
"task_id": step_data.get("info", {}).get("task_id", task_id),
"reward_components": step_data.get("info", {}).get("reward_components", {}),
}
)
rewards.append(reward)
obs_json = step_data
return {
"episode_reward_mean": sum(rewards) / len(rewards) if rewards else 0.001,
"episode_reward_best": max(rewards) if rewards else 0.001,
"episode_reward_last": rewards[-1] if rewards else 0.001,
"steps": steps,
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="llama3.2:latest")
parser.add_argument("--ollama-url", default="http://127.0.0.1:11434")
parser.add_argument("--env-url", default="http://127.0.0.1:7860")
parser.add_argument("--episodes", type=int, default=30)
parser.add_argument("--max-steps", type=int, default=5)
parser.add_argument("--output-dir", default="ollama_rl_out")
args = parser.parse_args()
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
traj_path = out_dir / f"trajectories_{ts}.jsonl"
summary_path = out_dir / f"summary_{ts}.csv"
tasks = ["easy", "medium", "hard", "type_errors-1", "security_bugs-1"]
episodes = []
with httpx.Client() as env_client, httpx.Client() as ollama_client:
for idx in range(args.episodes):
task = tasks[idx % len(tasks)]
ep = run_episode(
env_client,
ollama_client,
args.model,
task,
args.max_steps,
args.env_url,
args.ollama_url,
)
ep["episode_idx"] = idx + 1
ep["task_seed"] = task
episodes.append(ep)
with traj_path.open("w", encoding="utf-8") as f:
for ep in episodes:
f.write(json.dumps(ep, ensure_ascii=True) + "\n")
with summary_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(["episode", "task_seed", "mean_reward", "best_reward", "last_reward"])
for ep in episodes:
writer.writerow(
[
ep["episode_idx"],
ep["task_seed"],
ep["episode_reward_mean"],
ep["episode_reward_best"],
ep["episode_reward_last"],
]
)
all_mean = [e["episode_reward_mean"] for e in episodes]
print(f"episodes={len(episodes)}")
print(f"start_mean_reward={all_mean[0]:.4f}")
print(f"end_mean_reward={all_mean[-1]:.4f}")
print(f"best_mean_reward={max(all_mean):.4f}")
print(f"avg_mean_reward={(sum(all_mean)/len(all_mean)):.4f}")
print(f"trajectories={traj_path}")
print(f"summary={summary_path}")
if __name__ == "__main__":
main()