|
|
"""inference_smolagent.py - Run model with smolagents CodeAgent and LocalPythonExecutor""" |
|
|
import os |
|
|
import re |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from smolagents import CodeAgent, Tool |
|
|
from smolagents.local_python_executor import LocalPythonExecutor |
|
|
from smolagents.models import ChatMessage, MessageRole, Model |
|
|
|
|
|
DEBUG = int(os.environ.get("DEBUG", 0)) |
|
|
|
|
|
|
|
|
START_TOOL_CALL = "<|start_tool_call|>" |
|
|
END_TOOL_CALL = "<|end_tool_call|>" |
|
|
START_TOOL_RESPONSE = "<|start_tool_response|>" |
|
|
END_TOOL_RESPONSE = "<|end_tool_response|>" |
|
|
|
|
|
|
|
|
SMOLAGENT_CODE_START = "<code>" |
|
|
SMOLAGENT_CODE_END = "</code>" |
|
|
|
|
|
|
|
|
class LocalCodeModel(Model): |
|
|
""" |
|
|
Local model wrapper compatible with smolagents. |
|
|
|
|
|
Handles translation between smolagents format and model's training format. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_id: str, device: str = None): |
|
|
super().__init__() |
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id, fix_mistral_regex=True) |
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_id) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self._end_tool_id = self.tokenizer.encode(END_TOOL_CALL, add_special_tokens=False)[-1] |
|
|
|
|
|
def _convert_prompt_to_model_format(self, prompt: str) -> str: |
|
|
"""Convert smolagents prompt format to model's training format.""" |
|
|
|
|
|
prompt = prompt.replace(SMOLAGENT_CODE_START, START_TOOL_CALL) |
|
|
prompt = prompt.replace(SMOLAGENT_CODE_END, END_TOOL_CALL) |
|
|
return prompt |
|
|
|
|
|
def _convert_response_to_smolagent_format(self, response: str) -> str: |
|
|
"""Convert model's output format to smolagents expected format.""" |
|
|
|
|
|
response = response.replace(START_TOOL_CALL, SMOLAGENT_CODE_START) |
|
|
response = response.replace(END_TOOL_CALL, SMOLAGENT_CODE_END) |
|
|
response = response.replace(START_TOOL_RESPONSE, "") |
|
|
response = response.replace(END_TOOL_RESPONSE, "") |
|
|
|
|
|
|
|
|
response = re.sub(r'^\s*</code>\s*', '', response) |
|
|
|
|
|
|
|
|
has_open = SMOLAGENT_CODE_START in response |
|
|
has_close = SMOLAGENT_CODE_END in response |
|
|
|
|
|
|
|
|
if has_close and not has_open: |
|
|
response = response.replace(SMOLAGENT_CODE_END, "") |
|
|
|
|
|
|
|
|
if SMOLAGENT_CODE_START not in response: |
|
|
|
|
|
code_match = re.search(r'```(?:python)?\s*(.*?)\s*```', response, re.DOTALL) |
|
|
if code_match: |
|
|
code = code_match.group(1).strip() |
|
|
if code: |
|
|
response = f"Thoughts: Executing the code\n{SMOLAGENT_CODE_START}\n{code}\n{SMOLAGENT_CODE_END}" |
|
|
else: |
|
|
|
|
|
lines = response.strip().split('\n') |
|
|
code_lines = [l for l in lines if any(kw in l for kw in ['def ', 'print(', 'return ', '= ', 'import ', 'for ', 'if ', 'while '])] |
|
|
if code_lines: |
|
|
code = '\n'.join(code_lines) |
|
|
response = f"Thoughts: Executing the code\n{SMOLAGENT_CODE_START}\n{code}\n{SMOLAGENT_CODE_END}" |
|
|
else: |
|
|
|
|
|
clean = response.strip() |
|
|
if clean and not clean.startswith("Thoughts"): |
|
|
response = f"Thoughts: Attempting execution\n{SMOLAGENT_CODE_START}\nprint('No valid code generated')\n{SMOLAGENT_CODE_END}" |
|
|
|
|
|
|
|
|
if SMOLAGENT_CODE_START in response and SMOLAGENT_CODE_END not in response: |
|
|
response = response + f"\n{SMOLAGENT_CODE_END}" |
|
|
|
|
|
return response |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
messages: list[ChatMessage], |
|
|
stop_sequences: list[str] | None = None, |
|
|
grammar: str | None = None, |
|
|
tools_to_call_from: list[Tool] | None = None, |
|
|
**kwargs, |
|
|
) -> ChatMessage: |
|
|
"""Generate response for message history (required by smolagents Model).""" |
|
|
|
|
|
if DEBUG: |
|
|
print("\n[DEBUG] Messages received by model:") |
|
|
for i, msg in enumerate(messages): |
|
|
role = msg.role.value if hasattr(msg.role, "value") else msg.role |
|
|
content = str(msg.content)[:200] if msg.content else "<empty>" |
|
|
print(f" [{i}] {role}: {content}...") |
|
|
print() |
|
|
|
|
|
|
|
|
messages_dicts = [] |
|
|
for msg in messages: |
|
|
if hasattr(msg, "role") and hasattr(msg, "content"): |
|
|
role = msg.role.value if hasattr(msg.role, "value") else str(msg.role) |
|
|
content = msg.content if isinstance(msg.content, str) else str(msg.content or "") |
|
|
|
|
|
content = self._convert_prompt_to_model_format(content) |
|
|
|
|
|
if "Observation:" in content or "Out:" in content: |
|
|
|
|
|
obs_match = re.search(r'(?:Observation:|Out:)\s*(.*)', content, re.DOTALL) |
|
|
if obs_match: |
|
|
obs_content = obs_match.group(1).strip() |
|
|
content = f"{START_TOOL_RESPONSE}\n{obs_content}\n{END_TOOL_RESPONSE}" |
|
|
messages_dicts.append({"role": role, "content": content}) |
|
|
else: |
|
|
messages_dicts.append(msg) |
|
|
|
|
|
|
|
|
prompt = self.tokenizer.apply_chat_template( |
|
|
messages_dicts, |
|
|
add_generation_prompt=True, |
|
|
tokenize=False |
|
|
) |
|
|
|
|
|
|
|
|
if DEBUG: |
|
|
full_tokens = self.tokenizer(prompt, return_tensors="pt") |
|
|
print(f"[DEBUG] Prompt length: {full_tokens['input_ids'].shape[1]} tokens (max: 2048)") |
|
|
|
|
|
|
|
|
max_input_tokens = 1536 |
|
|
inputs = self.tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=max_input_tokens |
|
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=512, |
|
|
temperature=0.7, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.2, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
eos_token_id=[self.tokenizer.eos_token_id, self._end_tool_id], |
|
|
) |
|
|
|
|
|
new_tokens = outputs[0, inputs["input_ids"].shape[1]:] |
|
|
response = self.tokenizer.decode(new_tokens, skip_special_tokens=False) |
|
|
|
|
|
|
|
|
if stop_sequences: |
|
|
for seq in stop_sequences: |
|
|
if seq in response: |
|
|
response = response.split(seq)[0] |
|
|
|
|
|
|
|
|
response = self._convert_response_to_smolagent_format(response) |
|
|
|
|
|
return ChatMessage(role=MessageRole.ASSISTANT, content=response) |
|
|
|
|
|
|
|
|
|
|
|
class CalculatorTool(Tool): |
|
|
name = "calculator" |
|
|
description = "Evaluates a mathematical expression and returns the result." |
|
|
inputs = { |
|
|
"expression": { |
|
|
"type": "string", |
|
|
"description": "The mathematical expression to evaluate (e.g., '2 + 2 * 3')" |
|
|
} |
|
|
} |
|
|
output_type = "number" |
|
|
|
|
|
def forward(self, expression: str) -> float: |
|
|
|
|
|
allowed = set("0123456789+-*/().^ ") |
|
|
if not all(c in allowed for c in expression): |
|
|
raise ValueError("Invalid characters in expression") |
|
|
return eval(expression.replace("^", "**")) |
|
|
|
|
|
|
|
|
class FibonacciTool(Tool): |
|
|
name = "fibonacci" |
|
|
description = "Calculate the nth Fibonacci number." |
|
|
inputs = { |
|
|
"n": { |
|
|
"type": "integer", |
|
|
"description": "The position in Fibonacci sequence (0-indexed)" |
|
|
} |
|
|
} |
|
|
output_type = "integer" |
|
|
|
|
|
def forward(self, n: int) -> int: |
|
|
if n < 0: |
|
|
raise ValueError("n must be non-negative") |
|
|
if n <= 1: |
|
|
return n |
|
|
a, b = 0, 1 |
|
|
for _ in range(2, n + 1): |
|
|
a, b = b, a + b |
|
|
return b |
|
|
|
|
|
|
|
|
SHORT_PROMPT_TEMPLATES = { |
|
|
"system_prompt": """You solve tasks by writing Python code. |
|
|
|
|
|
Rules: |
|
|
- Write code inside <code> and </code> tags |
|
|
- Use print() to show results |
|
|
- Use final_answer(result) when done |
|
|
|
|
|
Format: |
|
|
Thoughts: your reasoning |
|
|
<code> |
|
|
# your code |
|
|
</code>""", |
|
|
"planning": { |
|
|
"initial_plan": "", |
|
|
"update_plan_pre_messages": "", |
|
|
"update_plan_post_messages": "", |
|
|
}, |
|
|
"managed_agent": { |
|
|
"task": "", |
|
|
"report": "", |
|
|
}, |
|
|
"final_answer": { |
|
|
"pre_messages": "", |
|
|
"post_messages": "", |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def create_agent( |
|
|
model_id: str = "AutomatedScientist/pynb-73m-base", |
|
|
tools: list[Tool] | None = None, |
|
|
additional_authorized_imports: list[str] | None = None, |
|
|
max_steps: int = 5, |
|
|
use_short_prompt: bool = True, |
|
|
) -> CodeAgent: |
|
|
""" |
|
|
Create a CodeAgent with LocalPythonExecutor. |
|
|
|
|
|
Args: |
|
|
model_id: HuggingFace model ID or local path |
|
|
tools: List of tools to provide to the agent |
|
|
additional_authorized_imports: Extra imports to allow in executor |
|
|
max_steps: Maximum agent steps before stopping |
|
|
use_short_prompt: Use shorter system prompt for small context models |
|
|
|
|
|
Returns: |
|
|
Configured CodeAgent instance |
|
|
""" |
|
|
model = LocalCodeModel(model_id) |
|
|
|
|
|
|
|
|
authorized_imports = [ |
|
|
"math", "statistics", "random", "datetime", |
|
|
"collections", "itertools", "re", "json", |
|
|
"functools", "operator" |
|
|
] |
|
|
if additional_authorized_imports: |
|
|
authorized_imports.extend(additional_authorized_imports) |
|
|
|
|
|
|
|
|
executor = LocalPythonExecutor( |
|
|
additional_authorized_imports=authorized_imports, |
|
|
max_print_outputs_length=10000, |
|
|
) |
|
|
|
|
|
|
|
|
agent_kwargs = { |
|
|
"tools": tools or [], |
|
|
"model": model, |
|
|
"executor": executor, |
|
|
"max_steps": max_steps, |
|
|
"verbosity_level": 1, |
|
|
} |
|
|
|
|
|
|
|
|
if use_short_prompt: |
|
|
agent_kwargs["prompt_templates"] = SHORT_PROMPT_TEMPLATES |
|
|
|
|
|
agent = CodeAgent(**agent_kwargs) |
|
|
|
|
|
return agent |
|
|
|
|
|
|
|
|
def run_task(agent: CodeAgent, task: str) -> any: |
|
|
""" |
|
|
Run a task through the agent. |
|
|
|
|
|
Args: |
|
|
agent: CodeAgent instance |
|
|
task: Natural language task description |
|
|
|
|
|
Returns: |
|
|
Agent output |
|
|
""" |
|
|
print(f"\n{'='*60}") |
|
|
print(f"Task: {task}") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
result = agent.run(task) |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"Result: {result}") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import sys |
|
|
|
|
|
|
|
|
model_id = "checkpoint" if os.path.exists("checkpoint") else "AutomatedScientist/pynb-73m-base" |
|
|
|
|
|
agent = create_agent( |
|
|
model_id=model_id, |
|
|
tools=[CalculatorTool(), FibonacciTool()], |
|
|
max_steps=8, |
|
|
) |
|
|
|
|
|
|
|
|
task = sys.argv[1] if len(sys.argv) > 1 else "Calculate 15 * 7 + 23" |
|
|
try: |
|
|
result = run_task(agent, task) |
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
|