File size: 1,570 Bytes
24f3fb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from qdrant_client import QdrantClient
from qdrant_client.http.models import PointStruct, VectorParams, Distance
import boto3
from neo4j import GraphDatabase
from .config import *

_qdrant = None
_s3 = None
_neo4j = None
_s3_session = None
# Define the embedding dimensionality.
embedding_dim = 512
COLLECTION_NAME = "object_collection"

def get_qdrant():
    global _qdrant
    if _qdrant is None:
        _qdrant = QdrantClient(url=QDRANT_HOST, api_key=QDRANT_API)


    if not _qdrant.collection_exists(COLLECTION_NAME):
        _qdrant.create_collection(
            collection_name=COLLECTION_NAME,
            vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE)
        )
    else:
        _qdrant.get_collection(COLLECTION_NAME)

    return _qdrant

def get_s3():
    global _s3, _s3_session
    if _s3 is None:
        _s3_session = boto3.Session(
            aws_access_key_id=AWS_KEY,
            aws_secret_access_key=AWS_SECRET,
            region_name=AWS_REGION
        )
        _s3 = _s3_session.client("s3")
    return _s3

def get_s3_session():
    global _s3_session
    return _s3_session

def get_neo4j():
    global _neo4j
    if _neo4j is None:
        _neo4j = GraphDatabase.driver(NEO4J_URI, 
                    auth=(NEO4J_USER, NEO4J_PASS),
                    max_connection_lifetime=30,  # Forces reconnect after 30 seconds of lifetime
                    connection_timeout=10,       # Fails faster if connection is bad
                    max_connection_pool_size=10)  # Optional: Limit pool size)
    return _neo4j