import os from qdrant_client import QdrantClient, models from qdrant_client.http.models import Distance, VectorParams import numpy as np from typing import List, Optional, Dict, Any import uuid class UnifiedQdrant: def __init__(self, collection_name: str, vector_size: int, num_clusters: int = 32, freshness_shard_id: int = 999): self.client = None self.collection_name = collection_name self.vector_size = vector_size self.num_clusters = num_clusters self.freshness_shard_id = freshness_shard_id def initialize(self): """ Connects to Qdrant and sets up the collection with Custom Sharding. Handles fallback if Free Tier limits are hit. """ # Connect url = os.getenv("QDRANT_URL", ":memory:") api_key = os.getenv("QDRANT_API_KEY", None) print(f"Connecting to Qdrant at {url}...") self.client = QdrantClient(location=url, api_key=api_key, timeout=60) self.is_local = url == ":memory:" or not url.startswith("http") if self.is_local: print("WARNING: Running in local/memory mode. Custom Sharding is NOT supported. Simulating behavior.") # Check if collection exists, if so, recreate it for a clean slate (or handle gracefully) if self.client.collection_exists(self.collection_name): self.client.delete_collection(self.collection_name) # Try to create collection with full clusters try: self._create_collection_and_shards(self.num_clusters) print(f"Successfully created collection with {self.num_clusters} clusters.") except Exception as e: print(f"Failed to create {self.num_clusters} clusters: {e}") print("Attempting fallback to 8 clusters (Free Tier limit mitigation)...") # Fallback 1: 8 Clusters try: self.num_clusters = 8 if self.client.collection_exists(self.collection_name): self.client.delete_collection(self.collection_name) self._create_collection_and_shards(self.num_clusters) print(f"Fallback successful: Created collection with {self.num_clusters} clusters.") except Exception as e2: print(f"Failed to create 8 clusters: {e2}") print("CRITICAL: Custom Sharding not supported. Falling back to Standard Collection (No Sharding).") # Fallback 2: Standard Collection self.num_clusters = 1 # Virtual clusters only if self.client.collection_exists(self.collection_name): self.client.delete_collection(self.collection_name) self.client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE) ) print("Fallback successful: Created Standard Collection.") def _create_collection_and_shards(self, n_clusters): print(f"Creating collection '{self.collection_name}' with custom sharding ({n_clusters} clusters)...") if self.is_local: # Local mode doesn't support sharding_method=CUSTOM self.client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE) ) else: self.client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE), sharding_method=models.ShardingMethod.CUSTOM, shard_number=n_clusters + 1 # Clusters + Freshness ) # CRITICAL: Create Shard Keys if not self.is_local: print("Creating shard keys...") for i in range(n_clusters): self.client.create_shard_key(self.collection_name, str(i)) # Create freshness shard key self.client.create_shard_key(self.collection_name, str(self.freshness_shard_id)) print("Shard keys created successfully.") def index_data(self, vectors: np.ndarray, payloads: List[Dict[str, Any]], cluster_ids: List[Optional[int]]): """ Indexes data into the specific shards based on cluster_ids. If cluster_id is None, it goes to the Freshness Shard. """ points = [] # We need to batch this properly, but for simplicity we'll group by shard # to minimize network calls if possible, or just iterate. # Qdrant's upsert can take a batch, but they must share the same shard key? # Actually, with custom sharding, if we provide a list of points, # we might need to specify the shard key per operation or batch by shard key. # The `upsert` method allows `shard_key_selector`. # It's best to batch by shard key. data_by_shard = {} for i, vec in enumerate(vectors): cluster_id = cluster_ids[i] if cluster_id is None: key = str(self.freshness_shard_id) else: key = str(cluster_id) if key not in data_by_shard: data_by_shard[key] = [] point_id = str(uuid.uuid4()) data_by_shard[key].append( models.PointStruct( id=point_id, vector=vec.tolist(), payload=payloads[i] ) ) # Upsert batches print(f"Indexing data across {len(data_by_shard)} shards...") for key, batch_points in data_by_shard.items(): if self.is_local: self.client.upsert( collection_name=self.collection_name, points=batch_points # No shard_key_selector in local ) else: self.client.upsert( collection_name=self.collection_name, points=batch_points, shard_key_selector=key ) def search_hybrid(self, query_vec: np.ndarray, target_cluster: int, confidence: float) -> List[Any]: """ Performs the hybrid search strategy. - Always include FRESHNESS_SHARD_ID. - If confidence < 0.5, Global Search (all shards). - Else, search [target_cluster, FRESHNESS_SHARD_ID]. """ # Ensure query_vec is list if isinstance(query_vec, np.ndarray): query_vec = query_vec.tolist() if isinstance(query_vec[0], list): # Handle 2D array if passed query_vec = query_vec[0] shard_keys = [] # Logic if confidence < 0.5: # Global Search # In Qdrant, if we don't specify shard_key_selector, does it search all? # With custom sharding, usually yes, or we might need to specify all keys. # Let's assume passing None or not passing it searches all. # However, the prompt says "Trigger a Global Search". # Explicitly, we can just NOT pass shard_key_selector. shard_keys = None search_mode = "GLOBAL" else: # Targeted Search shard_keys = [str(target_cluster), str(self.freshness_shard_id)] search_mode = f"TARGETED (Cluster {target_cluster} + Freshness)" # print(f"Searching: {search_mode} | Confidence: {confidence:.4f}") if self.is_local: results = self.client.query_points( collection_name=self.collection_name, query=query_vec, limit=10 ).points else: results = self.client.query_points( collection_name=self.collection_name, query=query_vec, shard_key_selector=shard_keys, limit=10 ).points return results, search_mode