Spaces:
Running
Running
File size: 5,316 Bytes
0231daa 1c7e30d 0231daa 1c7e30d 0231daa 1c7e30d 0231daa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
"""
Dense embedding model implementation.
This module provides the DenseEmbeddingModel class for generating
dense vector embeddings using sentence-transformers.
"""
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from loguru import logger
from src.config.settings import get_settings
from src.core.base import BaseEmbeddingModel
from src.core.config import ModelConfig
from src.core.exceptions import ModelLoadError, EmbeddingGenerationError
class DenseEmbeddingModel(BaseEmbeddingModel):
"""
Dense embedding model wrapper using sentence-transformers.
This class wraps sentence-transformers SentenceTransformer models
to generate dense vector embeddings for text.
Attributes:
config: ModelConfig instance
model: SentenceTransformer instance
_loaded: Flag indicating if the model is loaded
"""
def __init__(self, config: ModelConfig):
"""
Initialize the dense embedding model.
Args:
config: ModelConfig instance with model configuration
"""
super().__init__(config)
self.model: Optional[SentenceTransformer] = None
self.settings = get_settings()
def load(self) -> None:
"""
Load the dense embedding model into memory.
Raises:
ModelLoadError: If model fails to load
"""
if self._loaded:
logger.debug(f"Model {self.model_id} already loaded")
return
logger.info(f"Loading dense embedding model: {self.config.name}")
try:
self.model = SentenceTransformer(
self.config.name,
device=self.settings.DEVICE,
trust_remote_code=self.settings.TRUST_REMOTE_CODE
)
self._loaded = True
logger.success(f"✓ Loaded dense model: {self.model_id}")
except Exception as e:
error_msg = f"Failed to load model: {str(e)}"
logger.error(f"✗ {error_msg}")
raise ModelLoadError(self.model_id, error_msg)
def unload(self) -> None:
"""
Unload the model from memory and free resources.
This method safely releases the model and clears GPU/CPU memory.
"""
if not self._loaded:
logger.debug(f"Model {self.model_id} not loaded, nothing to unload")
return
try:
if self.model is not None:
# Clear model from memory
del self.model
self.model = None
self._loaded = False
logger.info(f"✓ Unloaded model: {self.model_id}")
except Exception as e:
logger.error(f"Error unloading model {self.model_id}: {e}")
def embed_query(
self, texts: List[str], prompt: Optional[str] = None, **kwargs
) -> List[List[float]]:
"""
Generate embeddings for query texts.
Args:
texts: List of query texts to embed
prompt: Optional instruction prompt
**kwargs: Additional parameters for sentence-transformers:
- normalize_embeddings (bool)
- batch_size (int)
- convert_to_numpy (bool)
- etc.
Returns:
List of embedding vectors
Raises:
RuntimeError: If model is not loaded
EmbeddingGenerationError: If embedding generation fails
"""
if not self._loaded or self.model is None:
self.load()
try:
embeddings = self.model.encode_query(texts, prompt=prompt, **kwargs)
# Convert to list format
return [
emb.tolist() if hasattr(emb, "tolist") else list(emb)
for emb in embeddings
]
except Exception as e:
error_msg = f"Query embedding generation failed: {str(e)}"
logger.error(error_msg)
raise EmbeddingGenerationError(self.model_id, error_msg)
def embed_documents(
self, texts: List[str], prompt: Optional[str] = None, **kwargs
) -> List[List[float]]:
"""
Generate embeddings for document texts.
Args:
texts: List of document texts to embed
prompt: Optional instruction prompt
**kwargs: Additional parameters for sentence-transformers:
- normalize_embeddings (bool)
- batch_size (int)
- convert_to_numpy (bool)
- etc.
Returns:
List of embedding vectors
Raises:
RuntimeError: If model is not loaded
EmbeddingGenerationError: If embedding generation fails
"""
if not self._loaded or self.model is None:
self.load()
try:
embeddings = self.model.encode_document(texts, prompt=prompt, **kwargs)
# Convert to list format
return [
emb.tolist() if hasattr(emb, "tolist") else list(emb)
for emb in embeddings
]
except Exception as e:
error_msg = f"Document embedding generation failed: {str(e)}"
logger.error(error_msg)
raise EmbeddingGenerationError(self.model_id, error_msg)
|