File size: 6,994 Bytes
f1e6b80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
"""Use a single chain to route an input to one of multiple llm chains."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain.chains import ConversationChain
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.router.base import MultiRouteChain
from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser
from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE
@deprecated(
since="0.2.12",
removal="1.0",
message=(
"Please see migration guide here for recommended implementation: "
"https://python.langchain.com/docs/versions/migrating_chains/multi_prompt_chain/" # noqa: E501
),
)
class MultiPromptChain(MultiRouteChain):
"""A multi-route chain that uses an LLM router chain to choose amongst prompts.
This class is deprecated. See below for a replacement, which offers several
benefits, including streaming and batch support.
Below is an example implementation:
.. code-block:: python
from operator import itemgetter
from typing import Literal
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from typing_extensions import TypedDict
llm = ChatOpenAI(model="gpt-4o-mini")
# Define the prompts we will route to
prompt_1 = ChatPromptTemplate.from_messages(
[
("system", "You are an expert on animals."),
("human", "{input}"),
]
)
prompt_2 = ChatPromptTemplate.from_messages(
[
("system", "You are an expert on vegetables."),
("human", "{input}"),
]
)
# Construct the chains we will route to. These format the input query
# into the respective prompt, run it through a chat model, and cast
# the result to a string.
chain_1 = prompt_1 | llm | StrOutputParser()
chain_2 = prompt_2 | llm | StrOutputParser()
# Next: define the chain that selects which branch to route to.
# Here we will take advantage of tool-calling features to force
# the output to select one of two desired branches.
route_system = "Route the user's query to either the animal or vegetable expert."
route_prompt = ChatPromptTemplate.from_messages(
[
("system", route_system),
("human", "{input}"),
]
)
# Define schema for output:
class RouteQuery(TypedDict):
\"\"\"Route query to destination expert.\"\"\"
destination: Literal["animal", "vegetable"]
route_chain = route_prompt | llm.with_structured_output(RouteQuery)
# For LangGraph, we will define the state of the graph to hold the query,
# destination, and final answer.
class State(TypedDict):
query: str
destination: RouteQuery
answer: str
# We define functions for each node, including routing the query:
async def route_query(state: State, config: RunnableConfig):
destination = await route_chain.ainvoke(state["query"], config)
return {"destination": destination}
# And one node for each prompt
async def prompt_1(state: State, config: RunnableConfig):
return {"answer": await chain_1.ainvoke(state["query"], config)}
async def prompt_2(state: State, config: RunnableConfig):
return {"answer": await chain_2.ainvoke(state["query"], config)}
# We then define logic that selects the prompt based on the classification
def select_node(state: State) -> Literal["prompt_1", "prompt_2"]:
if state["destination"] == "animal":
return "prompt_1"
else:
return "prompt_2"
# Finally, assemble the multi-prompt chain. This is a sequence of two steps:
# 1) Select "animal" or "vegetable" via the route_chain, and collect the answer
# alongside the input query.
# 2) Route the input query to chain_1 or chain_2, based on the
# selection.
graph = StateGraph(State)
graph.add_node("route_query", route_query)
graph.add_node("prompt_1", prompt_1)
graph.add_node("prompt_2", prompt_2)
graph.add_edge(START, "route_query")
graph.add_conditional_edges("route_query", select_node)
graph.add_edge("prompt_1", END)
graph.add_edge("prompt_2", END)
app = graph.compile()
result = await app.ainvoke({"query": "what color are carrots"})
print(result["destination"])
print(result["answer"])
""" # noqa: E501
@property
def output_keys(self) -> List[str]:
return ["text"]
@classmethod
def from_prompts(
cls,
llm: BaseLanguageModel,
prompt_infos: List[Dict[str, str]],
default_chain: Optional[Chain] = None,
**kwargs: Any,
) -> MultiPromptChain:
"""Convenience constructor for instantiating from destination prompts."""
destinations = [f"{p['name']}: {p['description']}" for p in prompt_infos]
destinations_str = "\n".join(destinations)
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
destinations=destinations_str
)
router_prompt = PromptTemplate(
template=router_template,
input_variables=["input"],
output_parser=RouterOutputParser(),
)
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
destination_chains = {}
for p_info in prompt_infos:
name = p_info["name"]
prompt_template = p_info["prompt_template"]
prompt = PromptTemplate(template=prompt_template, input_variables=["input"])
chain = LLMChain(llm=llm, prompt=prompt)
destination_chains[name] = chain
_default_chain = default_chain or ConversationChain(llm=llm, output_key="text")
return cls(
router_chain=router_chain,
destination_chains=destination_chains,
default_chain=_default_chain,
**kwargs,
)
|