Spaces:
Sleeping
Sleeping
File size: 3,368 Bytes
36bfe21 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
from typing import List, Dict, Any, Optional
from app.config import get_settings
from app.models.document import SearchResult
class QdrantDB:
"""Qdrant vector database client."""
def __init__(self):
self.settings = get_settings()
self.client = QdrantClient(
url=self.settings.qdrant_url,
api_key=self.settings.qdrant_api_key
)
self.collection_name = self.settings.qdrant_collection_name
def create_collection(self, vector_size: int = 1024):
"""Create the collection if it doesn't exist."""
try:
self.client.get_collection(self.collection_name)
print(f"Collection '{self.collection_name}' already exists")
except Exception:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=vector_size,
distance=Distance.COSINE
)
)
print(f"Created collection '{self.collection_name}'")
def upsert_chunks(self, chunks: List[Dict[str, Any]], vectors: List[List[float]]):
"""Insert or update document chunks with their embeddings."""
points = [
PointStruct(
id=chunk['chunk_id'],
vector=vector,
payload=chunk
)
for chunk, vector in zip(chunks, vectors)
]
self.client.upsert(
collection_name=self.collection_name,
points=points
)
def search(
self,
query_vector: List[float],
limit: int = 5,
filters: Optional[Dict[str, Any]] = None
) -> List[SearchResult]:
"""Search for similar chunks."""
# Build filter if provided
search_filter = None
if filters:
conditions = []
if 'chapter' in filters:
conditions.append(
FieldCondition(
key="chapter_number",
match=MatchValue(value=filters['chapter'])
)
)
if conditions:
search_filter = Filter(must=conditions)
# Perform search using query_points
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vector,
limit=limit,
query_filter=search_filter
).points
# Convert to SearchResult models
return [
SearchResult(
chunk_id=result.payload['chunk_id'],
chapter_number=result.payload['chapter_number'],
chapter_title=result.payload['chapter_title'],
section_title=result.payload['section_title'],
content=result.payload['content'],
content_type=result.payload['content_type'],
url=result.payload['url'],
score=result.score
)
for result in results
]
def get_collection_info(self) -> Dict[str, Any]:
"""Get information about the collection."""
return self.client.get_collection(self.collection_name)
|