carraraig's picture
finish (#8)
5dd4236 verified
"""
ReAct Generate Node - Simplified version with 3 clear paths
Node that generates final response using:
1. DirectAnswerChain for direct answers (no tools)
2. ResearcherChain for researcher tool results
3. ToolResultChain for other tool results
All chains provide consistent formatting and professional presentation with memory context support.
Independent implementation for ReAct workflow - no dependency on AgenticRAG.
"""
from typing import Dict, Any
from ComputeAgent.chains.tool_result_chain import ToolResultChain
from ComputeAgent.models.model_manager import ModelManager
from constant import Constants
import asyncio
import logging
import json
from langgraph.config import get_stream_writer
from langchain_core.messages import HumanMessage, SystemMessage
# Initialize model manager for LLM loading
model_manager = ModelManager()
# Initialize logger for generate node
logger = logging.getLogger("ReAct Generate Node")
def _create_error_response(state: Dict[str, Any], query: str, error_msg: str) -> Dict[str, Any]:
"""Create a standardized error response"""
final_response_dict = {
"query": query,
"final_response": f"I apologize, but I encountered an error: {error_msg}",
"sources": []
}
updated_state = state.copy()
updated_state["response"] = final_response_dict["final_response"]
updated_state["final_response_dict"] = final_response_dict
updated_state["current_step"] = "generate_complete"
# Send it via custom stream
writer = get_stream_writer()
writer({"final_response_dict": final_response_dict})
return updated_state
async def _generate_deployment_instructions(state: Dict[str, Any]) -> Dict[str, Any]:
"""
Generate deployment instructions when instance has been created.
Args:
state: Current state with instance_id and deployment info
Returns:
Updated state with deployment instructions
"""
logger.info("πŸ“ Generating deployment instructions")
# Extract deployment information
instance_id = state.get("instance_id", "")
instance_status = state.get("instance_status", "")
model_name = state.get("model_name", "Unknown Model")
model_info = state.get("model_info", {})
gpu_requirements = state.get("gpu_requirements", {})
estimated_gpu_memory = state.get("estimated_gpu_memory", 0)
# Get deployment configuration
location = model_info.get("location", "UAE-1")
gpu_type = model_info.get("GPU_type", "RTX 4090")
num_gpus = gpu_requirements.get(gpu_type, 1)
config = f"{num_gpus}x {gpu_type}"
# Determine capacity source
custom_capacity = state.get("custom_capacity", {})
capacity_source = "custom" if custom_capacity else "estimated"
# Build SSH command
ssh_command = f'ssh -i ~/.ssh/id_rsa -o "ProxyCommand=ssh bastion@ssh.hivecompute.ai %h" ubuntu@{instance_id}.ssh.hivecompute.ai'
# Get capacity estimation parameters
max_model_len = model_info.get("max_model_len", 2048)
max_num_seqs = model_info.get("max_num_seqs", 256)
max_batched_tokens = model_info.get("max_num_batched_tokens", 2048)
dtype = model_info.get("dtype", "BF16")
kv_cache_dtype = model_info.get("kv_cache_dtype", "auto")
gpu_memory_utilization = model_info.get("gpu_memory_utilization", 0.9)
# Use LLM to generate optimal vLLM command based on documentation and specs
logger.info("πŸ€– Using LLM to determine optimal vLLM parameters")
# Import vLLM documentation
try:
from vllm_engine_args import get_vllm_docs
vllm_docs = get_vllm_docs()
except ImportError:
logger.warning("⚠️ Could not import vllm_engine_args, using basic documentation")
vllm_docs = "Basic vLLM parameters: --model, --dtype, --max-model-len, --gpu-memory-utilization, --tensor-parallel-size, --enable-prefix-caching, --enable-chunked-prefill"
vllm_params_prompt = f"""You are an expert in vLLM deployment. Based on the model specifications and capacity estimation, generate an optimal vLLM serve command.
**Model Information:**
- Model: {model_name}
- GPU Type: {gpu_type}
- Number of GPUs: {num_gpus}
- GPU Memory: {estimated_gpu_memory:.2f} GB
- Location: {location}
**Capacity Estimation Parameters:**
- Max Model Length: {max_model_len}
- Max Sequences: {max_num_seqs}
- Max Batched Tokens: {max_batched_tokens}
- Data Type: {dtype}
- KV Cache dtype: {kv_cache_dtype}
- GPU Memory Utilization: {gpu_memory_utilization}
**vLLM Engine Arguments Documentation:**
{vllm_docs}
**Task:**
Generate the optimal vLLM serve command for this deployment. Consider:
1. Use the capacity estimation parameters provided
2. For multi-GPU setups ({num_gpus} GPUs), add --tensor-parallel-size {num_gpus} if num_gpus > 1
3. Add --enable-chunked-prefill if max_model_len > 8192 for better long context handling
4. Use --quantization fp8 only if dtype contains 'fp8' or 'FP8'
5. Always include --enable-prefix-caching for better performance
6. Set --host 0.0.0.0 and --port 8888
7. Use --download-dir /home/ubuntu/workspace/models
8. Consider other relevant parameters from the documentation based on the model and hardware specs
Return ONLY the complete vLLM command without any explanation, starting with 'vllm serve'."""
try:
from langchain_openai import ChatOpenAI
from constant import Constants
llm = ChatOpenAI(
base_url=Constants.LLM_BASE_URL,
api_key=Constants.LLM_API_KEY,
model=Constants.DEFAULT_LLM_NAME,
temperature=0.0
)
vllm_response = await llm.ainvoke(vllm_params_prompt)
vllm_command = vllm_response.content.strip()
logger.info(f"βœ… Generated vLLM command: {vllm_command}")
except Exception as e:
logger.error(f"❌ Failed to generate vLLM command with LLM: {e}")
# Fallback to basic command
quantization = "fp8" if "fp8" in dtype.lower() else None
vllm_command = f'vllm serve {model_name} --download-dir /home/ubuntu/workspace/models --gpu-memory-utilization {gpu_memory_utilization} --max-model-len {max_model_len} --max-num-seqs {max_num_seqs} --max-num-batched-tokens {max_batched_tokens} --dtype {dtype}'
if quantization:
vllm_command += f' --quantization {quantization}'
if num_gpus > 1:
vllm_command += f' --tensor-parallel-size {num_gpus}'
vllm_command += f' --kv-cache-dtype {kv_cache_dtype} --enable-prefix-caching --host 0.0.0.0 --port 8888'
# Build curl test command
curl_command = f'''curl -k https://{instance_id}-8888.tenants.hivecompute.ai/v1/chat/completions \\
-H "Content-Type: application/json" \\
-d '{{
"model": "{model_name}",
"messages": [
{{"role": "user", "content": "What is the capital of France?"}}
],
"max_tokens": 512
}}' '''
# Build complete deployment instructions response
final_response = f"""
# πŸš€ Deployment Instructions for {model_name}
## βœ… Instance Created Successfully
**Instance ID:** `{instance_id}`
**Status:** `{instance_status}`
**Location:** `{location}`
**Configuration:** `{config}`
---
## πŸ“Š Capacity Configuration
- **GPU Memory Required:** {estimated_gpu_memory:.2f} GB
- **GPU Type:** {gpu_type}
- **Number of GPUs:** {num_gpus}
- **Capacity Source:** {capacity_source}
---
## πŸ” Step 1: SSH to the Instance
```bash
{ssh_command}
```
---
## πŸ“ Step 2: Create Models Directory
Once connected via SSH, create the models directory inside the workspace:
```bash
mkdir -p /home/ubuntu/workspace/models
mkdir -p /home/ubuntu/workspace/tmpdir
```
**Note:** Cannot use docker file in HiveCompute since there is no VM support. Use an instance from HiveCompute with Template with Pytorch.
---
## πŸ“¦ Step 3: Install Dependencies (Using UV)
Install UV package manager:
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
```
Create and activate environment:
```bash
uv venv --python 3.12
source .venv/bin/activate
```
Install vLLM and dependencies:
```bash
uv pip install vllm==0.11.0 ray[default]
```
---
## πŸ€– Step 4: Start vLLM Server
Run the vLLM server with the following configuration:
```bash
{vllm_command}
```
**Configuration Details:**
The vLLM command above was intelligently generated based on:
- **Model Specifications:** {model_name} with {num_gpus}x {gpu_type}
- **Capacity Estimation:** {estimated_gpu_memory:.2f} GB GPU memory, {int(gpu_memory_utilization * 100)}% utilization
- **Context Length:** {max_model_len} tokens
- **Batch Configuration:** {max_num_seqs} max sequences, {max_batched_tokens} max batched tokens
- **Data Type:** {dtype} with {kv_cache_dtype} KV cache
- **vLLM Documentation:** Optimized parameters from https://docs.vllm.ai/en/v0.7.2/serving/engine_args.html
The LLM analyzed your deployment requirements and selected optimal parameters including tensor parallelism, chunked prefill, and caching strategies.
---
## πŸ§ͺ Step 5: Test the Deployment
Test your deployed model with a curl command:
```bash
{curl_command}
```
This will send a test request to your model and verify it's responding correctly.
---
## πŸ“ Additional Notes
- The vLLM server will download the model to `/home/ubuntu/workspace/models` on first run
- Make sure to monitor GPU memory usage during model loading
- The instance is accessible via the HiveCompute tenant URL: `https://{instance_id}-8888.tenants.hivecompute.ai`
- For production use, consider setting up monitoring and health checks
---
**Deployment Complete! πŸŽ‰**
"""
final_response_dict = {
"query": f"Deploy model {model_name}",
"final_response": final_response,
"instance_id": instance_id,
"instance_status": instance_status,
"sources": []
}
# Update state
updated_state = state.copy()
updated_state["response"] = final_response
updated_state["final_response_dict"] = final_response_dict
updated_state["current_step"] = "generate_complete"
# Remove tools to avoid serialization issues
if "tools" in updated_state:
del updated_state["tools"]
# Send via custom stream
writer = get_stream_writer()
writer({"final_response_dict": final_response_dict})
logger.info("βœ… Deployment instructions generated successfully")
return updated_state
async def _handle_tool_results(state: Dict[str, Any], query: str, user_id: str, session_id: str,
tool_results: list, memory_context: str, llm) -> Dict[str, Any]:
"""Handle general tool results using ToolResultChain"""
try:
logger.info(f"πŸ€– Synthesizing tool results using ToolResultChain...")
tool_result_chain = ToolResultChain(llm=llm)
formatted_response = await tool_result_chain.ainvoke(query, tool_results, memory_context)
final_response_dict = {
"query": query,
"final_response": formatted_response,
"sources": []
}
updated_state = state.copy()
updated_state["response"] = formatted_response
updated_state["final_response_dict"] = final_response_dict
updated_state["current_step"] = "generate_complete"
# Send it via custom stream
writer = get_stream_writer()
writer({"final_response_dict": final_response_dict})
logger.info("βœ… Tool results synthesized successfully")
return updated_state
except Exception as e:
logger.error(f"❌ ToolResultChain Error: {str(e)}")
# Final fallback to raw content
fallback_response = "I executed the requested tools but encountered formatting issues. Here are the raw results:\n\n"
for i, result in enumerate(tool_results, 1):
content = result.content if hasattr(result, 'content') else str(result)
fallback_response += f"Tool {i}: {content}\n"
final_response_dict = {
"query": query,
"final_response": fallback_response,
"sources": []
}
updated_state = state.copy()
updated_state["response"] = fallback_response
updated_state["final_response_dict"] = final_response_dict
updated_state["current_step"] = "generate_complete"
# Send it via custom stream
writer = get_stream_writer()
writer({"final_response_dict": final_response_dict})
logger.info("βœ… Tool results formatted using raw content fallback")
return updated_state
async def generate_node(state: Dict[str, Any]) -> Dict[str, Any]:
"""
Simple response generation with 4 clear paths:
1. Deployment Instructions (when instance_created == True)
2. Direct Answer (when current_step == "direct_answer_complete")
3. Researcher Results (when researcher_used == True)
4. General Tool Results (when tool_results exist but no researcher)
Args:
state: Current ReAct state
Returns:
Updated state with generated response
"""
logger.info("πŸ€– Starting response generation")
# Extract common variables
query = state.get("query", "")
user_id = state.get("user_id", "")
session_id = state.get("session_id", "")
current_step = state.get("current_step", "")
tool_results = state.get("tool_results", [])
existing_response = state.get("response", "")
researcher_used = state.get("researcher_used", False)
instance_created = state.get("instance_created", False)
# Debug logging to help diagnose path selection
logger.info(f"πŸ” DEBUG - instance_created: {instance_created}, researcher_used: {researcher_used}, tool_results count: {len(tool_results)}, current_step: {current_step}, existing_response: {bool(existing_response)}")
# Special handling for deployment workflow
if instance_created:
logger.info("πŸš€ Deployment mode detected - generating deployment instructions")
return await _generate_deployment_instructions(state)
# Build memory context once
memory_context = ""
if user_id and session_id:
try:
from helpers.memory import get_memory_manager
memory_manager = get_memory_manager()
memory_context = await memory_manager.build_context_for_node(user_id, session_id, "general")
if memory_context:
logger.info("🧠 Using memory context for response generation")
except Exception as e:
logger.warning(f"⚠️ Could not load memory context: {e}")
# Get model info once
model_name = Constants.DEFAULT_LLM_NAME
if hasattr(state.get("refining_llm"), 'model_name'):
model_name = state.get("refining_llm").model_name
try:
llm = await model_manager.load_llm_model(model_name)
except Exception as e:
logger.error(f"❌ Failed to load model {model_name}: {e}")
return _create_error_response(state, query, "Failed to load language model")
# If no tool results, generate a direct response using LLM
if not tool_results:
logger.info("ℹ️ No tool results found - generating LLM response")
system_prompt = """You are a helpful AI assistant. The user has made a request and you need to provide a comprehensive and helpful response.
If there's an existing response or context, acknowledge it and build upon it.
Be professional, clear, and concise in your response.
If you don't have specific information to provide, politely explain what you can help with instead."""
context_info = f"Query: {query}"
if existing_response:
context_info += f"\nExisting context: {existing_response}"
if memory_context:
context_info += f"\nMemory context: {memory_context}"
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=context_info)
]
try:
response = await llm.ainvoke(messages)
direct_response = response.content
# Create clean copy without tools (tools not serializable)
updated_state = state.copy()
updated_state["response"] = direct_response
updated_state["current_step"] = "generate_complete"
if "tools" in updated_state:
del updated_state["tools"]
logger.info("βœ… Generated LLM response successfully")
return updated_state
except Exception as e:
logger.error(f"❌ Error generating LLM response: {str(e)}")
return _create_error_response(state, query, f"Failed to generate response: {str(e)}")
# If we have tool results, use LLM to synthesize them
logger.info("πŸ”§ Processing tool results using LLM synthesis")
# Prepare tool results summary
tool_results_summary = ""
for i, result in enumerate(tool_results, 1):
content = result.content if hasattr(result, 'content') else str(result)
tool_name = getattr(result, 'name', f'Tool {i}')
tool_results_summary += f"\n{tool_name}: {content}\n"
system_prompt = """You are a helpful AI assistant that synthesizes tool execution results into a comprehensive response.
Your task is to:
1. Analyze the tool results provided
2. Generate a clear, professional response that summarizes what was accomplished
3. Present the information in a well-structured format
4. If there are any errors or issues, explain them clearly
5. Be concise but thorough in your explanation
Always maintain a helpful and professional tone."""
context_info = f"Query: {query}\n\nTool Results:{tool_results_summary}"
if memory_context:
context_info += f"\nMemory context: {memory_context}"
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=context_info)
]
try:
response = await llm.ainvoke(messages)
synthesized_response = response.content
# Create clean copy without tools (tools not serializable)
updated_state = state.copy()
updated_state["response"] = synthesized_response
updated_state["current_step"] = "generate_complete"
if "tools" in updated_state:
del updated_state["tools"]
logger.info("βœ… Synthesized tool results successfully using LLM")
return updated_state
except Exception as e:
logger.error(f"❌ Error synthesizing tool results with LLM: {str(e)}")
# Fallback to ToolResultChain if LLM synthesis fails
logger.info("πŸ”„ Falling back to ToolResultChain")
return await _handle_tool_results(state, query, user_id, session_id, tool_results, memory_context, llm)