Spaces:
Running
Running
File size: 6,580 Bytes
399b80c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | from __future__ import annotations
import json
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Any
from openspace.utils.logging import Logger
if TYPE_CHECKING:
from openspace.llm import LLMClient
from openspace.grounding.core.grounding_client import GroundingClient
from openspace.recording import RecordingManager
logger = Logger.get_logger(__name__)
class BaseAgent(ABC):
def __init__(
self,
name: str,
backend_scope: Optional[List[str]] = None,
llm_client: Optional[LLMClient] = None,
grounding_client: Optional[GroundingClient] = None,
recording_manager: Optional[RecordingManager] = None,
) -> None:
"""
Initialize the BaseAgent.
Args:
name: Unique name for the agent
backend_scope: List of backend types this agent can access (e.g., ["gui", "shell", "mcp", "web", "system"])
llm_client: LLM client for agent reasoning (optional, can be set later)
grounding_client: Reference to GroundingClient for tool execution
recording_manager: RecordingManager for recording execution
"""
self._name = name
self._grounding_client: Optional[GroundingClient] = grounding_client
self._backend_scope = backend_scope or []
self._llm_client = llm_client
self._recording_manager: Optional[RecordingManager] = recording_manager
self._step = 0
self._status = AgentStatus.ACTIVE
self._register_self()
logger.info(f"Initialized {self.__class__.__name__}: {name}")
@property
def name(self) -> str:
return self._name
@property
def grounding_client(self) -> Optional[GroundingClient]:
"""Get the grounding client."""
return self._grounding_client
@property
def backend_scope(self) -> List[str]:
return self._backend_scope
@property
def llm_client(self) -> Optional[LLMClient]:
return self._llm_client
@llm_client.setter
def llm_client(self, client: LLMClient) -> None:
self._llm_client = client
@property
def recording_manager(self) -> Optional[RecordingManager]:
"""Get the recording manager."""
return self._recording_manager
@property
def step(self) -> int:
return self._step
@property
def status(self) -> str:
return self._status
@abstractmethod
async def process(self, context: Dict[str, Any]) -> Dict[str, Any]:
pass
@abstractmethod
def construct_messages(self, context: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Construct messages for LLM reasoning.
Context must contain 'instruction' key.
"""
pass
async def get_llm_response(
self,
messages: List[Dict[str, Any]],
tools: Optional[List] = None,
**kwargs
) -> Dict[str, Any]:
if not self._llm_client:
raise ValueError(f"LLM client not initialized for agent {self.name}")
try:
response = await self._llm_client.complete(
messages=messages,
tools=tools,
**kwargs
)
return response
except Exception as e:
logger.error(f"{self.name}: LLM call failed: {e}", exc_info=True)
raise
def response_to_dict(self, response: str) -> Dict[str, Any]:
try:
if response.strip().startswith("```json") or response.strip().startswith("```"):
lines = response.strip().split('\n')
if lines and lines[0].startswith('```'):
lines = lines[1:]
end_idx = len(lines)
for i, line in enumerate(lines):
if line.strip() == '```':
end_idx = i
break
response = '\n'.join(lines[:end_idx])
return json.loads(response)
except json.JSONDecodeError as e:
# If parsing fails, try to find and extract just the JSON object/array
if "Extra data" in str(e):
try:
decoder = json.JSONDecoder()
obj, idx = decoder.raw_decode(response)
logger.warning(
f"{self.name}: Successfully extracted JSON but found extra text after position {idx}. "
f"Extra text: {response[idx:idx+100]}..."
)
return obj
except Exception as e2:
logger.error(f"{self.name}: Failed to extract JSON even with raw_decode: {e2}")
logger.error(f"{self.name}: Failed to parse response: {e}")
logger.error(f"{self.name}: Response content: {response[:500]}")
return {"error": "Failed to parse response", "raw": response}
def increment_step(self) -> None:
self._step += 1
@classmethod
def _register_self(cls) -> None:
"""Register the agent class in the registry upon instantiation."""
# Get the actual instance class, not BaseAgent
if cls.__name__ != "BaseAgent" and cls.__name__ not in AgentRegistry._registry:
AgentRegistry.register(cls.__name__, cls)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(name={self.name}, step={self.step}, status={self.status})>"
class AgentStatus:
"""Constants for agent status."""
ACTIVE = "active"
IDLE = "idle"
WAITING = "waiting"
class AgentRegistry:
"""
Registry for managing agent classes.
Allows dynamic registration and retrieval of agent types.
"""
_registry: Dict[str, Type[BaseAgent]] = {}
@classmethod
def register(cls, name: str, agent_cls: Type[BaseAgent]) -> None:
if name in cls._registry:
logger.warning(f"Agent class '{name}' already registered, overwriting")
cls._registry[name] = agent_cls
logger.debug(f"Registered agent class: {name}")
@classmethod
def get_cls(cls, name: str) -> Type[BaseAgent]:
if name not in cls._registry:
raise ValueError(f"No agent class registered under '{name}'")
return cls._registry[name]
@classmethod
def list_registered(cls) -> List[str]:
return list(cls._registry.keys())
@classmethod
def clear(cls) -> None:
cls._registry.clear()
logger.debug("Agent registry cleared") |