Spaces:
Configuration error
Configuration error
File size: 6,288 Bytes
bec06d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import uuid
import sys
from typing import List, Dict, Any, Optional
from qdrant_client import QdrantClient
from qdrant_client.http import models
# Add the current directory to the path so we can import config
sys.path.insert(0, os.path.dirname(__file__))
from config import QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME
import logging
logger = logging.getLogger(__name__)
class VectorStore:
"""
A class to handle vector storage and retrieval using Qdrant.
"""
def __init__(self):
if QDRANT_API_KEY:
self.client = QdrantClient(
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
prefer_grpc=True
)
else:
self.client = QdrantClient(url=QDRANT_URL)
def create_collection(self, vector_size: int = 1536):
"""Create a collection in Qdrant if it doesn't exist."""
try:
# Check if collection exists
collections = self.client.get_collections().collections
if not any(col.name == COLLECTION_NAME for col in collections):
self.client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE
),
)
logger.info(f"Created collection: {COLLECTION_NAME}")
else:
logger.info(f"Collection {COLLECTION_NAME} already exists")
except Exception as e:
logger.error(f"Error creating collection: {str(e)}")
raise
def add_documents(self, documents: List[Dict[str, Any]]):
"""Add documents with embeddings to the collection."""
try:
points = []
for doc in documents:
# Generate a unique ID for each document chunk
point_id = str(uuid.uuid4())
# Extract content, embedding, and metadata
content = doc.get('content', '')
embedding = doc.get('embedding', [])
metadata = doc.get('metadata', {})
# Create payload with all available metadata
payload = {
"content": content,
"source": metadata.get('source', ''),
"file_name": metadata.get('file_name', ''),
"file_path": metadata.get('file_path', ''),
}
# Add additional metadata if available
if 'chunk_id' in metadata:
payload['chunk_id'] = metadata['chunk_id']
if 'total_chunks' in metadata:
payload['total_chunks'] = metadata['total_chunks']
points.append(
models.PointStruct(
id=point_id,
vector=embedding,
payload=payload
)
)
# Upload points to the collection
self.client.upload_points(
collection_name=COLLECTION_NAME,
points=points
)
logger.info(f"Added {len(points)} documents to collection {COLLECTION_NAME}")
except Exception as e:
logger.error(f"Error adding documents: {str(e)}")
raise
def delete_collection(self):
"""Delete the collection if it exists."""
try:
self.client.delete_collection(collection_name=COLLECTION_NAME)
logger.info(f"Deleted collection: {COLLECTION_NAME}")
except Exception as e:
logger.error(f"Error deleting collection: {str(e)}")
raise
def delete_documents_by_source(self, source: str):
"""Delete documents that match a specific source."""
try:
# Find points with the matching source
result = self.client.scroll(
collection_name=COLLECTION_NAME,
scroll_filter=models.Filter(
must=[
models.FieldCondition(
key="source",
match=models.MatchValue(value=source)
)
]
),
limit=10000 # Adjust as needed
)
# Extract IDs of matching points
point_ids = [point.id for point in result[0]]
if point_ids:
# Delete the points
self.client.delete(
collection_name=COLLECTION_NAME,
points_selector=models.PointIdsList(
points=point_ids
)
)
logger.info(f"Deleted {len(point_ids)} documents from source: {source}")
else:
logger.info(f"No documents found from source: {source}")
except Exception as e:
logger.error(f"Error deleting documents by source: {str(e)}")
raise
def search_similar(self, query_embedding: List[float], top_k: int = 5) -> List[Dict[str, Any]]:
"""Search for similar documents based on embedding."""
try:
results = self.client.search(
collection_name=COLLECTION_NAME,
query_vector=query_embedding,
limit=top_k
)
hits = []
for hit in results:
hits.append({
'content': hit.payload.get('content', ''),
'source': hit.payload.get('source', ''),
'score': hit.score,
'id': hit.id
})
return hits
except Exception as e:
logger.error(f"Error searching for similar documents: {str(e)}")
return []
def get_all_documents_count(self) -> int:
"""Get the total number of documents in the collection."""
try:
info = self.client.get_collection(collection_name=COLLECTION_NAME)
return info.points_count
except Exception as e:
logger.error(f"Error getting document count: {str(e)}")
return 0 |