File size: 6,575 Bytes
f1e6b80 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | import re
from typing import Any, Dict, List, Tuple, Union
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.base import BaseOutputParser
from pydantic import field_validator
from langchain.output_parsers.format_instructions import (
PANDAS_DATAFRAME_FORMAT_INSTRUCTIONS,
)
class PandasDataFrameOutputParser(BaseOutputParser[Dict[str, Any]]):
"""Parse an output using Pandas DataFrame format."""
"""The Pandas DataFrame to parse."""
dataframe: Any
@field_validator("dataframe")
@classmethod
def validate_dataframe(cls, val: Any) -> Any:
import pandas as pd
if issubclass(type(val), pd.DataFrame):
return val
if pd.DataFrame(val).empty:
raise ValueError("DataFrame cannot be empty.")
raise TypeError(
"Wrong type for 'dataframe', must be a subclass \
of Pandas DataFrame (pd.DataFrame)"
)
def parse_array(
self, array: str, original_request_params: str
) -> Tuple[List[Union[int, str]], str]:
parsed_array: List[Union[int, str]] = []
# Check if the format is [1,3,5]
if re.match(r"\[\d+(,\s*\d+)*\]", array):
parsed_array = [int(i) for i in re.findall(r"\d+", array)]
# Check if the format is [1..5]
elif re.match(r"\[(\d+)\.\.(\d+)\]", array):
match = re.match(r"\[(\d+)\.\.(\d+)\]", array)
if match:
start, end = map(int, match.groups())
parsed_array = list(range(start, end + 1))
else:
raise OutputParserException(
f"Unable to parse the array provided in {array}. \
Please check the format instructions."
)
# Check if the format is ["column_name"]
elif re.match(r"\[[a-zA-Z0-9_]+(?:,[a-zA-Z0-9_]+)*\]", array):
match = re.match(r"\[[a-zA-Z0-9_]+(?:,[a-zA-Z0-9_]+)*\]", array)
if match:
parsed_array = list(map(str, match.group().strip("[]").split(",")))
else:
raise OutputParserException(
f"Unable to parse the array provided in {array}. \
Please check the format instructions."
)
# Validate the array
if not parsed_array:
raise OutputParserException(
f"Invalid array format in '{original_request_params}'. \
Please check the format instructions."
)
elif (
isinstance(parsed_array[0], int)
and parsed_array[-1] > self.dataframe.index.max()
):
raise OutputParserException(
f"The maximum index {parsed_array[-1]} exceeds the maximum index of \
the Pandas DataFrame {self.dataframe.index.max()}."
)
return parsed_array, original_request_params.split("[")[0]
def parse(self, request: str) -> Dict[str, Any]:
stripped_request_params = None
splitted_request = request.strip().split(":")
if len(splitted_request) != 2:
raise OutputParserException(
f"Request '{request}' is not correctly formatted. \
Please refer to the format instructions."
)
result = {}
try:
request_type, request_params = splitted_request
if request_type in {"Invalid column", "Invalid operation"}:
raise OutputParserException(
f"{request}. Please check the format instructions."
)
array_exists = re.search(r"(\[.*?\])", request_params)
if array_exists:
parsed_array, stripped_request_params = self.parse_array(
array_exists.group(1), request_params
)
if request_type == "column":
filtered_df = self.dataframe[
self.dataframe.index.isin(parsed_array)
]
if len(parsed_array) == 1:
result[stripped_request_params] = filtered_df[
stripped_request_params
].iloc[parsed_array[0]]
else:
result[stripped_request_params] = filtered_df[
stripped_request_params
]
elif request_type == "row":
filtered_df = self.dataframe[
self.dataframe.columns.intersection(parsed_array)
]
if len(parsed_array) == 1:
result[stripped_request_params] = filtered_df.iloc[
int(stripped_request_params)
][parsed_array[0]]
else:
result[stripped_request_params] = filtered_df.iloc[
int(stripped_request_params)
]
else:
filtered_df = self.dataframe[
self.dataframe.index.isin(parsed_array)
]
result[request_type] = getattr(
filtered_df[stripped_request_params], request_type
)()
else:
if request_type == "column":
result[request_params] = self.dataframe[request_params]
elif request_type == "row":
result[request_params] = self.dataframe.iloc[int(request_params)]
else:
result[request_type] = getattr(
self.dataframe[request_params], request_type
)()
except (AttributeError, IndexError, KeyError):
if request_type not in {"column", "row"}:
raise OutputParserException(
f"Unsupported request type '{request_type}'. \
Please check the format instructions."
)
raise OutputParserException(
f"""Requested index {
request_params
if stripped_request_params is None
else stripped_request_params
} is out of bounds."""
)
return result
def get_format_instructions(self) -> str:
return PANDAS_DATAFRAME_FORMAT_INSTRUCTIONS.format(
columns=", ".join(self.dataframe.columns)
)
|