File size: 16,272 Bytes
8816dfd 5dd4236 8816dfd 5dd4236 93850a2 8816dfd 93850a2 8816dfd 0297f14 8816dfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 |
"""
Basic Agent Main Graph Module (FastAPI Compatible - Minimal Changes)
This module implements the core workflow graph for the Basic Agent system.
It defines the agent's decision-making flow between model deployment and
React-based compute workflows.
CHANGES FROM ORIGINAL:
- __init__ now accepts optional tools and llm parameters
- Added async create() classmethod for FastAPI
- Fully backwards compatible with existing CLI code
Author: Your Name
License: Private
"""
import asyncio
from typing import Dict, Any, List, Optional
import uuid
import json
import logging
from langgraph.graph import StateGraph, END, START
from typing_extensions import TypedDict
from constant import Constants
# Import node functions (to be implemented in separate files)
from langgraph.checkpoint.memory import MemorySaver
from ComputeAgent.graph.graph_deploy import DeployModelAgent
from ComputeAgent.graph.graph_ReAct import ReactWorkflow
from ComputeAgent.models.model_manager import ModelManager
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_mcp_adapters.client import MultiServerMCPClient
from ComputeAgent.graph.state import AgentState
import os
# Initialize model manager for dynamic LLM loading and management
model_manager = ModelManager()
# Global MemorySaver (persists state across requests)
memory_saver = MemorySaver()
logger = logging.getLogger("ComputeAgent")
# Get the project root directory (parent of ComputeAgent folder)
import sys
# __file__ is in ComputeAgent/graph/graph.py
# Go up 3 levels: graph -> ComputeAgent -> project_root
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
mcp_server_path = os.path.join(project_root, "Compute_MCP", "main.py")
# Use sys.executable to get the current Python interpreter path
python_executable = sys.executable
mcp_client = MultiServerMCPClient(
{
"hivecompute": {
"command": python_executable,
"args": [mcp_server_path],
"transport": "stdio",
"env": {
# Pass HF Spaces secrets to the MCP subprocess
"HIVE_COMPUTE_DEFAULT_API_TOKEN": os.getenv("HIVE_COMPUTE_DEFAULT_API_TOKEN", ""),
"HIVE_COMPUTE_BASE_API_URL": os.getenv("HIVE_COMPUTE_BASE_API_URL", "https://api.hivecompute.ai"),
# Also pass these to ensure Python works correctly
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
}
}
}
)
class ComputeAgent:
"""
Main Compute Agent class providing AI-powered decision routing and execution.
This class orchestrates the complete agent workflow including:
- Decision routing between model deployment and React agent
- Model deployment workflow with capacity estimation and approval
- React agent execution with compute capabilities
- Error handling and state management
Attributes:
graph: Compiled LangGraph workflow
model_name: Default model name for operations
Usage:
# For CLI (backwards compatible):
agent = ComputeAgent()
# For FastAPI (async):
agent = await ComputeAgent.create()
"""
def __init__(self, tools=None, llm=None):
"""
Initialize Compute Agent with optional pre-loaded dependencies.
Args:
tools: Pre-loaded MCP tools (optional, will load if not provided)
llm: Pre-loaded LLM model (optional, will load if not provided)
"""
# If tools/llm not provided, load them synchronously (for CLI)
if tools is None:
self.tools = asyncio.run(mcp_client.get_tools())
else:
self.tools = tools
if llm is None:
self.llm = asyncio.run(model_manager.load_llm_model(Constants.DEFAULT_LLM_FC))
else:
self.llm = llm
self.deploy_subgraph = DeployModelAgent(llm=self.llm, react_tools=self.tools)
self.react_subgraph = ReactWorkflow(llm=self.llm, tools=self.tools)
self.graph = self._create_graph()
@classmethod
async def create(cls):
"""
Async factory method for creating ComputeAgent.
Use this in FastAPI to avoid asyncio.run() issues.
Returns:
Initialized ComputeAgent instance
"""
logger.info("π§ Loading tools and LLM asynchronously...")
tools = await mcp_client.get_tools()
llm = await model_manager.load_llm_model(Constants.DEFAULT_LLM_FC)
# Initialize DeployModelAgent with its own tools
deploy_subgraph = await DeployModelAgent.create(llm=llm, custom_tools=None)
return cls(tools=tools, llm=llm)
async def decision_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""
Node that handles routing decisions for the ComputeAgent workflow.
Analyzes the user query to determine whether to route to:
- Model deployment workflow (deploy_model)
- React agent workflow (react_agent)
Args:
state: Current agent state with memory fields
Returns:
Updated state with routing decision
"""
# Get user context
user_id = state.get("user_id", "")
session_id = state.get("session_id", "")
query = state.get("query", "")
logger.info(f"π― Decision node processing query for {user_id}:{session_id}")
# Build memory context for decision making
memory_context = ""
if user_id and session_id:
try:
from helpers.memory import get_memory_manager
memory_manager = get_memory_manager()
memory_context = await memory_manager.build_context_for_node(user_id, session_id, "decision")
if memory_context:
logger.info(f"π§ Using memory context for decision routing")
except Exception as e:
logger.warning(f"β οΈ Could not load memory context for decision: {e}")
try:
# Create a simple LLM for decision making
# Load main LLM using ModelManager
llm = await model_manager.load_llm_model(Constants.DEFAULT_LLM_NAME)
# Create decision prompt
decision_system_prompt = f"""
You are a routing assistant for ComputeAgent. Analyze the user's query and decide which workflow to use.
Choose between:
1. DEPLOY_MODEL - For queries about deploy AI model from HuggingFace. In this case the user MUST specify the model card name (like meta-llama/Meta-Llama-3-70B).
- The user can specify the hardware capacity needed.
- The user can ask for model analysis, deployment steps, or capacity estimation.
2. REACT_AGENT - For all the rest of queries.
{f"Conversation Context: {memory_context}" if memory_context else "No conversation context available."}
User Query: {query}
Respond with only: "DEPLOY_MODEL" or "REACT_AGENT"
"""
# Get routing decision
decision_response = await llm.ainvoke([
SystemMessage(content=decision_system_prompt)
])
routing_decision = decision_response.content.strip().upper()
# Validate and set decision
if "DEPLOY_MODEL" in routing_decision:
agent_decision = "deploy_model"
logger.info(f"π¦ Routing to model deployment workflow")
elif "REACT_AGENT" in routing_decision:
agent_decision = "react_agent"
logger.info(f"βοΈ Routing to React agent workflow")
else:
# Default fallback to React agent for general queries
agent_decision = "react_agent"
logger.warning(f"β οΈ Ambiguous routing decision '{routing_decision}', defaulting to React agent")
# Update state with decision
updated_state = state.copy()
updated_state["agent_decision"] = agent_decision
updated_state["current_step"] = "decision_complete"
logger.info(f"β
Decision node complete: {agent_decision}")
return updated_state
except Exception as e:
logger.error(f"β Error in decision node: {e}")
# Update state with fallback decision
updated_state = state.copy()
updated_state["error"] = f"Decision error (fallback used): {str(e)}"
return updated_state
def _create_graph(self) -> StateGraph:
"""
Create and configure the Compute Agent workflow graph.
This method builds the complete workflow including:
1. Initial decision node - routes to deployment or React agent
2. Model deployment path:
- Fetch model card from HuggingFace
- Extract model information
- Estimate capacity requirements
- Human approval checkpoint
- Deploy model or provide info
3. React agent path:
- Execute React agent with compute MCP capabilities
Returns:
Compiled StateGraph ready for execution
"""
workflow = StateGraph(AgentState)
# Add decision node
workflow.add_node("decision", self.decision_node)
# Add model deployment workflow nodes
workflow.add_node("deploy_model", self.deploy_subgraph.get_compiled_graph())
# Add React agent node
workflow.add_node("react_agent", self.react_subgraph.get_compiled_graph())
# Set entry point
workflow.set_entry_point("decision")
# Add conditional edges from decision node
workflow.add_conditional_edges(
"decision",
lambda state: state["agent_decision"],
{
"deploy_model": "deploy_model",
"react_agent": "react_agent",
}
)
# Add edges to END
workflow.add_edge("deploy_model", END)
workflow.add_edge("react_agent", END)
# Compile with checkpointer
return workflow.compile(checkpointer=memory_saver)
def get_compiled_graph(self):
"""Return the compiled graph for use in FastAPI"""
return self.graph
def invoke(self, query: str, user_id: str = "default_user", session_id: str = "default_session") -> Dict[str, Any]:
"""
Execute the graph with a given query and memory context (synchronous wrapper for async).
Args:
query: User's query
user_id: User identifier for memory management
session_id: Session identifier for memory management
Returns:
Final result from the graph execution
"""
return asyncio.run(self.ainvoke(query, user_id, session_id))
async def ainvoke(self, query: str, user_id: str = "default_user", session_id: str = "default_session") -> Dict[str, Any]:
"""
Execute the graph with a given query and memory context (async).
Args:
query: User's query
user_id: User identifier for memory management
session_id: Session identifier for memory management
Returns:
Final result from the graph execution containing:
- response: Final response to user
- agent_decision: Which path was taken
- deployment_result: If deployment path was taken
- react_results: If React agent path was taken
"""
initial_state = {
"user_id": user_id,
"session_id": session_id,
"query": query,
"response": "",
"current_step": "start",
"agent_decision": "",
"deployment_approved": False,
"model_name": "",
"model_card": {},
"model_info": {},
"capacity_estimate": {},
"deployment_result": {},
"react_results": {},
"tool_calls": [],
"tool_results": [],
"messages": [],
# Approval fields for ReactWorkflow
"pending_tool_calls": [],
"approved_tool_calls": [],
"rejected_tool_calls": [],
"modified_tool_calls": [],
"needs_re_reasoning": False,
"re_reasoning_feedback": ""
}
# Create config with thread_id for checkpointer
config = {
"configurable": {
"thread_id": f"{user_id}_{session_id}"
}
}
try:
result = await self.graph.ainvoke(initial_state, config)
return result
except Exception as e:
logger.error(f"Error in graph execution: {e}")
return {
**initial_state,
"error": str(e),
"error_step": initial_state.get("current_step", "unknown"),
"response": f"An error occurred during execution: {str(e)}"
}
async def astream_generate_nodes(self, query: str, user_id: str = "default_user", session_id: str = "default_session"):
"""
Stream the graph execution node by node (async).
Args:
query: User's query
user_id: User identifier for memory management
session_id: Session identifier for memory management
Yields:
Dict containing node execution updates
"""
initial_state = {
"user_id": user_id,
"session_id": session_id,
"query": query,
"response": "",
"current_step": "start",
"agent_decision": "",
"deployment_approved": False,
"model_name": "",
"model_card": {},
"model_info": {},
"capacity_estimate": {},
"deployment_result": {},
"react_results": {},
"tool_calls": [],
"tool_results": [],
"messages": [],
# Approval fields for ReactWorkflow
"pending_tool_calls": [],
"approved_tool_calls": [],
"rejected_tool_calls": [],
"modified_tool_calls": [],
"needs_re_reasoning": False,
"re_reasoning_feedback": ""
}
# Create config with thread_id for checkpointer
config = {
"configurable": {
"thread_id": f"{user_id}_{session_id}"
}
}
try:
# Stream through the graph execution
async for chunk in self.graph.astream(initial_state, config):
# Each chunk contains the node name and its output
for node_name, node_output in chunk.items():
yield {
"node": node_name,
"output": node_output,
**node_output # Include all state updates
}
except Exception as e:
logger.error(f"Error in graph streaming: {e}")
yield {
"error": str(e),
"status": "error",
"error_step": initial_state.get("current_step", "unknown")
}
def draw_graph(self, output_file_path: str = "basic_agent_graph.png"):
"""
Generate and save a visual representation of the Basic Agent workflow graph.
Args:
output_file_path: Path where to save the graph PNG file
"""
try:
self.graph.get_graph().draw_mermaid_png(output_file_path=output_file_path)
logger.info(f"β
Basic Agent graph visualization saved to: {output_file_path}")
except Exception as e:
logger.error(f"β Failed to generate Basic Agent graph visualization: {e}") |