Spaces:
Running
Running
| import os | |
| from qdrant_client import QdrantClient, models | |
| from qdrant_client.http.models import Distance, VectorParams | |
| import numpy as np | |
| from typing import List, Optional, Dict, Any | |
| import uuid | |
| class UnifiedQdrant: | |
| def __init__(self, collection_name: str, vector_size: int, num_clusters: int = 32, freshness_shard_id: int = 999): | |
| self.client = None | |
| self.collection_name = collection_name | |
| self.vector_size = vector_size | |
| self.num_clusters = num_clusters | |
| self.freshness_shard_id = freshness_shard_id | |
| def initialize(self): | |
| """ | |
| Connects to Qdrant and sets up the collection with Custom Sharding. | |
| Handles fallback if Free Tier limits are hit. | |
| """ | |
| # Connect | |
| url = os.getenv("QDRANT_URL", ":memory:") | |
| api_key = os.getenv("QDRANT_API_KEY", None) | |
| print(f"Connecting to Qdrant at {url}...") | |
| self.client = QdrantClient(location=url, api_key=api_key, timeout=60) | |
| self.is_local = url == ":memory:" or not url.startswith("http") | |
| if self.is_local: | |
| print("WARNING: Running in local/memory mode. Custom Sharding is NOT supported. Simulating behavior.") | |
| # Check if collection exists, if so, recreate it for a clean slate (or handle gracefully) | |
| if self.client.collection_exists(self.collection_name): | |
| self.client.delete_collection(self.collection_name) | |
| # Try to create collection with full clusters | |
| try: | |
| self._create_collection_and_shards(self.num_clusters) | |
| print(f"Successfully created collection with {self.num_clusters} clusters.") | |
| except Exception as e: | |
| print(f"Failed to create {self.num_clusters} clusters: {e}") | |
| print("Attempting fallback to 8 clusters (Free Tier limit mitigation)...") | |
| # Fallback 1: 8 Clusters | |
| try: | |
| self.num_clusters = 8 | |
| if self.client.collection_exists(self.collection_name): | |
| self.client.delete_collection(self.collection_name) | |
| self._create_collection_and_shards(self.num_clusters) | |
| print(f"Fallback successful: Created collection with {self.num_clusters} clusters.") | |
| except Exception as e2: | |
| print(f"Failed to create 8 clusters: {e2}") | |
| print("CRITICAL: Custom Sharding not supported. Falling back to Standard Collection (No Sharding).") | |
| # Fallback 2: Standard Collection | |
| self.num_clusters = 1 # Virtual clusters only | |
| if self.client.collection_exists(self.collection_name): | |
| self.client.delete_collection(self.collection_name) | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE) | |
| ) | |
| print("Fallback successful: Created Standard Collection.") | |
| def _create_collection_and_shards(self, n_clusters): | |
| print(f"Creating collection '{self.collection_name}' with custom sharding ({n_clusters} clusters)...") | |
| if self.is_local: | |
| # Local mode doesn't support sharding_method=CUSTOM | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE) | |
| ) | |
| else: | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE), | |
| sharding_method=models.ShardingMethod.CUSTOM, | |
| shard_number=n_clusters + 1 # Clusters + Freshness | |
| ) | |
| # CRITICAL: Create Shard Keys | |
| if not self.is_local: | |
| print("Creating shard keys...") | |
| for i in range(n_clusters): | |
| self.client.create_shard_key(self.collection_name, str(i)) | |
| # Create freshness shard key | |
| self.client.create_shard_key(self.collection_name, str(self.freshness_shard_id)) | |
| print("Shard keys created successfully.") | |
| def index_data(self, vectors: np.ndarray, payloads: List[Dict[str, Any]], cluster_ids: List[Optional[int]]): | |
| """ | |
| Indexes data into the specific shards based on cluster_ids. | |
| If cluster_id is None, it goes to the Freshness Shard. | |
| """ | |
| points = [] | |
| # We need to batch this properly, but for simplicity we'll group by shard | |
| # to minimize network calls if possible, or just iterate. | |
| # Qdrant's upsert can take a batch, but they must share the same shard key? | |
| # Actually, with custom sharding, if we provide a list of points, | |
| # we might need to specify the shard key per operation or batch by shard key. | |
| # The `upsert` method allows `shard_key_selector`. | |
| # It's best to batch by shard key. | |
| data_by_shard = {} | |
| for i, vec in enumerate(vectors): | |
| cluster_id = cluster_ids[i] | |
| if cluster_id is None: | |
| key = str(self.freshness_shard_id) | |
| else: | |
| key = str(cluster_id) | |
| if key not in data_by_shard: | |
| data_by_shard[key] = [] | |
| point_id = str(uuid.uuid4()) | |
| data_by_shard[key].append( | |
| models.PointStruct( | |
| id=point_id, | |
| vector=vec.tolist(), | |
| payload=payloads[i] | |
| ) | |
| ) | |
| # Upsert batches | |
| print(f"Indexing data across {len(data_by_shard)} shards...") | |
| for key, batch_points in data_by_shard.items(): | |
| if self.is_local: | |
| self.client.upsert( | |
| collection_name=self.collection_name, | |
| points=batch_points | |
| # No shard_key_selector in local | |
| ) | |
| else: | |
| self.client.upsert( | |
| collection_name=self.collection_name, | |
| points=batch_points, | |
| shard_key_selector=key | |
| ) | |
| def search_hybrid(self, query_vec: np.ndarray, target_cluster: int, confidence: float) -> List[Any]: | |
| """ | |
| Performs the hybrid search strategy. | |
| - Always include FRESHNESS_SHARD_ID. | |
| - If confidence < 0.5, Global Search (all shards). | |
| - Else, search [target_cluster, FRESHNESS_SHARD_ID]. | |
| """ | |
| # Ensure query_vec is list | |
| if isinstance(query_vec, np.ndarray): | |
| query_vec = query_vec.tolist() | |
| if isinstance(query_vec[0], list): # Handle 2D array if passed | |
| query_vec = query_vec[0] | |
| shard_keys = [] | |
| # Logic | |
| if confidence < 0.5: | |
| # Global Search | |
| # In Qdrant, if we don't specify shard_key_selector, does it search all? | |
| # With custom sharding, usually yes, or we might need to specify all keys. | |
| # Let's assume passing None or not passing it searches all. | |
| # However, the prompt says "Trigger a Global Search". | |
| # Explicitly, we can just NOT pass shard_key_selector. | |
| shard_keys = None | |
| search_mode = "GLOBAL" | |
| else: | |
| # Targeted Search | |
| shard_keys = [str(target_cluster), str(self.freshness_shard_id)] | |
| search_mode = f"TARGETED (Cluster {target_cluster} + Freshness)" | |
| # print(f"Searching: {search_mode} | Confidence: {confidence:.4f}") | |
| if self.is_local: | |
| results = self.client.query_points( | |
| collection_name=self.collection_name, | |
| query=query_vec, | |
| limit=10 | |
| ).points | |
| else: | |
| results = self.client.query_points( | |
| collection_name=self.collection_name, | |
| query=query_vec, | |
| shard_key_selector=shard_keys, | |
| limit=10 | |
| ).points | |
| return results, search_mode | |