hrbot / src /knowledge /embeddings.py
Sonu Prasad
updated
8a1c0d1
"""Embedding model wrapper for document vectorization."""
from pathlib import Path
from typing import Optional
import numpy as np
from sentence_transformers import SentenceTransformer
from src.config import settings
from src.document_processor.chunker import DocumentChunk
class EmbeddingModel:
"""Wrapper for sentence-transformers embedding models.
Provides efficient batch embedding with caching support.
"""
def __init__(self, model_name: Optional[str] = None):
"""Initialize the embedding model.
Args:
model_name: HuggingFace model name. Defaults to settings.embedding_model.
"""
self.model_name = model_name or settings.embedding_model
self._model: Optional[SentenceTransformer] = None
@property
def model(self) -> SentenceTransformer:
"""Lazy load the embedding model."""
if self._model is None:
self._model = SentenceTransformer(self.model_name)
return self._model
@property
def embedding_dimension(self) -> int:
"""Get the dimension of embeddings produced by this model."""
return self.model.get_sentence_embedding_dimension()
def embed_text(self, text: str) -> np.ndarray:
"""Embed a single text string.
Args:
text: Text to embed.
Returns:
Embedding vector as numpy array.
"""
return self.model.encode(text, convert_to_numpy=True, normalize_embeddings=True)
def embed_texts(self, texts: list[str], batch_size: int = 32) -> np.ndarray:
"""Embed multiple texts efficiently.
Args:
texts: List of texts to embed.
batch_size: Batch size for processing.
Returns:
Array of embedding vectors (num_texts x embedding_dim).
"""
return self.model.encode(
texts,
batch_size=batch_size,
convert_to_numpy=True,
normalize_embeddings=True,
show_progress_bar=len(texts) > 100,
)
def embed_chunks(
self, chunks: list[DocumentChunk], batch_size: int = 32
) -> list[tuple[DocumentChunk, np.ndarray]]:
"""Embed document chunks with their metadata.
Args:
chunks: List of DocumentChunks to embed.
batch_size: Batch size for processing.
Returns:
List of (chunk, embedding) tuples.
"""
texts = [chunk.content for chunk in chunks]
embeddings = self.embed_texts(texts, batch_size=batch_size)
return list(zip(chunks, embeddings))
def embed_query(self, query: str) -> np.ndarray:
"""Embed a query for retrieval.
Some models use different prompting for queries vs documents.
Args:
query: Query text to embed.
Returns:
Query embedding vector.
"""
# BGE models benefit from query prefixes
if "bge" in self.model_name.lower():
query = f"Represent this sentence for searching relevant passages: {query}"
return self.embed_text(query)