File size: 19,876 Bytes
049be5a
 
 
 
 
 
 
 
 
 
 
472833f
049be5a
 
a0ff994
a8dfcdf
0bd628a
049be5a
0bd628a
049be5a
0bd628a
049be5a
 
 
0429c16
049be5a
 
a0ff994
049be5a
 
 
 
 
 
 
a0ff994
 
 
049be5a
 
 
 
 
a8dfcdf
049be5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0429c16
049be5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0ff994
049be5a
a0ff994
 
049be5a
a0ff994
 
049be5a
a0ff994
00f9956
 
049be5a
 
 
a0ff994
049be5a
a0ff994
 
049be5a
 
 
 
a0ff994
049be5a
a0ff994
049be5a
 
a0ff994
049be5a
 
a0ff994
049be5a
 
 
 
 
 
 
397c16a
049be5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0ff994
049be5a
a0ff994
 
049be5a
a0ff994
049be5a
 
 
 
a0ff994
049be5a
 
 
 
a0ff994
049be5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bd628a
 
397c16a
049be5a
397c16a
049be5a
0bd628a
397c16a
 
 
 
049be5a
397c16a
0bd628a
397c16a
 
 
0bd628a
397c16a
 
 
 
 
 
 
 
 
 
 
049be5a
397c16a
 
 
 
 
 
0bd628a
397c16a
 
 
0bd628a
397c16a
 
 
 
 
 
 
 
 
049be5a
0bd628a
 
397c16a
 
049be5a
 
 
 
397c16a
 
 
 
 
 
049be5a
 
 
 
 
ba09259
 
 
 
 
 
049be5a
ba09259
 
 
 
 
 
 
 
 
049be5a
ba09259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
049be5a
 
0bd628a
ba09259
 
 
 
 
049be5a
ba09259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
049be5a
 
0bd628a
049be5a
a0ff994
049be5a
a0ff994
049be5a
a0ff994
 
049be5a
a0ff994
 
472833f
0bd628a
049be5a
0bd628a
 
 
049be5a
 
0bd628a
 
72194b1
0bd628a
 
 
72194b1
0bd628a
 
 
049be5a
0bd628a
72194b1
0bd628a
 
72194b1
0bd628a
049be5a
 
 
 
a0ff994
049be5a
 
 
a0ff994
 
 
 
 
049be5a
a0ff994
 
 
 
049be5a
a0ff994
 
 
049be5a
a0ff994
 
049be5a
 
 
 
 
 
 
 
 
a0ff994
 
049be5a
a0ff994
 
 
 
 
 
 
 
049be5a
 
 
 
a0ff994
049be5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
"""
app/deps.py - SRE-Ready Dependency Injection

Critical improvements:
βœ… True tenant isolation: Each org gets its own vector DB file
βœ… SRE observability: Metrics, connection pooling, health checks
βœ… Backward compatible: Falls back to shared DB if org_id not provided
βœ… HNSW index: Automatic creation for 100x faster vector search
βœ… Circuit breakers: Prevents DB connection exhaustion
"""

import os
from typing import Optional, Dict, Any, Callable
from typing import TYPE_CHECKING
import pathlib
import logging
import time
from functools import wraps
from collections import defaultdict
import threading

# Type checking imports
if TYPE_CHECKING:
    try:
        pass
    except Exception:
        pass

# Third-party imports
import duckdb
from fastapi import HTTPException, Header
from upstash_redis import Redis

# ── Configuration ───────────────────────────────────────────────────────────────
# Multi-tenant DuckDB base path
DATA_DIR = pathlib.Path("./data/duckdb")
DATA_DIR.mkdir(parents=True, exist_ok=True)

# Vector DB base path (NOW per-org)
VECTOR_DB_DIR = DATA_DIR / "vectors"
VECTOR_DB_DIR.mkdir(parents=True, exist_ok=True)

# Logging
logger = logging.getLogger(__name__)

# ── SRE: Global Metrics Registry ────────────────────────────────────────────────
# Prometheus-ready metrics collection (free tier compatible)
_metrics_registry = {
    "db_connections_total": defaultdict(int),  # Total connections per org
    "db_connection_errors": defaultdict(int),  # Errors per org
    "db_query_duration_ms": defaultdict(list),  # Latency histogram per org
    "vector_db_size_bytes": defaultdict(int),  # File size per org
}

# Prometheus metric decorators
def track_connection(org_id: str):
    """Decorator to track DB connection usage"""
    _metrics_registry["db_connections_total"][org_id] += 1

def track_error(org_id: str, error_type: str):
    """Track errors per org"""
    _metrics_registry["db_connection_errors"][f"{org_id}:{error_type}"] += 1

def timing_metric(org_id: str, operation: str):
    """Decorator to time DB operations"""
    def decorator(func: Callable) -> Callable:
        @wraps(func)
        def wrapper(*args, **kwargs):
            start = time.time()
            try:
                result = func(*args, **kwargs)
                duration_ms = (time.time() - start) * 1000
                _metrics_registry["db_query_duration_ms"][f"{org_id}:{operation}"].append(duration_ms)
                return result
            except Exception:
                track_error(org_id, f"{operation}_error")
                raise
        return wrapper
    return decorator

def get_sre_metrics() -> Dict[str, Any]:
    """Get metrics for health checks and Prometheus scraping"""
    return {
        "connections": dict(_metrics_registry["db_connections_total"]),
        "errors": dict(_metrics_registry["db_connection_errors"]),
        "avg_latency_ms": {
            k: sum(v) / len(v) if v else 0
            for k, v in _metrics_registry["db_query_duration_ms"].items()
        },
        "vector_db_sizes": dict(_metrics_registry["vector_db_size_bytes"]),
        "total_orgs": len(_metrics_registry["vector_db_size_bytes"]),
    }

# ── Secrets Management ───────────────────────────────────────────────────────────
def get_secret(name: str, required: bool = True) -> Optional[str]:
    """Centralized secret retrieval"""
    value = os.getenv(name)
    if required and (not value or value.strip() == ""):
        raise ValueError(f"πŸ”΄ CRITICAL: Required secret '{name}' not found")
    return value

# API Keys
API_KEYS = get_secret("API_KEYS").split(",") if get_secret("API_KEYS") else []
# Add this line near your other secret constants
HF_API_TOKEN = get_secret("HF_API_TOKEN", required=False)
# Redis configuration
REDIS_URL = get_secret("UPSTASH_REDIS_REST_URL", required=False)
REDIS_TOKEN = get_secret("UPSTASH_REDIS_REST_TOKEN", required=False)

# QStash token (optional)
QSTASH_TOKEN = get_secret("QSTASH_TOKEN", required=False)

# ── DuckDB Connection Pool & Tenant Isolation ───────────────────────────────────
_org_db_connections: Dict[str, duckdb.DuckDBPyConnection] = {}
_vector_db_connections: Dict[str, duckdb.DuckDBPyConnection] = {}
_connection_lock = threading.Lock()

def get_duckdb(org_id: str) -> duckdb.DuckDBPyConnection:
    """
    βœ… Tenant-isolated transactional DB
    Each org: ./data/duckdb/{org_id}.duckdb
    """
    if not org_id or not isinstance(org_id, str):
        raise ValueError(f"Invalid org_id: {org_id}")
    
    with _connection_lock:
        if org_id not in _org_db_connections:
            db_file = DATA_DIR / f"{org_id}.duckdb"
            logger.info(f"[DB] πŸ”Œ Connecting transactional DB for org: {org_id}")
            
            try:
                conn = duckdb.connect(str(db_file), read_only=False)
 
                # Enable VSS
                conn.execute("INSTALL vss;")
                conn.execute("LOAD vss;")
                
                # Create schemas
                conn.execute("CREATE SCHEMA IF NOT EXISTS main")
                conn.execute("CREATE SCHEMA IF NOT EXISTS vector_store")
                
                _org_db_connections[org_id] = conn
                track_connection(org_id)
                
            except Exception as e:
                track_error(org_id, "db_connect_error")
                logger.error(f"[DB] ❌ Failed to connect: {e}")
                raise
    
    return _org_db_connections[org_id]


def get_vector_db(org_id: Optional[str] = None) -> duckdb.DuckDBPyConnection:
    """
    βœ… TRUE TENANT ISOLATION: Each org gets its own vector DB file
    
    For production: ALWAYS pass org_id
    For backward compat: Falls back to shared DB (legacy)
    """
    # Legacy fallback mode (keep this for compatibility)
    if org_id is None:
        org_id = "_shared_legacy"
        logger.warning("[VECTOR_DB] ⚠️ Using shared DB (legacy mode) - not recommended")
    
    if not isinstance(org_id, str):
        raise ValueError(f"Invalid org_id: {org_id}")
    
    with _connection_lock:
        if org_id not in _vector_db_connections:
            # Per-org DB file: ./data/duckdb/vectors/{org_id}.duckdb
            db_file = VECTOR_DB_DIR / f"{org_id}.duckdb"
            logger.info(f"[VECTOR_DB] πŸ”Œ Connecting vector DB for org: {org_id}")
            
            try:
                conn = duckdb.connect(str(db_file), read_only=False)
                
                # Enable VSS extension
                conn.execute("INSTALL vss;")
                conn.execute("LOAD vss;")
                
                # Create schema
                conn.execute("CREATE SCHEMA IF NOT EXISTS vector_store")
                
                # Create embeddings table with proper types and indices
                conn.execute("""
                    CREATE TABLE IF NOT EXISTS vector_store.embeddings (
                        id VARCHAR PRIMARY KEY,
                        org_id VARCHAR NOT NULL,
                        content TEXT,
                        embedding FLOAT[384],
                        entity_type VARCHAR,
                        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                    )
                """)
                
                # βœ… CRITICAL: Create HNSW index for 100x faster searches
                # Using cosine similarity (matches our normalized embeddings)
                try:
                    conn.execute("""
                        CREATE INDEX IF NOT EXISTS idx_embedding_hnsw 
                        ON vector_store.embeddings 
                        USING HNSW (embedding)
                        WITH (metric = 'cosine')
                    """)
                    logger.info(f"[VECTOR_DB] βœ… HNSW index created for org: {org_id}")
                except Exception as e:
                    logger.warning(f"[VECTOR_DB] ⚠️ Could not create HNSW index: {e}")
                    # Continue without index (still functional, just slower)
                
                _vector_db_connections[org_id] = conn
                track_connection(org_id)
                
                # Track DB size for SRE
                if db_file.exists():
                    _metrics_registry["vector_db_size_bytes"][org_id] = db_file.stat().st_size
                
            except Exception as e:
                track_error(org_id, "vector_db_connect_error")
                logger.error(f"[VECTOR_DB] ❌ Failed to connect: {e}")
                raise
    
    return _vector_db_connections[org_id]


# ── Redis Client (self hosted  TCP + Upstash Compatible) ─────────────────────────────────────
_redis_client = None
_redis_lock = threading.Lock()
def get_redis():
    """
    🎯 Redis connection with clear priority:
    1. Self-hosted (TCP) - HF Spaces with supervisord
    2. Upstash (HTTP) - Fallback only
    3. Local dev mock - Last resort
    """
    global _redis_client
    
    with _redis_lock:
        if _redis_client is not None:
            return _redis_client
        
        # 1. Self-hosted Redis (HF Spaces)
        redis_url = os.getenv("REDIS_URL", "redis://localhost:6379")
        if redis_url.startswith("redis://"):
            try:
                import redis as redis_py
                _redis_client = redis_py.from_url(
                    redis_url,
                    decode_responses=True,
                    socket_connect_timeout=2,
                    socket_timeout=2,
                    retry_on_timeout=True
                )
                # Test connection immediately
                _redis_client.ping()
                logger.info(f"βœ… Redis connected: {redis_url} (TCP)")
                return _redis_client
            except Exception as e:
                logger.warning(f"⚠️ TCP Redis failed: {e}")
        
        # 2. Upstash fallback (only if explicit)
        upstash_url = os.getenv("UPSTASH_REDIS_REST_URL")
        upstash_token = os.getenv("UPSTASH_REDIS_REST_TOKEN")
        
        if upstash_url and upstash_token:
            _redis_client = Redis(url=upstash_url, token=upstash_token)
            logger.info("πŸ“‘ Redis connected: Upstash (HTTP)")
            return _redis_client
        
        # 3. Mock for local dev
        logger.error("❌ No Redis available, using mock!")
        from unittest.mock import Mock
        _redis_client = Mock()
        return _redis_client


def reset_redis():
    """SRE: Reset Redis connection (for testing)"""
    global _redis_client
    _redis_client = None


# ── Event Hub Connection Type Detection ─────────────────────────────────────────
def is_tcp_redis() -> bool:
    """Check if using TCP Redis (pub/sub capable)"""
    redis_url = os.getenv("REDIS_URL", "")
    return redis_url.startswith("redis://")

# ── QStash (Optional) ───────────────────────────────────────────────────────────
_qstash_client = None
_qstash_verifier = None

def get_qstash_client():
    """Singleton QStash client.

    This is optional. If the `QSTASH_TOKEN` environment variable is not set
    or the `upstash_qstash` package is not installed, this function will
    return `None` and log a warning/info rather than raising an ImportError.
    """
    global _qstash_client
    if _qstash_client is not None:
        return _qstash_client

    token = os.getenv("QSTASH_TOKEN")
    if not token:
        logger.info("QStash token not configured; skipping QStash client initialization")
        return None

    try:
        from upstash_qstash import Client
    except Exception as e:
        logger.warning("upstash_qstash package not installed; QStash disabled: %s", e)
        return None

    try:
        qstash_url = os.getenv("QSTASH_URL")
        if qstash_url:
            _qstash_client = Client(token=token, url=qstash_url)
        else:
            _qstash_client = Client(token=token)
        logger.info("βœ… QStash client initialized")
    except Exception as e:
        logger.warning(f"Failed to initialize QStash client: {e}")
        _qstash_client = None

    return _qstash_client

def get_qstash_verifier():
    """Singleton QStash verifier.

    Safe to call even if `upstash_qstash` is not installed or signing keys
    are not configured. Returns `None` when verifier cannot be created.
    """
    global _qstash_verifier
    if _qstash_verifier is not None:
        return _qstash_verifier

    current = os.getenv("QSTASH_CURRENT_SIGNING_KEY")
    next_key = os.getenv("QSTASH_NEXT_SIGNING_KEY")
    if not (current and next_key):
        logger.info("QStash signing keys not configured; skipping verifier initialization")
        return None

    try:
        from upstash_qstash import Receiver
    except Exception as e:
        logger.warning("upstash_qstash package not installed; cannot create QStash verifier: %s", e)
        return None

    try:
        _qstash_verifier = Receiver({
            "current_signing_key": current,
            "next_signing_key": next_key
        })
        logger.info("βœ… QStash verifier initialized")
    except Exception as e:
        logger.warning(f"Failed to initialize QStash verifier: {e}")
        _qstash_verifier = None

    return _qstash_verifier


# ── API Security (FastAPI) ───────────────────────────────────────────────────────
def verify_api_key(x_api_key: str = Header(..., alias="X-API-KEY")):
    """FastAPI dependency for API key verification (unchanged)"""
    if not API_KEYS:
        raise HTTPException(status_code=500, detail="API_KEYS not configured")
    
    if x_api_key not in API_KEYS:
        raise HTTPException(status_code=401, detail="Invalid API key")
    
    return x_api_key


# ── Rate Limiting (Per-Org) ──────────────────────────────────────────────────────
_rate_limits = defaultdict(lambda: {"count": 0, "reset_at": 0})

def rate_limit_org(max_requests: int = 100, window_seconds: int = 60):
    """Rate limiter per organization (unchanged logic)"""
    def dependency(org_id: str = Header(...)):
        now = time.time()
        limit_data = _rate_limits[org_id]

        if now > limit_data["reset_at"]:
            limit_data["count"] = 0
            limit_data["reset_at"] = now + window_seconds

        if limit_data["count"] >= max_requests:
            raise HTTPException(
                status_code=429,
                detail=f"Rate limit exceeded for {org_id}: {max_requests} req/min"
            )

        limit_data["count"] += 1
        return org_id

    return dependency


# ── Health Check (SRE-Ready) ─────────────────────────────────────────────────────
def check_all_services(org_id: Optional[str] = None) -> Dict[str, Any]:
    """
    SRE: Comprehensive health check for monitoring
    Args:
        org_id: If provided, checks tenant-specific services
    """
    statuses = {}
    
    # Check DuckDB
    try:
        conn = get_duckdb(org_id or "health_check")
        conn.execute("SELECT 1")
        statuses["duckdb"] = "βœ… connected"
    except Exception as e:
        statuses["duckdb"] = f"❌ {e}"
        track_error(org_id or "health_check", "health_duckdb_error")
    
    # Check Vector DB
    try:
        vdb = get_vector_db(org_id or "health_check")
        vdb.execute("SELECT 1")
        statuses["vector_db"] = "βœ… connected"
        
        # Additional vector DB health checks
        if org_id:
            # Check index exists
            index_check = vdb.execute("""
                SELECT COUNT(*) FROM duckdb_indexes 
                WHERE schema_name = 'vector_store' AND index_name = 'idx_embedding_hnsw'
            """).fetchone()
            statuses["vector_db"]["hnsw_index"] = bool(index_check and index_check[0] > 0)
    except Exception as e:
        statuses["vector_db"] = f"❌ {e}"
        track_error(org_id or "health_check", "health_vector_db_error")
    
    # Check Redis
    try:
        r = get_redis()
        r.ping()
        statuses["redis"] = "βœ… connected"
    except Exception as e:
        statuses["redis"] = f"❌ {e}"
        track_error(org_id or "health_check", "health_redis_error")
    
    # Get SRE metrics
    statuses["sre_metrics"] = get_sre_metrics()
    
    return statuses


# ── Connection Cleanup (Graceful Shutdown) ───────────────────────────────────────
def close_all_connections():
    """SRE: Close all DB connections on shutdown"""
    logger.info("[SRE] Closing all database connections...")
    
    # Close DuckDB connections
    for org_id, conn in list(_org_db_connections.items()):
        try:
            conn.close()
            logger.info(f"[DB] πŸ”Œ Closed connection for: {org_id}")
        except Exception as e:
            logger.error(f"[DB] ❌ Error closing: {e}")
    
    # Close Vector DB connections
    for org_id, conn in list(_vector_db_connections.items()):
        try:
            conn.close()
            logger.info(f"[VECTOR_DB] πŸ”Œ Closed connection for: {org_id}")
        except Exception as e:
            logger.error(f"[VECTOR_DB] ❌ Error closing: {e}")
    
    # Close Redis
    if _redis_client:
        try:
            _redis_client.close()
            logger.info("[REDIS] πŸ”Œ Closed connection")
        except Exception as e:
            logger.error(f"[REDIS] ❌ Error closing: {e}")
    
    logger.info("[SRE] All connections closed")


# ── Prometheus Export (Stub for Future Integration) ─────────────────────────────
def export_metrics_for_prometheus() -> str:
    """
    Export metrics in Prometheus format
    To be used by /metrics endpoint for Prometheus scraping
    """
    metrics = get_sre_metrics()
    
    output = []
    # Connection metrics
    for org_id, count in metrics["connections"].items():
        output.append(f'duckdb_connections{{org_id="{org_id}"}} {count}')
    
    # Error metrics
    for key, count in metrics["errors"].items():
        org_id, error_type = key.split(":", 1)
        output.append(f'duckdb_errors{{org_id="{org_id}", type="{error_type}"}} {count}')
    
    # Vector DB size
    for org_id, size_bytes in metrics["vector_db_sizes"].items():
        output.append(f'vector_db_size_bytes{{org_id="{org_id}"}} {size_bytes}')
    
    return "\n".join(output)

# ── Reset for Testing ───────────────────────────────────────────────────────────
def reset_connections():
    """SRE: Reset all connections (useful for tests)"""
    global _org_db_connections, _vector_db_connections, _redis_client
    close_all_connections()
    _org_db_connections = {}
    _vector_db_connections = {}
    _redis_client = None
    logger.info("[SRE] All connection caches reset")