|
|
|
|
|
import logging |
|
|
import time |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Dict, Any, List, Optional |
|
|
from dataclasses import dataclass, field |
|
|
|
|
|
|
|
|
try: |
|
|
from ..tools import mcp_client as _mcp_client_module |
|
|
MCP_CLIENT_AVAILABLE = True |
|
|
except ImportError: |
|
|
MCP_CLIENT_AVAILABLE = False |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AgentConfig: |
|
|
"""Configuration for agents - session management handled entirely by MCP server""" |
|
|
agent_name: str = "base_agent" |
|
|
planner_mode: str = "auto" |
|
|
model: Optional[str] = None |
|
|
max_iterations: int = 10 |
|
|
temperature: Optional[float] = None |
|
|
max_tokens: Optional[int] = None |
|
|
|
|
|
trajectory_storage_path: Optional[str] = None |
|
|
report_output_path: Optional[str] = None |
|
|
document_analysis_path: Optional[str] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AgentResponse: |
|
|
"""Standardized response format for all agents""" |
|
|
success: bool |
|
|
result: Optional[Dict[str, Any]] = None |
|
|
error: Optional[str] = None |
|
|
iterations: int = 0 |
|
|
reasoning_trace: List[Dict[str, Any]] = field(default_factory=list) |
|
|
agent_name: str = "" |
|
|
execution_time: float = 0.0 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TaskInput: |
|
|
"""Standardized task input format for all agents""" |
|
|
task_content: str |
|
|
task_steps_for_reference: Optional[str] = None |
|
|
deliverable_contents: Optional[str] = None |
|
|
current_task_status: Optional[str] = None |
|
|
task_executor: str = "info_seeker" |
|
|
workspace_id: Optional[str] = None |
|
|
acceptance_checking_criteria: Optional[str] = None |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert TaskInput to dictionary format""" |
|
|
return { |
|
|
"task_content": self.task_content, |
|
|
"task_steps_for_reference": self.task_steps_for_reference, |
|
|
"deliverable_contents": self.deliverable_contents, |
|
|
"current_task_status": self.current_task_status, |
|
|
"task_executor": self.task_executor, |
|
|
"workspace_id": self.workspace_id, |
|
|
"acceptance_checking_criteria": self.acceptance_checking_criteria |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'TaskInput': |
|
|
"""Create TaskInput from dictionary""" |
|
|
return cls( |
|
|
task_content=data.get("task_content", ""), |
|
|
task_steps_for_reference=data.get("task_steps_for_reference"), |
|
|
deliverable_contents=data.get("deliverable_contents"), |
|
|
current_task_status=data.get("current_task_status"), |
|
|
task_executor=data.get("task_executor", "info_seeker"), |
|
|
workspace_id=data.get("workspace_id"), |
|
|
acceptance_checking_criteria=data.get("acceptance_checking_criteria") |
|
|
) |
|
|
|
|
|
def format_for_prompt(self) -> str: |
|
|
"""Format the task input for use in prompts""" |
|
|
prompt = f"Task Content:\n{self.task_content}\n\n" |
|
|
|
|
|
if self.task_steps_for_reference: |
|
|
prompt += f"Task Steps for Reference:\n{self.task_steps_for_reference}\n\n" |
|
|
|
|
|
if self.deliverable_contents: |
|
|
prompt += f"Deliverable Contents:\n{self.deliverable_contents}\n\n" |
|
|
|
|
|
if self.current_task_status: |
|
|
prompt += f"Current Task Status:\n{self.current_task_status}\n\n" |
|
|
|
|
|
if self.acceptance_checking_criteria: |
|
|
prompt += f"Acceptance Checking Criteria:\n{self.acceptance_checking_criteria}\n\n" |
|
|
|
|
|
prompt += f"Task Executor: {self.task_executor}\n" |
|
|
|
|
|
if self.workspace_id: |
|
|
prompt += f"Workspace ID: {self.workspace_id}\n" |
|
|
|
|
|
return prompt |
|
|
|
|
|
|
|
|
class SectionWriterTaskInput(TaskInput): |
|
|
""" |
|
|
Specialized TaskInput for section writing tasks |
|
|
|
|
|
Only stores the essential parameters. The section_writer agent |
|
|
will handle prompt assembly internally. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
task_content: str, |
|
|
user_query: str, |
|
|
write_file_path: str, |
|
|
overall_outline: str, |
|
|
current_chapter_outline: str, |
|
|
key_files: List[Dict[str, Any]], |
|
|
written_chapters: str = "", |
|
|
workspace_id: Optional[str] = None |
|
|
): |
|
|
|
|
|
self.write_file_path = write_file_path |
|
|
self.user_query = user_query |
|
|
self.current_chapter_outline = current_chapter_outline |
|
|
self.key_files = key_files |
|
|
self.written_chapters = written_chapters |
|
|
self.overall_outline = overall_outline |
|
|
|
|
|
|
|
|
super().__init__( |
|
|
task_content=task_content, |
|
|
task_executor="section_writer", |
|
|
workspace_id=workspace_id, |
|
|
) |
|
|
|
|
|
|
|
|
class WriterAgentTaskInput(TaskInput): |
|
|
""" |
|
|
Specialized TaskInput for section writing tasks |
|
|
|
|
|
Only stores the 4 essential parameters. The section_writer agent |
|
|
will handle prompt assembly internally. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
task_content: str, |
|
|
user_query: str, |
|
|
key_files: List[Dict[str, Any]], |
|
|
workspace_id: Optional[str] = None |
|
|
): |
|
|
|
|
|
self.user_query = user_query |
|
|
self.key_files = key_files |
|
|
|
|
|
|
|
|
super().__init__( |
|
|
task_content=task_content, |
|
|
task_executor="writer_agent", |
|
|
workspace_id=workspace_id, |
|
|
) |
|
|
|
|
|
|
|
|
class BaseAgent(ABC): |
|
|
""" |
|
|
Base class for all agents with MCP server-managed sessions. |
|
|
|
|
|
Session management is now entirely handled by the MCP server: |
|
|
- Server assigns session IDs on connection |
|
|
- Server creates workspace folders with UUID names |
|
|
- All tool operations are performed in server-managed workspaces |
|
|
""" |
|
|
|
|
|
def __init__(self, config: AgentConfig, shared_mcp_client=None): |
|
|
self.execution_stats = None |
|
|
self.reasoning_trace = None |
|
|
self.config = config |
|
|
self.logger = logging.getLogger(f"{__name__}.{config.agent_name}") |
|
|
|
|
|
|
|
|
self.session_info = None |
|
|
|
|
|
|
|
|
self.mcp_tools = None |
|
|
self.available_tools = {} |
|
|
|
|
|
self.reset_trace() |
|
|
|
|
|
|
|
|
self._initialize(shared_mcp_client) |
|
|
|
|
|
def _initialize(self, shared_mcp_client=None): |
|
|
"""Initialize agent with MCP server connection or shared client""" |
|
|
try: |
|
|
self.logger.info(f"Initializing agent {self.config.agent_name}") |
|
|
|
|
|
if shared_mcp_client: |
|
|
|
|
|
agent_type = self._get_agent_type() |
|
|
self.mcp_tools = self._create_filtered_mcp_tools(shared_mcp_client, agent_type) |
|
|
self.logger.info(f"Agent {self.config.agent_name} using shared MCP client with {agent_type} tools") |
|
|
else: |
|
|
|
|
|
self.mcp_tools = self._create_filtered_mcp_tools_standalone() |
|
|
|
|
|
|
|
|
self.available_tools = self._discover_mcp_tools() |
|
|
|
|
|
|
|
|
self.tool_schemas = self._build_tool_schemas() |
|
|
|
|
|
self.logger.info(f"Agent {self.config.agent_name} initialized successfully") |
|
|
self.logger.info(f"Available tools: {list(self.available_tools.keys())}") |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"Failed to initialize agent {self.config.agent_name}: {e}") |
|
|
raise |
|
|
|
|
|
def _discover_mcp_tools(self) -> Dict[str, Any]: |
|
|
"""Discover available tools from MCP server or fallback tools""" |
|
|
available_tools = {} |
|
|
|
|
|
|
|
|
if hasattr(self.mcp_tools, 'get_available_tools'): |
|
|
try: |
|
|
mcp_tools_dict = self.mcp_tools.get_available_tools() |
|
|
for tool_name, tool_info in mcp_tools_dict.items(): |
|
|
|
|
|
|
|
|
available_tools[tool_name] = tool_info |
|
|
|
|
|
if available_tools: |
|
|
self.logger.info(f"Discovered {len(available_tools)} tools from MCP server") |
|
|
return available_tools |
|
|
except Exception as e: |
|
|
self.logger.warning(f"Failed to discover MCP tools: {e}") |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(self.mcp_tools, '__dict__'): |
|
|
for attr_name in dir(self.mcp_tools): |
|
|
if not attr_name.startswith('_') and callable(getattr(self.mcp_tools, attr_name)): |
|
|
available_tools[attr_name] = getattr(self.mcp_tools, attr_name) |
|
|
|
|
|
return available_tools |
|
|
|
|
|
def _get_agent_type(self) -> str: |
|
|
"""Get agent type for tool filtering""" |
|
|
agent_name = self.config.agent_name.lower() |
|
|
if "planner" in agent_name: |
|
|
return "planner" |
|
|
elif "information" in agent_name or "seeker" in agent_name: |
|
|
return "information_seeker" |
|
|
elif "writer" in agent_name: |
|
|
return "writer" |
|
|
else: |
|
|
|
|
|
return "planner" |
|
|
|
|
|
def _create_filtered_mcp_tools(self, shared_client, agent_type: str): |
|
|
"""Create filtered MCP tools adapter using shared client""" |
|
|
try: |
|
|
from src.tools.mcp_client import create_filtered_mcp_tools_adapter |
|
|
return create_filtered_mcp_tools_adapter(shared_client, agent_type) |
|
|
except ImportError: |
|
|
|
|
|
self.logger.warning("FilteredMCPToolsAdapter not available, using regular adapter") |
|
|
from src.tools.mcp_client import MCPToolsAdapter |
|
|
adapter = MCPToolsAdapter.__new__(MCPToolsAdapter) |
|
|
adapter.client = shared_client |
|
|
return adapter |
|
|
|
|
|
def _create_filtered_mcp_tools_standalone(self): |
|
|
"""Create filtered MCP tools adapter with its own client connection""" |
|
|
try: |
|
|
|
|
|
agent_type = self._get_agent_type() |
|
|
|
|
|
|
|
|
client = self._create_new_mcp_client() |
|
|
|
|
|
|
|
|
from src.tools.mcp_client import create_filtered_mcp_tools_adapter |
|
|
filtered_adapter = create_filtered_mcp_tools_adapter(client, agent_type) |
|
|
|
|
|
self.logger.info(f"Agent {self.config.agent_name} created filtered MCP adapter with {agent_type} tools") |
|
|
return filtered_adapter |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"Failed to create filtered MCP tools: {e}") |
|
|
raise RuntimeError(f"Failed to create filtered MCP client for {self.config.agent_name}: {e}") |
|
|
|
|
|
def _create_new_mcp_client(self): |
|
|
"""Create a new MCP client connection""" |
|
|
try: |
|
|
|
|
|
from config.config import get_mcp_config |
|
|
mcp_config = get_mcp_config() |
|
|
|
|
|
|
|
|
from src.tools.mcp_client import MCPClient |
|
|
|
|
|
if mcp_config.get("server_url") and not mcp_config.get("use_stdio", True): |
|
|
|
|
|
client = MCPClient(server_url=mcp_config["server_url"]) |
|
|
self.logger.info( |
|
|
f"Agent {self.config.agent_name} connected to HTTP MCP server: {mcp_config['server_url']}") |
|
|
else: |
|
|
|
|
|
client = MCPClient(server_url="http://localhost:6274/mcp") |
|
|
self.logger.info( |
|
|
f"Agent {self.config.agent_name} connected to default HTTP MCP server: http://localhost:6274/mcp") |
|
|
|
|
|
return client |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"Failed to create MCP client: {e}") |
|
|
raise RuntimeError(f"MCP client creation failed for {self.config.agent_name}: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_session_info(self) -> Optional[Dict[str, Any]]: |
|
|
"""Get information about the current server-managed session""" |
|
|
try: |
|
|
|
|
|
if hasattr(self.mcp_tools, 'get_session_info'): |
|
|
session_info = self.mcp_tools.get_session_info() |
|
|
if session_info: |
|
|
|
|
|
session_info.update({ |
|
|
"server_managed": True, |
|
|
"agent_name": self.config.agent_name |
|
|
}) |
|
|
return session_info |
|
|
|
|
|
|
|
|
if hasattr(self.mcp_tools, 'client'): |
|
|
client = self.mcp_tools.client |
|
|
|
|
|
|
|
|
if hasattr(client, '_session_id') and hasattr(client, 'is_connected'): |
|
|
return { |
|
|
"session_id": client._session_id, |
|
|
"server_managed": True, |
|
|
"agent_name": self.config.agent_name, |
|
|
"connected": client.is_connected() |
|
|
} |
|
|
|
|
|
|
|
|
if hasattr(self.mcp_tools, '_session_id'): |
|
|
return { |
|
|
"session_id": self.mcp_tools._session_id, |
|
|
"server_managed": True, |
|
|
"agent_name": self.config.agent_name, |
|
|
"connected": getattr(self.mcp_tools, 'is_connected', lambda: True)() |
|
|
} |
|
|
|
|
|
|
|
|
return { |
|
|
"session_id": None, |
|
|
"server_managed": True, |
|
|
"agent_name": self.config.agent_name, |
|
|
"connected": hasattr(self.mcp_tools, 'client') and getattr(self.mcp_tools.client, 'is_connected', |
|
|
lambda: False)() |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.warning(f"Failed to get session info: {e}") |
|
|
return { |
|
|
"session_id": None, |
|
|
"server_managed": True, |
|
|
"agent_name": self.config.agent_name, |
|
|
"connected": False, |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
def _build_tool_schemas(self) -> List[Dict[str, Any]]: |
|
|
"""Build tool schemas for function calling""" |
|
|
schemas = [] |
|
|
|
|
|
|
|
|
agent_schemas = self._build_agent_specific_tool_schemas() |
|
|
schemas.extend(agent_schemas) |
|
|
|
|
|
return schemas |
|
|
|
|
|
def _build_agent_specific_tool_schemas(self) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Build agent-specific tool schemas using proper MCP architecture. |
|
|
Schemas come from MCP server via client, not direct imports. |
|
|
""" |
|
|
schemas = [] |
|
|
|
|
|
|
|
|
try: |
|
|
if hasattr(self.mcp_tools, 'get_tool_schemas'): |
|
|
|
|
|
schemas = self.mcp_tools.get_tool_schemas() |
|
|
self.logger.info(f"Retrieved {len(schemas)} tool schemas from MCP server") |
|
|
else: |
|
|
|
|
|
self.logger.warning("MCP adapter doesn't support get_tool_schemas, using fallback") |
|
|
schemas = self._build_fallback_schemas() |
|
|
except Exception as e: |
|
|
self.logger.warning(f"Failed to get schemas from MCP client: {e}, using fallback") |
|
|
schemas = self._build_fallback_schemas() |
|
|
|
|
|
return schemas |
|
|
|
|
|
def _build_fallback_schemas(self) -> List[Dict[str, Any]]: |
|
|
"""Fallback schema building if MCP client method fails""" |
|
|
schemas = [] |
|
|
|
|
|
|
|
|
if hasattr(self.mcp_tools, 'get_available_tools'): |
|
|
try: |
|
|
available_tools = self.mcp_tools.get_available_tools() |
|
|
for tool_name, tool_info in available_tools.items(): |
|
|
schema = { |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": tool_name, |
|
|
"description": getattr(tool_info, 'description', f"Tool: {tool_name}"), |
|
|
"parameters": getattr(tool_info, 'input_schema', {"type": "object", "properties": {}, "required": []}) |
|
|
} |
|
|
} |
|
|
schemas.append(schema) |
|
|
self.logger.info(f"Built {len(schemas)} schemas using fallback method") |
|
|
except Exception as e: |
|
|
self.logger.warning(f"Fallback schema building failed: {e}") |
|
|
|
|
|
return schemas |
|
|
|
|
|
def execute_tool_call(self, tool_call) -> Dict[str, Any]: |
|
|
"""Execute a tool call and return results using proper MCP architecture""" |
|
|
tool_name = tool_call["name"] |
|
|
|
|
|
try: |
|
|
|
|
|
arguments = tool_call["arguments"] |
|
|
|
|
|
|
|
|
if tool_name not in self.available_tools: |
|
|
return { |
|
|
"success": False, |
|
|
"error": f"Tool '{tool_name}' not available for this agent" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if callable(self.available_tools[tool_name]): |
|
|
|
|
|
tool_function = self.available_tools[tool_name] |
|
|
result = tool_function(**arguments) |
|
|
|
|
|
|
|
|
if hasattr(result, 'to_dict'): |
|
|
return result.to_dict() |
|
|
elif isinstance(result, dict): |
|
|
return result |
|
|
else: |
|
|
return { |
|
|
"success": True, |
|
|
"data": result, |
|
|
"error": None, |
|
|
"metadata": {} |
|
|
} |
|
|
|
|
|
elif hasattr(self.mcp_tools, 'client') and hasattr(self.mcp_tools.client, 'call_tool'): |
|
|
|
|
|
result = self.mcp_tools.client.call_tool(tool_name, arguments) |
|
|
|
|
|
|
|
|
if hasattr(result, 'success'): |
|
|
return { |
|
|
"success": result.success, |
|
|
"data": result.data, |
|
|
"error": result.error, |
|
|
"metadata": getattr(result, 'metadata', {}) |
|
|
} |
|
|
else: |
|
|
return result |
|
|
else: |
|
|
return { |
|
|
"success": False, |
|
|
"error": f"Tool '{tool_name}' is not executable (neither built-in nor MCP)" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"Error executing tool {tool_name}: {e}") |
|
|
return { |
|
|
"success": False, |
|
|
"error": f"Tool execution failed: {str(e)}" |
|
|
} |
|
|
|
|
|
def log_reasoning(self, iteration: int, reasoning: str): |
|
|
"""Log reasoning step in the trace""" |
|
|
self.reasoning_trace.append({ |
|
|
"type": "reasoning", |
|
|
"iteration": iteration, |
|
|
"content": reasoning, |
|
|
"timestamp": time.time() |
|
|
}) |
|
|
self.execution_stats["reasoning_steps"] += 1 |
|
|
self.execution_stats["total_steps"] += 1 |
|
|
self.logger.info(f"Reasoning (Iter {iteration}): {reasoning[:100]}...") |
|
|
|
|
|
def log_action(self, iteration: int, tool: str, arguments: Dict[str, Any], result: Dict[str, Any]): |
|
|
"""Log action step in the trace""" |
|
|
self.reasoning_trace.append({ |
|
|
"type": "action", |
|
|
"iteration": iteration, |
|
|
"tool": tool, |
|
|
"arguments": arguments, |
|
|
"result": result, |
|
|
"timestamp": time.time() |
|
|
}) |
|
|
self.execution_stats["action_steps"] += 1 |
|
|
self.execution_stats["total_steps"] += 1 |
|
|
|
|
|
|
|
|
success = result.get("success", True) |
|
|
status = "Success" if success else "Failed" |
|
|
self.logger.info(f"Action (Iter {iteration}): {tool} -> {status} -> {str(arguments)[:400]}...") |
|
|
|
|
|
def log_error(self, iteration: int, error: str): |
|
|
"""Log error in the trace""" |
|
|
self.reasoning_trace.append({ |
|
|
"type": "error", |
|
|
"iteration": iteration, |
|
|
"error": error, |
|
|
"timestamp": time.time() |
|
|
}) |
|
|
self.execution_stats["error_steps"] += 1 |
|
|
self.execution_stats["total_steps"] += 1 |
|
|
self.logger.error(f"Error (Iter {iteration}): {error}") |
|
|
|
|
|
def reset_trace(self): |
|
|
"""Reset the reasoning trace for a new task""" |
|
|
self.reasoning_trace = [] |
|
|
self.execution_stats = { |
|
|
"total_steps": 0, |
|
|
"reasoning_steps": 0, |
|
|
"action_steps": 0, |
|
|
"error_steps": 0, |
|
|
"tool_usage": {}, |
|
|
"success_rate": 1.0 |
|
|
} |
|
|
|
|
|
def get_execution_stats(self) -> Dict[str, Any]: |
|
|
"""Get execution statistics""" |
|
|
|
|
|
if self.execution_stats["action_steps"] > 0: |
|
|
failed_actions = sum(1 for step in self.reasoning_trace |
|
|
if step.get("type") == "action" |
|
|
and not step.get("result", {}).get("success", True)) |
|
|
self.execution_stats["success_rate"] = ( |
|
|
(self.execution_stats["action_steps"] - failed_actions) / |
|
|
self.execution_stats["action_steps"] |
|
|
) |
|
|
|
|
|
return self.execution_stats.copy() |
|
|
|
|
|
def create_response(self, success: bool, result: Dict[str, Any] = None, |
|
|
error: str = None, iterations: int = 0, |
|
|
execution_time: float = 0.0) -> AgentResponse: |
|
|
"""Create a standardized agent response""" |
|
|
return AgentResponse( |
|
|
success=success, |
|
|
result=result, |
|
|
error=error, |
|
|
iterations=iterations, |
|
|
reasoning_trace=self.reasoning_trace.copy(), |
|
|
agent_name=self.config.agent_name, |
|
|
execution_time=execution_time |
|
|
) |
|
|
|
|
|
def validate_config(self) -> bool: |
|
|
"""Validate agent configuration""" |
|
|
try: |
|
|
|
|
|
if not self.config.agent_name: |
|
|
return False |
|
|
if not self.config.model: |
|
|
return False |
|
|
if self.config.max_iterations <= 0: |
|
|
return False |
|
|
if not (0.0 <= self.config.temperature <= 2.0): |
|
|
return False |
|
|
if self.config.max_tokens <= 0: |
|
|
return False |
|
|
|
|
|
return True |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
@abstractmethod |
|
|
def execute_task(self, task_input: TaskInput) -> AgentResponse: |
|
|
""" |
|
|
Execute a task using the standardized TaskInput format |
|
|
|
|
|
Args: |
|
|
task_input: TaskInput object with standardized task information |
|
|
|
|
|
Returns: |
|
|
AgentResponse with results and process trace |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def _build_system_prompt(self) -> str: |
|
|
"""Build the system prompt for this agent""" |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_agent_config( |
|
|
agent_name: str, |
|
|
model: Optional[str] = None, |
|
|
max_iterations: Optional[int] = None, |
|
|
temperature: Optional[float] = None, |
|
|
max_tokens: Optional[int] = None |
|
|
) -> AgentConfig: |
|
|
""" |
|
|
Create an AgentConfig instance for server-managed sessions. |
|
|
|
|
|
Args: |
|
|
agent_name: Name of the agent |
|
|
model: LLM model to use |
|
|
max_iterations: Maximum number of iterations |
|
|
temperature: LLM temperature setting |
|
|
max_tokens: Maximum tokens for LLM response |
|
|
|
|
|
Returns: |
|
|
Configured AgentConfig instance |
|
|
""" |
|
|
|
|
|
try: |
|
|
from config.config import get_config |
|
|
api_cfg = get_config() |
|
|
except Exception as e: |
|
|
raise ValueError(f"Failed to load global configuration: {e}") |
|
|
|
|
|
planner_mode = getattr(api_cfg, "planner_mode", "auto") |
|
|
|
|
|
resolved_model = model if model is not None else getattr(api_cfg, "model_name", None) |
|
|
if not resolved_model: |
|
|
raise ValueError("Model is not specified and MODEL_NAME is not set in environment") |
|
|
|
|
|
resolved_temperature = temperature if temperature is not None else getattr(api_cfg, "model_temperature", None) |
|
|
if resolved_temperature is None: |
|
|
raise ValueError("Temperature is not specified and MODEL_TEMPERATURE is not set in environment") |
|
|
|
|
|
resolved_max_tokens = max_tokens if max_tokens is not None else getattr(api_cfg, "model_max_tokens", None) |
|
|
if resolved_max_tokens is None: |
|
|
raise ValueError("Max tokens is not specified and MODEL_MAX_TOKENS is not set in environment") |
|
|
|
|
|
|
|
|
trajectory_storage_path = getattr(api_cfg, "trajectory_storage_path", None) |
|
|
report_output_path = getattr(api_cfg, "report_output_path", None) |
|
|
document_analysis_path = getattr(api_cfg, "document_analysis_path", None) |
|
|
|
|
|
|
|
|
if max_iterations is None: |
|
|
agent_lower = (agent_name or "").lower() |
|
|
resolved_max_iterations = None |
|
|
if "planner" in agent_lower: |
|
|
resolved_max_iterations = getattr(api_cfg, "planner_max_iterations", None) |
|
|
elif "writer" in agent_lower: |
|
|
resolved_max_iterations = getattr(api_cfg, "writer_max_iterations", None) |
|
|
elif "information" in agent_lower or "seeker" in agent_lower: |
|
|
resolved_max_iterations = getattr(api_cfg, "information_seeker_max_iterations", None) |
|
|
|
|
|
if resolved_max_iterations is None: |
|
|
raise ValueError("Max iterations not specified and no env override (PLANNER_MAX_ITERATION/WRITER_MAX_ITERATION/INFORMATION_SEEKER_MAX_ITERATION)") |
|
|
max_iterations = resolved_max_iterations |
|
|
|
|
|
return AgentConfig( |
|
|
agent_name=agent_name, |
|
|
planner_mode=planner_mode, |
|
|
model=resolved_model, |
|
|
max_iterations=int(max_iterations), |
|
|
temperature=resolved_temperature, |
|
|
max_tokens=resolved_max_tokens, |
|
|
trajectory_storage_path=trajectory_storage_path, |
|
|
report_output_path=report_output_path, |
|
|
document_analysis_path=document_analysis_path |
|
|
) |