object-memory / core /clients.py
russ4stall
fresh history
24f3fb6
from qdrant_client import QdrantClient
from qdrant_client.http.models import PointStruct, VectorParams, Distance
import boto3
from neo4j import GraphDatabase
from .config import *
_qdrant = None
_s3 = None
_neo4j = None
_s3_session = None
# Define the embedding dimensionality.
embedding_dim = 512
COLLECTION_NAME = "object_collection"
def get_qdrant():
global _qdrant
if _qdrant is None:
_qdrant = QdrantClient(url=QDRANT_HOST, api_key=QDRANT_API)
if not _qdrant.collection_exists(COLLECTION_NAME):
_qdrant.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE)
)
else:
_qdrant.get_collection(COLLECTION_NAME)
return _qdrant
def get_s3():
global _s3, _s3_session
if _s3 is None:
_s3_session = boto3.Session(
aws_access_key_id=AWS_KEY,
aws_secret_access_key=AWS_SECRET,
region_name=AWS_REGION
)
_s3 = _s3_session.client("s3")
return _s3
def get_s3_session():
global _s3_session
return _s3_session
def get_neo4j():
global _neo4j
if _neo4j is None:
_neo4j = GraphDatabase.driver(NEO4J_URI,
auth=(NEO4J_USER, NEO4J_PASS),
max_connection_lifetime=30, # Forces reconnect after 30 seconds of lifetime
connection_timeout=10, # Fails faster if connection is bad
max_connection_pool_size=10) # Optional: Limit pool size)
return _neo4j