Spaces:
Runtime error
Runtime error
| """Wrapper around OpenAI embedding models.""" | |
| from typing import Any, Dict, List, Optional | |
| from pydantic import BaseModel, Extra, root_validator | |
| from langchain.embeddings.base import Embeddings | |
| from langchain.utils import get_from_dict_or_env | |
| from tenacity import ( | |
| retry, | |
| retry_if_exception_type, | |
| stop_after_attempt, | |
| wait_exponential, | |
| ) | |
| from openai.error import Timeout, APIError, APIConnectionError, RateLimitError | |
| class OpenAIEmbeddings(BaseModel, Embeddings): | |
| """Wrapper around OpenAI embedding models. | |
| To use, you should have the ``openai`` python package installed, and the | |
| environment variable ``OPENAI_API_KEY`` set with your API key or pass it | |
| as a named parameter to the constructor. | |
| Example: | |
| .. code-block:: python | |
| from langchain.embeddings import OpenAIEmbeddings | |
| openai = OpenAIEmbeddings(openai_api_key="my-api-key") | |
| """ | |
| client: Any #: :meta private: | |
| document_model_name: str = "text-embedding-ada-002" | |
| query_model_name: str = "text-embedding-ada-002" | |
| openai_api_key: Optional[str] = None | |
| class Config: | |
| """Configuration for this pydantic object.""" | |
| extra = Extra.forbid | |
| # TODO: deprecate this | |
| def get_model_names(cls, values: Dict) -> Dict: | |
| """Get model names from just old model name.""" | |
| if "model_name" in values: | |
| if "document_model_name" in values: | |
| raise ValueError( | |
| "Both `model_name` and `document_model_name` were provided, " | |
| "but only one should be." | |
| ) | |
| if "query_model_name" in values: | |
| raise ValueError( | |
| "Both `model_name` and `query_model_name` were provided, " | |
| "but only one should be." | |
| ) | |
| model_name = values.pop("model_name") | |
| values["document_model_name"] = f"text-search-{model_name}-doc-001" | |
| values["query_model_name"] = f"text-search-{model_name}-query-001" | |
| return values | |
| def validate_environment(cls, values: Dict) -> Dict: | |
| """Validate that api key and python package exists in environment.""" | |
| openai_api_key = get_from_dict_or_env( | |
| values, "openai_api_key", "OPENAI_API_KEY" | |
| ) | |
| try: | |
| import openai | |
| openai.api_key = openai_api_key | |
| values["client"] = openai.Embedding | |
| except ImportError: | |
| raise ValueError( | |
| "Could not import openai python package. " | |
| "Please it install it with `pip install openai`." | |
| ) | |
| return values | |
| def _embedding_func(self, text: str, *, engine: str) -> List[float]: | |
| """Call out to OpenAI's embedding endpoint with exponential backoff.""" | |
| # replace newlines, which can negatively affect performance. | |
| text = text.replace("\n", " ") | |
| return self.client.create(input=[text], engine=engine)["data"][0]["embedding"] | |
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
| """Call out to OpenAI's embedding endpoint for embedding search docs. | |
| Args: | |
| texts: The list of texts to embed. | |
| Returns: | |
| List of embeddings, one for each text. | |
| """ | |
| responses = [ | |
| self._embedding_func(text, engine=self.document_model_name) | |
| for text in texts | |
| ] | |
| return responses | |
| def embed_query(self, text: str) -> List[float]: | |
| """Call out to OpenAI's embedding endpoint for embedding query text. | |
| Args: | |
| text: The text to embed. | |
| Returns: | |
| Embeddings for the text. | |
| """ | |
| embedding = self._embedding_func(text, engine=self.query_model_name) | |
| return embedding |