Spaces:
Running
Running
| from typing import List, Optional | |
| from sentence_transformers import SentenceTransformer | |
| from loguru import logger | |
| from ..src.core.config import ModelConfig | |
| class EmbeddingModel: | |
| """ | |
| Embedding model wrapper for dense embeddings. | |
| attributes: | |
| config: ModelConfig instance | |
| model: SentenceTransformer instance | |
| _loaded: Flag indicating if the model is loaded | |
| """ | |
| def __init__(self, config: ModelConfig): | |
| self.config = config | |
| self.model: Optional[SentenceTransformer] = None | |
| self._loaded = False | |
| def load(self) -> None: | |
| """Load the embedding model.""" | |
| if self._loaded: | |
| return | |
| logger.info(f"Loading embedding model: {self.config.name}") | |
| try: | |
| self.model = SentenceTransformer( | |
| self.config.name, device="cpu", trust_remote_code=True | |
| ) | |
| self._loaded = True | |
| logger.success(f"Loaded embedding model: {self.config.id}") | |
| except Exception as e: | |
| logger.error(f"Failed to load embedding model {self.config.id}: {e}") | |
| raise | |
| def query_embed(self, text: List[str], prompt: Optional[str] = None) -> List[float]: | |
| """ | |
| method to generate embedding for a single text. | |
| Args: | |
| text: Input text | |
| prompt: Optional prompt for instruction-based models | |
| Returns: | |
| Embedding vector | |
| """ | |
| if not self._loaded: | |
| self.load() | |
| try: | |
| embeddings = self.model.encode_query(text, prompt=prompt) | |
| return [embedding.tolist() for embedding in embeddings] | |
| except Exception as e: | |
| logger.error(f"Embedding generation failed: {e}") | |
| raise | |
| def embed_documents( | |
| self, texts: List[str], prompt: Optional[str] = None | |
| ) -> List[List[float]]: | |
| """ | |
| method to generate embeddings for a list of texts. | |
| Args: | |
| texts: List of input texts | |
| prompt: Optional prompt for instruction-based models | |
| Returns: | |
| List of embedding vectors | |
| """ | |
| if not self._loaded: | |
| self.load() | |
| try: | |
| embeddings = self.model.encode_document(texts, prompt=prompt) | |
| return [embedding.tolist() for embedding in embeddings] | |
| except Exception as e: | |
| logger.error(f"Embedding generation failed: {e}") | |
| raise | |