File size: 2,011 Bytes
5a3b322 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from __future__ import annotations
from datetime import datetime
from typing import List, Optional
import numpy as np
from sentence_transformers import SentenceTransformer
class EmbeddingModel:
"""Thin wrapper around SentenceTransformer with metadata."""
def __init__(
self,
model_name: str,
cache_dir: str = ".model_cache",
device: str | None = None,
query_prefix: Optional[str] = None,
doc_prefix: Optional[str] = None,
) -> None:
self.model_name = model_name
self.model = SentenceTransformer(model_name, cache_folder=cache_dir, device=device)
# Sensible defaults for some popular retrieval models.
lower_name = model_name.lower()
if query_prefix is None and "bge" in lower_name:
query_prefix = "Represent this query for retrieving relevant documents: "
if doc_prefix is None and "bge" in lower_name:
doc_prefix = "Represent this document for retrieval: "
self.query_prefix = query_prefix
self.doc_prefix = doc_prefix
self.metadata = {
"model_name": model_name,
"embedding_dim": self.model.get_sentence_embedding_dimension(),
"max_seq_length": self.model.max_seq_length,
"loaded_at": datetime.utcnow().isoformat(),
}
def encode(self, texts: List[str], normalize: bool = True, batch_size: int = 32, is_query: bool = False) -> np.ndarray:
def add_prefix(t: str) -> str:
if is_query and self.query_prefix:
return f"{self.query_prefix}{t}"
if not is_query and self.doc_prefix:
return f"{self.doc_prefix}{t}"
return t
prefixed = [add_prefix(t or "") for t in texts]
embeds = self.model.encode(
prefixed,
batch_size=batch_size,
normalize_embeddings=normalize,
convert_to_numpy=True,
show_progress_bar=False,
)
return embeds
|