codecourt / agents /solver.py
ayussssssiiii's picture
Initial HF Space snapshot
fcb838d
"""
Solver Agent — writes code solutions to problems.
Falls back to reference or brute-force solutions for testing.
"""
from typing import Optional, Dict, Any
from agents.prompts import (
SOLVER_SYSTEM,
SOLVER_USER_TEMPLATE,
REFERENCE_SOLUTIONS,
BRUTE_FORCE_SOLUTIONS,
)
class SolverAgent:
"""
Solver agent wrapper.
Uses LLM if available; falls back to reference/brute-force for testing.
"""
def __init__(
self,
model=None,
tokenizer=None,
use_reference: bool = False,
use_brute_force: bool = False,
):
self.model = model
self.tokenizer = tokenizer
self.use_reference = use_reference or (model is None and not use_brute_force)
self.use_brute_force = use_brute_force
def solve(self, problem: Dict[str, Any]) -> str:
"""
Given a problem dict, return Python code that attempts to solve it.
"""
if self.use_brute_force:
return self._brute_force(problem)
if self.use_reference:
return self._reference_solution(problem)
return self._llm_solve(problem)
def _reference_solution(self, problem: Dict[str, Any]) -> str:
if problem.get("reference_solution"):
return problem["reference_solution"]
archetype = problem.get("archetype", "array")
task_id = problem.get("task_id", 0)
key = (archetype, task_id)
return REFERENCE_SOLUTIONS.get(key, 'print(0)')
def _brute_force(self, problem: Dict[str, Any]) -> str:
if problem.get("brute_force_solution"):
return problem["brute_force_solution"]
archetype = problem.get("archetype", "array")
task_id = problem.get("task_id", 0)
key = (archetype, task_id)
if key in BRUTE_FORCE_SOLUTIONS:
return BRUTE_FORCE_SOLUTIONS[key]
return "print(0)"
def _llm_solve(self, problem: Dict[str, Any]) -> str:
prompt = SOLVER_USER_TEMPLATE.format(
description=problem["description"]
)
messages = [
{"role": "system", "content": SOLVER_SYSTEM},
{"role": "user", "content": prompt},
]
inputs = self.tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True,
).to(self.model.device)
outputs = self.model.generate(
inputs,
max_new_tokens=512,
temperature=0.8,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
)
generated = self.tokenizer.decode(
outputs[0][inputs.shape[1]:],
skip_special_tokens=True,
)
return self._clean_code(generated)
@staticmethod
def _clean_code(raw: str) -> str:
lines = raw.strip().split('\n')
cleaned = []
in_fence = False
for line in lines:
if line.strip().startswith('```'):
in_fence = not in_fence
continue
cleaned.append(line)
return '\n'.join(cleaned).strip()
def build_prompt_text(self, problem: Dict[str, Any]) -> str:
return SOLVER_SYSTEM + "\n\n" + SOLVER_USER_TEMPLATE.format(
description=problem["description"]
)