File size: 5,590 Bytes
f1e6b80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from __future__ import annotations

from typing import Annotated, Any, TypeVar, Union

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable, RunnableSerializable
from pydantic import SkipValidation
from typing_extensions import TypedDict

from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT

T = TypeVar("T")


class OutputFixingParserRetryChainInput(TypedDict, total=False):
    instructions: str
    completion: str
    error: str


class OutputFixingParser(BaseOutputParser[T]):
    """Wrap a parser and try to fix parsing errors."""

    @classmethod
    def is_lc_serializable(cls) -> bool:
        return True

    parser: Annotated[Any, SkipValidation()]
    """The parser to use to parse the output."""
    # Should be an LLMChain but we want to avoid top-level imports from langchain.chains
    retry_chain: Annotated[
        Union[RunnableSerializable[OutputFixingParserRetryChainInput, str], Any],
        SkipValidation(),
    ]
    """The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
    max_retries: int = 1
    """The maximum number of times to retry the parse."""
    legacy: bool = True
    """Whether to use the run or arun method of the retry_chain."""

    @classmethod
    def from_llm(
        cls,
        llm: Runnable,
        parser: BaseOutputParser[T],
        prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
        max_retries: int = 1,
    ) -> OutputFixingParser[T]:
        """Create an OutputFixingParser from a language model and a parser.

        Args:
            llm: llm to use for fixing
            parser: parser to use for parsing
            prompt: prompt to use for fixing
            max_retries: Maximum number of retries to parse.

        Returns:
            OutputFixingParser
        """
        chain = prompt | llm | StrOutputParser()
        return cls(parser=parser, retry_chain=chain, max_retries=max_retries)

    def parse(self, completion: str) -> T:
        retries = 0

        while retries <= self.max_retries:
            try:
                return self.parser.parse(completion)
            except OutputParserException as e:
                if retries == self.max_retries:
                    raise e
                else:
                    retries += 1
                    if self.legacy and hasattr(self.retry_chain, "run"):
                        completion = self.retry_chain.run(
                            instructions=self.parser.get_format_instructions(),
                            completion=completion,
                            error=repr(e),
                        )
                    else:
                        try:
                            completion = self.retry_chain.invoke(
                                dict(
                                    instructions=self.parser.get_format_instructions(),
                                    completion=completion,
                                    error=repr(e),
                                )
                            )
                        except (NotImplementedError, AttributeError):
                            # Case: self.parser does not have get_format_instructions
                            completion = self.retry_chain.invoke(
                                dict(
                                    completion=completion,
                                    error=repr(e),
                                )
                            )

        raise OutputParserException("Failed to parse")

    async def aparse(self, completion: str) -> T:
        retries = 0

        while retries <= self.max_retries:
            try:
                return await self.parser.aparse(completion)
            except OutputParserException as e:
                if retries == self.max_retries:
                    raise e
                else:
                    retries += 1
                    if self.legacy and hasattr(self.retry_chain, "arun"):
                        completion = await self.retry_chain.arun(
                            instructions=self.parser.get_format_instructions(),
                            completion=completion,
                            error=repr(e),
                        )
                    else:
                        try:
                            completion = await self.retry_chain.ainvoke(
                                dict(
                                    instructions=self.parser.get_format_instructions(),
                                    completion=completion,
                                    error=repr(e),
                                )
                            )
                        except (NotImplementedError, AttributeError):
                            # Case: self.parser does not have get_format_instructions
                            completion = await self.retry_chain.ainvoke(
                                dict(
                                    completion=completion,
                                    error=repr(e),
                                )
                            )

        raise OutputParserException("Failed to parse")

    def get_format_instructions(self) -> str:
        return self.parser.get_format_instructions()

    @property
    def _type(self) -> str:
        return "output_fixing"

    @property
    def OutputType(self) -> type[T]:
        return self.parser.OutputType