| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import dataclasses |
| import enum |
| import functools |
| import math |
| import re |
|
|
| |
| |
| from typing import Any, List, Text |
|
|
|
|
| EMPTY_ANSWER = "none" |
| EMPTY_ANSWER_AGG = "none" |
|
|
|
|
| def _split_thousands(delimiter, value): |
| split = value.split(delimiter) |
| return len(split) > 1 and any((len(x) == 3 for x in split)) |
|
|
|
|
| def convert_to_float(value): |
| """Converts value to a float using a series of increasingly complex heuristics. |
| Args: |
| value: object that needs to be converted. Allowed types include |
| float/int/strings. |
| Returns: |
| A float interpretation of value. |
| Raises: |
| ValueError if the float conversion of value fails. |
| """ |
| if isinstance(value, float): |
| return value |
| if isinstance(value, int): |
| return float(value) |
| if not isinstance(value, str): |
| raise ValueError("Argument value is not a string. Can't parse it as float") |
| sanitized = value |
|
|
| try: |
| |
| if "." in sanitized and "," in sanitized: |
| return float(sanitized.replace(",", "")) |
| |
| if "," in sanitized and _split_thousands(",", sanitized): |
| return float(sanitized.replace(",", "")) |
| |
| if "," in sanitized and sanitized.count(",") == 1 and not _split_thousands(",", sanitized): |
| return float(sanitized.replace(",", ".")) |
| |
| if sanitized.count(".") > 1: |
| return float(sanitized.replace(".", "")) |
| |
| if sanitized.count(",") > 1: |
| return float(sanitized.replace(",", "")) |
| return float(sanitized) |
| except ValueError: |
| |
| raise ValueError("Unable to convert value to float") |
|
|
|
|
| def _normalize_float(answer): |
| if answer is None: |
| return None |
| try: |
| value = convert_to_float(answer) |
| if isinstance(value, float) and math.isnan(value): |
| return None |
| return value |
| except ValueError: |
| return answer.lower() |
|
|
|
|
| _TYPE_CONVERTER = { |
| "text": lambda x: x, |
| "real": convert_to_float, |
| } |
|
|
|
|
| class _Aggregation(enum.Enum): |
| """Aggregations as defined by WikiSQL. Indexes match the data.""" |
|
|
| NONE = 0 |
| MAX = 1 |
| MIN = 2 |
| COUNT = 3 |
| SUM = 4 |
| AVERAGE = 5 |
|
|
|
|
| class _Operator(enum.Enum): |
| """The boolean operators used by WikiSQL. Indexes match the data.""" |
|
|
| EQUALS = 0 |
| GREATER = 1 |
| LESSER = 2 |
|
|
|
|
| @dataclasses.dataclass |
| class _Condition: |
| """Represents an SQL where clauses (e.g A = "a" or B > 5).""" |
|
|
| column: Text |
| operator: _Operator |
| cmp_value: Any |
|
|
|
|
| _TOKENIZER = re.compile(r"\w+|[^\w\s]+", re.UNICODE | re.MULTILINE | re.DOTALL) |
|
|
|
|
| def _normalize_for_match(x): |
| return list(_TOKENIZER.findall(x.lower())) |
|
|
|
|
| def _compare(operator, src, tgt): |
| if operator == _Operator.EQUALS: |
| return src == tgt |
| elif operator == _Operator.GREATER: |
| return src > tgt |
| elif operator == _Operator.LESSER: |
| return src < tgt |
| raise ValueError(f"Unknown operator: {operator}") |
|
|
|
|
| def _parse_value(table, column, cell_value): |
| """Convert numeric values to floats and keeps everything else as string.""" |
| types = table["types"] |
| return _TYPE_CONVERTER[types[column]](cell_value) |
|
|
|
|
| def _is_string(x): |
| return isinstance(x, str) |
|
|
|
|
| def _respect_conditions(table, row, conditions): |
| """True if 'row' satisfies all 'conditions'.""" |
| for cond in conditions: |
| table_value = row[cond.column] |
|
|
| cmp_value = _parse_value(table, cond.column, cond.cmp_value) |
|
|
| if _is_string(table_value) and _is_string(cmp_value): |
| table_value = _normalize_for_match(table_value) |
| cmp_value = _normalize_for_match(cmp_value) |
|
|
| if not isinstance(table_value, type(cmp_value)): |
| raise ValueError("Type difference {} != {}".format(type(table_value), type(cmp_value))) |
|
|
| if not _compare(cond.operator, table_value, cmp_value): |
| return False |
| return True |
|
|
|
|
| def _get_float_answer(table, answer_coordinates, aggregation_op): |
| """Applies operation to produce reference float answer.""" |
| if not answer_coordinates: |
| if aggregation_op == _Aggregation.COUNT: |
| return 0.0 |
| else: |
| return EMPTY_ANSWER_AGG |
|
|
| |
| if aggregation_op == _Aggregation.COUNT: |
| return float(len(answer_coordinates)) |
|
|
| |
| values = [table["rows"][i][j] for (i, j) in answer_coordinates] |
| if len(answer_coordinates) == 1: |
| try: |
| return convert_to_float(values[0]) |
| except ValueError as e: |
| if aggregation_op != _Aggregation.NONE: |
| raise e |
|
|
| if aggregation_op == _Aggregation.NONE: |
| return None |
|
|
| |
| if not all((isinstance(v, (int, float)) for v in values)): |
| return None |
|
|
| if aggregation_op == _Aggregation.SUM: |
| return float(sum(values)) |
| elif aggregation_op == _Aggregation.AVERAGE: |
| return sum(values) / len(answer_coordinates) |
| else: |
| raise ValueError(f"Unknown aggregation: {aggregation_op}") |
|
|
|
|
| def _get_answer_coordinates(table, sql_query): |
| """Retrieves references coordinates by executing SQL.""" |
| |
| aggregation_op_index = sql_query["agg"] |
| if aggregation_op_index >= 3: |
| aggregation_op = _Aggregation(aggregation_op_index) |
| else: |
| aggregation_op = _Aggregation.NONE |
|
|
| target_column = sql_query["sel"] |
| conditions = [ |
| _Condition(column, _Operator(operator), cmp_value) |
| for column, operator, cmp_value in zip( |
| sql_query["conds"]["column_index"], sql_query["conds"]["operator_index"], sql_query["conds"]["condition"] |
| ) |
| ] |
|
|
| indices = [] |
| for row in range(len(table["rows"])): |
| if _respect_conditions(table, table["rows"][row], conditions): |
| indices.append((row, target_column)) |
|
|
| if not indices: |
| return [], aggregation_op |
|
|
| if len(indices) == 1: |
| return indices, aggregation_op |
|
|
| |
| if aggregation_op_index in (1, 2): |
| operators = {2: min, 1: max} |
| values = [(table["rows"][i][j], index) for index, (i, j) in enumerate(indices)] |
| reduced = functools.reduce(operators[sql_query["agg"]], values) |
|
|
| ret = [indices[reduced[1]]] |
| return ret, _Aggregation.NONE |
|
|
| return indices, aggregation_op |
|
|
|
|
| def _get_answer_text(table, answer_coordinates, float_answer): |
| if float_answer is not None: |
| return [str(float_answer)] |
| return [str(table["real_rows"][r][c]) for r, c in answer_coordinates] |
|
|
|
|
| def retrieve_wikisql_query_answer_tapas(table, example) -> List: |
| answer_coordinates, aggregation_op = _get_answer_coordinates(table, example) |
| float_answer = _get_float_answer(table, answer_coordinates, aggregation_op) |
| answer_text = _get_answer_text(table, answer_coordinates, float_answer) |
| |
| if len(answer_text) == 0: |
| answer_text = [EMPTY_ANSWER] |
| return answer_text |
|
|