"""Database persistence for futures market snapshots and OHLC candles.""" from __future__ import annotations import logging from datetime import datetime, timedelta from typing import Any, Dict, List, Optional from sqlalchemy import and_, desc, or_ from database.db_manager import DatabaseManager from database.models import CachedFuturesOHLC, CachedFuturesSnapshot logger = logging.getLogger(__name__) class FuturesCacheQueries: def __init__(self, db_manager: DatabaseManager): self.db = db_manager def save_snapshot(self, payload: Dict[str, Any]) -> bool: if not payload.get("success"): return False try: with self.db.get_session() as session: record = CachedFuturesSnapshot( symbol=str(payload.get("symbol") or "").upper(), contract=str(payload.get("contract") or ""), exchange=str(payload.get("exchange") or payload.get("provider") or "unknown"), mark_price=float(payload.get("mark_price") or 0), index_price=float(payload["index_price"]) if payload.get("index_price") is not None else None, funding_rate=float(payload["funding_rate"]) if payload.get("funding_rate") is not None else None, open_interest=float(payload["open_interest"]) if payload.get("open_interest") is not None else None, volume_24h=float(payload["volume_24h"]) if payload.get("volume_24h") is not None else None, price_change_24h=float(payload["price_change_24h"]) if payload.get("price_change_24h") is not None else None, provider=str(payload.get("provider") or "unknown"), fetched_at=datetime.utcnow(), ) session.add(record) logger.info("Saved futures snapshot %s (%s)", record.symbol, record.contract) return True except Exception as exc: logger.error("save_snapshot failed: %s", exc, exc_info=True) return False def save_ohlc_batch(self, payload: Dict[str, Any]) -> int: if not payload.get("success"): return 0 saved = 0 contract = str(payload.get("contract") or "") symbol = str(payload.get("symbol") or "").upper() exchange = str(payload.get("exchange") or payload.get("provider") or "unknown") interval = str(payload.get("interval") or "1h") provider = str(payload.get("provider") or exchange) try: with self.db.get_session() as session: for candle in payload.get("candles") or []: ts = candle.get("timestamp") if isinstance(ts, (int, float)): ts = datetime.utcfromtimestamp(int(ts) / 1000 if ts > 1e12 else int(ts)) if not isinstance(ts, datetime): continue existing = session.query(CachedFuturesOHLC).filter( and_( CachedFuturesOHLC.contract == contract, CachedFuturesOHLC.interval == interval, CachedFuturesOHLC.timestamp == ts, ) ).first() if existing: existing.open = float(candle["open"]) existing.high = float(candle["high"]) existing.low = float(candle["low"]) existing.close = float(candle["close"]) existing.volume = float(candle.get("volume") or 0) existing.provider = provider existing.fetched_at = datetime.utcnow() else: session.add( CachedFuturesOHLC( contract=contract, symbol=symbol, exchange=exchange, interval=interval, timestamp=ts, open=float(candle["open"]), high=float(candle["high"]), low=float(candle["low"]), close=float(candle["close"]), volume=float(candle.get("volume") or 0), provider=provider, fetched_at=datetime.utcnow(), ) ) saved += 1 logger.info("Saved %s futures OHLC candles for %s %s", saved, contract, interval) return saved except Exception as exc: logger.error("save_ohlc_batch failed: %s", exc, exc_info=True) return saved def get_latest_snapshot(self, symbol: str) -> Optional[Dict[str, Any]]: sym = symbol.upper() try: with self.db.get_session() as session: row = ( session.query(CachedFuturesSnapshot) .filter(CachedFuturesSnapshot.symbol == sym) .order_by(desc(CachedFuturesSnapshot.fetched_at)) .first() ) if not row: return None return { "symbol": row.symbol, "contract": row.contract, "exchange": row.exchange, "mark_price": row.mark_price, "index_price": row.index_price, "funding_rate": row.funding_rate, "open_interest": row.open_interest, "volume_24h": row.volume_24h, "price_change_24h": row.price_change_24h, "provider": row.provider, "fetched_at": row.fetched_at.isoformat() if row.fetched_at else None, "source": "database", } except Exception as exc: logger.error("get_latest_snapshot failed: %s", exc) return None def get_cached_ohlc(self, contract: str, interval: str = "1h", limit: int = 200, symbol: Optional[str] = None) -> List[Dict[str, Any]]: try: sym = (symbol or "").upper() with self.db.get_session() as session: contract_filter = CachedFuturesOHLC.contract == contract if sym: contract_filter = or_( CachedFuturesOHLC.contract == contract, CachedFuturesOHLC.symbol == sym, CachedFuturesOHLC.contract == f"{sym}USDT", ) rows = ( session.query(CachedFuturesOHLC) .filter( and_( contract_filter, CachedFuturesOHLC.interval == interval, ) ) .order_by(desc(CachedFuturesOHLC.timestamp)) .limit(limit) .all() ) out = [] for row in reversed(rows): out.append({ "timestamp": row.timestamp.isoformat() if row.timestamp else None, "open": row.open, "high": row.high, "low": row.low, "close": row.close, "volume": row.volume, "provider": row.provider, "exchange": row.exchange, }) return out except Exception as exc: logger.error("get_cached_ohlc failed: %s", exc) return [] _futures_cache: Optional[FuturesCacheQueries] = None def get_futures_cache_queries(db_manager: Optional[DatabaseManager] = None) -> FuturesCacheQueries: global _futures_cache if _futures_cache is None: from database.db_manager import db_manager as default_db _futures_cache = FuturesCacheQueries(db_manager or default_db) return _futures_cache