Spaces:
Running
Running
File size: 9,348 Bytes
b92d96d 3621168 b92d96d 3621168 b92d96d 3621168 b92d96d 3621168 b92d96d 3621168 b92d96d 3621168 b92d96d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import os
from qdrant_client import QdrantClient, models
from qdrant_client.http.models import Distance, VectorParams
import numpy as np
from typing import List, Optional, Dict, Any
import uuid
class UnifiedQdrant:
def __init__(self, collection_name: str, vector_size: int, num_clusters: int = 32, freshness_shard_id: int = 999):
self.client = None
self.collection_name = collection_name
self.vector_size = vector_size
self.num_clusters = num_clusters
self.freshness_shard_id = freshness_shard_id
def initialize(self):
"""
Connects to Qdrant and sets up the collection with Custom Sharding.
Handles fallback if Free Tier limits are hit.
"""
# Connect
url = os.getenv("QDRANT_URL", ":memory:")
api_key = os.getenv("QDRANT_API_KEY", None)
print(f"Connecting to Qdrant at {url}...")
# Relaxed connection settings for HF Spaces
port = 443 if url.startswith("https") else 6333
self.client = QdrantClient(
location=url,
port=port,
api_key=api_key,
timeout=60,
check_compatibility=False,
verify=False # Passed to httpx
)
self.is_local = url == ":memory:" or not url.startswith("http")
if self.is_local:
print("Running in local/memory mode. Custom Sharding is NOT supported. Simulating behavior.")
self.num_clusters = 1
if self.client.collection_exists(collection_name=self.collection_name):
self.client.delete_collection(collection_name=self.collection_name)
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
)
print(f"Created standard collection '{self.collection_name}'.")
else:
# Check if exists first to avoid accidental deletion
if self.client.collection_exists(self.collection_name):
print(f"Collection '{self.collection_name}' already exists. Skipping initialization.")
return
# Try to create collection with full clusters
try:
self._create_collection_and_shards(self.num_clusters)
print(f"Successfully created collection with {self.num_clusters} clusters.")
except Exception as e:
print(f"Failed to create {self.num_clusters} clusters: {e}")
print("Attempting fallback to 8 clusters (Free Tier limit mitigation)...")
# Fallback 1: 8 Clusters
try:
self.num_clusters = 8
if self.client.collection_exists(self.collection_name):
self.client.delete_collection(self.collection_name)
self._create_collection_and_shards(self.num_clusters)
print(f"Fallback successful: Created collection with {self.num_clusters} clusters.")
except Exception as e2:
print(f"Failed to create 8 clusters: {e2}")
print("CRITICAL: Custom Sharding not supported. Falling back to Standard Collection (No Sharding).")
# Fallback 2: Standard Collection
self.num_clusters = 1
if self.client.collection_exists(self.collection_name):
self.client.delete_collection(self.collection_name)
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
)
print("Fallback successful: Created Standard Collection.")
def _create_collection_and_shards(self, n_clusters):
print(f"Creating collection '{self.collection_name}' with custom sharding ({n_clusters} clusters)...")
if self.client.collection_exists(self.collection_name):
print(f"Collection '{self.collection_name}' already exists. Skipping creation.")
return
if self.is_local:
# Local mode doesn't support sharding_method=CUSTOM
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
)
else:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE),
sharding_method=models.ShardingMethod.CUSTOM,
shard_number=n_clusters + 1 # Clusters + Freshness
)
# CRITICAL: Create Shard Keys
if not self.is_local:
print("Creating shard keys...")
for i in range(n_clusters):
self.client.create_shard_key(self.collection_name, str(i))
# Create freshness shard key
self.client.create_shard_key(self.collection_name, str(self.freshness_shard_id))
print("Shard keys created successfully.")
def index_data(self, vectors: np.ndarray, payloads: List[Dict[str, Any]], cluster_ids: List[Optional[int]]):
"""
Indexes data into the specific shards based on cluster_ids.
If cluster_id is None, it goes to the Freshness Shard.
"""
points = []
# We need to batch this properly, but for simplicity we'll group by shard
# to minimize network calls if possible, or just iterate.
# Qdrant's upsert can take a batch, but they must share the same shard key?
# Actually, with custom sharding, if we provide a list of points,
# we might need to specify the shard key per operation or batch by shard key.
# The `upsert` method allows `shard_key_selector`.
# It's best to batch by shard key.
data_by_shard = {}
for i, vec in enumerate(vectors):
cluster_id = cluster_ids[i]
if cluster_id is None:
key = str(self.freshness_shard_id)
else:
key = str(cluster_id)
if key not in data_by_shard:
data_by_shard[key] = []
point_id = str(uuid.uuid4())
data_by_shard[key].append(
models.PointStruct(
id=point_id,
vector=vec.tolist(),
payload=payloads[i]
)
)
# Upsert batches
print(f"Indexing data across {len(data_by_shard)} shards...")
for key, batch_points in data_by_shard.items():
if self.is_local:
self.client.upsert(
collection_name=self.collection_name,
points=batch_points
# No shard_key_selector in local
)
else:
self.client.upsert(
collection_name=self.collection_name,
points=batch_points,
shard_key_selector=key
)
def search_hybrid(self, query_vec: np.ndarray, target_cluster: int, confidence: float) -> List[Any]:
"""
Performs the hybrid search strategy.
- Always include FRESHNESS_SHARD_ID.
- If confidence < 0.5, Global Search (all shards).
- Else, search [target_cluster, FRESHNESS_SHARD_ID].
"""
# Ensure query_vec is list
if isinstance(query_vec, np.ndarray):
query_vec = query_vec.tolist()
if isinstance(query_vec[0], list): # Handle 2D array if passed
query_vec = query_vec[0]
shard_keys = []
# Logic
if confidence < 0.5:
# Global Search
# In Qdrant, if we don't specify shard_key_selector, does it search all?
# With custom sharding, usually yes, or we might need to specify all keys.
# Let's assume passing None or not passing it searches all.
# However, the prompt says "Trigger a Global Search".
# Explicitly, we can just NOT pass shard_key_selector.
shard_keys = None
search_mode = "GLOBAL"
else:
# Targeted Search
shard_keys = [str(target_cluster), str(self.freshness_shard_id)]
search_mode = f"TARGETED (Cluster {target_cluster} + Freshness)"
# print(f"Searching: {search_mode} | Confidence: {confidence:.4f}")
if self.is_local:
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vec,
limit=10
).points
else:
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vec,
shard_key_selector=shard_keys,
limit=10
).points
return results, search_mode
|