| | """ |
| | Functions used in several different places. This file should not import from any other non-lib files to prevent |
| | circular dependencies. |
| | """ |
| |
|
| | import json |
| | import logging |
| | from copy import copy |
| | from typing import Any, Callable, Dict, Optional, Tuple, Union |
| |
|
| | TOP_LEVEL_IDENTIFIERS = {"description", "links", "properties"} |
| |
|
| |
|
| | def get_json_from_model_output(input_generated_json: str): |
| | """ |
| | Parses a string, potentially containing Markdown code fences, into a JSON object. |
| | |
| | This function attempts to extract and parse a JSON object from a string, |
| | often the output of a language model. It handles cases where the JSON |
| | is enclosed in Markdown code fences (```json ... ``` or ``` ... ```). |
| | If the initial parsing fails, it attempts a more robust parsing using |
| | `_get_valid_json_from_string` and |
| | logs debug messages indicating success or failure. If all attempts fail, |
| | it returns an empty dictionary. |
| | |
| | Args: |
| | input_generated_json: A string potentially containing a JSON object. |
| | |
| | Returns: |
| | A tuple containing: |
| | - The parsed JSON object (a dictionary) or an empty dictionary if parsing failed. |
| | - An integer representing the number of times parsing failed initially. |
| | """ |
| | originally_invalid_json_count = 0 |
| |
|
| | generated_json_attempt_1 = copy(input_generated_json) |
| | try: |
| | code_split = generated_json_attempt_1.split("```") |
| | if len(code_split) > 1: |
| | generated_json_attempt_1 = json.loads( |
| | ("```" + code_split[1]).replace("```json", "") |
| | ) |
| | else: |
| | generated_json_attempt_1 = json.loads( |
| | generated_json_attempt_1.replace("```json", "").replace("```", "") |
| | ) |
| | except Exception as exc: |
| | logging.debug(f"could not parse AI model generated output as JSON. Exc: {exc}.") |
| | |
| | generated_json_attempt_1 = {} |
| | some_value_in_attempt_1_is_not_a_dict = check_contents_valid( |
| | generated_json_attempt_1 |
| | ) |
| | attempt_1_failed = ( |
| | not bool(generated_json_attempt_1) or some_value_in_attempt_1_is_not_a_dict |
| | ) |
| | generated_json_attempt_2 = copy(input_generated_json) if attempt_1_failed else {} |
| | if attempt_1_failed: |
| | logging.debug( |
| | "Attempting to make output valid to obtain better metrics (this works in limited cases where " |
| | "the model output was simply cut off)" |
| | ) |
| | try: |
| | code_split = generated_json_attempt_2.split("```") |
| | if len(code_split) > 1: |
| | generated_json_attempt_2 = json.loads( |
| | _get_valid_json_from_string( |
| | ("```" + code_split[1]).replace("```json", "") |
| | ) |
| | ) |
| | else: |
| | stripped_output = generated_json_attempt_2.replace( |
| | "```json", "" |
| | ).replace("```", "") |
| | balance_outcome = attempt( |
| | json.loads, (balance_braces(stripped_output),) |
| | ) |
| | if "error" not in balance_outcome: |
| | generated_json_attempt_2 = balance_outcome |
| | else: |
| | generated_json_attempt_2 = json.loads( |
| | _get_valid_json_from_string(stripped_output) |
| | ) |
| |
|
| | logging.debug( |
| | "Success! Reconstructed valid JSON from unparseable model output. Continuing metrics comparison..." |
| | ) |
| | except Exception as exc: |
| | logging.debug( |
| | "Failed. Setting model output as empty JSON to enable metrics comparison." |
| | ) |
| | generated_json_attempt_2 = {} |
| | some_value_in_attempt_2_is_not_a_dict = ( |
| | attempt_1_failed |
| | and isinstance(generated_json_attempt_2, dict) |
| | and check_contents_valid(generated_json_attempt_2) |
| | ) |
| | if some_value_in_attempt_1_is_not_a_dict and some_value_in_attempt_2_is_not_a_dict: |
| | logging.debug(f"Could not recover model output json, aborting!") |
| | originally_invalid_json_count += 1 |
| | generated_json = ( |
| | generated_json_attempt_1 if not attempt_1_failed else generated_json_attempt_2 |
| | ) |
| | return generated_json, originally_invalid_json_count |
| |
|
| |
|
| | def check_contents_valid(generated_json_attempt_1: Union[list, dict]): |
| | """ |
| | Checks that the sub nodes are not lists or anything |
| | |
| | Args: |
| | generated_json_attempt_1 (Union[list, dict]): data to check |
| | |
| | Returns: |
| | truthy based on contents of input |
| | """ |
| | if isinstance(generated_json_attempt_1, list): |
| | for item in generated_json_attempt_1: |
| | if not isinstance(item, dict): |
| | return item |
| | return None |
| | elif ( |
| | isinstance(generated_json_attempt_1, dict) |
| | and "nodes" in generated_json_attempt_1.keys() |
| | ): |
| | for item in generated_json_attempt_1.get("nodes", []): |
| | if not isinstance(item, dict): |
| | return item |
| | return None |
| | else: |
| | for item in generated_json_attempt_1.values(): |
| | if not isinstance(item, dict): |
| | return item |
| | return None |
| |
|
| |
|
| | def _get_valid_json_from_string(s): |
| | """ |
| | Given a JSON string with potentially unclosed strings, arrays, or objects, close those things |
| | to hopefully be able to parse as valid JSON |
| | """ |
| | double_quotes = 0 |
| | single_quotes = 0 |
| | brackets = [] |
| |
|
| | for i, c in enumerate(s): |
| | if c == '"': |
| | double_quotes = 1 - double_quotes |
| | elif c == "'": |
| | single_quotes = 1 - single_quotes |
| | elif c in "{[": |
| | brackets.append((i, c)) |
| | elif c in "}]": |
| | if double_quotes == 0 and single_quotes == 0: |
| | if brackets: |
| | last_opened = brackets.pop() |
| | if (c == "}" and last_opened[1] != "{") or ( |
| | c == "]" and last_opened[1] != "[" |
| | ): |
| | raise ValueError( |
| | f"Mismatched brackets/quotes found: opened {last_opened[1]} @ {last_opened[0]} " |
| | f"but closed {c} @ {i}" |
| | ) |
| | else: |
| | |
| | pass |
| |
|
| | |
| | if s.strip().endswith(","): |
| | logging.debug("Removing ending ,") |
| | s = s.strip().rstrip(",") |
| |
|
| | closing_chars = "" |
| |
|
| | |
| | if double_quotes > 0: |
| | closing_chars += '"' |
| | if single_quotes > 0: |
| | closing_chars += "'" |
| |
|
| | |
| | while brackets: |
| | last_opened = brackets.pop() |
| | if last_opened[1] == "{": |
| | closing_chars += "}" |
| | elif last_opened[1] == "[": |
| | closing_chars += "]" |
| |
|
| | logging.debug(f"closing_chars: {closing_chars}") |
| |
|
| | output_string = s + closing_chars |
| |
|
| | try: |
| | json.loads(output_string) |
| | except Exception: |
| | logging.debug( |
| | "JSON string still fails to be parseable, attempting another modification..." |
| | ) |
| | |
| | |
| | new_closing_chars = "" |
| | found_first_double_quote = False |
| | for char in closing_chars: |
| | if not found_first_double_quote and char == '"': |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | new_closing_chars += '": ""' |
| | else: |
| | new_closing_chars += char |
| |
|
| | logging.debug(f"new closing_chars: {new_closing_chars}") |
| | output_string = s + new_closing_chars |
| |
|
| | return output_string |
| |
|
| |
|
| | def on_fail( |
| | outcome: Union[Any, Dict[str, str]], |
| | fallback: Union[Any, Callable] = None, |
| | ): |
| | """ |
| | Allows you to provide a fallback to recover from a failed outcome. |
| | |
| | Args: |
| | outcome |
| | fallback |
| | |
| | Returns: |
| | |
| | """ |
| | is_fail = isinstance(outcome, dict) and "error" in outcome |
| | is_callable = isinstance(fallback, Callable) |
| | if is_fail and is_callable: |
| | return fallback(outcome) |
| | elif is_fail: |
| | return fallback |
| | return outcome |
| |
|
| |
|
| | def attempt( |
| | func: Callable, |
| | args: Tuple[Any, ...] = (), |
| | kwargs: Optional[Dict[str, Any]] = None, |
| | ) -> Union[Any, Dict[str, str]]: |
| | """ |
| | Attempts to execute a function with the provided arguments. |
| | |
| | If the function raises an exception, the exception is caught and returned in a dict. |
| | Args: |
| | func (Callable): The function to execute. |
| | args (Tuple[Any, ...], optional): A tuple of positional arguments for the function. |
| | kwargs (Optional[Dict[str, Any]], optional): A dictionary of keyword arguments for the function. |
| | Returns: |
| | Function result or {"error": <msg>} response |
| | """ |
| | kwargs = kwargs or {} |
| | try: |
| | return func(*args, **kwargs) |
| | except Exception as exc: |
| | return {"error": str(exc)} |
| |
|
| |
|
| | def balance_braces(s: str) -> str: |
| | """ |
| | Primitive function that just tries to add '{}' style braces to try to recover |
| | the model string. |
| | |
| | Args: |
| | s(str): string to balance braces on. |
| | |
| | Returns: |
| | provided string with balanced braces if possible |
| | """ |
| | open_count = s.count("{") |
| | close_count = s.count("}") |
| |
|
| | if open_count > close_count: |
| | s += "}" * (open_count - close_count) |
| | elif close_count > open_count: |
| | s = "{" * (close_count - open_count) + s |
| |
|
| | return s |
| |
|
| |
|
| | def flatten_list(coll): |
| | flattened_data = [] |
| | for set_list in coll: |
| | flattened_data = flattened_data + list(set_list) |
| | return flattened_data |
| |
|
| |
|
| | def keep_errors(collection): |
| | """ |
| | Given a set of outcomes, keeps any that resulted in an error |
| | |
| | Args: |
| | collection (Collection): collection of outcomes to filter. |
| | |
| | Returns: |
| | All instances of the collection that contain an error response. |
| | """ |
| | return [instance for instance in collection if "error" in (instance or [])] |
| |
|