""" PostgreSQL LISTEN/NOTIFY — event-driven change detection. Tables emit NOTIFY on INSERT/UPDATE via triggers. The hub listens and sets asyncio.Event flags so WS channels only re-query when data actually changed. """ import asyncio import asyncpg from loguru import logger from src.config import get_settings # ── Channel → event flag mapping ────────────────────────────────────────────── # Any PG channel name maps to one or more dirty flags. # Dirty flags are checked by the WS hub; cleared after a re-query. _flags: dict[str, asyncio.Event] = { "posts_changed": asyncio.Event(), "trends_changed": asyncio.Event(), "workflows_changed": asyncio.Event(), "notifs_changed": asyncio.Event(), "post_step": asyncio.Event(), "scheduler_changed": asyncio.Event(), } # Map PG NOTIFY channel → which flags to set _PG_TO_FLAGS: dict[str, list[str]] = { "posts_changed": ["posts_changed"], "trends_changed": ["trends_changed"], "workflows_changed": ["workflows_changed"], "notifs_changed": ["notifs_changed"], } def is_dirty(flag: str) -> bool: return _flags[flag].is_set() def mark_clean(flag: str): _flags[flag].clear() def mark_dirty(flag: str): _flags[flag].set() async def wait_any(flags: list[str], timeout: float) -> bool: """Return True if any flag fires within timeout, False on timeout.""" events = [_flags[f] for f in flags if f in _flags] if not events: await asyncio.sleep(timeout) return False done, _ = await asyncio.wait( [asyncio.create_task(e.wait()) for e in events], timeout=timeout, return_when=asyncio.FIRST_COMPLETED, ) return bool(done) # ── Listener task ────────────────────────────────────────────────────────────── _listener_task: asyncio.Task | None = None async def _listen_loop(): settings = get_settings() # Prefer direct URL (required for Neon pooled setups); fall back to DATABASE_URL raw_url = settings.database_direct_url or settings.database_url if raw_url.startswith("postgresql+asyncpg://"): raw_url = raw_url.replace("postgresql+asyncpg://", "postgresql://", 1) raw_url = raw_url.split("?")[0] ssl_ctx = None if "neon.tech" in raw_url or "sslmode" in settings.database_url: import ssl ssl_ctx = ssl.create_default_context() ssl_ctx.check_hostname = False ssl_ctx.verify_mode = ssl.CERT_NONE while True: conn = None try: conn = await asyncpg.connect(raw_url, ssl=ssl_ctx) logger.info("[Notify] LISTEN connection established") def _on_notify(conn, pid, channel, payload): for flag in _PG_TO_FLAGS.get(channel, []): mark_dirty(flag) logger.debug(f"[Notify] dirty: {flag} (from {channel})") for ch in _PG_TO_FLAGS: await conn.add_listener(ch, _on_notify) # Keep alive — wait indefinitely, re-connect on failure while True: await asyncio.sleep(30) await conn.execute("SELECT 1") except asyncio.CancelledError: break except Exception as e: logger.warning(f"[Notify] connection lost: {e} — reconnecting in 5s") await asyncio.sleep(5) finally: if conn and not conn.is_closed(): try: await conn.close() except Exception: pass def start_listener(): global _listener_task if _listener_task and not _listener_task.done(): return _listener_task = asyncio.create_task(_listen_loop()) logger.info("[Notify] listener task started") def stop_listener(): global _listener_task if _listener_task: _listener_task.cancel() _listener_task = None