Spaces:
Sleeping
Sleeping
File size: 9,883 Bytes
b92d96d 51fc709 b92d96d 51fc709 b92d96d 3621168 b92d96d 3621168 51fc709 3621168 51fc709 3621168 51fc709 3621168 b92d96d 3621168 b92d96d 3621168 51fc709 3621168 b92d96d 51fc709 b92d96d 51fc709 b92d96d 51fc709 b92d96d 9a9f1fb 51fc709 9a9f1fb b92d96d 9a9f1fb 51fc709 9a9f1fb 51fc709 b92d96d 51fc709 b92d96d 51fc709 b92d96d 9a9f1fb b92d96d b9df6ef b92d96d 51fc709 b92d96d 51fc709 b92d96d b9df6ef b92d96d b9df6ef b92d96d 51fc709 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
import os
from qdrant_client import QdrantClient, models
from qdrant_client.http.models import Distance, VectorParams
import numpy as np
from typing import List, Optional, Dict, Any
import uuid
class UnifiedQdrant:
def __init__(self, collection_name: str, vector_size: int, num_clusters: int = 32, freshness_shard_id: int = 999):
self.client = None
self.collection_name = collection_name
self.vector_size = vector_size
self.num_clusters = num_clusters
self.freshness_shard_id = freshness_shard_id
def initialize(self, is_baseline: bool = False):
"""
Connects to Qdrant and sets up the collection.
If is_baseline=True, creates a standard collection (No Sharding).
If is_baseline=False, creates a Custom Sharded collection.
"""
# Connect
url = os.getenv("QDRANT_URL", ":memory:")
api_key = os.getenv("QDRANT_API_KEY", None)
print(f"Connecting to Qdrant at {url}...")
# Relaxed connection settings for HF Spaces
port = 443 if url.startswith("https") else 6333
self.client = QdrantClient(
location=url,
port=port,
api_key=api_key,
timeout=60,
check_compatibility=False,
verify=False # Passed to httpx
)
self.is_local = url == ":memory:" or not url.startswith("http")
if self.is_local or is_baseline:
mode = "Local" if self.is_local else "Baseline"
print(f"Running in {mode} mode. Creating Standard Collection '{self.collection_name}'.")
self.num_clusters = 1
if self.client.collection_exists(self.collection_name):
print(f"Collection '{self.collection_name}' already exists. Skipping.")
return
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
)
print(f"Created standard collection '{self.collection_name}'.")
else:
# Custom Sharding Mode
if self.client.collection_exists(self.collection_name):
print(f"Collection '{self.collection_name}' already exists. Skipping initialization.")
return
# Try to create collection with full clusters
try:
self._create_collection_and_shards(self.num_clusters)
print(f"Successfully created collection with {self.num_clusters} clusters.")
except Exception as e:
print(f"Failed to create {self.num_clusters} clusters: {e}")
print("Attempting fallback to 8 clusters (Free Tier limit mitigation)...")
try:
self.num_clusters = 8
if self.client.collection_exists(self.collection_name):
self.client.delete_collection(self.collection_name)
self._create_collection_and_shards(self.num_clusters)
print(f"Fallback successful: Created collection with {self.num_clusters} clusters.")
except Exception as e2:
print(f"Failed to create 8 clusters: {e2}")
print("CRITICAL: Custom Sharding not supported. Falling back to Standard Collection.")
self.num_clusters = 1
if self.client.collection_exists(self.collection_name):
self.client.delete_collection(self.collection_name)
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
)
print("Fallback successful: Created Standard Collection.")
def _create_collection_and_shards(self, n_clusters):
print(f"Creating collection '{self.collection_name}' with custom sharding ({n_clusters} clusters)...")
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE),
sharding_method=models.ShardingMethod.CUSTOM,
shard_number=n_clusters + 1 # Clusters + Freshness
)
# Create Shard Keys
print("Creating shard keys...")
for i in range(n_clusters):
self.client.create_shard_key(self.collection_name, str(i))
# Create freshness shard key
self.client.create_shard_key(self.collection_name, str(self.freshness_shard_id))
print("Shard keys created successfully.")
def index_data(self, vectors: np.ndarray, payloads: List[Dict[str, Any]], cluster_ids: List[Optional[int]] = None):
"""
Indexes data with batching to avoid payload limits.
If cluster_ids provided, uses custom sharding (Prod).
If cluster_ids is None, uses standard upsert (Baseline/Local).
BATCH_SIZE hardcoded to 500 for safety.
"""
BATCH_SIZE = 500
if cluster_ids is None or self.is_local:
# Standard Upsert
points = [
models.PointStruct(
id=str(uuid.uuid4()),
vector=vec.tolist(),
payload=payloads[i]
) for i, vec in enumerate(vectors)
]
# Batching
total = len(points)
print(f"Upserting {total} points to '{self.collection_name}' (Standard)...")
for i in range(0, total, BATCH_SIZE):
batch = points[i : i + BATCH_SIZE]
self.client.upsert(
collection_name=self.collection_name,
points=batch
)
return
# Custom Sharding Upsert
data_by_shard = {}
for i, vec in enumerate(vectors):
cluster_id = cluster_ids[i]
key = str(self.freshness_shard_id) if cluster_id is None else str(cluster_id)
if key not in data_by_shard:
data_by_shard[key] = []
data_by_shard[key].append(
models.PointStruct(
id=str(uuid.uuid4()),
vector=vec.tolist(),
payload=payloads[i]
)
)
print(f"Indexing data across {len(data_by_shard)} shards (Custom Sharded)...")
for key, shard_points in data_by_shard.items():
# Also batch per shard if needed (though unlikely to exceed 32MB per shard with 25k samples)
# 25k samples / 32 shards ~= 800 points per shard. 800 * 8KB << 32MB.
# But safe is safe.
total_shard = len(shard_points)
for i in range(0, total_shard, BATCH_SIZE):
batch = shard_points[i : i + BATCH_SIZE]
self.client.upsert(
collection_name=self.collection_name,
points=batch,
shard_key_selector=key
)
def search_hybrid(self, query_vec: np.ndarray, target_clusters: List[int], confidence: float) -> List[Any]:
"""
Performs the hybrid search strategy (Prod).
"""
# Ensure query_vec is list
if isinstance(query_vec, np.ndarray):
query_vec = query_vec.tolist()
if isinstance(query_vec[0], list):
query_vec = query_vec[0]
shard_keys = []
if not target_clusters:
shard_keys = None
search_mode = "GLOBAL"
else:
shard_keys = [str(c) for c in target_clusters] + [str(self.freshness_shard_id)]
search_mode = f"TARGETED (Clusters {target_clusters} + Freshness)"
if self.is_local:
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vec,
limit=10
).points
else:
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vec,
shard_key_selector=shard_keys,
limit=10
).points
return results, search_mode
def search_baseline(self, query_vec: np.ndarray) -> List[Any]:
"""
Performs standard search (Baseline).
"""
if isinstance(query_vec, np.ndarray):
query_vec = query_vec.tolist()
if isinstance(query_vec[0], list):
query_vec = query_vec[0]
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vec,
limit=10
).points
return results
def get_shard_sizes(self) -> Dict[str, int]:
"""
Returns a dictionary of {shard_key: count}.
Only works for Custom Sharding collections.
"""
if self.is_local:
return {"local": self.client.count(self.collection_name).count}
sizes = {}
# Iterate through expected shard keys
# We assume keys are "0" to "num_clusters-1" and "freshness_shard_id"
keys = [str(i) for i in range(self.num_clusters)] + [str(self.freshness_shard_id)]
for key in keys:
try:
count = self.client.count(
collection_name=self.collection_name,
shard_key_selector=key
).count
sizes[key] = count
except:
sizes[key] = 0
return sizes
|