|
|
""" |
|
|
Basic Agent Main Graph Module (FastAPI Compatible - Minimal Changes) |
|
|
|
|
|
This module implements the core workflow graph for the Basic Agent system. |
|
|
It defines the agent's decision-making flow between model deployment and |
|
|
React-based compute workflows. |
|
|
|
|
|
CHANGES FROM ORIGINAL: |
|
|
- __init__ now accepts optional tools and llm parameters |
|
|
- Added async create() classmethod for FastAPI |
|
|
- Fully backwards compatible with existing CLI code |
|
|
|
|
|
Author: Your Name |
|
|
License: Private |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
from typing import Dict, Any, List, Optional |
|
|
import uuid |
|
|
import json |
|
|
import logging |
|
|
|
|
|
from langgraph.graph import StateGraph, END, START |
|
|
from typing_extensions import TypedDict |
|
|
from constant import Constants |
|
|
|
|
|
|
|
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
from ComputeAgent.graph.graph_deploy import DeployModelAgent |
|
|
from ComputeAgent.graph.graph_ReAct import ReactWorkflow |
|
|
from ComputeAgent.models.model_manager import ModelManager |
|
|
from langchain_core.messages import HumanMessage, SystemMessage |
|
|
from langchain_mcp_adapters.client import MultiServerMCPClient |
|
|
from ComputeAgent.graph.state import AgentState |
|
|
import os |
|
|
|
|
|
|
|
|
model_manager = ModelManager() |
|
|
|
|
|
|
|
|
memory_saver = MemorySaver() |
|
|
|
|
|
logger = logging.getLogger("ComputeAgent") |
|
|
|
|
|
|
|
|
import sys |
|
|
|
|
|
|
|
|
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
mcp_server_path = os.path.join(project_root, "Compute_MCP", "main.py") |
|
|
|
|
|
|
|
|
python_executable = sys.executable |
|
|
|
|
|
mcp_client = MultiServerMCPClient( |
|
|
{ |
|
|
"hivecompute": { |
|
|
"command": python_executable, |
|
|
"args": [mcp_server_path], |
|
|
"transport": "stdio", |
|
|
"env": { |
|
|
|
|
|
"HIVE_COMPUTE_DEFAULT_API_TOKEN": os.getenv("HIVE_COMPUTE_DEFAULT_API_TOKEN", ""), |
|
|
"HIVE_COMPUTE_BASE_API_URL": os.getenv("HIVE_COMPUTE_BASE_API_URL", "https://api.hivecompute.ai"), |
|
|
|
|
|
"PATH": os.getenv("PATH", ""), |
|
|
"PYTHONPATH": os.getenv("PYTHONPATH", ""), |
|
|
} |
|
|
} |
|
|
} |
|
|
) |
|
|
|
|
|
class ComputeAgent: |
|
|
""" |
|
|
Main Compute Agent class providing AI-powered decision routing and execution. |
|
|
|
|
|
This class orchestrates the complete agent workflow including: |
|
|
- Decision routing between model deployment and React agent |
|
|
- Model deployment workflow with capacity estimation and approval |
|
|
- React agent execution with compute capabilities |
|
|
- Error handling and state management |
|
|
|
|
|
Attributes: |
|
|
graph: Compiled LangGraph workflow |
|
|
model_name: Default model name for operations |
|
|
|
|
|
Usage: |
|
|
# For CLI (backwards compatible): |
|
|
agent = ComputeAgent() |
|
|
|
|
|
# For FastAPI (async): |
|
|
agent = await ComputeAgent.create() |
|
|
""" |
|
|
|
|
|
def __init__(self, tools=None, llm=None): |
|
|
""" |
|
|
Initialize Compute Agent with optional pre-loaded dependencies. |
|
|
|
|
|
Args: |
|
|
tools: Pre-loaded MCP tools (optional, will load if not provided) |
|
|
llm: Pre-loaded LLM model (optional, will load if not provided) |
|
|
""" |
|
|
|
|
|
if tools is None: |
|
|
self.tools = asyncio.run(mcp_client.get_tools()) |
|
|
else: |
|
|
self.tools = tools |
|
|
|
|
|
if llm is None: |
|
|
self.llm = asyncio.run(model_manager.load_llm_model(Constants.DEFAULT_LLM_FC)) |
|
|
else: |
|
|
self.llm = llm |
|
|
|
|
|
self.deploy_subgraph = DeployModelAgent(llm=self.llm, react_tools=self.tools) |
|
|
self.react_subgraph = ReactWorkflow(llm=self.llm, tools=self.tools) |
|
|
self.graph = self._create_graph() |
|
|
|
|
|
@classmethod |
|
|
async def create(cls): |
|
|
""" |
|
|
Async factory method for creating ComputeAgent. |
|
|
Use this in FastAPI to avoid asyncio.run() issues. |
|
|
|
|
|
Returns: |
|
|
Initialized ComputeAgent instance |
|
|
""" |
|
|
logger.info("🔧 Loading tools and LLM asynchronously...") |
|
|
tools = await mcp_client.get_tools() |
|
|
llm = await model_manager.load_llm_model(Constants.DEFAULT_LLM_FC) |
|
|
|
|
|
deploy_subgraph = await DeployModelAgent.create(llm=llm, custom_tools=None) |
|
|
return cls(tools=tools, llm=llm) |
|
|
|
|
|
async def decision_node(self, state: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Node that handles routing decisions for the ComputeAgent workflow. |
|
|
|
|
|
Analyzes the user query to determine whether to route to: |
|
|
- Model deployment workflow (deploy_model) |
|
|
- React agent workflow (react_agent) |
|
|
|
|
|
Args: |
|
|
state: Current agent state with memory fields |
|
|
|
|
|
Returns: |
|
|
Updated state with routing decision |
|
|
""" |
|
|
|
|
|
user_id = state.get("user_id", "") |
|
|
session_id = state.get("session_id", "") |
|
|
query = state.get("query", "") |
|
|
|
|
|
logger.info(f"🎯 Decision node processing query for {user_id}:{session_id}") |
|
|
|
|
|
|
|
|
memory_context = "" |
|
|
if user_id and session_id: |
|
|
try: |
|
|
from helpers.memory import get_memory_manager |
|
|
memory_manager = get_memory_manager() |
|
|
memory_context = await memory_manager.build_context_for_node(user_id, session_id, "decision") |
|
|
if memory_context: |
|
|
logger.info(f"🧠 Using memory context for decision routing") |
|
|
except Exception as e: |
|
|
logger.warning(f"⚠️ Could not load memory context for decision: {e}") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
llm = await model_manager.load_llm_model(Constants.DEFAULT_LLM_NAME) |
|
|
|
|
|
|
|
|
decision_system_prompt = f""" |
|
|
You are a routing assistant for ComputeAgent. Analyze the user's query and decide which workflow to use. |
|
|
|
|
|
Choose between: |
|
|
1. DEPLOY_MODEL - For queries about deploy AI model from HuggingFace. In this case the user MUST specify the model card name (like meta-llama/Meta-Llama-3-70B). |
|
|
- The user can specify the hardware capacity needed. |
|
|
- The user can ask for model analysis, deployment steps, or capacity estimation. |
|
|
|
|
|
2. REACT_AGENT - For all the rest of queries. |
|
|
|
|
|
{f"Conversation Context: {memory_context}" if memory_context else "No conversation context available."} |
|
|
|
|
|
User Query: {query} |
|
|
|
|
|
Respond with only: "DEPLOY_MODEL" or "REACT_AGENT" |
|
|
""" |
|
|
|
|
|
|
|
|
decision_response = await llm.ainvoke([ |
|
|
SystemMessage(content=decision_system_prompt) |
|
|
]) |
|
|
|
|
|
routing_decision = decision_response.content.strip().upper() |
|
|
|
|
|
|
|
|
if "DEPLOY_MODEL" in routing_decision: |
|
|
agent_decision = "deploy_model" |
|
|
logger.info(f"📦 Routing to model deployment workflow") |
|
|
elif "REACT_AGENT" in routing_decision: |
|
|
agent_decision = "react_agent" |
|
|
logger.info(f"⚛️ Routing to React agent workflow") |
|
|
else: |
|
|
|
|
|
agent_decision = "react_agent" |
|
|
logger.warning(f"⚠️ Ambiguous routing decision '{routing_decision}', defaulting to React agent") |
|
|
|
|
|
|
|
|
updated_state = state.copy() |
|
|
updated_state["agent_decision"] = agent_decision |
|
|
updated_state["current_step"] = "decision_complete" |
|
|
|
|
|
logger.info(f"✅ Decision node complete: {agent_decision}") |
|
|
return updated_state |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Error in decision node: {e}") |
|
|
|
|
|
|
|
|
updated_state = state.copy() |
|
|
updated_state["error"] = f"Decision error (fallback used): {str(e)}" |
|
|
|
|
|
return updated_state |
|
|
|
|
|
def _create_graph(self) -> StateGraph: |
|
|
""" |
|
|
Create and configure the Compute Agent workflow graph. |
|
|
|
|
|
This method builds the complete workflow including: |
|
|
1. Initial decision node - routes to deployment or React agent |
|
|
2. Model deployment path: |
|
|
- Fetch model card from HuggingFace |
|
|
- Extract model information |
|
|
- Estimate capacity requirements |
|
|
- Human approval checkpoint |
|
|
- Deploy model or provide info |
|
|
3. React agent path: |
|
|
- Execute React agent with compute MCP capabilities |
|
|
|
|
|
Returns: |
|
|
Compiled StateGraph ready for execution |
|
|
""" |
|
|
workflow = StateGraph(AgentState) |
|
|
|
|
|
|
|
|
workflow.add_node("decision", self.decision_node) |
|
|
|
|
|
|
|
|
workflow.add_node("deploy_model", self.deploy_subgraph.get_compiled_graph()) |
|
|
|
|
|
|
|
|
workflow.add_node("react_agent", self.react_subgraph.get_compiled_graph()) |
|
|
|
|
|
|
|
|
workflow.set_entry_point("decision") |
|
|
|
|
|
|
|
|
workflow.add_conditional_edges( |
|
|
"decision", |
|
|
lambda state: state["agent_decision"], |
|
|
{ |
|
|
"deploy_model": "deploy_model", |
|
|
"react_agent": "react_agent", |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
workflow.add_edge("deploy_model", END) |
|
|
workflow.add_edge("react_agent", END) |
|
|
|
|
|
|
|
|
return workflow.compile(checkpointer=memory_saver) |
|
|
|
|
|
def get_compiled_graph(self): |
|
|
"""Return the compiled graph for use in FastAPI""" |
|
|
return self.graph |
|
|
|
|
|
def invoke(self, query: str, user_id: str = "default_user", session_id: str = "default_session") -> Dict[str, Any]: |
|
|
""" |
|
|
Execute the graph with a given query and memory context (synchronous wrapper for async). |
|
|
|
|
|
Args: |
|
|
query: User's query |
|
|
user_id: User identifier for memory management |
|
|
session_id: Session identifier for memory management |
|
|
|
|
|
Returns: |
|
|
Final result from the graph execution |
|
|
""" |
|
|
return asyncio.run(self.ainvoke(query, user_id, session_id)) |
|
|
|
|
|
async def ainvoke(self, query: str, user_id: str = "default_user", session_id: str = "default_session") -> Dict[str, Any]: |
|
|
""" |
|
|
Execute the graph with a given query and memory context (async). |
|
|
|
|
|
Args: |
|
|
query: User's query |
|
|
user_id: User identifier for memory management |
|
|
session_id: Session identifier for memory management |
|
|
|
|
|
Returns: |
|
|
Final result from the graph execution containing: |
|
|
- response: Final response to user |
|
|
- agent_decision: Which path was taken |
|
|
- deployment_result: If deployment path was taken |
|
|
- react_results: If React agent path was taken |
|
|
""" |
|
|
initial_state = { |
|
|
"user_id": user_id, |
|
|
"session_id": session_id, |
|
|
"query": query, |
|
|
"response": "", |
|
|
"current_step": "start", |
|
|
"agent_decision": "", |
|
|
"deployment_approved": False, |
|
|
"model_name": "", |
|
|
"model_card": {}, |
|
|
"model_info": {}, |
|
|
"capacity_estimate": {}, |
|
|
"deployment_result": {}, |
|
|
"react_results": {}, |
|
|
"tool_calls": [], |
|
|
"tool_results": [], |
|
|
"messages": [], |
|
|
|
|
|
"pending_tool_calls": [], |
|
|
"approved_tool_calls": [], |
|
|
"rejected_tool_calls": [], |
|
|
"modified_tool_calls": [], |
|
|
"needs_re_reasoning": False, |
|
|
"re_reasoning_feedback": "" |
|
|
} |
|
|
|
|
|
|
|
|
config = { |
|
|
"configurable": { |
|
|
"thread_id": f"{user_id}_{session_id}" |
|
|
} |
|
|
} |
|
|
|
|
|
try: |
|
|
result = await self.graph.ainvoke(initial_state, config) |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in graph execution: {e}") |
|
|
return { |
|
|
**initial_state, |
|
|
"error": str(e), |
|
|
"error_step": initial_state.get("current_step", "unknown"), |
|
|
"response": f"An error occurred during execution: {str(e)}" |
|
|
} |
|
|
|
|
|
async def astream_generate_nodes(self, query: str, user_id: str = "default_user", session_id: str = "default_session"): |
|
|
""" |
|
|
Stream the graph execution node by node (async). |
|
|
|
|
|
Args: |
|
|
query: User's query |
|
|
user_id: User identifier for memory management |
|
|
session_id: Session identifier for memory management |
|
|
|
|
|
Yields: |
|
|
Dict containing node execution updates |
|
|
""" |
|
|
initial_state = { |
|
|
"user_id": user_id, |
|
|
"session_id": session_id, |
|
|
"query": query, |
|
|
"response": "", |
|
|
"current_step": "start", |
|
|
"agent_decision": "", |
|
|
"deployment_approved": False, |
|
|
"model_name": "", |
|
|
"model_card": {}, |
|
|
"model_info": {}, |
|
|
"capacity_estimate": {}, |
|
|
"deployment_result": {}, |
|
|
"react_results": {}, |
|
|
"tool_calls": [], |
|
|
"tool_results": [], |
|
|
"messages": [], |
|
|
|
|
|
"pending_tool_calls": [], |
|
|
"approved_tool_calls": [], |
|
|
"rejected_tool_calls": [], |
|
|
"modified_tool_calls": [], |
|
|
"needs_re_reasoning": False, |
|
|
"re_reasoning_feedback": "" |
|
|
} |
|
|
|
|
|
|
|
|
config = { |
|
|
"configurable": { |
|
|
"thread_id": f"{user_id}_{session_id}" |
|
|
} |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
async for chunk in self.graph.astream(initial_state, config): |
|
|
|
|
|
for node_name, node_output in chunk.items(): |
|
|
yield { |
|
|
"node": node_name, |
|
|
"output": node_output, |
|
|
**node_output |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in graph streaming: {e}") |
|
|
yield { |
|
|
"error": str(e), |
|
|
"status": "error", |
|
|
"error_step": initial_state.get("current_step", "unknown") |
|
|
} |
|
|
|
|
|
def draw_graph(self, output_file_path: str = "basic_agent_graph.png"): |
|
|
""" |
|
|
Generate and save a visual representation of the Basic Agent workflow graph. |
|
|
|
|
|
Args: |
|
|
output_file_path: Path where to save the graph PNG file |
|
|
""" |
|
|
try: |
|
|
self.graph.get_graph().draw_mermaid_png(output_file_path=output_file_path) |
|
|
logger.info(f"✅ Basic Agent graph visualization saved to: {output_file_path}") |
|
|
except Exception as e: |
|
|
logger.error(f"❌ Failed to generate Basic Agent graph visualization: {e}") |