|
|
from __future__ import annotations |
|
|
|
|
|
from datetime import datetime |
|
|
from typing import List, Optional |
|
|
|
|
|
import numpy as np |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
class EmbeddingModel: |
|
|
"""Thin wrapper around SentenceTransformer with metadata.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str, |
|
|
cache_dir: str = ".model_cache", |
|
|
device: str | None = None, |
|
|
query_prefix: Optional[str] = None, |
|
|
doc_prefix: Optional[str] = None, |
|
|
) -> None: |
|
|
self.model_name = model_name |
|
|
self.model = SentenceTransformer(model_name, cache_folder=cache_dir, device=device) |
|
|
|
|
|
lower_name = model_name.lower() |
|
|
if query_prefix is None and "bge" in lower_name: |
|
|
query_prefix = "Represent this query for retrieving relevant documents: " |
|
|
if doc_prefix is None and "bge" in lower_name: |
|
|
doc_prefix = "Represent this document for retrieval: " |
|
|
self.query_prefix = query_prefix |
|
|
self.doc_prefix = doc_prefix |
|
|
self.metadata = { |
|
|
"model_name": model_name, |
|
|
"embedding_dim": self.model.get_sentence_embedding_dimension(), |
|
|
"max_seq_length": self.model.max_seq_length, |
|
|
"loaded_at": datetime.utcnow().isoformat(), |
|
|
} |
|
|
|
|
|
def encode(self, texts: List[str], normalize: bool = True, batch_size: int = 32, is_query: bool = False) -> np.ndarray: |
|
|
def add_prefix(t: str) -> str: |
|
|
if is_query and self.query_prefix: |
|
|
return f"{self.query_prefix}{t}" |
|
|
if not is_query and self.doc_prefix: |
|
|
return f"{self.doc_prefix}{t}" |
|
|
return t |
|
|
|
|
|
prefixed = [add_prefix(t or "") for t in texts] |
|
|
embeds = self.model.encode( |
|
|
prefixed, |
|
|
batch_size=batch_size, |
|
|
normalize_embeddings=normalize, |
|
|
convert_to_numpy=True, |
|
|
show_progress_bar=False, |
|
|
) |
|
|
return embeds |
|
|
|