File size: 15,212 Bytes
8816dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd4236
 
8816dfd
 
5dd4236
 
 
 
 
8816dfd
93850a2
8816dfd
 
 
 
 
 
 
 
 
93850a2
 
 
 
 
 
 
 
 
 
8816dfd
 
 
0297f14
 
 
 
 
 
 
 
 
 
 
8816dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
"""
Deploy Model Graph - FIXED

This module implements the model deployment workflow graph for the ComputeAgent.

KEY FIX: DeployModelState now correctly inherits from AgentState (TypedDict) 
instead of StateGraph.

Author: ComputeAgent Team
License: Private
"""

import logging
from typing import Dict, Any, Optional
from langgraph.graph import StateGraph, END
from langgraph.graph.state import CompiledStateGraph
from ComputeAgent.graph.graph_ReAct import ReactWorkflow
from ComputeAgent.graph.state import AgentState

# Import nodes from ReAct_DeployModel package
from ComputeAgent.nodes.ReAct_DeployModel.extract_model_info import extract_model_info_node
from ComputeAgent.nodes.ReAct_DeployModel.generate_additional_info import generate_additional_info_node
from ComputeAgent.nodes.ReAct_DeployModel.capacity_estimation import capacity_estimation_node
from ComputeAgent.nodes.ReAct_DeployModel.capacity_approval import capacity_approval_node, auto_capacity_approval_node
from ComputeAgent.models.model_manager import ModelManager
from langchain_mcp_adapters.client import MultiServerMCPClient
import os

# Import constants for human approval settings
from constant import Constants

# Initialize model manager for dynamic LLM loading and management
model_manager = ModelManager()

logger = logging.getLogger("ComputeAgent")

# Get the project root directory (parent of ComputeAgent folder)
import sys
# __file__ is in ComputeAgent/graph/graph_deploy.py
# Go up 3 levels: graph -> ComputeAgent -> project_root
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
mcp_server_path = os.path.join(project_root, "Compute_MCP", "main.py")

# Use sys.executable to get the current Python interpreter path
python_executable = sys.executable

mcp_client = MultiServerMCPClient(
    {
        "hivecompute": {
            "command": python_executable,
            "args": [mcp_server_path],
            "transport": "stdio",
            "env": {
                # Pass HF Spaces secrets to the MCP subprocess
                "HIVE_COMPUTE_DEFAULT_API_TOKEN": os.getenv("HIVE_COMPUTE_DEFAULT_API_TOKEN", ""),
                "HIVE_COMPUTE_BASE_API_URL": os.getenv("HIVE_COMPUTE_BASE_API_URL", "https://api.hivecompute.ai"),
                # Also pass these to ensure Python works correctly
                "PATH": os.getenv("PATH", ""),
                "PYTHONPATH": os.getenv("PYTHONPATH", ""),
            }
        }
    }
)

logger = logging.getLogger("DeployModelGraph")


# Now inherits from AgentState (TypedDict) instead of StateGraph
class DeployModelState(AgentState):
    """
    DeployModelState extends AgentState to inherit all base agent fields.
    
    Inherited from AgentState (TypedDict):
        - query: str
        - response: str
        - current_step: str
        - messages: List[Dict[str, Any]]
        - agent_decision: str
        - deployment_approved: bool
        - model_name: str
        - llm: Any
        - model_card: Dict[str, Any]
        - model_info: Dict[str, Any]
        - capacity_estimate: Dict[str, Any]
        - deployment_result: Dict[str, Any]
        - react_results: Dict[str, Any]
        - tool_calls: List[Dict[str, Any]]
        - tool_results: List[Dict[str, Any]]
    
    All fields are inherited from AgentState - no additional fields needed.
    """
    pass  # Inherits all fields from AgentState


class DeployModelAgent:
    """
    Standalone Deploy Model Agent class with memory and streaming support.
    
    This class provides a dedicated interface for model deployment workflows
    with full memory management and streaming capabilities.
    """
    
    def __init__(self, llm, react_tools):
        self.llm = llm
        self.react_tools = react_tools
        self.react_subgraph = ReactWorkflow(llm=self.llm, tools=self.react_tools)
        self.graph = self._create_graph()

    @classmethod
    async def create(cls, llm=None, custom_tools=None):
        """
        Async factory method for DeployModelAgent.
        
        Args:
            llm: Optional pre-loaded LLM
            custom_tools: Optional pre-loaded tools for the nested ReactWorkflow
            
        Returns:
            DeployModelAgent instance
        """
        if llm is None:
            llm = await model_manager.load_llm_model(Constants.DEFAULT_LLM_FC)
        
        if custom_tools is None:
            # Load a separate MCP toolset for deployment React
            custom_tools = await mcp_client.get_tools()
        
        return cls(llm=llm, react_tools=custom_tools)
        
    def _create_graph(self) -> CompiledStateGraph:
        """
        Creates and configures the deploy model workflow.
        
        βœ… FIXED: Now correctly creates StateGraph with DeployModelState (TypedDict)
        """
        # βœ… This now works because DeployModelState is a TypedDict (via AgentState)
        workflow = StateGraph(DeployModelState)

        # Add nodes
        workflow.add_node("extract_model_info", extract_model_info_node)
        workflow.add_node("generate_model_name", generate_additional_info_node)
        workflow.add_node("capacity_estimation", capacity_estimation_node)
        workflow.add_node("capacity_approval", capacity_approval_node)
        workflow.add_node("auto_capacity_approval", auto_capacity_approval_node)
        workflow.add_node("react_deployment", self.react_subgraph.get_compiled_graph())
        
        # Set entry point
        workflow.set_entry_point("extract_model_info")

        # Add conditional edges - Decision point after model extraction
        workflow.add_conditional_edges(
            "extract_model_info",
            self.should_validate_or_generate,
            {
                "generate_model_name": "generate_model_name",
                "capacity_estimation": "capacity_estimation"
            }
        )
        
        # Add conditional edges from capacity estimation to approval
        workflow.add_conditional_edges(
            "capacity_estimation",
            self.should_continue_to_capacity_approval,
            {
                "capacity_approval": "capacity_approval",
                "auto_capacity_approval": "auto_capacity_approval",
                "end": END
            }
        )
        
        # Add conditional edges from capacity approval
        workflow.add_conditional_edges(
            "capacity_approval",
            self.should_continue_after_capacity_approval,
            {
                "react_deployment": "react_deployment",
                "capacity_estimation": "capacity_estimation",
                "end": END
            }
        )
        
        # Auto approval always goes to deployment
        workflow.add_edge("auto_capacity_approval", "react_deployment")
        
        # Final edges
        workflow.add_edge("generate_model_name", END)
        workflow.add_edge("react_deployment", END)
        
        # Compile
        return workflow.compile()
    
    def get_compiled_graph(self):
        """Return the compiled graph for embedding in parent graph"""
        return self.graph

    def should_validate_or_generate(self, state: Dict[str, Any]) -> str:
        """
        Decision routing function after model extraction.
        
        Path 1: If model found and valid β†’ proceed to capacity estimation
        Path 1A: If no model info or invalid β†’ generate helpful response with suggestions
        
        Args:
            state: Current workflow state
            
        Returns:
            Next node name or END
        """
        if state.get("model_name") and state.get("model_info") and not state.get("model_info", {}).get("error"):
            return "capacity_estimation"  # Path 1: Valid model case
        else:
            return "generate_model_name"  # Path 1A: No info case

    def should_continue_to_capacity_approval(self, state: Dict[str, Any]) -> str:
        """
        Determine whether to proceed to human approval, auto-approval, or end.
        
        This function controls the flow after capacity estimation based on HUMAN_APPROVAL_CAPACITY setting:
        - If HUMAN_APPROVAL_CAPACITY is True: Route to capacity_approval for manual approval
        - If HUMAN_APPROVAL_CAPACITY is False: Route to auto_capacity_approval for automatic approval
        - If capacity estimation failed: Route to end
        
        Args:
            state: Current workflow state containing capacity estimation results
            
        Returns:
            Next node name: "capacity_approval", "auto_capacity_approval", or "end"
        """
        # Check if capacity estimation was successful
        if state.get("capacity_estimation_status") != "success":
            logger.info("πŸ”„ Capacity estimation failed - routing to end")
            return "end"
        
        # Check if human approval is enabled
        HUMAN_APPROVAL_CAPACITY = True if Constants.HUMAN_APPROVAL_CAPACITY == "true" else False
        if not HUMAN_APPROVAL_CAPACITY:
            logger.info("πŸ”„ HUMAN_APPROVAL_CAPACITY disabled - routing to auto-approval")
            return "auto_capacity_approval"
        else:
            logger.info("πŸ”„ HUMAN_APPROVAL_CAPACITY enabled - routing to human approval")
            return "capacity_approval"

    def should_continue_after_capacity_approval(self, state: Dict[str, Any]) -> str:
        """
        Decide whether to proceed to ReAct deployment, re-estimate capacity, or end.
        """
        logger.info(f"πŸ” Routing after capacity approval:")
        logger.info(f"   - capacity_approved: {state.get('capacity_approved')}")
        logger.info(f"   - needs_re_estimation: {state.get('needs_re_estimation')}")
        logger.info(f"   - capacity_approval_status: {state.get('capacity_approval_status')}")

        # 1. FIRST check for re-estimation (highest priority)
        needs_re_estimation = state.get("needs_re_estimation")
        if needs_re_estimation is True:
            logger.info("πŸ”„ Re-estimation requested - routing to capacity_estimation")
            return "capacity_estimation"

        # 2. THEN check if APPROVED (explicit True check)
        capacity_approved = state.get("capacity_approved")
        if capacity_approved is True:
            logger.info("βœ… Capacity approved - proceeding to react_deployment")
            return "react_deployment"

        # 3. Check if REJECTED (explicit False check)
        if capacity_approved is False:
            logger.info("❌ Capacity rejected - ending workflow")
            return "end"

        # 4. If capacity_approved is None and no re-estimation, something is wrong
        logger.warning(f"⚠️ Unexpected state in capacity approval routing")
        logger.warning(f"   capacity_approved: {capacity_approved} (type: {type(capacity_approved)})")
        logger.warning(f"   needs_re_estimation: {needs_re_estimation} (type: {type(needs_re_estimation)})")
        logger.warning(f"   Full state keys: {list(state.keys())}")
        
        # Default to end to prevent infinite loops
        return "end"

    async def ainvoke(self, 
                      query: str, 
                      user_id: str = "default_user", 
                      session_id: str = "default_session",
                      enable_memory: bool = False,
                      config: Optional[Dict] = None) -> Dict[str, Any]:
        """
        Asynchronously invoke the Deploy Model Agent workflow.
        
        Args:
            query: User's model deployment query
            user_id: User identifier for memory management
            session_id: Session identifier for memory management
            enable_memory: Whether to enable conversation memory management
            config: Optional config dict
            
        Returns:
            Final workflow state with deployment results
        """      
        # Initialize state with all required fields from AgentState
        initial_state = {
            # Core fields
            "query": query,
            "response": "",
            "current_step": "initialized",
            "messages": [],
            
            # Decision fields
            "agent_decision": "",
            "deployment_approved": False,
            
            # Model deployment fields
            "model_name": "",
            "llm": None,
            "model_card": {},
            "model_info": {},
            "capacity_estimate": {},
            "deployment_result": {},
            
            # React agent fields
            "react_results": {},
            "tool_calls": [],
            "tool_results": [],
        }
        
        # Extract approval from config if provided
        if config and "configurable" in config:
            if "capacity_approved" in config["configurable"]:
                initial_state["deployment_approved"] = config["configurable"]["capacity_approved"]
                logger.info(f"πŸ“‹ DeployModelAgent received approval: {config['configurable']['capacity_approved']}")
        
        # Configure memory if checkpointer is available
        memory_config = None
        if self.checkpointer:
            thread_id = f"{user_id}:{session_id}"
            memory_config = {"configurable": {"thread_id": thread_id}}
        
        # Merge configs
        final_config = memory_config or {}
        if config:
            if "configurable" in final_config:
                final_config["configurable"].update(config.get("configurable", {}))
            else:
                final_config = config
        
        logger.info(f"πŸš€ Starting Deploy Model workflow")
        
        # Execute the graph
        if final_config:
            result = await self.graph.ainvoke(initial_state, final_config)
        else:
            result = await self.graph.ainvoke(initial_state)
        
        return result

    
    def invoke(self, query: str, user_id: str = "default_user", session_id: str = "default_session", enable_memory: bool = False) -> Dict[str, Any]:
        """
        Synchronously invoke the Deploy Model Agent workflow.
        
        Args:
            query: User's model deployment query
            user_id: User identifier for memory management
            session_id: Session identifier for memory management
            enable_memory: Whether to enable conversation memory management
            
        Returns:
            Final workflow state with deployment results
        """
        import asyncio
        return asyncio.run(self.ainvoke(query, user_id, session_id, enable_memory))
    
    def draw_graph(self, output_file_path: str = "deploy_model_graph.png"):
        """
        Generate and save a visual representation of the Deploy Model workflow graph.
        
        Args:
            output_file_path: Path where to save the graph PNG file
        """
        try:
            self.graph.get_graph().draw_mermaid_png(output_file_path=output_file_path)
            logger.info(f"βœ… Graph visualization saved to: {output_file_path}")
        except Exception as e:
            logger.error(f"❌ Failed to generate graph visualization: {e}")
            print(f"Error generating graph: {e}")