File size: 6,994 Bytes
f1e6b80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""Use a single chain to route an input to one of multiple llm chains."""

from __future__ import annotations

from typing import Any, Dict, List, Optional

from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate

from langchain.chains import ConversationChain
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.router.base import MultiRouteChain
from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser
from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE


@deprecated(
    since="0.2.12",
    removal="1.0",
    message=(
        "Please see migration guide here for recommended implementation: "
        "https://python.langchain.com/docs/versions/migrating_chains/multi_prompt_chain/"  # noqa: E501
    ),
)
class MultiPromptChain(MultiRouteChain):
    """A multi-route chain that uses an LLM router chain to choose amongst prompts.

    This class is deprecated. See below for a replacement, which offers several
    benefits, including streaming and batch support.

    Below is an example implementation:

        .. code-block:: python

            from operator import itemgetter
            from typing import Literal

            from langchain_core.output_parsers import StrOutputParser
            from langchain_core.prompts import ChatPromptTemplate
            from langchain_core.runnables import RunnableConfig
            from langchain_openai import ChatOpenAI
            from langgraph.graph import END, START, StateGraph
            from typing_extensions import TypedDict

            llm = ChatOpenAI(model="gpt-4o-mini")

            # Define the prompts we will route to
            prompt_1 = ChatPromptTemplate.from_messages(
                [
                    ("system", "You are an expert on animals."),
                    ("human", "{input}"),
                ]
            )
            prompt_2 = ChatPromptTemplate.from_messages(
                [
                    ("system", "You are an expert on vegetables."),
                    ("human", "{input}"),
                ]
            )

            # Construct the chains we will route to. These format the input query
            # into the respective prompt, run it through a chat model, and cast
            # the result to a string.
            chain_1 = prompt_1 | llm | StrOutputParser()
            chain_2 = prompt_2 | llm | StrOutputParser()


            # Next: define the chain that selects which branch to route to.
            # Here we will take advantage of tool-calling features to force
            # the output to select one of two desired branches.
            route_system = "Route the user's query to either the animal or vegetable expert."
            route_prompt = ChatPromptTemplate.from_messages(
                [
                    ("system", route_system),
                    ("human", "{input}"),
                ]
            )


            # Define schema for output:
            class RouteQuery(TypedDict):
                \"\"\"Route query to destination expert.\"\"\"

                destination: Literal["animal", "vegetable"]


            route_chain = route_prompt | llm.with_structured_output(RouteQuery)


            # For LangGraph, we will define the state of the graph to hold the query,
            # destination, and final answer.
            class State(TypedDict):
                query: str
                destination: RouteQuery
                answer: str


            # We define functions for each node, including routing the query:
            async def route_query(state: State, config: RunnableConfig):
                destination = await route_chain.ainvoke(state["query"], config)
                return {"destination": destination}


            # And one node for each prompt
            async def prompt_1(state: State, config: RunnableConfig):
                return {"answer": await chain_1.ainvoke(state["query"], config)}


            async def prompt_2(state: State, config: RunnableConfig):
                return {"answer": await chain_2.ainvoke(state["query"], config)}


            # We then define logic that selects the prompt based on the classification
            def select_node(state: State) -> Literal["prompt_1", "prompt_2"]:
                if state["destination"] == "animal":
                    return "prompt_1"
                else:
                    return "prompt_2"


            # Finally, assemble the multi-prompt chain. This is a sequence of two steps:
            # 1) Select "animal" or "vegetable" via the route_chain, and collect the answer
            # alongside the input query.
            # 2) Route the input query to chain_1 or chain_2, based on the
            # selection.
            graph = StateGraph(State)
            graph.add_node("route_query", route_query)
            graph.add_node("prompt_1", prompt_1)
            graph.add_node("prompt_2", prompt_2)

            graph.add_edge(START, "route_query")
            graph.add_conditional_edges("route_query", select_node)
            graph.add_edge("prompt_1", END)
            graph.add_edge("prompt_2", END)
            app = graph.compile()

            result = await app.ainvoke({"query": "what color are carrots"})
            print(result["destination"])
            print(result["answer"])
    """  # noqa: E501

    @property
    def output_keys(self) -> List[str]:
        return ["text"]

    @classmethod
    def from_prompts(
        cls,
        llm: BaseLanguageModel,
        prompt_infos: List[Dict[str, str]],
        default_chain: Optional[Chain] = None,
        **kwargs: Any,
    ) -> MultiPromptChain:
        """Convenience constructor for instantiating from destination prompts."""
        destinations = [f"{p['name']}: {p['description']}" for p in prompt_infos]
        destinations_str = "\n".join(destinations)
        router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
            destinations=destinations_str
        )
        router_prompt = PromptTemplate(
            template=router_template,
            input_variables=["input"],
            output_parser=RouterOutputParser(),
        )
        router_chain = LLMRouterChain.from_llm(llm, router_prompt)
        destination_chains = {}
        for p_info in prompt_infos:
            name = p_info["name"]
            prompt_template = p_info["prompt_template"]
            prompt = PromptTemplate(template=prompt_template, input_variables=["input"])
            chain = LLMChain(llm=llm, prompt=prompt)
            destination_chains[name] = chain
        _default_chain = default_chain or ConversationChain(llm=llm, output_key="text")
        return cls(
            router_chain=router_chain,
            destination_chains=destination_chains,
            default_chain=_default_chain,
            **kwargs,
        )