File size: 16,272 Bytes
8816dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd4236
 
 
8816dfd
 
5dd4236
93850a2
8816dfd
 
 
 
 
 
 
 
 
93850a2
 
 
 
 
 
 
 
 
 
8816dfd
 
 
0297f14
 
 
 
 
 
 
 
 
 
 
8816dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
"""
Basic Agent Main Graph Module (FastAPI Compatible - Minimal Changes)

This module implements the core workflow graph for the Basic Agent system.
It defines the agent's decision-making flow between model deployment and
React-based compute workflows.

CHANGES FROM ORIGINAL:
- __init__ now accepts optional tools and llm parameters
- Added async create() classmethod for FastAPI
- Fully backwards compatible with existing CLI code

Author: Your Name
License: Private
"""

import asyncio
from typing import Dict, Any, List, Optional
import uuid
import json
import logging

from langgraph.graph import StateGraph, END, START
from typing_extensions import TypedDict
from constant import Constants

# Import node functions (to be implemented in separate files)
from langgraph.checkpoint.memory import MemorySaver
from ComputeAgent.graph.graph_deploy import DeployModelAgent
from ComputeAgent.graph.graph_ReAct import ReactWorkflow
from ComputeAgent.models.model_manager import ModelManager
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_mcp_adapters.client import MultiServerMCPClient
from ComputeAgent.graph.state import AgentState
import os

# Initialize model manager for dynamic LLM loading and management
model_manager = ModelManager()

# Global MemorySaver (persists state across requests)
memory_saver = MemorySaver()

logger = logging.getLogger("ComputeAgent")

# Get the project root directory (parent of ComputeAgent folder)
import sys
# __file__ is in ComputeAgent/graph/graph.py
# Go up 3 levels: graph -> ComputeAgent -> project_root
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
mcp_server_path = os.path.join(project_root, "Compute_MCP", "main.py")

# Use sys.executable to get the current Python interpreter path
python_executable = sys.executable

mcp_client = MultiServerMCPClient(
    {
        "hivecompute": {
            "command": python_executable,
            "args": [mcp_server_path],
            "transport": "stdio",
            "env": {
                # Pass HF Spaces secrets to the MCP subprocess
                "HIVE_COMPUTE_DEFAULT_API_TOKEN": os.getenv("HIVE_COMPUTE_DEFAULT_API_TOKEN", ""),
                "HIVE_COMPUTE_BASE_API_URL": os.getenv("HIVE_COMPUTE_BASE_API_URL", "https://api.hivecompute.ai"),
                # Also pass these to ensure Python works correctly
                "PATH": os.getenv("PATH", ""),
                "PYTHONPATH": os.getenv("PYTHONPATH", ""),
            }
        }
    }
)

class ComputeAgent:
    """
    Main Compute Agent class providing AI-powered decision routing and execution.
    
    This class orchestrates the complete agent workflow including:
    - Decision routing between model deployment and React agent
    - Model deployment workflow with capacity estimation and approval
    - React agent execution with compute capabilities
    - Error handling and state management
    
    Attributes:
        graph: Compiled LangGraph workflow
        model_name: Default model name for operations
    
    Usage:
        # For CLI (backwards compatible):
        agent = ComputeAgent()
        
        # For FastAPI (async):
        agent = await ComputeAgent.create()
    """
    
    def __init__(self, tools=None, llm=None):
        """
        Initialize Compute Agent with optional pre-loaded dependencies.
        
        Args:
            tools: Pre-loaded MCP tools (optional, will load if not provided)
            llm: Pre-loaded LLM model (optional, will load if not provided)
        """
        # If tools/llm not provided, load them synchronously (for CLI)
        if tools is None:
            self.tools = asyncio.run(mcp_client.get_tools())
        else:
            self.tools = tools
            
        if llm is None:
            self.llm = asyncio.run(model_manager.load_llm_model(Constants.DEFAULT_LLM_FC))
        else:
            self.llm = llm
            
        self.deploy_subgraph = DeployModelAgent(llm=self.llm, react_tools=self.tools)
        self.react_subgraph = ReactWorkflow(llm=self.llm, tools=self.tools)
        self.graph = self._create_graph()
    
    @classmethod
    async def create(cls):
        """
        Async factory method for creating ComputeAgent.
        Use this in FastAPI to avoid asyncio.run() issues.
        
        Returns:
            Initialized ComputeAgent instance
        """
        logger.info("πŸ”§ Loading tools and LLM asynchronously...")
        tools = await mcp_client.get_tools()
        llm = await model_manager.load_llm_model(Constants.DEFAULT_LLM_FC)
        # Initialize DeployModelAgent with its own tools
        deploy_subgraph = await DeployModelAgent.create(llm=llm, custom_tools=None)
        return cls(tools=tools, llm=llm)

    async def decision_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
        """
        Node that handles routing decisions for the ComputeAgent workflow.
        
        Analyzes the user query to determine whether to route to:
        - Model deployment workflow (deploy_model)
        - React agent workflow (react_agent)
        
        Args:
            state: Current agent state with memory fields
            
        Returns:
            Updated state with routing decision
        """
        # Get user context
        user_id = state.get("user_id", "")
        session_id = state.get("session_id", "")
        query = state.get("query", "")
        
        logger.info(f"🎯 Decision node processing query for {user_id}:{session_id}")
        
        # Build memory context for decision making
        memory_context = ""
        if user_id and session_id:
            try:
                from helpers.memory import get_memory_manager
                memory_manager = get_memory_manager()
                memory_context = await memory_manager.build_context_for_node(user_id, session_id, "decision")
                if memory_context:
                    logger.info(f"🧠 Using memory context for decision routing")
            except Exception as e:
                logger.warning(f"⚠️ Could not load memory context for decision: {e}")
        
        try:
            # Create a simple LLM for decision making
            # Load main LLM using ModelManager
            llm = await model_manager.load_llm_model(Constants.DEFAULT_LLM_NAME)
            
            # Create decision prompt
            decision_system_prompt = f"""
            You are a routing assistant for ComputeAgent. Analyze the user's query and decide which workflow to use.

            Choose between:
            1. DEPLOY_MODEL - For queries about deploy AI model from HuggingFace. In this case the user MUST specify the model card name (like meta-llama/Meta-Llama-3-70B).
                - The user can specify the hardware capacity needed.
                - The user can ask for model analysis, deployment steps, or capacity estimation.

            2. REACT_AGENT - For all the rest of queries.

            {f"Conversation Context: {memory_context}" if memory_context else "No conversation context available."}

            User Query: {query}

            Respond with only: "DEPLOY_MODEL" or "REACT_AGENT"
            """
            
            # Get routing decision
            decision_response = await llm.ainvoke([
                SystemMessage(content=decision_system_prompt)
            ])
            
            routing_decision = decision_response.content.strip().upper()
            
            # Validate and set decision
            if "DEPLOY_MODEL" in routing_decision:
                agent_decision = "deploy_model"
                logger.info(f"πŸ“¦ Routing to model deployment workflow")
            elif "REACT_AGENT" in routing_decision:
                agent_decision = "react_agent"
                logger.info(f"βš›οΈ Routing to React agent workflow")
            else:
                # Default fallback to React agent for general queries
                agent_decision = "react_agent"
                logger.warning(f"⚠️ Ambiguous routing decision '{routing_decision}', defaulting to React agent")
            
            # Update state with decision
            updated_state = state.copy()
            updated_state["agent_decision"] = agent_decision
            updated_state["current_step"] = "decision_complete"
            
            logger.info(f"βœ… Decision node complete: {agent_decision}")
            return updated_state
            
        except Exception as e:
            logger.error(f"❌ Error in decision node: {e}")
            
            # Update state with fallback decision
            updated_state = state.copy()
            updated_state["error"] = f"Decision error (fallback used): {str(e)}"
            
            return updated_state
    
    def _create_graph(self) -> StateGraph:
        """
        Create and configure the Compute Agent workflow graph.
        
        This method builds the complete workflow including:
        1. Initial decision node - routes to deployment or React agent
        2. Model deployment path:
           - Fetch model card from HuggingFace
           - Extract model information
           - Estimate capacity requirements
           - Human approval checkpoint
           - Deploy model or provide info
        3. React agent path:
           - Execute React agent with compute MCP capabilities
        
        Returns:
            Compiled StateGraph ready for execution
        """
        workflow = StateGraph(AgentState)
        
        # Add decision node
        workflow.add_node("decision", self.decision_node)
        
        # Add model deployment workflow nodes
        workflow.add_node("deploy_model", self.deploy_subgraph.get_compiled_graph())
        
        # Add React agent node
        workflow.add_node("react_agent", self.react_subgraph.get_compiled_graph())
        
        # Set entry point
        workflow.set_entry_point("decision")
        
        # Add conditional edges from decision node
        workflow.add_conditional_edges(
            "decision",
            lambda state: state["agent_decision"],
            {
                "deploy_model": "deploy_model",
                "react_agent": "react_agent",
            }
        )
        
        # Add edges to END
        workflow.add_edge("deploy_model", END)
        workflow.add_edge("react_agent", END)
        
        # Compile with checkpointer
        return workflow.compile(checkpointer=memory_saver)
    
    def get_compiled_graph(self):
        """Return the compiled graph for use in FastAPI"""
        return self.graph
    
    def invoke(self, query: str, user_id: str = "default_user", session_id: str = "default_session") -> Dict[str, Any]:
        """
        Execute the graph with a given query and memory context (synchronous wrapper for async).
        
        Args:
            query: User's query
            user_id: User identifier for memory management
            session_id: Session identifier for memory management
            
        Returns:
            Final result from the graph execution
        """
        return asyncio.run(self.ainvoke(query, user_id, session_id))
    
    async def ainvoke(self, query: str, user_id: str = "default_user", session_id: str = "default_session") -> Dict[str, Any]:
        """
        Execute the graph with a given query and memory context (async).

        Args:
            query: User's query
            user_id: User identifier for memory management
            session_id: Session identifier for memory management

        Returns:
            Final result from the graph execution containing:
            - response: Final response to user
            - agent_decision: Which path was taken
            - deployment_result: If deployment path was taken
            - react_results: If React agent path was taken
        """
        initial_state = {
            "user_id": user_id,
            "session_id": session_id,
            "query": query,
            "response": "",
            "current_step": "start",
            "agent_decision": "",
            "deployment_approved": False,
            "model_name": "",
            "model_card": {},
            "model_info": {},
            "capacity_estimate": {},
            "deployment_result": {},
            "react_results": {},
            "tool_calls": [],
            "tool_results": [],
            "messages": [],
            # Approval fields for ReactWorkflow
            "pending_tool_calls": [],
            "approved_tool_calls": [],
            "rejected_tool_calls": [],
            "modified_tool_calls": [],
            "needs_re_reasoning": False,
            "re_reasoning_feedback": ""
        }
        
        # Create config with thread_id for checkpointer
        config = {
            "configurable": {
                "thread_id": f"{user_id}_{session_id}"
            }
        }
        
        try:
            result = await self.graph.ainvoke(initial_state, config)
            return result
            
        except Exception as e:
            logger.error(f"Error in graph execution: {e}")
            return {
                **initial_state,
                "error": str(e),
                "error_step": initial_state.get("current_step", "unknown"),
                "response": f"An error occurred during execution: {str(e)}"
            }
    
    async def astream_generate_nodes(self, query: str, user_id: str = "default_user", session_id: str = "default_session"):
        """
        Stream the graph execution node by node (async).

        Args:
            query: User's query
            user_id: User identifier for memory management
            session_id: Session identifier for memory management

        Yields:
            Dict containing node execution updates
        """
        initial_state = {
            "user_id": user_id,
            "session_id": session_id,
            "query": query,
            "response": "",
            "current_step": "start",
            "agent_decision": "",
            "deployment_approved": False,
            "model_name": "",
            "model_card": {},
            "model_info": {},
            "capacity_estimate": {},
            "deployment_result": {},
            "react_results": {},
            "tool_calls": [],
            "tool_results": [],
            "messages": [],
            # Approval fields for ReactWorkflow
            "pending_tool_calls": [],
            "approved_tool_calls": [],
            "rejected_tool_calls": [],
            "modified_tool_calls": [],
            "needs_re_reasoning": False,
            "re_reasoning_feedback": ""
        }
        
        # Create config with thread_id for checkpointer
        config = {
            "configurable": {
                "thread_id": f"{user_id}_{session_id}"
            }
        }
        
        try:
            # Stream through the graph execution
            async for chunk in self.graph.astream(initial_state, config):
                # Each chunk contains the node name and its output
                for node_name, node_output in chunk.items():
                    yield {
                        "node": node_name,
                        "output": node_output,
                        **node_output  # Include all state updates
                    }
                    
        except Exception as e:
            logger.error(f"Error in graph streaming: {e}")
            yield {
                "error": str(e),
                "status": "error",
                "error_step": initial_state.get("current_step", "unknown")
            }
    
    def draw_graph(self, output_file_path: str = "basic_agent_graph.png"):
        """
        Generate and save a visual representation of the Basic Agent workflow graph.
        
        Args:
            output_file_path: Path where to save the graph PNG file
        """
        try:
            self.graph.get_graph().draw_mermaid_png(output_file_path=output_file_path)
            logger.info(f"βœ… Basic Agent graph visualization saved to: {output_file_path}")
        except Exception as e:
            logger.error(f"❌ Failed to generate Basic Agent graph visualization: {e}")