File size: 9,721 Bytes
43efcb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 |
"""
Unified embedding model implementation supporting multiple backends.
"""
from typing import List, Union, Optional, Dict, Any
import logging
import numpy as np
import torch
from abc import ABC, abstractmethod
# Configure logging
logger = logging.getLogger(__name__)
class EmbeddingModel(ABC):
"""Abstract base class for embedding models."""
@abstractmethod
def embed(self, texts: Union[str, List[str]], batch_size: int = 32) -> np.ndarray:
"""
Convert text(s) to embedding vector(s).
Args:
texts: Input text(s) to embed
batch_size: Batch size for processing
Returns:
Embedding vector(s) as numpy array
"""
pass
@property
@abstractmethod
def dimension(self) -> int:
"""Get the dimension of the embedding vectors."""
pass
class SentenceTransformerEmbedding(EmbeddingModel):
"""Embedding model using sentence-transformers library."""
def __init__(
self,
model_name: str = "all-MiniLM-L6-v2",
device: Optional[str] = None,
normalize: bool = True,
**kwargs
):
"""
Initialize the sentence transformer embedding model.
Args:
model_name: Sentence transformer model name or path
device: Device to run model on ('cpu', 'cuda', 'cuda:0', etc.)
normalize: Whether to L2-normalize embeddings
**kwargs: Additional arguments for the model
"""
try:
from sentence_transformers import SentenceTransformer
except ImportError:
raise ImportError(
"sentence-transformers is not installed. "
"Please install it with `pip install sentence-transformers`."
)
self.model_name = model_name
self.normalize = normalize
# Determine device
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
logger.info(f"Loading SentenceTransformer model: {model_name} on {self.device}")
try:
self.model = SentenceTransformer(model_name, device=self.device)
self._dimension = self.model.get_sentence_embedding_dimension()
logger.info(f"Model loaded successfully. Embedding dimension: {self._dimension}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
def embed(self, texts: Union[str, List[str]], batch_size: int = 32) -> np.ndarray:
"""
Convert text(s) to embedding vector(s).
Args:
texts: Input text(s) to embed
batch_size: Batch size for processing
Returns:
Embedding vector(s) as numpy array
"""
# Handle single text input
if isinstance(texts, str):
texts = [texts]
# Validate input
if not texts:
logger.warning("Empty texts provided for embedding")
return np.array([])
try:
# Generate embeddings
embeddings = self.model.encode(
texts,
batch_size=batch_size,
show_progress_bar=False,
convert_to_numpy=True
)
# Normalize if requested
if self.normalize:
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
return embeddings
except Exception as e:
logger.error(f"Error during embedding generation: {e}")
raise
@property
def dimension(self) -> int:
"""Get the dimension of the embedding vectors."""
return self._dimension
class HuggingFaceEmbedding(EmbeddingModel):
"""Embedding model using HuggingFace transformers directly."""
def __init__(
self,
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
device: Optional[str] = None,
normalize: bool = True,
max_length: int = 512,
**kwargs
):
"""
Initialize the HuggingFace embedding model.
Args:
model_name: HuggingFace model name or path
device: Device to run model on ('cpu', 'cuda', 'cuda:0', etc.)
normalize: Whether to L2-normalize embeddings
max_length: Maximum token length for inputs
**kwargs: Additional arguments for the model
"""
try:
from transformers import AutoTokenizer, AutoModel
except ImportError:
raise ImportError(
"transformers is not installed. "
"Please install it with `pip install transformers`."
)
self.model_name = model_name
self.normalize = normalize
self.max_length = max_length
# Determine device
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
logger.info(f"Loading HuggingFace model: {model_name} on {self.device}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.model.to(self.device)
self.model.eval()
# Get embedding dimension from model config
self._dimension = self.model.config.hidden_size
logger.info(f"Model loaded successfully. Embedding dimension: {self._dimension}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
def _mean_pooling(self, model_output, attention_mask):
"""Perform mean pooling on token embeddings."""
token_embeddings = model_output.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def embed(self, texts: Union[str, List[str]], batch_size: int = 32) -> np.ndarray:
"""
Convert text(s) to embedding vector(s).
Args:
texts: Input text(s) to embed
batch_size: Batch size for processing
Returns:
Embedding vector(s) as numpy array
"""
# Handle single text input
if isinstance(texts, str):
texts = [texts]
# Validate input
if not texts:
logger.warning("Empty texts provided for embedding")
return np.array([])
try:
all_embeddings = []
# Process in batches
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
# Tokenize and move to device
inputs = self.tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate embeddings
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = self._mean_pooling(outputs, inputs["attention_mask"])
# Normalize if requested
if self.normalize:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
# Move to CPU and convert to numpy
embeddings = embeddings.cpu().numpy()
all_embeddings.append(embeddings)
# Concatenate all batches
return np.vstack(all_embeddings) if all_embeddings else np.array([])
except Exception as e:
logger.error(f"Error during embedding generation: {e}")
raise
@property
def dimension(self) -> int:
"""Get the dimension of the embedding vectors."""
return self._dimension
# Factory function to create embedding models
def create_embedding_model(
backend: str = "sentence-transformers",
model_name: Optional[str] = None,
**kwargs
) -> EmbeddingModel:
"""
Factory function to create an embedding model.
Args:
backend: Backend to use ('sentence-transformers' or 'huggingface')
model_name: Model name or path
**kwargs: Additional arguments for the model
Returns:
An EmbeddingModel instance
"""
from ..config import EMBEDDING_MODEL_NAME, get_model_config
# Use config model if not specified
if model_name is None:
model_name = EMBEDDING_MODEL_NAME
# Get model-specific config
model_config = get_model_config(model_name)
# Override with provided kwargs
for k, v in kwargs.items():
model_config[k] = v
# Create the model
if backend.lower() == "sentence-transformers":
return SentenceTransformerEmbedding(model_name=model_name, **model_config)
elif backend.lower() in ["huggingface", "hf", "transformers"]:
return HuggingFaceEmbedding(model_name=model_name, **model_config)
else:
raise ValueError(f"Unsupported backend: {backend}")
|