|
|
""" |
|
|
Deploy Model Graph - FIXED |
|
|
|
|
|
This module implements the model deployment workflow graph for the ComputeAgent. |
|
|
|
|
|
KEY FIX: DeployModelState now correctly inherits from AgentState (TypedDict) |
|
|
instead of StateGraph. |
|
|
|
|
|
Author: ComputeAgent Team |
|
|
License: Private |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from typing import Dict, Any, Optional |
|
|
from langgraph.graph import StateGraph, END |
|
|
from langgraph.graph.state import CompiledStateGraph |
|
|
from ComputeAgent.graph.graph_ReAct import ReactWorkflow |
|
|
from ComputeAgent.graph.state import AgentState |
|
|
|
|
|
|
|
|
from ComputeAgent.nodes.ReAct_DeployModel.extract_model_info import extract_model_info_node |
|
|
from ComputeAgent.nodes.ReAct_DeployModel.generate_additional_info import generate_additional_info_node |
|
|
from ComputeAgent.nodes.ReAct_DeployModel.capacity_estimation import capacity_estimation_node |
|
|
from ComputeAgent.nodes.ReAct_DeployModel.capacity_approval import capacity_approval_node, auto_capacity_approval_node |
|
|
from ComputeAgent.models.model_manager import ModelManager |
|
|
from langchain_mcp_adapters.client import MultiServerMCPClient |
|
|
import os |
|
|
|
|
|
|
|
|
from constant import Constants |
|
|
|
|
|
|
|
|
model_manager = ModelManager() |
|
|
|
|
|
logger = logging.getLogger("ComputeAgent") |
|
|
|
|
|
|
|
|
import sys |
|
|
|
|
|
|
|
|
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
mcp_server_path = os.path.join(project_root, "Compute_MCP", "main.py") |
|
|
|
|
|
|
|
|
python_executable = sys.executable |
|
|
|
|
|
mcp_client = MultiServerMCPClient( |
|
|
{ |
|
|
"hivecompute": { |
|
|
"command": python_executable, |
|
|
"args": [mcp_server_path], |
|
|
"transport": "stdio", |
|
|
"env": { |
|
|
|
|
|
"HIVE_COMPUTE_DEFAULT_API_TOKEN": os.getenv("HIVE_COMPUTE_DEFAULT_API_TOKEN", ""), |
|
|
"HIVE_COMPUTE_BASE_API_URL": os.getenv("HIVE_COMPUTE_BASE_API_URL", "https://api.hivecompute.ai"), |
|
|
|
|
|
"PATH": os.getenv("PATH", ""), |
|
|
"PYTHONPATH": os.getenv("PYTHONPATH", ""), |
|
|
} |
|
|
} |
|
|
} |
|
|
) |
|
|
|
|
|
logger = logging.getLogger("DeployModelGraph") |
|
|
|
|
|
|
|
|
|
|
|
class DeployModelState(AgentState): |
|
|
""" |
|
|
DeployModelState extends AgentState to inherit all base agent fields. |
|
|
|
|
|
Inherited from AgentState (TypedDict): |
|
|
- query: str |
|
|
- response: str |
|
|
- current_step: str |
|
|
- messages: List[Dict[str, Any]] |
|
|
- agent_decision: str |
|
|
- deployment_approved: bool |
|
|
- model_name: str |
|
|
- llm: Any |
|
|
- model_card: Dict[str, Any] |
|
|
- model_info: Dict[str, Any] |
|
|
- capacity_estimate: Dict[str, Any] |
|
|
- deployment_result: Dict[str, Any] |
|
|
- react_results: Dict[str, Any] |
|
|
- tool_calls: List[Dict[str, Any]] |
|
|
- tool_results: List[Dict[str, Any]] |
|
|
|
|
|
All fields are inherited from AgentState - no additional fields needed. |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
class DeployModelAgent: |
|
|
""" |
|
|
Standalone Deploy Model Agent class with memory and streaming support. |
|
|
|
|
|
This class provides a dedicated interface for model deployment workflows |
|
|
with full memory management and streaming capabilities. |
|
|
""" |
|
|
|
|
|
def __init__(self, llm, react_tools): |
|
|
self.llm = llm |
|
|
self.react_tools = react_tools |
|
|
self.react_subgraph = ReactWorkflow(llm=self.llm, tools=self.react_tools) |
|
|
self.graph = self._create_graph() |
|
|
|
|
|
@classmethod |
|
|
async def create(cls, llm=None, custom_tools=None): |
|
|
""" |
|
|
Async factory method for DeployModelAgent. |
|
|
|
|
|
Args: |
|
|
llm: Optional pre-loaded LLM |
|
|
custom_tools: Optional pre-loaded tools for the nested ReactWorkflow |
|
|
|
|
|
Returns: |
|
|
DeployModelAgent instance |
|
|
""" |
|
|
if llm is None: |
|
|
llm = await model_manager.load_llm_model(Constants.DEFAULT_LLM_FC) |
|
|
|
|
|
if custom_tools is None: |
|
|
|
|
|
custom_tools = await mcp_client.get_tools() |
|
|
|
|
|
return cls(llm=llm, react_tools=custom_tools) |
|
|
|
|
|
def _create_graph(self) -> CompiledStateGraph: |
|
|
""" |
|
|
Creates and configures the deploy model workflow. |
|
|
|
|
|
β
FIXED: Now correctly creates StateGraph with DeployModelState (TypedDict) |
|
|
""" |
|
|
|
|
|
workflow = StateGraph(DeployModelState) |
|
|
|
|
|
|
|
|
workflow.add_node("extract_model_info", extract_model_info_node) |
|
|
workflow.add_node("generate_model_name", generate_additional_info_node) |
|
|
workflow.add_node("capacity_estimation", capacity_estimation_node) |
|
|
workflow.add_node("capacity_approval", capacity_approval_node) |
|
|
workflow.add_node("auto_capacity_approval", auto_capacity_approval_node) |
|
|
workflow.add_node("react_deployment", self.react_subgraph.get_compiled_graph()) |
|
|
|
|
|
|
|
|
workflow.set_entry_point("extract_model_info") |
|
|
|
|
|
|
|
|
workflow.add_conditional_edges( |
|
|
"extract_model_info", |
|
|
self.should_validate_or_generate, |
|
|
{ |
|
|
"generate_model_name": "generate_model_name", |
|
|
"capacity_estimation": "capacity_estimation" |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
workflow.add_conditional_edges( |
|
|
"capacity_estimation", |
|
|
self.should_continue_to_capacity_approval, |
|
|
{ |
|
|
"capacity_approval": "capacity_approval", |
|
|
"auto_capacity_approval": "auto_capacity_approval", |
|
|
"end": END |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
workflow.add_conditional_edges( |
|
|
"capacity_approval", |
|
|
self.should_continue_after_capacity_approval, |
|
|
{ |
|
|
"react_deployment": "react_deployment", |
|
|
"capacity_estimation": "capacity_estimation", |
|
|
"end": END |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
workflow.add_edge("auto_capacity_approval", "react_deployment") |
|
|
|
|
|
|
|
|
workflow.add_edge("generate_model_name", END) |
|
|
workflow.add_edge("react_deployment", END) |
|
|
|
|
|
|
|
|
return workflow.compile() |
|
|
|
|
|
def get_compiled_graph(self): |
|
|
"""Return the compiled graph for embedding in parent graph""" |
|
|
return self.graph |
|
|
|
|
|
def should_validate_or_generate(self, state: Dict[str, Any]) -> str: |
|
|
""" |
|
|
Decision routing function after model extraction. |
|
|
|
|
|
Path 1: If model found and valid β proceed to capacity estimation |
|
|
Path 1A: If no model info or invalid β generate helpful response with suggestions |
|
|
|
|
|
Args: |
|
|
state: Current workflow state |
|
|
|
|
|
Returns: |
|
|
Next node name or END |
|
|
""" |
|
|
if state.get("model_name") and state.get("model_info") and not state.get("model_info", {}).get("error"): |
|
|
return "capacity_estimation" |
|
|
else: |
|
|
return "generate_model_name" |
|
|
|
|
|
def should_continue_to_capacity_approval(self, state: Dict[str, Any]) -> str: |
|
|
""" |
|
|
Determine whether to proceed to human approval, auto-approval, or end. |
|
|
|
|
|
This function controls the flow after capacity estimation based on HUMAN_APPROVAL_CAPACITY setting: |
|
|
- If HUMAN_APPROVAL_CAPACITY is True: Route to capacity_approval for manual approval |
|
|
- If HUMAN_APPROVAL_CAPACITY is False: Route to auto_capacity_approval for automatic approval |
|
|
- If capacity estimation failed: Route to end |
|
|
|
|
|
Args: |
|
|
state: Current workflow state containing capacity estimation results |
|
|
|
|
|
Returns: |
|
|
Next node name: "capacity_approval", "auto_capacity_approval", or "end" |
|
|
""" |
|
|
|
|
|
if state.get("capacity_estimation_status") != "success": |
|
|
logger.info("π Capacity estimation failed - routing to end") |
|
|
return "end" |
|
|
|
|
|
|
|
|
HUMAN_APPROVAL_CAPACITY = True if Constants.HUMAN_APPROVAL_CAPACITY == "true" else False |
|
|
if not HUMAN_APPROVAL_CAPACITY: |
|
|
logger.info("π HUMAN_APPROVAL_CAPACITY disabled - routing to auto-approval") |
|
|
return "auto_capacity_approval" |
|
|
else: |
|
|
logger.info("π HUMAN_APPROVAL_CAPACITY enabled - routing to human approval") |
|
|
return "capacity_approval" |
|
|
|
|
|
def should_continue_after_capacity_approval(self, state: Dict[str, Any]) -> str: |
|
|
""" |
|
|
Decide whether to proceed to ReAct deployment, re-estimate capacity, or end. |
|
|
""" |
|
|
logger.info(f"π Routing after capacity approval:") |
|
|
logger.info(f" - capacity_approved: {state.get('capacity_approved')}") |
|
|
logger.info(f" - needs_re_estimation: {state.get('needs_re_estimation')}") |
|
|
logger.info(f" - capacity_approval_status: {state.get('capacity_approval_status')}") |
|
|
|
|
|
|
|
|
needs_re_estimation = state.get("needs_re_estimation") |
|
|
if needs_re_estimation is True: |
|
|
logger.info("π Re-estimation requested - routing to capacity_estimation") |
|
|
return "capacity_estimation" |
|
|
|
|
|
|
|
|
capacity_approved = state.get("capacity_approved") |
|
|
if capacity_approved is True: |
|
|
logger.info("β
Capacity approved - proceeding to react_deployment") |
|
|
return "react_deployment" |
|
|
|
|
|
|
|
|
if capacity_approved is False: |
|
|
logger.info("β Capacity rejected - ending workflow") |
|
|
return "end" |
|
|
|
|
|
|
|
|
logger.warning(f"β οΈ Unexpected state in capacity approval routing") |
|
|
logger.warning(f" capacity_approved: {capacity_approved} (type: {type(capacity_approved)})") |
|
|
logger.warning(f" needs_re_estimation: {needs_re_estimation} (type: {type(needs_re_estimation)})") |
|
|
logger.warning(f" Full state keys: {list(state.keys())}") |
|
|
|
|
|
|
|
|
return "end" |
|
|
|
|
|
async def ainvoke(self, |
|
|
query: str, |
|
|
user_id: str = "default_user", |
|
|
session_id: str = "default_session", |
|
|
enable_memory: bool = False, |
|
|
config: Optional[Dict] = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Asynchronously invoke the Deploy Model Agent workflow. |
|
|
|
|
|
Args: |
|
|
query: User's model deployment query |
|
|
user_id: User identifier for memory management |
|
|
session_id: Session identifier for memory management |
|
|
enable_memory: Whether to enable conversation memory management |
|
|
config: Optional config dict |
|
|
|
|
|
Returns: |
|
|
Final workflow state with deployment results |
|
|
""" |
|
|
|
|
|
initial_state = { |
|
|
|
|
|
"query": query, |
|
|
"response": "", |
|
|
"current_step": "initialized", |
|
|
"messages": [], |
|
|
|
|
|
|
|
|
"agent_decision": "", |
|
|
"deployment_approved": False, |
|
|
|
|
|
|
|
|
"model_name": "", |
|
|
"llm": None, |
|
|
"model_card": {}, |
|
|
"model_info": {}, |
|
|
"capacity_estimate": {}, |
|
|
"deployment_result": {}, |
|
|
|
|
|
|
|
|
"react_results": {}, |
|
|
"tool_calls": [], |
|
|
"tool_results": [], |
|
|
} |
|
|
|
|
|
|
|
|
if config and "configurable" in config: |
|
|
if "capacity_approved" in config["configurable"]: |
|
|
initial_state["deployment_approved"] = config["configurable"]["capacity_approved"] |
|
|
logger.info(f"π DeployModelAgent received approval: {config['configurable']['capacity_approved']}") |
|
|
|
|
|
|
|
|
memory_config = None |
|
|
if self.checkpointer: |
|
|
thread_id = f"{user_id}:{session_id}" |
|
|
memory_config = {"configurable": {"thread_id": thread_id}} |
|
|
|
|
|
|
|
|
final_config = memory_config or {} |
|
|
if config: |
|
|
if "configurable" in final_config: |
|
|
final_config["configurable"].update(config.get("configurable", {})) |
|
|
else: |
|
|
final_config = config |
|
|
|
|
|
logger.info(f"π Starting Deploy Model workflow") |
|
|
|
|
|
|
|
|
if final_config: |
|
|
result = await self.graph.ainvoke(initial_state, final_config) |
|
|
else: |
|
|
result = await self.graph.ainvoke(initial_state) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def invoke(self, query: str, user_id: str = "default_user", session_id: str = "default_session", enable_memory: bool = False) -> Dict[str, Any]: |
|
|
""" |
|
|
Synchronously invoke the Deploy Model Agent workflow. |
|
|
|
|
|
Args: |
|
|
query: User's model deployment query |
|
|
user_id: User identifier for memory management |
|
|
session_id: Session identifier for memory management |
|
|
enable_memory: Whether to enable conversation memory management |
|
|
|
|
|
Returns: |
|
|
Final workflow state with deployment results |
|
|
""" |
|
|
import asyncio |
|
|
return asyncio.run(self.ainvoke(query, user_id, session_id, enable_memory)) |
|
|
|
|
|
def draw_graph(self, output_file_path: str = "deploy_model_graph.png"): |
|
|
""" |
|
|
Generate and save a visual representation of the Deploy Model workflow graph. |
|
|
|
|
|
Args: |
|
|
output_file_path: Path where to save the graph PNG file |
|
|
""" |
|
|
try: |
|
|
self.graph.get_graph().draw_mermaid_png(output_file_path=output_file_path) |
|
|
logger.info(f"β
Graph visualization saved to: {output_file_path}") |
|
|
except Exception as e: |
|
|
logger.error(f"β Failed to generate graph visualization: {e}") |
|
|
print(f"Error generating graph: {e}") |