from typing import Any, Dict, List, Optional from sentence_transformers import SparseEncoder from loguru import logger from ..src.core.config import ModelConfig class SparseEmbeddingModel: """ Sparse embedding model wrapper. Attributes: config: ModelConfig instance model: SparseEncoder instance _loaded: Flag indicating if the model is loaded """ def __init__(self, config: ModelConfig): self.config = config self.model: Optional[SparseEncoder] = None self._loaded = False def load(self) -> None: """Load the sparse embedding model.""" if self._loaded: return logger.info(f"Loading sparse model: {self.config.name}") try: self.model = SparseEncoder(self.config.name) self._loaded = True logger.success(f"Loaded sparse model: {self.config.id}") except Exception as e: logger.error(f"Failed to load sparse model {self.config.id}: {e}") raise def query_embed( self, text: List[str], prompt: Optional[str] = None ) -> Dict[Any, Any]: """ Generate a sparse embedding for a single text. Args: text: Input text prompt: Optional prompt for instruction-based models Returns: Sparse embedding as a dictionary with 'indices' and 'values' keys. """ if not self._loaded: self.load() try: tensor = self.model.encode_query(text) values = tensor[0].coalesce().values().tolist() indices = tensor[0].coalesce().indices()[0].tolist() return {"indices": indices, "values": values} except Exception as e: logger.error(f"Embedding error: {e}") raise def embed_documents( self, text: List[str], prompt: Optional[str] = None ) -> Dict[Any, Any]: """ Generate a sparse embedding for a single text. Args: text: Input text prompt: Optional prompt for instruction-based models Returns: Sparse embedding as a dictionary with 'indices' and 'values' keys. """ try: tensor = self.model.encode(text) values = tensor[0].coalesce().values().tolist() indices = tensor[0].coalesce().indices()[0].tolist() return {"indices": indices, "values": values} except Exception as e: logger.error(f"Embedding error: {e}") raise def embed_batch( self, texts: List[str], prompt: Optional[str] = None ) -> List[Dict[str, Any]]: """ Generate sparse embeddings for a batch of texts. Args: texts: List of input texts prompt: Optional prompt for instruction-based models Returns: List of sparse embeddings as dictionaries with 'text' and 'sparse_embedding' keys. """ if not self._loaded: self.load() try: tensors = self.model.encode(texts) results = [] for i, tensor in enumerate(tensors): values = tensor.coalesce().values().tolist() indices = tensor.coalesce().indices()[0].tolist() results.append( { "text": texts[i], "sparse_embedding": {"indices": indices, "values": values}, } ) return results except Exception as e: logger.error(f"Sparse embedding generation failed: {e}") raise