Spaces:
Runtime error
Runtime error
| import hashlib | |
| import logging | |
| from functools import cached_property | |
| from tenacity import stop_after_attempt, wait_random, retry, retry_if_exception | |
| from chromadb.api.types import ( | |
| Document, | |
| Documents, | |
| Embedding, | |
| Image, | |
| Images, | |
| EmbeddingFunction, | |
| Embeddings, | |
| is_image, | |
| is_document, | |
| ) | |
| from pathlib import Path | |
| import os | |
| import tarfile | |
| import requests | |
| from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Union, cast | |
| import numpy as np | |
| import numpy.typing as npt | |
| import importlib | |
| import inspect | |
| import json | |
| import sys | |
| try: | |
| from chromadb.is_thin_client import is_thin_client | |
| except ImportError: | |
| is_thin_client = False | |
| if TYPE_CHECKING: | |
| from onnxruntime import InferenceSession | |
| from tokenizers import Tokenizer | |
| logger = logging.getLogger(__name__) | |
| def _verify_sha256(fname: str, expected_sha256: str) -> bool: | |
| sha256_hash = hashlib.sha256() | |
| with open(fname, "rb") as f: | |
| # Read and update hash in chunks to avoid using too much memory | |
| for byte_block in iter(lambda: f.read(4096), b""): | |
| sha256_hash.update(byte_block) | |
| return sha256_hash.hexdigest() == expected_sha256 | |
| class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Documents]): | |
| # Since we do dynamic imports we have to type this as Any | |
| models: Dict[str, Any] = {} | |
| # If you have a beefier machine, try "gtr-t5-large". | |
| # for a full list of options: https://huggingface.co/sentence-transformers, https://www.sbert.net/docs/pretrained_models.html | |
| def __init__( | |
| self, | |
| model_name: str = "all-MiniLM-L6-v2", | |
| device: str = "cpu", | |
| normalize_embeddings: bool = False, | |
| ): | |
| if model_name not in self.models: | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| except ImportError: | |
| raise ValueError( | |
| "The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`" | |
| ) | |
| self.models[model_name] = SentenceTransformer(model_name, device=device) | |
| self._model = self.models[model_name] | |
| self._normalize_embeddings = normalize_embeddings | |
| def __call__(self, input: Documents) -> Embeddings: | |
| return cast( | |
| Embeddings, | |
| self._model.encode( | |
| list(input), | |
| convert_to_numpy=True, | |
| normalize_embeddings=self._normalize_embeddings, | |
| ).tolist(), | |
| ) | |
| class Text2VecEmbeddingFunction(EmbeddingFunction[Documents]): | |
| def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"): | |
| try: | |
| from text2vec import SentenceModel | |
| except ImportError: | |
| raise ValueError( | |
| "The text2vec python package is not installed. Please install it with `pip install text2vec`" | |
| ) | |
| self._model = SentenceModel(model_name_or_path=model_name) | |
| def __call__(self, input: Documents) -> Embeddings: | |
| return cast( | |
| Embeddings, self._model.encode(list(input), convert_to_numpy=True).tolist() | |
| ) # noqa E501 | |
| class OpenAIEmbeddingFunction(EmbeddingFunction[Documents]): | |
| def __init__( | |
| self, | |
| api_key: Optional[str] = None, | |
| model_name: str = "text-embedding-ada-002", | |
| organization_id: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| api_type: Optional[str] = None, | |
| api_version: Optional[str] = None, | |
| deployment_id: Optional[str] = None, | |
| default_headers: Optional[Mapping[str, str]] = None, | |
| ): | |
| """ | |
| Initialize the OpenAIEmbeddingFunction. | |
| Args: | |
| api_key (str, optional): Your API key for the OpenAI API. If not | |
| provided, it will raise an error to provide an OpenAI API key. | |
| organization_id(str, optional): The OpenAI organization ID if applicable | |
| model_name (str, optional): The name of the model to use for text | |
| embeddings. Defaults to "text-embedding-ada-002". | |
| api_base (str, optional): The base path for the API. If not provided, | |
| it will use the base path for the OpenAI API. This can be used to | |
| point to a different deployment, such as an Azure deployment. | |
| api_type (str, optional): The type of the API deployment. This can be | |
| used to specify a different deployment, such as 'azure'. If not | |
| provided, it will use the default OpenAI deployment. | |
| api_version (str, optional): The api version for the API. If not provided, | |
| it will use the api version for the OpenAI API. This can be used to | |
| point to a different deployment, such as an Azure deployment. | |
| deployment_id (str, optional): Deployment ID for Azure OpenAI. | |
| default_headers (Mapping, optional): A mapping of default headers to be sent with each API request. | |
| """ | |
| try: | |
| import openai | |
| except ImportError: | |
| raise ValueError( | |
| "The openai python package is not installed. Please install it with `pip install openai`" | |
| ) | |
| if api_key is not None: | |
| openai.api_key = api_key | |
| # If the api key is still not set, raise an error | |
| elif openai.api_key is None: | |
| raise ValueError( | |
| "Please provide an OpenAI API key. You can get one at https://platform.openai.com/account/api-keys" | |
| ) | |
| if api_base is not None: | |
| openai.api_base = api_base | |
| if api_version is not None: | |
| openai.api_version = api_version | |
| self._api_type = api_type | |
| if api_type is not None: | |
| openai.api_type = api_type | |
| if organization_id is not None: | |
| openai.organization = organization_id | |
| self._v1 = openai.__version__.startswith("1.") | |
| if self._v1: | |
| if api_type == "azure": | |
| self._client = openai.AzureOpenAI( | |
| api_key=api_key, | |
| api_version=api_version, | |
| azure_endpoint=api_base, | |
| default_headers=default_headers, | |
| ).embeddings | |
| else: | |
| self._client = openai.OpenAI( | |
| api_key=api_key, base_url=api_base, default_headers=default_headers | |
| ).embeddings | |
| else: | |
| self._client = openai.Embedding | |
| self._model_name = model_name | |
| self._deployment_id = deployment_id | |
| def __call__(self, input: Documents) -> Embeddings: | |
| # replace newlines, which can negatively affect performance. | |
| input = [t.replace("\n", " ") for t in input] | |
| # Call the OpenAI Embedding API | |
| if self._v1: | |
| embeddings = self._client.create( | |
| input=input, model=self._deployment_id or self._model_name | |
| ).data | |
| # Sort resulting embeddings by index | |
| sorted_embeddings = sorted(embeddings, key=lambda e: e.index) | |
| # Return just the embeddings | |
| return cast(Embeddings, [result.embedding for result in sorted_embeddings]) | |
| else: | |
| if self._api_type == "azure": | |
| embeddings = self._client.create( | |
| input=input, engine=self._deployment_id or self._model_name | |
| )["data"] | |
| else: | |
| embeddings = self._client.create(input=input, model=self._model_name)[ | |
| "data" | |
| ] | |
| # Sort resulting embeddings by index | |
| sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) | |
| # Return just the embeddings | |
| return cast( | |
| Embeddings, [result["embedding"] for result in sorted_embeddings] | |
| ) | |
| class CohereEmbeddingFunction(EmbeddingFunction[Documents]): | |
| def __init__(self, api_key: str, model_name: str = "large"): | |
| try: | |
| import cohere | |
| except ImportError: | |
| raise ValueError( | |
| "The cohere python package is not installed. Please install it with `pip install cohere`" | |
| ) | |
| self._client = cohere.Client(api_key) | |
| self._model_name = model_name | |
| def __call__(self, input: Documents) -> Embeddings: | |
| # Call Cohere Embedding API for each document. | |
| return [ | |
| embeddings | |
| for embeddings in self._client.embed( | |
| texts=input, model=self._model_name, input_type="search_document" | |
| ) | |
| ] | |
| class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]): | |
| """ | |
| This class is used to get embeddings for a list of texts using the HuggingFace API. | |
| It requires an API key and a model name. The default model name is "sentence-transformers/all-MiniLM-L6-v2". | |
| """ | |
| def __init__( | |
| self, api_key: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2" | |
| ): | |
| """ | |
| Initialize the HuggingFaceEmbeddingFunction. | |
| Args: | |
| api_key (str): Your API key for the HuggingFace API. | |
| model_name (str, optional): The name of the model to use for text embeddings. Defaults to "sentence-transformers/all-MiniLM-L6-v2". | |
| """ | |
| self._api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}" | |
| self._session = requests.Session() | |
| self._session.headers.update({"Authorization": f"Bearer {api_key}"}) | |
| def __call__(self, input: Documents) -> Embeddings: | |
| """ | |
| Get the embeddings for a list of texts. | |
| Args: | |
| texts (Documents): A list of texts to get embeddings for. | |
| Returns: | |
| Embeddings: The embeddings for the texts. | |
| Example: | |
| >>> hugging_face = HuggingFaceEmbeddingFunction(api_key="your_api_key") | |
| >>> texts = ["Hello, world!", "How are you?"] | |
| >>> embeddings = hugging_face(texts) | |
| """ | |
| # Call HuggingFace Embedding API for each document | |
| return cast( | |
| Embeddings, | |
| self._session.post( | |
| self._api_url, | |
| json={"inputs": input, "options": {"wait_for_model": True}}, | |
| ).json(), | |
| ) | |
| class JinaEmbeddingFunction(EmbeddingFunction[Documents]): | |
| """ | |
| This class is used to get embeddings for a list of texts using the Jina AI API. | |
| It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en". | |
| """ | |
| def __init__(self, api_key: str, model_name: str = "jina-embeddings-v2-base-en"): | |
| """ | |
| Initialize the JinaEmbeddingFunction. | |
| Args: | |
| api_key (str): Your API key for the Jina AI API. | |
| model_name (str, optional): The name of the model to use for text embeddings. Defaults to "jina-embeddings-v2-base-en". | |
| """ | |
| self._model_name = model_name | |
| self._api_url = "https://api.jina.ai/v1/embeddings" | |
| self._session = requests.Session() | |
| self._session.headers.update( | |
| {"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"} | |
| ) | |
| def __call__(self, input: Documents) -> Embeddings: | |
| """ | |
| Get the embeddings for a list of texts. | |
| Args: | |
| texts (Documents): A list of texts to get embeddings for. | |
| Returns: | |
| Embeddings: The embeddings for the texts. | |
| Example: | |
| >>> jina_ai_fn = JinaEmbeddingFunction(api_key="your_api_key") | |
| >>> input = ["Hello, world!", "How are you?"] | |
| >>> embeddings = jina_ai_fn(input) | |
| """ | |
| # Call Jina AI Embedding API | |
| resp = self._session.post( | |
| self._api_url, json={"input": input, "model": self._model_name} | |
| ).json() | |
| if "data" not in resp: | |
| raise RuntimeError(resp["detail"]) | |
| embeddings = resp["data"] | |
| # Sort resulting embeddings by index | |
| sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) | |
| # Return just the embeddings | |
| return cast(Embeddings, [result["embedding"] for result in sorted_embeddings]) | |
| class InstructorEmbeddingFunction(EmbeddingFunction[Documents]): | |
| # If you have a GPU with at least 6GB try model_name = "hkunlp/instructor-xl" and device = "cuda" | |
| # for a full list of options: https://github.com/HKUNLP/instructor-embedding#model-list | |
| def __init__( | |
| self, | |
| model_name: str = "hkunlp/instructor-base", | |
| device: str = "cpu", | |
| instruction: Optional[str] = None, | |
| ): | |
| try: | |
| from InstructorEmbedding import INSTRUCTOR | |
| except ImportError: | |
| raise ValueError( | |
| "The InstructorEmbedding python package is not installed. Please install it with `pip install InstructorEmbedding`" | |
| ) | |
| self._model = INSTRUCTOR(model_name, device=device) | |
| self._instruction = instruction | |
| def __call__(self, input: Documents) -> Embeddings: | |
| if self._instruction is None: | |
| return cast(Embeddings, self._model.encode(input).tolist()) | |
| texts_with_instructions = [[self._instruction, text] for text in input] | |
| return cast(Embeddings, self._model.encode(texts_with_instructions).tolist()) | |
| # In order to remove dependencies on sentence-transformers, which in turn depends on | |
| # pytorch and sentence-piece we have created a default ONNX embedding function that | |
| # implements the same functionality as "all-MiniLM-L6-v2" from sentence-transformers. | |
| # visit https://github.com/chroma-core/onnx-embedding for the source code to generate | |
| # and verify the ONNX model. | |
| class ONNXMiniLM_L6_V2(EmbeddingFunction[Documents]): | |
| MODEL_NAME = "all-MiniLM-L6-v2" | |
| DOWNLOAD_PATH = Path.home() / ".cache" / "chroma" / "onnx_models" / MODEL_NAME | |
| EXTRACTED_FOLDER_NAME = "onnx" | |
| ARCHIVE_FILENAME = "onnx.tar.gz" | |
| MODEL_DOWNLOAD_URL = ( | |
| "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz" | |
| ) | |
| _MODEL_SHA256 = "913d7300ceae3b2dbc2c50d1de4baacab4be7b9380491c27fab7418616a16ec3" | |
| # https://github.com/python/mypy/issues/7291 mypy makes you type the constructor if | |
| # no args | |
| def __init__(self, preferred_providers: Optional[List[str]] = None) -> None: | |
| # Import dependencies on demand to mirror other embedding functions. This | |
| # breaks typechecking, thus the ignores. | |
| # convert the list to set for unique values | |
| if preferred_providers and not all( | |
| [isinstance(i, str) for i in preferred_providers] | |
| ): | |
| raise ValueError("Preferred providers must be a list of strings") | |
| # check for duplicate providers | |
| if preferred_providers and len(preferred_providers) != len( | |
| set(preferred_providers) | |
| ): | |
| raise ValueError("Preferred providers must be unique") | |
| self._preferred_providers = preferred_providers | |
| try: | |
| # Equivalent to import onnxruntime | |
| self.ort = importlib.import_module("onnxruntime") | |
| except ImportError: | |
| raise ValueError( | |
| "The onnxruntime python package is not installed. Please install it with `pip install onnxruntime`" | |
| ) | |
| try: | |
| # Equivalent to from tokenizers import Tokenizer | |
| self.Tokenizer = importlib.import_module("tokenizers").Tokenizer | |
| except ImportError: | |
| raise ValueError( | |
| "The tokenizers python package is not installed. Please install it with `pip install tokenizers`" | |
| ) | |
| try: | |
| # Equivalent to from tqdm import tqdm | |
| self.tqdm = importlib.import_module("tqdm").tqdm | |
| except ImportError: | |
| raise ValueError( | |
| "The tqdm python package is not installed. Please install it with `pip install tqdm`" | |
| ) | |
| # Borrowed from https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51 | |
| # Download with tqdm to preserve the sentence-transformers experience | |
| def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None: | |
| resp = requests.get(url, stream=True) | |
| total = int(resp.headers.get("content-length", 0)) | |
| with open(fname, "wb") as file, self.tqdm( | |
| desc=str(fname), | |
| total=total, | |
| unit="iB", | |
| unit_scale=True, | |
| unit_divisor=1024, | |
| ) as bar: | |
| for data in resp.iter_content(chunk_size=chunk_size): | |
| size = file.write(data) | |
| bar.update(size) | |
| if not _verify_sha256(fname, self._MODEL_SHA256): | |
| # if the integrity of the file is not verified, remove it | |
| os.remove(fname) | |
| raise ValueError( | |
| f"Downloaded file {fname} does not match expected SHA256 hash. Corrupted download or malicious file." | |
| ) | |
| # Use pytorches default epsilon for division by zero | |
| # https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html | |
| def _normalize(self, v: npt.NDArray) -> npt.NDArray: | |
| norm = np.linalg.norm(v, axis=1) | |
| norm[norm == 0] = 1e-12 | |
| return cast(npt.NDArray, v / norm[:, np.newaxis]) | |
| def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray: | |
| # We need to cast to the correct type because the type checker doesn't know that init_model_and_tokenizer will set the values | |
| self.tokenizer = cast(self.Tokenizer, self.tokenizer) | |
| self.model = cast(self.ort.InferenceSession, self.model) | |
| all_embeddings = [] | |
| for i in range(0, len(documents), batch_size): | |
| batch = documents[i : i + batch_size] | |
| encoded = [self.tokenizer.encode(d) for d in batch] | |
| input_ids = np.array([e.ids for e in encoded]) | |
| attention_mask = np.array([e.attention_mask for e in encoded]) | |
| onnx_input = { | |
| "input_ids": np.array(input_ids, dtype=np.int64), | |
| "attention_mask": np.array(attention_mask, dtype=np.int64), | |
| "token_type_ids": np.array( | |
| [np.zeros(len(e), dtype=np.int64) for e in input_ids], | |
| dtype=np.int64, | |
| ), | |
| } | |
| model_output = self.model.run(None, onnx_input) | |
| last_hidden_state = model_output[0] | |
| # Perform mean pooling with attention weighting | |
| input_mask_expanded = np.broadcast_to( | |
| np.expand_dims(attention_mask, -1), last_hidden_state.shape | |
| ) | |
| embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip( | |
| input_mask_expanded.sum(1), a_min=1e-9, a_max=None | |
| ) | |
| embeddings = self._normalize(embeddings).astype(np.float32) | |
| all_embeddings.append(embeddings) | |
| return np.concatenate(all_embeddings) | |
| def tokenizer(self) -> "Tokenizer": | |
| tokenizer = self.Tokenizer.from_file( | |
| os.path.join( | |
| self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "tokenizer.json" | |
| ) | |
| ) | |
| # max_seq_length = 256, for some reason sentence-transformers uses 256 even though the HF config has a max length of 128 | |
| # https://github.com/UKPLab/sentence-transformers/blob/3e1929fddef16df94f8bc6e3b10598a98f46e62d/docs/_static/html/models_en_sentence_embeddings.html#LL480 | |
| tokenizer.enable_truncation(max_length=256) | |
| tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=256) | |
| return tokenizer | |
| def model(self) -> "InferenceSession": | |
| if self._preferred_providers is None or len(self._preferred_providers) == 0: | |
| if len(self.ort.get_available_providers()) > 0: | |
| logger.debug( | |
| f"WARNING: No ONNX providers provided, defaulting to available providers: " | |
| f"{self.ort.get_available_providers()}" | |
| ) | |
| self._preferred_providers = self.ort.get_available_providers() | |
| elif not set(self._preferred_providers).issubset( | |
| set(self.ort.get_available_providers()) | |
| ): | |
| raise ValueError( | |
| f"Preferred providers must be subset of available providers: {self.ort.get_available_providers()}" | |
| ) | |
| return self.ort.InferenceSession( | |
| os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "model.onnx"), | |
| # Since 1.9 onnyx runtime requires providers to be specified when there are multiple available - https://onnxruntime.ai/docs/api/python/api_summary.html | |
| # This is probably not ideal but will improve DX as no exceptions will be raised in multi-provider envs | |
| providers=self._preferred_providers, | |
| ) | |
| def __call__(self, input: Documents) -> Embeddings: | |
| # Only download the model when it is actually used | |
| self._download_model_if_not_exists() | |
| return cast(Embeddings, self._forward(input).tolist()) | |
| def _download_model_if_not_exists(self) -> None: | |
| onnx_files = [ | |
| "config.json", | |
| "model.onnx", | |
| "special_tokens_map.json", | |
| "tokenizer_config.json", | |
| "tokenizer.json", | |
| "vocab.txt", | |
| ] | |
| extracted_folder = os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME) | |
| onnx_files_exist = True | |
| for f in onnx_files: | |
| if not os.path.exists(os.path.join(extracted_folder, f)): | |
| onnx_files_exist = False | |
| break | |
| # Model is not downloaded yet | |
| if not onnx_files_exist: | |
| os.makedirs(self.DOWNLOAD_PATH, exist_ok=True) | |
| if not os.path.exists( | |
| os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME) | |
| ) or not _verify_sha256( | |
| os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), | |
| self._MODEL_SHA256, | |
| ): | |
| self._download( | |
| url=self.MODEL_DOWNLOAD_URL, | |
| fname=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), | |
| ) | |
| with tarfile.open( | |
| name=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), | |
| mode="r:gz", | |
| ) as tar: | |
| tar.extractall(path=self.DOWNLOAD_PATH) | |
| def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: | |
| if is_thin_client: | |
| return None | |
| else: | |
| return ONNXMiniLM_L6_V2() | |
| class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]): | |
| """To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key.""" | |
| def __init__(self, api_key: str, model_name: str = "models/embedding-gecko-001"): | |
| if not api_key: | |
| raise ValueError("Please provide a PaLM API key.") | |
| if not model_name: | |
| raise ValueError("Please provide the model name.") | |
| try: | |
| import google.generativeai as palm | |
| except ImportError: | |
| raise ValueError( | |
| "The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`" | |
| ) | |
| palm.configure(api_key=api_key) | |
| self._palm = palm | |
| self._model_name = model_name | |
| def __call__(self, input: Documents) -> Embeddings: | |
| return [ | |
| self._palm.generate_embeddings(model=self._model_name, text=text)[ | |
| "embedding" | |
| ] | |
| for text in input | |
| ] | |
| class GoogleGenerativeAiEmbeddingFunction(EmbeddingFunction[Documents]): | |
| """To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a Google API key.""" | |
| """Use RETRIEVAL_DOCUMENT for the task_type for embedding, and RETRIEVAL_QUERY for the task_type for retrieval.""" | |
| def __init__( | |
| self, | |
| api_key: str, | |
| model_name: str = "models/embedding-001", | |
| task_type: str = "RETRIEVAL_DOCUMENT", | |
| ): | |
| if not api_key: | |
| raise ValueError("Please provide a Google API key.") | |
| if not model_name: | |
| raise ValueError("Please provide the model name.") | |
| try: | |
| import google.generativeai as genai | |
| except ImportError: | |
| raise ValueError( | |
| "The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`" | |
| ) | |
| genai.configure(api_key=api_key) | |
| self._genai = genai | |
| self._model_name = model_name | |
| self._task_type = task_type | |
| self._task_title = None | |
| if self._task_type == "RETRIEVAL_DOCUMENT": | |
| self._task_title = "Embedding of single string" | |
| def __call__(self, input: Documents) -> Embeddings: | |
| return [ | |
| self._genai.embed_content( | |
| model=self._model_name, | |
| content=text, | |
| task_type=self._task_type, | |
| title=self._task_title, | |
| )["embedding"] | |
| for text in input | |
| ] | |
| class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]): | |
| # Follow API Quickstart for Google Vertex AI | |
| # https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart | |
| # Information about the text embedding modules in Google Vertex AI | |
| # https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings | |
| def __init__( | |
| self, | |
| api_key: str, | |
| model_name: str = "textembedding-gecko", | |
| project_id: str = "cloud-large-language-models", | |
| region: str = "us-central1", | |
| ): | |
| self._api_url = f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/publishers/goole/models/{model_name}:predict" | |
| self._session = requests.Session() | |
| self._session.headers.update({"Authorization": f"Bearer {api_key}"}) | |
| def __call__(self, input: Documents) -> Embeddings: | |
| embeddings = [] | |
| for text in input: | |
| response = self._session.post( | |
| self._api_url, json={"instances": [{"content": text}]} | |
| ).json() | |
| if "predictions" in response: | |
| embeddings.append(response["predictions"]["embeddings"]["values"]) | |
| return embeddings | |
| class OpenCLIPEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]): | |
| def __init__( | |
| self, model_name: str = "ViT-B-32", checkpoint: str = "laion2b_s34b_b79k" | |
| ) -> None: | |
| try: | |
| import open_clip | |
| except ImportError: | |
| raise ValueError( | |
| "The open_clip python package is not installed. Please install it with `pip install open-clip-torch`. https://github.com/mlfoundations/open_clip" | |
| ) | |
| try: | |
| self._torch = importlib.import_module("torch") | |
| except ImportError: | |
| raise ValueError( | |
| "The torch python package is not installed. Please install it with `pip install torch`" | |
| ) | |
| try: | |
| self._PILImage = importlib.import_module("PIL.Image") | |
| except ImportError: | |
| raise ValueError( | |
| "The PIL python package is not installed. Please install it with `pip install pillow`" | |
| ) | |
| model, _, preprocess = open_clip.create_model_and_transforms( | |
| model_name=model_name, pretrained=checkpoint | |
| ) | |
| self._model = model | |
| self._preprocess = preprocess | |
| self._tokenizer = open_clip.get_tokenizer(model_name=model_name) | |
| def _encode_image(self, image: Image) -> Embedding: | |
| pil_image = self._PILImage.fromarray(image) | |
| with self._torch.no_grad(): | |
| image_features = self._model.encode_image( | |
| self._preprocess(pil_image).unsqueeze(0) | |
| ) | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| return cast(Embedding, image_features.squeeze().tolist()) | |
| def _encode_text(self, text: Document) -> Embedding: | |
| with self._torch.no_grad(): | |
| text_features = self._model.encode_text(self._tokenizer(text)) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| return cast(Embedding, text_features.squeeze().tolist()) | |
| def __call__(self, input: Union[Documents, Images]) -> Embeddings: | |
| embeddings: Embeddings = [] | |
| for item in input: | |
| if is_image(item): | |
| embeddings.append(self._encode_image(cast(Image, item))) | |
| elif is_document(item): | |
| embeddings.append(self._encode_text(cast(Document, item))) | |
| return embeddings | |
| class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]): | |
| def __init__( | |
| self, | |
| session: "boto3.Session", # noqa: F821 # Quote for forward reference | |
| model_name: str = "amazon.titan-embed-text-v1", | |
| **kwargs: Any, | |
| ): | |
| """Initialize AmazonBedrockEmbeddingFunction. | |
| Args: | |
| session (boto3.Session): The boto3 session to use. | |
| model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1" | |
| **kwargs: Additional arguments to pass to the boto3 client. | |
| Example: | |
| >>> import boto3 | |
| >>> session = boto3.Session(profile_name="profile", region_name="us-east-1") | |
| >>> bedrock = AmazonBedrockEmbeddingFunction(session=session) | |
| >>> texts = ["Hello, world!", "How are you?"] | |
| >>> embeddings = bedrock(texts) | |
| """ | |
| self._model_name = model_name | |
| self._client = session.client( | |
| service_name="bedrock-runtime", | |
| **kwargs, | |
| ) | |
| def __call__(self, input: Documents) -> Embeddings: | |
| accept = "application/json" | |
| content_type = "application/json" | |
| embeddings = [] | |
| for text in input: | |
| input_body = {"inputText": text} | |
| body = json.dumps(input_body) | |
| response = self._client.invoke_model( | |
| body=body, | |
| modelId=self._model_name, | |
| accept=accept, | |
| contentType=content_type, | |
| ) | |
| embedding = json.load(response.get("body")).get("embedding") | |
| embeddings.append(embedding) | |
| return embeddings | |
| class HuggingFaceEmbeddingServer(EmbeddingFunction[Documents]): | |
| """ | |
| This class is used to get embeddings for a list of texts using the HuggingFace Embedding server (https://github.com/huggingface/text-embeddings-inference). | |
| The embedding model is configured in the server. | |
| """ | |
| def __init__(self, url: str): | |
| """ | |
| Initialize the HuggingFaceEmbeddingServer. | |
| Args: | |
| url (str): The URL of the HuggingFace Embedding Server. | |
| """ | |
| try: | |
| import requests | |
| except ImportError: | |
| raise ValueError( | |
| "The requests python package is not installed. Please install it with `pip install requests`" | |
| ) | |
| self._api_url = f"{url}" | |
| self._session = requests.Session() | |
| def __call__(self, input: Documents) -> Embeddings: | |
| """ | |
| Get the embeddings for a list of texts. | |
| Args: | |
| texts (Documents): A list of texts to get embeddings for. | |
| Returns: | |
| Embeddings: The embeddings for the texts. | |
| Example: | |
| >>> hugging_face = HuggingFaceEmbeddingServer(url="http://localhost:8080/embed") | |
| >>> texts = ["Hello, world!", "How are you?"] | |
| >>> embeddings = hugging_face(texts) | |
| """ | |
| # Call HuggingFace Embedding Server API for each document | |
| return cast( | |
| Embeddings, self._session.post(self._api_url, json={"inputs": input}).json() | |
| ) | |
| # List of all classes in this module | |
| _classes = [ | |
| name | |
| for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass) | |
| if obj.__module__ == __name__ | |
| ] | |
| def get_builtins() -> List[str]: | |
| return _classes | |