File size: 5,590 Bytes
f1e6b80 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | from __future__ import annotations
from typing import Annotated, Any, TypeVar, Union
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable, RunnableSerializable
from pydantic import SkipValidation
from typing_extensions import TypedDict
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
T = TypeVar("T")
class OutputFixingParserRetryChainInput(TypedDict, total=False):
instructions: str
completion: str
error: str
class OutputFixingParser(BaseOutputParser[T]):
"""Wrap a parser and try to fix parsing errors."""
@classmethod
def is_lc_serializable(cls) -> bool:
return True
parser: Annotated[Any, SkipValidation()]
"""The parser to use to parse the output."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Annotated[
Union[RunnableSerializable[OutputFixingParserRetryChainInput, str], Any],
SkipValidation(),
]
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
max_retries: int = 1
"""The maximum number of times to retry the parse."""
legacy: bool = True
"""Whether to use the run or arun method of the retry_chain."""
@classmethod
def from_llm(
cls,
llm: Runnable,
parser: BaseOutputParser[T],
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
max_retries: int = 1,
) -> OutputFixingParser[T]:
"""Create an OutputFixingParser from a language model and a parser.
Args:
llm: llm to use for fixing
parser: parser to use for parsing
prompt: prompt to use for fixing
max_retries: Maximum number of retries to parse.
Returns:
OutputFixingParser
"""
chain = prompt | llm | StrOutputParser()
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
def parse(self, completion: str) -> T:
retries = 0
while retries <= self.max_retries:
try:
return self.parser.parse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
if self.legacy and hasattr(self.retry_chain, "run"):
completion = self.retry_chain.run(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
else:
try:
completion = self.retry_chain.invoke(
dict(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
)
except (NotImplementedError, AttributeError):
# Case: self.parser does not have get_format_instructions
completion = self.retry_chain.invoke(
dict(
completion=completion,
error=repr(e),
)
)
raise OutputParserException("Failed to parse")
async def aparse(self, completion: str) -> T:
retries = 0
while retries <= self.max_retries:
try:
return await self.parser.aparse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
if self.legacy and hasattr(self.retry_chain, "arun"):
completion = await self.retry_chain.arun(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
else:
try:
completion = await self.retry_chain.ainvoke(
dict(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
)
except (NotImplementedError, AttributeError):
# Case: self.parser does not have get_format_instructions
completion = await self.retry_chain.ainvoke(
dict(
completion=completion,
error=repr(e),
)
)
raise OutputParserException("Failed to parse")
def get_format_instructions(self) -> str:
return self.parser.get_format_instructions()
@property
def _type(self) -> str:
return "output_fixing"
@property
def OutputType(self) -> type[T]:
return self.parser.OutputType
|