Spaces:
Sleeping
Sleeping
File size: 3,640 Bytes
97f9138 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from typing import List, Dict, Optional
import uuid
from src.config import config
from src.embeddings import embedding_service
class VectorStore:
def __init__(self, host: str = None, port: int = None, collection_name: str = None):
self.host = host or config.QDRANT_HOST
self.port = port or config.QDRANT_PORT
self.collection_name = collection_name or config.COLLECTION_NAME
self.client = None
def connect(self):
if self.client is None:
if config.QDRANT_URL:
self.client = QdrantClient(
url=config.QDRANT_URL,
api_key=config.QDRANT_API_KEY
)
else:
self.client = QdrantClient(host=self.host, port=self.port)
self._ensure_collection()
def _ensure_collection(self):
collections = self.client.get_collections().collections
collection_names = [col.name for col in collections]
if self.collection_name not in collection_names:
embedding_dim = embedding_service.get_embedding_dimension()
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=embedding_dim,
distance=Distance.COSINE
)
)
print(f"Created collection: {self.collection_name}")
def store_chunks(self, url_id: str, url: str, chunks: List[Dict], embeddings: List[List[float]]):
self.connect()
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
point_id = str(uuid.uuid4())
point = PointStruct(
id=point_id,
vector=embedding,
payload={
"url_id": url_id,
"url": url,
"chunk_id": chunk["id"],
"text": chunk["text"],
"start_word": chunk["start_word"],
"end_word": chunk["end_word"]
}
)
points.append(point)
self.client.upsert(
collection_name=self.collection_name,
points=points
)
return len(points)
def search(self, query_embedding: List[float], top_k: int = None) -> List[Dict]:
self.connect()
k = top_k or config.TOP_K_RESULTS
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_embedding,
limit=k
)
return [
{
"id": result.id,
"score": result.score,
"url": result.payload.get("url"),
"url_id": result.payload.get("url_id"),
"text": result.payload.get("text"),
"chunk_id": result.payload.get("chunk_id")
}
for result in results
]
def delete_by_url_id(self, url_id: str):
self.connect()
self.client.delete(
collection_name=self.collection_name,
points_selector={
"filter": {
"must": [
{
"key": "url_id",
"match": {"value": url_id}
}
]
}
}
)
vector_store = VectorStore()
|