orbis-backend / src /utils /notify.py
Deusxx1234's picture
fix(scheduler): never block API event loop β€” mutex + thread pool
4a69d62
"""
PostgreSQL LISTEN/NOTIFY β€” event-driven change detection.
Tables emit NOTIFY on INSERT/UPDATE via triggers.
The hub listens and sets asyncio.Event flags so WS channels
only re-query when data actually changed.
"""
import asyncio
import asyncpg
from loguru import logger
from src.config import get_settings
# ── Channel β†’ event flag mapping ──────────────────────────────────────────────
# Any PG channel name maps to one or more dirty flags.
# Dirty flags are checked by the WS hub; cleared after a re-query.
_flags: dict[str, asyncio.Event] = {
"posts_changed": asyncio.Event(),
"trends_changed": asyncio.Event(),
"workflows_changed": asyncio.Event(),
"notifs_changed": asyncio.Event(),
"post_step": asyncio.Event(),
"scheduler_changed": asyncio.Event(),
}
# Map PG NOTIFY channel β†’ which flags to set
_PG_TO_FLAGS: dict[str, list[str]] = {
"posts_changed": ["posts_changed"],
"trends_changed": ["trends_changed"],
"workflows_changed": ["workflows_changed"],
"notifs_changed": ["notifs_changed"],
}
def is_dirty(flag: str) -> bool:
return _flags[flag].is_set()
def mark_clean(flag: str):
_flags[flag].clear()
def mark_dirty(flag: str):
_flags[flag].set()
async def wait_any(flags: list[str], timeout: float) -> bool:
"""Return True if any flag fires within timeout, False on timeout."""
events = [_flags[f] for f in flags if f in _flags]
if not events:
await asyncio.sleep(timeout)
return False
done, _ = await asyncio.wait(
[asyncio.create_task(e.wait()) for e in events],
timeout=timeout,
return_when=asyncio.FIRST_COMPLETED,
)
return bool(done)
# ── Listener task ──────────────────────────────────────────────────────────────
_listener_task: asyncio.Task | None = None
async def _listen_loop():
settings = get_settings()
# Prefer direct URL (required for Neon pooled setups); fall back to DATABASE_URL
raw_url = settings.database_direct_url or settings.database_url
if raw_url.startswith("postgresql+asyncpg://"):
raw_url = raw_url.replace("postgresql+asyncpg://", "postgresql://", 1)
raw_url = raw_url.split("?")[0]
ssl_ctx = None
if "neon.tech" in raw_url or "sslmode" in settings.database_url:
import ssl
ssl_ctx = ssl.create_default_context()
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE
while True:
conn = None
try:
conn = await asyncpg.connect(raw_url, ssl=ssl_ctx)
logger.info("[Notify] LISTEN connection established")
def _on_notify(conn, pid, channel, payload):
for flag in _PG_TO_FLAGS.get(channel, []):
mark_dirty(flag)
logger.debug(f"[Notify] dirty: {flag} (from {channel})")
for ch in _PG_TO_FLAGS:
await conn.add_listener(ch, _on_notify)
# Keep alive β€” wait indefinitely, re-connect on failure
while True:
await asyncio.sleep(30)
await conn.execute("SELECT 1")
except asyncio.CancelledError:
break
except Exception as e:
logger.warning(f"[Notify] connection lost: {e} β€” reconnecting in 5s")
await asyncio.sleep(5)
finally:
if conn and not conn.is_closed():
try:
await conn.close()
except Exception:
pass
def start_listener():
global _listener_task
if _listener_task and not _listener_task.done():
return
_listener_task = asyncio.create_task(_listen_loop())
logger.info("[Notify] listener task started")
def stop_listener():
global _listener_task
if _listener_task:
_listener_task.cancel()
_listener_task = None