Nexa_Labs / examples /demo_agent.py
Allanatrix's picture
Upload 57 files
d8328bf verified
"""Interactive demo showing the NexaSci agent reasoning, using tools, and citing sources."""
from __future__ import annotations
import json
import sys
import time
from pathlib import Path
# Add project root to Python path
project_root = Path(__file__).resolve().parents[1]
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from agent.controller import AgentController, AgentRunResult
from agent.client_llm import Message
def print_section(title: str, char: str = "=") -> None:
"""Print a formatted section header."""
print(f"\n{char * 80}")
print(f"{title:^80}")
print(f"{char * 80}\n")
def print_tool_call(turn: int, tool_name: str, args: dict) -> None:
"""Print a formatted tool call."""
print(f"\n[Turn {turn}] πŸ”§ Tool Call: {tool_name}")
print(f" Arguments: {json.dumps(args, indent=4)}")
def print_tool_result(result: dict) -> None:
"""Print a formatted tool result."""
if "error" in result:
print(f" ❌ Error: {result['error']}")
else:
print(f" βœ“ Success")
# Truncate long outputs
output_str = json.dumps(result, indent=2)
if len(output_str) > 500:
output_str = output_str[:500] + "\n ... (truncated)"
print(f" Output: {output_str}")
def print_agent_response(turn: int, response: str) -> None:
"""Print the agent's response with formatting."""
print(f"\n[Turn {turn}] πŸ€– Agent Response:")
print(f" {response[:500]}{'...' if len(response) > 500 else ''}")
def run_demo(controller: AgentController, prompt: str) -> AgentRunResult:
"""Run the agent with detailed step-by-step output."""
print_section("NexaSci Agent Demo", "=")
print(f"πŸ“ User Prompt: {prompt}\n")
messages: list[Message] = [Message(role="user", content=prompt)]
tool_results: list = []
turn = 0
for iteration in range(controller.max_turns):
turn += 1
print_section(f"Turn {turn}", "-")
# Generate response
print(f"πŸ”„ Generating response...")
start_time = time.time()
try:
response_text = controller.llm_client.generate(messages)
elapsed = time.time() - start_time
print(f" ⏱️ Generated in {elapsed:.2f}s")
except Exception as e:
elapsed = time.time() - start_time
print(f" ❌ Generation failed after {elapsed:.2f}s")
print(f" Error: {e}")
raise
print_agent_response(turn, response_text)
messages.append(Message(role="assistant", content=response_text))
# Check for tool calls using regex
import re
TOOLCALL_REGEX = re.compile(r"~~~toolcall(.*?)~~~", re.DOTALL)
FINAL_REGEX = re.compile(r"~~~final(.*?)~~~", re.DOTALL)
tool_calls = []
for match in TOOLCALL_REGEX.findall(response_text):
snippet = match.strip()
if snippet:
try:
payload = json.loads(snippet)
from tools.schemas import ToolCall
tool_calls.append(ToolCall(**payload))
except (json.JSONDecodeError, Exception):
continue
if tool_calls:
print(f"\n πŸ” Found {len(tool_calls)} tool call(s)")
for call in tool_calls:
print_tool_call(turn, call.tool, call.arguments)
# Dispatch tool
result = controller._dispatch_tool(call)
tool_results.append(result)
print_tool_result(result.output)
# Add tool result to conversation
messages.append(
Message(
role="tool",
content=json.dumps(result.output, ensure_ascii=False),
)
)
continue
# Check for final response
match = FINAL_REGEX.search(response_text)
final_payload = None
if match:
snippet = match.group(1).strip()
if snippet:
try:
final_payload = json.loads(snippet)
except json.JSONDecodeError:
final_payload = {}
if final_payload is not None:
print_section("Final Response", "=")
print(json.dumps(final_payload, indent=2))
return AgentRunResult(
final_response=final_payload,
messages=messages,
tool_results=tool_results,
)
raise RuntimeError(f"Agent did not produce a final response within {controller.max_turns} turns.")
def main() -> None:
"""Entry point for the demo."""
import argparse
parser = argparse.ArgumentParser(
description="Interactive demo of NexaSci agent with tool usage visualization."
)
parser.add_argument(
"--prompt",
default="Design a Python simulation to model the diffusion of nanoparticles in a fluid. Include visualization and cite relevant literature.",
help="The prompt to send to the agent.",
)
parser.add_argument(
"--max-turns",
type=int,
default=10,
help="Maximum number of agent turns (default: 10).",
)
args = parser.parse_args()
print("πŸš€ Initializing NexaSci Agent...")
# Check if we should use remote model server
import yaml
config_path = Path("agent/config.yaml")
use_remote = False
model_server_url = "http://127.0.0.1:8001"
if config_path.exists():
with config_path.open() as f:
config = yaml.safe_load(f)
model_server_cfg = config.get("model_server", {})
use_remote = model_server_cfg.get("enabled", False)
model_server_url = model_server_cfg.get("base_url", model_server_url)
if use_remote:
print(f" Connecting to remote model server at {model_server_url}...")
from agent.client_llm_remote import RemoteNexaSciClient
llm_client = RemoteNexaSciClient(base_url=model_server_url)
# Check health
health = llm_client.health_check()
if health.get("status") != "healthy":
print(f" ⚠️ Model server not ready: {health}")
print(f" Make sure model server is running in agent tmux!")
else:
print(f" βœ“ Connected to model server")
controller = AgentController(llm_client=llm_client, max_turns=args.max_turns)
else:
print(" (Model will load locally on first generation - this may take 30-60 seconds)")
from agent.client_llm import NexaSciModelClient
llm_client = NexaSciModelClient(lazy_load=True)
controller = AgentController(llm_client=llm_client, max_turns=args.max_turns)
print("βœ“ Agent initialized\n")
try:
result = run_demo(controller, args.prompt)
print_section("Summary", "=")
print(f"βœ“ Total turns: {len(result.messages) // 2}")
print(f"βœ“ Tools used: {len(result.tool_results)}")
if result.tool_results:
tool_names = [r.tool_name for r in result.tool_results]
print(f" Tools: {', '.join(set(tool_names))}")
print(f"βœ“ Final response generated")
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()