import os from qdrant_client import QdrantClient, models from qdrant_client.http.models import Distance, VectorParams import numpy as np from typing import List, Optional, Dict, Any import uuid class UnifiedQdrant: def __init__(self, collection_name: str, vector_size: int, num_clusters: int = 32, freshness_shard_id: int = 999): self.client = None self.collection_name = collection_name self.vector_size = vector_size self.num_clusters = num_clusters self.freshness_shard_id = freshness_shard_id def initialize(self, is_baseline: bool = False): """ Connects to Qdrant and sets up the collection. If is_baseline=True, creates a standard collection (No Sharding). If is_baseline=False, creates a Custom Sharded collection. """ # Connect url = os.getenv("QDRANT_URL", ":memory:") api_key = os.getenv("QDRANT_API_KEY", None) print(f"Connecting to Qdrant at {url}...") # Relaxed connection settings for HF Spaces port = 443 if url.startswith("https") else 6333 self.client = QdrantClient( location=url, port=port, api_key=api_key, timeout=60, check_compatibility=False, verify=False # Passed to httpx ) self.is_local = url == ":memory:" or not url.startswith("http") if self.is_local or is_baseline: mode = "Local" if self.is_local else "Baseline" print(f"Running in {mode} mode. Creating Standard Collection '{self.collection_name}'.") self.num_clusters = 1 if self.client.collection_exists(self.collection_name): print(f"Collection '{self.collection_name}' already exists. Skipping.") return self.client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE) ) print(f"Created standard collection '{self.collection_name}'.") else: # Custom Sharding Mode if self.client.collection_exists(self.collection_name): print(f"Collection '{self.collection_name}' already exists. Skipping initialization.") return # Try to create collection with full clusters try: self._create_collection_and_shards(self.num_clusters) print(f"Successfully created collection with {self.num_clusters} clusters.") except Exception as e: print(f"Failed to create {self.num_clusters} clusters: {e}") print("Attempting fallback to 8 clusters (Free Tier limit mitigation)...") try: self.num_clusters = 8 if self.client.collection_exists(self.collection_name): self.client.delete_collection(self.collection_name) self._create_collection_and_shards(self.num_clusters) print(f"Fallback successful: Created collection with {self.num_clusters} clusters.") except Exception as e2: print(f"Failed to create 8 clusters: {e2}") print("CRITICAL: Custom Sharding not supported. Falling back to Standard Collection.") self.num_clusters = 1 if self.client.collection_exists(self.collection_name): self.client.delete_collection(self.collection_name) self.client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE) ) print("Fallback successful: Created Standard Collection.") def _create_collection_and_shards(self, n_clusters): print(f"Creating collection '{self.collection_name}' with custom sharding ({n_clusters} clusters)...") self.client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE), sharding_method=models.ShardingMethod.CUSTOM, shard_number=n_clusters + 1 # Clusters + Freshness ) # Create Shard Keys print("Creating shard keys...") for i in range(n_clusters): self.client.create_shard_key(self.collection_name, str(i)) # Create freshness shard key self.client.create_shard_key(self.collection_name, str(self.freshness_shard_id)) print("Shard keys created successfully.") def index_data(self, vectors: np.ndarray, payloads: List[Dict[str, Any]], cluster_ids: List[Optional[int]] = None): """ Indexes data with batching to avoid payload limits. If cluster_ids provided, uses custom sharding (Prod). If cluster_ids is None, uses standard upsert (Baseline/Local). BATCH_SIZE hardcoded to 500 for safety. """ BATCH_SIZE = 500 if cluster_ids is None or self.is_local: # Standard Upsert points = [ models.PointStruct( id=str(uuid.uuid4()), vector=vec.tolist(), payload=payloads[i] ) for i, vec in enumerate(vectors) ] # Batching total = len(points) print(f"Upserting {total} points to '{self.collection_name}' (Standard)...") for i in range(0, total, BATCH_SIZE): batch = points[i : i + BATCH_SIZE] self.client.upsert( collection_name=self.collection_name, points=batch ) return # Custom Sharding Upsert data_by_shard = {} for i, vec in enumerate(vectors): cluster_id = cluster_ids[i] key = str(self.freshness_shard_id) if cluster_id is None else str(cluster_id) if key not in data_by_shard: data_by_shard[key] = [] data_by_shard[key].append( models.PointStruct( id=str(uuid.uuid4()), vector=vec.tolist(), payload=payloads[i] ) ) print(f"Indexing data across {len(data_by_shard)} shards (Custom Sharded)...") for key, shard_points in data_by_shard.items(): # Also batch per shard if needed (though unlikely to exceed 32MB per shard with 25k samples) # 25k samples / 32 shards ~= 800 points per shard. 800 * 8KB << 32MB. # But safe is safe. total_shard = len(shard_points) for i in range(0, total_shard, BATCH_SIZE): batch = shard_points[i : i + BATCH_SIZE] self.client.upsert( collection_name=self.collection_name, points=batch, shard_key_selector=key ) def search_hybrid(self, query_vec: np.ndarray, target_clusters: List[int], confidence: float) -> List[Any]: """ Performs the hybrid search strategy (Prod). """ # Ensure query_vec is list if isinstance(query_vec, np.ndarray): query_vec = query_vec.tolist() if isinstance(query_vec[0], list): query_vec = query_vec[0] shard_keys = [] if not target_clusters: shard_keys = None search_mode = "GLOBAL" else: shard_keys = [str(c) for c in target_clusters] + [str(self.freshness_shard_id)] search_mode = f"TARGETED (Clusters {target_clusters} + Freshness)" if self.is_local: results = self.client.query_points( collection_name=self.collection_name, query=query_vec, limit=10, with_payload=True ).points else: results = self.client.query_points( collection_name=self.collection_name, query=query_vec, shard_key_selector=shard_keys, limit=10, with_payload=True ).points return results, search_mode def search_baseline(self, query_vec: np.ndarray) -> List[Any]: """ Performs standard search (Baseline). """ if isinstance(query_vec, np.ndarray): query_vec = query_vec.tolist() if isinstance(query_vec[0], list): query_vec = query_vec[0] results = self.client.query_points( collection_name=self.collection_name, query=query_vec, limit=10, with_payload=True ).points return results def get_shard_sizes(self) -> Dict[str, int]: """ Returns a dictionary of {shard_key: count}. Only works for Custom Sharding collections. """ if self.is_local: return {"local": self.client.count(self.collection_name).count} sizes = {} # Iterate through expected shard keys # We assume keys are "0" to "num_clusters-1" and "freshness_shard_id" keys = [str(i) for i in range(self.num_clusters)] + [str(self.freshness_shard_id)] for key in keys: try: count = self.client.count( collection_name=self.collection_name, shard_key_selector=key ).count sizes[key] = count except: sizes[key] = 0 return sizes