Spaces:
Running
Running
File size: 4,930 Bytes
0231daa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
"""
Abstract base classes for embedding models.
This module defines the interface that all embedding model implementations
must follow, ensuring consistency across dense and sparse embeddings.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
class BaseEmbeddingModel(ABC):
"""
Abstract base class for all embedding models.
All embedding model implementations (dense, sparse, etc.) must inherit
from this class and implement all abstract methods.
Attributes:
config: ModelConfig instance containing model metadata
_loaded: Flag indicating if the model is currently loaded in memory
"""
def __init__(self, config: Any):
"""
Initialize the embedding model.
Args:
config: ModelConfig instance with model configuration
"""
self.config = config
self._loaded = False
@abstractmethod
def load(self) -> None:
"""
Load the model into memory.
This method should:
- Check if already loaded (idempotent)
- Initialize the underlying model
- Set _loaded flag to True
- Handle errors gracefully
Raises:
RuntimeError: If model fails to load
"""
pass
@abstractmethod
def unload(self) -> None:
"""
Unload the model from memory and free resources.
This method should:
- Release model from memory
- Clear any caches
- Set _loaded flag to False
- Be safe to call multiple times
"""
pass
@abstractmethod
def embed_query(
self, texts: List[str], prompt: Optional[str] = None, **kwargs
) -> Union[List[List[float]], List[Dict[str, Any]]]:
"""
Generate embeddings for query texts.
Query embeddings may differ from document embeddings in some models
(e.g., asymmetric retrieval models).
Args:
texts: List of query texts to embed (REQUIRED)
prompt: Optional instruction prompt for the model
**kwargs: Additional model-specific parameters, such as:
- normalize_embeddings (bool): L2 normalize output vectors
- batch_size (int): Batch size for processing
- max_length (int): Maximum token sequence length
- convert_to_numpy (bool): Return numpy arrays instead of lists
- precision (str): Computation precision ('float32', 'int8', etc.)
Returns:
List of embeddings (format depends on model type)
- Dense: List[List[float]]
- Sparse: List[Dict[str, Any]] with 'indices' and 'values'
Raises:
RuntimeError: If model is not loaded
ValueError: If input validation fails
Note:
Available kwargs depend on the underlying model implementation.
Check sentence-transformers documentation for full parameter list.
"""
pass
@abstractmethod
def embed_documents(
self, texts: List[str], prompt: Optional[str] = None, **kwargs
) -> Union[List[List[float]], List[Dict[str, Any]]]:
"""
Generate embeddings for document texts.
Document embeddings are used for indexing and storage.
Args:
texts: List of document texts to embed (REQUIRED)
prompt: Optional instruction prompt for the model
**kwargs: Additional model-specific parameters (see embed_query for details)
Returns:
List of embeddings (format depends on model type)
- Dense: List[List[float]]
- Sparse: List[Dict[str, Any]] with 'indices' and 'values'
Raises:
RuntimeError: If model is not loaded
ValueError: If input validation fails
Note:
Available kwargs depend on the underlying model implementation.
Check sentence-transformers documentation for full parameter list.
"""
pass
@property
def is_loaded(self) -> bool:
"""
Check if the model is currently loaded.
Returns:
True if model is loaded, False otherwise
"""
return self._loaded
@property
def model_id(self) -> str:
"""
Get the model identifier.
Returns:
Model ID string
"""
return self.config.id
@property
def model_type(self) -> str:
"""
Get the model type.
Returns:
Model type ('embeddings' or 'sparse-embeddings')
"""
return self.config.type
def __repr__(self) -> str:
"""String representation of the model."""
return (
f"{self.__class__.__name__}("
f"id={self.model_id}, "
f"type={self.model_type}, "
f"loaded={self.is_loaded})"
)
|