Spaces:
Runtime error
Runtime error
| """Hypothetical Document Embeddings. | |
| https://arxiv.org/abs/2212.10496 | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional | |
| import numpy as np | |
| from langchain_core.embeddings import Embeddings | |
| from langchain_core.language_models import BaseLanguageModel | |
| from langchain_core.pydantic_v1 import Extra | |
| from langchain.callbacks.manager import CallbackManagerForChainRun | |
| from langchain.chains.base import Chain | |
| from langchain.chains.hyde.prompts import PROMPT_MAP | |
| from langchain.chains.llm import LLMChain | |
| class HypotheticalDocumentEmbedder(Chain, Embeddings): | |
| """Generate hypothetical document for query, and then embed that. | |
| Based on https://arxiv.org/abs/2212.10496 | |
| """ | |
| base_embeddings: Embeddings | |
| llm_chain: LLMChain | |
| class Config: | |
| """Configuration for this pydantic object.""" | |
| extra = Extra.forbid | |
| arbitrary_types_allowed = True | |
| def input_keys(self) -> List[str]: | |
| """Input keys for Hyde's LLM chain.""" | |
| return self.llm_chain.input_keys | |
| def output_keys(self) -> List[str]: | |
| """Output keys for Hyde's LLM chain.""" | |
| return self.llm_chain.output_keys | |
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
| """Call the base embeddings.""" | |
| return self.base_embeddings.embed_documents(texts) | |
| def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]: | |
| """Combine embeddings into final embeddings.""" | |
| return list(np.array(embeddings).mean(axis=0)) | |
| def embed_query(self, text: str) -> List[float]: | |
| """Generate a hypothetical document and embedded it.""" | |
| var_name = self.llm_chain.input_keys[0] | |
| result = self.llm_chain.generate([{var_name: text}]) | |
| documents = [generation.text for generation in result.generations[0]] | |
| embeddings = self.embed_documents(documents) | |
| return self.combine_embeddings(embeddings) | |
| def _call( | |
| self, | |
| inputs: Dict[str, Any], | |
| run_manager: Optional[CallbackManagerForChainRun] = None, | |
| ) -> Dict[str, str]: | |
| """Call the internal llm chain.""" | |
| _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
| return self.llm_chain(inputs, callbacks=_run_manager.get_child()) | |
| def from_llm( | |
| cls, | |
| llm: BaseLanguageModel, | |
| base_embeddings: Embeddings, | |
| prompt_key: str, | |
| **kwargs: Any, | |
| ) -> HypotheticalDocumentEmbedder: | |
| """Load and use LLMChain for a specific prompt key.""" | |
| prompt = PROMPT_MAP[prompt_key] | |
| llm_chain = LLMChain(llm=llm, prompt=prompt) | |
| return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs) | |
| def _chain_type(self) -> str: | |
| return "hyde_chain" | |