Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| from typing import List, Optional | |
| from langchain.embeddings.base import Embeddings | |
| from langchain_qdrant import Qdrant | |
| from langchain.schema import Document | |
| from langchain_openai import OpenAIEmbeddings | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Distance, VectorParams | |
| from dotenv import load_dotenv | |
| from config import settings | |
| load_dotenv() | |
| # Qdrant Collections Params | |
| openai_embeddings = OpenAIEmbeddings(model=settings.EMBEDDINGS_MODEL_NAME) | |
| QDRANT_COLLECTIONS_PARAMS = {'openai_large_chunks_1000char': {'collection_name': 'openai_large_chunks_1000char', | |
| 'embeddings_model_name': 'text-embedding-3-large', | |
| 'vector_size': 3072, | |
| 'distance': Distance.COSINE, | |
| 'embeddings_model': openai_embeddings}, | |
| 'openai_large_chunks_500char': {'collection_name': 'openai_large_chunks_500char', | |
| 'embeddings_model_name': 'text-embedding-3-large', | |
| 'vector_size': 3072, | |
| 'distance': Distance.COSINE, | |
| 'embeddings_model': openai_embeddings}} | |
| class QdrantManager: | |
| '''Qdrant Manager to create a collection, add documents, and get a retriever. To see available collections, run client.get_collections()''' | |
| def __init__( | |
| self, | |
| collection_name: str, | |
| embeddings: Optional[Embeddings] = None, | |
| url: Optional[str] = None, | |
| api_key: Optional[str] = None, | |
| vector_size: Optional[int] = None, | |
| distance: Optional[Distance] = None | |
| ): | |
| self.collection_name = collection_name | |
| self.embeddings = embeddings or QDRANT_COLLECTIONS_PARAMS[collection_name]['embeddings_model'] | |
| self.url = url or os.getenv("QDRANT_URL") | |
| self.api_key = api_key or os.getenv("QDRANT_API_KEY") | |
| self.vector_size = vector_size or QDRANT_COLLECTIONS_PARAMS[collection_name]['vector_size'] | |
| self.distance = distance or QDRANT_COLLECTIONS_PARAMS[collection_name]['distance'] | |
| self.client = QdrantClient(url=self.url, api_key=self.api_key) | |
| self.qdrant = Qdrant( | |
| client=self.client, | |
| collection_name=self.collection_name, | |
| embeddings=self.embeddings, | |
| content_payload_key="page_content", | |
| metadata_payload_key="metadata", | |
| ) | |
| def create_collection(self) -> None: | |
| """Create a new collection if it doesn't exist.""" | |
| collections = self.client.get_collections().collections | |
| if self.collection_name not in [c.name for c in collections]: | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=VectorParams(size=self.vector_size, distance=self.distance), | |
| ) | |
| print(f"Collection '{self.collection_name}' created.") | |
| else: | |
| print(f"Collection '{self.collection_name}' already exists.") | |
| def add_documents(self, documents: List[Document], batch_size: int = 1000) -> None: | |
| """Add documents to the collection in batches.""" | |
| for i in range(0, len(documents), batch_size): | |
| batch = documents[i:i+batch_size] | |
| self.qdrant.add_documents(batch) | |
| print(f"Added batch {i//batch_size + 1} ({len(batch)} documents)") | |
| print(f"Total documents added: {len(documents)}") | |
| def get_vectorstore(self): | |
| """Get the Qdrant vectorstore.""" | |
| return self.qdrant | |
| def delete_collection(self) -> None: | |
| """Delete the collection.""" | |
| self.client.delete_collection(self.collection_name) | |
| print(f"Collection '{self.collection_name}' deleted.") |