wayydb-api / api /pubsub.py
rcgalbo's picture
Deploy wayyDB to HuggingFace Spaces
bf20cb7
"""
WayyDB PubSub Abstraction Layer
Provides a pluggable pub/sub transport for real-time tick distribution.
Two backends:
- InMemoryPubSub: Default, zero dependencies, single-process
- RedisPubSub: Optional, requires redis-py, multi-process capable
Configure via REDIS_URL environment variable:
- Not set or empty: uses InMemoryPubSub
- Set to redis://...: uses RedisPubSub
Channel naming convention:
ticks:{symbol} - Trade ticks for a symbol
quotes:{symbol} - Quote updates for a symbol
ticks:* - All trade ticks
{table}:{symbol} - Generic table:symbol pattern
"""
import asyncio
import logging
import time
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine, Dict, List, Optional, Set
logger = logging.getLogger(__name__)
# Type alias for async callback
AsyncCallback = Callable[[dict], Coroutine[Any, Any, None]]
@dataclass
class Message:
"""A pub/sub message with metadata."""
channel: str
data: dict
sequence: int
timestamp: float = field(default_factory=time.time)
class PubSubBackend(ABC):
"""Abstract pub/sub backend interface.
Implementations must provide publish, subscribe, and unsubscribe.
This abstraction allows swapping between in-memory, Redis, NATS, etc.
"""
@abstractmethod
async def publish(self, channel: str, data: dict) -> int:
"""Publish a message to a channel.
Args:
channel: Channel name (e.g., "ticks:AAPL")
data: Message payload
Returns:
Sequence number of the published message
"""
...
@abstractmethod
async def subscribe(
self,
channel: str,
callback: AsyncCallback,
subscriber_id: str = "",
) -> None:
"""Subscribe to a channel with a callback.
Args:
channel: Channel name or pattern (e.g., "ticks:AAPL" or "ticks:*")
callback: Async function called with each message dict
subscriber_id: Unique identifier for this subscriber
"""
...
@abstractmethod
async def unsubscribe(self, channel: str, subscriber_id: str = "") -> None:
"""Unsubscribe from a channel.
Args:
channel: Channel name or pattern
subscriber_id: The subscriber to remove
"""
...
@abstractmethod
async def publish_batch(self, channel: str, messages: List[dict]) -> int:
"""Publish a batch of messages to a channel.
Args:
channel: Channel name
messages: List of message payloads
Returns:
Sequence number of the last message
"""
...
@abstractmethod
def get_stats(self) -> dict:
"""Get pub/sub statistics."""
...
@abstractmethod
async def start(self) -> None:
"""Start the backend (connect, initialize)."""
...
@abstractmethod
async def stop(self) -> None:
"""Stop the backend (disconnect, cleanup)."""
...
class InMemoryPubSub(PubSubBackend):
"""In-process pub/sub using asyncio.
Features:
- Channel-based routing with wildcard support
- Per-channel sequence numbers
- Ring buffer for backpressure (drops oldest on overflow)
- Concurrent broadcast via asyncio.gather
- Message replay from buffer
"""
def __init__(
self,
max_buffer_per_channel: int = 10000,
broadcast_timeout: float = 5.0,
):
self._subscribers: Dict[str, Dict[str, AsyncCallback]] = defaultdict(dict)
self._sequence: Dict[str, int] = defaultdict(int)
self._buffers: Dict[str, deque] = {}
self._max_buffer = max_buffer_per_channel
self._broadcast_timeout = broadcast_timeout
self._stats = {
"messages_published": 0,
"messages_delivered": 0,
"messages_dropped": 0,
"active_subscriptions": 0,
"channels": 0,
}
self._running = False
async def start(self) -> None:
self._running = True
logger.info("InMemoryPubSub started")
async def stop(self) -> None:
self._running = False
self._subscribers.clear()
self._buffers.clear()
logger.info("InMemoryPubSub stopped")
async def publish(self, channel: str, data: dict) -> int:
self._sequence[channel] += 1
seq = self._sequence[channel]
msg = Message(channel=channel, data=data, sequence=seq)
# Buffer the message
if channel not in self._buffers:
self._buffers[channel] = deque(maxlen=self._max_buffer)
buf = self._buffers[channel]
if len(buf) >= self._max_buffer:
self._stats["messages_dropped"] += 1
buf.append(msg)
self._stats["messages_published"] += 1
self._stats["channels"] = len(self._buffers)
# Deliver to subscribers
await self._deliver(channel, data, seq)
return seq
async def publish_batch(self, channel: str, messages: List[dict]) -> int:
last_seq = 0
for data in messages:
last_seq = await self.publish(channel, data)
return last_seq
async def subscribe(
self,
channel: str,
callback: AsyncCallback,
subscriber_id: str = "",
) -> None:
if not subscriber_id:
subscriber_id = f"sub_{id(callback)}"
self._subscribers[channel][subscriber_id] = callback
self._stats["active_subscriptions"] = sum(
len(subs) for subs in self._subscribers.values()
)
logger.debug(f"Subscribed {subscriber_id} to {channel}")
async def unsubscribe(self, channel: str, subscriber_id: str = "") -> None:
if channel in self._subscribers:
if subscriber_id and subscriber_id in self._subscribers[channel]:
del self._subscribers[channel][subscriber_id]
elif not subscriber_id:
self._subscribers[channel].clear()
if not self._subscribers[channel]:
del self._subscribers[channel]
self._stats["active_subscriptions"] = sum(
len(subs) for subs in self._subscribers.values()
)
async def _deliver(self, channel: str, data: dict, sequence: int) -> None:
"""Deliver message to all matching subscribers concurrently."""
enriched = {**data, "_seq": sequence, "_channel": channel}
# Collect all matching callbacks
callbacks: List[AsyncCallback] = []
# Exact match subscribers
if channel in self._subscribers:
callbacks.extend(self._subscribers[channel].values())
# Wildcard subscribers (e.g., "ticks:*" matches "ticks:AAPL")
for pattern, subs in self._subscribers.items():
if pattern.endswith(":*"):
prefix = pattern[:-1] # "ticks:"
if channel.startswith(prefix) and pattern != channel:
callbacks.extend(subs.values())
if not callbacks:
return
# Concurrent delivery with timeout
dead_callbacks: List[AsyncCallback] = []
async def safe_deliver(cb: AsyncCallback) -> None:
try:
await asyncio.wait_for(cb(enriched), timeout=self._broadcast_timeout)
self._stats["messages_delivered"] += 1
except asyncio.TimeoutError:
logger.warning(f"Subscriber timed out on {channel}")
dead_callbacks.append(cb)
except Exception:
dead_callbacks.append(cb)
await asyncio.gather(*(safe_deliver(cb) for cb in callbacks))
# Remove dead subscribers
for dead_cb in dead_callbacks:
for pattern, subs in list(self._subscribers.items()):
to_remove = [
sid for sid, cb in subs.items() if cb is dead_cb
]
for sid in to_remove:
del subs[sid]
logger.debug(f"Removed dead subscriber {sid} from {pattern}")
if dead_callbacks:
self._stats["active_subscriptions"] = sum(
len(subs) for subs in self._subscribers.values()
)
def get_channel_buffer(self, channel: str, since_seq: int = 0) -> List[Message]:
"""Get buffered messages for replay.
Args:
channel: Channel name
since_seq: Only return messages with sequence > since_seq
Returns:
List of messages for replay
"""
if channel not in self._buffers:
return []
return [m for m in self._buffers[channel] if m.sequence > since_seq]
def get_stats(self) -> dict:
return {
"backend": "in_memory",
**self._stats,
"buffer_sizes": {ch: len(buf) for ch, buf in self._buffers.items()},
}
class RedisPubSub(PubSubBackend):
"""Redis-backed pub/sub for multi-process deployments.
Uses Redis pub/sub for real-time delivery and Redis Streams
for message persistence and replay.
Requires: pip install redis[hiredis]
Configure via REDIS_URL environment variable.
"""
def __init__(self, redis_url: str, max_stream_len: int = 100000):
self._redis_url = redis_url
self._max_stream_len = max_stream_len
self._redis = None
self._pubsub = None
self._subscribers: Dict[str, Dict[str, AsyncCallback]] = defaultdict(dict)
self._sequence: Dict[str, int] = defaultdict(int)
self._listener_task: Optional[asyncio.Task] = None
self._running = False
self._stats = {
"messages_published": 0,
"messages_delivered": 0,
"messages_dropped": 0,
"active_subscriptions": 0,
"channels": 0,
"redis_connected": False,
}
async def start(self) -> None:
try:
import redis.asyncio as aioredis
except ImportError:
raise ImportError(
"redis package required for RedisPubSub. "
"Install with: pip install redis[hiredis]"
)
self._redis = aioredis.from_url(
self._redis_url,
decode_responses=True,
socket_connect_timeout=5,
retry_on_timeout=True,
)
# Test connection
await self._redis.ping()
self._stats["redis_connected"] = True
self._pubsub = self._redis.pubsub()
self._running = True
self._listener_task = asyncio.create_task(self._listen_loop())
logger.info(f"RedisPubSub connected to {self._redis_url}")
async def stop(self) -> None:
self._running = False
if self._listener_task:
self._listener_task.cancel()
try:
await self._listener_task
except asyncio.CancelledError:
pass
if self._pubsub:
await self._pubsub.unsubscribe()
await self._pubsub.close()
if self._redis:
await self._redis.close()
self._stats["redis_connected"] = False
logger.info("RedisPubSub stopped")
async def publish(self, channel: str, data: dict) -> int:
import json
self._sequence[channel] += 1
seq = self._sequence[channel]
enriched = {**data, "_seq": seq, "_ts": time.time()}
payload = json.dumps(enriched)
# Publish to Redis pub/sub channel
await self._redis.publish(f"wayy:{channel}", payload)
# Also write to Redis Stream for persistence/replay
stream_key = f"wayy:stream:{channel}"
await self._redis.xadd(
stream_key,
{"data": payload},
maxlen=self._max_stream_len,
)
self._stats["messages_published"] += 1
return seq
async def publish_batch(self, channel: str, messages: List[dict]) -> int:
import json
pipe = self._redis.pipeline()
last_seq = 0
for data in messages:
self._sequence[channel] += 1
seq = self._sequence[channel]
last_seq = seq
enriched = {**data, "_seq": seq, "_ts": time.time()}
payload = json.dumps(enriched)
pipe.publish(f"wayy:{channel}", payload)
stream_key = f"wayy:stream:{channel}"
pipe.xadd(stream_key, {"data": payload}, maxlen=self._max_stream_len)
await pipe.execute()
self._stats["messages_published"] += len(messages)
return last_seq
async def subscribe(
self,
channel: str,
callback: AsyncCallback,
subscriber_id: str = "",
) -> None:
if not subscriber_id:
subscriber_id = f"sub_{id(callback)}"
is_new_channel = channel not in self._subscribers or not self._subscribers[channel]
self._subscribers[channel][subscriber_id] = callback
if is_new_channel and self._pubsub:
if channel.endswith(":*"):
await self._pubsub.psubscribe(f"wayy:{channel}")
else:
await self._pubsub.subscribe(f"wayy:{channel}")
self._stats["active_subscriptions"] = sum(
len(subs) for subs in self._subscribers.values()
)
self._stats["channels"] = len(self._subscribers)
async def unsubscribe(self, channel: str, subscriber_id: str = "") -> None:
if channel in self._subscribers:
if subscriber_id and subscriber_id in self._subscribers[channel]:
del self._subscribers[channel][subscriber_id]
elif not subscriber_id:
self._subscribers[channel].clear()
if not self._subscribers[channel]:
del self._subscribers[channel]
if self._pubsub:
if channel.endswith(":*"):
await self._pubsub.punsubscribe(f"wayy:{channel}")
else:
await self._pubsub.unsubscribe(f"wayy:{channel}")
self._stats["active_subscriptions"] = sum(
len(subs) for subs in self._subscribers.values()
)
async def _listen_loop(self) -> None:
"""Background task that listens for Redis pub/sub messages."""
import json
while self._running:
try:
message = await self._pubsub.get_message(
ignore_subscribe_messages=True, timeout=0.1
)
if message is None:
await asyncio.sleep(0.01)
continue
if message["type"] not in ("message", "pmessage"):
continue
raw_channel = message.get("channel", "")
# Strip "wayy:" prefix
if raw_channel.startswith("wayy:"):
channel = raw_channel[5:]
else:
channel = raw_channel
data = json.loads(message["data"])
# Deliver to local subscribers
await self._deliver_local(channel, data)
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"Redis listener error: {e}")
await asyncio.sleep(1.0)
async def _deliver_local(self, channel: str, data: dict) -> None:
"""Deliver a received message to local subscribers."""
callbacks: List[AsyncCallback] = []
if channel in self._subscribers:
callbacks.extend(self._subscribers[channel].values())
# Wildcard matching
for pattern, subs in self._subscribers.items():
if pattern.endswith(":*"):
prefix = pattern[:-1]
if channel.startswith(prefix) and pattern != channel:
callbacks.extend(subs.values())
for cb in callbacks:
try:
await asyncio.wait_for(cb(data), timeout=5.0)
self._stats["messages_delivered"] += 1
except Exception:
self._stats["messages_dropped"] += 1
async def replay(
self, channel: str, since_id: str = "0-0", count: int = 1000
) -> List[dict]:
"""Replay messages from Redis Stream.
Args:
channel: Channel name
since_id: Redis Stream ID to start from
count: Maximum messages to return
Returns:
List of message dicts
"""
import json
stream_key = f"wayy:stream:{channel}"
messages = await self._redis.xrange(stream_key, min=since_id, count=count)
return [json.loads(entry["data"]) for _id, entry in messages]
def get_stats(self) -> dict:
return {
"backend": "redis",
"redis_url": self._redis_url.split("@")[-1] if "@" in self._redis_url else self._redis_url,
**self._stats,
}
def create_pubsub(redis_url: Optional[str] = None) -> PubSubBackend:
"""Factory function to create the appropriate PubSub backend.
Args:
redis_url: Redis URL. If None/empty, uses InMemoryPubSub.
Returns:
PubSubBackend instance
"""
if redis_url:
logger.info(f"Using RedisPubSub backend")
return RedisPubSub(redis_url=redis_url)
else:
logger.info("Using InMemoryPubSub backend (set REDIS_URL for Redis)")
return InMemoryPubSub()