""" Basic Agent Main Graph Module (FastAPI Compatible - Minimal Changes) This module implements the core workflow graph for the Basic Agent system. It defines the agent's decision-making flow between model deployment and React-based compute workflows. CHANGES FROM ORIGINAL: - __init__ now accepts optional tools and llm parameters - Added async create() classmethod for FastAPI - Fully backwards compatible with existing CLI code Author: Your Name License: Private """ import asyncio from typing import Dict, Any, List, Optional import uuid import json import logging from langgraph.graph import StateGraph, END, START from typing_extensions import TypedDict from constant import Constants # Import node functions (to be implemented in separate files) from langgraph.checkpoint.memory import MemorySaver from ComputeAgent.graph.graph_deploy import DeployModelAgent from ComputeAgent.graph.graph_ReAct import ReactWorkflow from ComputeAgent.models.model_manager import ModelManager from langchain_core.messages import HumanMessage, SystemMessage from langchain_mcp_adapters.client import MultiServerMCPClient from ComputeAgent.graph.state import AgentState import os # Initialize model manager for dynamic LLM loading and management model_manager = ModelManager() # Global MemorySaver (persists state across requests) memory_saver = MemorySaver() logger = logging.getLogger("ComputeAgent") # Get the project root directory (parent of ComputeAgent folder) import sys # __file__ is in ComputeAgent/graph/graph.py # Go up 3 levels: graph -> ComputeAgent -> project_root project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) mcp_server_path = os.path.join(project_root, "Compute_MCP", "main.py") # Use sys.executable to get the current Python interpreter path python_executable = sys.executable mcp_client = MultiServerMCPClient( { "hivecompute": { "command": python_executable, "args": [mcp_server_path], "transport": "stdio", "env": { # Pass HF Spaces secrets to the MCP subprocess "HIVE_COMPUTE_DEFAULT_API_TOKEN": os.getenv("HIVE_COMPUTE_DEFAULT_API_TOKEN", ""), "HIVE_COMPUTE_BASE_API_URL": os.getenv("HIVE_COMPUTE_BASE_API_URL", "https://api.hivecompute.ai"), # Also pass these to ensure Python works correctly "PATH": os.getenv("PATH", ""), "PYTHONPATH": os.getenv("PYTHONPATH", ""), } } } ) class ComputeAgent: """ Main Compute Agent class providing AI-powered decision routing and execution. This class orchestrates the complete agent workflow including: - Decision routing between model deployment and React agent - Model deployment workflow with capacity estimation and approval - React agent execution with compute capabilities - Error handling and state management Attributes: graph: Compiled LangGraph workflow model_name: Default model name for operations Usage: # For CLI (backwards compatible): agent = ComputeAgent() # For FastAPI (async): agent = await ComputeAgent.create() """ def __init__(self, tools=None, llm=None): """ Initialize Compute Agent with optional pre-loaded dependencies. Args: tools: Pre-loaded MCP tools (optional, will load if not provided) llm: Pre-loaded LLM model (optional, will load if not provided) """ # If tools/llm not provided, load them synchronously (for CLI) if tools is None: self.tools = asyncio.run(mcp_client.get_tools()) else: self.tools = tools if llm is None: self.llm = asyncio.run(model_manager.load_llm_model(Constants.DEFAULT_LLM_FC)) else: self.llm = llm self.deploy_subgraph = DeployModelAgent(llm=self.llm, react_tools=self.tools) self.react_subgraph = ReactWorkflow(llm=self.llm, tools=self.tools) self.graph = self._create_graph() @classmethod async def create(cls): """ Async factory method for creating ComputeAgent. Use this in FastAPI to avoid asyncio.run() issues. Returns: Initialized ComputeAgent instance """ logger.info("🔧 Loading tools and LLM asynchronously...") tools = await mcp_client.get_tools() llm = await model_manager.load_llm_model(Constants.DEFAULT_LLM_FC) # Initialize DeployModelAgent with its own tools deploy_subgraph = await DeployModelAgent.create(llm=llm, custom_tools=None) return cls(tools=tools, llm=llm) async def decision_node(self, state: Dict[str, Any]) -> Dict[str, Any]: """ Node that handles routing decisions for the ComputeAgent workflow. Analyzes the user query to determine whether to route to: - Model deployment workflow (deploy_model) - React agent workflow (react_agent) Args: state: Current agent state with memory fields Returns: Updated state with routing decision """ # Get user context user_id = state.get("user_id", "") session_id = state.get("session_id", "") query = state.get("query", "") logger.info(f"🎯 Decision node processing query for {user_id}:{session_id}") # Build memory context for decision making memory_context = "" if user_id and session_id: try: from helpers.memory import get_memory_manager memory_manager = get_memory_manager() memory_context = await memory_manager.build_context_for_node(user_id, session_id, "decision") if memory_context: logger.info(f"🧠 Using memory context for decision routing") except Exception as e: logger.warning(f"⚠️ Could not load memory context for decision: {e}") try: # Create a simple LLM for decision making # Load main LLM using ModelManager llm = await model_manager.load_llm_model(Constants.DEFAULT_LLM_NAME) # Create decision prompt decision_system_prompt = f""" You are a routing assistant for ComputeAgent. Analyze the user's query and decide which workflow to use. Choose between: 1. DEPLOY_MODEL - For queries about deploy AI model from HuggingFace. In this case the user MUST specify the model card name (like meta-llama/Meta-Llama-3-70B). - The user can specify the hardware capacity needed. - The user can ask for model analysis, deployment steps, or capacity estimation. 2. REACT_AGENT - For all the rest of queries. {f"Conversation Context: {memory_context}" if memory_context else "No conversation context available."} User Query: {query} Respond with only: "DEPLOY_MODEL" or "REACT_AGENT" """ # Get routing decision decision_response = await llm.ainvoke([ SystemMessage(content=decision_system_prompt) ]) routing_decision = decision_response.content.strip().upper() # Validate and set decision if "DEPLOY_MODEL" in routing_decision: agent_decision = "deploy_model" logger.info(f"📦 Routing to model deployment workflow") elif "REACT_AGENT" in routing_decision: agent_decision = "react_agent" logger.info(f"⚛️ Routing to React agent workflow") else: # Default fallback to React agent for general queries agent_decision = "react_agent" logger.warning(f"⚠️ Ambiguous routing decision '{routing_decision}', defaulting to React agent") # Update state with decision updated_state = state.copy() updated_state["agent_decision"] = agent_decision updated_state["current_step"] = "decision_complete" logger.info(f"✅ Decision node complete: {agent_decision}") return updated_state except Exception as e: logger.error(f"❌ Error in decision node: {e}") # Update state with fallback decision updated_state = state.copy() updated_state["error"] = f"Decision error (fallback used): {str(e)}" return updated_state def _create_graph(self) -> StateGraph: """ Create and configure the Compute Agent workflow graph. This method builds the complete workflow including: 1. Initial decision node - routes to deployment or React agent 2. Model deployment path: - Fetch model card from HuggingFace - Extract model information - Estimate capacity requirements - Human approval checkpoint - Deploy model or provide info 3. React agent path: - Execute React agent with compute MCP capabilities Returns: Compiled StateGraph ready for execution """ workflow = StateGraph(AgentState) # Add decision node workflow.add_node("decision", self.decision_node) # Add model deployment workflow nodes workflow.add_node("deploy_model", self.deploy_subgraph.get_compiled_graph()) # Add React agent node workflow.add_node("react_agent", self.react_subgraph.get_compiled_graph()) # Set entry point workflow.set_entry_point("decision") # Add conditional edges from decision node workflow.add_conditional_edges( "decision", lambda state: state["agent_decision"], { "deploy_model": "deploy_model", "react_agent": "react_agent", } ) # Add edges to END workflow.add_edge("deploy_model", END) workflow.add_edge("react_agent", END) # Compile with checkpointer return workflow.compile(checkpointer=memory_saver) def get_compiled_graph(self): """Return the compiled graph for use in FastAPI""" return self.graph def invoke(self, query: str, user_id: str = "default_user", session_id: str = "default_session") -> Dict[str, Any]: """ Execute the graph with a given query and memory context (synchronous wrapper for async). Args: query: User's query user_id: User identifier for memory management session_id: Session identifier for memory management Returns: Final result from the graph execution """ return asyncio.run(self.ainvoke(query, user_id, session_id)) async def ainvoke(self, query: str, user_id: str = "default_user", session_id: str = "default_session") -> Dict[str, Any]: """ Execute the graph with a given query and memory context (async). Args: query: User's query user_id: User identifier for memory management session_id: Session identifier for memory management Returns: Final result from the graph execution containing: - response: Final response to user - agent_decision: Which path was taken - deployment_result: If deployment path was taken - react_results: If React agent path was taken """ initial_state = { "user_id": user_id, "session_id": session_id, "query": query, "response": "", "current_step": "start", "agent_decision": "", "deployment_approved": False, "model_name": "", "model_card": {}, "model_info": {}, "capacity_estimate": {}, "deployment_result": {}, "react_results": {}, "tool_calls": [], "tool_results": [], "messages": [], # Approval fields for ReactWorkflow "pending_tool_calls": [], "approved_tool_calls": [], "rejected_tool_calls": [], "modified_tool_calls": [], "needs_re_reasoning": False, "re_reasoning_feedback": "" } # Create config with thread_id for checkpointer config = { "configurable": { "thread_id": f"{user_id}_{session_id}" } } try: result = await self.graph.ainvoke(initial_state, config) return result except Exception as e: logger.error(f"Error in graph execution: {e}") return { **initial_state, "error": str(e), "error_step": initial_state.get("current_step", "unknown"), "response": f"An error occurred during execution: {str(e)}" } async def astream_generate_nodes(self, query: str, user_id: str = "default_user", session_id: str = "default_session"): """ Stream the graph execution node by node (async). Args: query: User's query user_id: User identifier for memory management session_id: Session identifier for memory management Yields: Dict containing node execution updates """ initial_state = { "user_id": user_id, "session_id": session_id, "query": query, "response": "", "current_step": "start", "agent_decision": "", "deployment_approved": False, "model_name": "", "model_card": {}, "model_info": {}, "capacity_estimate": {}, "deployment_result": {}, "react_results": {}, "tool_calls": [], "tool_results": [], "messages": [], # Approval fields for ReactWorkflow "pending_tool_calls": [], "approved_tool_calls": [], "rejected_tool_calls": [], "modified_tool_calls": [], "needs_re_reasoning": False, "re_reasoning_feedback": "" } # Create config with thread_id for checkpointer config = { "configurable": { "thread_id": f"{user_id}_{session_id}" } } try: # Stream through the graph execution async for chunk in self.graph.astream(initial_state, config): # Each chunk contains the node name and its output for node_name, node_output in chunk.items(): yield { "node": node_name, "output": node_output, **node_output # Include all state updates } except Exception as e: logger.error(f"Error in graph streaming: {e}") yield { "error": str(e), "status": "error", "error_step": initial_state.get("current_step", "unknown") } def draw_graph(self, output_file_path: str = "basic_agent_graph.png"): """ Generate and save a visual representation of the Basic Agent workflow graph. Args: output_file_path: Path where to save the graph PNG file """ try: self.graph.get_graph().draw_mermaid_png(output_file_path=output_file_path) logger.info(f"✅ Basic Agent graph visualization saved to: {output_file_path}") except Exception as e: logger.error(f"❌ Failed to generate Basic Agent graph visualization: {e}")