Spaces:
Sleeping
Sleeping
| import os | |
| from collections.abc import Sequence | |
| from typing import Literal | |
| import logfire | |
| import ollama | |
| from openai import OpenAI | |
| class Embedder: | |
| def __init__( | |
| self, | |
| *, | |
| provider: Literal["ollama", "openai", "sentence-transformers"], | |
| model_name: str, | |
| vector_type: str, | |
| base_url: str | None = None, | |
| ): | |
| """ | |
| Initialize embedder with specified provider. | |
| Params: | |
| ====== | |
| - provider: "ollama", "openai", or "sentence-transformers" | |
| - model_name: model name (e.g., "nomic-embed-text" for Ollama, "text-embedding-3-small" for OpenAI) | |
| - vector_type: vector type for database (e.g., "F32", "I8") | |
| - base_url: optional base URL for OpenAI-compatible APIs | |
| """ | |
| self._provider: Literal["ollama", "openai", "sentence-transformers"] = ( | |
| provider | |
| ) | |
| self.model_name: str = model_name | |
| self.vector_type: str = vector_type | |
| self._st_model = None | |
| # Initialize clients if needed | |
| if provider == "openai": | |
| api_key = os.getenv("OPENAI_API_KEY") or os.getenv( | |
| "BLABLADOR_API_KEY" | |
| ) | |
| if not api_key: | |
| raise ValueError( | |
| "OPENAI_API_KEY or BLABLADOR_API_KEY environment variable not set" | |
| ) | |
| resolved_base_url = ( | |
| base_url | |
| or os.getenv("OPENAI_BASE_URL") | |
| or os.getenv("OPENAI_API_BASE") | |
| or os.getenv("BLABLADOR_BASE_URL") | |
| ) | |
| if resolved_base_url: | |
| resolved_base_url = resolved_base_url.rstrip("/") + "/" | |
| self._openai_client = OpenAI( | |
| api_key=api_key, base_url=resolved_base_url | |
| ) | |
| else: | |
| self._openai_client = OpenAI(api_key=api_key) | |
| elif provider == "sentence-transformers": | |
| from sentence_transformers import SentenceTransformer | |
| self._openai_client = None | |
| self._st_model = SentenceTransformer(self.model_name) | |
| else: | |
| self._openai_client = None | |
| # Detect dimension by embedding a test string | |
| if provider == "ollama": | |
| vec = ollama.embed(model=self.model_name, input="hi").embeddings[0] | |
| self.dimension: int = len(vec) | |
| elif provider == "sentence-transformers": | |
| if self._st_model is None: | |
| raise ValueError("SentenceTransformer model not initialized") | |
| self.dimension = int( | |
| self._st_model.get_sentence_embedding_dimension() | |
| ) | |
| else: | |
| _ = logfire.instrument_openai() | |
| if self._openai_client is None: | |
| raise ValueError("OpenAI client not initialized") | |
| response = self._openai_client.embeddings.create( | |
| model=self.model_name, input="hi" | |
| ) | |
| self.dimension = len(response.data[0].embedding) | |
| def max_length(self) -> int: | |
| """Maximum token length supported for embeddings for the current provider/model.""" | |
| if self._provider == "ollama": | |
| # Try to derive from local model metadata | |
| try: | |
| info = ollama.show(self.model_name) | |
| if isinstance(info, dict): | |
| model_info = info.get("model_info") or {} # pyright: ignore[reportUnknownVariableType] | |
| if isinstance(model_info, dict): | |
| for key in ( | |
| "num_ctx", | |
| "context_length", | |
| "max_context_length", | |
| "ctx", | |
| ): | |
| val = model_info.get(key) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] | |
| if val is not None: | |
| try: | |
| return int(val) # pyright: ignore[reportUnknownArgumentType] | |
| except Exception: | |
| try: | |
| return int(str(val)) # pyright: ignore[reportUnknownArgumentType] | |
| except Exception: | |
| pass | |
| params = info.get("parameters") # pyright: ignore[reportAny] | |
| if isinstance(params, str): | |
| import re | |
| m = re.search(r"\bnum_ctx\s+(\d+)\b", params) | |
| if m: | |
| return int(m.group(1)) | |
| except Exception: | |
| # If we cannot inspect the model, fall back to sensible defaults for common models. | |
| pass | |
| name = self.model_name.lower() | |
| if "nomic-embed" in name: | |
| return 8192 | |
| if "bge" in name: | |
| return 4096 | |
| # Generic fallback for most models in Ollama | |
| return 8192 | |
| elif self._provider == "openai": | |
| name = self.model_name.lower() | |
| # OpenAI embedding models support up to 8192 tokens | |
| if name in ( | |
| "text-embedding-3-small", | |
| "text-embedding-3-large", | |
| "text-embedding-ada-002", | |
| ): | |
| return 8192 | |
| if name.startswith("text-embedding-3-"): | |
| return 8192 | |
| # Default to 8192 if unknown | |
| return 8192 | |
| elif self._provider == "sentence-transformers": | |
| if self._st_model is None: | |
| raise ValueError("SentenceTransformer model not initialized") | |
| max_len = getattr(self._st_model, "max_seq_length", None) | |
| if max_len: | |
| return int(max_len) | |
| return 512 | |
| else: | |
| raise ValueError(f"Unknown provider: {self._provider}") | |
| def _embed_ollama(self, text: str) -> list[float]: | |
| """Generate embedding using Ollama.""" | |
| res = ollama.embed(model=self.model_name, input=text, truncate=True) | |
| return list(res.embeddings[0]) | |
| def _embed_openai(self, text: str) -> list[float]: | |
| """Generate embedding using OpenAI.""" | |
| if self._openai_client is None: | |
| raise ValueError("OpenAI client not initialized") | |
| response = self._openai_client.embeddings.create( | |
| model=self.model_name, input=text | |
| ) | |
| return response.data[0].embedding | |
| def _embed_sentence_transformers(self, text: str) -> list[float]: | |
| if self._st_model is None: | |
| raise ValueError("SentenceTransformer model not initialized") | |
| vec = self._st_model.encode(text, normalize_embeddings=False) | |
| return [float(x) for x in vec.tolist()] | |
| def _embed_batch_ollama( | |
| self, texts: list[str] | |
| ) -> Sequence[Sequence[float]]: | |
| """Generate batch embeddings using Ollama.""" | |
| res = ollama.embed(model=self.model_name, input=texts, truncate=True) | |
| return res.embeddings | |
| def _embed_batch_openai( | |
| self, texts: list[str] | |
| ) -> Sequence[Sequence[float]]: | |
| """Generate batch embeddings using OpenAI.""" | |
| if self._openai_client is None: | |
| raise ValueError("OpenAI client not initialized") | |
| response = self._openai_client.embeddings.create( | |
| model=self.model_name, input=texts | |
| ) | |
| return [data.embedding for data in response.data] | |
| def _embed_batch_sentence_transformers( | |
| self, texts: list[str] | |
| ) -> Sequence[Sequence[float]]: | |
| if self._st_model is None: | |
| raise ValueError("SentenceTransformer model not initialized") | |
| vecs = self._st_model.encode(texts, normalize_embeddings=False) | |
| return [list(map(float, vec)) for vec in vecs.tolist()] | |
| def embed(self, text: str) -> list[float]: | |
| if self._provider == "ollama": | |
| return self._embed_ollama(text) | |
| if self._provider == "sentence-transformers": | |
| return self._embed_sentence_transformers(text) | |
| else: | |
| return self._embed_openai(text) | |
| def embed_batch(self, texts: list[str]) -> Sequence[Sequence[float]]: | |
| if self._provider == "ollama": | |
| return self._embed_batch_ollama(texts) | |
| if self._provider == "sentence-transformers": | |
| return self._embed_batch_sentence_transformers(texts) | |
| else: | |
| return self._embed_batch_openai(texts) | |