|
|
import logging |
|
|
import uuid |
|
|
from typing import Dict, List, Any, Optional, Tuple |
|
|
from dataclasses import dataclass, asdict |
|
|
import time |
|
|
from qdrant_client import QdrantClient |
|
|
from qdrant_client.http import models |
|
|
from qdrant_client.http.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue |
|
|
import os |
|
|
import sys |
|
|
from dotenv import load_dotenv |
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
from src.document_processor import DocumentChunk, ChunkMetadata |
|
|
try: |
|
|
from logger.custom_logger import CustomLoggerTracker |
|
|
custom_log = CustomLoggerTracker() |
|
|
logger = custom_log.get_logger("vector_store") |
|
|
|
|
|
except ImportError: |
|
|
|
|
|
logger = logging.getLogger("vector_store") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SearchResult: |
|
|
"""Result of vector similarity search.""" |
|
|
chunk: DocumentChunk |
|
|
similarity_score: float |
|
|
rerank_score: Optional[float] = None |
|
|
metadata: Dict[str, Any] = None |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.metadata is None: |
|
|
self.metadata = {} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class IndexStats: |
|
|
"""Statistics about the vector index.""" |
|
|
total_points: int |
|
|
collection_name: str |
|
|
vector_size: int |
|
|
distance_metric: str |
|
|
indexed_documents: int |
|
|
last_updated: str |
|
|
|
|
|
|
|
|
class QdrantVectorStore: |
|
|
def __init__(self, config: Dict[str, Any]): |
|
|
self.config = config |
|
|
self.url = config.get('qdrant_url', 'http://localhost:6333') |
|
|
self.api_key = config.get('qdrant_api_key') |
|
|
self.collection_name = config.get('qdrant_collection', 'manufacturing_docs') |
|
|
self.vector_size = config.get('vector_size', 1024) |
|
|
self.distance_metric = Distance.COSINE |
|
|
|
|
|
|
|
|
logger.info(f"Connecting to Qdrant at URL: {os.environ['QDRANT_URL']}") |
|
|
self.client = QdrantClient( |
|
|
url="https://50f53cc8-bbb0-4939-8254-8f025a577222.us-west-2-0.aws.cloud.qdrant.io:6333", |
|
|
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.gHOXbfqPucRwhczrW8s3VSZbconqQ6Rk49Uaz9ZChdE",) |
|
|
|
|
|
self._ensure_collection_exists() |
|
|
logger.info(f"Qdrant vector store initialized: {os.environ['QDRANT_URL']}, collection: {self.collection_name}") |
|
|
|
|
|
|
|
|
def _ensure_collection_exists(self): |
|
|
try: |
|
|
|
|
|
collections = self.client.get_collections() |
|
|
collection_names = [col.name for col in collections.collections] |
|
|
if self.collection_name not in collection_names: |
|
|
logger.info(f"Creating collection: {self.collection_name}") |
|
|
|
|
|
self.client.create_collection( |
|
|
collection_name=self.collection_name, |
|
|
vectors_config=VectorParams( |
|
|
size=self.vector_size, |
|
|
distance=self.distance_metric |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
self._create_payload_indexes() |
|
|
logger.info(f"Collection {self.collection_name} created successfully") |
|
|
else: |
|
|
logger.debug(f"Collection {self.collection_name} already exists") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to ensure collection exists: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
def _create_payload_indexes(self): |
|
|
try: |
|
|
self.client.create_payload_index( |
|
|
collection_name=self.collection_name, |
|
|
field_name="document_id", |
|
|
field_schema=models.KeywordIndexParams()) |
|
|
|
|
|
self.client.create_payload_index( |
|
|
collection_name=self.collection_name, |
|
|
field_name="document_type", |
|
|
field_schema=models.KeywordIndexParams()) |
|
|
|
|
|
|
|
|
self.client.create_payload_index( |
|
|
collection_name=self.collection_name, |
|
|
field_name="page_number", |
|
|
field_schema=models.IntegerIndexParams()) |
|
|
|
|
|
|
|
|
self.client.create_payload_index( |
|
|
collection_name=self.collection_name, |
|
|
field_name="worksheet_name", |
|
|
field_schema=models.KeywordIndexParams()) |
|
|
|
|
|
logger.debug("Payload indexes created successfully") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to create payload indexes: {e}") |
|
|
|
|
|
def add_documents(self, chunks: List[DocumentChunk]) -> bool: |
|
|
if not chunks: |
|
|
logger.warning("No chunks provided for indexing") |
|
|
return True |
|
|
try: |
|
|
points = [] |
|
|
for chunk in chunks: |
|
|
if not chunk.embedding: |
|
|
logger.warning(f"Chunk {chunk.metadata.chunk_id} has no embedding, skipping") |
|
|
continue |
|
|
|
|
|
|
|
|
point = PointStruct( |
|
|
id=str(uuid.uuid4()), |
|
|
vector=chunk.embedding, |
|
|
payload={ |
|
|
|
|
|
"chunk_id": chunk.metadata.chunk_id, |
|
|
"document_id": chunk.metadata.document_id, |
|
|
"chunk_index": chunk.metadata.chunk_index, |
|
|
"content": chunk.content, |
|
|
|
|
|
|
|
|
"page_number": chunk.metadata.page_number, |
|
|
"worksheet_name": chunk.metadata.worksheet_name, |
|
|
"cell_range": chunk.metadata.cell_range, |
|
|
"section_title": chunk.metadata.section_title, |
|
|
|
|
|
|
|
|
"image_references": chunk.metadata.image_references, |
|
|
"table_references": chunk.metadata.table_references, |
|
|
|
|
|
|
|
|
"extraction_timestamp": chunk.metadata.extraction_timestamp.isoformat(), |
|
|
"confidence_score": chunk.metadata.confidence_score, |
|
|
|
|
|
|
|
|
"content_length": len(chunk.content), |
|
|
"indexed_at": time.time() |
|
|
} |
|
|
) |
|
|
|
|
|
points.append(point) |
|
|
if not points: |
|
|
logger.warning("No valid points to index") |
|
|
return True |
|
|
|
|
|
|
|
|
operation_info = self.client.upsert( |
|
|
collection_name=self.collection_name, |
|
|
points=points) |
|
|
logger.info(f"Successfully indexed {len(points)} chunks to Qdrant") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to add documents to vector store: {e}") |
|
|
return False |
|
|
|
|
|
def similarity_search(self, query_embedding: List[float], k: int = 10, |
|
|
filters: Optional[Dict[str, Any]] = None) -> List[SearchResult]: |
|
|
try: |
|
|
|
|
|
filter_conditions = self._build_filter_conditions(filters) if filters else None |
|
|
|
|
|
search_results = self.client.search( |
|
|
collection_name=self.collection_name, |
|
|
query_vector=query_embedding, |
|
|
limit=k, |
|
|
query_filter=filter_conditions, |
|
|
with_payload=True, |
|
|
with_vectors=False |
|
|
) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for result in search_results: |
|
|
payload = result.payload |
|
|
|
|
|
|
|
|
metadata = ChunkMetadata( |
|
|
chunk_id=payload.get("chunk_id", ""), |
|
|
document_id=payload.get("document_id", ""), |
|
|
chunk_index=payload.get("chunk_index", 0), |
|
|
page_number=payload.get("page_number"), |
|
|
worksheet_name=payload.get("worksheet_name"), |
|
|
cell_range=payload.get("cell_range"), |
|
|
section_title=payload.get("section_title"), |
|
|
image_references=payload.get("image_references", []), |
|
|
table_references=payload.get("table_references", []), |
|
|
confidence_score=payload.get("confidence_score")) |
|
|
|
|
|
|
|
|
chunk = DocumentChunk( |
|
|
content=payload.get("content", ""), |
|
|
metadata=metadata, |
|
|
embedding=None |
|
|
) |
|
|
|
|
|
|
|
|
search_result = SearchResult( |
|
|
chunk=chunk, |
|
|
similarity_score=result.score, |
|
|
metadata={ |
|
|
"qdrant_id": result.id, |
|
|
"content_length": payload.get("content_length", 0), |
|
|
"indexed_at": payload.get("indexed_at"), |
|
|
"extraction_timestamp": payload.get("extraction_timestamp") |
|
|
} |
|
|
) |
|
|
|
|
|
results.append(search_result) |
|
|
|
|
|
logger.debug(f"Found {len(results)} similar chunks") |
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Similarity search failed: {e}") |
|
|
return [] |
|
|
|
|
|
def filtered_search(self, query_embedding: List[float], filters: Dict[str, Any], |
|
|
k: int = 10) -> List[SearchResult]: |
|
|
return self.similarity_search(query_embedding, k, filters) |
|
|
|
|
|
def delete_document(self, document_id: str) -> bool: |
|
|
try: |
|
|
|
|
|
self.client.delete( |
|
|
collection_name=self.collection_name, |
|
|
points_selector=models.FilterSelector( |
|
|
filter=Filter( |
|
|
must=[ |
|
|
FieldCondition( |
|
|
key="document_id", |
|
|
match=MatchValue(value=document_id) |
|
|
) |
|
|
] |
|
|
) |
|
|
) |
|
|
) |
|
|
|
|
|
logger.info(f"Deleted all chunks for document: {document_id}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to delete document {document_id}: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def get_collection_info(self) -> Optional[IndexStats]: |
|
|
try: |
|
|
collection_info = self.client.get_collection(self.collection_name) |
|
|
|
|
|
|
|
|
search_results = self.client.scroll( |
|
|
collection_name=self.collection_name, |
|
|
limit=10000, |
|
|
with_payload=["document_id"], |
|
|
with_vectors=False |
|
|
) |
|
|
|
|
|
unique_documents = set() |
|
|
for point in search_results[0]: |
|
|
if point.payload and "document_id" in point.payload: |
|
|
unique_documents.add(point.payload["document_id"]) |
|
|
|
|
|
return IndexStats( |
|
|
total_points=collection_info.points_count, |
|
|
collection_name=self.collection_name, |
|
|
vector_size=collection_info.config.params.vectors.size, |
|
|
distance_metric=collection_info.config.params.vectors.distance.name, |
|
|
indexed_documents=len(unique_documents), |
|
|
last_updated=time.strftime("%Y-%m-%d %H:%M:%S") |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to get collection info: {e}") |
|
|
return None |
|
|
|
|
|
def _build_filter_conditions(self, filters: Dict[str, Any]) -> Filter: |
|
|
""" |
|
|
Build Qdrant filter conditions from filter dictionary. |
|
|
|
|
|
Args: |
|
|
filters: Dictionary of filter conditions |
|
|
|
|
|
Returns: |
|
|
Qdrant Filter object |
|
|
""" |
|
|
conditions = [] |
|
|
|
|
|
|
|
|
if "document_id" in filters: |
|
|
conditions.append( |
|
|
FieldCondition( |
|
|
key="document_id", |
|
|
match=MatchValue(value=filters["document_id"]) |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
if "document_type" in filters: |
|
|
conditions.append( |
|
|
FieldCondition( |
|
|
key="document_type", |
|
|
match=MatchValue(value=filters["document_type"]) |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
if "page_number" in filters: |
|
|
conditions.append( |
|
|
FieldCondition( |
|
|
key="page_number", |
|
|
match=MatchValue(value=filters["page_number"]) |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
if "worksheet_name" in filters: |
|
|
conditions.append( |
|
|
FieldCondition( |
|
|
key="worksheet_name", |
|
|
match=MatchValue(value=filters["worksheet_name"]) |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
if "min_content_length" in filters: |
|
|
conditions.append( |
|
|
FieldCondition( |
|
|
key="content_length", |
|
|
range=models.Range(gte=filters["min_content_length"]) |
|
|
) |
|
|
) |
|
|
|
|
|
if "max_content_length" in filters: |
|
|
conditions.append( |
|
|
FieldCondition( |
|
|
key="content_length", |
|
|
range=models.Range(lte=filters["max_content_length"]) |
|
|
) |
|
|
) |
|
|
|
|
|
return Filter(must=conditions) if conditions else None |
|
|
|
|
|
def health_check(self) -> bool: |
|
|
""" |
|
|
Check if the vector store is healthy and accessible. |
|
|
|
|
|
Returns: |
|
|
True if healthy, False otherwise |
|
|
""" |
|
|
try: |
|
|
|
|
|
self.client.get_collection(self.collection_name) |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Vector store health check failed: {e}") |
|
|
return False |
|
|
|
|
|
def create_collection(self, vector_size: int, distance_metric: Distance = Distance.COSINE) -> bool: |
|
|
try: |
|
|
self.client.create_collection( |
|
|
collection_name=self.collection_name, |
|
|
vectors_config=VectorParams( |
|
|
size=vector_size, |
|
|
distance=distance_metric |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
self.vector_size = vector_size |
|
|
self.distance_metric = distance_metric |
|
|
|
|
|
self._create_payload_indexes() |
|
|
logger.info(f"Created collection {self.collection_name} with vector size {vector_size}") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to create collection: {e}") |
|
|
return False |
|
|
|
|
|
def delete_collection(self) -> bool: |
|
|
try: |
|
|
self.client.delete_collection(self.collection_name) |
|
|
logger.info(f"Deleted collection: {self.collection_name}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to delete collection: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__=="__main__": |
|
|
logger.info(f"Vector store init ..") |
|
|
config = { |
|
|
'qdrant_url': os.getenv('QDRANT_URL', 'http://localhost:6333'), |
|
|
'qdrant_api_key': os.getenv('QDRANT_API_KEY'), |
|
|
'qdrant_collection': 'manufacturing_docs', |
|
|
'vector_size': 1024 |
|
|
} |
|
|
vector_store = QdrantVectorStore(config) |
|
|
health = vector_store.health_check() |
|
|
if health: |
|
|
logger.info("Vector store is healthy and ready.") |
|
|
else: |
|
|
logger.error("Vector store is not accessible.") |