File size: 12,143 Bytes
9ab70a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
"""inference_smolagent.py - Run model with smolagents CodeAgent and LocalPythonExecutor"""
import os
import re

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from smolagents import CodeAgent, Tool
from smolagents.local_python_executor import LocalPythonExecutor
from smolagents.models import ChatMessage, MessageRole, Model

DEBUG = int(os.environ.get("DEBUG", 0))

# Model's special tokens (from training)
START_TOOL_CALL = "<|start_tool_call|>"
END_TOOL_CALL = "<|end_tool_call|>"
START_TOOL_RESPONSE = "<|start_tool_response|>"
END_TOOL_RESPONSE = "<|end_tool_response|>"

# Smolagents expected tokens
SMOLAGENT_CODE_START = "<code>"
SMOLAGENT_CODE_END = "</code>"


class LocalCodeModel(Model):
    """
    Local model wrapper compatible with smolagents.

    Handles translation between smolagents format and model's training format.
    """

    def __init__(self, model_id: str, device: str = None):
        super().__init__()
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = AutoTokenizer.from_pretrained(model_id, fix_mistral_regex=True)
        self.model = AutoModelForCausalLM.from_pretrained(model_id)
        self.model.to(self.device)
        self.model.eval()

        # Cache special token IDs for stopping
        self._end_tool_id = self.tokenizer.encode(END_TOOL_CALL, add_special_tokens=False)[-1]

    def _convert_prompt_to_model_format(self, prompt: str) -> str:
        """Convert smolagents prompt format to model's training format."""
        # Replace smolagents code markers with model's markers
        prompt = prompt.replace(SMOLAGENT_CODE_START, START_TOOL_CALL)
        prompt = prompt.replace(SMOLAGENT_CODE_END, END_TOOL_CALL)
        return prompt

    def _convert_response_to_smolagent_format(self, response: str) -> str:
        """Convert model's output format to smolagents expected format."""
        # Replace model's markers with smolagents markers
        response = response.replace(START_TOOL_CALL, SMOLAGENT_CODE_START)
        response = response.replace(END_TOOL_CALL, SMOLAGENT_CODE_END)
        response = response.replace(START_TOOL_RESPONSE, "")
        response = response.replace(END_TOOL_RESPONSE, "")

        # Clean up: remove orphan closing tags at start
        response = re.sub(r'^\s*</code>\s*', '', response)

        # Check if we have valid <code>...</code> block
        has_open = SMOLAGENT_CODE_START in response
        has_close = SMOLAGENT_CODE_END in response

        # If only closing tag, remove it
        if has_close and not has_open:
            response = response.replace(SMOLAGENT_CODE_END, "")

        # If no code markers, try to extract and wrap code
        if SMOLAGENT_CODE_START not in response:
            # Look for python code patterns in markdown
            code_match = re.search(r'```(?:python)?\s*(.*?)\s*```', response, re.DOTALL)
            if code_match:
                code = code_match.group(1).strip()
                if code:
                    response = f"Thoughts: Executing the code\n{SMOLAGENT_CODE_START}\n{code}\n{SMOLAGENT_CODE_END}"
            else:
                # Look for any code-like content
                lines = response.strip().split('\n')
                code_lines = [l for l in lines if any(kw in l for kw in ['def ', 'print(', 'return ', '= ', 'import ', 'for ', 'if ', 'while '])]
                if code_lines:
                    code = '\n'.join(code_lines)
                    response = f"Thoughts: Executing the code\n{SMOLAGENT_CODE_START}\n{code}\n{SMOLAGENT_CODE_END}"
                else:
                    # Fallback: wrap entire response as code if it looks like code
                    clean = response.strip()
                    if clean and not clean.startswith("Thoughts"):
                        response = f"Thoughts: Attempting execution\n{SMOLAGENT_CODE_START}\nprint('No valid code generated')\n{SMOLAGENT_CODE_END}"

        # Ensure closing tag exists if opening exists
        if SMOLAGENT_CODE_START in response and SMOLAGENT_CODE_END not in response:
            response = response + f"\n{SMOLAGENT_CODE_END}"

        return response

    def generate(
        self,
        messages: list[ChatMessage],
        stop_sequences: list[str] | None = None,
        grammar: str | None = None,
        tools_to_call_from: list[Tool] | None = None,
        **kwargs,
    ) -> ChatMessage:
        """Generate response for message history (required by smolagents Model)."""
        # Debug: show what messages are passed (including executor output)
        if DEBUG:
            print("\n[DEBUG] Messages received by model:")
            for i, msg in enumerate(messages):
                role = msg.role.value if hasattr(msg.role, "value") else msg.role
                content = str(msg.content)[:200] if msg.content else "<empty>"
                print(f"  [{i}] {role}: {content}...")
            print()

        # Convert ChatMessage objects to dicts for chat template
        messages_dicts = []
        for msg in messages:
            if hasattr(msg, "role") and hasattr(msg, "content"):
                role = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
                content = msg.content if isinstance(msg.content, str) else str(msg.content or "")
                # Convert prompt format in content
                content = self._convert_prompt_to_model_format(content)
                # Wrap observations (executor output) in tool response tokens
                if "Observation:" in content or "Out:" in content:
                    # Extract the observation content
                    obs_match = re.search(r'(?:Observation:|Out:)\s*(.*)', content, re.DOTALL)
                    if obs_match:
                        obs_content = obs_match.group(1).strip()
                        content = f"{START_TOOL_RESPONSE}\n{obs_content}\n{END_TOOL_RESPONSE}"
                messages_dicts.append({"role": role, "content": content})
            else:
                messages_dicts.append(msg)

        # Convert messages to prompt using chat template
        prompt = self.tokenizer.apply_chat_template(
            messages_dicts,
            add_generation_prompt=True,
            tokenize=False
        )

        # Check prompt length
        if DEBUG:
            full_tokens = self.tokenizer(prompt, return_tensors="pt")
            print(f"[DEBUG] Prompt length: {full_tokens['input_ids'].shape[1]} tokens (max: 2048)")

        # Truncate to fit model's context window (2048 tokens, leave room for generation)
        max_input_tokens = 1536  # Leave 512 for generation
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=max_input_tokens
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=512,
                temperature=0.7,
                do_sample=True,
                top_p=0.9,
                repetition_penalty=1.2,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=[self.tokenizer.eos_token_id, self._end_tool_id],
            )

        new_tokens = outputs[0, inputs["input_ids"].shape[1]:]
        response = self.tokenizer.decode(new_tokens, skip_special_tokens=False)

        # Handle stop sequences
        if stop_sequences:
            for seq in stop_sequences:
                if seq in response:
                    response = response.split(seq)[0]

        # Convert response format for smolagents
        response = self._convert_response_to_smolagent_format(response)

        return ChatMessage(role=MessageRole.ASSISTANT, content=response)


# Example tools
class CalculatorTool(Tool):
    name = "calculator"
    description = "Evaluates a mathematical expression and returns the result."
    inputs = {
        "expression": {
            "type": "string",
            "description": "The mathematical expression to evaluate (e.g., '2 + 2 * 3')"
        }
    }
    output_type = "number"

    def forward(self, expression: str) -> float:
        # Safe eval for math expressions
        allowed = set("0123456789+-*/().^ ")
        if not all(c in allowed for c in expression):
            raise ValueError("Invalid characters in expression")
        return eval(expression.replace("^", "**"))


class FibonacciTool(Tool):
    name = "fibonacci"
    description = "Calculate the nth Fibonacci number."
    inputs = {
        "n": {
            "type": "integer",
            "description": "The position in Fibonacci sequence (0-indexed)"
        }
    }
    output_type = "integer"

    def forward(self, n: int) -> int:
        if n < 0:
            raise ValueError("n must be non-negative")
        if n <= 1:
            return n
        a, b = 0, 1
        for _ in range(2, n + 1):
            a, b = b, a + b
        return b


SHORT_PROMPT_TEMPLATES = {
    "system_prompt": """You solve tasks by writing Python code.

Rules:
- Write code inside <code> and </code> tags
- Use print() to show results
- Use final_answer(result) when done

Format:
Thoughts: your reasoning
<code>
# your code
</code>""",
    "planning": {
        "initial_plan": "",
        "update_plan_pre_messages": "",
        "update_plan_post_messages": "",
    },
    "managed_agent": {
        "task": "",
        "report": "",
    },
    "final_answer": {
        "pre_messages": "",
        "post_messages": "",
    },
}


def create_agent(
    model_id: str = "AutomatedScientist/pynb-73m-base",
    tools: list[Tool] | None = None,
    additional_authorized_imports: list[str] | None = None,
    max_steps: int = 5,
    use_short_prompt: bool = True,
) -> CodeAgent:
    """
    Create a CodeAgent with LocalPythonExecutor.

    Args:
        model_id: HuggingFace model ID or local path
        tools: List of tools to provide to the agent
        additional_authorized_imports: Extra imports to allow in executor
        max_steps: Maximum agent steps before stopping
        use_short_prompt: Use shorter system prompt for small context models

    Returns:
        Configured CodeAgent instance
    """
    model = LocalCodeModel(model_id)

    # Default authorized imports
    authorized_imports = [
        "math", "statistics", "random", "datetime",
        "collections", "itertools", "re", "json",
        "functools", "operator"
    ]
    if additional_authorized_imports:
        authorized_imports.extend(additional_authorized_imports)

    # Create executor with sandbox
    executor = LocalPythonExecutor(
        additional_authorized_imports=authorized_imports,
        max_print_outputs_length=10000,
    )

    # Build agent config
    agent_kwargs = {
        "tools": tools or [],
        "model": model,
        "executor": executor,
        "max_steps": max_steps,
        "verbosity_level": 1,
    }

    # Use short prompt for small context models
    if use_short_prompt:
        agent_kwargs["prompt_templates"] = SHORT_PROMPT_TEMPLATES

    agent = CodeAgent(**agent_kwargs)

    return agent


def run_task(agent: CodeAgent, task: str) -> any:
    """
    Run a task through the agent.

    Args:
        agent: CodeAgent instance
        task: Natural language task description

    Returns:
        Agent output
    """
    print(f"\n{'='*60}")
    print(f"Task: {task}")
    print(f"{'='*60}\n")

    result = agent.run(task)

    print(f"\n{'='*60}")
    print(f"Result: {result}")
    print(f"{'='*60}\n")

    return result


if __name__ == "__main__":
    import sys

    # Use local checkpoint if available, otherwise HuggingFace
    model_id = "checkpoint" if os.path.exists("checkpoint") else "AutomatedScientist/pynb-73m-base"

    agent = create_agent(
        model_id=model_id,
        tools=[CalculatorTool(), FibonacciTool()],
        max_steps=8,
    )

    # Run example task
    task = sys.argv[1] if len(sys.argv) > 1 else "Calculate 15 * 7 + 23"
    try:
        result = run_task(agent, task)
    except Exception as e:
        print(f"Error: {e}")