Spaces:
Running
Running
| import os | |
| from qdrant_client import QdrantClient, models | |
| from qdrant_client.http.models import Distance, VectorParams | |
| import numpy as np | |
| from typing import List, Optional, Dict, Any | |
| import uuid | |
| class UnifiedQdrant: | |
| def __init__(self, collection_name: str, vector_size: int, num_clusters: int = 32, freshness_shard_id: int = 999): | |
| self.client = None | |
| self.collection_name = collection_name | |
| self.vector_size = vector_size | |
| self.num_clusters = num_clusters | |
| self.freshness_shard_id = freshness_shard_id | |
| def initialize(self, is_baseline: bool = False): | |
| """ | |
| Connects to Qdrant and sets up the collection. | |
| If is_baseline=True, creates a standard collection (No Sharding). | |
| If is_baseline=False, creates a Custom Sharded collection. | |
| """ | |
| # Connect | |
| url = os.getenv("QDRANT_URL", ":memory:") | |
| api_key = os.getenv("QDRANT_API_KEY", None) | |
| print(f"Connecting to Qdrant at {url}...") | |
| # Relaxed connection settings for HF Spaces | |
| port = 443 if url.startswith("https") else 6333 | |
| self.client = QdrantClient( | |
| location=url, | |
| port=port, | |
| api_key=api_key, | |
| timeout=60, | |
| check_compatibility=False, | |
| verify=False # Passed to httpx | |
| ) | |
| self.is_local = url == ":memory:" or not url.startswith("http") | |
| if self.is_local or is_baseline: | |
| mode = "Local" if self.is_local else "Baseline" | |
| print(f"Running in {mode} mode. Creating Standard Collection '{self.collection_name}'.") | |
| self.num_clusters = 1 | |
| if self.client.collection_exists(self.collection_name): | |
| print(f"Collection '{self.collection_name}' already exists. Skipping.") | |
| return | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE) | |
| ) | |
| print(f"Created standard collection '{self.collection_name}'.") | |
| else: | |
| # Custom Sharding Mode | |
| if self.client.collection_exists(self.collection_name): | |
| print(f"Collection '{self.collection_name}' already exists. Skipping initialization.") | |
| return | |
| # Try to create collection with full clusters | |
| try: | |
| self._create_collection_and_shards(self.num_clusters) | |
| print(f"Successfully created collection with {self.num_clusters} clusters.") | |
| except Exception as e: | |
| print(f"Failed to create {self.num_clusters} clusters: {e}") | |
| print("Attempting fallback to 8 clusters (Free Tier limit mitigation)...") | |
| try: | |
| self.num_clusters = 8 | |
| if self.client.collection_exists(self.collection_name): | |
| self.client.delete_collection(self.collection_name) | |
| self._create_collection_and_shards(self.num_clusters) | |
| print(f"Fallback successful: Created collection with {self.num_clusters} clusters.") | |
| except Exception as e2: | |
| print(f"Failed to create 8 clusters: {e2}") | |
| print("CRITICAL: Custom Sharding not supported. Falling back to Standard Collection.") | |
| self.num_clusters = 1 | |
| if self.client.collection_exists(self.collection_name): | |
| self.client.delete_collection(self.collection_name) | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE) | |
| ) | |
| print("Fallback successful: Created Standard Collection.") | |
| def _create_collection_and_shards(self, n_clusters): | |
| print(f"Creating collection '{self.collection_name}' with custom sharding ({n_clusters} clusters)...") | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE), | |
| sharding_method=models.ShardingMethod.CUSTOM, | |
| shard_number=n_clusters + 1 # Clusters + Freshness | |
| ) | |
| # Create Shard Keys | |
| print("Creating shard keys...") | |
| for i in range(n_clusters): | |
| self.client.create_shard_key(self.collection_name, str(i)) | |
| # Create freshness shard key | |
| self.client.create_shard_key(self.collection_name, str(self.freshness_shard_id)) | |
| print("Shard keys created successfully.") | |
| def index_data(self, vectors: np.ndarray, payloads: List[Dict[str, Any]], cluster_ids: List[Optional[int]] = None): | |
| """ | |
| Indexes data with batching to avoid payload limits. | |
| If cluster_ids provided, uses custom sharding (Prod). | |
| If cluster_ids is None, uses standard upsert (Baseline/Local). | |
| BATCH_SIZE hardcoded to 500 for safety. | |
| """ | |
| BATCH_SIZE = 500 | |
| if cluster_ids is None or self.is_local: | |
| # Standard Upsert | |
| points = [ | |
| models.PointStruct( | |
| id=str(uuid.uuid4()), | |
| vector=vec.tolist(), | |
| payload=payloads[i] | |
| ) for i, vec in enumerate(vectors) | |
| ] | |
| # Batching | |
| total = len(points) | |
| print(f"Upserting {total} points to '{self.collection_name}' (Standard)...") | |
| for i in range(0, total, BATCH_SIZE): | |
| batch = points[i : i + BATCH_SIZE] | |
| self.client.upsert( | |
| collection_name=self.collection_name, | |
| points=batch | |
| ) | |
| return | |
| # Custom Sharding Upsert | |
| data_by_shard = {} | |
| for i, vec in enumerate(vectors): | |
| cluster_id = cluster_ids[i] | |
| key = str(self.freshness_shard_id) if cluster_id is None else str(cluster_id) | |
| if key not in data_by_shard: | |
| data_by_shard[key] = [] | |
| data_by_shard[key].append( | |
| models.PointStruct( | |
| id=str(uuid.uuid4()), | |
| vector=vec.tolist(), | |
| payload=payloads[i] | |
| ) | |
| ) | |
| print(f"Indexing data across {len(data_by_shard)} shards (Custom Sharded)...") | |
| for key, shard_points in data_by_shard.items(): | |
| # Also batch per shard if needed (though unlikely to exceed 32MB per shard with 25k samples) | |
| # 25k samples / 32 shards ~= 800 points per shard. 800 * 8KB << 32MB. | |
| # But safe is safe. | |
| total_shard = len(shard_points) | |
| for i in range(0, total_shard, BATCH_SIZE): | |
| batch = shard_points[i : i + BATCH_SIZE] | |
| self.client.upsert( | |
| collection_name=self.collection_name, | |
| points=batch, | |
| shard_key_selector=key | |
| ) | |
| def search_hybrid(self, query_vec: np.ndarray, target_clusters: List[int], confidence: float) -> List[Any]: | |
| """ | |
| Performs the hybrid search strategy (Prod). | |
| """ | |
| # Ensure query_vec is list | |
| if isinstance(query_vec, np.ndarray): | |
| query_vec = query_vec.tolist() | |
| if isinstance(query_vec[0], list): | |
| query_vec = query_vec[0] | |
| shard_keys = [] | |
| if not target_clusters: | |
| shard_keys = None | |
| search_mode = "GLOBAL" | |
| else: | |
| shard_keys = [str(c) for c in target_clusters] + [str(self.freshness_shard_id)] | |
| search_mode = f"TARGETED (Clusters {target_clusters} + Freshness)" | |
| if self.is_local: | |
| results = self.client.query_points( | |
| collection_name=self.collection_name, | |
| query=query_vec, | |
| limit=10 | |
| ).points | |
| else: | |
| results = self.client.query_points( | |
| collection_name=self.collection_name, | |
| query=query_vec, | |
| shard_key_selector=shard_keys, | |
| limit=10 | |
| ).points | |
| return results, search_mode | |
| def search_baseline(self, query_vec: np.ndarray) -> List[Any]: | |
| """ | |
| Performs standard search (Baseline). | |
| """ | |
| if isinstance(query_vec, np.ndarray): | |
| query_vec = query_vec.tolist() | |
| if isinstance(query_vec[0], list): | |
| query_vec = query_vec[0] | |
| results = self.client.query_points( | |
| collection_name=self.collection_name, | |
| query=query_vec, | |
| limit=10 | |
| ).points | |
| return results | |
| def get_shard_sizes(self) -> Dict[str, int]: | |
| """ | |
| Returns a dictionary of {shard_key: count}. | |
| Only works for Custom Sharding collections. | |
| """ | |
| if self.is_local: | |
| return {"local": self.client.count(self.collection_name).count} | |
| sizes = {} | |
| # Iterate through expected shard keys | |
| # We assume keys are "0" to "num_clusters-1" and "freshness_shard_id" | |
| keys = [str(i) for i in range(self.num_clusters)] + [str(self.freshness_shard_id)] | |
| for key in keys: | |
| try: | |
| count = self.client.count( | |
| collection_name=self.collection_name, | |
| shard_key_selector=key | |
| ).count | |
| sizes[key] = count | |
| except: | |
| sizes[key] = 0 | |
| return sizes | |