""" Model loading and embedding interface for the Rabbinic embedding benchmark. Supports: - Curated models from Hugging Face (sentence-transformers) - Any Hugging Face sentence-transformer model - API-based models (OpenAI, Voyage AI, Google Gemini) """ import os from abc import ABC, abstractmethod from typing import Optional import numpy as np # Curated local models known to work well for multilingual tasks CURATED_MODELS = { "intfloat/multilingual-e5-large": { "name": "Multilingual E5 Large", "description": "Strong multilingual model from Microsoft, 560M params", "type": "local", "query_prefix": "query: ", "passage_prefix": "passage: ", }, "intfloat/multilingual-e5-base": { "name": "Multilingual E5 Base", "description": "Smaller multilingual E5, 278M params", "type": "local", "query_prefix": "query: ", "passage_prefix": "passage: ", }, "sentence-transformers/paraphrase-multilingual-mpnet-base-v2": { "name": "Multilingual MPNet", "description": "Classic multilingual sentence transformer, 278M params", "type": "local", "query_prefix": "", "passage_prefix": "", }, "BAAI/bge-m3": { "name": "BGE-M3", "description": "Multi-lingual, multi-functionality, multi-granularity model from BAAI", "type": "local", "query_prefix": "", "passage_prefix": "", }, "intfloat/e5-mistral-7b-instruct": { "name": "E5 Mistral 7B", "description": "Large instruction-tuned embedding model, 7B params (requires GPU)", "type": "local", "query_prefix": "Instruct: Retrieve semantically similar text\nQuery: ", "passage_prefix": "", }, "Alibaba-NLP/gte-multilingual-base": { "name": "GTE Multilingual Base", "description": "General Text Embeddings multilingual model from Alibaba", "type": "local", "query_prefix": "", "passage_prefix": "", }, "google/embeddinggemma-300m": { "name": "EmbeddingGemma", "description": "Google's 300M param embedding model, 100+ languages, 768d (requires HF token + license)", "type": "local", "query_prefix": "task: search result | query: ", "passage_prefix": "title: none | text: ", "max_length": 2048, }, } # API-based models API_MODELS = { "openai/text-embedding-3-large": { "name": "OpenAI text-embedding-3-large", "description": "OpenAI's best embedding model, 3072 dimensions (API key required)", "type": "openai", "model_name": "text-embedding-3-large", "dimensions": 3072, }, "openai/text-embedding-3-small": { "name": "OpenAI text-embedding-3-small", "description": "OpenAI's efficient embedding model, 1536 dimensions (API key required)", "type": "openai", "model_name": "text-embedding-3-small", "dimensions": 1536, }, "openai/text-embedding-ada-002": { "name": "OpenAI Ada 002", "description": "OpenAI's legacy embedding model, 1536 dimensions (API key required)", "type": "openai", "model_name": "text-embedding-ada-002", "dimensions": 1536, }, "voyage/voyage-3.5": { "name": "Voyage AI voyage-3.5", "description": "Voyage AI's latest embedding model (API key required)", "type": "voyage", "model_name": "voyage-3.5", "dimensions": 1024, }, "voyage/voyage-3.5-lite": { "name": "Voyage AI voyage-3.5-lite", "description": "Voyage AI's efficient embedding model (API key required)", "type": "voyage", "model_name": "voyage-3.5-lite", "dimensions": 1024, }, "voyage/voyage-3": { "name": "Voyage AI voyage-3", "description": "Voyage AI's general purpose embedding model (API key required)", "type": "voyage", "model_name": "voyage-3", "dimensions": 1024, }, "voyage/voyage-3-lite": { "name": "Voyage AI voyage-3-lite", "description": "Voyage AI's lightweight embedding model (API key required)", "type": "voyage", "model_name": "voyage-3-lite", "dimensions": 512, }, "voyage/voyage-multilingual-2": { "name": "Voyage AI voyage-multilingual-2", "description": "Voyage AI's multilingual embedding model, optimized for non-English (API key required)", "type": "voyage", "model_name": "voyage-multilingual-2", "dimensions": 1024, }, "gemini/gemini-embedding-001": { "name": "Gemini Embedding 001", "description": "Google's Gemini embedding model, 3072 dimensions (API key required)", "type": "gemini", "model_name": "gemini-embedding-001", "dimensions": 3072, }, "gemini/gemini-embedding-001-768": { "name": "Gemini Embedding 001 (768d)", "description": "Google's Gemini embedding model, 768 dimensions (API key required)", "type": "gemini", "model_name": "gemini-embedding-001", "dimensions": 768, }, "gemini/gemini-embedding-001-1536": { "name": "Gemini Embedding 001 (1536d)", "description": "Google's Gemini embedding model, 1536 dimensions (API key required)", "type": "gemini", "model_name": "gemini-embedding-001", "dimensions": 1536, }, "cohere/embed-multilingual-v3.0": { "name": "Cohere embed-multilingual-v3.0", "description": "Cohere's multilingual embedding model, 100+ languages (API key required)", "type": "cohere", "model_name": "embed-multilingual-v3.0", "dimensions": 1024, }, "cohere/embed-multilingual-light-v3.0": { "name": "Cohere embed-multilingual-light-v3.0", "description": "Cohere's lightweight multilingual model (API key required)", "type": "cohere", "model_name": "embed-multilingual-light-v3.0", "dimensions": 384, }, } # Merge all models for easy lookup ALL_MODELS = {**CURATED_MODELS, **API_MODELS} class BaseEmbeddingModel(ABC): """Abstract base class for embedding models.""" model_id: str embedding_dim: int @abstractmethod def encode( self, texts: list[str], is_query: bool = False, batch_size: int = 32, show_progress: bool = True, normalize: bool = True, ) -> np.ndarray: """Encode texts to embeddings.""" pass @property @abstractmethod def name(self) -> str: """Get display name for the model.""" pass @property @abstractmethod def description(self) -> str: """Get description for the model.""" pass def encode_pairs( self, he_texts: list[str], en_texts: list[str], batch_size: int = 32, show_progress: bool = True, ) -> tuple[np.ndarray, np.ndarray]: """ Encode parallel Hebrew/English text pairs. Args: he_texts: Hebrew/Aramaic source texts en_texts: English translations batch_size: Batch size for encoding show_progress: Whether to show progress bar Returns: Tuple of (hebrew_embeddings, english_embeddings) """ he_embeddings = self.encode( he_texts, is_query=True, batch_size=batch_size, show_progress=show_progress, ) en_embeddings = self.encode( en_texts, is_query=False, batch_size=batch_size, show_progress=show_progress, ) return he_embeddings, en_embeddings class EmbeddingModel(BaseEmbeddingModel): """ Wrapper for sentence-transformer models with consistent interface. """ def __init__( self, model_id: str, device: Optional[str] = None, max_length: int = 512, hf_token: Optional[str] = None, ): """ Initialize the embedding model. Args: model_id: Hugging Face model ID device: Device to use ('cuda', 'cpu', or None for auto) max_length: Maximum sequence length for tokenization hf_token: HuggingFace token for gated models (or uses HF_TOKEN env var) """ from sentence_transformers import SentenceTransformer import torch self.model_id = model_id # Auto-detect device if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device # Get model config if it's a curated model self.config = CURATED_MODELS.get(model_id, { "name": model_id.split("/")[-1], "description": "Custom model", "type": "local", "query_prefix": "", "passage_prefix": "", }) # Use config max_length if available, otherwise use parameter self.max_length = self.config.get("max_length", max_length) # Get HF token from parameter or environment (for gated models like EmbeddingGemma) hf_token = hf_token or os.environ.get("HF_TOKEN") # Load the model with float16 on CUDA to save VRAM # (12B model: float32 = 48GB, float16 = 24GB) print(f"Loading model: {model_id} on {device}") # Only trust remote code from known publishers (security measure) trusted_publishers = ["nvidia/", "google/"] trust_remote_code = any(model_id.startswith(pub) for pub in trusted_publishers) if device == "cuda": self.model = SentenceTransformer( model_id, device=device, model_kwargs={"torch_dtype": torch.float16}, trust_remote_code=trust_remote_code, token=hf_token, ) else: self.model = SentenceTransformer( model_id, device=device, trust_remote_code=trust_remote_code, token=hf_token, ) # Set max sequence length if supported if hasattr(self.model, "max_seq_length"): self.model.max_seq_length = min(self.max_length, self.model.max_seq_length) self.embedding_dim = self.model.get_sentence_embedding_dimension() print(f"Model loaded. Embedding dimension: {self.embedding_dim}") def encode( self, texts: list[str], is_query: bool = False, batch_size: int = 32, show_progress: bool = True, normalize: bool = True, ) -> np.ndarray: """ Encode texts to embeddings. Args: texts: List of texts to encode is_query: Whether these are queries (vs passages) for asymmetric models batch_size: Batch size for encoding show_progress: Whether to show progress bar normalize: Whether to L2-normalize embeddings Returns: numpy array of shape (len(texts), embedding_dim) """ # Add prefix if needed (for E5-style models) prefix = self.config["query_prefix"] if is_query else self.config["passage_prefix"] if prefix: texts = [prefix + t for t in texts] embeddings = self.model.encode( texts, batch_size=batch_size, show_progress_bar=show_progress, normalize_embeddings=normalize, convert_to_numpy=True, ) return embeddings @property def name(self) -> str: """Get display name for the model.""" return self.config.get("name", self.model_id) @property def description(self) -> str: """Get description for the model.""" return self.config.get("description", "") class OpenAIEmbeddingModel(BaseEmbeddingModel): """ Wrapper for OpenAI embedding API with consistent interface. """ # OpenAI embedding models have an 8191 token limit MAX_TOKENS = 8191 def __init__( self, model_id: str, api_key: Optional[str] = None, ): """ Initialize the OpenAI embedding model. Args: model_id: Model ID in format 'openai/model-name' api_key: OpenAI API key (or uses OPENAI_API_KEY env var) """ try: from openai import OpenAI except ImportError: raise ImportError( "OpenAI package not installed. Install with: pip install openai" ) self.model_id = model_id # Get API key from parameter or environment api_key = api_key or os.environ.get("OPENAI_API_KEY") if not api_key: raise ValueError( "OpenAI API key required. Set OPENAI_API_KEY environment variable " "or pass api_key parameter." ) self.client = OpenAI(api_key=api_key) # Get model config self.config = API_MODELS.get(model_id, { "name": model_id, "description": "OpenAI embedding model", "type": "openai", "model_name": model_id.replace("openai/", ""), "dimensions": 1536, }) self._model_name = self.config["model_name"] self.embedding_dim = self.config["dimensions"] # Initialize tokenizer for truncation self._encoding = None try: import tiktoken self._encoding = tiktoken.encoding_for_model(self._model_name) except Exception: # Fall back to cl100k_base which is used by embedding models try: import tiktoken self._encoding = tiktoken.get_encoding("cl100k_base") except Exception: print("Warning: tiktoken not available, using character-based truncation") print(f"Initialized OpenAI embedding model: {self._model_name}") print(f"Embedding dimension: {self.embedding_dim}") def _truncate_text(self, text: str) -> str: """Truncate text to fit within token limit.""" if self._encoding is not None: # Use tiktoken for accurate token counting tokens = self._encoding.encode(text) if len(tokens) > self.MAX_TOKENS: tokens = tokens[:self.MAX_TOKENS] return self._encoding.decode(tokens) return text else: # Fallback: rough character-based truncation # Assume ~3 chars per token for Hebrew/mixed text (conservative) max_chars = self.MAX_TOKENS * 3 if len(text) > max_chars: return text[:max_chars] return text def encode( self, texts: list[str], is_query: bool = False, batch_size: int = 100, # OpenAI supports larger batches show_progress: bool = True, normalize: bool = True, ) -> np.ndarray: """ Encode texts to embeddings using OpenAI API. Args: texts: List of texts to encode is_query: Not used for OpenAI (symmetric embeddings) batch_size: Batch size for API calls show_progress: Whether to show progress bar normalize: Whether to L2-normalize embeddings (OpenAI already normalizes) Returns: numpy array of shape (len(texts), embedding_dim) """ import time all_embeddings = [] total_batches = (len(texts) + batch_size - 1) // batch_size for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] batch_num = i // batch_size + 1 if show_progress: print(f" Encoding batch {batch_num}/{total_batches}...") # Retry logic for API calls max_retries = 3 for attempt in range(max_retries): try: response = self.client.embeddings.create( model=self._model_name, input=batch, ) # Extract embeddings from response batch_embeddings = [item.embedding for item in response.data] all_embeddings.extend(batch_embeddings) break except Exception as e: if attempt < max_retries - 1: wait_time = 2 ** attempt print(f" API error, retrying in {wait_time}s: {e}") time.sleep(wait_time) else: raise RuntimeError(f"OpenAI API error after {max_retries} retries: {e}") # Small delay to avoid rate limits if i + batch_size < len(texts): time.sleep(0.1) embeddings = np.array(all_embeddings, dtype=np.float32) # OpenAI embeddings are already normalized, but normalize if requested if normalize: norms = np.linalg.norm(embeddings, axis=1, keepdims=True) embeddings = embeddings / np.maximum(norms, 1e-10) return embeddings @property def name(self) -> str: """Get display name for the model.""" return self.config.get("name", self.model_id) @property def description(self) -> str: """Get description for the model.""" return self.config.get("description", "") class VoyageEmbeddingModel(BaseEmbeddingModel): """ Wrapper for Voyage AI embedding API with consistent interface. """ def __init__( self, model_id: str, api_key: Optional[str] = None, ): """ Initialize the Voyage AI embedding model. Args: model_id: Model ID in format 'voyage/model-name' api_key: Voyage API key (or uses VOYAGE_API_KEY env var) """ try: import voyageai except ImportError: raise ImportError( "Voyage AI package not installed. Install with: pip install voyageai" ) self.model_id = model_id # Get API key from parameter or environment api_key = api_key or os.environ.get("VOYAGE_API_KEY") if not api_key: raise ValueError( "Voyage API key required. Set VOYAGE_API_KEY environment variable " "or pass api_key parameter." ) self.client = voyageai.Client(api_key=api_key) # Get model config self.config = API_MODELS.get(model_id, { "name": model_id, "description": "Voyage AI embedding model", "type": "voyage", "model_name": model_id.replace("voyage/", ""), "dimensions": 1024, # Default dimension }) self._model_name = self.config["model_name"] self.embedding_dim = self.config["dimensions"] print(f"Initialized Voyage AI embedding model: {self._model_name}") print(f"Embedding dimension: {self.embedding_dim}") def encode( self, texts: list[str], is_query: bool = False, batch_size: int = 128, # Voyage supports larger batches show_progress: bool = True, normalize: bool = True, ) -> np.ndarray: """ Encode texts to embeddings using Voyage AI API. Args: texts: List of texts to encode is_query: Whether these are queries (Voyage supports input_type) batch_size: Batch size for API calls show_progress: Whether to show progress bar normalize: Whether to L2-normalize embeddings Returns: numpy array of shape (len(texts), embedding_dim) """ import time all_embeddings = [] total_batches = (len(texts) + batch_size - 1) // batch_size # Voyage supports input_type for asymmetric embeddings input_type = "query" if is_query else "document" for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] batch_num = i // batch_size + 1 if show_progress: print(f" Encoding batch {batch_num}/{total_batches}...") # Retry logic for API calls max_retries = 3 for attempt in range(max_retries): try: result = self.client.embed( batch, model=self._model_name, input_type=input_type, ) # Extract embeddings from response batch_embeddings = result.embeddings all_embeddings.extend(batch_embeddings) break except Exception as e: if attempt < max_retries - 1: wait_time = 2 ** attempt print(f" API error, retrying in {wait_time}s: {e}") time.sleep(wait_time) else: raise RuntimeError(f"Voyage AI API error after {max_retries} retries: {e}") # Small delay to avoid rate limits if i + batch_size < len(texts): time.sleep(0.1) embeddings = np.array(all_embeddings, dtype=np.float32) # Normalize if requested if normalize: norms = np.linalg.norm(embeddings, axis=1, keepdims=True) embeddings = embeddings / np.maximum(norms, 1e-10) return embeddings @property def name(self) -> str: """Get display name for the model.""" return self.config.get("name", self.model_id) @property def description(self) -> str: """Get description for the model.""" return self.config.get("description", "") class GeminiEmbeddingModel(BaseEmbeddingModel): """ Wrapper for Google Gemini embedding API with consistent interface. """ def __init__( self, model_id: str, api_key: Optional[str] = None, ): """ Initialize the Gemini embedding model. Args: model_id: Model ID in format 'gemini/model-name' api_key: Gemini API key (optional - can use GEMINI_API_KEY env var or Google Cloud Application Default Credentials) """ try: from google import genai except ImportError: raise ImportError( "Google GenAI package not installed. Install with: pip install google-genai" ) self.model_id = model_id # Get API key from parameter or environment (optional - ADC also works) api_key = api_key or os.environ.get("GEMINI_API_KEY") # Create client - if no API key, will use Application Default Credentials if api_key: self.client = genai.Client(api_key=api_key) else: # Use Application Default Credentials (gcloud auth application-default login) self.client = genai.Client() # Get model config self.config = API_MODELS.get(model_id, { "name": model_id, "description": "Gemini embedding model", "type": "gemini", "model_name": model_id.replace("gemini/", "").split("-768")[0].split("-1536")[0], "dimensions": 3072, # Default dimension }) self._model_name = self.config["model_name"] self.embedding_dim = self.config["dimensions"] print(f"Initialized Gemini embedding model: {self._model_name}") print(f"Embedding dimension: {self.embedding_dim}") def encode( self, texts: list[str], is_query: bool = False, batch_size: int = 20, # Smaller batches to avoid rate limits show_progress: bool = True, normalize: bool = True, ) -> np.ndarray: """ Encode texts to embeddings using Gemini API. Args: texts: List of texts to encode is_query: Whether these are queries (uses RETRIEVAL_QUERY vs RETRIEVAL_DOCUMENT) batch_size: Batch size for API calls (smaller for Gemini to avoid rate limits) show_progress: Whether to show progress bar normalize: Whether to L2-normalize embeddings Returns: numpy array of shape (len(texts), embedding_dim) """ import time import random from google.genai import types all_embeddings = [] total_batches = (len(texts) + batch_size - 1) // batch_size # Gemini supports task_type for asymmetric embeddings task_type = "RETRIEVAL_QUERY" if is_query else "RETRIEVAL_DOCUMENT" for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] batch_num = i // batch_size + 1 if show_progress: print(f" Encoding batch {batch_num}/{total_batches}...") # Retry logic with exponential backoff for rate limits max_retries = 8 base_delay = 2.0 for attempt in range(max_retries): try: # Build config with task type and output dimensionality embed_config = types.EmbedContentConfig( task_type=task_type, output_dimensionality=self.embedding_dim, ) result = self.client.models.embed_content( model=self._model_name, contents=batch, config=embed_config, ) # Extract embeddings from response batch_embeddings = [e.values for e in result.embeddings] all_embeddings.extend(batch_embeddings) break except Exception as e: error_str = str(e) is_rate_limit = "429" in error_str or "RESOURCE_EXHAUSTED" in error_str if attempt < max_retries - 1: # Exponential backoff with jitter # Longer waits for rate limit errors if is_rate_limit: wait_time = base_delay * (2 ** attempt) + random.uniform(1, 5) print(f" Rate limited, waiting {wait_time:.1f}s before retry {attempt + 2}/{max_retries}...") else: wait_time = base_delay * (2 ** attempt) + random.uniform(0, 1) print(f" API error, retrying in {wait_time:.1f}s: {e}") time.sleep(wait_time) else: raise RuntimeError(f"Gemini API error after {max_retries} retries: {e}") # Delay between batches to avoid rate limits (longer for Gemini) if i + batch_size < len(texts): time.sleep(0.5) embeddings = np.array(all_embeddings, dtype=np.float32) # Normalize if requested (Gemini's 3072d is normalized, but smaller dims need it) if normalize: norms = np.linalg.norm(embeddings, axis=1, keepdims=True) embeddings = embeddings / np.maximum(norms, 1e-10) return embeddings @property def name(self) -> str: """Get display name for the model.""" return self.config.get("name", self.model_id) @property def description(self) -> str: """Get description for the model.""" return self.config.get("description", "") class CohereEmbeddingModel(BaseEmbeddingModel): """ Wrapper for Cohere embedding API with consistent interface. """ def __init__( self, model_id: str, api_key: Optional[str] = None, ): """ Initialize the Cohere embedding model. Args: model_id: Model ID in format 'cohere/model-name' api_key: Cohere API key (or uses COHERE_API_KEY env var) """ try: import cohere except ImportError: raise ImportError( "Cohere package not installed. Install with: pip install cohere" ) self.model_id = model_id # Get API key from parameter or environment api_key = api_key or os.environ.get("COHERE_API_KEY") if not api_key: raise ValueError( "Cohere API key required. Set COHERE_API_KEY environment variable " "or pass api_key parameter." ) self.client = cohere.Client(api_key=api_key) # Get model config self.config = API_MODELS.get(model_id, { "name": model_id, "description": "Cohere embedding model", "type": "cohere", "model_name": model_id.replace("cohere/", ""), "dimensions": 1024, # Default dimension }) self._model_name = self.config["model_name"] self.embedding_dim = self.config["dimensions"] print(f"Initialized Cohere embedding model: {self._model_name}") print(f"Embedding dimension: {self.embedding_dim}") def encode( self, texts: list[str], is_query: bool = False, batch_size: int = 96, # Cohere supports up to 96 texts per request show_progress: bool = True, normalize: bool = True, ) -> np.ndarray: """ Encode texts to embeddings using Cohere API. Args: texts: List of texts to encode is_query: Whether these are queries (uses search_query vs search_document) batch_size: Batch size for API calls show_progress: Whether to show progress bar normalize: Whether to L2-normalize embeddings Returns: numpy array of shape (len(texts), embedding_dim) """ import time all_embeddings = [] total_batches = (len(texts) + batch_size - 1) // batch_size # Cohere v3 models require input_type for asymmetric embeddings input_type = "search_query" if is_query else "search_document" for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] batch_num = i // batch_size + 1 if show_progress: print(f" Encoding batch {batch_num}/{total_batches}...") # Retry logic for API calls max_retries = 3 for attempt in range(max_retries): try: result = self.client.embed( texts=batch, model=self._model_name, input_type=input_type, ) # Extract embeddings from response batch_embeddings = result.embeddings all_embeddings.extend(batch_embeddings) break except Exception as e: if attempt < max_retries - 1: wait_time = 2 ** attempt print(f" API error, retrying in {wait_time}s: {e}") time.sleep(wait_time) else: raise RuntimeError(f"Cohere API error after {max_retries} retries: {e}") # Small delay to avoid rate limits if i + batch_size < len(texts): time.sleep(0.1) embeddings = np.array(all_embeddings, dtype=np.float32) # Normalize if requested if normalize: norms = np.linalg.norm(embeddings, axis=1, keepdims=True) embeddings = embeddings / np.maximum(norms, 1e-10) return embeddings @property def name(self) -> str: """Get display name for the model.""" return self.config.get("name", self.model_id) @property def description(self) -> str: """Get description for the model.""" return self.config.get("description", "") def get_curated_model_choices() -> list[tuple[str, str]]: """ Get list of curated local models for UI dropdown. Returns: List of (model_id, display_name) tuples """ return [ (model_id, f"{info['name']} - {info['description']}") for model_id, info in CURATED_MODELS.items() ] def get_api_model_choices() -> list[tuple[str, str]]: """ Get list of API-based models for UI dropdown. Returns: List of (model_id, display_name) tuples """ return [ (model_id, f"{info['name']} - {info['description']}") for model_id, info in API_MODELS.items() ] def get_all_model_choices() -> list[tuple[str, str]]: """ Get list of all models (local + API) for UI dropdown. Returns: List of (model_id, display_name) tuples """ return get_curated_model_choices() + get_api_model_choices() def is_api_model(model_id: str) -> bool: """Check if a model ID is an API-based model.""" model_id = model_id.strip() # Check if it's in API_MODELS if model_id in API_MODELS: return True # Check if it starts with known API prefixes if model_id.startswith("openai/"): return True if model_id.startswith("voyage/"): return True if model_id.startswith("gemini/"): return True if model_id.startswith("cohere/"): return True return False def load_model( model_id: str, device: Optional[str] = None, api_key: Optional[str] = None, hf_token: Optional[str] = None, ) -> BaseEmbeddingModel: """ Load an embedding model by ID. Args: model_id: Model ID (HuggingFace model ID or API model like 'openai/text-embedding-3-large') device: Device to use (for local models only) api_key: API key (for API-based models, or uses environment variable) hf_token: HuggingFace token for gated local models (or uses HF_TOKEN env var) Returns: Loaded embedding model instance """ model_id = model_id.strip() # Check if this is an API model if is_api_model(model_id): # Check model type from config or prefix model_config = API_MODELS.get(model_id, {}) model_type = model_config.get("type", "") if model_type == "voyage" or model_id.startswith("voyage/"): return VoyageEmbeddingModel(model_id, api_key=api_key) elif model_type == "gemini" or model_id.startswith("gemini/"): return GeminiEmbeddingModel(model_id, api_key=api_key) elif model_type == "cohere" or model_id.startswith("cohere/"): return CohereEmbeddingModel(model_id, api_key=api_key) elif model_type == "openai" or model_id.startswith("openai/"): return OpenAIEmbeddingModel(model_id, api_key=api_key) else: raise ValueError(f"Unknown API model type: {model_id}") # Otherwise, load as a local sentence-transformer model return EmbeddingModel(model_id, device=device, hf_token=hf_token) def validate_model_id(model_id: str) -> tuple[bool, str]: """ Check if a model ID is valid and loadable. Args: model_id: The model ID to validate Returns: Tuple of (is_valid, error_message) """ if not model_id or not model_id.strip(): return False, "Model ID cannot be empty" model_id = model_id.strip() # Check if it's a curated local model if model_id in CURATED_MODELS: return True, "" # Check if it's a known API model if model_id in API_MODELS: return True, "" # Check for OpenAI models if model_id.startswith("openai/"): return True, "" # Check for Voyage AI models if model_id.startswith("voyage/"): return True, "" # Check for Gemini models if model_id.startswith("gemini/"): return True, "" # Check for Cohere models if model_id.startswith("cohere/"): return True, "" # For custom models, check if it looks like a valid HF model ID if "/" not in model_id: return False, "Model ID should be in format 'organization/model-name'" # Could add an API check here, but that would slow down validation return True, "" def requires_api_key(model_id: str) -> bool: """Check if a model requires an API key.""" return is_api_model(model_id) def api_key_optional(model_id: str) -> bool: """ Check if an API key is optional for this model. Some providers (like Google Gemini) support Application Default Credentials as an alternative to explicit API keys. """ key_type = get_api_key_type(model_id) # Gemini supports ADC (gcloud auth application-default login) return key_type == "gemini" def get_api_key_type(model_id: str) -> Optional[str]: """ Get the type of API key required for a model. Args: model_id: The model ID Returns: 'openai', 'voyage', or None if no API key needed """ if not is_api_model(model_id): return None model_id = model_id.strip() model_config = API_MODELS.get(model_id, {}) model_type = model_config.get("type", "") if model_type == "voyage" or model_id.startswith("voyage/"): return "voyage" elif model_type == "gemini" or model_id.startswith("gemini/"): return "gemini" elif model_type == "cohere" or model_id.startswith("cohere/"): return "cohere" elif model_type == "openai" or model_id.startswith("openai/"): return "openai" return None def get_api_key_env_var(model_id: str) -> Optional[str]: """ Get the environment variable name for the API key required by a model. Args: model_id: The model ID Returns: Environment variable name or None """ key_type = get_api_key_type(model_id) if key_type == "openai": return "OPENAI_API_KEY" elif key_type == "voyage": return "VOYAGE_API_KEY" elif key_type == "gemini": return "GEMINI_API_KEY" elif key_type == "cohere": return "COHERE_API_KEY" return None if __name__ == "__main__": import argparse parser = argparse.ArgumentParser( description="Test embedding model loading and encoding" ) parser.add_argument( "--local", action="store_true", help="Test only local sentence-transformer models", ) parser.add_argument( "--remote", action="store_true", help="Test only remote/API models (requires API keys)", ) parser.add_argument( "--model", type=str, default=None, help="Test a specific model ID", ) args = parser.parse_args() # If neither flag specified, test both test_local = args.local or (not args.local and not args.remote) test_remote = args.remote or (not args.local and not args.remote) print("Testing model loading...") print(f"\nLocal models available:") for model_id, display in get_curated_model_choices(): print(f" - {display}") print(f"\nAPI models available:") for model_id, display in get_api_model_choices(): print(f" - {display}") # Test texts test_texts = [ "בראשית ברא אלהים את השמים ואת הארץ", "In the beginning God created the heaven and the earth", ] def run_model_test(model_id: str, model_type: str): """Run a test for a specific model.""" print(f"\n{'='*60}") print(f"Testing {model_type}: {model_id}") print("="*60) try: model = load_model(model_id) embeddings = model.encode(test_texts, show_progress=False) print(f"\nEncoded {len(test_texts)} texts") print(f"Embedding shape: {embeddings.shape}") similarity = np.dot(embeddings[0], embeddings[1]) print(f"Cosine similarity between Hebrew and English: {similarity:.4f}") return True except Exception as e: print(f"Test failed: {e}") return False # Test specific model if provided if args.model: run_model_test(args.model, "specified model") else: # Test local model if test_local: run_model_test( "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", "local sentence-transformer model" ) # Test API models if test_remote: # Test OpenAI model if os.environ.get("OPENAI_API_KEY"): run_model_test( "openai/text-embedding-3-small", "OpenAI API model" ) else: print("\n(Skipping OpenAI test - OPENAI_API_KEY not set)") # Test Voyage AI model if os.environ.get("VOYAGE_API_KEY"): run_model_test( "voyage/voyage-3.5", "Voyage AI API model" ) else: print("\n(Skipping Voyage AI test - VOYAGE_API_KEY not set)") # Test Gemini model if os.environ.get("GEMINI_API_KEY"): run_model_test( "gemini/gemini-embedding-001", "Gemini API model" ) else: print("\n(Skipping Gemini test - GEMINI_API_KEY not set)")