File size: 15,212 Bytes
8816dfd 5dd4236 8816dfd 5dd4236 8816dfd 93850a2 8816dfd 93850a2 8816dfd 0297f14 8816dfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 |
"""
Deploy Model Graph - FIXED
This module implements the model deployment workflow graph for the ComputeAgent.
KEY FIX: DeployModelState now correctly inherits from AgentState (TypedDict)
instead of StateGraph.
Author: ComputeAgent Team
License: Private
"""
import logging
from typing import Dict, Any, Optional
from langgraph.graph import StateGraph, END
from langgraph.graph.state import CompiledStateGraph
from ComputeAgent.graph.graph_ReAct import ReactWorkflow
from ComputeAgent.graph.state import AgentState
# Import nodes from ReAct_DeployModel package
from ComputeAgent.nodes.ReAct_DeployModel.extract_model_info import extract_model_info_node
from ComputeAgent.nodes.ReAct_DeployModel.generate_additional_info import generate_additional_info_node
from ComputeAgent.nodes.ReAct_DeployModel.capacity_estimation import capacity_estimation_node
from ComputeAgent.nodes.ReAct_DeployModel.capacity_approval import capacity_approval_node, auto_capacity_approval_node
from ComputeAgent.models.model_manager import ModelManager
from langchain_mcp_adapters.client import MultiServerMCPClient
import os
# Import constants for human approval settings
from constant import Constants
# Initialize model manager for dynamic LLM loading and management
model_manager = ModelManager()
logger = logging.getLogger("ComputeAgent")
# Get the project root directory (parent of ComputeAgent folder)
import sys
# __file__ is in ComputeAgent/graph/graph_deploy.py
# Go up 3 levels: graph -> ComputeAgent -> project_root
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
mcp_server_path = os.path.join(project_root, "Compute_MCP", "main.py")
# Use sys.executable to get the current Python interpreter path
python_executable = sys.executable
mcp_client = MultiServerMCPClient(
{
"hivecompute": {
"command": python_executable,
"args": [mcp_server_path],
"transport": "stdio",
"env": {
# Pass HF Spaces secrets to the MCP subprocess
"HIVE_COMPUTE_DEFAULT_API_TOKEN": os.getenv("HIVE_COMPUTE_DEFAULT_API_TOKEN", ""),
"HIVE_COMPUTE_BASE_API_URL": os.getenv("HIVE_COMPUTE_BASE_API_URL", "https://api.hivecompute.ai"),
# Also pass these to ensure Python works correctly
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
}
}
}
)
logger = logging.getLogger("DeployModelGraph")
# Now inherits from AgentState (TypedDict) instead of StateGraph
class DeployModelState(AgentState):
"""
DeployModelState extends AgentState to inherit all base agent fields.
Inherited from AgentState (TypedDict):
- query: str
- response: str
- current_step: str
- messages: List[Dict[str, Any]]
- agent_decision: str
- deployment_approved: bool
- model_name: str
- llm: Any
- model_card: Dict[str, Any]
- model_info: Dict[str, Any]
- capacity_estimate: Dict[str, Any]
- deployment_result: Dict[str, Any]
- react_results: Dict[str, Any]
- tool_calls: List[Dict[str, Any]]
- tool_results: List[Dict[str, Any]]
All fields are inherited from AgentState - no additional fields needed.
"""
pass # Inherits all fields from AgentState
class DeployModelAgent:
"""
Standalone Deploy Model Agent class with memory and streaming support.
This class provides a dedicated interface for model deployment workflows
with full memory management and streaming capabilities.
"""
def __init__(self, llm, react_tools):
self.llm = llm
self.react_tools = react_tools
self.react_subgraph = ReactWorkflow(llm=self.llm, tools=self.react_tools)
self.graph = self._create_graph()
@classmethod
async def create(cls, llm=None, custom_tools=None):
"""
Async factory method for DeployModelAgent.
Args:
llm: Optional pre-loaded LLM
custom_tools: Optional pre-loaded tools for the nested ReactWorkflow
Returns:
DeployModelAgent instance
"""
if llm is None:
llm = await model_manager.load_llm_model(Constants.DEFAULT_LLM_FC)
if custom_tools is None:
# Load a separate MCP toolset for deployment React
custom_tools = await mcp_client.get_tools()
return cls(llm=llm, react_tools=custom_tools)
def _create_graph(self) -> CompiledStateGraph:
"""
Creates and configures the deploy model workflow.
β
FIXED: Now correctly creates StateGraph with DeployModelState (TypedDict)
"""
# β
This now works because DeployModelState is a TypedDict (via AgentState)
workflow = StateGraph(DeployModelState)
# Add nodes
workflow.add_node("extract_model_info", extract_model_info_node)
workflow.add_node("generate_model_name", generate_additional_info_node)
workflow.add_node("capacity_estimation", capacity_estimation_node)
workflow.add_node("capacity_approval", capacity_approval_node)
workflow.add_node("auto_capacity_approval", auto_capacity_approval_node)
workflow.add_node("react_deployment", self.react_subgraph.get_compiled_graph())
# Set entry point
workflow.set_entry_point("extract_model_info")
# Add conditional edges - Decision point after model extraction
workflow.add_conditional_edges(
"extract_model_info",
self.should_validate_or_generate,
{
"generate_model_name": "generate_model_name",
"capacity_estimation": "capacity_estimation"
}
)
# Add conditional edges from capacity estimation to approval
workflow.add_conditional_edges(
"capacity_estimation",
self.should_continue_to_capacity_approval,
{
"capacity_approval": "capacity_approval",
"auto_capacity_approval": "auto_capacity_approval",
"end": END
}
)
# Add conditional edges from capacity approval
workflow.add_conditional_edges(
"capacity_approval",
self.should_continue_after_capacity_approval,
{
"react_deployment": "react_deployment",
"capacity_estimation": "capacity_estimation",
"end": END
}
)
# Auto approval always goes to deployment
workflow.add_edge("auto_capacity_approval", "react_deployment")
# Final edges
workflow.add_edge("generate_model_name", END)
workflow.add_edge("react_deployment", END)
# Compile
return workflow.compile()
def get_compiled_graph(self):
"""Return the compiled graph for embedding in parent graph"""
return self.graph
def should_validate_or_generate(self, state: Dict[str, Any]) -> str:
"""
Decision routing function after model extraction.
Path 1: If model found and valid β proceed to capacity estimation
Path 1A: If no model info or invalid β generate helpful response with suggestions
Args:
state: Current workflow state
Returns:
Next node name or END
"""
if state.get("model_name") and state.get("model_info") and not state.get("model_info", {}).get("error"):
return "capacity_estimation" # Path 1: Valid model case
else:
return "generate_model_name" # Path 1A: No info case
def should_continue_to_capacity_approval(self, state: Dict[str, Any]) -> str:
"""
Determine whether to proceed to human approval, auto-approval, or end.
This function controls the flow after capacity estimation based on HUMAN_APPROVAL_CAPACITY setting:
- If HUMAN_APPROVAL_CAPACITY is True: Route to capacity_approval for manual approval
- If HUMAN_APPROVAL_CAPACITY is False: Route to auto_capacity_approval for automatic approval
- If capacity estimation failed: Route to end
Args:
state: Current workflow state containing capacity estimation results
Returns:
Next node name: "capacity_approval", "auto_capacity_approval", or "end"
"""
# Check if capacity estimation was successful
if state.get("capacity_estimation_status") != "success":
logger.info("π Capacity estimation failed - routing to end")
return "end"
# Check if human approval is enabled
HUMAN_APPROVAL_CAPACITY = True if Constants.HUMAN_APPROVAL_CAPACITY == "true" else False
if not HUMAN_APPROVAL_CAPACITY:
logger.info("π HUMAN_APPROVAL_CAPACITY disabled - routing to auto-approval")
return "auto_capacity_approval"
else:
logger.info("π HUMAN_APPROVAL_CAPACITY enabled - routing to human approval")
return "capacity_approval"
def should_continue_after_capacity_approval(self, state: Dict[str, Any]) -> str:
"""
Decide whether to proceed to ReAct deployment, re-estimate capacity, or end.
"""
logger.info(f"π Routing after capacity approval:")
logger.info(f" - capacity_approved: {state.get('capacity_approved')}")
logger.info(f" - needs_re_estimation: {state.get('needs_re_estimation')}")
logger.info(f" - capacity_approval_status: {state.get('capacity_approval_status')}")
# 1. FIRST check for re-estimation (highest priority)
needs_re_estimation = state.get("needs_re_estimation")
if needs_re_estimation is True:
logger.info("π Re-estimation requested - routing to capacity_estimation")
return "capacity_estimation"
# 2. THEN check if APPROVED (explicit True check)
capacity_approved = state.get("capacity_approved")
if capacity_approved is True:
logger.info("β
Capacity approved - proceeding to react_deployment")
return "react_deployment"
# 3. Check if REJECTED (explicit False check)
if capacity_approved is False:
logger.info("β Capacity rejected - ending workflow")
return "end"
# 4. If capacity_approved is None and no re-estimation, something is wrong
logger.warning(f"β οΈ Unexpected state in capacity approval routing")
logger.warning(f" capacity_approved: {capacity_approved} (type: {type(capacity_approved)})")
logger.warning(f" needs_re_estimation: {needs_re_estimation} (type: {type(needs_re_estimation)})")
logger.warning(f" Full state keys: {list(state.keys())}")
# Default to end to prevent infinite loops
return "end"
async def ainvoke(self,
query: str,
user_id: str = "default_user",
session_id: str = "default_session",
enable_memory: bool = False,
config: Optional[Dict] = None) -> Dict[str, Any]:
"""
Asynchronously invoke the Deploy Model Agent workflow.
Args:
query: User's model deployment query
user_id: User identifier for memory management
session_id: Session identifier for memory management
enable_memory: Whether to enable conversation memory management
config: Optional config dict
Returns:
Final workflow state with deployment results
"""
# Initialize state with all required fields from AgentState
initial_state = {
# Core fields
"query": query,
"response": "",
"current_step": "initialized",
"messages": [],
# Decision fields
"agent_decision": "",
"deployment_approved": False,
# Model deployment fields
"model_name": "",
"llm": None,
"model_card": {},
"model_info": {},
"capacity_estimate": {},
"deployment_result": {},
# React agent fields
"react_results": {},
"tool_calls": [],
"tool_results": [],
}
# Extract approval from config if provided
if config and "configurable" in config:
if "capacity_approved" in config["configurable"]:
initial_state["deployment_approved"] = config["configurable"]["capacity_approved"]
logger.info(f"π DeployModelAgent received approval: {config['configurable']['capacity_approved']}")
# Configure memory if checkpointer is available
memory_config = None
if self.checkpointer:
thread_id = f"{user_id}:{session_id}"
memory_config = {"configurable": {"thread_id": thread_id}}
# Merge configs
final_config = memory_config or {}
if config:
if "configurable" in final_config:
final_config["configurable"].update(config.get("configurable", {}))
else:
final_config = config
logger.info(f"π Starting Deploy Model workflow")
# Execute the graph
if final_config:
result = await self.graph.ainvoke(initial_state, final_config)
else:
result = await self.graph.ainvoke(initial_state)
return result
def invoke(self, query: str, user_id: str = "default_user", session_id: str = "default_session", enable_memory: bool = False) -> Dict[str, Any]:
"""
Synchronously invoke the Deploy Model Agent workflow.
Args:
query: User's model deployment query
user_id: User identifier for memory management
session_id: Session identifier for memory management
enable_memory: Whether to enable conversation memory management
Returns:
Final workflow state with deployment results
"""
import asyncio
return asyncio.run(self.ainvoke(query, user_id, session_id, enable_memory))
def draw_graph(self, output_file_path: str = "deploy_model_graph.png"):
"""
Generate and save a visual representation of the Deploy Model workflow graph.
Args:
output_file_path: Path where to save the graph PNG file
"""
try:
self.graph.get_graph().draw_mermaid_png(output_file_path=output_file_path)
logger.info(f"β
Graph visualization saved to: {output_file_path}")
except Exception as e:
logger.error(f"β Failed to generate graph visualization: {e}")
print(f"Error generating graph: {e}") |