Lev Israel
Embedding Gemma
d1c390a
"""
Model loading and embedding interface for the Rabbinic embedding benchmark.
Supports:
- Curated models from Hugging Face (sentence-transformers)
- Any Hugging Face sentence-transformer model
- API-based models (OpenAI, Voyage AI, Google Gemini)
"""
import os
from abc import ABC, abstractmethod
from typing import Optional
import numpy as np
# Curated local models known to work well for multilingual tasks
CURATED_MODELS = {
"intfloat/multilingual-e5-large": {
"name": "Multilingual E5 Large",
"description": "Strong multilingual model from Microsoft, 560M params",
"type": "local",
"query_prefix": "query: ",
"passage_prefix": "passage: ",
},
"intfloat/multilingual-e5-base": {
"name": "Multilingual E5 Base",
"description": "Smaller multilingual E5, 278M params",
"type": "local",
"query_prefix": "query: ",
"passage_prefix": "passage: ",
},
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2": {
"name": "Multilingual MPNet",
"description": "Classic multilingual sentence transformer, 278M params",
"type": "local",
"query_prefix": "",
"passage_prefix": "",
},
"BAAI/bge-m3": {
"name": "BGE-M3",
"description": "Multi-lingual, multi-functionality, multi-granularity model from BAAI",
"type": "local",
"query_prefix": "",
"passage_prefix": "",
},
"intfloat/e5-mistral-7b-instruct": {
"name": "E5 Mistral 7B",
"description": "Large instruction-tuned embedding model, 7B params (requires GPU)",
"type": "local",
"query_prefix": "Instruct: Retrieve semantically similar text\nQuery: ",
"passage_prefix": "",
},
"Alibaba-NLP/gte-multilingual-base": {
"name": "GTE Multilingual Base",
"description": "General Text Embeddings multilingual model from Alibaba",
"type": "local",
"query_prefix": "",
"passage_prefix": "",
},
"google/embeddinggemma-300m": {
"name": "EmbeddingGemma",
"description": "Google's 300M param embedding model, 100+ languages, 768d (requires HF token + license)",
"type": "local",
"query_prefix": "task: search result | query: ",
"passage_prefix": "title: none | text: ",
"max_length": 2048,
},
}
# API-based models
API_MODELS = {
"openai/text-embedding-3-large": {
"name": "OpenAI text-embedding-3-large",
"description": "OpenAI's best embedding model, 3072 dimensions (API key required)",
"type": "openai",
"model_name": "text-embedding-3-large",
"dimensions": 3072,
},
"openai/text-embedding-3-small": {
"name": "OpenAI text-embedding-3-small",
"description": "OpenAI's efficient embedding model, 1536 dimensions (API key required)",
"type": "openai",
"model_name": "text-embedding-3-small",
"dimensions": 1536,
},
"openai/text-embedding-ada-002": {
"name": "OpenAI Ada 002",
"description": "OpenAI's legacy embedding model, 1536 dimensions (API key required)",
"type": "openai",
"model_name": "text-embedding-ada-002",
"dimensions": 1536,
},
"voyage/voyage-3.5": {
"name": "Voyage AI voyage-3.5",
"description": "Voyage AI's latest embedding model (API key required)",
"type": "voyage",
"model_name": "voyage-3.5",
"dimensions": 1024,
},
"voyage/voyage-3.5-lite": {
"name": "Voyage AI voyage-3.5-lite",
"description": "Voyage AI's efficient embedding model (API key required)",
"type": "voyage",
"model_name": "voyage-3.5-lite",
"dimensions": 1024,
},
"voyage/voyage-3": {
"name": "Voyage AI voyage-3",
"description": "Voyage AI's general purpose embedding model (API key required)",
"type": "voyage",
"model_name": "voyage-3",
"dimensions": 1024,
},
"voyage/voyage-3-lite": {
"name": "Voyage AI voyage-3-lite",
"description": "Voyage AI's lightweight embedding model (API key required)",
"type": "voyage",
"model_name": "voyage-3-lite",
"dimensions": 512,
},
"voyage/voyage-multilingual-2": {
"name": "Voyage AI voyage-multilingual-2",
"description": "Voyage AI's multilingual embedding model, optimized for non-English (API key required)",
"type": "voyage",
"model_name": "voyage-multilingual-2",
"dimensions": 1024,
},
"gemini/gemini-embedding-001": {
"name": "Gemini Embedding 001",
"description": "Google's Gemini embedding model, 3072 dimensions (API key required)",
"type": "gemini",
"model_name": "gemini-embedding-001",
"dimensions": 3072,
},
"gemini/gemini-embedding-001-768": {
"name": "Gemini Embedding 001 (768d)",
"description": "Google's Gemini embedding model, 768 dimensions (API key required)",
"type": "gemini",
"model_name": "gemini-embedding-001",
"dimensions": 768,
},
"gemini/gemini-embedding-001-1536": {
"name": "Gemini Embedding 001 (1536d)",
"description": "Google's Gemini embedding model, 1536 dimensions (API key required)",
"type": "gemini",
"model_name": "gemini-embedding-001",
"dimensions": 1536,
},
"cohere/embed-multilingual-v3.0": {
"name": "Cohere embed-multilingual-v3.0",
"description": "Cohere's multilingual embedding model, 100+ languages (API key required)",
"type": "cohere",
"model_name": "embed-multilingual-v3.0",
"dimensions": 1024,
},
"cohere/embed-multilingual-light-v3.0": {
"name": "Cohere embed-multilingual-light-v3.0",
"description": "Cohere's lightweight multilingual model (API key required)",
"type": "cohere",
"model_name": "embed-multilingual-light-v3.0",
"dimensions": 384,
},
}
# Merge all models for easy lookup
ALL_MODELS = {**CURATED_MODELS, **API_MODELS}
class BaseEmbeddingModel(ABC):
"""Abstract base class for embedding models."""
model_id: str
embedding_dim: int
@abstractmethod
def encode(
self,
texts: list[str],
is_query: bool = False,
batch_size: int = 32,
show_progress: bool = True,
normalize: bool = True,
) -> np.ndarray:
"""Encode texts to embeddings."""
pass
@property
@abstractmethod
def name(self) -> str:
"""Get display name for the model."""
pass
@property
@abstractmethod
def description(self) -> str:
"""Get description for the model."""
pass
def encode_pairs(
self,
he_texts: list[str],
en_texts: list[str],
batch_size: int = 32,
show_progress: bool = True,
) -> tuple[np.ndarray, np.ndarray]:
"""
Encode parallel Hebrew/English text pairs.
Args:
he_texts: Hebrew/Aramaic source texts
en_texts: English translations
batch_size: Batch size for encoding
show_progress: Whether to show progress bar
Returns:
Tuple of (hebrew_embeddings, english_embeddings)
"""
he_embeddings = self.encode(
he_texts,
is_query=True,
batch_size=batch_size,
show_progress=show_progress,
)
en_embeddings = self.encode(
en_texts,
is_query=False,
batch_size=batch_size,
show_progress=show_progress,
)
return he_embeddings, en_embeddings
class EmbeddingModel(BaseEmbeddingModel):
"""
Wrapper for sentence-transformer models with consistent interface.
"""
def __init__(
self,
model_id: str,
device: Optional[str] = None,
max_length: int = 512,
hf_token: Optional[str] = None,
):
"""
Initialize the embedding model.
Args:
model_id: Hugging Face model ID
device: Device to use ('cuda', 'cpu', or None for auto)
max_length: Maximum sequence length for tokenization
hf_token: HuggingFace token for gated models (or uses HF_TOKEN env var)
"""
from sentence_transformers import SentenceTransformer
import torch
self.model_id = model_id
# Auto-detect device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
# Get model config if it's a curated model
self.config = CURATED_MODELS.get(model_id, {
"name": model_id.split("/")[-1],
"description": "Custom model",
"type": "local",
"query_prefix": "",
"passage_prefix": "",
})
# Use config max_length if available, otherwise use parameter
self.max_length = self.config.get("max_length", max_length)
# Get HF token from parameter or environment (for gated models like EmbeddingGemma)
hf_token = hf_token or os.environ.get("HF_TOKEN")
# Load the model with float16 on CUDA to save VRAM
# (12B model: float32 = 48GB, float16 = 24GB)
print(f"Loading model: {model_id} on {device}")
# Only trust remote code from known publishers (security measure)
trusted_publishers = ["nvidia/", "google/"]
trust_remote_code = any(model_id.startswith(pub) for pub in trusted_publishers)
if device == "cuda":
self.model = SentenceTransformer(
model_id,
device=device,
model_kwargs={"torch_dtype": torch.float16},
trust_remote_code=trust_remote_code,
token=hf_token,
)
else:
self.model = SentenceTransformer(
model_id,
device=device,
trust_remote_code=trust_remote_code,
token=hf_token,
)
# Set max sequence length if supported
if hasattr(self.model, "max_seq_length"):
self.model.max_seq_length = min(self.max_length, self.model.max_seq_length)
self.embedding_dim = self.model.get_sentence_embedding_dimension()
print(f"Model loaded. Embedding dimension: {self.embedding_dim}")
def encode(
self,
texts: list[str],
is_query: bool = False,
batch_size: int = 32,
show_progress: bool = True,
normalize: bool = True,
) -> np.ndarray:
"""
Encode texts to embeddings.
Args:
texts: List of texts to encode
is_query: Whether these are queries (vs passages) for asymmetric models
batch_size: Batch size for encoding
show_progress: Whether to show progress bar
normalize: Whether to L2-normalize embeddings
Returns:
numpy array of shape (len(texts), embedding_dim)
"""
# Add prefix if needed (for E5-style models)
prefix = self.config["query_prefix"] if is_query else self.config["passage_prefix"]
if prefix:
texts = [prefix + t for t in texts]
embeddings = self.model.encode(
texts,
batch_size=batch_size,
show_progress_bar=show_progress,
normalize_embeddings=normalize,
convert_to_numpy=True,
)
return embeddings
@property
def name(self) -> str:
"""Get display name for the model."""
return self.config.get("name", self.model_id)
@property
def description(self) -> str:
"""Get description for the model."""
return self.config.get("description", "")
class OpenAIEmbeddingModel(BaseEmbeddingModel):
"""
Wrapper for OpenAI embedding API with consistent interface.
"""
# OpenAI embedding models have an 8191 token limit
MAX_TOKENS = 8191
def __init__(
self,
model_id: str,
api_key: Optional[str] = None,
):
"""
Initialize the OpenAI embedding model.
Args:
model_id: Model ID in format 'openai/model-name'
api_key: OpenAI API key (or uses OPENAI_API_KEY env var)
"""
try:
from openai import OpenAI
except ImportError:
raise ImportError(
"OpenAI package not installed. Install with: pip install openai"
)
self.model_id = model_id
# Get API key from parameter or environment
api_key = api_key or os.environ.get("OPENAI_API_KEY")
if not api_key:
raise ValueError(
"OpenAI API key required. Set OPENAI_API_KEY environment variable "
"or pass api_key parameter."
)
self.client = OpenAI(api_key=api_key)
# Get model config
self.config = API_MODELS.get(model_id, {
"name": model_id,
"description": "OpenAI embedding model",
"type": "openai",
"model_name": model_id.replace("openai/", ""),
"dimensions": 1536,
})
self._model_name = self.config["model_name"]
self.embedding_dim = self.config["dimensions"]
# Initialize tokenizer for truncation
self._encoding = None
try:
import tiktoken
self._encoding = tiktoken.encoding_for_model(self._model_name)
except Exception:
# Fall back to cl100k_base which is used by embedding models
try:
import tiktoken
self._encoding = tiktoken.get_encoding("cl100k_base")
except Exception:
print("Warning: tiktoken not available, using character-based truncation")
print(f"Initialized OpenAI embedding model: {self._model_name}")
print(f"Embedding dimension: {self.embedding_dim}")
def _truncate_text(self, text: str) -> str:
"""Truncate text to fit within token limit."""
if self._encoding is not None:
# Use tiktoken for accurate token counting
tokens = self._encoding.encode(text)
if len(tokens) > self.MAX_TOKENS:
tokens = tokens[:self.MAX_TOKENS]
return self._encoding.decode(tokens)
return text
else:
# Fallback: rough character-based truncation
# Assume ~3 chars per token for Hebrew/mixed text (conservative)
max_chars = self.MAX_TOKENS * 3
if len(text) > max_chars:
return text[:max_chars]
return text
def encode(
self,
texts: list[str],
is_query: bool = False,
batch_size: int = 100, # OpenAI supports larger batches
show_progress: bool = True,
normalize: bool = True,
) -> np.ndarray:
"""
Encode texts to embeddings using OpenAI API.
Args:
texts: List of texts to encode
is_query: Not used for OpenAI (symmetric embeddings)
batch_size: Batch size for API calls
show_progress: Whether to show progress bar
normalize: Whether to L2-normalize embeddings (OpenAI already normalizes)
Returns:
numpy array of shape (len(texts), embedding_dim)
"""
import time
all_embeddings = []
total_batches = (len(texts) + batch_size - 1) // batch_size
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
batch_num = i // batch_size + 1
if show_progress:
print(f" Encoding batch {batch_num}/{total_batches}...")
# Retry logic for API calls
max_retries = 3
for attempt in range(max_retries):
try:
response = self.client.embeddings.create(
model=self._model_name,
input=batch,
)
# Extract embeddings from response
batch_embeddings = [item.embedding for item in response.data]
all_embeddings.extend(batch_embeddings)
break
except Exception as e:
if attempt < max_retries - 1:
wait_time = 2 ** attempt
print(f" API error, retrying in {wait_time}s: {e}")
time.sleep(wait_time)
else:
raise RuntimeError(f"OpenAI API error after {max_retries} retries: {e}")
# Small delay to avoid rate limits
if i + batch_size < len(texts):
time.sleep(0.1)
embeddings = np.array(all_embeddings, dtype=np.float32)
# OpenAI embeddings are already normalized, but normalize if requested
if normalize:
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings / np.maximum(norms, 1e-10)
return embeddings
@property
def name(self) -> str:
"""Get display name for the model."""
return self.config.get("name", self.model_id)
@property
def description(self) -> str:
"""Get description for the model."""
return self.config.get("description", "")
class VoyageEmbeddingModel(BaseEmbeddingModel):
"""
Wrapper for Voyage AI embedding API with consistent interface.
"""
def __init__(
self,
model_id: str,
api_key: Optional[str] = None,
):
"""
Initialize the Voyage AI embedding model.
Args:
model_id: Model ID in format 'voyage/model-name'
api_key: Voyage API key (or uses VOYAGE_API_KEY env var)
"""
try:
import voyageai
except ImportError:
raise ImportError(
"Voyage AI package not installed. Install with: pip install voyageai"
)
self.model_id = model_id
# Get API key from parameter or environment
api_key = api_key or os.environ.get("VOYAGE_API_KEY")
if not api_key:
raise ValueError(
"Voyage API key required. Set VOYAGE_API_KEY environment variable "
"or pass api_key parameter."
)
self.client = voyageai.Client(api_key=api_key)
# Get model config
self.config = API_MODELS.get(model_id, {
"name": model_id,
"description": "Voyage AI embedding model",
"type": "voyage",
"model_name": model_id.replace("voyage/", ""),
"dimensions": 1024, # Default dimension
})
self._model_name = self.config["model_name"]
self.embedding_dim = self.config["dimensions"]
print(f"Initialized Voyage AI embedding model: {self._model_name}")
print(f"Embedding dimension: {self.embedding_dim}")
def encode(
self,
texts: list[str],
is_query: bool = False,
batch_size: int = 128, # Voyage supports larger batches
show_progress: bool = True,
normalize: bool = True,
) -> np.ndarray:
"""
Encode texts to embeddings using Voyage AI API.
Args:
texts: List of texts to encode
is_query: Whether these are queries (Voyage supports input_type)
batch_size: Batch size for API calls
show_progress: Whether to show progress bar
normalize: Whether to L2-normalize embeddings
Returns:
numpy array of shape (len(texts), embedding_dim)
"""
import time
all_embeddings = []
total_batches = (len(texts) + batch_size - 1) // batch_size
# Voyage supports input_type for asymmetric embeddings
input_type = "query" if is_query else "document"
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
batch_num = i // batch_size + 1
if show_progress:
print(f" Encoding batch {batch_num}/{total_batches}...")
# Retry logic for API calls
max_retries = 3
for attempt in range(max_retries):
try:
result = self.client.embed(
batch,
model=self._model_name,
input_type=input_type,
)
# Extract embeddings from response
batch_embeddings = result.embeddings
all_embeddings.extend(batch_embeddings)
break
except Exception as e:
if attempt < max_retries - 1:
wait_time = 2 ** attempt
print(f" API error, retrying in {wait_time}s: {e}")
time.sleep(wait_time)
else:
raise RuntimeError(f"Voyage AI API error after {max_retries} retries: {e}")
# Small delay to avoid rate limits
if i + batch_size < len(texts):
time.sleep(0.1)
embeddings = np.array(all_embeddings, dtype=np.float32)
# Normalize if requested
if normalize:
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings / np.maximum(norms, 1e-10)
return embeddings
@property
def name(self) -> str:
"""Get display name for the model."""
return self.config.get("name", self.model_id)
@property
def description(self) -> str:
"""Get description for the model."""
return self.config.get("description", "")
class GeminiEmbeddingModel(BaseEmbeddingModel):
"""
Wrapper for Google Gemini embedding API with consistent interface.
"""
def __init__(
self,
model_id: str,
api_key: Optional[str] = None,
):
"""
Initialize the Gemini embedding model.
Args:
model_id: Model ID in format 'gemini/model-name'
api_key: Gemini API key (optional - can use GEMINI_API_KEY env var
or Google Cloud Application Default Credentials)
"""
try:
from google import genai
except ImportError:
raise ImportError(
"Google GenAI package not installed. Install with: pip install google-genai"
)
self.model_id = model_id
# Get API key from parameter or environment (optional - ADC also works)
api_key = api_key or os.environ.get("GEMINI_API_KEY")
# Create client - if no API key, will use Application Default Credentials
if api_key:
self.client = genai.Client(api_key=api_key)
else:
# Use Application Default Credentials (gcloud auth application-default login)
self.client = genai.Client()
# Get model config
self.config = API_MODELS.get(model_id, {
"name": model_id,
"description": "Gemini embedding model",
"type": "gemini",
"model_name": model_id.replace("gemini/", "").split("-768")[0].split("-1536")[0],
"dimensions": 3072, # Default dimension
})
self._model_name = self.config["model_name"]
self.embedding_dim = self.config["dimensions"]
print(f"Initialized Gemini embedding model: {self._model_name}")
print(f"Embedding dimension: {self.embedding_dim}")
def encode(
self,
texts: list[str],
is_query: bool = False,
batch_size: int = 20, # Smaller batches to avoid rate limits
show_progress: bool = True,
normalize: bool = True,
) -> np.ndarray:
"""
Encode texts to embeddings using Gemini API.
Args:
texts: List of texts to encode
is_query: Whether these are queries (uses RETRIEVAL_QUERY vs RETRIEVAL_DOCUMENT)
batch_size: Batch size for API calls (smaller for Gemini to avoid rate limits)
show_progress: Whether to show progress bar
normalize: Whether to L2-normalize embeddings
Returns:
numpy array of shape (len(texts), embedding_dim)
"""
import time
import random
from google.genai import types
all_embeddings = []
total_batches = (len(texts) + batch_size - 1) // batch_size
# Gemini supports task_type for asymmetric embeddings
task_type = "RETRIEVAL_QUERY" if is_query else "RETRIEVAL_DOCUMENT"
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
batch_num = i // batch_size + 1
if show_progress:
print(f" Encoding batch {batch_num}/{total_batches}...")
# Retry logic with exponential backoff for rate limits
max_retries = 8
base_delay = 2.0
for attempt in range(max_retries):
try:
# Build config with task type and output dimensionality
embed_config = types.EmbedContentConfig(
task_type=task_type,
output_dimensionality=self.embedding_dim,
)
result = self.client.models.embed_content(
model=self._model_name,
contents=batch,
config=embed_config,
)
# Extract embeddings from response
batch_embeddings = [e.values for e in result.embeddings]
all_embeddings.extend(batch_embeddings)
break
except Exception as e:
error_str = str(e)
is_rate_limit = "429" in error_str or "RESOURCE_EXHAUSTED" in error_str
if attempt < max_retries - 1:
# Exponential backoff with jitter
# Longer waits for rate limit errors
if is_rate_limit:
wait_time = base_delay * (2 ** attempt) + random.uniform(1, 5)
print(f" Rate limited, waiting {wait_time:.1f}s before retry {attempt + 2}/{max_retries}...")
else:
wait_time = base_delay * (2 ** attempt) + random.uniform(0, 1)
print(f" API error, retrying in {wait_time:.1f}s: {e}")
time.sleep(wait_time)
else:
raise RuntimeError(f"Gemini API error after {max_retries} retries: {e}")
# Delay between batches to avoid rate limits (longer for Gemini)
if i + batch_size < len(texts):
time.sleep(0.5)
embeddings = np.array(all_embeddings, dtype=np.float32)
# Normalize if requested (Gemini's 3072d is normalized, but smaller dims need it)
if normalize:
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings / np.maximum(norms, 1e-10)
return embeddings
@property
def name(self) -> str:
"""Get display name for the model."""
return self.config.get("name", self.model_id)
@property
def description(self) -> str:
"""Get description for the model."""
return self.config.get("description", "")
class CohereEmbeddingModel(BaseEmbeddingModel):
"""
Wrapper for Cohere embedding API with consistent interface.
"""
def __init__(
self,
model_id: str,
api_key: Optional[str] = None,
):
"""
Initialize the Cohere embedding model.
Args:
model_id: Model ID in format 'cohere/model-name'
api_key: Cohere API key (or uses COHERE_API_KEY env var)
"""
try:
import cohere
except ImportError:
raise ImportError(
"Cohere package not installed. Install with: pip install cohere"
)
self.model_id = model_id
# Get API key from parameter or environment
api_key = api_key or os.environ.get("COHERE_API_KEY")
if not api_key:
raise ValueError(
"Cohere API key required. Set COHERE_API_KEY environment variable "
"or pass api_key parameter."
)
self.client = cohere.Client(api_key=api_key)
# Get model config
self.config = API_MODELS.get(model_id, {
"name": model_id,
"description": "Cohere embedding model",
"type": "cohere",
"model_name": model_id.replace("cohere/", ""),
"dimensions": 1024, # Default dimension
})
self._model_name = self.config["model_name"]
self.embedding_dim = self.config["dimensions"]
print(f"Initialized Cohere embedding model: {self._model_name}")
print(f"Embedding dimension: {self.embedding_dim}")
def encode(
self,
texts: list[str],
is_query: bool = False,
batch_size: int = 96, # Cohere supports up to 96 texts per request
show_progress: bool = True,
normalize: bool = True,
) -> np.ndarray:
"""
Encode texts to embeddings using Cohere API.
Args:
texts: List of texts to encode
is_query: Whether these are queries (uses search_query vs search_document)
batch_size: Batch size for API calls
show_progress: Whether to show progress bar
normalize: Whether to L2-normalize embeddings
Returns:
numpy array of shape (len(texts), embedding_dim)
"""
import time
all_embeddings = []
total_batches = (len(texts) + batch_size - 1) // batch_size
# Cohere v3 models require input_type for asymmetric embeddings
input_type = "search_query" if is_query else "search_document"
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
batch_num = i // batch_size + 1
if show_progress:
print(f" Encoding batch {batch_num}/{total_batches}...")
# Retry logic for API calls
max_retries = 3
for attempt in range(max_retries):
try:
result = self.client.embed(
texts=batch,
model=self._model_name,
input_type=input_type,
)
# Extract embeddings from response
batch_embeddings = result.embeddings
all_embeddings.extend(batch_embeddings)
break
except Exception as e:
if attempt < max_retries - 1:
wait_time = 2 ** attempt
print(f" API error, retrying in {wait_time}s: {e}")
time.sleep(wait_time)
else:
raise RuntimeError(f"Cohere API error after {max_retries} retries: {e}")
# Small delay to avoid rate limits
if i + batch_size < len(texts):
time.sleep(0.1)
embeddings = np.array(all_embeddings, dtype=np.float32)
# Normalize if requested
if normalize:
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings / np.maximum(norms, 1e-10)
return embeddings
@property
def name(self) -> str:
"""Get display name for the model."""
return self.config.get("name", self.model_id)
@property
def description(self) -> str:
"""Get description for the model."""
return self.config.get("description", "")
def get_curated_model_choices() -> list[tuple[str, str]]:
"""
Get list of curated local models for UI dropdown.
Returns:
List of (model_id, display_name) tuples
"""
return [
(model_id, f"{info['name']} - {info['description']}")
for model_id, info in CURATED_MODELS.items()
]
def get_api_model_choices() -> list[tuple[str, str]]:
"""
Get list of API-based models for UI dropdown.
Returns:
List of (model_id, display_name) tuples
"""
return [
(model_id, f"{info['name']} - {info['description']}")
for model_id, info in API_MODELS.items()
]
def get_all_model_choices() -> list[tuple[str, str]]:
"""
Get list of all models (local + API) for UI dropdown.
Returns:
List of (model_id, display_name) tuples
"""
return get_curated_model_choices() + get_api_model_choices()
def is_api_model(model_id: str) -> bool:
"""Check if a model ID is an API-based model."""
model_id = model_id.strip()
# Check if it's in API_MODELS
if model_id in API_MODELS:
return True
# Check if it starts with known API prefixes
if model_id.startswith("openai/"):
return True
if model_id.startswith("voyage/"):
return True
if model_id.startswith("gemini/"):
return True
if model_id.startswith("cohere/"):
return True
return False
def load_model(
model_id: str,
device: Optional[str] = None,
api_key: Optional[str] = None,
hf_token: Optional[str] = None,
) -> BaseEmbeddingModel:
"""
Load an embedding model by ID.
Args:
model_id: Model ID (HuggingFace model ID or API model like 'openai/text-embedding-3-large')
device: Device to use (for local models only)
api_key: API key (for API-based models, or uses environment variable)
hf_token: HuggingFace token for gated local models (or uses HF_TOKEN env var)
Returns:
Loaded embedding model instance
"""
model_id = model_id.strip()
# Check if this is an API model
if is_api_model(model_id):
# Check model type from config or prefix
model_config = API_MODELS.get(model_id, {})
model_type = model_config.get("type", "")
if model_type == "voyage" or model_id.startswith("voyage/"):
return VoyageEmbeddingModel(model_id, api_key=api_key)
elif model_type == "gemini" or model_id.startswith("gemini/"):
return GeminiEmbeddingModel(model_id, api_key=api_key)
elif model_type == "cohere" or model_id.startswith("cohere/"):
return CohereEmbeddingModel(model_id, api_key=api_key)
elif model_type == "openai" or model_id.startswith("openai/"):
return OpenAIEmbeddingModel(model_id, api_key=api_key)
else:
raise ValueError(f"Unknown API model type: {model_id}")
# Otherwise, load as a local sentence-transformer model
return EmbeddingModel(model_id, device=device, hf_token=hf_token)
def validate_model_id(model_id: str) -> tuple[bool, str]:
"""
Check if a model ID is valid and loadable.
Args:
model_id: The model ID to validate
Returns:
Tuple of (is_valid, error_message)
"""
if not model_id or not model_id.strip():
return False, "Model ID cannot be empty"
model_id = model_id.strip()
# Check if it's a curated local model
if model_id in CURATED_MODELS:
return True, ""
# Check if it's a known API model
if model_id in API_MODELS:
return True, ""
# Check for OpenAI models
if model_id.startswith("openai/"):
return True, ""
# Check for Voyage AI models
if model_id.startswith("voyage/"):
return True, ""
# Check for Gemini models
if model_id.startswith("gemini/"):
return True, ""
# Check for Cohere models
if model_id.startswith("cohere/"):
return True, ""
# For custom models, check if it looks like a valid HF model ID
if "/" not in model_id:
return False, "Model ID should be in format 'organization/model-name'"
# Could add an API check here, but that would slow down validation
return True, ""
def requires_api_key(model_id: str) -> bool:
"""Check if a model requires an API key."""
return is_api_model(model_id)
def api_key_optional(model_id: str) -> bool:
"""
Check if an API key is optional for this model.
Some providers (like Google Gemini) support Application Default Credentials
as an alternative to explicit API keys.
"""
key_type = get_api_key_type(model_id)
# Gemini supports ADC (gcloud auth application-default login)
return key_type == "gemini"
def get_api_key_type(model_id: str) -> Optional[str]:
"""
Get the type of API key required for a model.
Args:
model_id: The model ID
Returns:
'openai', 'voyage', or None if no API key needed
"""
if not is_api_model(model_id):
return None
model_id = model_id.strip()
model_config = API_MODELS.get(model_id, {})
model_type = model_config.get("type", "")
if model_type == "voyage" or model_id.startswith("voyage/"):
return "voyage"
elif model_type == "gemini" or model_id.startswith("gemini/"):
return "gemini"
elif model_type == "cohere" or model_id.startswith("cohere/"):
return "cohere"
elif model_type == "openai" or model_id.startswith("openai/"):
return "openai"
return None
def get_api_key_env_var(model_id: str) -> Optional[str]:
"""
Get the environment variable name for the API key required by a model.
Args:
model_id: The model ID
Returns:
Environment variable name or None
"""
key_type = get_api_key_type(model_id)
if key_type == "openai":
return "OPENAI_API_KEY"
elif key_type == "voyage":
return "VOYAGE_API_KEY"
elif key_type == "gemini":
return "GEMINI_API_KEY"
elif key_type == "cohere":
return "COHERE_API_KEY"
return None
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Test embedding model loading and encoding"
)
parser.add_argument(
"--local",
action="store_true",
help="Test only local sentence-transformer models",
)
parser.add_argument(
"--remote",
action="store_true",
help="Test only remote/API models (requires API keys)",
)
parser.add_argument(
"--model",
type=str,
default=None,
help="Test a specific model ID",
)
args = parser.parse_args()
# If neither flag specified, test both
test_local = args.local or (not args.local and not args.remote)
test_remote = args.remote or (not args.local and not args.remote)
print("Testing model loading...")
print(f"\nLocal models available:")
for model_id, display in get_curated_model_choices():
print(f" - {display}")
print(f"\nAPI models available:")
for model_id, display in get_api_model_choices():
print(f" - {display}")
# Test texts
test_texts = [
"讘专讗砖讬转 讘专讗 讗诇讛讬诐 讗转 讛砖诪讬诐 讜讗转 讛讗专抓",
"In the beginning God created the heaven and the earth",
]
def run_model_test(model_id: str, model_type: str):
"""Run a test for a specific model."""
print(f"\n{'='*60}")
print(f"Testing {model_type}: {model_id}")
print("="*60)
try:
model = load_model(model_id)
embeddings = model.encode(test_texts, show_progress=False)
print(f"\nEncoded {len(test_texts)} texts")
print(f"Embedding shape: {embeddings.shape}")
similarity = np.dot(embeddings[0], embeddings[1])
print(f"Cosine similarity between Hebrew and English: {similarity:.4f}")
return True
except Exception as e:
print(f"Test failed: {e}")
return False
# Test specific model if provided
if args.model:
run_model_test(args.model, "specified model")
else:
# Test local model
if test_local:
run_model_test(
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
"local sentence-transformer model"
)
# Test API models
if test_remote:
# Test OpenAI model
if os.environ.get("OPENAI_API_KEY"):
run_model_test(
"openai/text-embedding-3-small",
"OpenAI API model"
)
else:
print("\n(Skipping OpenAI test - OPENAI_API_KEY not set)")
# Test Voyage AI model
if os.environ.get("VOYAGE_API_KEY"):
run_model_test(
"voyage/voyage-3.5",
"Voyage AI API model"
)
else:
print("\n(Skipping Voyage AI test - VOYAGE_API_KEY not set)")
# Test Gemini model
if os.environ.get("GEMINI_API_KEY"):
run_model_test(
"gemini/gemini-embedding-001",
"Gemini API model"
)
else:
print("\n(Skipping Gemini test - GEMINI_API_KEY not set)")