File size: 12,143 Bytes
9ab70a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 |
"""inference_smolagent.py - Run model with smolagents CodeAgent and LocalPythonExecutor"""
import os
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from smolagents import CodeAgent, Tool
from smolagents.local_python_executor import LocalPythonExecutor
from smolagents.models import ChatMessage, MessageRole, Model
DEBUG = int(os.environ.get("DEBUG", 0))
# Model's special tokens (from training)
START_TOOL_CALL = "<|start_tool_call|>"
END_TOOL_CALL = "<|end_tool_call|>"
START_TOOL_RESPONSE = "<|start_tool_response|>"
END_TOOL_RESPONSE = "<|end_tool_response|>"
# Smolagents expected tokens
SMOLAGENT_CODE_START = "<code>"
SMOLAGENT_CODE_END = "</code>"
class LocalCodeModel(Model):
"""
Local model wrapper compatible with smolagents.
Handles translation between smolagents format and model's training format.
"""
def __init__(self, model_id: str, device: str = None):
super().__init__()
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_id, fix_mistral_regex=True)
self.model = AutoModelForCausalLM.from_pretrained(model_id)
self.model.to(self.device)
self.model.eval()
# Cache special token IDs for stopping
self._end_tool_id = self.tokenizer.encode(END_TOOL_CALL, add_special_tokens=False)[-1]
def _convert_prompt_to_model_format(self, prompt: str) -> str:
"""Convert smolagents prompt format to model's training format."""
# Replace smolagents code markers with model's markers
prompt = prompt.replace(SMOLAGENT_CODE_START, START_TOOL_CALL)
prompt = prompt.replace(SMOLAGENT_CODE_END, END_TOOL_CALL)
return prompt
def _convert_response_to_smolagent_format(self, response: str) -> str:
"""Convert model's output format to smolagents expected format."""
# Replace model's markers with smolagents markers
response = response.replace(START_TOOL_CALL, SMOLAGENT_CODE_START)
response = response.replace(END_TOOL_CALL, SMOLAGENT_CODE_END)
response = response.replace(START_TOOL_RESPONSE, "")
response = response.replace(END_TOOL_RESPONSE, "")
# Clean up: remove orphan closing tags at start
response = re.sub(r'^\s*</code>\s*', '', response)
# Check if we have valid <code>...</code> block
has_open = SMOLAGENT_CODE_START in response
has_close = SMOLAGENT_CODE_END in response
# If only closing tag, remove it
if has_close and not has_open:
response = response.replace(SMOLAGENT_CODE_END, "")
# If no code markers, try to extract and wrap code
if SMOLAGENT_CODE_START not in response:
# Look for python code patterns in markdown
code_match = re.search(r'```(?:python)?\s*(.*?)\s*```', response, re.DOTALL)
if code_match:
code = code_match.group(1).strip()
if code:
response = f"Thoughts: Executing the code\n{SMOLAGENT_CODE_START}\n{code}\n{SMOLAGENT_CODE_END}"
else:
# Look for any code-like content
lines = response.strip().split('\n')
code_lines = [l for l in lines if any(kw in l for kw in ['def ', 'print(', 'return ', '= ', 'import ', 'for ', 'if ', 'while '])]
if code_lines:
code = '\n'.join(code_lines)
response = f"Thoughts: Executing the code\n{SMOLAGENT_CODE_START}\n{code}\n{SMOLAGENT_CODE_END}"
else:
# Fallback: wrap entire response as code if it looks like code
clean = response.strip()
if clean and not clean.startswith("Thoughts"):
response = f"Thoughts: Attempting execution\n{SMOLAGENT_CODE_START}\nprint('No valid code generated')\n{SMOLAGENT_CODE_END}"
# Ensure closing tag exists if opening exists
if SMOLAGENT_CODE_START in response and SMOLAGENT_CODE_END not in response:
response = response + f"\n{SMOLAGENT_CODE_END}"
return response
def generate(
self,
messages: list[ChatMessage],
stop_sequences: list[str] | None = None,
grammar: str | None = None,
tools_to_call_from: list[Tool] | None = None,
**kwargs,
) -> ChatMessage:
"""Generate response for message history (required by smolagents Model)."""
# Debug: show what messages are passed (including executor output)
if DEBUG:
print("\n[DEBUG] Messages received by model:")
for i, msg in enumerate(messages):
role = msg.role.value if hasattr(msg.role, "value") else msg.role
content = str(msg.content)[:200] if msg.content else "<empty>"
print(f" [{i}] {role}: {content}...")
print()
# Convert ChatMessage objects to dicts for chat template
messages_dicts = []
for msg in messages:
if hasattr(msg, "role") and hasattr(msg, "content"):
role = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
content = msg.content if isinstance(msg.content, str) else str(msg.content or "")
# Convert prompt format in content
content = self._convert_prompt_to_model_format(content)
# Wrap observations (executor output) in tool response tokens
if "Observation:" in content or "Out:" in content:
# Extract the observation content
obs_match = re.search(r'(?:Observation:|Out:)\s*(.*)', content, re.DOTALL)
if obs_match:
obs_content = obs_match.group(1).strip()
content = f"{START_TOOL_RESPONSE}\n{obs_content}\n{END_TOOL_RESPONSE}"
messages_dicts.append({"role": role, "content": content})
else:
messages_dicts.append(msg)
# Convert messages to prompt using chat template
prompt = self.tokenizer.apply_chat_template(
messages_dicts,
add_generation_prompt=True,
tokenize=False
)
# Check prompt length
if DEBUG:
full_tokens = self.tokenizer(prompt, return_tensors="pt")
print(f"[DEBUG] Prompt length: {full_tokens['input_ids'].shape[1]} tokens (max: 2048)")
# Truncate to fit model's context window (2048 tokens, leave room for generation)
max_input_tokens = 1536 # Leave 512 for generation
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=max_input_tokens
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
top_p=0.9,
repetition_penalty=1.2,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=[self.tokenizer.eos_token_id, self._end_tool_id],
)
new_tokens = outputs[0, inputs["input_ids"].shape[1]:]
response = self.tokenizer.decode(new_tokens, skip_special_tokens=False)
# Handle stop sequences
if stop_sequences:
for seq in stop_sequences:
if seq in response:
response = response.split(seq)[0]
# Convert response format for smolagents
response = self._convert_response_to_smolagent_format(response)
return ChatMessage(role=MessageRole.ASSISTANT, content=response)
# Example tools
class CalculatorTool(Tool):
name = "calculator"
description = "Evaluates a mathematical expression and returns the result."
inputs = {
"expression": {
"type": "string",
"description": "The mathematical expression to evaluate (e.g., '2 + 2 * 3')"
}
}
output_type = "number"
def forward(self, expression: str) -> float:
# Safe eval for math expressions
allowed = set("0123456789+-*/().^ ")
if not all(c in allowed for c in expression):
raise ValueError("Invalid characters in expression")
return eval(expression.replace("^", "**"))
class FibonacciTool(Tool):
name = "fibonacci"
description = "Calculate the nth Fibonacci number."
inputs = {
"n": {
"type": "integer",
"description": "The position in Fibonacci sequence (0-indexed)"
}
}
output_type = "integer"
def forward(self, n: int) -> int:
if n < 0:
raise ValueError("n must be non-negative")
if n <= 1:
return n
a, b = 0, 1
for _ in range(2, n + 1):
a, b = b, a + b
return b
SHORT_PROMPT_TEMPLATES = {
"system_prompt": """You solve tasks by writing Python code.
Rules:
- Write code inside <code> and </code> tags
- Use print() to show results
- Use final_answer(result) when done
Format:
Thoughts: your reasoning
<code>
# your code
</code>""",
"planning": {
"initial_plan": "",
"update_plan_pre_messages": "",
"update_plan_post_messages": "",
},
"managed_agent": {
"task": "",
"report": "",
},
"final_answer": {
"pre_messages": "",
"post_messages": "",
},
}
def create_agent(
model_id: str = "AutomatedScientist/pynb-73m-base",
tools: list[Tool] | None = None,
additional_authorized_imports: list[str] | None = None,
max_steps: int = 5,
use_short_prompt: bool = True,
) -> CodeAgent:
"""
Create a CodeAgent with LocalPythonExecutor.
Args:
model_id: HuggingFace model ID or local path
tools: List of tools to provide to the agent
additional_authorized_imports: Extra imports to allow in executor
max_steps: Maximum agent steps before stopping
use_short_prompt: Use shorter system prompt for small context models
Returns:
Configured CodeAgent instance
"""
model = LocalCodeModel(model_id)
# Default authorized imports
authorized_imports = [
"math", "statistics", "random", "datetime",
"collections", "itertools", "re", "json",
"functools", "operator"
]
if additional_authorized_imports:
authorized_imports.extend(additional_authorized_imports)
# Create executor with sandbox
executor = LocalPythonExecutor(
additional_authorized_imports=authorized_imports,
max_print_outputs_length=10000,
)
# Build agent config
agent_kwargs = {
"tools": tools or [],
"model": model,
"executor": executor,
"max_steps": max_steps,
"verbosity_level": 1,
}
# Use short prompt for small context models
if use_short_prompt:
agent_kwargs["prompt_templates"] = SHORT_PROMPT_TEMPLATES
agent = CodeAgent(**agent_kwargs)
return agent
def run_task(agent: CodeAgent, task: str) -> any:
"""
Run a task through the agent.
Args:
agent: CodeAgent instance
task: Natural language task description
Returns:
Agent output
"""
print(f"\n{'='*60}")
print(f"Task: {task}")
print(f"{'='*60}\n")
result = agent.run(task)
print(f"\n{'='*60}")
print(f"Result: {result}")
print(f"{'='*60}\n")
return result
if __name__ == "__main__":
import sys
# Use local checkpoint if available, otherwise HuggingFace
model_id = "checkpoint" if os.path.exists("checkpoint") else "AutomatedScientist/pynb-73m-base"
agent = create_agent(
model_id=model_id,
tools=[CalculatorTool(), FibonacciTool()],
max_steps=8,
)
# Run example task
task = sys.argv[1] if len(sys.argv) > 1 else "Calculate 15 * 7 + 23"
try:
result = run_task(agent, task)
except Exception as e:
print(f"Error: {e}")
|