File size: 5,860 Bytes
6155b26 754345f 6155b26 754345f 6155b26 754345f 6155b26 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | import asyncio
import logging
from collections.abc import Iterable
import httpx
from agent.messaging.base import (
NotificationError,
NotificationProvider,
RetryableNotificationError,
)
from agent.messaging.models import (
MessagingConfig,
NotificationRequest,
NotificationResult,
)
from agent.messaging.slack import SlackProvider
logger = logging.getLogger(__name__)
_RETRY_DELAYS = (1, 2, 4)
class NotificationGateway:
def __init__(self, config: MessagingConfig):
self.config = config
self._providers: dict[str, NotificationProvider] = {
"slack": SlackProvider(),
}
self._queue: asyncio.Queue[NotificationRequest] = asyncio.Queue()
self._worker_task: asyncio.Task | None = None
self._client: httpx.AsyncClient | None = None
@property
def enabled(self) -> bool:
return self.config.enabled
async def start(self) -> None:
if not self.enabled or self._worker_task is not None:
return
self._client = httpx.AsyncClient(timeout=10.0)
self._worker_task = asyncio.create_task(
self._worker(), name="notification-gateway"
)
async def flush(self) -> None:
if not self.enabled:
return
await self._queue.join()
async def close(self) -> None:
if not self.enabled:
return
await self.flush()
if self._worker_task is not None:
self._worker_task.cancel()
try:
await self._worker_task
except asyncio.CancelledError:
pass
self._worker_task = None
if self._client is not None:
await self._client.aclose()
self._client = None
async def send(self, request: NotificationRequest) -> NotificationResult:
if not self.enabled:
return NotificationResult(
destination=request.destination,
ok=False,
provider="disabled",
error="Messaging is disabled",
)
destination = self.config.get_destination(request.destination)
if destination is None:
return NotificationResult(
destination=request.destination,
ok=False,
provider="unknown",
error=f"Unknown destination '{request.destination}'",
)
provider = self._providers.get(destination.provider)
if provider is None:
return NotificationResult(
destination=request.destination,
ok=False,
provider=destination.provider,
error=f"No provider implementation for '{destination.provider}'",
)
return await self._send_with_retries(
provider, request.destination, destination, request
)
async def send_many(
self, requests: Iterable[NotificationRequest]
) -> list[NotificationResult]:
results: list[NotificationResult] = []
for request in requests:
results.append(await self.send(request))
return results
async def enqueue(self, request: NotificationRequest) -> bool:
if not self.enabled or self._worker_task is None:
return False
await self._queue.put(request)
return True
async def _worker(self) -> None:
while True:
request = await self._queue.get()
try:
result = await self.send(request)
if not result.ok:
logger.warning(
"Notification delivery failed for %s: %s",
request.destination,
result.error,
)
except Exception:
logger.exception("Unexpected notification worker failure")
finally:
self._queue.task_done()
async def _send_with_retries(
self,
provider: NotificationProvider,
destination_name: str,
destination,
request: NotificationRequest,
) -> NotificationResult:
client = self._client or httpx.AsyncClient(timeout=10.0)
owns_client = self._client is None
try:
for attempt in range(len(_RETRY_DELAYS) + 1):
try:
return await provider.send(
client, destination_name, destination, request
)
except RetryableNotificationError as exc:
if attempt >= len(_RETRY_DELAYS):
return NotificationResult(
destination=destination_name,
ok=False,
provider=provider.provider_name,
error=str(exc),
)
delay = _RETRY_DELAYS[attempt]
logger.warning(
"Retrying notification to %s in %ss after transient error: %s",
destination_name,
delay,
exc,
)
await asyncio.sleep(delay)
except NotificationError as exc:
return NotificationResult(
destination=destination_name,
ok=False,
provider=provider.provider_name,
error=str(exc),
)
return NotificationResult(
destination=destination_name,
ok=False,
provider=provider.provider_name,
error="Notification delivery exhausted retries",
)
finally:
if owns_client:
await client.aclose()
|