Spaces:
Sleeping
Sleeping
File size: 9,557 Bytes
03a7eb9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 | #!/usr/bin/env python3
"""
Improved CodeArena RL Agent with better prompting and debugging strategy.
"""
import os
import requests
import time
from typing import Dict, List, Tuple
class CodeArenaAgent:
def __init__(self, backend: str = "ollama", model: str = "llama3.2:latest"):
self.backend = backend
self.model = model
self.api_base = "http://localhost:11434"
self.api_key = None # Ollama doesn't need API key
def generate_fix(self, buggy_code: str, error_log: str, test_results: str,
previous_attempts: List[str], step_count: int) -> str:
"""Generate a fix using improved prompting strategy"""
# Build context from previous failures
context = ""
if previous_attempts:
context += f"\nPrevious attempts that failed:\n"
for i, attempt in enumerate(previous_attempts[-2:], 1): # Last 2 attempts
context += f"Attempt {len(previous_attempts)-len(previous_attempts[-2:])+i}: {attempt[:100]}...\n"
# Step-aware prompt
step_instructions = {
1: "Focus on fixing syntax errors and basic compilation issues first.",
2: "Now address logic errors and test failures from the previous attempt.",
3: "Optimize the solution and ensure all edge cases are handled.",
4: "Final attempt: ensure the solution is robust and handles all test cases.",
5: "Last chance: fix any remaining issues with a completely different approach."
}
prompt = f"""You are an expert Python debugger. Fix the buggy code below.
BUGGY CODE:
{buggy_code}
CURRENT ERRORS:
{error_log}
TEST RESULTS:
{test_results}
STEP {step_count} INSTRUCTIONS:
{step_instructions.get(step_count, "Fix all remaining issues.")}
{context}
REQUIREMENTS:
1. The code must compile without syntax errors
2. All tests must pass
3. Fix the ROOT CAUSE, not just symptoms
4. Do NOT repeat previous failed approaches
5. Ensure proper Python syntax and indentation
6. Return ONLY the corrected code, no explanations
Output the complete corrected Python code:"""
if not self.api_key and self.backend == "openai":
# Fallback for OpenAI without key
return self._fallback_fix(buggy_code, step_count)
try:
if self.backend == "ollama":
# Use Ollama API
import requests
response = requests.post(
f"{self.api_base}/api/generate",
json={
"model": self.model,
"prompt": prompt,
"stream": False,
"options": {
"temperature": 0.3,
"num_predict": 1000
}
},
timeout=30
)
response.raise_for_status()
result = response.json()
fix = result.get("response", "").strip()
else:
# Use OpenAI API
import openai
client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=1000,
temperature=0.3
)
fix = response.choices[0].message.content.strip()
# Clean up common markdown artifacts
if fix.startswith("```python"):
fix = fix[9:]
if fix.startswith("```"):
fix = fix[3:]
if fix.endswith("```"):
fix = fix[:-3]
return fix.strip()
except Exception as e:
print(f"API Error: {e}")
return self._fallback_fix(buggy_code, step_count)
def _fallback_fix(self, buggy_code: str, step_count: int) -> str:
"""Simple fallback fix for when API is unavailable"""
print(f"[DEBUG] Fallback input code ({len(buggy_code)} chars): {repr(buggy_code[:100])}")
# Try to fix common syntax errors in the buggy code
fixed_code = buggy_code
# Fix 1: Add missing colons after function definitions
lines = fixed_code.split('\n')
for i, line in enumerate(lines):
stripped = line.strip()
if stripped.startswith('def ') and not stripped.endswith(':'):
lines[i] = line + ':'
print(f"[DEBUG] Added colon to line {i+1}")
fixed_code = '\n'.join(lines)
# Fix 2: Replace length() with len()
if 'length(' in fixed_code:
fixed_code = fixed_code.replace('length(', 'len(')
print("[DEBUG] Replaced length() with len()")
print(f"[DEBUG] Fallback output code ({len(fixed_code)} chars): {repr(fixed_code[:100])}")
return fixed_code
def run_episode(task_id: str = "easy-1", max_steps: int = 5) -> Dict:
"""Run a single episode with improved agent"""
agent = CodeArenaAgent()
print(f"\nπ― Starting episode: {task_id}")
# Reset
try:
response = requests.post("http://localhost:7860/reset", json={"task_id": task_id}, timeout=10)
response.raise_for_status()
obs = response.json()
print(f"β
Reset successful - task: {obs.get('task_id')}")
except Exception as e:
print(f"β Reset failed: {e}")
return {"success": False, "error": str(e)}
rewards = []
previous_attempts = []
done = False
step_count = 0
while not done and step_count < max_steps:
step_count += 1
# Generate fix
fix = agent.generate_fix(
buggy_code=obs.get('buggy_code', ''),
error_log=obs.get('error_log', ''),
test_results=obs.get('test_results', ''),
previous_attempts=previous_attempts,
step_count=step_count
)
print(f"\nπ§ Step {step_count}: Generated fix ({len(fix)} chars)")
# Step
try:
response = requests.post("http://localhost:7860/step",
json={"proposed_fix": fix},
timeout=20)
response.raise_for_status()
result = response.json()
reward = result.get('reward', 0)
done = result.get('done', False)
info = result.get('info', {})
rewards.append(reward)
previous_attempts.append(fix)
print(".3f")
print(f" Tests: {info.get('test_results', 'unknown')}")
print(f" Done: {done}")
if reward > 0.5:
print("π Good reward! Continuing...")
elif reward < 0.1:
print("β οΈ Low reward - check debug logs")
obs = result.get('observation', {})
except Exception as e:
print(f"β Step failed: {e}")
break
# Summary
final_reward = rewards[-1] if rewards else 0
success = final_reward > 0.5
print(f"\nπ Episode complete!")
print(f" Steps: {step_count}")
print(".3f")
print(f" Success: {success}")
return {
"success": success,
"steps": step_count,
"final_reward": final_reward,
"rewards": rewards
}
def main():
import argparse
parser = argparse.ArgumentParser(description="Improved CodeArena RL Agent")
parser.add_argument("--task", default="easy-1", help="Task ID to run")
parser.add_argument("--episodes", type=int, default=1, help="Number of episodes")
parser.add_argument("--backend", default="ollama", choices=["ollama", "openai", "hf"], help="Backend to use")
parser.add_argument("--model", default="llama3.2:latest", help="Model name")
args = parser.parse_args()
print("π€ Improved CodeArena Agent")
print("=" * 50)
print(f"Task: {args.task}")
print(f"Episodes: {args.episodes}")
print(f"Backend: {args.backend}")
print(f"Model: {args.model}")
results = []
for i in range(args.episodes):
print(f"\nπ Episode {i+1}/{args.episodes}")
result = run_episode(args.task)
results.append(result)
# Log to CSV
import csv
with open("rewards_log.csv", "a", newline="") as f:
writer = csv.writer(f)
if os.path.getsize("rewards_log.csv") == 0: # Empty file
writer.writerow(["timestamp", "task_id", "step", "reward", "compile_score", "test_ratio", "efficiency_score"])
# Note: We don't have detailed component breakdown here, so we'll use placeholders
writer.writerow([
time.strftime("%Y-%m-%d %H:%M:%S"),
args.task,
result["steps"],
result["final_reward"],
0.0, 0.0, 0.0 # Placeholder values
])
# Summary
successes = sum(1 for r in results if r["success"])
avg_reward = sum(r["final_reward"] for r in results) / len(results)
print(f"\nπ Summary:")
print(f" Success rate: {successes}/{len(results)} ({successes/len(results)*100:.1f}%)")
print(".3f")
if successes > 0:
print("π Some episodes succeeded! Check rewards_log.csv and run plot_rewards.py")
else:
print("β οΈ All episodes failed. Check debug output and fix issues.")
if __name__ == "__main__":
main() |