carraraig's picture
revert (#20)
0297f14 verified
"""
Basic Agent Main Graph Module (FastAPI Compatible - Minimal Changes)
This module implements the core workflow graph for the Basic Agent system.
It defines the agent's decision-making flow between model deployment and
React-based compute workflows.
CHANGES FROM ORIGINAL:
- __init__ now accepts optional tools and llm parameters
- Added async create() classmethod for FastAPI
- Fully backwards compatible with existing CLI code
Author: Your Name
License: Private
"""
import asyncio
from typing import Dict, Any, List, Optional
import uuid
import json
import logging
from langgraph.graph import StateGraph, END, START
from typing_extensions import TypedDict
from constant import Constants
# Import node functions (to be implemented in separate files)
from langgraph.checkpoint.memory import MemorySaver
from ComputeAgent.graph.graph_deploy import DeployModelAgent
from ComputeAgent.graph.graph_ReAct import ReactWorkflow
from ComputeAgent.models.model_manager import ModelManager
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_mcp_adapters.client import MultiServerMCPClient
from ComputeAgent.graph.state import AgentState
import os
# Initialize model manager for dynamic LLM loading and management
model_manager = ModelManager()
# Global MemorySaver (persists state across requests)
memory_saver = MemorySaver()
logger = logging.getLogger("ComputeAgent")
# Get the project root directory (parent of ComputeAgent folder)
import sys
# __file__ is in ComputeAgent/graph/graph.py
# Go up 3 levels: graph -> ComputeAgent -> project_root
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
mcp_server_path = os.path.join(project_root, "Compute_MCP", "main.py")
# Use sys.executable to get the current Python interpreter path
python_executable = sys.executable
mcp_client = MultiServerMCPClient(
{
"hivecompute": {
"command": python_executable,
"args": [mcp_server_path],
"transport": "stdio",
"env": {
# Pass HF Spaces secrets to the MCP subprocess
"HIVE_COMPUTE_DEFAULT_API_TOKEN": os.getenv("HIVE_COMPUTE_DEFAULT_API_TOKEN", ""),
"HIVE_COMPUTE_BASE_API_URL": os.getenv("HIVE_COMPUTE_BASE_API_URL", "https://api.hivecompute.ai"),
# Also pass these to ensure Python works correctly
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
}
}
}
)
class ComputeAgent:
"""
Main Compute Agent class providing AI-powered decision routing and execution.
This class orchestrates the complete agent workflow including:
- Decision routing between model deployment and React agent
- Model deployment workflow with capacity estimation and approval
- React agent execution with compute capabilities
- Error handling and state management
Attributes:
graph: Compiled LangGraph workflow
model_name: Default model name for operations
Usage:
# For CLI (backwards compatible):
agent = ComputeAgent()
# For FastAPI (async):
agent = await ComputeAgent.create()
"""
def __init__(self, tools=None, llm=None):
"""
Initialize Compute Agent with optional pre-loaded dependencies.
Args:
tools: Pre-loaded MCP tools (optional, will load if not provided)
llm: Pre-loaded LLM model (optional, will load if not provided)
"""
# If tools/llm not provided, load them synchronously (for CLI)
if tools is None:
self.tools = asyncio.run(mcp_client.get_tools())
else:
self.tools = tools
if llm is None:
self.llm = asyncio.run(model_manager.load_llm_model(Constants.DEFAULT_LLM_FC))
else:
self.llm = llm
self.deploy_subgraph = DeployModelAgent(llm=self.llm, react_tools=self.tools)
self.react_subgraph = ReactWorkflow(llm=self.llm, tools=self.tools)
self.graph = self._create_graph()
@classmethod
async def create(cls):
"""
Async factory method for creating ComputeAgent.
Use this in FastAPI to avoid asyncio.run() issues.
Returns:
Initialized ComputeAgent instance
"""
logger.info("🔧 Loading tools and LLM asynchronously...")
tools = await mcp_client.get_tools()
llm = await model_manager.load_llm_model(Constants.DEFAULT_LLM_FC)
# Initialize DeployModelAgent with its own tools
deploy_subgraph = await DeployModelAgent.create(llm=llm, custom_tools=None)
return cls(tools=tools, llm=llm)
async def decision_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""
Node that handles routing decisions for the ComputeAgent workflow.
Analyzes the user query to determine whether to route to:
- Model deployment workflow (deploy_model)
- React agent workflow (react_agent)
Args:
state: Current agent state with memory fields
Returns:
Updated state with routing decision
"""
# Get user context
user_id = state.get("user_id", "")
session_id = state.get("session_id", "")
query = state.get("query", "")
logger.info(f"🎯 Decision node processing query for {user_id}:{session_id}")
# Build memory context for decision making
memory_context = ""
if user_id and session_id:
try:
from helpers.memory import get_memory_manager
memory_manager = get_memory_manager()
memory_context = await memory_manager.build_context_for_node(user_id, session_id, "decision")
if memory_context:
logger.info(f"🧠 Using memory context for decision routing")
except Exception as e:
logger.warning(f"⚠️ Could not load memory context for decision: {e}")
try:
# Create a simple LLM for decision making
# Load main LLM using ModelManager
llm = await model_manager.load_llm_model(Constants.DEFAULT_LLM_NAME)
# Create decision prompt
decision_system_prompt = f"""
You are a routing assistant for ComputeAgent. Analyze the user's query and decide which workflow to use.
Choose between:
1. DEPLOY_MODEL - For queries about deploy AI model from HuggingFace. In this case the user MUST specify the model card name (like meta-llama/Meta-Llama-3-70B).
- The user can specify the hardware capacity needed.
- The user can ask for model analysis, deployment steps, or capacity estimation.
2. REACT_AGENT - For all the rest of queries.
{f"Conversation Context: {memory_context}" if memory_context else "No conversation context available."}
User Query: {query}
Respond with only: "DEPLOY_MODEL" or "REACT_AGENT"
"""
# Get routing decision
decision_response = await llm.ainvoke([
SystemMessage(content=decision_system_prompt)
])
routing_decision = decision_response.content.strip().upper()
# Validate and set decision
if "DEPLOY_MODEL" in routing_decision:
agent_decision = "deploy_model"
logger.info(f"📦 Routing to model deployment workflow")
elif "REACT_AGENT" in routing_decision:
agent_decision = "react_agent"
logger.info(f"⚛️ Routing to React agent workflow")
else:
# Default fallback to React agent for general queries
agent_decision = "react_agent"
logger.warning(f"⚠️ Ambiguous routing decision '{routing_decision}', defaulting to React agent")
# Update state with decision
updated_state = state.copy()
updated_state["agent_decision"] = agent_decision
updated_state["current_step"] = "decision_complete"
logger.info(f"✅ Decision node complete: {agent_decision}")
return updated_state
except Exception as e:
logger.error(f"❌ Error in decision node: {e}")
# Update state with fallback decision
updated_state = state.copy()
updated_state["error"] = f"Decision error (fallback used): {str(e)}"
return updated_state
def _create_graph(self) -> StateGraph:
"""
Create and configure the Compute Agent workflow graph.
This method builds the complete workflow including:
1. Initial decision node - routes to deployment or React agent
2. Model deployment path:
- Fetch model card from HuggingFace
- Extract model information
- Estimate capacity requirements
- Human approval checkpoint
- Deploy model or provide info
3. React agent path:
- Execute React agent with compute MCP capabilities
Returns:
Compiled StateGraph ready for execution
"""
workflow = StateGraph(AgentState)
# Add decision node
workflow.add_node("decision", self.decision_node)
# Add model deployment workflow nodes
workflow.add_node("deploy_model", self.deploy_subgraph.get_compiled_graph())
# Add React agent node
workflow.add_node("react_agent", self.react_subgraph.get_compiled_graph())
# Set entry point
workflow.set_entry_point("decision")
# Add conditional edges from decision node
workflow.add_conditional_edges(
"decision",
lambda state: state["agent_decision"],
{
"deploy_model": "deploy_model",
"react_agent": "react_agent",
}
)
# Add edges to END
workflow.add_edge("deploy_model", END)
workflow.add_edge("react_agent", END)
# Compile with checkpointer
return workflow.compile(checkpointer=memory_saver)
def get_compiled_graph(self):
"""Return the compiled graph for use in FastAPI"""
return self.graph
def invoke(self, query: str, user_id: str = "default_user", session_id: str = "default_session") -> Dict[str, Any]:
"""
Execute the graph with a given query and memory context (synchronous wrapper for async).
Args:
query: User's query
user_id: User identifier for memory management
session_id: Session identifier for memory management
Returns:
Final result from the graph execution
"""
return asyncio.run(self.ainvoke(query, user_id, session_id))
async def ainvoke(self, query: str, user_id: str = "default_user", session_id: str = "default_session") -> Dict[str, Any]:
"""
Execute the graph with a given query and memory context (async).
Args:
query: User's query
user_id: User identifier for memory management
session_id: Session identifier for memory management
Returns:
Final result from the graph execution containing:
- response: Final response to user
- agent_decision: Which path was taken
- deployment_result: If deployment path was taken
- react_results: If React agent path was taken
"""
initial_state = {
"user_id": user_id,
"session_id": session_id,
"query": query,
"response": "",
"current_step": "start",
"agent_decision": "",
"deployment_approved": False,
"model_name": "",
"model_card": {},
"model_info": {},
"capacity_estimate": {},
"deployment_result": {},
"react_results": {},
"tool_calls": [],
"tool_results": [],
"messages": [],
# Approval fields for ReactWorkflow
"pending_tool_calls": [],
"approved_tool_calls": [],
"rejected_tool_calls": [],
"modified_tool_calls": [],
"needs_re_reasoning": False,
"re_reasoning_feedback": ""
}
# Create config with thread_id for checkpointer
config = {
"configurable": {
"thread_id": f"{user_id}_{session_id}"
}
}
try:
result = await self.graph.ainvoke(initial_state, config)
return result
except Exception as e:
logger.error(f"Error in graph execution: {e}")
return {
**initial_state,
"error": str(e),
"error_step": initial_state.get("current_step", "unknown"),
"response": f"An error occurred during execution: {str(e)}"
}
async def astream_generate_nodes(self, query: str, user_id: str = "default_user", session_id: str = "default_session"):
"""
Stream the graph execution node by node (async).
Args:
query: User's query
user_id: User identifier for memory management
session_id: Session identifier for memory management
Yields:
Dict containing node execution updates
"""
initial_state = {
"user_id": user_id,
"session_id": session_id,
"query": query,
"response": "",
"current_step": "start",
"agent_decision": "",
"deployment_approved": False,
"model_name": "",
"model_card": {},
"model_info": {},
"capacity_estimate": {},
"deployment_result": {},
"react_results": {},
"tool_calls": [],
"tool_results": [],
"messages": [],
# Approval fields for ReactWorkflow
"pending_tool_calls": [],
"approved_tool_calls": [],
"rejected_tool_calls": [],
"modified_tool_calls": [],
"needs_re_reasoning": False,
"re_reasoning_feedback": ""
}
# Create config with thread_id for checkpointer
config = {
"configurable": {
"thread_id": f"{user_id}_{session_id}"
}
}
try:
# Stream through the graph execution
async for chunk in self.graph.astream(initial_state, config):
# Each chunk contains the node name and its output
for node_name, node_output in chunk.items():
yield {
"node": node_name,
"output": node_output,
**node_output # Include all state updates
}
except Exception as e:
logger.error(f"Error in graph streaming: {e}")
yield {
"error": str(e),
"status": "error",
"error_step": initial_state.get("current_step", "unknown")
}
def draw_graph(self, output_file_path: str = "basic_agent_graph.png"):
"""
Generate and save a visual representation of the Basic Agent workflow graph.
Args:
output_file_path: Path where to save the graph PNG file
"""
try:
self.graph.get_graph().draw_mermaid_png(output_file_path=output_file_path)
logger.info(f"✅ Basic Agent graph visualization saved to: {output_file_path}")
except Exception as e:
logger.error(f"❌ Failed to generate Basic Agent graph visualization: {e}")