Spaces:
Paused
Paused
| """Interactive demo showing the NexaSci agent reasoning, using tools, and citing sources.""" | |
| from __future__ import annotations | |
| import json | |
| import sys | |
| import time | |
| from pathlib import Path | |
| # Add project root to Python path | |
| project_root = Path(__file__).resolve().parents[1] | |
| if str(project_root) not in sys.path: | |
| sys.path.insert(0, str(project_root)) | |
| from agent.controller import AgentController, AgentRunResult | |
| from agent.client_llm import Message | |
| def print_section(title: str, char: str = "=") -> None: | |
| """Print a formatted section header.""" | |
| print(f"\n{char * 80}") | |
| print(f"{title:^80}") | |
| print(f"{char * 80}\n") | |
| def print_tool_call(turn: int, tool_name: str, args: dict) -> None: | |
| """Print a formatted tool call.""" | |
| print(f"\n[Turn {turn}] π§ Tool Call: {tool_name}") | |
| print(f" Arguments: {json.dumps(args, indent=4)}") | |
| def print_tool_result(result: dict) -> None: | |
| """Print a formatted tool result.""" | |
| if "error" in result: | |
| print(f" β Error: {result['error']}") | |
| else: | |
| print(f" β Success") | |
| # Truncate long outputs | |
| output_str = json.dumps(result, indent=2) | |
| if len(output_str) > 500: | |
| output_str = output_str[:500] + "\n ... (truncated)" | |
| print(f" Output: {output_str}") | |
| def print_agent_response(turn: int, response: str) -> None: | |
| """Print the agent's response with formatting.""" | |
| print(f"\n[Turn {turn}] π€ Agent Response:") | |
| print(f" {response[:500]}{'...' if len(response) > 500 else ''}") | |
| def run_demo(controller: AgentController, prompt: str) -> AgentRunResult: | |
| """Run the agent with detailed step-by-step output.""" | |
| print_section("NexaSci Agent Demo", "=") | |
| print(f"π User Prompt: {prompt}\n") | |
| messages: list[Message] = [Message(role="user", content=prompt)] | |
| tool_results: list = [] | |
| turn = 0 | |
| for iteration in range(controller.max_turns): | |
| turn += 1 | |
| print_section(f"Turn {turn}", "-") | |
| # Generate response | |
| print(f"π Generating response...") | |
| start_time = time.time() | |
| try: | |
| response_text = controller.llm_client.generate(messages) | |
| elapsed = time.time() - start_time | |
| print(f" β±οΈ Generated in {elapsed:.2f}s") | |
| except Exception as e: | |
| elapsed = time.time() - start_time | |
| print(f" β Generation failed after {elapsed:.2f}s") | |
| print(f" Error: {e}") | |
| raise | |
| print_agent_response(turn, response_text) | |
| messages.append(Message(role="assistant", content=response_text)) | |
| # Check for tool calls using regex | |
| import re | |
| TOOLCALL_REGEX = re.compile(r"~~~toolcall(.*?)~~~", re.DOTALL) | |
| FINAL_REGEX = re.compile(r"~~~final(.*?)~~~", re.DOTALL) | |
| tool_calls = [] | |
| for match in TOOLCALL_REGEX.findall(response_text): | |
| snippet = match.strip() | |
| if snippet: | |
| try: | |
| payload = json.loads(snippet) | |
| from tools.schemas import ToolCall | |
| tool_calls.append(ToolCall(**payload)) | |
| except (json.JSONDecodeError, Exception): | |
| continue | |
| if tool_calls: | |
| print(f"\n π Found {len(tool_calls)} tool call(s)") | |
| for call in tool_calls: | |
| print_tool_call(turn, call.tool, call.arguments) | |
| # Dispatch tool | |
| result = controller._dispatch_tool(call) | |
| tool_results.append(result) | |
| print_tool_result(result.output) | |
| # Add tool result to conversation | |
| messages.append( | |
| Message( | |
| role="tool", | |
| content=json.dumps(result.output, ensure_ascii=False), | |
| ) | |
| ) | |
| continue | |
| # Check for final response | |
| match = FINAL_REGEX.search(response_text) | |
| final_payload = None | |
| if match: | |
| snippet = match.group(1).strip() | |
| if snippet: | |
| try: | |
| final_payload = json.loads(snippet) | |
| except json.JSONDecodeError: | |
| final_payload = {} | |
| if final_payload is not None: | |
| print_section("Final Response", "=") | |
| print(json.dumps(final_payload, indent=2)) | |
| return AgentRunResult( | |
| final_response=final_payload, | |
| messages=messages, | |
| tool_results=tool_results, | |
| ) | |
| raise RuntimeError(f"Agent did not produce a final response within {controller.max_turns} turns.") | |
| def main() -> None: | |
| """Entry point for the demo.""" | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| description="Interactive demo of NexaSci agent with tool usage visualization." | |
| ) | |
| parser.add_argument( | |
| "--prompt", | |
| default="Design a Python simulation to model the diffusion of nanoparticles in a fluid. Include visualization and cite relevant literature.", | |
| help="The prompt to send to the agent.", | |
| ) | |
| parser.add_argument( | |
| "--max-turns", | |
| type=int, | |
| default=10, | |
| help="Maximum number of agent turns (default: 10).", | |
| ) | |
| args = parser.parse_args() | |
| print("π Initializing NexaSci Agent...") | |
| # Check if we should use remote model server | |
| import yaml | |
| config_path = Path("agent/config.yaml") | |
| use_remote = False | |
| model_server_url = "http://127.0.0.1:8001" | |
| if config_path.exists(): | |
| with config_path.open() as f: | |
| config = yaml.safe_load(f) | |
| model_server_cfg = config.get("model_server", {}) | |
| use_remote = model_server_cfg.get("enabled", False) | |
| model_server_url = model_server_cfg.get("base_url", model_server_url) | |
| if use_remote: | |
| print(f" Connecting to remote model server at {model_server_url}...") | |
| from agent.client_llm_remote import RemoteNexaSciClient | |
| llm_client = RemoteNexaSciClient(base_url=model_server_url) | |
| # Check health | |
| health = llm_client.health_check() | |
| if health.get("status") != "healthy": | |
| print(f" β οΈ Model server not ready: {health}") | |
| print(f" Make sure model server is running in agent tmux!") | |
| else: | |
| print(f" β Connected to model server") | |
| controller = AgentController(llm_client=llm_client, max_turns=args.max_turns) | |
| else: | |
| print(" (Model will load locally on first generation - this may take 30-60 seconds)") | |
| from agent.client_llm import NexaSciModelClient | |
| llm_client = NexaSciModelClient(lazy_load=True) | |
| controller = AgentController(llm_client=llm_client, max_turns=args.max_turns) | |
| print("β Agent initialized\n") | |
| try: | |
| result = run_demo(controller, args.prompt) | |
| print_section("Summary", "=") | |
| print(f"β Total turns: {len(result.messages) // 2}") | |
| print(f"β Tools used: {len(result.tool_results)}") | |
| if result.tool_results: | |
| tool_names = [r.tool_name for r in result.tool_results] | |
| print(f" Tools: {', '.join(set(tool_names))}") | |
| print(f"β Final response generated") | |
| except Exception as e: | |
| print(f"\nβ Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |