| import asyncio |
| import json |
| import logging |
| import aiohttp |
| import websockets |
| from typing import Callable, Optional |
| from cryptography.hazmat.primitives import hashes |
| from cryptography.hazmat.primitives.asymmetric import padding |
| from cryptography.hazmat.primitives import serialization |
| from cryptography.hazmat.backends import default_backend |
| import base64 |
| import time |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class KalshiClient: |
| def __init__(self, rest_url: str, ws_url: str, api_key: str, private_key: Optional[str] = None): |
| """ |
| Kalshi Client uses API Key (key_id) and an optional RSA private key for signing requests. |
| For market data, public endpoints might be sufficient, but we implement signing for future trading. |
| """ |
| self.rest_url = rest_url |
| self.ws_url = ws_url |
| self.api_key = api_key |
| self.private_key_str = private_key |
| self.session: Optional[aiohttp.ClientSession] = None |
| self.ws: Optional[websockets.WebSocketClientProtocol] = None |
| self.on_message_callback: Optional[Callable] = None |
|
|
| def sign_request(self, timestamp: str, method: str, path: str) -> str: |
| if not self.private_key_str: |
| return "" |
| private_key = serialization.load_pem_private_key( |
| self.private_key_str.encode("utf-8"), |
| password=None, |
| backend=default_backend() |
| ) |
| msg_string = timestamp + method + path |
| signature = private_key.sign( |
| msg_string.encode("utf-8"), |
| padding.PKCS1v15(), |
| hashes.SHA256() |
| ) |
| return base64.b64encode(signature).decode("utf-8") |
|
|
| def _get_headers(self, method: str = "GET", path: str = ""): |
| timestamp = str(int(time.time() * 1000)) |
| headers = { |
| "KALSHI-ACCESS-KEY": self.api_key, |
| "KALSHI-ACCESS-TIMESTAMP": timestamp, |
| } |
| if self.private_key_str: |
| headers["KALSHI-ACCESS-SIGNATURE"] = self.sign_request(timestamp, method, path) |
| return headers |
|
|
| async def connect(self): |
| """Establish HTTP session and Kalshi WebSocket connection.""" |
| self.session = aiohttp.ClientSession(headers=self._get_headers()) |
| await self._connect_ws() |
|
|
| async def _connect_ws(self): |
| backoff = 1 |
| while True: |
| try: |
| logger.info(f"Connecting to Kalshi WS: {self.ws_url}") |
| self.ws = await websockets.connect(self.ws_url) |
| logger.info("Kalshi WS connected.") |
| |
| await self._authenticate_ws() |
| asyncio.create_task(self._listen()) |
| break |
| except Exception as e: |
| logger.error(f"WebSocket connection failed: {e}. Retrying in {backoff}s...") |
| await asyncio.sleep(backoff) |
| backoff = min(backoff * 2, 60) |
| |
| async def _authenticate_ws(self): |
| """Send authentication frame immediately upon connection.""" |
| if not self.private_key_str: return |
| timestamp = str(int(time.time() * 1000)) |
| sig = self.sign_request(timestamp, "GET", "/trade-api/ws/v2") |
| auth_msg = { |
| "id": 1, |
| "cmd": "subscribe", |
| "params": { |
| "channels": ["auth"], |
| "key_id": self.api_key, |
| "signature": sig, |
| "timestamp": timestamp |
| } |
| } |
| await self.ws.send(json.dumps(auth_msg)) |
|
|
| def set_callback(self, callback: Callable): |
| self.on_message_callback = callback |
|
|
| async def subscribe(self, tickers: list[str]): |
| """Subscribe to orderbook for specific Kalshi tickers.""" |
| if not self.ws: |
| return |
| msg = { |
| "id": 2, |
| "cmd": "subscribe", |
| "params": { |
| "channels": ["orderbook_delta"], |
| "market_tickers": tickers |
| } |
| } |
| await self.ws.send(json.dumps(msg)) |
| logger.info(f"Subscribed to Kalshi markets: {tickers}") |
|
|
| async def _listen(self): |
| try: |
| async for message in self.ws: |
| data = json.loads(message) |
| if self.on_message_callback: |
| await self.on_message_callback("kalshi", data) |
| except websockets.exceptions.ConnectionClosed: |
| logger.warning("Kalshi WS connection closed. Reconnecting...") |
| await self._connect_ws() |
| except Exception as e: |
| logger.error(f"Kalshi WS listen error: {e}") |
| await self._connect_ws() |
|
|
| async def get_market(self, ticker: str): |
| path = f"/markets/{ticker}" |
| headers = self._get_headers("GET", path) |
| async with self.session.get(f"{self.rest_url}{path}", headers=headers) as response: |
| if response.status == 200: |
| return await response.json() |
| return None |
|
|
| def normalize_price(self, cents: int) -> float: |
| """Convert Kalshi cent-format to implied probability.""" |
| return cents / 100.0 |
|
|
| async def close(self): |
| if self.ws: |
| await self.ws.close() |
| if self.session: |
| await self.session.close() |
|
|