dashVectorSpace / src /vector_db.py
justmotes's picture
Deploy dashVectorspace v1 (Full)
b92d96d
raw
history blame
8.29 kB
import os
from qdrant_client import QdrantClient, models
from qdrant_client.http.models import Distance, VectorParams
import numpy as np
from typing import List, Optional, Dict, Any
import uuid
class UnifiedQdrant:
def __init__(self, collection_name: str, vector_size: int, num_clusters: int = 32, freshness_shard_id: int = 999):
self.client = None
self.collection_name = collection_name
self.vector_size = vector_size
self.num_clusters = num_clusters
self.freshness_shard_id = freshness_shard_id
def initialize(self):
"""
Connects to Qdrant and sets up the collection with Custom Sharding.
Handles fallback if Free Tier limits are hit.
"""
# Connect
url = os.getenv("QDRANT_URL", ":memory:")
api_key = os.getenv("QDRANT_API_KEY", None)
print(f"Connecting to Qdrant at {url}...")
self.client = QdrantClient(location=url, api_key=api_key, timeout=60)
self.is_local = url == ":memory:" or not url.startswith("http")
if self.is_local:
print("WARNING: Running in local/memory mode. Custom Sharding is NOT supported. Simulating behavior.")
# Check if collection exists, if so, recreate it for a clean slate (or handle gracefully)
if self.client.collection_exists(self.collection_name):
self.client.delete_collection(self.collection_name)
# Try to create collection with full clusters
try:
self._create_collection_and_shards(self.num_clusters)
print(f"Successfully created collection with {self.num_clusters} clusters.")
except Exception as e:
print(f"Failed to create {self.num_clusters} clusters: {e}")
print("Attempting fallback to 8 clusters (Free Tier limit mitigation)...")
# Fallback 1: 8 Clusters
try:
self.num_clusters = 8
if self.client.collection_exists(self.collection_name):
self.client.delete_collection(self.collection_name)
self._create_collection_and_shards(self.num_clusters)
print(f"Fallback successful: Created collection with {self.num_clusters} clusters.")
except Exception as e2:
print(f"Failed to create 8 clusters: {e2}")
print("CRITICAL: Custom Sharding not supported. Falling back to Standard Collection (No Sharding).")
# Fallback 2: Standard Collection
self.num_clusters = 1 # Virtual clusters only
if self.client.collection_exists(self.collection_name):
self.client.delete_collection(self.collection_name)
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
)
print("Fallback successful: Created Standard Collection.")
def _create_collection_and_shards(self, n_clusters):
print(f"Creating collection '{self.collection_name}' with custom sharding ({n_clusters} clusters)...")
if self.is_local:
# Local mode doesn't support sharding_method=CUSTOM
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
)
else:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE),
sharding_method=models.ShardingMethod.CUSTOM,
shard_number=n_clusters + 1 # Clusters + Freshness
)
# CRITICAL: Create Shard Keys
if not self.is_local:
print("Creating shard keys...")
for i in range(n_clusters):
self.client.create_shard_key(self.collection_name, str(i))
# Create freshness shard key
self.client.create_shard_key(self.collection_name, str(self.freshness_shard_id))
print("Shard keys created successfully.")
def index_data(self, vectors: np.ndarray, payloads: List[Dict[str, Any]], cluster_ids: List[Optional[int]]):
"""
Indexes data into the specific shards based on cluster_ids.
If cluster_id is None, it goes to the Freshness Shard.
"""
points = []
# We need to batch this properly, but for simplicity we'll group by shard
# to minimize network calls if possible, or just iterate.
# Qdrant's upsert can take a batch, but they must share the same shard key?
# Actually, with custom sharding, if we provide a list of points,
# we might need to specify the shard key per operation or batch by shard key.
# The `upsert` method allows `shard_key_selector`.
# It's best to batch by shard key.
data_by_shard = {}
for i, vec in enumerate(vectors):
cluster_id = cluster_ids[i]
if cluster_id is None:
key = str(self.freshness_shard_id)
else:
key = str(cluster_id)
if key not in data_by_shard:
data_by_shard[key] = []
point_id = str(uuid.uuid4())
data_by_shard[key].append(
models.PointStruct(
id=point_id,
vector=vec.tolist(),
payload=payloads[i]
)
)
# Upsert batches
print(f"Indexing data across {len(data_by_shard)} shards...")
for key, batch_points in data_by_shard.items():
if self.is_local:
self.client.upsert(
collection_name=self.collection_name,
points=batch_points
# No shard_key_selector in local
)
else:
self.client.upsert(
collection_name=self.collection_name,
points=batch_points,
shard_key_selector=key
)
def search_hybrid(self, query_vec: np.ndarray, target_cluster: int, confidence: float) -> List[Any]:
"""
Performs the hybrid search strategy.
- Always include FRESHNESS_SHARD_ID.
- If confidence < 0.5, Global Search (all shards).
- Else, search [target_cluster, FRESHNESS_SHARD_ID].
"""
# Ensure query_vec is list
if isinstance(query_vec, np.ndarray):
query_vec = query_vec.tolist()
if isinstance(query_vec[0], list): # Handle 2D array if passed
query_vec = query_vec[0]
shard_keys = []
# Logic
if confidence < 0.5:
# Global Search
# In Qdrant, if we don't specify shard_key_selector, does it search all?
# With custom sharding, usually yes, or we might need to specify all keys.
# Let's assume passing None or not passing it searches all.
# However, the prompt says "Trigger a Global Search".
# Explicitly, we can just NOT pass shard_key_selector.
shard_keys = None
search_mode = "GLOBAL"
else:
# Targeted Search
shard_keys = [str(target_cluster), str(self.freshness_shard_id)]
search_mode = f"TARGETED (Cluster {target_cluster} + Freshness)"
# print(f"Searching: {search_mode} | Confidence: {confidence:.4f}")
if self.is_local:
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vec,
limit=10
).points
else:
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vec,
shard_key_selector=shard_keys,
limit=10
).points
return results, search_mode