Spaces:
Paused
Paused
| """ | |
| app/deps.py - SRE-Ready Dependency Injection | |
| Critical improvements: | |
| β True tenant isolation: Each org gets its own vector DB file | |
| β SRE observability: Metrics, connection pooling, health checks | |
| β Backward compatible: Falls back to shared DB if org_id not provided | |
| β HNSW index: Automatic creation for 100x faster vector search | |
| β Circuit breakers: Prevents DB connection exhaustion | |
| """ | |
| import os | |
| from typing import Optional, Dict, Any, Callable | |
| from typing import TYPE_CHECKING | |
| import pathlib | |
| import logging | |
| import time | |
| from functools import wraps | |
| from collections import defaultdict | |
| import threading | |
| # Type checking imports | |
| if TYPE_CHECKING: | |
| try: | |
| pass | |
| except Exception: | |
| pass | |
| # Third-party imports | |
| import duckdb | |
| from fastapi import HTTPException, Header | |
| from upstash_redis import Redis | |
| # ββ Configuration βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Multi-tenant DuckDB base path | |
| DATA_DIR = pathlib.Path("./data/duckdb") | |
| DATA_DIR.mkdir(parents=True, exist_ok=True) | |
| # Vector DB base path (NOW per-org) | |
| VECTOR_DB_DIR = DATA_DIR / "vectors" | |
| VECTOR_DB_DIR.mkdir(parents=True, exist_ok=True) | |
| # Logging | |
| logger = logging.getLogger(__name__) | |
| # ββ SRE: Global Metrics Registry ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Prometheus-ready metrics collection (free tier compatible) | |
| _metrics_registry = { | |
| "db_connections_total": defaultdict(int), # Total connections per org | |
| "db_connection_errors": defaultdict(int), # Errors per org | |
| "db_query_duration_ms": defaultdict(list), # Latency histogram per org | |
| "vector_db_size_bytes": defaultdict(int), # File size per org | |
| } | |
| # Prometheus metric decorators | |
| def track_connection(org_id: str): | |
| """Decorator to track DB connection usage""" | |
| _metrics_registry["db_connections_total"][org_id] += 1 | |
| def track_error(org_id: str, error_type: str): | |
| """Track errors per org""" | |
| _metrics_registry["db_connection_errors"][f"{org_id}:{error_type}"] += 1 | |
| def timing_metric(org_id: str, operation: str): | |
| """Decorator to time DB operations""" | |
| def decorator(func: Callable) -> Callable: | |
| def wrapper(*args, **kwargs): | |
| start = time.time() | |
| try: | |
| result = func(*args, **kwargs) | |
| duration_ms = (time.time() - start) * 1000 | |
| _metrics_registry["db_query_duration_ms"][f"{org_id}:{operation}"].append(duration_ms) | |
| return result | |
| except Exception: | |
| track_error(org_id, f"{operation}_error") | |
| raise | |
| return wrapper | |
| return decorator | |
| def get_sre_metrics() -> Dict[str, Any]: | |
| """Get metrics for health checks and Prometheus scraping""" | |
| return { | |
| "connections": dict(_metrics_registry["db_connections_total"]), | |
| "errors": dict(_metrics_registry["db_connection_errors"]), | |
| "avg_latency_ms": { | |
| k: sum(v) / len(v) if v else 0 | |
| for k, v in _metrics_registry["db_query_duration_ms"].items() | |
| }, | |
| "vector_db_sizes": dict(_metrics_registry["vector_db_size_bytes"]), | |
| "total_orgs": len(_metrics_registry["vector_db_size_bytes"]), | |
| } | |
| # ββ Secrets Management βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_secret(name: str, required: bool = True) -> Optional[str]: | |
| """Centralized secret retrieval""" | |
| value = os.getenv(name) | |
| if required and (not value or value.strip() == ""): | |
| raise ValueError(f"π΄ CRITICAL: Required secret '{name}' not found") | |
| return value | |
| # API Keys | |
| API_KEYS = get_secret("API_KEYS").split(",") if get_secret("API_KEYS") else [] | |
| # Add this line near your other secret constants | |
| HF_API_TOKEN = get_secret("HF_API_TOKEN", required=False) | |
| # Redis configuration | |
| REDIS_URL = get_secret("UPSTASH_REDIS_REST_URL", required=False) | |
| REDIS_TOKEN = get_secret("UPSTASH_REDIS_REST_TOKEN", required=False) | |
| # QStash token (optional) | |
| QSTASH_TOKEN = get_secret("QSTASH_TOKEN", required=False) | |
| # ββ DuckDB Connection Pool & Tenant Isolation βββββββββββββββββββββββββββββββββββ | |
| _org_db_connections: Dict[str, duckdb.DuckDBPyConnection] = {} | |
| _vector_db_connections: Dict[str, duckdb.DuckDBPyConnection] = {} | |
| _connection_lock = threading.Lock() | |
| def get_duckdb(org_id: str) -> duckdb.DuckDBPyConnection: | |
| """ | |
| β Tenant-isolated transactional DB | |
| Each org: ./data/duckdb/{org_id}.duckdb | |
| """ | |
| if not org_id or not isinstance(org_id, str): | |
| raise ValueError(f"Invalid org_id: {org_id}") | |
| with _connection_lock: | |
| if org_id not in _org_db_connections: | |
| db_file = DATA_DIR / f"{org_id}.duckdb" | |
| logger.info(f"[DB] π Connecting transactional DB for org: {org_id}") | |
| try: | |
| conn = duckdb.connect(str(db_file), read_only=False) | |
| # Enable VSS | |
| conn.execute("INSTALL vss;") | |
| conn.execute("LOAD vss;") | |
| # Create schemas | |
| conn.execute("CREATE SCHEMA IF NOT EXISTS main") | |
| conn.execute("CREATE SCHEMA IF NOT EXISTS vector_store") | |
| _org_db_connections[org_id] = conn | |
| track_connection(org_id) | |
| except Exception as e: | |
| track_error(org_id, "db_connect_error") | |
| logger.error(f"[DB] β Failed to connect: {e}") | |
| raise | |
| return _org_db_connections[org_id] | |
| def get_vector_db(org_id: Optional[str] = None) -> duckdb.DuckDBPyConnection: | |
| """ | |
| β TRUE TENANT ISOLATION: Each org gets its own vector DB file | |
| For production: ALWAYS pass org_id | |
| For backward compat: Falls back to shared DB (legacy) | |
| """ | |
| # Legacy fallback mode (keep this for compatibility) | |
| if org_id is None: | |
| org_id = "_shared_legacy" | |
| logger.warning("[VECTOR_DB] β οΈ Using shared DB (legacy mode) - not recommended") | |
| if not isinstance(org_id, str): | |
| raise ValueError(f"Invalid org_id: {org_id}") | |
| with _connection_lock: | |
| if org_id not in _vector_db_connections: | |
| # Per-org DB file: ./data/duckdb/vectors/{org_id}.duckdb | |
| db_file = VECTOR_DB_DIR / f"{org_id}.duckdb" | |
| logger.info(f"[VECTOR_DB] π Connecting vector DB for org: {org_id}") | |
| try: | |
| conn = duckdb.connect(str(db_file), read_only=False) | |
| # Enable VSS extension | |
| conn.execute("INSTALL vss;") | |
| conn.execute("LOAD vss;") | |
| # Create schema | |
| conn.execute("CREATE SCHEMA IF NOT EXISTS vector_store") | |
| # Create embeddings table with proper types and indices | |
| conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS vector_store.embeddings ( | |
| id VARCHAR PRIMARY KEY, | |
| org_id VARCHAR NOT NULL, | |
| content TEXT, | |
| embedding FLOAT[384], | |
| entity_type VARCHAR, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| # β CRITICAL: Create HNSW index for 100x faster searches | |
| # Using cosine similarity (matches our normalized embeddings) | |
| try: | |
| conn.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_embedding_hnsw | |
| ON vector_store.embeddings | |
| USING HNSW (embedding) | |
| WITH (metric = 'cosine') | |
| """) | |
| logger.info(f"[VECTOR_DB] β HNSW index created for org: {org_id}") | |
| except Exception as e: | |
| logger.warning(f"[VECTOR_DB] β οΈ Could not create HNSW index: {e}") | |
| # Continue without index (still functional, just slower) | |
| _vector_db_connections[org_id] = conn | |
| track_connection(org_id) | |
| # Track DB size for SRE | |
| if db_file.exists(): | |
| _metrics_registry["vector_db_size_bytes"][org_id] = db_file.stat().st_size | |
| except Exception as e: | |
| track_error(org_id, "vector_db_connect_error") | |
| logger.error(f"[VECTOR_DB] β Failed to connect: {e}") | |
| raise | |
| return _vector_db_connections[org_id] | |
| # ββ Redis Client (self hosted TCP + Upstash Compatible) βββββββββββββββββββββββββββββββββββββ | |
| _redis_client = None | |
| _redis_lock = threading.Lock() | |
| def get_redis(): | |
| """ | |
| π― Redis connection with clear priority: | |
| 1. Self-hosted (TCP) - HF Spaces with supervisord | |
| 2. Upstash (HTTP) - Fallback only | |
| 3. Local dev mock - Last resort | |
| """ | |
| global _redis_client | |
| with _redis_lock: | |
| if _redis_client is not None: | |
| return _redis_client | |
| # 1. Self-hosted Redis (HF Spaces) | |
| redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") | |
| if redis_url.startswith("redis://"): | |
| try: | |
| import redis as redis_py | |
| _redis_client = redis_py.from_url( | |
| redis_url, | |
| decode_responses=True, | |
| socket_connect_timeout=2, | |
| socket_timeout=2, | |
| retry_on_timeout=True | |
| ) | |
| # Test connection immediately | |
| _redis_client.ping() | |
| logger.info(f"β Redis connected: {redis_url} (TCP)") | |
| return _redis_client | |
| except Exception as e: | |
| logger.warning(f"β οΈ TCP Redis failed: {e}") | |
| # 2. Upstash fallback (only if explicit) | |
| upstash_url = os.getenv("UPSTASH_REDIS_REST_URL") | |
| upstash_token = os.getenv("UPSTASH_REDIS_REST_TOKEN") | |
| if upstash_url and upstash_token: | |
| _redis_client = Redis(url=upstash_url, token=upstash_token) | |
| logger.info("π‘ Redis connected: Upstash (HTTP)") | |
| return _redis_client | |
| # 3. Mock for local dev | |
| logger.error("β No Redis available, using mock!") | |
| from unittest.mock import Mock | |
| _redis_client = Mock() | |
| return _redis_client | |
| def reset_redis(): | |
| """SRE: Reset Redis connection (for testing)""" | |
| global _redis_client | |
| _redis_client = None | |
| # ββ Event Hub Connection Type Detection βββββββββββββββββββββββββββββββββββββββββ | |
| def is_tcp_redis() -> bool: | |
| """Check if using TCP Redis (pub/sub capable)""" | |
| redis_url = os.getenv("REDIS_URL", "") | |
| return redis_url.startswith("redis://") | |
| # ββ QStash (Optional) βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _qstash_client = None | |
| _qstash_verifier = None | |
| def get_qstash_client(): | |
| """Singleton QStash client. | |
| This is optional. If the `QSTASH_TOKEN` environment variable is not set | |
| or the `upstash_qstash` package is not installed, this function will | |
| return `None` and log a warning/info rather than raising an ImportError. | |
| """ | |
| global _qstash_client | |
| if _qstash_client is not None: | |
| return _qstash_client | |
| token = os.getenv("QSTASH_TOKEN") | |
| if not token: | |
| logger.info("QStash token not configured; skipping QStash client initialization") | |
| return None | |
| try: | |
| from upstash_qstash import Client | |
| except Exception as e: | |
| logger.warning("upstash_qstash package not installed; QStash disabled: %s", e) | |
| return None | |
| try: | |
| qstash_url = os.getenv("QSTASH_URL") | |
| if qstash_url: | |
| _qstash_client = Client(token=token, url=qstash_url) | |
| else: | |
| _qstash_client = Client(token=token) | |
| logger.info("β QStash client initialized") | |
| except Exception as e: | |
| logger.warning(f"Failed to initialize QStash client: {e}") | |
| _qstash_client = None | |
| return _qstash_client | |
| def get_qstash_verifier(): | |
| """Singleton QStash verifier. | |
| Safe to call even if `upstash_qstash` is not installed or signing keys | |
| are not configured. Returns `None` when verifier cannot be created. | |
| """ | |
| global _qstash_verifier | |
| if _qstash_verifier is not None: | |
| return _qstash_verifier | |
| current = os.getenv("QSTASH_CURRENT_SIGNING_KEY") | |
| next_key = os.getenv("QSTASH_NEXT_SIGNING_KEY") | |
| if not (current and next_key): | |
| logger.info("QStash signing keys not configured; skipping verifier initialization") | |
| return None | |
| try: | |
| from upstash_qstash import Receiver | |
| except Exception as e: | |
| logger.warning("upstash_qstash package not installed; cannot create QStash verifier: %s", e) | |
| return None | |
| try: | |
| _qstash_verifier = Receiver({ | |
| "current_signing_key": current, | |
| "next_signing_key": next_key | |
| }) | |
| logger.info("β QStash verifier initialized") | |
| except Exception as e: | |
| logger.warning(f"Failed to initialize QStash verifier: {e}") | |
| _qstash_verifier = None | |
| return _qstash_verifier | |
| # ββ API Security (FastAPI) βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def verify_api_key(x_api_key: str = Header(..., alias="X-API-KEY")): | |
| """FastAPI dependency for API key verification (unchanged)""" | |
| if not API_KEYS: | |
| raise HTTPException(status_code=500, detail="API_KEYS not configured") | |
| if x_api_key not in API_KEYS: | |
| raise HTTPException(status_code=401, detail="Invalid API key") | |
| return x_api_key | |
| # ββ Rate Limiting (Per-Org) ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _rate_limits = defaultdict(lambda: {"count": 0, "reset_at": 0}) | |
| def rate_limit_org(max_requests: int = 100, window_seconds: int = 60): | |
| """Rate limiter per organization (unchanged logic)""" | |
| def dependency(org_id: str = Header(...)): | |
| now = time.time() | |
| limit_data = _rate_limits[org_id] | |
| if now > limit_data["reset_at"]: | |
| limit_data["count"] = 0 | |
| limit_data["reset_at"] = now + window_seconds | |
| if limit_data["count"] >= max_requests: | |
| raise HTTPException( | |
| status_code=429, | |
| detail=f"Rate limit exceeded for {org_id}: {max_requests} req/min" | |
| ) | |
| limit_data["count"] += 1 | |
| return org_id | |
| return dependency | |
| # ββ Health Check (SRE-Ready) βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def check_all_services(org_id: Optional[str] = None) -> Dict[str, Any]: | |
| """ | |
| SRE: Comprehensive health check for monitoring | |
| Args: | |
| org_id: If provided, checks tenant-specific services | |
| """ | |
| statuses = {} | |
| # Check DuckDB | |
| try: | |
| conn = get_duckdb(org_id or "health_check") | |
| conn.execute("SELECT 1") | |
| statuses["duckdb"] = "β connected" | |
| except Exception as e: | |
| statuses["duckdb"] = f"β {e}" | |
| track_error(org_id or "health_check", "health_duckdb_error") | |
| # Check Vector DB | |
| try: | |
| vdb = get_vector_db(org_id or "health_check") | |
| vdb.execute("SELECT 1") | |
| statuses["vector_db"] = "β connected" | |
| # Additional vector DB health checks | |
| if org_id: | |
| # Check index exists | |
| index_check = vdb.execute(""" | |
| SELECT COUNT(*) FROM duckdb_indexes | |
| WHERE schema_name = 'vector_store' AND index_name = 'idx_embedding_hnsw' | |
| """).fetchone() | |
| statuses["vector_db"]["hnsw_index"] = bool(index_check and index_check[0] > 0) | |
| except Exception as e: | |
| statuses["vector_db"] = f"β {e}" | |
| track_error(org_id or "health_check", "health_vector_db_error") | |
| # Check Redis | |
| try: | |
| r = get_redis() | |
| r.ping() | |
| statuses["redis"] = "β connected" | |
| except Exception as e: | |
| statuses["redis"] = f"β {e}" | |
| track_error(org_id or "health_check", "health_redis_error") | |
| # Get SRE metrics | |
| statuses["sre_metrics"] = get_sre_metrics() | |
| return statuses | |
| # ββ Connection Cleanup (Graceful Shutdown) βββββββββββββββββββββββββββββββββββββββ | |
| def close_all_connections(): | |
| """SRE: Close all DB connections on shutdown""" | |
| logger.info("[SRE] Closing all database connections...") | |
| # Close DuckDB connections | |
| for org_id, conn in list(_org_db_connections.items()): | |
| try: | |
| conn.close() | |
| logger.info(f"[DB] π Closed connection for: {org_id}") | |
| except Exception as e: | |
| logger.error(f"[DB] β Error closing: {e}") | |
| # Close Vector DB connections | |
| for org_id, conn in list(_vector_db_connections.items()): | |
| try: | |
| conn.close() | |
| logger.info(f"[VECTOR_DB] π Closed connection for: {org_id}") | |
| except Exception as e: | |
| logger.error(f"[VECTOR_DB] β Error closing: {e}") | |
| # Close Redis | |
| if _redis_client: | |
| try: | |
| _redis_client.close() | |
| logger.info("[REDIS] π Closed connection") | |
| except Exception as e: | |
| logger.error(f"[REDIS] β Error closing: {e}") | |
| logger.info("[SRE] All connections closed") | |
| # ββ Prometheus Export (Stub for Future Integration) βββββββββββββββββββββββββββββ | |
| def export_metrics_for_prometheus() -> str: | |
| """ | |
| Export metrics in Prometheus format | |
| To be used by /metrics endpoint for Prometheus scraping | |
| """ | |
| metrics = get_sre_metrics() | |
| output = [] | |
| # Connection metrics | |
| for org_id, count in metrics["connections"].items(): | |
| output.append(f'duckdb_connections{{org_id="{org_id}"}} {count}') | |
| # Error metrics | |
| for key, count in metrics["errors"].items(): | |
| org_id, error_type = key.split(":", 1) | |
| output.append(f'duckdb_errors{{org_id="{org_id}", type="{error_type}"}} {count}') | |
| # Vector DB size | |
| for org_id, size_bytes in metrics["vector_db_sizes"].items(): | |
| output.append(f'vector_db_size_bytes{{org_id="{org_id}"}} {size_bytes}') | |
| return "\n".join(output) | |
| # ββ Reset for Testing βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset_connections(): | |
| """SRE: Reset all connections (useful for tests)""" | |
| global _org_db_connections, _vector_db_connections, _redis_client | |
| close_all_connections() | |
| _org_db_connections = {} | |
| _vector_db_connections = {} | |
| _redis_client = None | |
| logger.info("[SRE] All connection caches reset") |