DB_Chatbot / rag /embeddings.py
Vanshcc's picture
Upload 34 files
f9ad313 verified
"""
Embedding Generation Module.
Supports:
- Sentence Transformers (local, free)
- OpenAI Embeddings (cloud, paid)
Configurable via environment variables.
"""
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
import numpy as np
logger = logging.getLogger(__name__)
class EmbeddingProvider(ABC):
"""Abstract base class for embedding providers."""
@abstractmethod
def embed_text(self, text: str) -> np.ndarray:
"""Generate embedding for a single text."""
pass
@abstractmethod
def embed_texts(self, texts: List[str]) -> np.ndarray:
"""Generate embeddings for multiple texts."""
pass
@property
@abstractmethod
def dimension(self) -> int:
"""Return the embedding dimension."""
pass
class SentenceTransformerEmbedding(EmbeddingProvider):
"""
Sentence Transformers embedding provider.
Uses local models, no API key required.
Default: all-MiniLM-L6-v2 (384 dimensions)
"""
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
"""
Initialize the Sentence Transformer model.
Args:
model_name: HuggingFace model name
"""
self.model_name = model_name
self._model = None
self._dimension = None
@property
def model(self):
"""Lazy load the model."""
if self._model is None:
try:
from sentence_transformers import SentenceTransformer
logger.info(f"Loading embedding model: {self.model_name}")
self._model = SentenceTransformer(self.model_name)
self._dimension = self._model.get_sentence_embedding_dimension()
logger.info(f"Model loaded. Embedding dimension: {self._dimension}")
except ImportError:
raise ImportError(
"sentence-transformers is required. Install with: pip install sentence-transformers"
)
return self._model
@property
def dimension(self) -> int:
"""Get embedding dimension."""
if self._dimension is None:
_ = self.model # Force model load
return self._dimension
def embed_text(self, text: str) -> np.ndarray:
"""Generate embedding for a single text."""
return self.model.encode(text, convert_to_numpy=True)
def embed_texts(self, texts: List[str]) -> np.ndarray:
"""Generate embeddings for multiple texts."""
return self.model.encode(texts, convert_to_numpy=True, show_progress_bar=len(texts) > 100)
class OpenAIEmbedding(EmbeddingProvider):
"""
OpenAI embedding provider.
Uses OpenAI API, requires API key.
Default: text-embedding-3-small (1536 dimensions)
"""
DIMENSION_MAP = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536
}
def __init__(self, api_key: str, model_name: str = "text-embedding-3-small"):
"""
Initialize OpenAI embedding client.
Args:
api_key: OpenAI API key
model_name: OpenAI embedding model name
"""
self.api_key = api_key
self.model_name = model_name
self._client = None
self._dimension = self.DIMENSION_MAP.get(model_name, 1536)
@property
def client(self):
"""Lazy load the OpenAI client."""
if self._client is None:
try:
from openai import OpenAI
self._client = OpenAI(api_key=self.api_key)
except ImportError:
raise ImportError(
"openai is required. Install with: pip install openai"
)
return self._client
@property
def dimension(self) -> int:
"""Get embedding dimension."""
return self._dimension
def embed_text(self, text: str) -> np.ndarray:
"""Generate embedding for a single text."""
response = self.client.embeddings.create(
input=text,
model=self.model_name
)
return np.array(response.data[0].embedding, dtype=np.float32)
def embed_texts(self, texts: List[str]) -> np.ndarray:
"""Generate embeddings for multiple texts (batch)."""
# OpenAI API supports batching up to 2048 inputs
batch_size = 100
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
response = self.client.embeddings.create(
input=batch,
model=self.model_name
)
embeddings = [np.array(d.embedding, dtype=np.float32) for d in response.data]
all_embeddings.extend(embeddings)
return np.array(all_embeddings)
def create_embedding_provider(
provider_type: str = "sentence_transformers",
model_name: Optional[str] = None,
api_key: Optional[str] = None
) -> EmbeddingProvider:
"""
Factory function to create the appropriate embedding provider.
Args:
provider_type: "sentence_transformers" or "openai"
model_name: Model name (optional, uses defaults)
api_key: API key for OpenAI (required if using OpenAI)
Returns:
Configured EmbeddingProvider instance
"""
if provider_type == "openai":
if not api_key:
raise ValueError("OpenAI API key is required for OpenAI embeddings")
return OpenAIEmbedding(
api_key=api_key,
model_name=model_name or "text-embedding-3-small"
)
else:
return SentenceTransformerEmbedding(
model_name=model_name or "sentence-transformers/all-MiniLM-L6-v2"
)
# Global embedding provider instance
_embedding_provider: Optional[EmbeddingProvider] = None
def get_embedding_provider() -> EmbeddingProvider:
"""Get or create the global embedding provider."""
global _embedding_provider
if _embedding_provider is None:
# Default to sentence transformers (free, local)
_embedding_provider = SentenceTransformerEmbedding()
return _embedding_provider
def set_embedding_provider(provider: EmbeddingProvider):
"""Set the global embedding provider."""
global _embedding_provider
_embedding_provider = provider