Spaces:
Runtime error
Runtime error
| """Test HyDE.""" | |
| from typing import List, Optional | |
| import numpy as np | |
| from pydantic import BaseModel | |
| from langchain.chains.hyde.base import HypotheticalDocumentEmbedder | |
| from langchain.chains.hyde.prompts import PROMPT_MAP | |
| from langchain.embeddings.base import Embeddings | |
| from langchain.llms.base import BaseLLM | |
| from langchain.schema import Generation, LLMResult | |
| class FakeEmbeddings(Embeddings): | |
| """Fake embedding class for tests.""" | |
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
| """Return random floats.""" | |
| return [list(np.random.uniform(0, 1, 10)) for _ in range(10)] | |
| def embed_query(self, text: str) -> List[float]: | |
| """Return random floats.""" | |
| return list(np.random.uniform(0, 1, 10)) | |
| class FakeLLM(BaseLLM, BaseModel): | |
| """Fake LLM wrapper for testing purposes.""" | |
| n: int = 1 | |
| def _generate( | |
| self, prompts: List[str], stop: Optional[List[str]] = None | |
| ) -> LLMResult: | |
| return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]]) | |
| async def _agenerate( | |
| self, prompts: List[str], stop: Optional[List[str]] = None | |
| ) -> LLMResult: | |
| return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]]) | |
| def _llm_type(self) -> str: | |
| """Return type of llm.""" | |
| return "fake" | |
| def test_hyde_from_llm() -> None: | |
| """Test loading HyDE from all prompts.""" | |
| for key in PROMPT_MAP: | |
| embedding = HypotheticalDocumentEmbedder.from_llm( | |
| FakeLLM(), FakeEmbeddings(), key | |
| ) | |
| embedding.embed_query("foo") | |
| def test_hyde_from_llm_with_multiple_n() -> None: | |
| """Test loading HyDE from all prompts.""" | |
| for key in PROMPT_MAP: | |
| embedding = HypotheticalDocumentEmbedder.from_llm( | |
| FakeLLM(n=8), FakeEmbeddings(), key | |
| ) | |
| embedding.embed_query("foo") | |