File size: 6,575 Bytes
f1e6b80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import re
from typing import Any, Dict, List, Tuple, Union

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.base import BaseOutputParser
from pydantic import field_validator

from langchain.output_parsers.format_instructions import (
    PANDAS_DATAFRAME_FORMAT_INSTRUCTIONS,
)


class PandasDataFrameOutputParser(BaseOutputParser[Dict[str, Any]]):
    """Parse an output using Pandas DataFrame format."""

    """The Pandas DataFrame to parse."""
    dataframe: Any

    @field_validator("dataframe")
    @classmethod
    def validate_dataframe(cls, val: Any) -> Any:
        import pandas as pd

        if issubclass(type(val), pd.DataFrame):
            return val
        if pd.DataFrame(val).empty:
            raise ValueError("DataFrame cannot be empty.")

        raise TypeError(
            "Wrong type for 'dataframe', must be a subclass \
                of Pandas DataFrame (pd.DataFrame)"
        )

    def parse_array(
        self, array: str, original_request_params: str
    ) -> Tuple[List[Union[int, str]], str]:
        parsed_array: List[Union[int, str]] = []

        # Check if the format is [1,3,5]
        if re.match(r"\[\d+(,\s*\d+)*\]", array):
            parsed_array = [int(i) for i in re.findall(r"\d+", array)]
        # Check if the format is [1..5]
        elif re.match(r"\[(\d+)\.\.(\d+)\]", array):
            match = re.match(r"\[(\d+)\.\.(\d+)\]", array)
            if match:
                start, end = map(int, match.groups())
                parsed_array = list(range(start, end + 1))
            else:
                raise OutputParserException(
                    f"Unable to parse the array provided in {array}. \
                        Please check the format instructions."
                )
        # Check if the format is ["column_name"]
        elif re.match(r"\[[a-zA-Z0-9_]+(?:,[a-zA-Z0-9_]+)*\]", array):
            match = re.match(r"\[[a-zA-Z0-9_]+(?:,[a-zA-Z0-9_]+)*\]", array)
            if match:
                parsed_array = list(map(str, match.group().strip("[]").split(",")))
            else:
                raise OutputParserException(
                    f"Unable to parse the array provided in {array}. \
                        Please check the format instructions."
                )

        # Validate the array
        if not parsed_array:
            raise OutputParserException(
                f"Invalid array format in '{original_request_params}'. \
                    Please check the format instructions."
            )
        elif (
            isinstance(parsed_array[0], int)
            and parsed_array[-1] > self.dataframe.index.max()
        ):
            raise OutputParserException(
                f"The maximum index {parsed_array[-1]} exceeds the maximum index of \
                    the Pandas DataFrame {self.dataframe.index.max()}."
            )

        return parsed_array, original_request_params.split("[")[0]

    def parse(self, request: str) -> Dict[str, Any]:
        stripped_request_params = None
        splitted_request = request.strip().split(":")
        if len(splitted_request) != 2:
            raise OutputParserException(
                f"Request '{request}' is not correctly formatted. \
                    Please refer to the format instructions."
            )
        result = {}
        try:
            request_type, request_params = splitted_request
            if request_type in {"Invalid column", "Invalid operation"}:
                raise OutputParserException(
                    f"{request}. Please check the format instructions."
                )
            array_exists = re.search(r"(\[.*?\])", request_params)
            if array_exists:
                parsed_array, stripped_request_params = self.parse_array(
                    array_exists.group(1), request_params
                )
                if request_type == "column":
                    filtered_df = self.dataframe[
                        self.dataframe.index.isin(parsed_array)
                    ]
                    if len(parsed_array) == 1:
                        result[stripped_request_params] = filtered_df[
                            stripped_request_params
                        ].iloc[parsed_array[0]]
                    else:
                        result[stripped_request_params] = filtered_df[
                            stripped_request_params
                        ]
                elif request_type == "row":
                    filtered_df = self.dataframe[
                        self.dataframe.columns.intersection(parsed_array)
                    ]
                    if len(parsed_array) == 1:
                        result[stripped_request_params] = filtered_df.iloc[
                            int(stripped_request_params)
                        ][parsed_array[0]]
                    else:
                        result[stripped_request_params] = filtered_df.iloc[
                            int(stripped_request_params)
                        ]
                else:
                    filtered_df = self.dataframe[
                        self.dataframe.index.isin(parsed_array)
                    ]
                    result[request_type] = getattr(
                        filtered_df[stripped_request_params], request_type
                    )()
            else:
                if request_type == "column":
                    result[request_params] = self.dataframe[request_params]
                elif request_type == "row":
                    result[request_params] = self.dataframe.iloc[int(request_params)]
                else:
                    result[request_type] = getattr(
                        self.dataframe[request_params], request_type
                    )()
        except (AttributeError, IndexError, KeyError):
            if request_type not in {"column", "row"}:
                raise OutputParserException(
                    f"Unsupported request type '{request_type}'. \
                        Please check the format instructions."
                )
            raise OutputParserException(
                f"""Requested index {
                    request_params
                    if stripped_request_params is None
                    else stripped_request_params
                } is out of bounds."""
            )

        return result

    def get_format_instructions(self) -> str:
        return PANDAS_DATAFRAME_FORMAT_INSTRUCTIONS.format(
            columns=", ".join(self.dataframe.columns)
        )