File size: 4,050 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from string import Template
from typing import Dict, Union, Optional

from evoagentx.rag.schema import Query
from evoagentx.models.base_model import BaseLLM
from evoagentx.rag.transforms.query.base import BaseQueryTransform
from evoagentx.prompts.rag.hyde import DEFAULT_HYDE_PROMPT, HYDE_SYSTEM_IMPLE_


class HyDETransform(BaseQueryTransform):
    """
    Hypothetical Document Embeddings (HyDE) query transform.

    This class implements the HyDE technique for improving dense retrieval, as described in
    `Precise Zero-Shot Dense Retrieval without Relevance Labels` (https://arxiv.org/abs/2212.10496).
    It uses a language model to generate a hypothetical document (answer) for a given query, which
    is then used to create embedding strings for enhanced retrieval.

    Attributes:
        _llm (BaseLLM): The language model used to generate hypothetical documents.
        _hyde_prompt (Union[str, Template]): The prompt template for generating hypothetical documents.
        _include_original (bool): Whether to include the original query's embedding strings in the output.
    """

    def __init__(
        self,
        llm: BaseLLM,
        hyde_prompt: Optional[Union[str, Template]] = None,
        include_original: bool = True,
    ) -> None:
        """
        Initialize the HyDETransform.

        Args:
            llm (BaseLLM): The language model for generating hypothetical documents.
            hyde_prompt (Optional[Union[str, Template]]): Custom prompt template for HyDE generation.
                Defaults to DEFAULT_HYDE_PROMPT if not provided.
            include_original (bool): Whether to include the original query's embedding strings
                alongside the hypothetical document. Defaults to True.
        """
        self._llm = llm
        self._hyde_prompt = hyde_prompt or DEFAULT_HYDE_PROMPT
        self._include_original = include_original

    def _run(self, query: Query, metadata: Dict) -> Query:
        """
        Transform a query by generating a hypothetical document and updating embedding strings.

        This method uses the LLM to generate a hypothetical answer to the query, which is then
        used as an embedding string for retrieval. If include_original is True, the original
        query's embedding strings are also retained.

        Args:
            query (Query): The input query to transform.
            metadata (Dict): Additional metadata associated with the query (not used in this implementation).

        Returns:
            Query: A new Query instance with updated embedding strings, including the hypothetical document.
        """
        query_str = query.query_str

        # Format the prompt by substituting the query string into the HyDE prompt template
        instruction = self._hyde_prompt.format_map({"query": query_str})

        hypothetical_doc = self._llm.generate(
            prompt=instruction,
            system_message=HYDE_SYSTEM_IMPLE_,
        ).content

        # Initialize embedding strings with the hypothetical document
        embedding_strs = [hypothetical_doc]
        # Append original embedding strings if specified
        if self._include_original:
            embedding_strs.extend(query.embedding_strs)

        # Create a deep copy of the input query to avoid modifying the original
        tmp_query = query.deepcopy()
        tmp_query.custom_embedding_strs = embedding_strs
        return tmp_query


if __name__ == "__main__":
    import dotenv
    import os

    dotenv.load_dotenv()

    from evoagentx.models import OpenAILLMConfig, OpenAILLM

    os.environ["SSL_CERT_FILE"] = r"D:\miniconda3\envs\envoagentx\Library\ssl\cacert.pem"

    config = OpenAILLMConfig(
        model="gpt-4o-mini",
        temperature=0.7,
        max_tokens=1000,
        openai_key=os.environ["OPENAI_API_KEY"],
    )

    llm = OpenAILLM(config=config)
    hyde_trans = HyDETransform(llm=llm)
    output_query = hyde_trans(Query(query_str="Were Scott Derrickson and Ed Wood of the same nationality?"))
    print(output_query)