Spaces:
Sleeping
Sleeping
| """Vector store management and operations.""" | |
| from pathlib import Path | |
| from typing import Dict, Any, List, Optional | |
| import torch | |
| from langchain_qdrant import QdrantVectorStore | |
| from langchain.docstore.document import Document | |
| from langchain_core.embeddings import Embeddings | |
| from sentence_transformers import SentenceTransformer | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| class MatryoshkaEmbeddings(Embeddings): | |
| """Custom embeddings class that supports Matryoshka dimension truncation.""" | |
| def __init__(self, model_name: str, truncate_dim: int = None, **kwargs): | |
| """ | |
| Initialize Matryoshka embeddings. | |
| Args: | |
| model_name: Name of the model | |
| truncate_dim: Dimension to truncate to (for Matryoshka models) | |
| **kwargs: Additional arguments (ignored for Matryoshka models) | |
| """ | |
| self.model_name = model_name | |
| self.truncate_dim = truncate_dim | |
| if truncate_dim and "matryoshka" in model_name.lower(): | |
| # Use SentenceTransformer directly for Matryoshka models | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = SentenceTransformer(model_name, truncate_dim=truncate_dim, device=device) | |
| print(f"🔧 Matryoshka model configured for {truncate_dim} dimensions") | |
| else: | |
| # Use standard HuggingFaceEmbeddings | |
| self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs) | |
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
| """Embed documents.""" | |
| if self.truncate_dim and "matryoshka" in self.model_name.lower(): | |
| embeddings = self.model.encode(texts, normalize_embeddings=True) | |
| return embeddings.tolist() | |
| else: | |
| return self.model.embed_documents(texts) | |
| def embed_query(self, text: str) -> List[float]: | |
| """Embed query.""" | |
| if self.truncate_dim and "matryoshka" in self.model_name.lower(): | |
| embedding = self.model.encode([text], normalize_embeddings=True) | |
| return embedding[0].tolist() | |
| else: | |
| return self.model.embed_query(text) | |
| class VectorStoreManager: | |
| """Manages vector store operations and connections.""" | |
| def __init__(self, config: Dict[str, Any]): | |
| """ | |
| Initialize vector store manager. | |
| Args: | |
| config: Configuration dictionary | |
| """ | |
| self.config = config | |
| self.embeddings = self._create_embeddings() | |
| self.vectorstore = None | |
| # Define metadata fields that need payload indexes for filtering | |
| self.metadata_fields = [ | |
| ("metadata.year", "keyword"), | |
| ("metadata.source", "keyword"), | |
| ("metadata.filename", "keyword"), | |
| # Add more metadata fields as needed | |
| ] | |
| def _create_embeddings(self) -> HuggingFaceEmbeddings: | |
| """Create embeddings model from configuration.""" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_name = self.config["retriever"]["model"] | |
| normalize = self.config["retriever"]["normalize"] | |
| model_kwargs = {"device": device} | |
| encode_kwargs = { | |
| "normalize_embeddings": normalize, | |
| "batch_size": 100, | |
| } | |
| # For Matryoshka models, check if we need to truncate dimensions | |
| if "matryoshka" in model_name.lower(): | |
| # Check if we have a specific dimension requirement | |
| collection_name = self.config.get("qdrant", {}).get("collection_name", "") | |
| if "modernbert-embed-base-akryl-matryoshka" in collection_name: | |
| # This collection expects 768 dimensions | |
| truncate_dim = 768 | |
| print(f"🔧 Matryoshka model configured for {truncate_dim} dimensions") | |
| # Use custom MatryoshkaEmbeddings | |
| embeddings = MatryoshkaEmbeddings( | |
| model_name=model_name, | |
| truncate_dim=truncate_dim, | |
| model_kwargs=model_kwargs, | |
| encode_kwargs=encode_kwargs, | |
| show_progress=True, | |
| ) | |
| return embeddings | |
| # Use standard HuggingFaceEmbeddings for non-Matryoshka models | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=model_name, | |
| model_kwargs=model_kwargs, | |
| encode_kwargs=encode_kwargs, | |
| show_progress=True, | |
| ) | |
| return embeddings | |
| def ensure_metadata_indexes(self) -> None: | |
| """ | |
| Create payload indexes for all required metadata fields. | |
| This ensures filtering works properly, especially in Qdrant Cloud. | |
| """ | |
| if not self.vectorstore: | |
| return | |
| qdrant_config = self.config["qdrant"] | |
| collection_name = qdrant_config["collection_name"] | |
| for field_name, field_type in self.metadata_fields: | |
| try: | |
| self.vectorstore.client.create_payload_index( | |
| collection_name=collection_name, | |
| field_name=field_name, | |
| field_type=field_type | |
| ) | |
| print(f"Created payload index for {field_name} ({field_type})") | |
| except Exception as e: | |
| # Index might already exist or other error - log but continue | |
| print(f"Index creation for {field_name} ({field_type}): {str(e)}") | |
| def connect_to_existing(self, force_recreate: bool = False) -> QdrantVectorStore: | |
| """ | |
| Connect to existing Qdrant collection. | |
| Args: | |
| force_recreate: If True, recreate the collection if dimension mismatch occurs | |
| Returns: | |
| QdrantVectorStore instance | |
| """ | |
| qdrant_config = self.config["qdrant"] | |
| kwargs_qdrant = { | |
| "url": qdrant_config["url"], | |
| "collection_name": qdrant_config["collection_name"], | |
| "prefer_grpc": qdrant_config.get("prefer_grpc", True), | |
| "api_key": qdrant_config.get("api_key", None), | |
| } | |
| if force_recreate: | |
| kwargs_qdrant["force_recreate"] = True | |
| self.vectorstore = QdrantVectorStore.from_existing_collection( | |
| embedding=self.embeddings, | |
| **kwargs_qdrant | |
| ) | |
| # Ensure payload indexes exist for metadata filtering | |
| self.ensure_metadata_indexes() | |
| return self.vectorstore | |
| def create_from_documents(self, documents: List[Document]) -> QdrantVectorStore: | |
| """ | |
| Create new Qdrant collection from documents. | |
| Args: | |
| documents: List of Document objects | |
| Returns: | |
| QdrantVectorStore instance | |
| """ | |
| qdrant_config = self.config["qdrant"] | |
| kwargs_qdrant = { | |
| "url": qdrant_config["url"], | |
| "collection_name": qdrant_config["collection_name"], | |
| "prefer_grpc": qdrant_config.get("prefer_grpc", True), | |
| "api_key": qdrant_config.get("api_key", None), | |
| } | |
| self.vectorstore = QdrantVectorStore.from_documents( | |
| documents=documents, | |
| embedding=self.embeddings, | |
| **kwargs_qdrant | |
| ) | |
| # Ensure payload indexes exist for metadata filtering | |
| self.ensure_metadata_indexes() | |
| return self.vectorstore | |
| def delete_collection(self) -> None: | |
| """ | |
| Delete the current Qdrant collection. | |
| Returns: | |
| QdrantVectorStore instance | |
| """ | |
| qdrant_config = self.config["qdrant"] | |
| collection_name = qdrant_config.get("collection_name") | |
| self.vectorstore.client.delete_collection( | |
| collection_name=collection_name | |
| ) | |
| return self.vectorstore | |
| def get_vectorstore(self) -> Optional[QdrantVectorStore]: | |
| """Get current vectorstore instance.""" | |
| return self.vectorstore | |
| def get_local_qdrant(config: Dict[str, Any]) -> QdrantVectorStore: | |
| """ | |
| Get local Qdrant vector store (legacy function for compatibility). | |
| Args: | |
| config: Configuration dictionary | |
| Returns: | |
| QdrantVectorStore instance | |
| """ | |
| manager = VectorStoreManager(config) | |
| return manager.connect_to_existing() | |
| def create_vectorstore(config: Dict[str, Any], documents: List[Document]) -> QdrantVectorStore: | |
| """ | |
| Create new vector store from documents. | |
| Args: | |
| config: Configuration dictionary | |
| documents: List of Document objects | |
| Returns: | |
| QdrantVectorStore instance | |
| """ | |
| manager = VectorStoreManager(config) | |
| return manager.create_from_documents(documents) | |
| def get_embeddings_model(config: Dict[str, Any]) -> HuggingFaceEmbeddings: | |
| """ | |
| Create embeddings model from configuration (legacy function). | |
| Args: | |
| config: Configuration dictionary | |
| Returns: | |
| HuggingFaceEmbeddings instance | |
| """ | |
| manager = VectorStoreManager(config) | |
| return manager.embeddings | |