|
|
""" |
|
|
Model loading and embedding interface for the Rabbinic embedding benchmark. |
|
|
|
|
|
Supports: |
|
|
- Curated models from Hugging Face (sentence-transformers) |
|
|
- Any Hugging Face sentence-transformer model |
|
|
- API-based models (OpenAI, Voyage AI, Google Gemini) |
|
|
""" |
|
|
|
|
|
import os |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Optional |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
CURATED_MODELS = { |
|
|
"intfloat/multilingual-e5-large": { |
|
|
"name": "Multilingual E5 Large", |
|
|
"description": "Strong multilingual model from Microsoft, 560M params", |
|
|
"type": "local", |
|
|
"query_prefix": "query: ", |
|
|
"passage_prefix": "passage: ", |
|
|
}, |
|
|
"intfloat/multilingual-e5-base": { |
|
|
"name": "Multilingual E5 Base", |
|
|
"description": "Smaller multilingual E5, 278M params", |
|
|
"type": "local", |
|
|
"query_prefix": "query: ", |
|
|
"passage_prefix": "passage: ", |
|
|
}, |
|
|
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2": { |
|
|
"name": "Multilingual MPNet", |
|
|
"description": "Classic multilingual sentence transformer, 278M params", |
|
|
"type": "local", |
|
|
"query_prefix": "", |
|
|
"passage_prefix": "", |
|
|
}, |
|
|
"BAAI/bge-m3": { |
|
|
"name": "BGE-M3", |
|
|
"description": "Multi-lingual, multi-functionality, multi-granularity model from BAAI", |
|
|
"type": "local", |
|
|
"query_prefix": "", |
|
|
"passage_prefix": "", |
|
|
}, |
|
|
"intfloat/e5-mistral-7b-instruct": { |
|
|
"name": "E5 Mistral 7B", |
|
|
"description": "Large instruction-tuned embedding model, 7B params (requires GPU)", |
|
|
"type": "local", |
|
|
"query_prefix": "Instruct: Retrieve semantically similar text\nQuery: ", |
|
|
"passage_prefix": "", |
|
|
}, |
|
|
"Alibaba-NLP/gte-multilingual-base": { |
|
|
"name": "GTE Multilingual Base", |
|
|
"description": "General Text Embeddings multilingual model from Alibaba", |
|
|
"type": "local", |
|
|
"query_prefix": "", |
|
|
"passage_prefix": "", |
|
|
}, |
|
|
"google/embeddinggemma-300m": { |
|
|
"name": "EmbeddingGemma", |
|
|
"description": "Google's 300M param embedding model, 100+ languages, 768d (requires HF token + license)", |
|
|
"type": "local", |
|
|
"query_prefix": "task: search result | query: ", |
|
|
"passage_prefix": "title: none | text: ", |
|
|
"max_length": 2048, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
API_MODELS = { |
|
|
"openai/text-embedding-3-large": { |
|
|
"name": "OpenAI text-embedding-3-large", |
|
|
"description": "OpenAI's best embedding model, 3072 dimensions (API key required)", |
|
|
"type": "openai", |
|
|
"model_name": "text-embedding-3-large", |
|
|
"dimensions": 3072, |
|
|
}, |
|
|
"openai/text-embedding-3-small": { |
|
|
"name": "OpenAI text-embedding-3-small", |
|
|
"description": "OpenAI's efficient embedding model, 1536 dimensions (API key required)", |
|
|
"type": "openai", |
|
|
"model_name": "text-embedding-3-small", |
|
|
"dimensions": 1536, |
|
|
}, |
|
|
"openai/text-embedding-ada-002": { |
|
|
"name": "OpenAI Ada 002", |
|
|
"description": "OpenAI's legacy embedding model, 1536 dimensions (API key required)", |
|
|
"type": "openai", |
|
|
"model_name": "text-embedding-ada-002", |
|
|
"dimensions": 1536, |
|
|
}, |
|
|
"voyage/voyage-3.5": { |
|
|
"name": "Voyage AI voyage-3.5", |
|
|
"description": "Voyage AI's latest embedding model (API key required)", |
|
|
"type": "voyage", |
|
|
"model_name": "voyage-3.5", |
|
|
"dimensions": 1024, |
|
|
}, |
|
|
"voyage/voyage-3.5-lite": { |
|
|
"name": "Voyage AI voyage-3.5-lite", |
|
|
"description": "Voyage AI's efficient embedding model (API key required)", |
|
|
"type": "voyage", |
|
|
"model_name": "voyage-3.5-lite", |
|
|
"dimensions": 1024, |
|
|
}, |
|
|
"voyage/voyage-3": { |
|
|
"name": "Voyage AI voyage-3", |
|
|
"description": "Voyage AI's general purpose embedding model (API key required)", |
|
|
"type": "voyage", |
|
|
"model_name": "voyage-3", |
|
|
"dimensions": 1024, |
|
|
}, |
|
|
"voyage/voyage-3-lite": { |
|
|
"name": "Voyage AI voyage-3-lite", |
|
|
"description": "Voyage AI's lightweight embedding model (API key required)", |
|
|
"type": "voyage", |
|
|
"model_name": "voyage-3-lite", |
|
|
"dimensions": 512, |
|
|
}, |
|
|
"voyage/voyage-multilingual-2": { |
|
|
"name": "Voyage AI voyage-multilingual-2", |
|
|
"description": "Voyage AI's multilingual embedding model, optimized for non-English (API key required)", |
|
|
"type": "voyage", |
|
|
"model_name": "voyage-multilingual-2", |
|
|
"dimensions": 1024, |
|
|
}, |
|
|
"gemini/gemini-embedding-001": { |
|
|
"name": "Gemini Embedding 001", |
|
|
"description": "Google's Gemini embedding model, 3072 dimensions (API key required)", |
|
|
"type": "gemini", |
|
|
"model_name": "gemini-embedding-001", |
|
|
"dimensions": 3072, |
|
|
}, |
|
|
"gemini/gemini-embedding-001-768": { |
|
|
"name": "Gemini Embedding 001 (768d)", |
|
|
"description": "Google's Gemini embedding model, 768 dimensions (API key required)", |
|
|
"type": "gemini", |
|
|
"model_name": "gemini-embedding-001", |
|
|
"dimensions": 768, |
|
|
}, |
|
|
"gemini/gemini-embedding-001-1536": { |
|
|
"name": "Gemini Embedding 001 (1536d)", |
|
|
"description": "Google's Gemini embedding model, 1536 dimensions (API key required)", |
|
|
"type": "gemini", |
|
|
"model_name": "gemini-embedding-001", |
|
|
"dimensions": 1536, |
|
|
}, |
|
|
"cohere/embed-multilingual-v3.0": { |
|
|
"name": "Cohere embed-multilingual-v3.0", |
|
|
"description": "Cohere's multilingual embedding model, 100+ languages (API key required)", |
|
|
"type": "cohere", |
|
|
"model_name": "embed-multilingual-v3.0", |
|
|
"dimensions": 1024, |
|
|
}, |
|
|
"cohere/embed-multilingual-light-v3.0": { |
|
|
"name": "Cohere embed-multilingual-light-v3.0", |
|
|
"description": "Cohere's lightweight multilingual model (API key required)", |
|
|
"type": "cohere", |
|
|
"model_name": "embed-multilingual-light-v3.0", |
|
|
"dimensions": 384, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
ALL_MODELS = {**CURATED_MODELS, **API_MODELS} |
|
|
|
|
|
|
|
|
class BaseEmbeddingModel(ABC): |
|
|
"""Abstract base class for embedding models.""" |
|
|
|
|
|
model_id: str |
|
|
embedding_dim: int |
|
|
|
|
|
@abstractmethod |
|
|
def encode( |
|
|
self, |
|
|
texts: list[str], |
|
|
is_query: bool = False, |
|
|
batch_size: int = 32, |
|
|
show_progress: bool = True, |
|
|
normalize: bool = True, |
|
|
) -> np.ndarray: |
|
|
"""Encode texts to embeddings.""" |
|
|
pass |
|
|
|
|
|
@property |
|
|
@abstractmethod |
|
|
def name(self) -> str: |
|
|
"""Get display name for the model.""" |
|
|
pass |
|
|
|
|
|
@property |
|
|
@abstractmethod |
|
|
def description(self) -> str: |
|
|
"""Get description for the model.""" |
|
|
pass |
|
|
|
|
|
def encode_pairs( |
|
|
self, |
|
|
he_texts: list[str], |
|
|
en_texts: list[str], |
|
|
batch_size: int = 32, |
|
|
show_progress: bool = True, |
|
|
) -> tuple[np.ndarray, np.ndarray]: |
|
|
""" |
|
|
Encode parallel Hebrew/English text pairs. |
|
|
|
|
|
Args: |
|
|
he_texts: Hebrew/Aramaic source texts |
|
|
en_texts: English translations |
|
|
batch_size: Batch size for encoding |
|
|
show_progress: Whether to show progress bar |
|
|
|
|
|
Returns: |
|
|
Tuple of (hebrew_embeddings, english_embeddings) |
|
|
""" |
|
|
he_embeddings = self.encode( |
|
|
he_texts, |
|
|
is_query=True, |
|
|
batch_size=batch_size, |
|
|
show_progress=show_progress, |
|
|
) |
|
|
|
|
|
en_embeddings = self.encode( |
|
|
en_texts, |
|
|
is_query=False, |
|
|
batch_size=batch_size, |
|
|
show_progress=show_progress, |
|
|
) |
|
|
|
|
|
return he_embeddings, en_embeddings |
|
|
|
|
|
|
|
|
class EmbeddingModel(BaseEmbeddingModel): |
|
|
""" |
|
|
Wrapper for sentence-transformer models with consistent interface. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_id: str, |
|
|
device: Optional[str] = None, |
|
|
max_length: int = 512, |
|
|
hf_token: Optional[str] = None, |
|
|
): |
|
|
""" |
|
|
Initialize the embedding model. |
|
|
|
|
|
Args: |
|
|
model_id: Hugging Face model ID |
|
|
device: Device to use ('cuda', 'cpu', or None for auto) |
|
|
max_length: Maximum sequence length for tokenization |
|
|
hf_token: HuggingFace token for gated models (or uses HF_TOKEN env var) |
|
|
""" |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import torch |
|
|
|
|
|
self.model_id = model_id |
|
|
|
|
|
|
|
|
if device is None: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self.config = CURATED_MODELS.get(model_id, { |
|
|
"name": model_id.split("/")[-1], |
|
|
"description": "Custom model", |
|
|
"type": "local", |
|
|
"query_prefix": "", |
|
|
"passage_prefix": "", |
|
|
}) |
|
|
|
|
|
|
|
|
self.max_length = self.config.get("max_length", max_length) |
|
|
|
|
|
|
|
|
hf_token = hf_token or os.environ.get("HF_TOKEN") |
|
|
|
|
|
|
|
|
|
|
|
print(f"Loading model: {model_id} on {device}") |
|
|
|
|
|
|
|
|
trusted_publishers = ["nvidia/", "google/"] |
|
|
trust_remote_code = any(model_id.startswith(pub) for pub in trusted_publishers) |
|
|
|
|
|
if device == "cuda": |
|
|
self.model = SentenceTransformer( |
|
|
model_id, |
|
|
device=device, |
|
|
model_kwargs={"torch_dtype": torch.float16}, |
|
|
trust_remote_code=trust_remote_code, |
|
|
token=hf_token, |
|
|
) |
|
|
else: |
|
|
self.model = SentenceTransformer( |
|
|
model_id, |
|
|
device=device, |
|
|
trust_remote_code=trust_remote_code, |
|
|
token=hf_token, |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(self.model, "max_seq_length"): |
|
|
self.model.max_seq_length = min(self.max_length, self.model.max_seq_length) |
|
|
|
|
|
self.embedding_dim = self.model.get_sentence_embedding_dimension() |
|
|
print(f"Model loaded. Embedding dimension: {self.embedding_dim}") |
|
|
|
|
|
def encode( |
|
|
self, |
|
|
texts: list[str], |
|
|
is_query: bool = False, |
|
|
batch_size: int = 32, |
|
|
show_progress: bool = True, |
|
|
normalize: bool = True, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Encode texts to embeddings. |
|
|
|
|
|
Args: |
|
|
texts: List of texts to encode |
|
|
is_query: Whether these are queries (vs passages) for asymmetric models |
|
|
batch_size: Batch size for encoding |
|
|
show_progress: Whether to show progress bar |
|
|
normalize: Whether to L2-normalize embeddings |
|
|
|
|
|
Returns: |
|
|
numpy array of shape (len(texts), embedding_dim) |
|
|
""" |
|
|
|
|
|
prefix = self.config["query_prefix"] if is_query else self.config["passage_prefix"] |
|
|
if prefix: |
|
|
texts = [prefix + t for t in texts] |
|
|
|
|
|
embeddings = self.model.encode( |
|
|
texts, |
|
|
batch_size=batch_size, |
|
|
show_progress_bar=show_progress, |
|
|
normalize_embeddings=normalize, |
|
|
convert_to_numpy=True, |
|
|
) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
@property |
|
|
def name(self) -> str: |
|
|
"""Get display name for the model.""" |
|
|
return self.config.get("name", self.model_id) |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
"""Get description for the model.""" |
|
|
return self.config.get("description", "") |
|
|
|
|
|
|
|
|
class OpenAIEmbeddingModel(BaseEmbeddingModel): |
|
|
""" |
|
|
Wrapper for OpenAI embedding API with consistent interface. |
|
|
""" |
|
|
|
|
|
|
|
|
MAX_TOKENS = 8191 |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_id: str, |
|
|
api_key: Optional[str] = None, |
|
|
): |
|
|
""" |
|
|
Initialize the OpenAI embedding model. |
|
|
|
|
|
Args: |
|
|
model_id: Model ID in format 'openai/model-name' |
|
|
api_key: OpenAI API key (or uses OPENAI_API_KEY env var) |
|
|
""" |
|
|
try: |
|
|
from openai import OpenAI |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"OpenAI package not installed. Install with: pip install openai" |
|
|
) |
|
|
|
|
|
self.model_id = model_id |
|
|
|
|
|
|
|
|
api_key = api_key or os.environ.get("OPENAI_API_KEY") |
|
|
if not api_key: |
|
|
raise ValueError( |
|
|
"OpenAI API key required. Set OPENAI_API_KEY environment variable " |
|
|
"or pass api_key parameter." |
|
|
) |
|
|
|
|
|
self.client = OpenAI(api_key=api_key) |
|
|
|
|
|
|
|
|
self.config = API_MODELS.get(model_id, { |
|
|
"name": model_id, |
|
|
"description": "OpenAI embedding model", |
|
|
"type": "openai", |
|
|
"model_name": model_id.replace("openai/", ""), |
|
|
"dimensions": 1536, |
|
|
}) |
|
|
|
|
|
self._model_name = self.config["model_name"] |
|
|
self.embedding_dim = self.config["dimensions"] |
|
|
|
|
|
|
|
|
self._encoding = None |
|
|
try: |
|
|
import tiktoken |
|
|
self._encoding = tiktoken.encoding_for_model(self._model_name) |
|
|
except Exception: |
|
|
|
|
|
try: |
|
|
import tiktoken |
|
|
self._encoding = tiktoken.get_encoding("cl100k_base") |
|
|
except Exception: |
|
|
print("Warning: tiktoken not available, using character-based truncation") |
|
|
|
|
|
print(f"Initialized OpenAI embedding model: {self._model_name}") |
|
|
print(f"Embedding dimension: {self.embedding_dim}") |
|
|
|
|
|
def _truncate_text(self, text: str) -> str: |
|
|
"""Truncate text to fit within token limit.""" |
|
|
if self._encoding is not None: |
|
|
|
|
|
tokens = self._encoding.encode(text) |
|
|
if len(tokens) > self.MAX_TOKENS: |
|
|
tokens = tokens[:self.MAX_TOKENS] |
|
|
return self._encoding.decode(tokens) |
|
|
return text |
|
|
else: |
|
|
|
|
|
|
|
|
max_chars = self.MAX_TOKENS * 3 |
|
|
if len(text) > max_chars: |
|
|
return text[:max_chars] |
|
|
return text |
|
|
|
|
|
def encode( |
|
|
self, |
|
|
texts: list[str], |
|
|
is_query: bool = False, |
|
|
batch_size: int = 100, |
|
|
show_progress: bool = True, |
|
|
normalize: bool = True, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Encode texts to embeddings using OpenAI API. |
|
|
|
|
|
Args: |
|
|
texts: List of texts to encode |
|
|
is_query: Not used for OpenAI (symmetric embeddings) |
|
|
batch_size: Batch size for API calls |
|
|
show_progress: Whether to show progress bar |
|
|
normalize: Whether to L2-normalize embeddings (OpenAI already normalizes) |
|
|
|
|
|
Returns: |
|
|
numpy array of shape (len(texts), embedding_dim) |
|
|
""" |
|
|
import time |
|
|
|
|
|
all_embeddings = [] |
|
|
total_batches = (len(texts) + batch_size - 1) // batch_size |
|
|
|
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch = texts[i:i + batch_size] |
|
|
batch_num = i // batch_size + 1 |
|
|
|
|
|
if show_progress: |
|
|
print(f" Encoding batch {batch_num}/{total_batches}...") |
|
|
|
|
|
|
|
|
max_retries = 3 |
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
response = self.client.embeddings.create( |
|
|
model=self._model_name, |
|
|
input=batch, |
|
|
) |
|
|
|
|
|
|
|
|
batch_embeddings = [item.embedding for item in response.data] |
|
|
all_embeddings.extend(batch_embeddings) |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
if attempt < max_retries - 1: |
|
|
wait_time = 2 ** attempt |
|
|
print(f" API error, retrying in {wait_time}s: {e}") |
|
|
time.sleep(wait_time) |
|
|
else: |
|
|
raise RuntimeError(f"OpenAI API error after {max_retries} retries: {e}") |
|
|
|
|
|
|
|
|
if i + batch_size < len(texts): |
|
|
time.sleep(0.1) |
|
|
|
|
|
embeddings = np.array(all_embeddings, dtype=np.float32) |
|
|
|
|
|
|
|
|
if normalize: |
|
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) |
|
|
embeddings = embeddings / np.maximum(norms, 1e-10) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
@property |
|
|
def name(self) -> str: |
|
|
"""Get display name for the model.""" |
|
|
return self.config.get("name", self.model_id) |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
"""Get description for the model.""" |
|
|
return self.config.get("description", "") |
|
|
|
|
|
|
|
|
class VoyageEmbeddingModel(BaseEmbeddingModel): |
|
|
""" |
|
|
Wrapper for Voyage AI embedding API with consistent interface. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_id: str, |
|
|
api_key: Optional[str] = None, |
|
|
): |
|
|
""" |
|
|
Initialize the Voyage AI embedding model. |
|
|
|
|
|
Args: |
|
|
model_id: Model ID in format 'voyage/model-name' |
|
|
api_key: Voyage API key (or uses VOYAGE_API_KEY env var) |
|
|
""" |
|
|
try: |
|
|
import voyageai |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"Voyage AI package not installed. Install with: pip install voyageai" |
|
|
) |
|
|
|
|
|
self.model_id = model_id |
|
|
|
|
|
|
|
|
api_key = api_key or os.environ.get("VOYAGE_API_KEY") |
|
|
if not api_key: |
|
|
raise ValueError( |
|
|
"Voyage API key required. Set VOYAGE_API_KEY environment variable " |
|
|
"or pass api_key parameter." |
|
|
) |
|
|
|
|
|
self.client = voyageai.Client(api_key=api_key) |
|
|
|
|
|
|
|
|
self.config = API_MODELS.get(model_id, { |
|
|
"name": model_id, |
|
|
"description": "Voyage AI embedding model", |
|
|
"type": "voyage", |
|
|
"model_name": model_id.replace("voyage/", ""), |
|
|
"dimensions": 1024, |
|
|
}) |
|
|
|
|
|
self._model_name = self.config["model_name"] |
|
|
self.embedding_dim = self.config["dimensions"] |
|
|
|
|
|
print(f"Initialized Voyage AI embedding model: {self._model_name}") |
|
|
print(f"Embedding dimension: {self.embedding_dim}") |
|
|
|
|
|
def encode( |
|
|
self, |
|
|
texts: list[str], |
|
|
is_query: bool = False, |
|
|
batch_size: int = 128, |
|
|
show_progress: bool = True, |
|
|
normalize: bool = True, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Encode texts to embeddings using Voyage AI API. |
|
|
|
|
|
Args: |
|
|
texts: List of texts to encode |
|
|
is_query: Whether these are queries (Voyage supports input_type) |
|
|
batch_size: Batch size for API calls |
|
|
show_progress: Whether to show progress bar |
|
|
normalize: Whether to L2-normalize embeddings |
|
|
|
|
|
Returns: |
|
|
numpy array of shape (len(texts), embedding_dim) |
|
|
""" |
|
|
import time |
|
|
|
|
|
all_embeddings = [] |
|
|
total_batches = (len(texts) + batch_size - 1) // batch_size |
|
|
|
|
|
|
|
|
input_type = "query" if is_query else "document" |
|
|
|
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch = texts[i:i + batch_size] |
|
|
batch_num = i // batch_size + 1 |
|
|
|
|
|
if show_progress: |
|
|
print(f" Encoding batch {batch_num}/{total_batches}...") |
|
|
|
|
|
|
|
|
max_retries = 3 |
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
result = self.client.embed( |
|
|
batch, |
|
|
model=self._model_name, |
|
|
input_type=input_type, |
|
|
) |
|
|
|
|
|
|
|
|
batch_embeddings = result.embeddings |
|
|
all_embeddings.extend(batch_embeddings) |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
if attempt < max_retries - 1: |
|
|
wait_time = 2 ** attempt |
|
|
print(f" API error, retrying in {wait_time}s: {e}") |
|
|
time.sleep(wait_time) |
|
|
else: |
|
|
raise RuntimeError(f"Voyage AI API error after {max_retries} retries: {e}") |
|
|
|
|
|
|
|
|
if i + batch_size < len(texts): |
|
|
time.sleep(0.1) |
|
|
|
|
|
embeddings = np.array(all_embeddings, dtype=np.float32) |
|
|
|
|
|
|
|
|
if normalize: |
|
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) |
|
|
embeddings = embeddings / np.maximum(norms, 1e-10) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
@property |
|
|
def name(self) -> str: |
|
|
"""Get display name for the model.""" |
|
|
return self.config.get("name", self.model_id) |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
"""Get description for the model.""" |
|
|
return self.config.get("description", "") |
|
|
|
|
|
|
|
|
class GeminiEmbeddingModel(BaseEmbeddingModel): |
|
|
""" |
|
|
Wrapper for Google Gemini embedding API with consistent interface. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_id: str, |
|
|
api_key: Optional[str] = None, |
|
|
): |
|
|
""" |
|
|
Initialize the Gemini embedding model. |
|
|
|
|
|
Args: |
|
|
model_id: Model ID in format 'gemini/model-name' |
|
|
api_key: Gemini API key (optional - can use GEMINI_API_KEY env var |
|
|
or Google Cloud Application Default Credentials) |
|
|
""" |
|
|
try: |
|
|
from google import genai |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"Google GenAI package not installed. Install with: pip install google-genai" |
|
|
) |
|
|
|
|
|
self.model_id = model_id |
|
|
|
|
|
|
|
|
api_key = api_key or os.environ.get("GEMINI_API_KEY") |
|
|
|
|
|
|
|
|
if api_key: |
|
|
self.client = genai.Client(api_key=api_key) |
|
|
else: |
|
|
|
|
|
self.client = genai.Client() |
|
|
|
|
|
|
|
|
self.config = API_MODELS.get(model_id, { |
|
|
"name": model_id, |
|
|
"description": "Gemini embedding model", |
|
|
"type": "gemini", |
|
|
"model_name": model_id.replace("gemini/", "").split("-768")[0].split("-1536")[0], |
|
|
"dimensions": 3072, |
|
|
}) |
|
|
|
|
|
self._model_name = self.config["model_name"] |
|
|
self.embedding_dim = self.config["dimensions"] |
|
|
|
|
|
print(f"Initialized Gemini embedding model: {self._model_name}") |
|
|
print(f"Embedding dimension: {self.embedding_dim}") |
|
|
|
|
|
def encode( |
|
|
self, |
|
|
texts: list[str], |
|
|
is_query: bool = False, |
|
|
batch_size: int = 20, |
|
|
show_progress: bool = True, |
|
|
normalize: bool = True, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Encode texts to embeddings using Gemini API. |
|
|
|
|
|
Args: |
|
|
texts: List of texts to encode |
|
|
is_query: Whether these are queries (uses RETRIEVAL_QUERY vs RETRIEVAL_DOCUMENT) |
|
|
batch_size: Batch size for API calls (smaller for Gemini to avoid rate limits) |
|
|
show_progress: Whether to show progress bar |
|
|
normalize: Whether to L2-normalize embeddings |
|
|
|
|
|
Returns: |
|
|
numpy array of shape (len(texts), embedding_dim) |
|
|
""" |
|
|
import time |
|
|
import random |
|
|
from google.genai import types |
|
|
|
|
|
all_embeddings = [] |
|
|
total_batches = (len(texts) + batch_size - 1) // batch_size |
|
|
|
|
|
|
|
|
task_type = "RETRIEVAL_QUERY" if is_query else "RETRIEVAL_DOCUMENT" |
|
|
|
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch = texts[i:i + batch_size] |
|
|
batch_num = i // batch_size + 1 |
|
|
|
|
|
if show_progress: |
|
|
print(f" Encoding batch {batch_num}/{total_batches}...") |
|
|
|
|
|
|
|
|
max_retries = 8 |
|
|
base_delay = 2.0 |
|
|
|
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
|
|
|
embed_config = types.EmbedContentConfig( |
|
|
task_type=task_type, |
|
|
output_dimensionality=self.embedding_dim, |
|
|
) |
|
|
|
|
|
result = self.client.models.embed_content( |
|
|
model=self._model_name, |
|
|
contents=batch, |
|
|
config=embed_config, |
|
|
) |
|
|
|
|
|
|
|
|
batch_embeddings = [e.values for e in result.embeddings] |
|
|
all_embeddings.extend(batch_embeddings) |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
error_str = str(e) |
|
|
is_rate_limit = "429" in error_str or "RESOURCE_EXHAUSTED" in error_str |
|
|
|
|
|
if attempt < max_retries - 1: |
|
|
|
|
|
|
|
|
if is_rate_limit: |
|
|
wait_time = base_delay * (2 ** attempt) + random.uniform(1, 5) |
|
|
print(f" Rate limited, waiting {wait_time:.1f}s before retry {attempt + 2}/{max_retries}...") |
|
|
else: |
|
|
wait_time = base_delay * (2 ** attempt) + random.uniform(0, 1) |
|
|
print(f" API error, retrying in {wait_time:.1f}s: {e}") |
|
|
time.sleep(wait_time) |
|
|
else: |
|
|
raise RuntimeError(f"Gemini API error after {max_retries} retries: {e}") |
|
|
|
|
|
|
|
|
if i + batch_size < len(texts): |
|
|
time.sleep(0.5) |
|
|
|
|
|
embeddings = np.array(all_embeddings, dtype=np.float32) |
|
|
|
|
|
|
|
|
if normalize: |
|
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) |
|
|
embeddings = embeddings / np.maximum(norms, 1e-10) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
@property |
|
|
def name(self) -> str: |
|
|
"""Get display name for the model.""" |
|
|
return self.config.get("name", self.model_id) |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
"""Get description for the model.""" |
|
|
return self.config.get("description", "") |
|
|
|
|
|
|
|
|
class CohereEmbeddingModel(BaseEmbeddingModel): |
|
|
""" |
|
|
Wrapper for Cohere embedding API with consistent interface. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_id: str, |
|
|
api_key: Optional[str] = None, |
|
|
): |
|
|
""" |
|
|
Initialize the Cohere embedding model. |
|
|
|
|
|
Args: |
|
|
model_id: Model ID in format 'cohere/model-name' |
|
|
api_key: Cohere API key (or uses COHERE_API_KEY env var) |
|
|
""" |
|
|
try: |
|
|
import cohere |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"Cohere package not installed. Install with: pip install cohere" |
|
|
) |
|
|
|
|
|
self.model_id = model_id |
|
|
|
|
|
|
|
|
api_key = api_key or os.environ.get("COHERE_API_KEY") |
|
|
if not api_key: |
|
|
raise ValueError( |
|
|
"Cohere API key required. Set COHERE_API_KEY environment variable " |
|
|
"or pass api_key parameter." |
|
|
) |
|
|
|
|
|
self.client = cohere.Client(api_key=api_key) |
|
|
|
|
|
|
|
|
self.config = API_MODELS.get(model_id, { |
|
|
"name": model_id, |
|
|
"description": "Cohere embedding model", |
|
|
"type": "cohere", |
|
|
"model_name": model_id.replace("cohere/", ""), |
|
|
"dimensions": 1024, |
|
|
}) |
|
|
|
|
|
self._model_name = self.config["model_name"] |
|
|
self.embedding_dim = self.config["dimensions"] |
|
|
|
|
|
print(f"Initialized Cohere embedding model: {self._model_name}") |
|
|
print(f"Embedding dimension: {self.embedding_dim}") |
|
|
|
|
|
def encode( |
|
|
self, |
|
|
texts: list[str], |
|
|
is_query: bool = False, |
|
|
batch_size: int = 96, |
|
|
show_progress: bool = True, |
|
|
normalize: bool = True, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Encode texts to embeddings using Cohere API. |
|
|
|
|
|
Args: |
|
|
texts: List of texts to encode |
|
|
is_query: Whether these are queries (uses search_query vs search_document) |
|
|
batch_size: Batch size for API calls |
|
|
show_progress: Whether to show progress bar |
|
|
normalize: Whether to L2-normalize embeddings |
|
|
|
|
|
Returns: |
|
|
numpy array of shape (len(texts), embedding_dim) |
|
|
""" |
|
|
import time |
|
|
|
|
|
all_embeddings = [] |
|
|
total_batches = (len(texts) + batch_size - 1) // batch_size |
|
|
|
|
|
|
|
|
input_type = "search_query" if is_query else "search_document" |
|
|
|
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch = texts[i:i + batch_size] |
|
|
batch_num = i // batch_size + 1 |
|
|
|
|
|
if show_progress: |
|
|
print(f" Encoding batch {batch_num}/{total_batches}...") |
|
|
|
|
|
|
|
|
max_retries = 3 |
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
result = self.client.embed( |
|
|
texts=batch, |
|
|
model=self._model_name, |
|
|
input_type=input_type, |
|
|
) |
|
|
|
|
|
|
|
|
batch_embeddings = result.embeddings |
|
|
all_embeddings.extend(batch_embeddings) |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
if attempt < max_retries - 1: |
|
|
wait_time = 2 ** attempt |
|
|
print(f" API error, retrying in {wait_time}s: {e}") |
|
|
time.sleep(wait_time) |
|
|
else: |
|
|
raise RuntimeError(f"Cohere API error after {max_retries} retries: {e}") |
|
|
|
|
|
|
|
|
if i + batch_size < len(texts): |
|
|
time.sleep(0.1) |
|
|
|
|
|
embeddings = np.array(all_embeddings, dtype=np.float32) |
|
|
|
|
|
|
|
|
if normalize: |
|
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) |
|
|
embeddings = embeddings / np.maximum(norms, 1e-10) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
@property |
|
|
def name(self) -> str: |
|
|
"""Get display name for the model.""" |
|
|
return self.config.get("name", self.model_id) |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
"""Get description for the model.""" |
|
|
return self.config.get("description", "") |
|
|
|
|
|
|
|
|
def get_curated_model_choices() -> list[tuple[str, str]]: |
|
|
""" |
|
|
Get list of curated local models for UI dropdown. |
|
|
|
|
|
Returns: |
|
|
List of (model_id, display_name) tuples |
|
|
""" |
|
|
return [ |
|
|
(model_id, f"{info['name']} - {info['description']}") |
|
|
for model_id, info in CURATED_MODELS.items() |
|
|
] |
|
|
|
|
|
|
|
|
def get_api_model_choices() -> list[tuple[str, str]]: |
|
|
""" |
|
|
Get list of API-based models for UI dropdown. |
|
|
|
|
|
Returns: |
|
|
List of (model_id, display_name) tuples |
|
|
""" |
|
|
return [ |
|
|
(model_id, f"{info['name']} - {info['description']}") |
|
|
for model_id, info in API_MODELS.items() |
|
|
] |
|
|
|
|
|
|
|
|
def get_all_model_choices() -> list[tuple[str, str]]: |
|
|
""" |
|
|
Get list of all models (local + API) for UI dropdown. |
|
|
|
|
|
Returns: |
|
|
List of (model_id, display_name) tuples |
|
|
""" |
|
|
return get_curated_model_choices() + get_api_model_choices() |
|
|
|
|
|
|
|
|
def is_api_model(model_id: str) -> bool: |
|
|
"""Check if a model ID is an API-based model.""" |
|
|
model_id = model_id.strip() |
|
|
|
|
|
|
|
|
if model_id in API_MODELS: |
|
|
return True |
|
|
|
|
|
|
|
|
if model_id.startswith("openai/"): |
|
|
return True |
|
|
if model_id.startswith("voyage/"): |
|
|
return True |
|
|
if model_id.startswith("gemini/"): |
|
|
return True |
|
|
if model_id.startswith("cohere/"): |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
def load_model( |
|
|
model_id: str, |
|
|
device: Optional[str] = None, |
|
|
api_key: Optional[str] = None, |
|
|
hf_token: Optional[str] = None, |
|
|
) -> BaseEmbeddingModel: |
|
|
""" |
|
|
Load an embedding model by ID. |
|
|
|
|
|
Args: |
|
|
model_id: Model ID (HuggingFace model ID or API model like 'openai/text-embedding-3-large') |
|
|
device: Device to use (for local models only) |
|
|
api_key: API key (for API-based models, or uses environment variable) |
|
|
hf_token: HuggingFace token for gated local models (or uses HF_TOKEN env var) |
|
|
|
|
|
Returns: |
|
|
Loaded embedding model instance |
|
|
""" |
|
|
model_id = model_id.strip() |
|
|
|
|
|
|
|
|
if is_api_model(model_id): |
|
|
|
|
|
model_config = API_MODELS.get(model_id, {}) |
|
|
model_type = model_config.get("type", "") |
|
|
|
|
|
if model_type == "voyage" or model_id.startswith("voyage/"): |
|
|
return VoyageEmbeddingModel(model_id, api_key=api_key) |
|
|
elif model_type == "gemini" or model_id.startswith("gemini/"): |
|
|
return GeminiEmbeddingModel(model_id, api_key=api_key) |
|
|
elif model_type == "cohere" or model_id.startswith("cohere/"): |
|
|
return CohereEmbeddingModel(model_id, api_key=api_key) |
|
|
elif model_type == "openai" or model_id.startswith("openai/"): |
|
|
return OpenAIEmbeddingModel(model_id, api_key=api_key) |
|
|
else: |
|
|
raise ValueError(f"Unknown API model type: {model_id}") |
|
|
|
|
|
|
|
|
return EmbeddingModel(model_id, device=device, hf_token=hf_token) |
|
|
|
|
|
|
|
|
def validate_model_id(model_id: str) -> tuple[bool, str]: |
|
|
""" |
|
|
Check if a model ID is valid and loadable. |
|
|
|
|
|
Args: |
|
|
model_id: The model ID to validate |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, error_message) |
|
|
""" |
|
|
if not model_id or not model_id.strip(): |
|
|
return False, "Model ID cannot be empty" |
|
|
|
|
|
model_id = model_id.strip() |
|
|
|
|
|
|
|
|
if model_id in CURATED_MODELS: |
|
|
return True, "" |
|
|
|
|
|
|
|
|
if model_id in API_MODELS: |
|
|
return True, "" |
|
|
|
|
|
|
|
|
if model_id.startswith("openai/"): |
|
|
return True, "" |
|
|
|
|
|
|
|
|
if model_id.startswith("voyage/"): |
|
|
return True, "" |
|
|
|
|
|
|
|
|
if model_id.startswith("gemini/"): |
|
|
return True, "" |
|
|
|
|
|
|
|
|
if model_id.startswith("cohere/"): |
|
|
return True, "" |
|
|
|
|
|
|
|
|
if "/" not in model_id: |
|
|
return False, "Model ID should be in format 'organization/model-name'" |
|
|
|
|
|
|
|
|
return True, "" |
|
|
|
|
|
|
|
|
def requires_api_key(model_id: str) -> bool: |
|
|
"""Check if a model requires an API key.""" |
|
|
return is_api_model(model_id) |
|
|
|
|
|
|
|
|
def api_key_optional(model_id: str) -> bool: |
|
|
""" |
|
|
Check if an API key is optional for this model. |
|
|
|
|
|
Some providers (like Google Gemini) support Application Default Credentials |
|
|
as an alternative to explicit API keys. |
|
|
""" |
|
|
key_type = get_api_key_type(model_id) |
|
|
|
|
|
return key_type == "gemini" |
|
|
|
|
|
|
|
|
def get_api_key_type(model_id: str) -> Optional[str]: |
|
|
""" |
|
|
Get the type of API key required for a model. |
|
|
|
|
|
Args: |
|
|
model_id: The model ID |
|
|
|
|
|
Returns: |
|
|
'openai', 'voyage', or None if no API key needed |
|
|
""" |
|
|
if not is_api_model(model_id): |
|
|
return None |
|
|
|
|
|
model_id = model_id.strip() |
|
|
model_config = API_MODELS.get(model_id, {}) |
|
|
model_type = model_config.get("type", "") |
|
|
|
|
|
if model_type == "voyage" or model_id.startswith("voyage/"): |
|
|
return "voyage" |
|
|
elif model_type == "gemini" or model_id.startswith("gemini/"): |
|
|
return "gemini" |
|
|
elif model_type == "cohere" or model_id.startswith("cohere/"): |
|
|
return "cohere" |
|
|
elif model_type == "openai" or model_id.startswith("openai/"): |
|
|
return "openai" |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def get_api_key_env_var(model_id: str) -> Optional[str]: |
|
|
""" |
|
|
Get the environment variable name for the API key required by a model. |
|
|
|
|
|
Args: |
|
|
model_id: The model ID |
|
|
|
|
|
Returns: |
|
|
Environment variable name or None |
|
|
""" |
|
|
key_type = get_api_key_type(model_id) |
|
|
if key_type == "openai": |
|
|
return "OPENAI_API_KEY" |
|
|
elif key_type == "voyage": |
|
|
return "VOYAGE_API_KEY" |
|
|
elif key_type == "gemini": |
|
|
return "GEMINI_API_KEY" |
|
|
elif key_type == "cohere": |
|
|
return "COHERE_API_KEY" |
|
|
return None |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
|
description="Test embedding model loading and encoding" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--local", |
|
|
action="store_true", |
|
|
help="Test only local sentence-transformer models", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--remote", |
|
|
action="store_true", |
|
|
help="Test only remote/API models (requires API keys)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Test a specific model ID", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
test_local = args.local or (not args.local and not args.remote) |
|
|
test_remote = args.remote or (not args.local and not args.remote) |
|
|
|
|
|
print("Testing model loading...") |
|
|
|
|
|
print(f"\nLocal models available:") |
|
|
for model_id, display in get_curated_model_choices(): |
|
|
print(f" - {display}") |
|
|
|
|
|
print(f"\nAPI models available:") |
|
|
for model_id, display in get_api_model_choices(): |
|
|
print(f" - {display}") |
|
|
|
|
|
|
|
|
test_texts = [ |
|
|
"讘专讗砖讬转 讘专讗 讗诇讛讬诐 讗转 讛砖诪讬诐 讜讗转 讛讗专抓", |
|
|
"In the beginning God created the heaven and the earth", |
|
|
] |
|
|
|
|
|
def run_model_test(model_id: str, model_type: str): |
|
|
"""Run a test for a specific model.""" |
|
|
print(f"\n{'='*60}") |
|
|
print(f"Testing {model_type}: {model_id}") |
|
|
print("="*60) |
|
|
|
|
|
try: |
|
|
model = load_model(model_id) |
|
|
|
|
|
embeddings = model.encode(test_texts, show_progress=False) |
|
|
print(f"\nEncoded {len(test_texts)} texts") |
|
|
print(f"Embedding shape: {embeddings.shape}") |
|
|
|
|
|
similarity = np.dot(embeddings[0], embeddings[1]) |
|
|
print(f"Cosine similarity between Hebrew and English: {similarity:.4f}") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Test failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
if args.model: |
|
|
run_model_test(args.model, "specified model") |
|
|
else: |
|
|
|
|
|
if test_local: |
|
|
run_model_test( |
|
|
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2", |
|
|
"local sentence-transformer model" |
|
|
) |
|
|
|
|
|
|
|
|
if test_remote: |
|
|
|
|
|
if os.environ.get("OPENAI_API_KEY"): |
|
|
run_model_test( |
|
|
"openai/text-embedding-3-small", |
|
|
"OpenAI API model" |
|
|
) |
|
|
else: |
|
|
print("\n(Skipping OpenAI test - OPENAI_API_KEY not set)") |
|
|
|
|
|
|
|
|
if os.environ.get("VOYAGE_API_KEY"): |
|
|
run_model_test( |
|
|
"voyage/voyage-3.5", |
|
|
"Voyage AI API model" |
|
|
) |
|
|
else: |
|
|
print("\n(Skipping Voyage AI test - VOYAGE_API_KEY not set)") |
|
|
|
|
|
|
|
|
if os.environ.get("GEMINI_API_KEY"): |
|
|
run_model_test( |
|
|
"gemini/gemini-embedding-001", |
|
|
"Gemini API model" |
|
|
) |
|
|
else: |
|
|
print("\n(Skipping Gemini test - GEMINI_API_KEY not set)") |
|
|
|
|
|
|