dashVectorSpace / src /vector_db.py
justmotes's picture
Upload src/vector_db.py with huggingface_hub
5bd11d7 verified
import os
from qdrant_client import QdrantClient, models
from qdrant_client.http.models import Distance, VectorParams
import numpy as np
from typing import List, Optional, Dict, Any
import uuid
class UnifiedQdrant:
def __init__(self, collection_name: str, vector_size: int, num_clusters: int = 32, freshness_shard_id: int = 999):
self.client = None
self.collection_name = collection_name
self.vector_size = vector_size
self.num_clusters = num_clusters
self.freshness_shard_id = freshness_shard_id
def initialize(self, is_baseline: bool = False):
"""
Connects to Qdrant and sets up the collection.
If is_baseline=True, creates a standard collection (No Sharding).
If is_baseline=False, creates a Custom Sharded collection.
"""
# Connect
url = os.getenv("QDRANT_URL", ":memory:")
api_key = os.getenv("QDRANT_API_KEY", None)
print(f"Connecting to Qdrant at {url}...")
# Relaxed connection settings for HF Spaces
port = 443 if url.startswith("https") else 6333
self.client = QdrantClient(
location=url,
port=port,
api_key=api_key,
timeout=60,
check_compatibility=False,
verify=False # Passed to httpx
)
self.is_local = url == ":memory:" or not url.startswith("http")
if self.is_local or is_baseline:
mode = "Local" if self.is_local else "Baseline"
print(f"Running in {mode} mode. Creating Standard Collection '{self.collection_name}'.")
self.num_clusters = 1
if self.client.collection_exists(self.collection_name):
print(f"Collection '{self.collection_name}' already exists. Skipping.")
return
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
)
print(f"Created standard collection '{self.collection_name}'.")
else:
# Custom Sharding Mode
if self.client.collection_exists(self.collection_name):
print(f"Collection '{self.collection_name}' already exists. Skipping initialization.")
return
# Try to create collection with full clusters
try:
self._create_collection_and_shards(self.num_clusters)
print(f"Successfully created collection with {self.num_clusters} clusters.")
except Exception as e:
print(f"Failed to create {self.num_clusters} clusters: {e}")
print("Attempting fallback to 8 clusters (Free Tier limit mitigation)...")
try:
self.num_clusters = 8
if self.client.collection_exists(self.collection_name):
self.client.delete_collection(self.collection_name)
self._create_collection_and_shards(self.num_clusters)
print(f"Fallback successful: Created collection with {self.num_clusters} clusters.")
except Exception as e2:
print(f"Failed to create 8 clusters: {e2}")
print("CRITICAL: Custom Sharding not supported. Falling back to Standard Collection.")
self.num_clusters = 1
if self.client.collection_exists(self.collection_name):
self.client.delete_collection(self.collection_name)
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
)
print("Fallback successful: Created Standard Collection.")
def _create_collection_and_shards(self, n_clusters):
print(f"Creating collection '{self.collection_name}' with custom sharding ({n_clusters} clusters)...")
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE),
sharding_method=models.ShardingMethod.CUSTOM,
shard_number=n_clusters + 1 # Clusters + Freshness
)
# Create Shard Keys
print("Creating shard keys...")
for i in range(n_clusters):
self.client.create_shard_key(self.collection_name, str(i))
# Create freshness shard key
self.client.create_shard_key(self.collection_name, str(self.freshness_shard_id))
print("Shard keys created successfully.")
def index_data(self, vectors: np.ndarray, payloads: List[Dict[str, Any]], cluster_ids: List[Optional[int]] = None):
"""
Indexes data with batching to avoid payload limits.
If cluster_ids provided, uses custom sharding (Prod).
If cluster_ids is None, uses standard upsert (Baseline/Local).
BATCH_SIZE hardcoded to 500 for safety.
"""
BATCH_SIZE = 500
if cluster_ids is None or self.is_local:
# Standard Upsert
points = [
models.PointStruct(
id=str(uuid.uuid4()),
vector=vec.tolist(),
payload=payloads[i]
) for i, vec in enumerate(vectors)
]
# Batching
total = len(points)
print(f"Upserting {total} points to '{self.collection_name}' (Standard)...")
for i in range(0, total, BATCH_SIZE):
batch = points[i : i + BATCH_SIZE]
self.client.upsert(
collection_name=self.collection_name,
points=batch
)
return
# Custom Sharding Upsert
data_by_shard = {}
for i, vec in enumerate(vectors):
cluster_id = cluster_ids[i]
key = str(self.freshness_shard_id) if cluster_id is None else str(cluster_id)
if key not in data_by_shard:
data_by_shard[key] = []
data_by_shard[key].append(
models.PointStruct(
id=str(uuid.uuid4()),
vector=vec.tolist(),
payload=payloads[i]
)
)
print(f"Indexing data across {len(data_by_shard)} shards (Custom Sharded)...")
for key, shard_points in data_by_shard.items():
# Also batch per shard if needed (though unlikely to exceed 32MB per shard with 25k samples)
# 25k samples / 32 shards ~= 800 points per shard. 800 * 8KB << 32MB.
# But safe is safe.
total_shard = len(shard_points)
for i in range(0, total_shard, BATCH_SIZE):
batch = shard_points[i : i + BATCH_SIZE]
self.client.upsert(
collection_name=self.collection_name,
points=batch,
shard_key_selector=key
)
def search_hybrid(self, query_vec: np.ndarray, target_clusters: List[int], confidence: float) -> List[Any]:
"""
Performs the hybrid search strategy (Prod).
"""
# Ensure query_vec is list
if isinstance(query_vec, np.ndarray):
query_vec = query_vec.tolist()
if isinstance(query_vec[0], list):
query_vec = query_vec[0]
shard_keys = []
if not target_clusters:
shard_keys = None
search_mode = "GLOBAL"
else:
shard_keys = [str(c) for c in target_clusters] + [str(self.freshness_shard_id)]
search_mode = f"TARGETED (Clusters {target_clusters} + Freshness)"
if self.is_local:
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vec,
limit=10,
with_payload=True
).points
else:
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vec,
shard_key_selector=shard_keys,
limit=10,
with_payload=True
).points
return results, search_mode
def search_baseline(self, query_vec: np.ndarray) -> List[Any]:
"""
Performs standard search (Baseline).
"""
if isinstance(query_vec, np.ndarray):
query_vec = query_vec.tolist()
if isinstance(query_vec[0], list):
query_vec = query_vec[0]
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vec,
limit=10,
with_payload=True
).points
return results
def get_shard_sizes(self) -> Dict[str, int]:
"""
Returns a dictionary of {shard_key: count}.
Only works for Custom Sharding collections.
"""
if self.is_local:
return {"local": self.client.count(self.collection_name).count}
sizes = {}
# Iterate through expected shard keys
# We assume keys are "0" to "num_clusters-1" and "freshness_shard_id"
keys = [str(i) for i in range(self.num_clusters)] + [str(self.freshness_shard_id)]
for key in keys:
try:
count = self.client.count(
collection_name=self.collection_name,
shard_key_selector=key
).count
sizes[key] = count
except:
sizes[key] = 0
return sizes