File size: 3,711 Bytes
d63c832
 
e427949
 
 
 
 
 
 
 
 
d63c832
5237303
e427949
 
5949a90
 
5237303
d63c832
5949a90
 
 
e427949
 
 
 
5949a90
 
 
e427949
d63c832
5237303
 
 
d63c832
5237303
e427949
5237303
d63c832
e427949
 
45f51d2
 
 
 
 
 
 
e427949
45f51d2
 
e427949
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d63c832
e427949
d63c832
e427949
 
 
 
 
 
 
 
d63c832
e427949
d63c832
 
e427949
5949a90
 
5237303
 
e427949
 
 
 
 
 
 
5237303
 
5949a90
e427949
5237303
5949a90
 
 
 
5237303
 
 
d63c832
5949a90
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
BharatGraph API Dependencies

BUG-20 FIX: verify_connectivity() was called on EVERY single API request.
With 50 concurrent cold-start requests this creates a connection storm where
50 simultaneous verify_connectivity() calls hit Neo4j AuraDB, which rate-limits
verification and causes cascading 503 errors.

Fix: TTL-based health cache -- only re-verify if more than 30 seconds have
elapsed since the last successful verification. A threading.Lock ensures
only one reconnect attempt runs at a time.
"""
import os
import time
import threading
from dotenv import load_dotenv
from loguru import logger
from neo4j import GraphDatabase
from neo4j.exceptions import ServiceUnavailable, AuthError

load_dotenv()

_driver          = None
_last_verified_at = 0.0
_VERIFY_TTL      = 30.0   # seconds between connectivity re-checks
_driver_lock     = threading.Lock()


def get_driver():
    global _driver, _last_verified_at

    uri  = os.getenv("NEO4J_URI", "")
    user = os.getenv("NEO4J_USER", "neo4j")
    pwd  = os.getenv("NEO4J_PASSWORD", "")

    if not uri:
        logger.warning("[API] NEO4J_URI not set -- running without database")
        return None

    now = time.monotonic()

    # H-06 FIX: lockless fast path -- if driver is valid and TTL fresh,
    # return immediately without acquiring the lock. The lock is only needed
    # for reconnection (rare). Under 50 concurrent requests this prevents
    # all threads serializing behind the lock for a trivial pointer check.
    if _driver is not None and (now - _last_verified_at) < _VERIFY_TTL:
        return _driver

    with _driver_lock:
        # Re-check inside lock in case another thread just reconnected
        now = time.monotonic()
        if _driver is not None and (now - _last_verified_at) < _VERIFY_TTL:
            return _driver

        # TTL expired or first call -- need to verify (and reconnect if needed)
        if _driver is not None:
            try:
                _driver.verify_connectivity()
                _last_verified_at = time.monotonic()
                return _driver
            except Exception as e:
                logger.warning(
                    f"[API] Cached Neo4j driver dead ({type(e).__name__}), reconnecting..."
                )
                try:
                    _driver.close()
                except Exception:
                    pass
                _driver = None

        try:
            _driver = GraphDatabase.driver(uri, auth=(user, pwd))
            _driver.verify_connectivity()
            _last_verified_at = time.monotonic()
            logger.success(f"[API] Neo4j connected: {uri[:30]}...")
        except AuthError as e:
            logger.error(f"[API] Neo4j auth failed -- check NEO4J_USER/NEO4J_PASSWORD: {e}")
            _driver = None
        except ServiceUnavailable as e:
            logger.error(f"[API] Neo4j service unavailable: {e}")
            _driver = None
        except Exception as e:
            logger.error(f"[API] Neo4j connection failed: {type(e).__name__}: {e}")
            _driver = None

        return _driver


def close_driver():
    global _driver
    with _driver_lock:
        if _driver:
            try:
                _driver.close()
            except Exception:
                pass
            _driver = None


def get_db():
    """FastAPI dependency -- returns live driver or raises 503."""
    from fastapi import HTTPException
    driver = get_driver()
    if driver is None:
        raise HTTPException(
            status_code=503,
            detail=(
                "Graph database unavailable. "
                "Check NEO4J_URI and NEO4J_PASSWORD in environment secrets."
            ),
        )
    return driver