feat(app): more examples, better parsing and error handling
Browse files- app.py +34 -14
- models.py +41 -0
- parsing.py +260 -0
- schema_to_sql.py +38 -0
- shared.py +308 -0
- utils.py +225 -4
app.py
CHANGED
|
@@ -1,21 +1,24 @@
|
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
import zipfile
|
| 4 |
|
| 5 |
-
import torch
|
| 6 |
import gradio as gr
|
| 7 |
import spaces
|
|
|
|
| 8 |
from peft import PeftConfig, PeftModel
|
| 9 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 10 |
|
| 11 |
from schema_to_sql import dd_to_sql
|
| 12 |
from utils import (
|
|
|
|
|
|
|
| 13 |
create_summary_tables,
|
| 14 |
-
|
|
|
|
| 15 |
get_prompt_with_files_uploaded,
|
| 16 |
)
|
| 17 |
-
|
| 18 |
-
from utils import MAX_NEW_TOKENS, TEMPERATURE
|
| 19 |
|
| 20 |
LOCAL_DIR = "tsvs"
|
| 21 |
ZIP_PATH = "tsvs.zip"
|
|
@@ -98,6 +101,11 @@ def run_llm_inference(model_prompt):
|
|
| 98 |
output_data_model = tokenizer.decode(outputs[0][prompt_length:])
|
| 99 |
output_data_model = output_data_model.split("<|eot_id|>")[0]
|
| 100 |
print(output_data_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
# Test output for JSON schema validity
|
| 102 |
try:
|
| 103 |
json.loads(output_data_model)
|
|
@@ -141,8 +149,16 @@ def gen_output_from_files_uploaded(filepaths: list[str] = None):
|
|
| 141 |
return model_response, sql, nodes_df, properties_df
|
| 142 |
|
| 143 |
|
| 144 |
-
def
|
| 145 |
-
model_response =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
model_response_json = json.loads(model_response)
|
| 147 |
sql, validation = dd_to_sql(model_response_json)
|
| 148 |
|
|
@@ -167,14 +183,12 @@ with gr.Blocks() as demo:
|
|
| 167 |
|
| 168 |
gr.Markdown("## (Optional) Get Sample TSV(s) to Upload")
|
| 169 |
|
| 170 |
-
gr.Markdown("### Example 1: A
|
| 171 |
download_btn = gr.DownloadButton(
|
| 172 |
-
label="Download
|
| 173 |
)
|
| 174 |
gr.Markdown("### Example 2: Many TSVs in a single .zip file.")
|
| 175 |
-
download_btn = gr.DownloadButton(
|
| 176 |
-
label="Download All Sample TSVs as .zip", value=zip_tsvs
|
| 177 |
-
)
|
| 178 |
gr.Markdown("You need to extract the .zip if you want to use them.")
|
| 179 |
|
| 180 |
gr.Markdown("## Upload TSVs With Headers (No Data Rows Required)")
|
|
@@ -222,10 +236,16 @@ with gr.Blocks() as demo:
|
|
| 222 |
outputs=[json_out, sql_out],
|
| 223 |
)
|
| 224 |
|
| 225 |
-
gr.Markdown("Run out of FreeGPU or having issues? Try the example
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
demo_btn.click(
|
| 228 |
-
fn=
|
| 229 |
outputs=[json_out, sql_out],
|
| 230 |
)
|
| 231 |
|
|
|
|
| 1 |
+
import copy
|
| 2 |
import json
|
| 3 |
import os
|
| 4 |
import zipfile
|
| 5 |
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
import spaces
|
| 8 |
+
import torch
|
| 9 |
from peft import PeftConfig, PeftModel
|
| 10 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 11 |
|
| 12 |
from schema_to_sql import dd_to_sql
|
| 13 |
from utils import (
|
| 14 |
+
MAX_NEW_TOKENS,
|
| 15 |
+
TEMPERATURE,
|
| 16 |
create_summary_tables,
|
| 17 |
+
get_example_ai_model_output_many,
|
| 18 |
+
get_example_ai_model_output_simple,
|
| 19 |
get_prompt_with_files_uploaded,
|
| 20 |
)
|
| 21 |
+
from parsing import try_parsing_actual_model_output
|
|
|
|
| 22 |
|
| 23 |
LOCAL_DIR = "tsvs"
|
| 24 |
ZIP_PATH = "tsvs.zip"
|
|
|
|
| 101 |
output_data_model = tokenizer.decode(outputs[0][prompt_length:])
|
| 102 |
output_data_model = output_data_model.split("<|eot_id|>")[0]
|
| 103 |
print(output_data_model)
|
| 104 |
+
|
| 105 |
+
parsed_output_data_model = try_parsing_actual_model_output(output_data_model)
|
| 106 |
+
if "error" not in parsed_output_data_model:
|
| 107 |
+
output_data_model = copy.deepcopy(parsed_output_data_model)
|
| 108 |
+
|
| 109 |
# Test output for JSON schema validity
|
| 110 |
try:
|
| 111 |
json.loads(output_data_model)
|
|
|
|
| 149 |
return model_response, sql, nodes_df, properties_df
|
| 150 |
|
| 151 |
|
| 152 |
+
def gen_output_from_example_simple():
|
| 153 |
+
model_response = get_example_ai_model_output_simple()
|
| 154 |
+
model_response_json = json.loads(model_response)
|
| 155 |
+
sql, validation = dd_to_sql(model_response_json)
|
| 156 |
+
|
| 157 |
+
return model_response, sql
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def gen_output_from_example_many():
|
| 161 |
+
model_response = get_example_ai_model_output_many()
|
| 162 |
model_response_json = json.loads(model_response)
|
| 163 |
sql, validation = dd_to_sql(model_response_json)
|
| 164 |
|
|
|
|
| 183 |
|
| 184 |
gr.Markdown("## (Optional) Get Sample TSV(s) to Upload")
|
| 185 |
|
| 186 |
+
gr.Markdown("### Example 1: A Simple TSV")
|
| 187 |
download_btn = gr.DownloadButton(
|
| 188 |
+
label="Download Simple TSV", value="sample_metadata.tsv"
|
| 189 |
)
|
| 190 |
gr.Markdown("### Example 2: Many TSVs in a single .zip file.")
|
| 191 |
+
download_btn = gr.DownloadButton(label="Download Many TSVs as .zip", value=zip_tsvs)
|
|
|
|
|
|
|
| 192 |
gr.Markdown("You need to extract the .zip if you want to use them.")
|
| 193 |
|
| 194 |
gr.Markdown("## Upload TSVs With Headers (No Data Rows Required)")
|
|
|
|
| 236 |
outputs=[json_out, sql_out],
|
| 237 |
)
|
| 238 |
|
| 239 |
+
gr.Markdown("Run out of FreeGPU or having issues? Try the example outputs!")
|
| 240 |
+
demo_btn2 = gr.Button("Manually Load 'Simple' Example Output from Previous Run")
|
| 241 |
+
demo_btn2.click(
|
| 242 |
+
fn=gen_output_from_example_simple,
|
| 243 |
+
outputs=[json_out, sql_out],
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
demo_btn = gr.Button("Manually Load 'Many' Example Output from Previous Run")
|
| 247 |
demo_btn.click(
|
| 248 |
+
fn=gen_output_from_example_many,
|
| 249 |
outputs=[json_out, sql_out],
|
| 250 |
)
|
| 251 |
|
models.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file should contain all relevant data-only information, it should contain minimal to no behavior.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, Field, model_validator
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Property(BaseModel):
|
| 11 |
+
name: Optional[str] = None
|
| 12 |
+
type: Optional[str] = None
|
| 13 |
+
description: Optional[str] = None
|
| 14 |
+
model_config = {"extra": "ignore"}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def upsert_dict(mapping, kvp):
|
| 18 |
+
k, v = kvp
|
| 19 |
+
mapping[k] = v
|
| 20 |
+
return mapping
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DataModel(BaseModel):
|
| 24 |
+
name: Optional[str] = None
|
| 25 |
+
description: Optional[str] = None
|
| 26 |
+
links: Optional[List[str]] = Field(default={})
|
| 27 |
+
properties: Optional[Dict[str, Property]] = Field(default={})
|
| 28 |
+
model_config = {"extra": "ignore"}
|
| 29 |
+
|
| 30 |
+
@model_validator(mode="before")
|
| 31 |
+
def _normalize_properties(cls, node_content):
|
| 32 |
+
props = node_content.get("properties", {})
|
| 33 |
+
if not isinstance(props, list):
|
| 34 |
+
return node_content
|
| 35 |
+
props_to_use = [prop for prop in props if "name" in prop]
|
| 36 |
+
prop_items = [Property(**prop) for prop in props_to_use]
|
| 37 |
+
prop_dict = {
|
| 38 |
+
prop.name: Property(type=prop.type, description=prop.description)
|
| 39 |
+
for prop in prop_items
|
| 40 |
+
}
|
| 41 |
+
return upsert_dict(node_content, ("properties", prop_dict))
|
parsing.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file handles reshaping raw results data into a list of nodes that match
|
| 3 |
+
what we expect the model output to be. That is to say, it handles parsing any raw data.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from functools import reduce
|
| 9 |
+
from typing import Any, Dict, Hashable, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
from models import DataModel
|
| 12 |
+
from shared import (
|
| 13 |
+
TOP_LEVEL_IDENTIFIERS,
|
| 14 |
+
attempt,
|
| 15 |
+
get_json_from_model_output,
|
| 16 |
+
keep_errors,
|
| 17 |
+
on_fail,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def handle_parsing_schema_files(expected_location: str, actual_location: str):
|
| 22 |
+
raw_reference, raw_generated = read_in_expected_and_actual_json(
|
| 23 |
+
expected_location, actual_location
|
| 24 |
+
)
|
| 25 |
+
read_errors = keep_errors((raw_reference, raw_generated))
|
| 26 |
+
if len(read_errors) > 0:
|
| 27 |
+
raise ValueError(f"Could not ingest raw data: {read_errors}")
|
| 28 |
+
generated_nodes_to_content = try_parsing_actual_model_output(raw_generated)
|
| 29 |
+
reference_nodes_to_content = derive_nodes_from_actual_json_output(
|
| 30 |
+
parse_json(raw_reference)
|
| 31 |
+
)
|
| 32 |
+
errors = keep_errors((reference_nodes_to_content, generated_nodes_to_content))
|
| 33 |
+
if len(errors) > 0:
|
| 34 |
+
raise ValueError(f"Error parsing files: {errors}")
|
| 35 |
+
return generated_nodes_to_content, reference_nodes_to_content
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def read_in_expected_and_actual_json(
|
| 39 |
+
expected_json_location: str, actual_json_location: str
|
| 40 |
+
):
|
| 41 |
+
unparsed_expected, err = read_json_as_text(expected_json_location)
|
| 42 |
+
if err:
|
| 43 |
+
return {"error": f"Could not read in expected file. e: {err}"}, None
|
| 44 |
+
unparsed_actual, err = read_json_as_text(actual_json_location)
|
| 45 |
+
if err:
|
| 46 |
+
return None, {"error": str(err)}
|
| 47 |
+
return unparsed_expected, unparsed_actual
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def read_json_as_text(file_path: str) -> Tuple[Optional[str], Optional[Exception]]:
|
| 51 |
+
"""
|
| 52 |
+
Reads the contents of a file as text without attempting to parse it as JSON.
|
| 53 |
+
"""
|
| 54 |
+
try:
|
| 55 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
| 56 |
+
return file.read(), None
|
| 57 |
+
except Exception as e:
|
| 58 |
+
return None, e
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def try_parsing_actual_model_output(model_output: str):
|
| 62 |
+
first_parse_json = parse_json(model_output)
|
| 63 |
+
if isinstance(first_parse_json, list):
|
| 64 |
+
return {
|
| 65 |
+
"error": "Could not parse json file. Model output should not be a list."
|
| 66 |
+
}
|
| 67 |
+
first_pass_failed = "error" in first_parse_json
|
| 68 |
+
recovered_json, errors = (
|
| 69 |
+
get_json_from_model_output(model_output) if first_pass_failed else ({}, 0)
|
| 70 |
+
)
|
| 71 |
+
if errors > 0:
|
| 72 |
+
return {"error": "Could not parse json file, no metrics to calculate"}
|
| 73 |
+
parsed_json = recovered_json if first_pass_failed else first_parse_json
|
| 74 |
+
node_derivation_outcome = on_fail(
|
| 75 |
+
attempt(derive_nodes_from_actual_json_output, (parsed_json,)), []
|
| 76 |
+
)
|
| 77 |
+
if not node_derivation_outcome:
|
| 78 |
+
return {"error": f"Could not derive nodes. Parsed json: {parsed_json}"}
|
| 79 |
+
return node_derivation_outcome
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def find_all_nodes(name_and_contents) -> Dict:
|
| 83 |
+
name, contents = name_and_contents
|
| 84 |
+
content_contains_nodes = bool(set(contents.keys()) & TOP_LEVEL_IDENTIFIERS)
|
| 85 |
+
if content_contains_nodes:
|
| 86 |
+
return dict([name_and_contents])
|
| 87 |
+
sub_dicts = list(filter(lambda kvp: isinstance(kvp[1], dict), contents.items()))
|
| 88 |
+
all_sub_nodes = {}
|
| 89 |
+
for sub_name_and_contents in sub_dicts:
|
| 90 |
+
sub_nodes = find_all_nodes(sub_name_and_contents)
|
| 91 |
+
all_sub_nodes.update(sub_nodes)
|
| 92 |
+
return all_sub_nodes
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def assign_to_key(key: Hashable):
|
| 96 |
+
def add_at_key(
|
| 97 |
+
assignment_mapping: Dict[Hashable, Any], mapping_to_add: Dict[Hashable, Any]
|
| 98 |
+
):
|
| 99 |
+
assignment_id = mapping_to_add[key]
|
| 100 |
+
assignment_mapping[assignment_id] = mapping_to_add
|
| 101 |
+
return assignment_mapping
|
| 102 |
+
|
| 103 |
+
return add_at_key
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
key_exists = lambda key: lambda mapping: key in mapping
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def handle_property_correction(all_nodes):
|
| 110 |
+
has_incorrect_property_shape = lambda node: (
|
| 111 |
+
"properties" in node[1] and isinstance(node[1].get("properties", {}), list)
|
| 112 |
+
)
|
| 113 |
+
nodes_that_need_corrected = dict(
|
| 114 |
+
filter(has_incorrect_property_shape, all_nodes.items())
|
| 115 |
+
)
|
| 116 |
+
nodes_with_corrected_properties = dict(
|
| 117 |
+
map(correct_properties_for_node, nodes_that_need_corrected.items())
|
| 118 |
+
)
|
| 119 |
+
all_corrected_nodes = {**all_nodes, **nodes_with_corrected_properties}
|
| 120 |
+
return all_corrected_nodes
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def correct_properties_for_node(node):
|
| 124 |
+
"""Maps node's property names to their actual content."""
|
| 125 |
+
node_name, node_data = node
|
| 126 |
+
properties = node_data["properties"]
|
| 127 |
+
identified_properties = list(filter(key_exists("name"), properties))
|
| 128 |
+
actual_properties = reduce(assign_to_key("name"), identified_properties, {})
|
| 129 |
+
node_data["properties"] = actual_properties
|
| 130 |
+
return node
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def derive_nodes_from_actual_json_output(json_data: Union[dict, list]):
|
| 134 |
+
"""
|
| 135 |
+
Find nodes from non-deterministic AI output
|
| 136 |
+
"""
|
| 137 |
+
if isinstance(json_data, list):
|
| 138 |
+
return {}
|
| 139 |
+
all_nodes = flatten_all_nodes(json_data)
|
| 140 |
+
if json_data.get("nodes", None) is None:
|
| 141 |
+
return all_nodes
|
| 142 |
+
nodes_with_properties_corrected = handle_property_correction(all_nodes)
|
| 143 |
+
return nodes_with_properties_corrected
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def flatten_all_nodes(json_data) -> Dict[Hashable, Any]:
|
| 147 |
+
"""
|
| 148 |
+
Model output could have nested nodes, this extracts them.
|
| 149 |
+
"""
|
| 150 |
+
nodes = json_data.get("nodes", None)
|
| 151 |
+
if nodes is None:
|
| 152 |
+
sub_nodes_list = [
|
| 153 |
+
find_all_nodes((name, contents))
|
| 154 |
+
for name, contents in json_data.items()
|
| 155 |
+
if isinstance(contents, dict)
|
| 156 |
+
]
|
| 157 |
+
else:
|
| 158 |
+
sub_nodes_list = [
|
| 159 |
+
find_all_nodes((node["name"], node))
|
| 160 |
+
for node in nodes
|
| 161 |
+
if isinstance(node, dict) and node.get("name") is not None
|
| 162 |
+
]
|
| 163 |
+
all_nodes = {k: v for sub_nodes in sub_nodes_list for k, v in sub_nodes.items()}
|
| 164 |
+
return all_nodes
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def aggregate_desc(acc, node):
|
| 168 |
+
node_name, node_info = node
|
| 169 |
+
desc = node_info.get("description", None)
|
| 170 |
+
if desc is not None and isinstance(desc, str):
|
| 171 |
+
acc[desc].add(node_name)
|
| 172 |
+
return acc
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def reform_links(outer_acc, node):
|
| 176 |
+
node_name, node_info = node
|
| 177 |
+
links = node_info.get("links", [])
|
| 178 |
+
collect_links_to_aggregator = lambda inner_acc, link: upsert_set(
|
| 179 |
+
inner_acc, (link, node_name)
|
| 180 |
+
)
|
| 181 |
+
links_to_node_names = reduce(collect_links_to_aggregator, links, outer_acc)
|
| 182 |
+
return links_to_node_names
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def lens(key, default=None):
|
| 186 |
+
"""Simple way to interface with the contents of a dict"""
|
| 187 |
+
return lambda d: d.get(key, default)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def aggregate_properties(outer_acc, node):
|
| 191 |
+
node_name, node_info = node
|
| 192 |
+
properties = node_info.get("properties", {})
|
| 193 |
+
is_list = isinstance(properties, list)
|
| 194 |
+
property_names = (
|
| 195 |
+
list(map(lens("name"), properties)) if is_list else list(properties.keys())
|
| 196 |
+
)
|
| 197 |
+
aggregate_properties = reduce(
|
| 198 |
+
lambda inner_acc, prop_name: upsert_set(inner_acc, (prop_name, node_name)),
|
| 199 |
+
property_names,
|
| 200 |
+
outer_acc,
|
| 201 |
+
)
|
| 202 |
+
return aggregate_properties
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def conform_node_to_expected_schema(name_to_data_model):
|
| 206 |
+
name, dm = name_to_data_model
|
| 207 |
+
conform_result = attempt(DataModel.model_validate, (dm,))
|
| 208 |
+
model_outcome = (
|
| 209 |
+
conform_result.model_dump(exclude_none=True, exclude_unset=True)
|
| 210 |
+
if "error" not in conform_result
|
| 211 |
+
else {}
|
| 212 |
+
)
|
| 213 |
+
errors = conform_result if "error" in conform_result else {}
|
| 214 |
+
return (name, model_outcome), errors
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def aggregate_parsed_file(nodes: dict):
|
| 218 |
+
conformed_nodes_result = [
|
| 219 |
+
conform_node_to_expected_schema(node) for node in nodes.items()
|
| 220 |
+
]
|
| 221 |
+
conformed_nodes = [node for node, errors in conformed_nodes_result if not errors]
|
| 222 |
+
aggregated_links = reduce(reform_links, conformed_nodes, defaultdict(set))
|
| 223 |
+
aggregated_properties = reduce(
|
| 224 |
+
aggregate_properties, conformed_nodes, defaultdict(set)
|
| 225 |
+
)
|
| 226 |
+
aggregated_descriptions = reduce(aggregate_desc, conformed_nodes, defaultdict(set))
|
| 227 |
+
|
| 228 |
+
parsed_results = {
|
| 229 |
+
"node_names": dict(conformed_nodes),
|
| 230 |
+
"links": aggregated_links,
|
| 231 |
+
"properties": aggregated_properties,
|
| 232 |
+
"description": aggregated_descriptions,
|
| 233 |
+
}
|
| 234 |
+
return parsed_results
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def upsert_set(accumulator, kvp):
|
| 238 |
+
key, value = kvp
|
| 239 |
+
if isinstance(key, Hashable):
|
| 240 |
+
accumulator[key].add(value)
|
| 241 |
+
return accumulator
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def parse_json(json_string: str) -> Optional[Union[dict, list]]:
|
| 245 |
+
"""
|
| 246 |
+
Safely parses a JSON string into a Python dictionary or list.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
json_string (str): The JSON string to parse
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
dict/list: Parsed JSON data if successful
|
| 253 |
+
dict["error"]: If parsing fails, provides error as string in response
|
| 254 |
+
"""
|
| 255 |
+
try:
|
| 256 |
+
return json.loads(json_string)
|
| 257 |
+
except json.JSONDecodeError as e:
|
| 258 |
+
return {"error": str(e)}
|
| 259 |
+
except TypeError as e:
|
| 260 |
+
return {"error": str(e)}
|
schema_to_sql.py
CHANGED
|
@@ -99,6 +99,43 @@ def get_foreign_table_and_field(prop_name, node_name):
|
|
| 99 |
return None, None
|
| 100 |
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
def generate_create_table(node, table_lookup):
|
| 103 |
"""
|
| 104 |
Returns SQL for the given AI generated node
|
|
@@ -189,6 +226,7 @@ def dd_to_sql(dd):
|
|
| 189 |
sql (str): SQL
|
| 190 |
validation (str): Validation result
|
| 191 |
"""
|
|
|
|
| 192 |
# Build a lookup for table columns in all nodes
|
| 193 |
table_lookup = {}
|
| 194 |
for node in dd["nodes"]:
|
|
|
|
| 99 |
return None, None
|
| 100 |
|
| 101 |
|
| 102 |
+
def transform_dd(dd):
|
| 103 |
+
"""
|
| 104 |
+
Returns transformed DD
|
| 105 |
+
|
| 106 |
+
This function takes AI generated DD and ensures all required fields are
|
| 107 |
+
present in properties and properties are dictionaries.
|
| 108 |
+
|
| 109 |
+
Parameters:
|
| 110 |
+
dd (dict): AI generated DD
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
dd (dict): Transformed DD
|
| 114 |
+
"""
|
| 115 |
+
for node in dd.get("nodes", []):
|
| 116 |
+
props = node.get("properties", [])
|
| 117 |
+
if props and all(isinstance(x, dict) for x in props):
|
| 118 |
+
prop_names = {p["name"] for p in props}
|
| 119 |
+
elif props and all(isinstance(x, str) for x in props):
|
| 120 |
+
prop_names = set(props)
|
| 121 |
+
# Upgrade to list of dicts
|
| 122 |
+
props = [
|
| 123 |
+
{"name": prop, "description": "", "type": "string"} for prop in props
|
| 124 |
+
]
|
| 125 |
+
else:
|
| 126 |
+
props = []
|
| 127 |
+
prop_names = set()
|
| 128 |
+
|
| 129 |
+
# Ensure each required field is present in properties
|
| 130 |
+
for req in node.get("required", []):
|
| 131 |
+
if req not in prop_names:
|
| 132 |
+
props.append({"name": req, "description": "", "type": "string"})
|
| 133 |
+
prop_names.add(req)
|
| 134 |
+
|
| 135 |
+
node["properties"] = props
|
| 136 |
+
return dd
|
| 137 |
+
|
| 138 |
+
|
| 139 |
def generate_create_table(node, table_lookup):
|
| 140 |
"""
|
| 141 |
Returns SQL for the given AI generated node
|
|
|
|
| 226 |
sql (str): SQL
|
| 227 |
validation (str): Validation result
|
| 228 |
"""
|
| 229 |
+
dd = transform_dd(dd)
|
| 230 |
# Build a lookup for table columns in all nodes
|
| 231 |
table_lookup = {}
|
| 232 |
for node in dd["nodes"]:
|
shared.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Functions used in several different places. This file should not import from any other non-lib files to prevent
|
| 3 |
+
circular dependencies.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
from copy import copy
|
| 9 |
+
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
TOP_LEVEL_IDENTIFIERS = {"description", "links", "properties"}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_json_from_model_output(input_generated_json: str):
|
| 15 |
+
"""
|
| 16 |
+
Parses a string, potentially containing Markdown code fences, into a JSON object.
|
| 17 |
+
|
| 18 |
+
This function attempts to extract and parse a JSON object from a string,
|
| 19 |
+
often the output of a language model. It handles cases where the JSON
|
| 20 |
+
is enclosed in Markdown code fences (```json ... ``` or ``` ... ```).
|
| 21 |
+
If the initial parsing fails, it attempts a more robust parsing using
|
| 22 |
+
`_get_valid_json_from_string` and
|
| 23 |
+
logs debug messages indicating success or failure. If all attempts fail,
|
| 24 |
+
it returns an empty dictionary.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
input_generated_json: A string potentially containing a JSON object.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
A tuple containing:
|
| 31 |
+
- The parsed JSON object (a dictionary) or an empty dictionary if parsing failed.
|
| 32 |
+
- An integer representing the number of times parsing failed initially.
|
| 33 |
+
"""
|
| 34 |
+
originally_invalid_json_count = 0
|
| 35 |
+
|
| 36 |
+
generated_json_attempt_1 = copy(input_generated_json)
|
| 37 |
+
try:
|
| 38 |
+
code_split = generated_json_attempt_1.split("```")
|
| 39 |
+
if len(code_split) > 1:
|
| 40 |
+
generated_json_attempt_1 = json.loads(
|
| 41 |
+
("```" + code_split[1]).replace("```json", "")
|
| 42 |
+
)
|
| 43 |
+
else:
|
| 44 |
+
generated_json_attempt_1 = json.loads(
|
| 45 |
+
generated_json_attempt_1.replace("```json", "").replace("```", "")
|
| 46 |
+
)
|
| 47 |
+
except Exception as exc:
|
| 48 |
+
logging.debug(f"could not parse AI model generated output as JSON. Exc: {exc}.")
|
| 49 |
+
# originally_invalid_json_count += 1
|
| 50 |
+
generated_json_attempt_1 = {}
|
| 51 |
+
some_value_in_attempt_1_is_not_a_dict = check_contents_valid(
|
| 52 |
+
generated_json_attempt_1
|
| 53 |
+
)
|
| 54 |
+
attempt_1_failed = (
|
| 55 |
+
not bool(generated_json_attempt_1) or some_value_in_attempt_1_is_not_a_dict
|
| 56 |
+
)
|
| 57 |
+
generated_json_attempt_2 = copy(input_generated_json) if attempt_1_failed else {}
|
| 58 |
+
if attempt_1_failed:
|
| 59 |
+
logging.debug(
|
| 60 |
+
"Attempting to make output valid to obtain better metrics (this works in limited cases where "
|
| 61 |
+
"the model output was simply cut off)"
|
| 62 |
+
)
|
| 63 |
+
try:
|
| 64 |
+
code_split = generated_json_attempt_2.split("```")
|
| 65 |
+
if len(code_split) > 1:
|
| 66 |
+
generated_json_attempt_2 = json.loads(
|
| 67 |
+
_get_valid_json_from_string(
|
| 68 |
+
("```" + code_split[1]).replace("```json", "")
|
| 69 |
+
)
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
stripped_output = generated_json_attempt_2.replace(
|
| 73 |
+
"```json", ""
|
| 74 |
+
).replace("```", "")
|
| 75 |
+
balance_outcome = attempt(
|
| 76 |
+
json.loads, (balance_braces(stripped_output),)
|
| 77 |
+
)
|
| 78 |
+
if "error" not in balance_outcome:
|
| 79 |
+
generated_json_attempt_2 = balance_outcome
|
| 80 |
+
else:
|
| 81 |
+
generated_json_attempt_2 = json.loads(
|
| 82 |
+
_get_valid_json_from_string(stripped_output)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
logging.debug(
|
| 86 |
+
"Success! Reconstructed valid JSON from unparseable model output. Continuing metrics comparison..."
|
| 87 |
+
)
|
| 88 |
+
except Exception as exc:
|
| 89 |
+
logging.debug(
|
| 90 |
+
"Failed. Setting model output as empty JSON to enable metrics comparison."
|
| 91 |
+
)
|
| 92 |
+
generated_json_attempt_2 = {}
|
| 93 |
+
some_value_in_attempt_2_is_not_a_dict = (
|
| 94 |
+
attempt_1_failed
|
| 95 |
+
and isinstance(generated_json_attempt_2, dict)
|
| 96 |
+
and check_contents_valid(generated_json_attempt_2)
|
| 97 |
+
)
|
| 98 |
+
if some_value_in_attempt_1_is_not_a_dict and some_value_in_attempt_2_is_not_a_dict:
|
| 99 |
+
logging.debug(f"Could not recover model output json, aborting!")
|
| 100 |
+
originally_invalid_json_count += 1
|
| 101 |
+
generated_json = (
|
| 102 |
+
generated_json_attempt_1 if not attempt_1_failed else generated_json_attempt_2
|
| 103 |
+
)
|
| 104 |
+
return generated_json, originally_invalid_json_count
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def check_contents_valid(generated_json_attempt_1: Union[list, dict]):
|
| 108 |
+
"""
|
| 109 |
+
Checks that the sub nodes are not lists or anything
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
generated_json_attempt_1 (Union[list, dict]): data to check
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
truthy based on contents of input
|
| 116 |
+
"""
|
| 117 |
+
if isinstance(generated_json_attempt_1, list):
|
| 118 |
+
for item in generated_json_attempt_1:
|
| 119 |
+
if not isinstance(item, dict):
|
| 120 |
+
return item
|
| 121 |
+
return None
|
| 122 |
+
elif (
|
| 123 |
+
isinstance(generated_json_attempt_1, dict)
|
| 124 |
+
and "nodes" in generated_json_attempt_1.keys()
|
| 125 |
+
):
|
| 126 |
+
for item in generated_json_attempt_1.get("nodes", []):
|
| 127 |
+
if not isinstance(item, dict):
|
| 128 |
+
return item
|
| 129 |
+
return None
|
| 130 |
+
else:
|
| 131 |
+
for item in generated_json_attempt_1.values():
|
| 132 |
+
if not isinstance(item, dict):
|
| 133 |
+
return item
|
| 134 |
+
return None
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _get_valid_json_from_string(s):
|
| 138 |
+
"""
|
| 139 |
+
Given a JSON string with potentially unclosed strings, arrays, or objects, close those things
|
| 140 |
+
to hopefully be able to parse as valid JSON
|
| 141 |
+
"""
|
| 142 |
+
double_quotes = 0
|
| 143 |
+
single_quotes = 0
|
| 144 |
+
brackets = []
|
| 145 |
+
|
| 146 |
+
for i, c in enumerate(s):
|
| 147 |
+
if c == '"':
|
| 148 |
+
double_quotes = 1 - double_quotes # Toggle between 0 and 1
|
| 149 |
+
elif c == "'":
|
| 150 |
+
single_quotes = 1 - single_quotes # Toggle between 0 and 1
|
| 151 |
+
elif c in "{[":
|
| 152 |
+
brackets.append((i, c))
|
| 153 |
+
elif c in "}]":
|
| 154 |
+
if double_quotes == 0 and single_quotes == 0:
|
| 155 |
+
if brackets:
|
| 156 |
+
last_opened = brackets.pop()
|
| 157 |
+
if (c == "}" and last_opened[1] != "{") or (
|
| 158 |
+
c == "]" and last_opened[1] != "["
|
| 159 |
+
):
|
| 160 |
+
raise ValueError(
|
| 161 |
+
f"Mismatched brackets/quotes found: opened {last_opened[1]} @ {last_opened[0]} "
|
| 162 |
+
f"but closed {c} @ {i}"
|
| 163 |
+
)
|
| 164 |
+
else:
|
| 165 |
+
# If no matching opening bracket, it's an error, but we can skip this for the task
|
| 166 |
+
pass
|
| 167 |
+
|
| 168 |
+
# Remove trailing comma if it exists
|
| 169 |
+
if s.strip().endswith(","):
|
| 170 |
+
logging.debug("Removing ending ,")
|
| 171 |
+
s = s.strip().rstrip(",")
|
| 172 |
+
|
| 173 |
+
closing_chars = ""
|
| 174 |
+
|
| 175 |
+
# Adding closing quotes if there are missing ones
|
| 176 |
+
if double_quotes > 0:
|
| 177 |
+
closing_chars += '"'
|
| 178 |
+
if single_quotes > 0:
|
| 179 |
+
closing_chars += "'"
|
| 180 |
+
|
| 181 |
+
# Add closing brackets for any unmatched opening brackets
|
| 182 |
+
while brackets:
|
| 183 |
+
last_opened = brackets.pop()
|
| 184 |
+
if last_opened[1] == "{":
|
| 185 |
+
closing_chars += "}"
|
| 186 |
+
elif last_opened[1] == "[":
|
| 187 |
+
closing_chars += "]"
|
| 188 |
+
|
| 189 |
+
logging.debug(f"closing_chars: {closing_chars}")
|
| 190 |
+
|
| 191 |
+
output_string = s + closing_chars
|
| 192 |
+
|
| 193 |
+
try:
|
| 194 |
+
json.loads(output_string)
|
| 195 |
+
except Exception:
|
| 196 |
+
logging.debug(
|
| 197 |
+
"JSON string still fails to be parseable, attempting another modification..."
|
| 198 |
+
)
|
| 199 |
+
# it's possible the closing quotes were on a property that didn't have a value, let's
|
| 200 |
+
# fix that and see if it works
|
| 201 |
+
new_closing_chars = ""
|
| 202 |
+
found_first_double_quote = False
|
| 203 |
+
for char in closing_chars:
|
| 204 |
+
if not found_first_double_quote and char == '"':
|
| 205 |
+
# for keys in objects with no value, append an empty value
|
| 206 |
+
#
|
| 207 |
+
# For example:
|
| 208 |
+
# ```
|
| 209 |
+
# {
|
| 210 |
+
# "properties": {
|
| 211 |
+
# "annotation
|
| 212 |
+
# ```
|
| 213 |
+
new_closing_chars += '": ""'
|
| 214 |
+
else:
|
| 215 |
+
new_closing_chars += char
|
| 216 |
+
|
| 217 |
+
logging.debug(f"new closing_chars: {new_closing_chars}")
|
| 218 |
+
output_string = s + new_closing_chars
|
| 219 |
+
|
| 220 |
+
return output_string
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def on_fail(
|
| 224 |
+
outcome: Union[Any, Dict[str, str]],
|
| 225 |
+
fallback: Union[Any, Callable] = None,
|
| 226 |
+
):
|
| 227 |
+
"""
|
| 228 |
+
Allows you to provide a fallback to recover from a failed outcome.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
outcome
|
| 232 |
+
fallback
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
|
| 236 |
+
"""
|
| 237 |
+
is_fail = isinstance(outcome, dict) and "error" in outcome
|
| 238 |
+
is_callable = isinstance(fallback, Callable)
|
| 239 |
+
if is_fail and is_callable:
|
| 240 |
+
return fallback(outcome)
|
| 241 |
+
elif is_fail:
|
| 242 |
+
return fallback
|
| 243 |
+
return outcome
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def attempt(
|
| 247 |
+
func: Callable,
|
| 248 |
+
args: Tuple[Any, ...] = (),
|
| 249 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
| 250 |
+
) -> Union[Any, Dict[str, str]]:
|
| 251 |
+
"""
|
| 252 |
+
Attempts to execute a function with the provided arguments.
|
| 253 |
+
|
| 254 |
+
If the function raises an exception, the exception is caught and returned in a dict.
|
| 255 |
+
Args:
|
| 256 |
+
func (Callable): The function to execute.
|
| 257 |
+
args (Tuple[Any, ...], optional): A tuple of positional arguments for the function.
|
| 258 |
+
kwargs (Optional[Dict[str, Any]], optional): A dictionary of keyword arguments for the function.
|
| 259 |
+
Returns:
|
| 260 |
+
Function result or {"error": <msg>} response
|
| 261 |
+
"""
|
| 262 |
+
kwargs = kwargs or {}
|
| 263 |
+
try:
|
| 264 |
+
return func(*args, **kwargs)
|
| 265 |
+
except Exception as exc:
|
| 266 |
+
return {"error": str(exc)}
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def balance_braces(s: str) -> str:
|
| 270 |
+
"""
|
| 271 |
+
Primitive function that just tries to add '{}' style braces to try to recover
|
| 272 |
+
the model string.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
s(str): string to balance braces on.
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
provided string with balanced braces if possible
|
| 279 |
+
"""
|
| 280 |
+
open_count = s.count("{")
|
| 281 |
+
close_count = s.count("}")
|
| 282 |
+
|
| 283 |
+
if open_count > close_count:
|
| 284 |
+
s += "}" * (open_count - close_count)
|
| 285 |
+
elif close_count > open_count:
|
| 286 |
+
s = "{" * (close_count - open_count) + s
|
| 287 |
+
|
| 288 |
+
return s
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def flatten_list(coll):
|
| 292 |
+
flattened_data = []
|
| 293 |
+
for set_list in coll:
|
| 294 |
+
flattened_data = flattened_data + list(set_list)
|
| 295 |
+
return flattened_data
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def keep_errors(collection):
|
| 299 |
+
"""
|
| 300 |
+
Given a set of outcomes, keeps any that resulted in an error
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
collection (Collection): collection of outcomes to filter.
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
All instances of the collection that contain an error response.
|
| 307 |
+
"""
|
| 308 |
+
return [instance for instance in collection if "error" in (instance or [])]
|
utils.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
| 1 |
-
import os
|
| 2 |
import csv
|
|
|
|
|
|
|
| 3 |
from string import Template
|
| 4 |
-
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
-
|
| 7 |
from PIL import Image
|
| 8 |
|
| 9 |
MAX_INPUT_TOKEN_LENGTH = 1024
|
|
@@ -108,7 +109,227 @@ def create_summary_tables(json_response):
|
|
| 108 |
return node_descriptions, node_property_descriptions
|
| 109 |
|
| 110 |
|
| 111 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
return """
|
| 113 |
{
|
| 114 |
"nodes": [
|
|
|
|
|
|
|
| 1 |
import csv
|
| 2 |
+
import os
|
| 3 |
+
from io import BytesIO
|
| 4 |
from string import Template
|
| 5 |
+
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
+
import networkx as nx
|
| 8 |
from PIL import Image
|
| 9 |
|
| 10 |
MAX_INPUT_TOKEN_LENGTH = 1024
|
|
|
|
| 109 |
return node_descriptions, node_property_descriptions
|
| 110 |
|
| 111 |
|
| 112 |
+
def get_example_ai_model_output_simple():
|
| 113 |
+
return """
|
| 114 |
+
{
|
| 115 |
+
"nodes": [
|
| 116 |
+
{
|
| 117 |
+
"name": "project",
|
| 118 |
+
"description": "Any specifically defined piece of work that is undertaken or attempted to meet a single requirement. (NCIt C47885)",
|
| 119 |
+
"links": [],
|
| 120 |
+
"required": [
|
| 121 |
+
"dbgap_accession_number",
|
| 122 |
+
"project.id"
|
| 123 |
+
],
|
| 124 |
+
"properties": [
|
| 125 |
+
{
|
| 126 |
+
"name": "awg_review",
|
| 127 |
+
"description": "Indicates that the project is an AWG project.",
|
| 128 |
+
"type": "boolean"
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"name": "data_citation",
|
| 132 |
+
"description": "The citation for the published dataset.",
|
| 133 |
+
"type": "string"
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"name": "data_contributor",
|
| 137 |
+
"description": "The name of the organization or individual that the contributed dataset belongs to.",
|
| 138 |
+
"type": "string"
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
"name": "data_description",
|
| 142 |
+
"description": "A brief, free-text description of the data files and associated metadata provided for this dataset.",
|
| 143 |
+
"type": "string"
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"name": "dbgap_accession_number",
|
| 147 |
+
"description": "The dbgap accession number provided for the project.",
|
| 148 |
+
"type": "string"
|
| 149 |
+
},
|
| 150 |
+
{
|
| 151 |
+
"name": "in_review",
|
| 152 |
+
"description": "Indicates that the project is under review by the submitter. Upload and data modification is disabled.",
|
| 153 |
+
"type": "boolean"
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"name": "intended_release_date",
|
| 157 |
+
"description": "Tracks a Project's intended release date.",
|
| 158 |
+
"type": "string"
|
| 159 |
+
},
|
| 160 |
+
{
|
| 161 |
+
"name": "project.id",
|
| 162 |
+
"description": "A unique identifier for records in this 'project' table.",
|
| 163 |
+
"type": "string"
|
| 164 |
+
},
|
| 165 |
+
{
|
| 166 |
+
"name": "protocol_number",
|
| 167 |
+
"description": "The project's protocol number or similar amount.",
|
| 168 |
+
"type": "string"
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"name": "releasable",
|
| 172 |
+
"description": "A project can only be released by the user when `releasable` is true.",
|
| 173 |
+
"type": "boolean"
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"name": "request_submission",
|
| 177 |
+
"description": "Indicates that the user has requested submission to the GDC for harmonization.",
|
| 178 |
+
"type": "boolean"
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"name": "research_design",
|
| 182 |
+
"description": "A summary of the goals of the research or a general description of the research's relationship to a clinical application.",
|
| 183 |
+
"type": "string"
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"name": "research_objective",
|
| 187 |
+
"description": "The general objective of the research; what the researchers hope to discover or determine.",
|
| 188 |
+
"type": "string"
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
"name": "research_setup",
|
| 192 |
+
"description": "A high level description of the setup used to achieve the research objectives.",
|
| 193 |
+
"type": "string"
|
| 194 |
+
}
|
| 195 |
+
]
|
| 196 |
+
},
|
| 197 |
+
{
|
| 198 |
+
"name": "dataset",
|
| 199 |
+
"description": "A set of metadata and associated data file objects originating from single a research study, clinical trial or patient cohort.",
|
| 200 |
+
"links": [
|
| 201 |
+
"project"
|
| 202 |
+
],
|
| 203 |
+
"required": [
|
| 204 |
+
"dataset.id",
|
| 205 |
+
"project.id"
|
| 206 |
+
],
|
| 207 |
+
"properties": [
|
| 208 |
+
{
|
| 209 |
+
"name": "Class_of_Case_Desc",
|
| 210 |
+
"description": "The text term used to describe the kind of clinical condition that can be defined based on objective criteria or by including all patient information from the case.",
|
| 211 |
+
"type": "string"
|
| 212 |
+
},
|
| 213 |
+
{
|
| 214 |
+
"name": "data_citation",
|
| 215 |
+
"description": "The citation for the published dataset.",
|
| 216 |
+
"type": "string"
|
| 217 |
+
},
|
| 218 |
+
{
|
| 219 |
+
"name": "full_name",
|
| 220 |
+
"description": "The full name or title of the dataset or publication.",
|
| 221 |
+
"type": "string"
|
| 222 |
+
},
|
| 223 |
+
{
|
| 224 |
+
"name": "longitudinal",
|
| 225 |
+
"description": "Indicates whether the dataset has longitudinal or time-series data.",
|
| 226 |
+
"type": "boolean"
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"name": "project.id",
|
| 230 |
+
"description": "Unique identifiers for records in the 'project' table that relate via this foreign key to records in this 'dataset' table.",
|
| 231 |
+
"type": "string"
|
| 232 |
+
},
|
| 233 |
+
{
|
| 234 |
+
"name": "dataset.id",
|
| 235 |
+
"description": "A unique identifier for records in this 'dataset' table.",
|
| 236 |
+
"type": "string"
|
| 237 |
+
}
|
| 238 |
+
]
|
| 239 |
+
},
|
| 240 |
+
{
|
| 241 |
+
"name": "subject",
|
| 242 |
+
"description": "The collection of all data related to a specific subject in the context of a specific experiment.",
|
| 243 |
+
"links": [
|
| 244 |
+
"project",
|
| 245 |
+
"dataset"
|
| 246 |
+
],
|
| 247 |
+
"required": [
|
| 248 |
+
"dataset.id",
|
| 249 |
+
"subject.id"
|
| 250 |
+
],
|
| 251 |
+
"properties": [
|
| 252 |
+
{
|
| 253 |
+
"name": "date_of_death",
|
| 254 |
+
"description": "The date of death of the subject in the context of a specific experiment.",
|
| 255 |
+
"type": "string"
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"name": "dataset.id",
|
| 259 |
+
"description": "Unique identifiers for records in the 'dataset' table that relate via this foreign key to records in this'subject' table.",
|
| 260 |
+
"type": "string"
|
| 261 |
+
},
|
| 262 |
+
{
|
| 263 |
+
"name": "subject.id",
|
| 264 |
+
"description": "A unique identifier for records in this'subject' table.",
|
| 265 |
+
"type": "string"
|
| 266 |
+
},
|
| 267 |
+
{
|
| 268 |
+
"name": "CancerRegistry_PatientID",
|
| 269 |
+
"description": "The patient unique id in the case registry.",
|
| 270 |
+
"type": "string"
|
| 271 |
+
},
|
| 272 |
+
{
|
| 273 |
+
"name": "Ethnicity",
|
| 274 |
+
"description": "An individual's self-described social and cultural grouping, specifically whether an individual describes themselves as Hispanic or Latino. The provided values are based on the categories defined by the U.S. Office of Management and Business and used by the U.S. Census Bureau.",
|
| 275 |
+
"type": "string"
|
| 276 |
+
},
|
| 277 |
+
{
|
| 278 |
+
"name": "Last_Name",
|
| 279 |
+
"description": "The surname(s) of individual(s) in study, in the form used for cultural or ethnic reasons (e.g., Spanish surnames)",
|
| 280 |
+
"type": "string"
|
| 281 |
+
},
|
| 282 |
+
{
|
| 283 |
+
"name": "Sex_Desc",
|
| 284 |
+
"description": "The description of the individual's gender.",
|
| 285 |
+
"type": "string"
|
| 286 |
+
},
|
| 287 |
+
{
|
| 288 |
+
"name": "project.id",
|
| 289 |
+
"description": "Unique identifiers for records in the 'project' table that relate via this foreign key to records in this'subject' table.",
|
| 290 |
+
"type": "string"
|
| 291 |
+
}
|
| 292 |
+
]
|
| 293 |
+
},
|
| 294 |
+
{
|
| 295 |
+
"name": "sample",
|
| 296 |
+
"description": "Any material sample taken from a biological entity for testing, diagnostic, propagation, treatment or research purposes, including a sample obtained from a living organism or taken from the biological object after halting of all its life functions. Biospecimen can contain one or more components including but not limited to cellular molecules, cells, tissues, organs, body fluids, embryos, and body excretory products.",
|
| 297 |
+
"links": [
|
| 298 |
+
"subject"
|
| 299 |
+
],
|
| 300 |
+
"required": [
|
| 301 |
+
"sample.id",
|
| 302 |
+
"subject.id"
|
| 303 |
+
],
|
| 304 |
+
"properties": [
|
| 305 |
+
{
|
| 306 |
+
"name": "body_fluid_code",
|
| 307 |
+
"description": "The code for the body fluid from which the sample was taken.",
|
| 308 |
+
"type": "string"
|
| 309 |
+
},
|
| 310 |
+
{
|
| 311 |
+
"name": "procedure_date",
|
| 312 |
+
"description": "Year the sample was taken for analysis.",
|
| 313 |
+
"type": "integer"
|
| 314 |
+
},
|
| 315 |
+
{
|
| 316 |
+
"name": "subject.id",
|
| 317 |
+
"description": "Unique identifiers for records in the'subject' table that relate via this foreign key to records in this'sample' table.",
|
| 318 |
+
"type": "string"
|
| 319 |
+
},
|
| 320 |
+
{
|
| 321 |
+
"name": "sample.id",
|
| 322 |
+
"description": "A unique identifier for records in this'sample' table.",
|
| 323 |
+
"type": "string"
|
| 324 |
+
}
|
| 325 |
+
]
|
| 326 |
+
}
|
| 327 |
+
]
|
| 328 |
+
}
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def get_example_ai_model_output_many():
|
| 333 |
return """
|
| 334 |
{
|
| 335 |
"nodes": [
|