Spaces:
Sleeping
Sleeping
File size: 3,323 Bytes
fcb838d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | """
Solver Agent — writes code solutions to problems.
Falls back to reference or brute-force solutions for testing.
"""
from typing import Optional, Dict, Any
from agents.prompts import (
SOLVER_SYSTEM,
SOLVER_USER_TEMPLATE,
REFERENCE_SOLUTIONS,
BRUTE_FORCE_SOLUTIONS,
)
class SolverAgent:
"""
Solver agent wrapper.
Uses LLM if available; falls back to reference/brute-force for testing.
"""
def __init__(
self,
model=None,
tokenizer=None,
use_reference: bool = False,
use_brute_force: bool = False,
):
self.model = model
self.tokenizer = tokenizer
self.use_reference = use_reference or (model is None and not use_brute_force)
self.use_brute_force = use_brute_force
def solve(self, problem: Dict[str, Any]) -> str:
"""
Given a problem dict, return Python code that attempts to solve it.
"""
if self.use_brute_force:
return self._brute_force(problem)
if self.use_reference:
return self._reference_solution(problem)
return self._llm_solve(problem)
def _reference_solution(self, problem: Dict[str, Any]) -> str:
if problem.get("reference_solution"):
return problem["reference_solution"]
archetype = problem.get("archetype", "array")
task_id = problem.get("task_id", 0)
key = (archetype, task_id)
return REFERENCE_SOLUTIONS.get(key, 'print(0)')
def _brute_force(self, problem: Dict[str, Any]) -> str:
if problem.get("brute_force_solution"):
return problem["brute_force_solution"]
archetype = problem.get("archetype", "array")
task_id = problem.get("task_id", 0)
key = (archetype, task_id)
if key in BRUTE_FORCE_SOLUTIONS:
return BRUTE_FORCE_SOLUTIONS[key]
return "print(0)"
def _llm_solve(self, problem: Dict[str, Any]) -> str:
prompt = SOLVER_USER_TEMPLATE.format(
description=problem["description"]
)
messages = [
{"role": "system", "content": SOLVER_SYSTEM},
{"role": "user", "content": prompt},
]
inputs = self.tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True,
).to(self.model.device)
outputs = self.model.generate(
inputs,
max_new_tokens=512,
temperature=0.8,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
)
generated = self.tokenizer.decode(
outputs[0][inputs.shape[1]:],
skip_special_tokens=True,
)
return self._clean_code(generated)
@staticmethod
def _clean_code(raw: str) -> str:
lines = raw.strip().split('\n')
cleaned = []
in_fence = False
for line in lines:
if line.strip().startswith('```'):
in_fence = not in_fence
continue
cleaned.append(line)
return '\n'.join(cleaned).strip()
def build_prompt_text(self, problem: Dict[str, Any]) -> str:
return SOLVER_SYSTEM + "\n\n" + SOLVER_USER_TEMPLATE.format(
description=problem["description"]
)
|