File size: 3,249 Bytes
59abb4f
 
8bd7457
59abb4f
 
 
 
 
 
 
8bd7457
 
59abb4f
8bd7457
 
 
 
a37ed50
8bd7457
 
a37ed50
8bd7457
 
 
 
 
 
 
 
59abb4f
 
8bd7457
 
59abb4f
 
8bd7457
 
 
 
 
 
 
 
 
 
 
 
59abb4f
 
a37ed50
59abb4f
bf49c73
 
a37ed50
bf49c73
59abb4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from neo4j import GraphDatabase
import os
import time
from dotenv import load_dotenv

load_dotenv()


class Neo4jConnection:
    def __init__(self, uri: str, user: str, password: str, database: str = "neo4j"):
        self._uri = uri
        self._auth = (user, password)
        self.database = database
        self.driver = None
        self._connect_with_retry()

    def _connect_with_retry(self, retries: int = 10, delay: int = 6):
        auth = None if (self._auth == ("neo4j", "") or self._auth[1] == "") else self._auth
        for attempt in range(retries):
            try:
                self.driver = GraphDatabase.driver(self._uri, auth=auth)
                self.driver.verify_connectivity()
                print(f"[neo4j] Connected on attempt {attempt + 1}")
                return
            except Exception as e:
                print(f"[neo4j] Attempt {attempt + 1}/{retries} failed: {e}")
                if attempt < retries - 1:
                    time.sleep(delay)
        print("[neo4j] WARNING: Could not connect — queries will fail until Neo4j is ready")

    def close(self):
        if self.driver:
            self.driver.close()

    def run_query(self, query: str, parameters: dict | None = None) -> list:
        if not self.driver:
            self._connect_with_retry()
        try:
            with self.driver.session(database=self.database) as session:
                result = session.run(query, parameters or {})
                return [record.data() for record in result]
        except Exception:
            # Driver may have gone stale — reconnect once and retry
            self._connect_with_retry(retries=3, delay=3)
            with self.driver.session(database=self.database) as session:
                result = session.run(query, parameters or {})
                return [record.data() for record in result]


_auth_disabled = (os.getenv("NEO4J_AUTH", "") == "none")
neo4j_conn = Neo4jConnection(
    uri=os.getenv("NEO4J_URI") or "bolt://127.0.0.1:7687",
    user=os.getenv("NEO4J_USERNAME") or "neo4j",
    password="" if _auth_disabled else (os.getenv("NEO4J_PASSWORD") or "clinicalmatch2024"),
    database=os.getenv("NEO4J_DATABASE") or "neo4j",
)


def setup_schema():
    constraints = [
        "CREATE CONSTRAINT patient_id IF NOT EXISTS FOR (p:Patient) REQUIRE p.id IS UNIQUE",
        "CREATE CONSTRAINT trial_id IF NOT EXISTS FOR (t:Trial) REQUIRE t.id IS UNIQUE",
        "CREATE CONSTRAINT diagnosis_code IF NOT EXISTS FOR (d:Diagnosis) REQUIRE d.code IS UNIQUE",
        "CREATE CONSTRAINT site_id IF NOT EXISTS FOR (s:StudySite) REQUIRE s.id IS UNIQUE",
    ]
    indexes = [
        "CREATE INDEX patient_age IF NOT EXISTS FOR (p:Patient) ON (p.age)",
        "CREATE INDEX trial_phase IF NOT EXISTS FOR (t:Trial) ON (t.phase)",
        "CREATE INDEX trial_condition IF NOT EXISTS FOR (t:Trial) ON (t.condition)",
        "CREATE INDEX trial_status IF NOT EXISTS FOR (t:Trial) ON (t.status)",
    ]
    for query in constraints + indexes:
        try:
            neo4j_conn.run_query(query)
        except Exception as e:
            print(f"Schema warning: {e}")
    print("Schema setup complete.")


if __name__ == "__main__":
    setup_schema()
    neo4j_conn.close()