Spaces:
Sleeping
Sleeping
| """ | |
| Embedding Generation Module. | |
| Supports: | |
| - Sentence Transformers (local, free) | |
| - OpenAI Embeddings (cloud, paid) | |
| Configurable via environment variables. | |
| """ | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from typing import List, Optional | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| class EmbeddingProvider(ABC): | |
| """Abstract base class for embedding providers.""" | |
| def embed_text(self, text: str) -> np.ndarray: | |
| """Generate embedding for a single text.""" | |
| pass | |
| def embed_texts(self, texts: List[str]) -> np.ndarray: | |
| """Generate embeddings for multiple texts.""" | |
| pass | |
| def dimension(self) -> int: | |
| """Return the embedding dimension.""" | |
| pass | |
| class SentenceTransformerEmbedding(EmbeddingProvider): | |
| """ | |
| Sentence Transformers embedding provider. | |
| Uses local models, no API key required. | |
| Default: all-MiniLM-L6-v2 (384 dimensions) | |
| """ | |
| def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"): | |
| """ | |
| Initialize the Sentence Transformer model. | |
| Args: | |
| model_name: HuggingFace model name | |
| """ | |
| self.model_name = model_name | |
| self._model = None | |
| self._dimension = None | |
| def model(self): | |
| """Lazy load the model.""" | |
| if self._model is None: | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| logger.info(f"Loading embedding model: {self.model_name}") | |
| self._model = SentenceTransformer(self.model_name) | |
| self._dimension = self._model.get_sentence_embedding_dimension() | |
| logger.info(f"Model loaded. Embedding dimension: {self._dimension}") | |
| except ImportError: | |
| raise ImportError( | |
| "sentence-transformers is required. Install with: pip install sentence-transformers" | |
| ) | |
| return self._model | |
| def dimension(self) -> int: | |
| """Get embedding dimension.""" | |
| if self._dimension is None: | |
| _ = self.model # Force model load | |
| return self._dimension | |
| def embed_text(self, text: str) -> np.ndarray: | |
| """Generate embedding for a single text.""" | |
| return self.model.encode(text, convert_to_numpy=True) | |
| def embed_texts(self, texts: List[str]) -> np.ndarray: | |
| """Generate embeddings for multiple texts.""" | |
| return self.model.encode(texts, convert_to_numpy=True, show_progress_bar=len(texts) > 100) | |
| class OpenAIEmbedding(EmbeddingProvider): | |
| """ | |
| OpenAI embedding provider. | |
| Uses OpenAI API, requires API key. | |
| Default: text-embedding-3-small (1536 dimensions) | |
| """ | |
| DIMENSION_MAP = { | |
| "text-embedding-3-small": 1536, | |
| "text-embedding-3-large": 3072, | |
| "text-embedding-ada-002": 1536 | |
| } | |
| def __init__(self, api_key: str, model_name: str = "text-embedding-3-small"): | |
| """ | |
| Initialize OpenAI embedding client. | |
| Args: | |
| api_key: OpenAI API key | |
| model_name: OpenAI embedding model name | |
| """ | |
| self.api_key = api_key | |
| self.model_name = model_name | |
| self._client = None | |
| self._dimension = self.DIMENSION_MAP.get(model_name, 1536) | |
| def client(self): | |
| """Lazy load the OpenAI client.""" | |
| if self._client is None: | |
| try: | |
| from openai import OpenAI | |
| self._client = OpenAI(api_key=self.api_key) | |
| except ImportError: | |
| raise ImportError( | |
| "openai is required. Install with: pip install openai" | |
| ) | |
| return self._client | |
| def dimension(self) -> int: | |
| """Get embedding dimension.""" | |
| return self._dimension | |
| def embed_text(self, text: str) -> np.ndarray: | |
| """Generate embedding for a single text.""" | |
| response = self.client.embeddings.create( | |
| input=text, | |
| model=self.model_name | |
| ) | |
| return np.array(response.data[0].embedding, dtype=np.float32) | |
| def embed_texts(self, texts: List[str]) -> np.ndarray: | |
| """Generate embeddings for multiple texts (batch).""" | |
| # OpenAI API supports batching up to 2048 inputs | |
| batch_size = 100 | |
| all_embeddings = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i:i + batch_size] | |
| response = self.client.embeddings.create( | |
| input=batch, | |
| model=self.model_name | |
| ) | |
| embeddings = [np.array(d.embedding, dtype=np.float32) for d in response.data] | |
| all_embeddings.extend(embeddings) | |
| return np.array(all_embeddings) | |
| def create_embedding_provider( | |
| provider_type: str = "sentence_transformers", | |
| model_name: Optional[str] = None, | |
| api_key: Optional[str] = None | |
| ) -> EmbeddingProvider: | |
| """ | |
| Factory function to create the appropriate embedding provider. | |
| Args: | |
| provider_type: "sentence_transformers" or "openai" | |
| model_name: Model name (optional, uses defaults) | |
| api_key: API key for OpenAI (required if using OpenAI) | |
| Returns: | |
| Configured EmbeddingProvider instance | |
| """ | |
| if provider_type == "openai": | |
| if not api_key: | |
| raise ValueError("OpenAI API key is required for OpenAI embeddings") | |
| return OpenAIEmbedding( | |
| api_key=api_key, | |
| model_name=model_name or "text-embedding-3-small" | |
| ) | |
| else: | |
| return SentenceTransformerEmbedding( | |
| model_name=model_name or "sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| # Global embedding provider instance | |
| _embedding_provider: Optional[EmbeddingProvider] = None | |
| def get_embedding_provider() -> EmbeddingProvider: | |
| """Get or create the global embedding provider.""" | |
| global _embedding_provider | |
| if _embedding_provider is None: | |
| # Default to sentence transformers (free, local) | |
| _embedding_provider = SentenceTransformerEmbedding() | |
| return _embedding_provider | |
| def set_embedding_provider(provider: EmbeddingProvider): | |
| """Set the global embedding provider.""" | |
| global _embedding_provider | |
| _embedding_provider = provider | |