"""inference_smolagent.py - Run model with smolagents CodeAgent and LocalPythonExecutor""" import os import re import torch from transformers import AutoModelForCausalLM, AutoTokenizer from smolagents import CodeAgent, Tool from smolagents.local_python_executor import LocalPythonExecutor from smolagents.models import ChatMessage, MessageRole, Model DEBUG = int(os.environ.get("DEBUG", 0)) # Model's special tokens (from training) START_TOOL_CALL = "<|start_tool_call|>" END_TOOL_CALL = "<|end_tool_call|>" START_TOOL_RESPONSE = "<|start_tool_response|>" END_TOOL_RESPONSE = "<|end_tool_response|>" # Smolagents expected tokens SMOLAGENT_CODE_START = "" SMOLAGENT_CODE_END = "" class LocalCodeModel(Model): """ Local model wrapper compatible with smolagents. Handles translation between smolagents format and model's training format. """ def __init__(self, model_id: str, device: str = None): super().__init__() self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.tokenizer = AutoTokenizer.from_pretrained(model_id, fix_mistral_regex=True) self.model = AutoModelForCausalLM.from_pretrained(model_id) self.model.to(self.device) self.model.eval() # Cache special token IDs for stopping self._end_tool_id = self.tokenizer.encode(END_TOOL_CALL, add_special_tokens=False)[-1] def _convert_prompt_to_model_format(self, prompt: str) -> str: """Convert smolagents prompt format to model's training format.""" # Replace smolagents code markers with model's markers prompt = prompt.replace(SMOLAGENT_CODE_START, START_TOOL_CALL) prompt = prompt.replace(SMOLAGENT_CODE_END, END_TOOL_CALL) return prompt def _convert_response_to_smolagent_format(self, response: str) -> str: """Convert model's output format to smolagents expected format.""" # Replace model's markers with smolagents markers response = response.replace(START_TOOL_CALL, SMOLAGENT_CODE_START) response = response.replace(END_TOOL_CALL, SMOLAGENT_CODE_END) response = response.replace(START_TOOL_RESPONSE, "") response = response.replace(END_TOOL_RESPONSE, "") # Clean up: remove orphan closing tags at start response = re.sub(r'^\s*\s*', '', response) # Check if we have valid ... block has_open = SMOLAGENT_CODE_START in response has_close = SMOLAGENT_CODE_END in response # If only closing tag, remove it if has_close and not has_open: response = response.replace(SMOLAGENT_CODE_END, "") # If no code markers, try to extract and wrap code if SMOLAGENT_CODE_START not in response: # Look for python code patterns in markdown code_match = re.search(r'```(?:python)?\s*(.*?)\s*```', response, re.DOTALL) if code_match: code = code_match.group(1).strip() if code: response = f"Thoughts: Executing the code\n{SMOLAGENT_CODE_START}\n{code}\n{SMOLAGENT_CODE_END}" else: # Look for any code-like content lines = response.strip().split('\n') code_lines = [l for l in lines if any(kw in l for kw in ['def ', 'print(', 'return ', '= ', 'import ', 'for ', 'if ', 'while '])] if code_lines: code = '\n'.join(code_lines) response = f"Thoughts: Executing the code\n{SMOLAGENT_CODE_START}\n{code}\n{SMOLAGENT_CODE_END}" else: # Fallback: wrap entire response as code if it looks like code clean = response.strip() if clean and not clean.startswith("Thoughts"): response = f"Thoughts: Attempting execution\n{SMOLAGENT_CODE_START}\nprint('No valid code generated')\n{SMOLAGENT_CODE_END}" # Ensure closing tag exists if opening exists if SMOLAGENT_CODE_START in response and SMOLAGENT_CODE_END not in response: response = response + f"\n{SMOLAGENT_CODE_END}" return response def generate( self, messages: list[ChatMessage], stop_sequences: list[str] | None = None, grammar: str | None = None, tools_to_call_from: list[Tool] | None = None, **kwargs, ) -> ChatMessage: """Generate response for message history (required by smolagents Model).""" # Debug: show what messages are passed (including executor output) if DEBUG: print("\n[DEBUG] Messages received by model:") for i, msg in enumerate(messages): role = msg.role.value if hasattr(msg.role, "value") else msg.role content = str(msg.content)[:200] if msg.content else "" print(f" [{i}] {role}: {content}...") print() # Convert ChatMessage objects to dicts for chat template messages_dicts = [] for msg in messages: if hasattr(msg, "role") and hasattr(msg, "content"): role = msg.role.value if hasattr(msg.role, "value") else str(msg.role) content = msg.content if isinstance(msg.content, str) else str(msg.content or "") # Convert prompt format in content content = self._convert_prompt_to_model_format(content) # Wrap observations (executor output) in tool response tokens if "Observation:" in content or "Out:" in content: # Extract the observation content obs_match = re.search(r'(?:Observation:|Out:)\s*(.*)', content, re.DOTALL) if obs_match: obs_content = obs_match.group(1).strip() content = f"{START_TOOL_RESPONSE}\n{obs_content}\n{END_TOOL_RESPONSE}" messages_dicts.append({"role": role, "content": content}) else: messages_dicts.append(msg) # Convert messages to prompt using chat template prompt = self.tokenizer.apply_chat_template( messages_dicts, add_generation_prompt=True, tokenize=False ) # Check prompt length if DEBUG: full_tokens = self.tokenizer(prompt, return_tensors="pt") print(f"[DEBUG] Prompt length: {full_tokens['input_ids'].shape[1]} tokens (max: 2048)") # Truncate to fit model's context window (2048 tokens, leave room for generation) max_input_tokens = 1536 # Leave 512 for generation inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=max_input_tokens ).to(self.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=512, temperature=0.7, do_sample=True, top_p=0.9, repetition_penalty=1.2, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=[self.tokenizer.eos_token_id, self._end_tool_id], ) new_tokens = outputs[0, inputs["input_ids"].shape[1]:] response = self.tokenizer.decode(new_tokens, skip_special_tokens=False) # Handle stop sequences if stop_sequences: for seq in stop_sequences: if seq in response: response = response.split(seq)[0] # Convert response format for smolagents response = self._convert_response_to_smolagent_format(response) return ChatMessage(role=MessageRole.ASSISTANT, content=response) # Example tools class CalculatorTool(Tool): name = "calculator" description = "Evaluates a mathematical expression and returns the result." inputs = { "expression": { "type": "string", "description": "The mathematical expression to evaluate (e.g., '2 + 2 * 3')" } } output_type = "number" def forward(self, expression: str) -> float: # Safe eval for math expressions allowed = set("0123456789+-*/().^ ") if not all(c in allowed for c in expression): raise ValueError("Invalid characters in expression") return eval(expression.replace("^", "**")) class FibonacciTool(Tool): name = "fibonacci" description = "Calculate the nth Fibonacci number." inputs = { "n": { "type": "integer", "description": "The position in Fibonacci sequence (0-indexed)" } } output_type = "integer" def forward(self, n: int) -> int: if n < 0: raise ValueError("n must be non-negative") if n <= 1: return n a, b = 0, 1 for _ in range(2, n + 1): a, b = b, a + b return b SHORT_PROMPT_TEMPLATES = { "system_prompt": """You solve tasks by writing Python code. Rules: - Write code inside and tags - Use print() to show results - Use final_answer(result) when done Format: Thoughts: your reasoning # your code """, "planning": { "initial_plan": "", "update_plan_pre_messages": "", "update_plan_post_messages": "", }, "managed_agent": { "task": "", "report": "", }, "final_answer": { "pre_messages": "", "post_messages": "", }, } def create_agent( model_id: str = "AutomatedScientist/pynb-73m-base", tools: list[Tool] | None = None, additional_authorized_imports: list[str] | None = None, max_steps: int = 5, use_short_prompt: bool = True, ) -> CodeAgent: """ Create a CodeAgent with LocalPythonExecutor. Args: model_id: HuggingFace model ID or local path tools: List of tools to provide to the agent additional_authorized_imports: Extra imports to allow in executor max_steps: Maximum agent steps before stopping use_short_prompt: Use shorter system prompt for small context models Returns: Configured CodeAgent instance """ model = LocalCodeModel(model_id) # Default authorized imports authorized_imports = [ "math", "statistics", "random", "datetime", "collections", "itertools", "re", "json", "functools", "operator" ] if additional_authorized_imports: authorized_imports.extend(additional_authorized_imports) # Create executor with sandbox executor = LocalPythonExecutor( additional_authorized_imports=authorized_imports, max_print_outputs_length=10000, ) # Build agent config agent_kwargs = { "tools": tools or [], "model": model, "executor": executor, "max_steps": max_steps, "verbosity_level": 1, } # Use short prompt for small context models if use_short_prompt: agent_kwargs["prompt_templates"] = SHORT_PROMPT_TEMPLATES agent = CodeAgent(**agent_kwargs) return agent def run_task(agent: CodeAgent, task: str) -> any: """ Run a task through the agent. Args: agent: CodeAgent instance task: Natural language task description Returns: Agent output """ print(f"\n{'='*60}") print(f"Task: {task}") print(f"{'='*60}\n") result = agent.run(task) print(f"\n{'='*60}") print(f"Result: {result}") print(f"{'='*60}\n") return result if __name__ == "__main__": import sys # Use local checkpoint if available, otherwise HuggingFace model_id = "checkpoint" if os.path.exists("checkpoint") else "AutomatedScientist/pynb-73m-base" agent = create_agent( model_id=model_id, tools=[CalculatorTool(), FibonacciTool()], max_steps=8, ) # Run example task task = sys.argv[1] if len(sys.argv) > 1 else "Calculate 15 * 7 + 23" try: result = run_task(agent, task) except Exception as e: print(f"Error: {e}")