File size: 3,323 Bytes
fcb838d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Solver Agent — writes code solutions to problems.
Falls back to reference or brute-force solutions for testing.
"""

from typing import Optional, Dict, Any
from agents.prompts import (
    SOLVER_SYSTEM,
    SOLVER_USER_TEMPLATE,
    REFERENCE_SOLUTIONS,
    BRUTE_FORCE_SOLUTIONS,
)


class SolverAgent:
    """
    Solver agent wrapper.
    Uses LLM if available; falls back to reference/brute-force for testing.
    """

    def __init__(
        self,
        model=None,
        tokenizer=None,
        use_reference: bool = False,
        use_brute_force: bool = False,
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.use_reference = use_reference or (model is None and not use_brute_force)
        self.use_brute_force = use_brute_force

    def solve(self, problem: Dict[str, Any]) -> str:
        """
        Given a problem dict, return Python code that attempts to solve it.
        """
        if self.use_brute_force:
            return self._brute_force(problem)

        if self.use_reference:
            return self._reference_solution(problem)

        return self._llm_solve(problem)

    def _reference_solution(self, problem: Dict[str, Any]) -> str:
        if problem.get("reference_solution"):
            return problem["reference_solution"]
        archetype = problem.get("archetype", "array")
        task_id = problem.get("task_id", 0)
        key = (archetype, task_id)
        return REFERENCE_SOLUTIONS.get(key, 'print(0)')

    def _brute_force(self, problem: Dict[str, Any]) -> str:
        if problem.get("brute_force_solution"):
            return problem["brute_force_solution"]
        archetype = problem.get("archetype", "array")
        task_id = problem.get("task_id", 0)
        key = (archetype, task_id)
        if key in BRUTE_FORCE_SOLUTIONS:
            return BRUTE_FORCE_SOLUTIONS[key]
        return "print(0)"

    def _llm_solve(self, problem: Dict[str, Any]) -> str:
        prompt = SOLVER_USER_TEMPLATE.format(
            description=problem["description"]
        )

        messages = [
            {"role": "system", "content": SOLVER_SYSTEM},
            {"role": "user", "content": prompt},
        ]

        inputs = self.tokenizer.apply_chat_template(
            messages,
            return_tensors="pt",
            add_generation_prompt=True,
        ).to(self.model.device)

        outputs = self.model.generate(
            inputs,
            max_new_tokens=512,
            temperature=0.8,
            do_sample=True,
            pad_token_id=self.tokenizer.eos_token_id,
        )

        generated = self.tokenizer.decode(
            outputs[0][inputs.shape[1]:],
            skip_special_tokens=True,
        )

        return self._clean_code(generated)

    @staticmethod
    def _clean_code(raw: str) -> str:
        lines = raw.strip().split('\n')
        cleaned = []
        in_fence = False
        for line in lines:
            if line.strip().startswith('```'):
                in_fence = not in_fence
                continue
            cleaned.append(line)
        return '\n'.join(cleaned).strip()

    def build_prompt_text(self, problem: Dict[str, Any]) -> str:
        return SOLVER_SYSTEM + "\n\n" + SOLVER_USER_TEMPLATE.format(
            description=problem["description"]
        )