File size: 9,883 Bytes
b92d96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51fc709
b92d96d
51fc709
 
 
b92d96d
 
 
 
 
3621168
 
 
 
 
 
 
 
 
 
 
b92d96d
 
3621168
51fc709
 
 
3621168
51fc709
 
 
 
 
3621168
 
 
 
 
 
51fc709
3621168
 
 
b92d96d
3621168
b92d96d
 
3621168
 
 
 
 
 
 
 
 
 
 
 
51fc709
3621168
 
 
 
 
 
 
 
 
b92d96d
 
 
 
51fc709
 
 
 
 
 
b92d96d
51fc709
 
 
 
 
 
 
 
b92d96d
51fc709
b92d96d
9a9f1fb
51fc709
 
9a9f1fb
b92d96d
9a9f1fb
 
51fc709
 
 
 
 
 
 
 
 
9a9f1fb
 
 
 
 
 
 
 
 
 
51fc709
 
 
b92d96d
 
 
51fc709
b92d96d
 
 
 
 
 
51fc709
b92d96d
 
 
 
 
9a9f1fb
 
 
 
 
 
 
 
 
 
 
 
 
b92d96d
b9df6ef
b92d96d
51fc709
b92d96d
 
 
 
51fc709
b92d96d
 
 
b9df6ef
b92d96d
 
 
b9df6ef
 
b92d96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51fc709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import os
from qdrant_client import QdrantClient, models
from qdrant_client.http.models import Distance, VectorParams
import numpy as np
from typing import List, Optional, Dict, Any
import uuid

class UnifiedQdrant:
    def __init__(self, collection_name: str, vector_size: int, num_clusters: int = 32, freshness_shard_id: int = 999):
        self.client = None
        self.collection_name = collection_name
        self.vector_size = vector_size
        self.num_clusters = num_clusters
        self.freshness_shard_id = freshness_shard_id
        
    def initialize(self, is_baseline: bool = False):
        """
        Connects to Qdrant and sets up the collection.
        If is_baseline=True, creates a standard collection (No Sharding).
        If is_baseline=False, creates a Custom Sharded collection.
        """
        # Connect
        url = os.getenv("QDRANT_URL", ":memory:")
        api_key = os.getenv("QDRANT_API_KEY", None)
        print(f"Connecting to Qdrant at {url}...")
        
        # Relaxed connection settings for HF Spaces
        port = 443 if url.startswith("https") else 6333
        self.client = QdrantClient(
            location=url, 
            port=port,
            api_key=api_key, 
            timeout=60,
            check_compatibility=False,
            verify=False # Passed to httpx
        )
        
        self.is_local = url == ":memory:" or not url.startswith("http")
        
        if self.is_local or is_baseline:
            mode = "Local" if self.is_local else "Baseline"
            print(f"Running in {mode} mode. Creating Standard Collection '{self.collection_name}'.")
            self.num_clusters = 1
            
            if self.client.collection_exists(self.collection_name):
                print(f"Collection '{self.collection_name}' already exists. Skipping.")
                return

            self.client.create_collection(
                collection_name=self.collection_name,
                vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
            )
            print(f"Created standard collection '{self.collection_name}'.")
        else:
            # Custom Sharding Mode
            if self.client.collection_exists(self.collection_name):
                print(f"Collection '{self.collection_name}' already exists. Skipping initialization.")
                return

            # Try to create collection with full clusters
            try:
                self._create_collection_and_shards(self.num_clusters)
                print(f"Successfully created collection with {self.num_clusters} clusters.")
            except Exception as e:
                print(f"Failed to create {self.num_clusters} clusters: {e}")
                print("Attempting fallback to 8 clusters (Free Tier limit mitigation)...")
                try:
                    self.num_clusters = 8
                    if self.client.collection_exists(self.collection_name):
                        self.client.delete_collection(self.collection_name)
                    self._create_collection_and_shards(self.num_clusters)
                    print(f"Fallback successful: Created collection with {self.num_clusters} clusters.")
                except Exception as e2:
                    print(f"Failed to create 8 clusters: {e2}")
                    print("CRITICAL: Custom Sharding not supported. Falling back to Standard Collection.")
                    self.num_clusters = 1 
                    if self.client.collection_exists(self.collection_name):
                        self.client.delete_collection(self.collection_name)
                    
                    self.client.create_collection(
                        collection_name=self.collection_name,
                        vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
                    )
                    print("Fallback successful: Created Standard Collection.")

    def _create_collection_and_shards(self, n_clusters):
        print(f"Creating collection '{self.collection_name}' with custom sharding ({n_clusters} clusters)...")
        
        self.client.create_collection(
            collection_name=self.collection_name,
            vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE),
            sharding_method=models.ShardingMethod.CUSTOM,
            shard_number=n_clusters + 1 # Clusters + Freshness
        )
        
        # Create Shard Keys
        print("Creating shard keys...")
        for i in range(n_clusters):
            self.client.create_shard_key(self.collection_name, str(i))
            
        # Create freshness shard key
        self.client.create_shard_key(self.collection_name, str(self.freshness_shard_id))
        print("Shard keys created successfully.")

    def index_data(self, vectors: np.ndarray, payloads: List[Dict[str, Any]], cluster_ids: List[Optional[int]] = None):
        """
        Indexes data with batching to avoid payload limits.
        If cluster_ids provided, uses custom sharding (Prod).
        If cluster_ids is None, uses standard upsert (Baseline/Local).
        BATCH_SIZE hardcoded to 500 for safety.
        """
        BATCH_SIZE = 500

        if cluster_ids is None or self.is_local:
            # Standard Upsert
            points = [
                models.PointStruct(
                    id=str(uuid.uuid4()),
                    vector=vec.tolist(),
                    payload=payloads[i]
                ) for i, vec in enumerate(vectors)
            ]
            
            # Batching
            total = len(points)
            print(f"Upserting {total} points to '{self.collection_name}' (Standard)...")
            for i in range(0, total, BATCH_SIZE):
                batch = points[i : i + BATCH_SIZE]
                self.client.upsert(
                    collection_name=self.collection_name,
                    points=batch
                )
            return

        # Custom Sharding Upsert
        data_by_shard = {}
        for i, vec in enumerate(vectors):
            cluster_id = cluster_ids[i]
            key = str(self.freshness_shard_id) if cluster_id is None else str(cluster_id)
                
            if key not in data_by_shard:
                data_by_shard[key] = []
                
            data_by_shard[key].append(
                models.PointStruct(
                    id=str(uuid.uuid4()),
                    vector=vec.tolist(),
                    payload=payloads[i]
                )
            )
            
        print(f"Indexing data across {len(data_by_shard)} shards (Custom Sharded)...")
        for key, shard_points in data_by_shard.items():
            # Also batch per shard if needed (though unlikely to exceed 32MB per shard with 25k samples)
            # 25k samples / 32 shards ~= 800 points per shard. 800 * 8KB << 32MB.
            # But safe is safe.
            total_shard = len(shard_points)
            for i in range(0, total_shard, BATCH_SIZE):
                batch = shard_points[i : i + BATCH_SIZE]
                self.client.upsert(
                    collection_name=self.collection_name,
                    points=batch,
                    shard_key_selector=key
                )
            
    def search_hybrid(self, query_vec: np.ndarray, target_clusters: List[int], confidence: float) -> List[Any]:
        """
        Performs the hybrid search strategy (Prod).
        """
        # Ensure query_vec is list
        if isinstance(query_vec, np.ndarray):
            query_vec = query_vec.tolist()
            if isinstance(query_vec[0], list): 
                query_vec = query_vec[0]

        shard_keys = []
        if not target_clusters:
            shard_keys = None 
            search_mode = "GLOBAL"
        else:
            shard_keys = [str(c) for c in target_clusters] + [str(self.freshness_shard_id)]
            search_mode = f"TARGETED (Clusters {target_clusters} + Freshness)"
            
        if self.is_local:
             results = self.client.query_points(
                collection_name=self.collection_name,
                query=query_vec,
                limit=10
            ).points
        else:
            results = self.client.query_points(
                collection_name=self.collection_name,
                query=query_vec,
                shard_key_selector=shard_keys,
                limit=10
            ).points
        
        return results, search_mode

    def search_baseline(self, query_vec: np.ndarray) -> List[Any]:
        """
        Performs standard search (Baseline).
        """
        if isinstance(query_vec, np.ndarray):
            query_vec = query_vec.tolist()
            if isinstance(query_vec[0], list): 
                query_vec = query_vec[0]
                
        results = self.client.query_points(
            collection_name=self.collection_name,
            query=query_vec,
            limit=10
        ).points
        return results

    def get_shard_sizes(self) -> Dict[str, int]:
        """
        Returns a dictionary of {shard_key: count}.
        Only works for Custom Sharding collections.
        """
        if self.is_local:
            return {"local": self.client.count(self.collection_name).count}
            
        sizes = {}
        # Iterate through expected shard keys
        # We assume keys are "0" to "num_clusters-1" and "freshness_shard_id"
        keys = [str(i) for i in range(self.num_clusters)] + [str(self.freshness_shard_id)]
        
        for key in keys:
            try:
                count = self.client.count(
                    collection_name=self.collection_name,
                    shard_key_selector=key
                ).count
                sizes[key] = count
            except:
                sizes[key] = 0
        return sizes