carraraig's picture
revert (#20)
0297f14 verified
"""
Deploy Model Graph - FIXED
This module implements the model deployment workflow graph for the ComputeAgent.
KEY FIX: DeployModelState now correctly inherits from AgentState (TypedDict)
instead of StateGraph.
Author: ComputeAgent Team
License: Private
"""
import logging
from typing import Dict, Any, Optional
from langgraph.graph import StateGraph, END
from langgraph.graph.state import CompiledStateGraph
from ComputeAgent.graph.graph_ReAct import ReactWorkflow
from ComputeAgent.graph.state import AgentState
# Import nodes from ReAct_DeployModel package
from ComputeAgent.nodes.ReAct_DeployModel.extract_model_info import extract_model_info_node
from ComputeAgent.nodes.ReAct_DeployModel.generate_additional_info import generate_additional_info_node
from ComputeAgent.nodes.ReAct_DeployModel.capacity_estimation import capacity_estimation_node
from ComputeAgent.nodes.ReAct_DeployModel.capacity_approval import capacity_approval_node, auto_capacity_approval_node
from ComputeAgent.models.model_manager import ModelManager
from langchain_mcp_adapters.client import MultiServerMCPClient
import os
# Import constants for human approval settings
from constant import Constants
# Initialize model manager for dynamic LLM loading and management
model_manager = ModelManager()
logger = logging.getLogger("ComputeAgent")
# Get the project root directory (parent of ComputeAgent folder)
import sys
# __file__ is in ComputeAgent/graph/graph_deploy.py
# Go up 3 levels: graph -> ComputeAgent -> project_root
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
mcp_server_path = os.path.join(project_root, "Compute_MCP", "main.py")
# Use sys.executable to get the current Python interpreter path
python_executable = sys.executable
mcp_client = MultiServerMCPClient(
{
"hivecompute": {
"command": python_executable,
"args": [mcp_server_path],
"transport": "stdio",
"env": {
# Pass HF Spaces secrets to the MCP subprocess
"HIVE_COMPUTE_DEFAULT_API_TOKEN": os.getenv("HIVE_COMPUTE_DEFAULT_API_TOKEN", ""),
"HIVE_COMPUTE_BASE_API_URL": os.getenv("HIVE_COMPUTE_BASE_API_URL", "https://api.hivecompute.ai"),
# Also pass these to ensure Python works correctly
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
}
}
}
)
logger = logging.getLogger("DeployModelGraph")
# Now inherits from AgentState (TypedDict) instead of StateGraph
class DeployModelState(AgentState):
"""
DeployModelState extends AgentState to inherit all base agent fields.
Inherited from AgentState (TypedDict):
- query: str
- response: str
- current_step: str
- messages: List[Dict[str, Any]]
- agent_decision: str
- deployment_approved: bool
- model_name: str
- llm: Any
- model_card: Dict[str, Any]
- model_info: Dict[str, Any]
- capacity_estimate: Dict[str, Any]
- deployment_result: Dict[str, Any]
- react_results: Dict[str, Any]
- tool_calls: List[Dict[str, Any]]
- tool_results: List[Dict[str, Any]]
All fields are inherited from AgentState - no additional fields needed.
"""
pass # Inherits all fields from AgentState
class DeployModelAgent:
"""
Standalone Deploy Model Agent class with memory and streaming support.
This class provides a dedicated interface for model deployment workflows
with full memory management and streaming capabilities.
"""
def __init__(self, llm, react_tools):
self.llm = llm
self.react_tools = react_tools
self.react_subgraph = ReactWorkflow(llm=self.llm, tools=self.react_tools)
self.graph = self._create_graph()
@classmethod
async def create(cls, llm=None, custom_tools=None):
"""
Async factory method for DeployModelAgent.
Args:
llm: Optional pre-loaded LLM
custom_tools: Optional pre-loaded tools for the nested ReactWorkflow
Returns:
DeployModelAgent instance
"""
if llm is None:
llm = await model_manager.load_llm_model(Constants.DEFAULT_LLM_FC)
if custom_tools is None:
# Load a separate MCP toolset for deployment React
custom_tools = await mcp_client.get_tools()
return cls(llm=llm, react_tools=custom_tools)
def _create_graph(self) -> CompiledStateGraph:
"""
Creates and configures the deploy model workflow.
βœ… FIXED: Now correctly creates StateGraph with DeployModelState (TypedDict)
"""
# βœ… This now works because DeployModelState is a TypedDict (via AgentState)
workflow = StateGraph(DeployModelState)
# Add nodes
workflow.add_node("extract_model_info", extract_model_info_node)
workflow.add_node("generate_model_name", generate_additional_info_node)
workflow.add_node("capacity_estimation", capacity_estimation_node)
workflow.add_node("capacity_approval", capacity_approval_node)
workflow.add_node("auto_capacity_approval", auto_capacity_approval_node)
workflow.add_node("react_deployment", self.react_subgraph.get_compiled_graph())
# Set entry point
workflow.set_entry_point("extract_model_info")
# Add conditional edges - Decision point after model extraction
workflow.add_conditional_edges(
"extract_model_info",
self.should_validate_or_generate,
{
"generate_model_name": "generate_model_name",
"capacity_estimation": "capacity_estimation"
}
)
# Add conditional edges from capacity estimation to approval
workflow.add_conditional_edges(
"capacity_estimation",
self.should_continue_to_capacity_approval,
{
"capacity_approval": "capacity_approval",
"auto_capacity_approval": "auto_capacity_approval",
"end": END
}
)
# Add conditional edges from capacity approval
workflow.add_conditional_edges(
"capacity_approval",
self.should_continue_after_capacity_approval,
{
"react_deployment": "react_deployment",
"capacity_estimation": "capacity_estimation",
"end": END
}
)
# Auto approval always goes to deployment
workflow.add_edge("auto_capacity_approval", "react_deployment")
# Final edges
workflow.add_edge("generate_model_name", END)
workflow.add_edge("react_deployment", END)
# Compile
return workflow.compile()
def get_compiled_graph(self):
"""Return the compiled graph for embedding in parent graph"""
return self.graph
def should_validate_or_generate(self, state: Dict[str, Any]) -> str:
"""
Decision routing function after model extraction.
Path 1: If model found and valid β†’ proceed to capacity estimation
Path 1A: If no model info or invalid β†’ generate helpful response with suggestions
Args:
state: Current workflow state
Returns:
Next node name or END
"""
if state.get("model_name") and state.get("model_info") and not state.get("model_info", {}).get("error"):
return "capacity_estimation" # Path 1: Valid model case
else:
return "generate_model_name" # Path 1A: No info case
def should_continue_to_capacity_approval(self, state: Dict[str, Any]) -> str:
"""
Determine whether to proceed to human approval, auto-approval, or end.
This function controls the flow after capacity estimation based on HUMAN_APPROVAL_CAPACITY setting:
- If HUMAN_APPROVAL_CAPACITY is True: Route to capacity_approval for manual approval
- If HUMAN_APPROVAL_CAPACITY is False: Route to auto_capacity_approval for automatic approval
- If capacity estimation failed: Route to end
Args:
state: Current workflow state containing capacity estimation results
Returns:
Next node name: "capacity_approval", "auto_capacity_approval", or "end"
"""
# Check if capacity estimation was successful
if state.get("capacity_estimation_status") != "success":
logger.info("πŸ”„ Capacity estimation failed - routing to end")
return "end"
# Check if human approval is enabled
HUMAN_APPROVAL_CAPACITY = True if Constants.HUMAN_APPROVAL_CAPACITY == "true" else False
if not HUMAN_APPROVAL_CAPACITY:
logger.info("πŸ”„ HUMAN_APPROVAL_CAPACITY disabled - routing to auto-approval")
return "auto_capacity_approval"
else:
logger.info("πŸ”„ HUMAN_APPROVAL_CAPACITY enabled - routing to human approval")
return "capacity_approval"
def should_continue_after_capacity_approval(self, state: Dict[str, Any]) -> str:
"""
Decide whether to proceed to ReAct deployment, re-estimate capacity, or end.
"""
logger.info(f"πŸ” Routing after capacity approval:")
logger.info(f" - capacity_approved: {state.get('capacity_approved')}")
logger.info(f" - needs_re_estimation: {state.get('needs_re_estimation')}")
logger.info(f" - capacity_approval_status: {state.get('capacity_approval_status')}")
# 1. FIRST check for re-estimation (highest priority)
needs_re_estimation = state.get("needs_re_estimation")
if needs_re_estimation is True:
logger.info("πŸ”„ Re-estimation requested - routing to capacity_estimation")
return "capacity_estimation"
# 2. THEN check if APPROVED (explicit True check)
capacity_approved = state.get("capacity_approved")
if capacity_approved is True:
logger.info("βœ… Capacity approved - proceeding to react_deployment")
return "react_deployment"
# 3. Check if REJECTED (explicit False check)
if capacity_approved is False:
logger.info("❌ Capacity rejected - ending workflow")
return "end"
# 4. If capacity_approved is None and no re-estimation, something is wrong
logger.warning(f"⚠️ Unexpected state in capacity approval routing")
logger.warning(f" capacity_approved: {capacity_approved} (type: {type(capacity_approved)})")
logger.warning(f" needs_re_estimation: {needs_re_estimation} (type: {type(needs_re_estimation)})")
logger.warning(f" Full state keys: {list(state.keys())}")
# Default to end to prevent infinite loops
return "end"
async def ainvoke(self,
query: str,
user_id: str = "default_user",
session_id: str = "default_session",
enable_memory: bool = False,
config: Optional[Dict] = None) -> Dict[str, Any]:
"""
Asynchronously invoke the Deploy Model Agent workflow.
Args:
query: User's model deployment query
user_id: User identifier for memory management
session_id: Session identifier for memory management
enable_memory: Whether to enable conversation memory management
config: Optional config dict
Returns:
Final workflow state with deployment results
"""
# Initialize state with all required fields from AgentState
initial_state = {
# Core fields
"query": query,
"response": "",
"current_step": "initialized",
"messages": [],
# Decision fields
"agent_decision": "",
"deployment_approved": False,
# Model deployment fields
"model_name": "",
"llm": None,
"model_card": {},
"model_info": {},
"capacity_estimate": {},
"deployment_result": {},
# React agent fields
"react_results": {},
"tool_calls": [],
"tool_results": [],
}
# Extract approval from config if provided
if config and "configurable" in config:
if "capacity_approved" in config["configurable"]:
initial_state["deployment_approved"] = config["configurable"]["capacity_approved"]
logger.info(f"πŸ“‹ DeployModelAgent received approval: {config['configurable']['capacity_approved']}")
# Configure memory if checkpointer is available
memory_config = None
if self.checkpointer:
thread_id = f"{user_id}:{session_id}"
memory_config = {"configurable": {"thread_id": thread_id}}
# Merge configs
final_config = memory_config or {}
if config:
if "configurable" in final_config:
final_config["configurable"].update(config.get("configurable", {}))
else:
final_config = config
logger.info(f"πŸš€ Starting Deploy Model workflow")
# Execute the graph
if final_config:
result = await self.graph.ainvoke(initial_state, final_config)
else:
result = await self.graph.ainvoke(initial_state)
return result
def invoke(self, query: str, user_id: str = "default_user", session_id: str = "default_session", enable_memory: bool = False) -> Dict[str, Any]:
"""
Synchronously invoke the Deploy Model Agent workflow.
Args:
query: User's model deployment query
user_id: User identifier for memory management
session_id: Session identifier for memory management
enable_memory: Whether to enable conversation memory management
Returns:
Final workflow state with deployment results
"""
import asyncio
return asyncio.run(self.ainvoke(query, user_id, session_id, enable_memory))
def draw_graph(self, output_file_path: str = "deploy_model_graph.png"):
"""
Generate and save a visual representation of the Deploy Model workflow graph.
Args:
output_file_path: Path where to save the graph PNG file
"""
try:
self.graph.get_graph().draw_mermaid_png(output_file_path=output_file_path)
logger.info(f"βœ… Graph visualization saved to: {output_file_path}")
except Exception as e:
logger.error(f"❌ Failed to generate graph visualization: {e}")
print(f"Error generating graph: {e}")