Spaces:
Running
Running
| """ | |
| PostgreSQL LISTEN/NOTIFY β event-driven change detection. | |
| Tables emit NOTIFY on INSERT/UPDATE via triggers. | |
| The hub listens and sets asyncio.Event flags so WS channels | |
| only re-query when data actually changed. | |
| """ | |
| import asyncio | |
| import asyncpg | |
| from loguru import logger | |
| from src.config import get_settings | |
| # ββ Channel β event flag mapping ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Any PG channel name maps to one or more dirty flags. | |
| # Dirty flags are checked by the WS hub; cleared after a re-query. | |
| _flags: dict[str, asyncio.Event] = { | |
| "posts_changed": asyncio.Event(), | |
| "trends_changed": asyncio.Event(), | |
| "workflows_changed": asyncio.Event(), | |
| "notifs_changed": asyncio.Event(), | |
| "post_step": asyncio.Event(), | |
| "scheduler_changed": asyncio.Event(), | |
| } | |
| # Map PG NOTIFY channel β which flags to set | |
| _PG_TO_FLAGS: dict[str, list[str]] = { | |
| "posts_changed": ["posts_changed"], | |
| "trends_changed": ["trends_changed"], | |
| "workflows_changed": ["workflows_changed"], | |
| "notifs_changed": ["notifs_changed"], | |
| } | |
| def is_dirty(flag: str) -> bool: | |
| return _flags[flag].is_set() | |
| def mark_clean(flag: str): | |
| _flags[flag].clear() | |
| def mark_dirty(flag: str): | |
| _flags[flag].set() | |
| async def wait_any(flags: list[str], timeout: float) -> bool: | |
| """Return True if any flag fires within timeout, False on timeout.""" | |
| events = [_flags[f] for f in flags if f in _flags] | |
| if not events: | |
| await asyncio.sleep(timeout) | |
| return False | |
| done, _ = await asyncio.wait( | |
| [asyncio.create_task(e.wait()) for e in events], | |
| timeout=timeout, | |
| return_when=asyncio.FIRST_COMPLETED, | |
| ) | |
| return bool(done) | |
| # ββ Listener task ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _listener_task: asyncio.Task | None = None | |
| async def _listen_loop(): | |
| settings = get_settings() | |
| # Prefer direct URL (required for Neon pooled setups); fall back to DATABASE_URL | |
| raw_url = settings.database_direct_url or settings.database_url | |
| if raw_url.startswith("postgresql+asyncpg://"): | |
| raw_url = raw_url.replace("postgresql+asyncpg://", "postgresql://", 1) | |
| raw_url = raw_url.split("?")[0] | |
| ssl_ctx = None | |
| if "neon.tech" in raw_url or "sslmode" in settings.database_url: | |
| import ssl | |
| ssl_ctx = ssl.create_default_context() | |
| ssl_ctx.check_hostname = False | |
| ssl_ctx.verify_mode = ssl.CERT_NONE | |
| while True: | |
| conn = None | |
| try: | |
| conn = await asyncpg.connect(raw_url, ssl=ssl_ctx) | |
| logger.info("[Notify] LISTEN connection established") | |
| def _on_notify(conn, pid, channel, payload): | |
| for flag in _PG_TO_FLAGS.get(channel, []): | |
| mark_dirty(flag) | |
| logger.debug(f"[Notify] dirty: {flag} (from {channel})") | |
| for ch in _PG_TO_FLAGS: | |
| await conn.add_listener(ch, _on_notify) | |
| # Keep alive β wait indefinitely, re-connect on failure | |
| while True: | |
| await asyncio.sleep(30) | |
| await conn.execute("SELECT 1") | |
| except asyncio.CancelledError: | |
| break | |
| except Exception as e: | |
| logger.warning(f"[Notify] connection lost: {e} β reconnecting in 5s") | |
| await asyncio.sleep(5) | |
| finally: | |
| if conn and not conn.is_closed(): | |
| try: | |
| await conn.close() | |
| except Exception: | |
| pass | |
| def start_listener(): | |
| global _listener_task | |
| if _listener_task and not _listener_task.done(): | |
| return | |
| _listener_task = asyncio.create_task(_listen_loop()) | |
| logger.info("[Notify] listener task started") | |
| def stop_listener(): | |
| global _listener_task | |
| if _listener_task: | |
| _listener_task.cancel() | |
| _listener_task = None | |