codeareana / server /ai_fixer.py
havinashpatil
Finalizing CodeArena RL Benchmark: frontend improvements, GRPO training scripts, and cleaned environment
03a7eb9
"""
CodeArena Built-in AI Code Fixer
Works WITHOUT Ollama. Uses AST analysis + pattern-based repair.
Also supports Ollama if available (graceful fallback).
"""
import ast
import re
import textwrap
import subprocess
import sys
from typing import Optional
from server.algorithm_detector import (
detect_problem_type, detect_complexity, needs_optimization,
get_optimization_hint, build_adaptive_prompt_suffix, ALGO_HINTS
)
from server.memory import store_success, retrieve_memory, log_complexity_reward
# ─── Pattern-Based Fixes ─────────────────────────────────────────────────────
def fix_syntax_errors(code: str) -> str:
"""Try to auto-fix common syntax errors."""
lines = code.split('\n')
fixed = []
for line in lines:
# Fix missing colon on def/class/if/for/while/else/elif/try/except/finally
stripped = line.rstrip()
if re.match(r'^\s*(def |class |if |elif |else|for |while |try|except|finally)', stripped):
if not stripped.endswith(':') and not stripped.endswith('\\') and not stripped.endswith(','):
stripped = stripped + ':'
fixed.append(stripped)
return '\n'.join(fixed)
def fix_wrong_builtins(code: str) -> str:
"""Fix common wrong builtin usage."""
replacements = {
r'\blenght\b': 'len',
r'\bappned\b': 'append',
r'\bpirnt\b': 'print',
r'\bprnit\b': 'print',
r'\bretrun\b': 'return',
r'\bpas\b': 'pass',
r'\bTreu\b': 'True',
r'\bFlase\b': 'False',
r'\bNoen\b': 'None',
}
for pattern, replacement in replacements.items():
code = re.sub(pattern, replacement, code)
return code
def optimize_complexity(code: str) -> str:
"""
Detect and optimize common O(N^2)/O(N^3) patterns.
- Triple nested loops on same array β†’ Kadane's algorithm
- Bubble sort β†’ sorted()
- Linear search in list β†’ set/dict lookup
"""
# Detect triple nested loop (O(N^3)) β†’ max subarray β†’ Kadane's
if re.search(r'for\s+\w+\s+in\s+range.*:\s*\n.*for\s+\w+\s+in\s+range.*:\s*\n.*for\s+\w+\s+in\s+range', code, re.DOTALL):
# Extract function signature
match = re.match(r'(def\s+\w+\([^)]*\):)', code.strip())
if match:
sig = match.group(1)
fname = re.search(r'def\s+(\w+)', sig).group(1)
# Check if it's a max subarray problem
if 'max' in code.lower() and ('sum' in code.lower() or 'subarray' in code.lower()):
return f"""{sig}
# Optimized: Kadane's Algorithm O(N)
if not arr:
return 0
max_sum = arr[0]
current_sum = arr[0]
for num in arr[1:]:
current_sum = max(num, current_sum + num)
max_sum = max(max_sum, current_sum)
return max_sum"""
# Detect O(N^2) bubble sort β†’ use sorted()
if re.search(r'for\s+\w+.*range.*:\s*\n.*for\s+\w+.*range.*:\s*\n.*if\s+\w+\[', code, re.DOTALL):
if 'swap' in code.lower() or ('arr[i]' in code and 'arr[j]' in code):
match = re.match(r'(def\s+\w+\([^)]*\):)', code.strip())
if match:
sig = match.group(1)
param = re.search(r'def\s+\w+\(([^)]*)\)', sig)
params = param.group(1).split(',')[0].strip() if param else 'arr'
return f"""{sig}
# Optimized: Python built-in sort O(N log N)
return sorted({params})"""
# Detect double nested loop with repeated computation
if code.count('for ') >= 2 and 'range(n)' in code and 'range(i' in code:
# Off-by-one fix for binary search
if 'binary_search' in code.lower() or ('mid' in code and 'low' in code and 'high' in code):
match = re.match(r'(def\s+\w+\([^)]*\):)', code.strip())
if match:
sig = match.group(1)
params = re.search(r'def\s+\w+\(([^)]*)\)', sig).group(1)
param_list = [p.strip() for p in params.split(',')]
arr_p = param_list[0] if len(param_list) > 0 else 'arr'
target_p = param_list[1] if len(param_list) > 1 else 'target'
return f"""{sig}
# Fixed: Correct binary search O(log N)
low, high = 0, len({arr_p}) - 1
while low <= high:
mid = (low + high) // 2
if {arr_p}[mid] == {target_p}:
return mid
elif {arr_p}[mid] < {target_p}:
low = mid + 1
else:
high = mid - 1
return -1"""
return code
def fix_logic_bugs(code: str) -> str:
"""Fix common logic bugs: off-by-one, wrong operators, etc."""
# range(n) instead of range(n+1) for inclusive
# Off-by-one in binary search
code = re.sub(r'high\s*=\s*len\((\w+)\)', r'high = len(\1) - 1', code)
# Fix wrong range in binary search: range(len(arr)) -> while low <= high
# Fix average calculation: sum / n should use len()
code = re.sub(r'return\s+total\s*/\s*n\b', 'return total / len(arr) if arr else 0', code)
# Fix division by zero risk
if 'average' in code.lower() or 'mean' in code.lower():
code = re.sub(
r'return\s+(\w+)\s*/\s*len\((\w+)\)',
r'return \1 / len(\2) if \2 else 0',
code
)
return code
def apply_all_fixes(code: str) -> str:
"""Apply all fixers in sequence."""
code = fix_wrong_builtins(code)
code = fix_syntax_errors(code)
code = fix_logic_bugs(code)
code = optimize_complexity(code)
return code
# ─── Ollama Integration (optional) ───────────────────────────────────────────
def is_ollama_available(ollama_url: str = "http://localhost:11434", model: str = "llama3.2:latest") -> bool:
"""Check if Ollama is running and model exists."""
try:
import urllib.request
import json
req = urllib.request.Request(f"{ollama_url}/api/tags")
with urllib.request.urlopen(req, timeout=3) as resp:
data = json.loads(resp.read())
models = [m['name'] for m in data.get('models', [])]
return any(model.split(':')[0] in m for m in models)
except Exception:
return False
def validate_code(code: str) -> bool:
"""Safety layer to prevent 0.0 reward syntax failures."""
try:
compile(code, "<string>", "exec")
return True
except Exception:
return False
def is_inefficient(code: str) -> bool:
"""
Detect if generated code is still using brute force.
Returns True if code looks inefficient.
"""
nested_fors = code.count('for ') >= 2
has_on2_marker = 'O(n^2)' in code or 'O(n^3)' in code or 'O(N^2)' in code or 'O(N^3)' in code
# Detect triple nested loop pattern (O(N^3))
triple_loop = bool(re.search(
r'for\s+\w+.*:\s*\n\s+for\s+\w+.*:\s*\n\s+for\s+\w+', code, re.MULTILINE
))
return triple_loop or has_on2_marker
def _call_ollama(prompt: str, model: str, ollama_url: str, num_predict: int = 1024) -> str | None:
"""Send a single prompt to Ollama and return raw text response."""
import urllib.request
import json
payload = json.dumps({
"model": model,
"prompt": prompt,
"stream": False,
"options": {"temperature": 0.1, "num_predict": num_predict}
}).encode()
req = urllib.request.Request(
f"{ollama_url}/api/generate",
data=payload,
headers={"Content-Type": "application/json"},
method="POST"
)
with urllib.request.urlopen(req, timeout=60) as resp:
data = json.loads(resp.read())
return data.get("response", "").strip()
def _extract_code_and_explanation(result: str) -> tuple[str, str]:
"""Extract code block and explanation from model response."""
code_match = re.search(r'```python\n(.*?)\n```', result, re.DOTALL)
if not code_match:
code_match = re.search(r'```(.*?)```', result, re.DOTALL)
extracted_code = code_match.group(1).strip() if code_match else result.strip()
explanation = result.replace(code_match.group(0), '').strip() if code_match else "No reasoning provided."
return extracted_code, explanation
def _build_optimization_prompt(code: str, error_log: str) -> str:
"""
Build the Analysis β†’ Optimization β†’ Code 3-step prompt with pattern mapping.
"""
return f"""You are an expert Python algorithm engineer.
The current solution is inefficient or buggy.
Step 1: Identify why it is inefficient or incorrect (1 line only)
Step 2: Identify the optimal algorithm to solve this problem
Step 3: Rewrite the code using the optimal algorithm
Constraints:
- MUST improve time complexity
- DO NOT use brute force
- Target O(n) if possible
- If your solution is O(n^2) or worse, improve it
Common algorithm patterns:
- Maximum subarray β†’ Kadane's algorithm (O(n))
- Subarray sum β†’ prefix sum (O(n))
- Searching sorted array β†’ binary search (O(log n))
- Sorting β†’ use built-in sorted() (O(n log n))
- Sliding window β†’ two pointers (O(n))
First think step-by-step about how to optimize the algorithm.
Then output only the final code.
Do NOT stop at identifying the issue β€” you MUST produce optimized code.
Previous error:
{error_log or "No errors, but the solution is suboptimal."}
CURRENT CODE:
{code}
Output your 3-step reasoning, then wrap the final optimized code in a ```python ... ``` block."""
def _build_fix_prompt(code: str, error_log: str, reward: float = 0.0, task_id: str = "") -> str:
"""Build prompt for correctness fix (when code has bugs/errors)."""
# Get algorithm hint from detector
algo_hint = get_optimization_hint(code, error_log)
# Get adaptive suffix based on current reward
adaptive_suffix = build_adaptive_prompt_suffix(reward)
# Retrieve memory for past success
memory_note = ""
if task_id:
past = retrieve_memory(task_id)
if past and past.get('reward', 0) > 0.7:
memory_note = f"\nPrevious successful solution (reward={past['reward']}):\n{past['best_code']}\nImprove upon this."
return f"""You are an expert Python debugging agent.
Follow this process and explain your reasoning:
Step 1: Identify bug type (syntax / logic / type / edge case)
Step 2: Locate exact line causing issue
Step 3: Fix only that issue and ensure tests pass
Step 4: Report the Time Complexity of your fixed code
Step 5: If complexity is O(n^2) or worse, optimize to O(n) if possible
Algorithm Detection: {algo_hint}
Common algorithm patterns:
- Maximum subarray β†’ Kadane's algorithm (O(n))
- Subarray sum β†’ prefix sum (O(n))
- Searching sorted array β†’ binary search (O(log n))
- Sorting β†’ use built-in sorted() (O(n log n))
Is your solution optimal? If not, improve it.
{adaptive_suffix}
{memory_note}
Previous attempt failed with:
{error_log or "No errors, but tests are failing."}
BUGGY CODE:
{code}
Output your step-by-step reasoning, then wrap ONLY the corrected Python code in a ```python ... ``` block."""
def fix_with_ollama(
code: str,
error_log: str = "",
ollama_url: str = "http://localhost:11434",
model: str = "llama3.2:latest",
reward: float = 0.0,
task_id: str = "",
) -> Optional[tuple[str, str]]:
"""
Fix + optimize code using Ollama.
Pipeline:
1. Generate fix (correctness + optimization prompt)
2. Self-critique: if result is still inefficient β†’ run optimization prompt
3. Iterative refinement: repeat up to 2 full cycles
Returns (code, explanation) or None.
"""
try:
import urllib.request
import json
best_code = None
best_explanation = ""
# Iterative refinement: up to 2 full optimization passes
for iteration in range(2):
# Choose prompt: optimization-first if first run, fix-first if error exists
if iteration == 0 and error_log:
prompt = _build_fix_prompt(code, error_log, reward=reward, task_id=task_id)
else:
# Inject algorithm hint + adaptive suffix into optimization prompt
algo_hint = get_optimization_hint(best_code or code, error_log)
adaptive_suffix = build_adaptive_prompt_suffix(reward)
base_opt_prompt = _build_optimization_prompt(best_code or code, error_log)
prompt = base_opt_prompt + f"\n\nAlgorithm Detection: {algo_hint}{adaptive_suffix}"
result = None
for attempt in range(3): # 3 retries per iteration
try:
result = _call_ollama(prompt, model, ollama_url)
if not result:
continue
extracted_code, explanation = _extract_code_and_explanation(result)
if extracted_code and validate_code(extracted_code):
best_code = extracted_code
best_explanation = explanation
break # Valid code β€” move on
# Invalid syntax: tell model to fix it
prompt += "\n\nYour last generated code had a SyntaxError. Wrap ONLY valid Python code in ```python ... ``` blocks."
except Exception as e:
print(f"[Ollama attempt {attempt+1} failed]: {e}")
continue
if best_code is None:
return None # All retries failed
# ── Self-Critique Loop ────────────────────────────────────────────
# If the generated code is still brute-force, force a re-optimization pass
if is_inefficient(best_code):
print(f"[Self-Critique] Iteration {iteration+1}: Code still inefficient, re-optimizing...")
# Build a targeted re-optimization prompt
critique_prompt = f"""You are a Python performance expert.
The following solution is STILL using brute force and is too slow:
```python
{best_code}
```
This is unacceptable. You MUST rewrite it using an optimal algorithm.
Common patterns:
- Maximum subarray β†’ Kadane's algorithm (O(n))
- Subarray sum β†’ prefix sum (O(n))
- Searching β†’ binary search (O(log n))
Output ONLY the O(n) optimized version inside a ```python ... ``` block. No explanation needed."""
try:
critique_result = _call_ollama(critique_prompt, model, ollama_url)
if critique_result:
improved_code, improved_explanation = _extract_code_and_explanation(critique_result)
if improved_code and validate_code(improved_code):
best_code = improved_code
best_explanation = f"[Self-Critique Applied]\n{improved_explanation or best_explanation}"
except Exception as e:
print(f"[Self-Critique] Failed: {e}")
# If no longer inefficient after critique, stop early
if not is_inefficient(best_code):
break
return (best_code, best_explanation) if best_code else None
except Exception as e:
print(f"Ollama fix failed: {e}")
return None
def generate_fix(
code: str,
error_log: str = "",
ollama_url: str = "http://localhost:11434",
model: str = "llama3.2:latest",
use_ollama: bool = True,
reward: float = 0.0,
task_id: str = "",
) -> dict:
"""
Main entry point for code fixing.
Full pipeline: Algorithm Detection + Memory → Ollama (Analysis→Optimization→Code + Self-Critique) → built-in fallback
Logs complexity vs reward to CSV for research tracking.
Returns: { fixed_code, method, success, explanation }
"""
if use_ollama:
result = fix_with_ollama(code, error_log, ollama_url, model, reward=reward, task_id=task_id)
if result:
fixed_code, explanation = result
# Log complexity vs reward for research tracking
complexity = detect_complexity(fixed_code)
log_complexity_reward(task_id or "sandbox", reward, complexity, step=0, method="ollama")
# Store in memory if good reward
if reward >= 0.8 and task_id:
store_success(task_id, fixed_code, reward)
return {
"fixed_code": fixed_code,
"method": "ollama",
"success": True,
"explanation": explanation,
"complexity": complexity,
"algo_hint": get_optimization_hint(fixed_code, error_log),
}
# Fallback: built-in AST pattern fixer
fixed_code = apply_all_fixes(code)
complexity = detect_complexity(fixed_code)
log_complexity_reward(task_id or "sandbox", reward, complexity, step=0, method="builtin")
return {
"fixed_code": fixed_code,
"method": "builtin",
"success": True,
"explanation": "Ollama unavailable. Used built-in pattern-based fixer.",
"note": "Ollama unavailable. Used built-in pattern-based fixer.",
"complexity": complexity,
"algo_hint": get_optimization_hint(fixed_code),
}