pynb-73m-base / inference_smolagent.py
AutomatedScientist's picture
Upload folder using huggingface_hub
9ab70a9 verified
"""inference_smolagent.py - Run model with smolagents CodeAgent and LocalPythonExecutor"""
import os
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from smolagents import CodeAgent, Tool
from smolagents.local_python_executor import LocalPythonExecutor
from smolagents.models import ChatMessage, MessageRole, Model
DEBUG = int(os.environ.get("DEBUG", 0))
# Model's special tokens (from training)
START_TOOL_CALL = "<|start_tool_call|>"
END_TOOL_CALL = "<|end_tool_call|>"
START_TOOL_RESPONSE = "<|start_tool_response|>"
END_TOOL_RESPONSE = "<|end_tool_response|>"
# Smolagents expected tokens
SMOLAGENT_CODE_START = "<code>"
SMOLAGENT_CODE_END = "</code>"
class LocalCodeModel(Model):
"""
Local model wrapper compatible with smolagents.
Handles translation between smolagents format and model's training format.
"""
def __init__(self, model_id: str, device: str = None):
super().__init__()
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_id, fix_mistral_regex=True)
self.model = AutoModelForCausalLM.from_pretrained(model_id)
self.model.to(self.device)
self.model.eval()
# Cache special token IDs for stopping
self._end_tool_id = self.tokenizer.encode(END_TOOL_CALL, add_special_tokens=False)[-1]
def _convert_prompt_to_model_format(self, prompt: str) -> str:
"""Convert smolagents prompt format to model's training format."""
# Replace smolagents code markers with model's markers
prompt = prompt.replace(SMOLAGENT_CODE_START, START_TOOL_CALL)
prompt = prompt.replace(SMOLAGENT_CODE_END, END_TOOL_CALL)
return prompt
def _convert_response_to_smolagent_format(self, response: str) -> str:
"""Convert model's output format to smolagents expected format."""
# Replace model's markers with smolagents markers
response = response.replace(START_TOOL_CALL, SMOLAGENT_CODE_START)
response = response.replace(END_TOOL_CALL, SMOLAGENT_CODE_END)
response = response.replace(START_TOOL_RESPONSE, "")
response = response.replace(END_TOOL_RESPONSE, "")
# Clean up: remove orphan closing tags at start
response = re.sub(r'^\s*</code>\s*', '', response)
# Check if we have valid <code>...</code> block
has_open = SMOLAGENT_CODE_START in response
has_close = SMOLAGENT_CODE_END in response
# If only closing tag, remove it
if has_close and not has_open:
response = response.replace(SMOLAGENT_CODE_END, "")
# If no code markers, try to extract and wrap code
if SMOLAGENT_CODE_START not in response:
# Look for python code patterns in markdown
code_match = re.search(r'```(?:python)?\s*(.*?)\s*```', response, re.DOTALL)
if code_match:
code = code_match.group(1).strip()
if code:
response = f"Thoughts: Executing the code\n{SMOLAGENT_CODE_START}\n{code}\n{SMOLAGENT_CODE_END}"
else:
# Look for any code-like content
lines = response.strip().split('\n')
code_lines = [l for l in lines if any(kw in l for kw in ['def ', 'print(', 'return ', '= ', 'import ', 'for ', 'if ', 'while '])]
if code_lines:
code = '\n'.join(code_lines)
response = f"Thoughts: Executing the code\n{SMOLAGENT_CODE_START}\n{code}\n{SMOLAGENT_CODE_END}"
else:
# Fallback: wrap entire response as code if it looks like code
clean = response.strip()
if clean and not clean.startswith("Thoughts"):
response = f"Thoughts: Attempting execution\n{SMOLAGENT_CODE_START}\nprint('No valid code generated')\n{SMOLAGENT_CODE_END}"
# Ensure closing tag exists if opening exists
if SMOLAGENT_CODE_START in response and SMOLAGENT_CODE_END not in response:
response = response + f"\n{SMOLAGENT_CODE_END}"
return response
def generate(
self,
messages: list[ChatMessage],
stop_sequences: list[str] | None = None,
grammar: str | None = None,
tools_to_call_from: list[Tool] | None = None,
**kwargs,
) -> ChatMessage:
"""Generate response for message history (required by smolagents Model)."""
# Debug: show what messages are passed (including executor output)
if DEBUG:
print("\n[DEBUG] Messages received by model:")
for i, msg in enumerate(messages):
role = msg.role.value if hasattr(msg.role, "value") else msg.role
content = str(msg.content)[:200] if msg.content else "<empty>"
print(f" [{i}] {role}: {content}...")
print()
# Convert ChatMessage objects to dicts for chat template
messages_dicts = []
for msg in messages:
if hasattr(msg, "role") and hasattr(msg, "content"):
role = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
content = msg.content if isinstance(msg.content, str) else str(msg.content or "")
# Convert prompt format in content
content = self._convert_prompt_to_model_format(content)
# Wrap observations (executor output) in tool response tokens
if "Observation:" in content or "Out:" in content:
# Extract the observation content
obs_match = re.search(r'(?:Observation:|Out:)\s*(.*)', content, re.DOTALL)
if obs_match:
obs_content = obs_match.group(1).strip()
content = f"{START_TOOL_RESPONSE}\n{obs_content}\n{END_TOOL_RESPONSE}"
messages_dicts.append({"role": role, "content": content})
else:
messages_dicts.append(msg)
# Convert messages to prompt using chat template
prompt = self.tokenizer.apply_chat_template(
messages_dicts,
add_generation_prompt=True,
tokenize=False
)
# Check prompt length
if DEBUG:
full_tokens = self.tokenizer(prompt, return_tensors="pt")
print(f"[DEBUG] Prompt length: {full_tokens['input_ids'].shape[1]} tokens (max: 2048)")
# Truncate to fit model's context window (2048 tokens, leave room for generation)
max_input_tokens = 1536 # Leave 512 for generation
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=max_input_tokens
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
top_p=0.9,
repetition_penalty=1.2,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=[self.tokenizer.eos_token_id, self._end_tool_id],
)
new_tokens = outputs[0, inputs["input_ids"].shape[1]:]
response = self.tokenizer.decode(new_tokens, skip_special_tokens=False)
# Handle stop sequences
if stop_sequences:
for seq in stop_sequences:
if seq in response:
response = response.split(seq)[0]
# Convert response format for smolagents
response = self._convert_response_to_smolagent_format(response)
return ChatMessage(role=MessageRole.ASSISTANT, content=response)
# Example tools
class CalculatorTool(Tool):
name = "calculator"
description = "Evaluates a mathematical expression and returns the result."
inputs = {
"expression": {
"type": "string",
"description": "The mathematical expression to evaluate (e.g., '2 + 2 * 3')"
}
}
output_type = "number"
def forward(self, expression: str) -> float:
# Safe eval for math expressions
allowed = set("0123456789+-*/().^ ")
if not all(c in allowed for c in expression):
raise ValueError("Invalid characters in expression")
return eval(expression.replace("^", "**"))
class FibonacciTool(Tool):
name = "fibonacci"
description = "Calculate the nth Fibonacci number."
inputs = {
"n": {
"type": "integer",
"description": "The position in Fibonacci sequence (0-indexed)"
}
}
output_type = "integer"
def forward(self, n: int) -> int:
if n < 0:
raise ValueError("n must be non-negative")
if n <= 1:
return n
a, b = 0, 1
for _ in range(2, n + 1):
a, b = b, a + b
return b
SHORT_PROMPT_TEMPLATES = {
"system_prompt": """You solve tasks by writing Python code.
Rules:
- Write code inside <code> and </code> tags
- Use print() to show results
- Use final_answer(result) when done
Format:
Thoughts: your reasoning
<code>
# your code
</code>""",
"planning": {
"initial_plan": "",
"update_plan_pre_messages": "",
"update_plan_post_messages": "",
},
"managed_agent": {
"task": "",
"report": "",
},
"final_answer": {
"pre_messages": "",
"post_messages": "",
},
}
def create_agent(
model_id: str = "AutomatedScientist/pynb-73m-base",
tools: list[Tool] | None = None,
additional_authorized_imports: list[str] | None = None,
max_steps: int = 5,
use_short_prompt: bool = True,
) -> CodeAgent:
"""
Create a CodeAgent with LocalPythonExecutor.
Args:
model_id: HuggingFace model ID or local path
tools: List of tools to provide to the agent
additional_authorized_imports: Extra imports to allow in executor
max_steps: Maximum agent steps before stopping
use_short_prompt: Use shorter system prompt for small context models
Returns:
Configured CodeAgent instance
"""
model = LocalCodeModel(model_id)
# Default authorized imports
authorized_imports = [
"math", "statistics", "random", "datetime",
"collections", "itertools", "re", "json",
"functools", "operator"
]
if additional_authorized_imports:
authorized_imports.extend(additional_authorized_imports)
# Create executor with sandbox
executor = LocalPythonExecutor(
additional_authorized_imports=authorized_imports,
max_print_outputs_length=10000,
)
# Build agent config
agent_kwargs = {
"tools": tools or [],
"model": model,
"executor": executor,
"max_steps": max_steps,
"verbosity_level": 1,
}
# Use short prompt for small context models
if use_short_prompt:
agent_kwargs["prompt_templates"] = SHORT_PROMPT_TEMPLATES
agent = CodeAgent(**agent_kwargs)
return agent
def run_task(agent: CodeAgent, task: str) -> any:
"""
Run a task through the agent.
Args:
agent: CodeAgent instance
task: Natural language task description
Returns:
Agent output
"""
print(f"\n{'='*60}")
print(f"Task: {task}")
print(f"{'='*60}\n")
result = agent.run(task)
print(f"\n{'='*60}")
print(f"Result: {result}")
print(f"{'='*60}\n")
return result
if __name__ == "__main__":
import sys
# Use local checkpoint if available, otherwise HuggingFace
model_id = "checkpoint" if os.path.exists("checkpoint") else "AutomatedScientist/pynb-73m-base"
agent = create_agent(
model_id=model_id,
tools=[CalculatorTool(), FibonacciTool()],
max_steps=8,
)
# Run example task
task = sys.argv[1] if len(sys.argv) > 1 else "Calculate 15 * 7 + 23"
try:
result = run_task(agent, task)
except Exception as e:
print(f"Error: {e}")