Spaces:
Paused
Paused
File size: 8,034 Bytes
8d1819a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import uuid
from typing import Any, Dict, List, Optional
from python.helpers.print_style import PrintStyle
try:
from fasta2a.client import A2AClient # type: ignore
import httpx # type: ignore
FASTA2A_CLIENT_AVAILABLE = True
except ImportError:
FASTA2A_CLIENT_AVAILABLE = False
PrintStyle.warning("FastA2A client not available. Agent-to-agent communication disabled.")
_PRINTER = PrintStyle(italic=True, font_color="cyan", padding=False)
class AgentConnection:
"""Helper class for connecting to and communicating with other Agent Zero instances via FastA2A."""
def __init__(self, agent_url: str, timeout: int = 30, token: Optional[str] = None):
"""Initialize connection to an agent.
Args:
agent_url: The base URL of the agent (e.g., "https://agent.example.com")
timeout: Request timeout in seconds
"""
if not FASTA2A_CLIENT_AVAILABLE:
raise RuntimeError("FastA2A client not available")
# Ensure scheme is present
if not agent_url.startswith(('http://', 'https://')):
agent_url = 'http://' + agent_url
self.agent_url = agent_url.rstrip('/')
self.timeout = timeout
# Auth headers
if token is None:
import os
token = os.getenv("A2A_TOKEN")
headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
headers["X-API-KEY"] = token
self._http_client = httpx.AsyncClient(timeout=timeout, headers=headers) # type: ignore
self._a2a_client = A2AClient(base_url=self.agent_url, http_client=self._http_client) # type: ignore
self._agent_card: Optional[Dict[str, Any]] = None
# Track conversation context automatically
self._context_id: Optional[str] = None
async def get_agent_card(self) -> Dict[str, Any]:
"""Retrieve the agent card from the remote agent."""
if self._agent_card is None:
try:
response = await self._http_client.get(f"{self.agent_url}/.well-known/agent.json")
response.raise_for_status()
self._agent_card = response.json()
_PRINTER.print(f"Retrieved agent card from {self.agent_url}")
_PRINTER.print(f"Agent: {self._agent_card.get('name', 'Unknown')}") # type: ignore
_PRINTER.print(f"Description: {self._agent_card.get('description', 'No description')}") # type: ignore
except Exception as e:
# Fallback: if URL contains '/a2a', try root path without it
if "/a2a" in self.agent_url:
root_url = self.agent_url.split("/a2a", 1)[0]
try:
response = await self._http_client.get(f"{root_url}/.well-known/agent.json")
response.raise_for_status()
self._agent_card = response.json()
_PRINTER.print(f"Retrieved agent card from {root_url}")
except Exception:
pass # swallow, will re-raise below
_PRINTER.print(f"[!] Could not connect to {self.agent_url}\n → Ensure the server is running and reachable.\n → Full error: {e}")
raise RuntimeError(f"Could not retrieve agent card: {e}")
return self._agent_card # type: ignore
async def send_message(
self,
message: str,
attachments: Optional[List[str]] = None,
context_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Send a message to the remote agent and return task response."""
if not self._agent_card:
await self.get_agent_card()
# Re-use context automatically if caller did not supply one
if context_id is None:
context_id = self._context_id
# Build message parts
parts = [{'kind': 'text', 'text': message}]
if attachments:
for attachment in attachments:
file_part = {'kind': 'file', 'file': {'uri': attachment}}
parts.append(file_part) # type: ignore
# Construct A2A message
a2a_message = {
'role': 'user',
'parts': parts,
'kind': 'message',
'message_id': str(uuid.uuid4())
}
if context_id is not None:
a2a_message['context_id'] = context_id
# Send using the message/send method (not send_task)
try:
response = await self._a2a_client.send_message(
message=a2a_message, # type: ignore
metadata=metadata,
configuration={'accepted_output_modes': ['application/json', 'text/plain'], 'blocking': True} # type: ignore
)
# Persist context id for subsequent calls
try:
ctx = response.get('result', {}).get('context_id') # type: ignore[index]
if isinstance(ctx, str):
self._context_id = ctx
except Exception:
pass # ignore if structure differs
return response # type: ignore
except Exception as e:
_PRINTER.print(f"[A2A] Error sending message: {e}")
raise
async def get_task(self, task_id: str) -> Dict[str, Any]:
"""Get the status and results of a task.
Args:
task_id: The ID of the task to query
Returns:
Dictionary containing the task information
"""
try:
response = await self._a2a_client.get_task(task_id) # type: ignore
return response # type: ignore
except Exception as e:
_PRINTER.print(f"Failed to get task {task_id}: {e}")
raise RuntimeError(f"Failed to get task: {e}")
async def wait_for_completion(self, task_id: str, poll_interval: int = 2, max_wait: int = 300) -> Dict[str, Any]:
"""Wait for a task to complete and return the final result.
Args:
task_id: The ID of the task to wait for
poll_interval: How often to check task status (seconds)
max_wait: Maximum time to wait (seconds)
Returns:
Dictionary containing the completed task information
"""
import asyncio
waited = 0
while waited < max_wait:
task_info = await self.get_task(task_id)
if 'result' in task_info:
task = task_info['result']
status = task.get('status', {})
state = status.get('state', 'unknown')
if state in ['completed', 'failed', 'canceled']:
_PRINTER.print(f"Task {task_id} finished with state: {state}")
return task_info
else:
_PRINTER.print(f"Task {task_id} status: {state}")
await asyncio.sleep(poll_interval)
waited += poll_interval
raise TimeoutError(f"Task {task_id} did not complete within {max_wait} seconds")
async def close(self):
"""Close the HTTP client connection."""
await self._http_client.aclose()
async def __aenter__(self):
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.close()
async def connect_to_agent(agent_url: str, timeout: int = 30) -> AgentConnection:
"""Create a connection to a remote agent.
Args:
agent_url: The base URL of the agent
timeout: Request timeout in seconds
Returns:
AgentConnection instance
"""
connection = AgentConnection(agent_url, timeout)
# Verify connection by retrieving agent card
await connection.get_agent_card()
return connection
def is_client_available() -> bool:
"""Check if FastA2A client is available."""
return FASTA2A_CLIENT_AVAILABLE
|