Spaces:
Running
Running
| from loguru import logger | |
| from typing import Dict, List, Optional, Any | |
| from sentence_transformers import SentenceTransformer | |
| from sentence_transformers import SparseEncoder | |
| class ModelConfig: | |
| def __init__(self, model_id: str, config: Dict[str, Any]): | |
| self.id = model_id | |
| self.name = config["name"] | |
| self.type = config["type"] # "embedding" or "sparse" | |
| self.dimension = int(config["dimension"]) | |
| self.max_tokens = int(config["max_tokens"]) | |
| self.description = config["description"] | |
| self.language = config["language"] | |
| self.repository = config["repository"] | |
| class EmbeddingModel: | |
| """ | |
| Embedding model wrapper for dense embeddings. | |
| attributes: | |
| config: ModelConfig instance | |
| model: SentenceTransformer instance | |
| _loaded: Flag indicating if the model is loaded | |
| """ | |
| def __init__(self, config: ModelConfig): | |
| self.config = config | |
| self.model: Optional[SentenceTransformer] = None | |
| self._loaded = False | |
| def load(self) -> None: | |
| """Load the embedding model.""" | |
| if self._loaded: | |
| return | |
| logger.info(f"Loading embedding model: {self.config.name}") | |
| try: | |
| self.model = SentenceTransformer(self.config.name, device="cpu", trust_remote_code=True) | |
| self._loaded = True | |
| logger.success(f"Loaded embedding model: {self.config.id}") | |
| except Exception as e: | |
| logger.error(f"Failed to load embedding model {self.config.id}: {e}") | |
| raise | |
| def embed(self, texts: List[str], prompt: Optional[str] = None) -> List[List[float]]: | |
| """ | |
| method to generate embeddings for a list of texts. | |
| Args: | |
| texts: List of input texts | |
| prompt: Optional prompt for instruction-based models | |
| Returns: | |
| List of embedding vectors | |
| """ | |
| if not self._loaded: | |
| self.load() | |
| try: | |
| embeddings = self.model.encode(texts, prompt=prompt) | |
| return [embedding.tolist() for embedding in embeddings] | |
| except Exception as e: | |
| logger.error(f"Embedding generation failed: {e}") | |
| raise | |
| class SparseEmbeddingModel: | |
| """ | |
| Sparse embedding model wrapper. | |
| Attributes: | |
| config: ModelConfig instance | |
| model: SparseEncoder instance | |
| _loaded: Flag indicating if the model is loaded | |
| """ | |
| def __init__(self, config: ModelConfig): | |
| self.config = config | |
| self.model: Optional[SparseEncoder] = None | |
| self._loaded = False | |
| def _format_values(self, values: List[float]) -> List[float]: | |
| """Format float values to a fixed precision.""" | |
| return [round(float(v), 7) for v in values] | |
| def load(self) -> None: | |
| """Load the sparse embedding model.""" | |
| if self._loaded: | |
| return | |
| logger.info(f"Loading sparse model: {self.config.name}") | |
| try: | |
| self.model = SparseEncoder(self.config.name) | |
| self._loaded = True | |
| logger.success(f"Loaded sparse model: {self.config.id}") | |
| except Exception as e: | |
| logger.error(f"Failed to load sparse model {self.config.id}: {e}") | |
| raise | |
| def embed(self, text: str, prompt: Optional[str] = None) -> Dict[Any, Any]: | |
| """ | |
| Generate a sparse embedding for a single text. | |
| Args: | |
| text: Input text | |
| prompt: Optional prompt for instruction-based models | |
| Returns: | |
| Sparse embedding as a dictionary with 'indices' and 'values' keys. | |
| """ | |
| try: | |
| tensor = self.model.encode([text]) | |
| values = tensor[0].coalesce().values().tolist() | |
| indices = tensor[0].coalesce().indices()[0].tolist() | |
| return { | |
| "indices": indices, | |
| "values": self._format_values(values) | |
| } | |
| except Exception as e: | |
| logger.error(f"Embedding error: {e}") | |
| raise | |
| def embed_batch(self, texts: List[str], prompt: Optional[str] = None) -> List[Dict[str, Any]]: | |
| """ | |
| Generate sparse embeddings for a batch of texts. | |
| Args: | |
| texts: List of input texts | |
| prompt: Optional prompt for instruction-based models | |
| Returns: | |
| List of sparse embeddings as dictionaries with 'text' and 'sparse_embedding' keys. | |
| """ | |
| if not self._loaded: | |
| self.load() | |
| try: | |
| tensors = self.model.encode(texts) | |
| results = [] | |
| for i, tensor in enumerate(tensors): | |
| values = tensor.coalesce().values().tolist() | |
| indices = tensor.coalesce().indices()[0].tolist() | |
| results.append({ | |
| "text": texts[i], | |
| "sparse_embedding": { | |
| "indices": indices, | |
| "values": self._format_values(values) | |
| } | |
| }) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Sparse embedding generation failed: {e}") | |
| raise | |