avantol commited on
Commit
cc69c66
·
1 Parent(s): 5bf42c5

feat(app): more examples, better parsing and error handling

Browse files
Files changed (6) hide show
  1. app.py +34 -14
  2. models.py +41 -0
  3. parsing.py +260 -0
  4. schema_to_sql.py +38 -0
  5. shared.py +308 -0
  6. utils.py +225 -4
app.py CHANGED
@@ -1,21 +1,24 @@
 
1
  import json
2
  import os
3
  import zipfile
4
 
5
- import torch
6
  import gradio as gr
7
  import spaces
 
8
  from peft import PeftConfig, PeftModel
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
 
11
  from schema_to_sql import dd_to_sql
12
  from utils import (
 
 
13
  create_summary_tables,
14
- get_example_ai_model_output,
 
15
  get_prompt_with_files_uploaded,
16
  )
17
-
18
- from utils import MAX_NEW_TOKENS, TEMPERATURE
19
 
20
  LOCAL_DIR = "tsvs"
21
  ZIP_PATH = "tsvs.zip"
@@ -98,6 +101,11 @@ def run_llm_inference(model_prompt):
98
  output_data_model = tokenizer.decode(outputs[0][prompt_length:])
99
  output_data_model = output_data_model.split("<|eot_id|>")[0]
100
  print(output_data_model)
 
 
 
 
 
101
  # Test output for JSON schema validity
102
  try:
103
  json.loads(output_data_model)
@@ -141,8 +149,16 @@ def gen_output_from_files_uploaded(filepaths: list[str] = None):
141
  return model_response, sql, nodes_df, properties_df
142
 
143
 
144
- def gen_output_from_example():
145
- model_response = get_example_ai_model_output()
 
 
 
 
 
 
 
 
146
  model_response_json = json.loads(model_response)
147
  sql, validation = dd_to_sql(model_response_json)
148
 
@@ -167,14 +183,12 @@ with gr.Blocks() as demo:
167
 
168
  gr.Markdown("## (Optional) Get Sample TSV(s) to Upload")
169
 
170
- gr.Markdown("### Example 1: A single TSV")
171
  download_btn = gr.DownloadButton(
172
- label="Download Single TSV", value="sample_metadata.tsv"
173
  )
174
  gr.Markdown("### Example 2: Many TSVs in a single .zip file.")
175
- download_btn = gr.DownloadButton(
176
- label="Download All Sample TSVs as .zip", value=zip_tsvs
177
- )
178
  gr.Markdown("You need to extract the .zip if you want to use them.")
179
 
180
  gr.Markdown("## Upload TSVs With Headers (No Data Rows Required)")
@@ -222,10 +236,16 @@ with gr.Blocks() as demo:
222
  outputs=[json_out, sql_out],
223
  )
224
 
225
- gr.Markdown("Run out of FreeGPU or having issues? Try the example output!")
226
- demo_btn = gr.Button("Manually Load Example Output from Previous Run")
 
 
 
 
 
 
227
  demo_btn.click(
228
- fn=gen_output_from_example,
229
  outputs=[json_out, sql_out],
230
  )
231
 
 
1
+ import copy
2
  import json
3
  import os
4
  import zipfile
5
 
 
6
  import gradio as gr
7
  import spaces
8
+ import torch
9
  from peft import PeftConfig, PeftModel
10
  from transformers import AutoModelForCausalLM, AutoTokenizer
11
 
12
  from schema_to_sql import dd_to_sql
13
  from utils import (
14
+ MAX_NEW_TOKENS,
15
+ TEMPERATURE,
16
  create_summary_tables,
17
+ get_example_ai_model_output_many,
18
+ get_example_ai_model_output_simple,
19
  get_prompt_with_files_uploaded,
20
  )
21
+ from parsing import try_parsing_actual_model_output
 
22
 
23
  LOCAL_DIR = "tsvs"
24
  ZIP_PATH = "tsvs.zip"
 
101
  output_data_model = tokenizer.decode(outputs[0][prompt_length:])
102
  output_data_model = output_data_model.split("<|eot_id|>")[0]
103
  print(output_data_model)
104
+
105
+ parsed_output_data_model = try_parsing_actual_model_output(output_data_model)
106
+ if "error" not in parsed_output_data_model:
107
+ output_data_model = copy.deepcopy(parsed_output_data_model)
108
+
109
  # Test output for JSON schema validity
110
  try:
111
  json.loads(output_data_model)
 
149
  return model_response, sql, nodes_df, properties_df
150
 
151
 
152
+ def gen_output_from_example_simple():
153
+ model_response = get_example_ai_model_output_simple()
154
+ model_response_json = json.loads(model_response)
155
+ sql, validation = dd_to_sql(model_response_json)
156
+
157
+ return model_response, sql
158
+
159
+
160
+ def gen_output_from_example_many():
161
+ model_response = get_example_ai_model_output_many()
162
  model_response_json = json.loads(model_response)
163
  sql, validation = dd_to_sql(model_response_json)
164
 
 
183
 
184
  gr.Markdown("## (Optional) Get Sample TSV(s) to Upload")
185
 
186
+ gr.Markdown("### Example 1: A Simple TSV")
187
  download_btn = gr.DownloadButton(
188
+ label="Download Simple TSV", value="sample_metadata.tsv"
189
  )
190
  gr.Markdown("### Example 2: Many TSVs in a single .zip file.")
191
+ download_btn = gr.DownloadButton(label="Download Many TSVs as .zip", value=zip_tsvs)
 
 
192
  gr.Markdown("You need to extract the .zip if you want to use them.")
193
 
194
  gr.Markdown("## Upload TSVs With Headers (No Data Rows Required)")
 
236
  outputs=[json_out, sql_out],
237
  )
238
 
239
+ gr.Markdown("Run out of FreeGPU or having issues? Try the example outputs!")
240
+ demo_btn2 = gr.Button("Manually Load 'Simple' Example Output from Previous Run")
241
+ demo_btn2.click(
242
+ fn=gen_output_from_example_simple,
243
+ outputs=[json_out, sql_out],
244
+ )
245
+
246
+ demo_btn = gr.Button("Manually Load 'Many' Example Output from Previous Run")
247
  demo_btn.click(
248
+ fn=gen_output_from_example_many,
249
  outputs=[json_out, sql_out],
250
  )
251
 
models.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file should contain all relevant data-only information, it should contain minimal to no behavior.
3
+ """
4
+
5
+ from typing import Dict, List, Optional
6
+
7
+ from pydantic import BaseModel, Field, model_validator
8
+
9
+
10
+ class Property(BaseModel):
11
+ name: Optional[str] = None
12
+ type: Optional[str] = None
13
+ description: Optional[str] = None
14
+ model_config = {"extra": "ignore"}
15
+
16
+
17
+ def upsert_dict(mapping, kvp):
18
+ k, v = kvp
19
+ mapping[k] = v
20
+ return mapping
21
+
22
+
23
+ class DataModel(BaseModel):
24
+ name: Optional[str] = None
25
+ description: Optional[str] = None
26
+ links: Optional[List[str]] = Field(default={})
27
+ properties: Optional[Dict[str, Property]] = Field(default={})
28
+ model_config = {"extra": "ignore"}
29
+
30
+ @model_validator(mode="before")
31
+ def _normalize_properties(cls, node_content):
32
+ props = node_content.get("properties", {})
33
+ if not isinstance(props, list):
34
+ return node_content
35
+ props_to_use = [prop for prop in props if "name" in prop]
36
+ prop_items = [Property(**prop) for prop in props_to_use]
37
+ prop_dict = {
38
+ prop.name: Property(type=prop.type, description=prop.description)
39
+ for prop in prop_items
40
+ }
41
+ return upsert_dict(node_content, ("properties", prop_dict))
parsing.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file handles reshaping raw results data into a list of nodes that match
3
+ what we expect the model output to be. That is to say, it handles parsing any raw data.
4
+ """
5
+
6
+ import json
7
+ from collections import defaultdict
8
+ from functools import reduce
9
+ from typing import Any, Dict, Hashable, Optional, Tuple, Union
10
+
11
+ from models import DataModel
12
+ from shared import (
13
+ TOP_LEVEL_IDENTIFIERS,
14
+ attempt,
15
+ get_json_from_model_output,
16
+ keep_errors,
17
+ on_fail,
18
+ )
19
+
20
+
21
+ def handle_parsing_schema_files(expected_location: str, actual_location: str):
22
+ raw_reference, raw_generated = read_in_expected_and_actual_json(
23
+ expected_location, actual_location
24
+ )
25
+ read_errors = keep_errors((raw_reference, raw_generated))
26
+ if len(read_errors) > 0:
27
+ raise ValueError(f"Could not ingest raw data: {read_errors}")
28
+ generated_nodes_to_content = try_parsing_actual_model_output(raw_generated)
29
+ reference_nodes_to_content = derive_nodes_from_actual_json_output(
30
+ parse_json(raw_reference)
31
+ )
32
+ errors = keep_errors((reference_nodes_to_content, generated_nodes_to_content))
33
+ if len(errors) > 0:
34
+ raise ValueError(f"Error parsing files: {errors}")
35
+ return generated_nodes_to_content, reference_nodes_to_content
36
+
37
+
38
+ def read_in_expected_and_actual_json(
39
+ expected_json_location: str, actual_json_location: str
40
+ ):
41
+ unparsed_expected, err = read_json_as_text(expected_json_location)
42
+ if err:
43
+ return {"error": f"Could not read in expected file. e: {err}"}, None
44
+ unparsed_actual, err = read_json_as_text(actual_json_location)
45
+ if err:
46
+ return None, {"error": str(err)}
47
+ return unparsed_expected, unparsed_actual
48
+
49
+
50
+ def read_json_as_text(file_path: str) -> Tuple[Optional[str], Optional[Exception]]:
51
+ """
52
+ Reads the contents of a file as text without attempting to parse it as JSON.
53
+ """
54
+ try:
55
+ with open(file_path, "r", encoding="utf-8") as file:
56
+ return file.read(), None
57
+ except Exception as e:
58
+ return None, e
59
+
60
+
61
+ def try_parsing_actual_model_output(model_output: str):
62
+ first_parse_json = parse_json(model_output)
63
+ if isinstance(first_parse_json, list):
64
+ return {
65
+ "error": "Could not parse json file. Model output should not be a list."
66
+ }
67
+ first_pass_failed = "error" in first_parse_json
68
+ recovered_json, errors = (
69
+ get_json_from_model_output(model_output) if first_pass_failed else ({}, 0)
70
+ )
71
+ if errors > 0:
72
+ return {"error": "Could not parse json file, no metrics to calculate"}
73
+ parsed_json = recovered_json if first_pass_failed else first_parse_json
74
+ node_derivation_outcome = on_fail(
75
+ attempt(derive_nodes_from_actual_json_output, (parsed_json,)), []
76
+ )
77
+ if not node_derivation_outcome:
78
+ return {"error": f"Could not derive nodes. Parsed json: {parsed_json}"}
79
+ return node_derivation_outcome
80
+
81
+
82
+ def find_all_nodes(name_and_contents) -> Dict:
83
+ name, contents = name_and_contents
84
+ content_contains_nodes = bool(set(contents.keys()) & TOP_LEVEL_IDENTIFIERS)
85
+ if content_contains_nodes:
86
+ return dict([name_and_contents])
87
+ sub_dicts = list(filter(lambda kvp: isinstance(kvp[1], dict), contents.items()))
88
+ all_sub_nodes = {}
89
+ for sub_name_and_contents in sub_dicts:
90
+ sub_nodes = find_all_nodes(sub_name_and_contents)
91
+ all_sub_nodes.update(sub_nodes)
92
+ return all_sub_nodes
93
+
94
+
95
+ def assign_to_key(key: Hashable):
96
+ def add_at_key(
97
+ assignment_mapping: Dict[Hashable, Any], mapping_to_add: Dict[Hashable, Any]
98
+ ):
99
+ assignment_id = mapping_to_add[key]
100
+ assignment_mapping[assignment_id] = mapping_to_add
101
+ return assignment_mapping
102
+
103
+ return add_at_key
104
+
105
+
106
+ key_exists = lambda key: lambda mapping: key in mapping
107
+
108
+
109
+ def handle_property_correction(all_nodes):
110
+ has_incorrect_property_shape = lambda node: (
111
+ "properties" in node[1] and isinstance(node[1].get("properties", {}), list)
112
+ )
113
+ nodes_that_need_corrected = dict(
114
+ filter(has_incorrect_property_shape, all_nodes.items())
115
+ )
116
+ nodes_with_corrected_properties = dict(
117
+ map(correct_properties_for_node, nodes_that_need_corrected.items())
118
+ )
119
+ all_corrected_nodes = {**all_nodes, **nodes_with_corrected_properties}
120
+ return all_corrected_nodes
121
+
122
+
123
+ def correct_properties_for_node(node):
124
+ """Maps node's property names to their actual content."""
125
+ node_name, node_data = node
126
+ properties = node_data["properties"]
127
+ identified_properties = list(filter(key_exists("name"), properties))
128
+ actual_properties = reduce(assign_to_key("name"), identified_properties, {})
129
+ node_data["properties"] = actual_properties
130
+ return node
131
+
132
+
133
+ def derive_nodes_from_actual_json_output(json_data: Union[dict, list]):
134
+ """
135
+ Find nodes from non-deterministic AI output
136
+ """
137
+ if isinstance(json_data, list):
138
+ return {}
139
+ all_nodes = flatten_all_nodes(json_data)
140
+ if json_data.get("nodes", None) is None:
141
+ return all_nodes
142
+ nodes_with_properties_corrected = handle_property_correction(all_nodes)
143
+ return nodes_with_properties_corrected
144
+
145
+
146
+ def flatten_all_nodes(json_data) -> Dict[Hashable, Any]:
147
+ """
148
+ Model output could have nested nodes, this extracts them.
149
+ """
150
+ nodes = json_data.get("nodes", None)
151
+ if nodes is None:
152
+ sub_nodes_list = [
153
+ find_all_nodes((name, contents))
154
+ for name, contents in json_data.items()
155
+ if isinstance(contents, dict)
156
+ ]
157
+ else:
158
+ sub_nodes_list = [
159
+ find_all_nodes((node["name"], node))
160
+ for node in nodes
161
+ if isinstance(node, dict) and node.get("name") is not None
162
+ ]
163
+ all_nodes = {k: v for sub_nodes in sub_nodes_list for k, v in sub_nodes.items()}
164
+ return all_nodes
165
+
166
+
167
+ def aggregate_desc(acc, node):
168
+ node_name, node_info = node
169
+ desc = node_info.get("description", None)
170
+ if desc is not None and isinstance(desc, str):
171
+ acc[desc].add(node_name)
172
+ return acc
173
+
174
+
175
+ def reform_links(outer_acc, node):
176
+ node_name, node_info = node
177
+ links = node_info.get("links", [])
178
+ collect_links_to_aggregator = lambda inner_acc, link: upsert_set(
179
+ inner_acc, (link, node_name)
180
+ )
181
+ links_to_node_names = reduce(collect_links_to_aggregator, links, outer_acc)
182
+ return links_to_node_names
183
+
184
+
185
+ def lens(key, default=None):
186
+ """Simple way to interface with the contents of a dict"""
187
+ return lambda d: d.get(key, default)
188
+
189
+
190
+ def aggregate_properties(outer_acc, node):
191
+ node_name, node_info = node
192
+ properties = node_info.get("properties", {})
193
+ is_list = isinstance(properties, list)
194
+ property_names = (
195
+ list(map(lens("name"), properties)) if is_list else list(properties.keys())
196
+ )
197
+ aggregate_properties = reduce(
198
+ lambda inner_acc, prop_name: upsert_set(inner_acc, (prop_name, node_name)),
199
+ property_names,
200
+ outer_acc,
201
+ )
202
+ return aggregate_properties
203
+
204
+
205
+ def conform_node_to_expected_schema(name_to_data_model):
206
+ name, dm = name_to_data_model
207
+ conform_result = attempt(DataModel.model_validate, (dm,))
208
+ model_outcome = (
209
+ conform_result.model_dump(exclude_none=True, exclude_unset=True)
210
+ if "error" not in conform_result
211
+ else {}
212
+ )
213
+ errors = conform_result if "error" in conform_result else {}
214
+ return (name, model_outcome), errors
215
+
216
+
217
+ def aggregate_parsed_file(nodes: dict):
218
+ conformed_nodes_result = [
219
+ conform_node_to_expected_schema(node) for node in nodes.items()
220
+ ]
221
+ conformed_nodes = [node for node, errors in conformed_nodes_result if not errors]
222
+ aggregated_links = reduce(reform_links, conformed_nodes, defaultdict(set))
223
+ aggregated_properties = reduce(
224
+ aggregate_properties, conformed_nodes, defaultdict(set)
225
+ )
226
+ aggregated_descriptions = reduce(aggregate_desc, conformed_nodes, defaultdict(set))
227
+
228
+ parsed_results = {
229
+ "node_names": dict(conformed_nodes),
230
+ "links": aggregated_links,
231
+ "properties": aggregated_properties,
232
+ "description": aggregated_descriptions,
233
+ }
234
+ return parsed_results
235
+
236
+
237
+ def upsert_set(accumulator, kvp):
238
+ key, value = kvp
239
+ if isinstance(key, Hashable):
240
+ accumulator[key].add(value)
241
+ return accumulator
242
+
243
+
244
+ def parse_json(json_string: str) -> Optional[Union[dict, list]]:
245
+ """
246
+ Safely parses a JSON string into a Python dictionary or list.
247
+
248
+ Args:
249
+ json_string (str): The JSON string to parse
250
+
251
+ Returns:
252
+ dict/list: Parsed JSON data if successful
253
+ dict["error"]: If parsing fails, provides error as string in response
254
+ """
255
+ try:
256
+ return json.loads(json_string)
257
+ except json.JSONDecodeError as e:
258
+ return {"error": str(e)}
259
+ except TypeError as e:
260
+ return {"error": str(e)}
schema_to_sql.py CHANGED
@@ -99,6 +99,43 @@ def get_foreign_table_and_field(prop_name, node_name):
99
  return None, None
100
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def generate_create_table(node, table_lookup):
103
  """
104
  Returns SQL for the given AI generated node
@@ -189,6 +226,7 @@ def dd_to_sql(dd):
189
  sql (str): SQL
190
  validation (str): Validation result
191
  """
 
192
  # Build a lookup for table columns in all nodes
193
  table_lookup = {}
194
  for node in dd["nodes"]:
 
99
  return None, None
100
 
101
 
102
+ def transform_dd(dd):
103
+ """
104
+ Returns transformed DD
105
+
106
+ This function takes AI generated DD and ensures all required fields are
107
+ present in properties and properties are dictionaries.
108
+
109
+ Parameters:
110
+ dd (dict): AI generated DD
111
+
112
+ Returns:
113
+ dd (dict): Transformed DD
114
+ """
115
+ for node in dd.get("nodes", []):
116
+ props = node.get("properties", [])
117
+ if props and all(isinstance(x, dict) for x in props):
118
+ prop_names = {p["name"] for p in props}
119
+ elif props and all(isinstance(x, str) for x in props):
120
+ prop_names = set(props)
121
+ # Upgrade to list of dicts
122
+ props = [
123
+ {"name": prop, "description": "", "type": "string"} for prop in props
124
+ ]
125
+ else:
126
+ props = []
127
+ prop_names = set()
128
+
129
+ # Ensure each required field is present in properties
130
+ for req in node.get("required", []):
131
+ if req not in prop_names:
132
+ props.append({"name": req, "description": "", "type": "string"})
133
+ prop_names.add(req)
134
+
135
+ node["properties"] = props
136
+ return dd
137
+
138
+
139
  def generate_create_table(node, table_lookup):
140
  """
141
  Returns SQL for the given AI generated node
 
226
  sql (str): SQL
227
  validation (str): Validation result
228
  """
229
+ dd = transform_dd(dd)
230
  # Build a lookup for table columns in all nodes
231
  table_lookup = {}
232
  for node in dd["nodes"]:
shared.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions used in several different places. This file should not import from any other non-lib files to prevent
3
+ circular dependencies.
4
+ """
5
+
6
+ import json
7
+ import logging
8
+ from copy import copy
9
+ from typing import Any, Callable, Dict, Optional, Tuple, Union
10
+
11
+ TOP_LEVEL_IDENTIFIERS = {"description", "links", "properties"}
12
+
13
+
14
+ def get_json_from_model_output(input_generated_json: str):
15
+ """
16
+ Parses a string, potentially containing Markdown code fences, into a JSON object.
17
+
18
+ This function attempts to extract and parse a JSON object from a string,
19
+ often the output of a language model. It handles cases where the JSON
20
+ is enclosed in Markdown code fences (```json ... ``` or ``` ... ```).
21
+ If the initial parsing fails, it attempts a more robust parsing using
22
+ `_get_valid_json_from_string` and
23
+ logs debug messages indicating success or failure. If all attempts fail,
24
+ it returns an empty dictionary.
25
+
26
+ Args:
27
+ input_generated_json: A string potentially containing a JSON object.
28
+
29
+ Returns:
30
+ A tuple containing:
31
+ - The parsed JSON object (a dictionary) or an empty dictionary if parsing failed.
32
+ - An integer representing the number of times parsing failed initially.
33
+ """
34
+ originally_invalid_json_count = 0
35
+
36
+ generated_json_attempt_1 = copy(input_generated_json)
37
+ try:
38
+ code_split = generated_json_attempt_1.split("```")
39
+ if len(code_split) > 1:
40
+ generated_json_attempt_1 = json.loads(
41
+ ("```" + code_split[1]).replace("```json", "")
42
+ )
43
+ else:
44
+ generated_json_attempt_1 = json.loads(
45
+ generated_json_attempt_1.replace("```json", "").replace("```", "")
46
+ )
47
+ except Exception as exc:
48
+ logging.debug(f"could not parse AI model generated output as JSON. Exc: {exc}.")
49
+ # originally_invalid_json_count += 1
50
+ generated_json_attempt_1 = {}
51
+ some_value_in_attempt_1_is_not_a_dict = check_contents_valid(
52
+ generated_json_attempt_1
53
+ )
54
+ attempt_1_failed = (
55
+ not bool(generated_json_attempt_1) or some_value_in_attempt_1_is_not_a_dict
56
+ )
57
+ generated_json_attempt_2 = copy(input_generated_json) if attempt_1_failed else {}
58
+ if attempt_1_failed:
59
+ logging.debug(
60
+ "Attempting to make output valid to obtain better metrics (this works in limited cases where "
61
+ "the model output was simply cut off)"
62
+ )
63
+ try:
64
+ code_split = generated_json_attempt_2.split("```")
65
+ if len(code_split) > 1:
66
+ generated_json_attempt_2 = json.loads(
67
+ _get_valid_json_from_string(
68
+ ("```" + code_split[1]).replace("```json", "")
69
+ )
70
+ )
71
+ else:
72
+ stripped_output = generated_json_attempt_2.replace(
73
+ "```json", ""
74
+ ).replace("```", "")
75
+ balance_outcome = attempt(
76
+ json.loads, (balance_braces(stripped_output),)
77
+ )
78
+ if "error" not in balance_outcome:
79
+ generated_json_attempt_2 = balance_outcome
80
+ else:
81
+ generated_json_attempt_2 = json.loads(
82
+ _get_valid_json_from_string(stripped_output)
83
+ )
84
+
85
+ logging.debug(
86
+ "Success! Reconstructed valid JSON from unparseable model output. Continuing metrics comparison..."
87
+ )
88
+ except Exception as exc:
89
+ logging.debug(
90
+ "Failed. Setting model output as empty JSON to enable metrics comparison."
91
+ )
92
+ generated_json_attempt_2 = {}
93
+ some_value_in_attempt_2_is_not_a_dict = (
94
+ attempt_1_failed
95
+ and isinstance(generated_json_attempt_2, dict)
96
+ and check_contents_valid(generated_json_attempt_2)
97
+ )
98
+ if some_value_in_attempt_1_is_not_a_dict and some_value_in_attempt_2_is_not_a_dict:
99
+ logging.debug(f"Could not recover model output json, aborting!")
100
+ originally_invalid_json_count += 1
101
+ generated_json = (
102
+ generated_json_attempt_1 if not attempt_1_failed else generated_json_attempt_2
103
+ )
104
+ return generated_json, originally_invalid_json_count
105
+
106
+
107
+ def check_contents_valid(generated_json_attempt_1: Union[list, dict]):
108
+ """
109
+ Checks that the sub nodes are not lists or anything
110
+
111
+ Args:
112
+ generated_json_attempt_1 (Union[list, dict]): data to check
113
+
114
+ Returns:
115
+ truthy based on contents of input
116
+ """
117
+ if isinstance(generated_json_attempt_1, list):
118
+ for item in generated_json_attempt_1:
119
+ if not isinstance(item, dict):
120
+ return item
121
+ return None
122
+ elif (
123
+ isinstance(generated_json_attempt_1, dict)
124
+ and "nodes" in generated_json_attempt_1.keys()
125
+ ):
126
+ for item in generated_json_attempt_1.get("nodes", []):
127
+ if not isinstance(item, dict):
128
+ return item
129
+ return None
130
+ else:
131
+ for item in generated_json_attempt_1.values():
132
+ if not isinstance(item, dict):
133
+ return item
134
+ return None
135
+
136
+
137
+ def _get_valid_json_from_string(s):
138
+ """
139
+ Given a JSON string with potentially unclosed strings, arrays, or objects, close those things
140
+ to hopefully be able to parse as valid JSON
141
+ """
142
+ double_quotes = 0
143
+ single_quotes = 0
144
+ brackets = []
145
+
146
+ for i, c in enumerate(s):
147
+ if c == '"':
148
+ double_quotes = 1 - double_quotes # Toggle between 0 and 1
149
+ elif c == "'":
150
+ single_quotes = 1 - single_quotes # Toggle between 0 and 1
151
+ elif c in "{[":
152
+ brackets.append((i, c))
153
+ elif c in "}]":
154
+ if double_quotes == 0 and single_quotes == 0:
155
+ if brackets:
156
+ last_opened = brackets.pop()
157
+ if (c == "}" and last_opened[1] != "{") or (
158
+ c == "]" and last_opened[1] != "["
159
+ ):
160
+ raise ValueError(
161
+ f"Mismatched brackets/quotes found: opened {last_opened[1]} @ {last_opened[0]} "
162
+ f"but closed {c} @ {i}"
163
+ )
164
+ else:
165
+ # If no matching opening bracket, it's an error, but we can skip this for the task
166
+ pass
167
+
168
+ # Remove trailing comma if it exists
169
+ if s.strip().endswith(","):
170
+ logging.debug("Removing ending ,")
171
+ s = s.strip().rstrip(",")
172
+
173
+ closing_chars = ""
174
+
175
+ # Adding closing quotes if there are missing ones
176
+ if double_quotes > 0:
177
+ closing_chars += '"'
178
+ if single_quotes > 0:
179
+ closing_chars += "'"
180
+
181
+ # Add closing brackets for any unmatched opening brackets
182
+ while brackets:
183
+ last_opened = brackets.pop()
184
+ if last_opened[1] == "{":
185
+ closing_chars += "}"
186
+ elif last_opened[1] == "[":
187
+ closing_chars += "]"
188
+
189
+ logging.debug(f"closing_chars: {closing_chars}")
190
+
191
+ output_string = s + closing_chars
192
+
193
+ try:
194
+ json.loads(output_string)
195
+ except Exception:
196
+ logging.debug(
197
+ "JSON string still fails to be parseable, attempting another modification..."
198
+ )
199
+ # it's possible the closing quotes were on a property that didn't have a value, let's
200
+ # fix that and see if it works
201
+ new_closing_chars = ""
202
+ found_first_double_quote = False
203
+ for char in closing_chars:
204
+ if not found_first_double_quote and char == '"':
205
+ # for keys in objects with no value, append an empty value
206
+ #
207
+ # For example:
208
+ # ```
209
+ # {
210
+ # "properties": {
211
+ # "annotation
212
+ # ```
213
+ new_closing_chars += '": ""'
214
+ else:
215
+ new_closing_chars += char
216
+
217
+ logging.debug(f"new closing_chars: {new_closing_chars}")
218
+ output_string = s + new_closing_chars
219
+
220
+ return output_string
221
+
222
+
223
+ def on_fail(
224
+ outcome: Union[Any, Dict[str, str]],
225
+ fallback: Union[Any, Callable] = None,
226
+ ):
227
+ """
228
+ Allows you to provide a fallback to recover from a failed outcome.
229
+
230
+ Args:
231
+ outcome
232
+ fallback
233
+
234
+ Returns:
235
+
236
+ """
237
+ is_fail = isinstance(outcome, dict) and "error" in outcome
238
+ is_callable = isinstance(fallback, Callable)
239
+ if is_fail and is_callable:
240
+ return fallback(outcome)
241
+ elif is_fail:
242
+ return fallback
243
+ return outcome
244
+
245
+
246
+ def attempt(
247
+ func: Callable,
248
+ args: Tuple[Any, ...] = (),
249
+ kwargs: Optional[Dict[str, Any]] = None,
250
+ ) -> Union[Any, Dict[str, str]]:
251
+ """
252
+ Attempts to execute a function with the provided arguments.
253
+
254
+ If the function raises an exception, the exception is caught and returned in a dict.
255
+ Args:
256
+ func (Callable): The function to execute.
257
+ args (Tuple[Any, ...], optional): A tuple of positional arguments for the function.
258
+ kwargs (Optional[Dict[str, Any]], optional): A dictionary of keyword arguments for the function.
259
+ Returns:
260
+ Function result or {"error": <msg>} response
261
+ """
262
+ kwargs = kwargs or {}
263
+ try:
264
+ return func(*args, **kwargs)
265
+ except Exception as exc:
266
+ return {"error": str(exc)}
267
+
268
+
269
+ def balance_braces(s: str) -> str:
270
+ """
271
+ Primitive function that just tries to add '{}' style braces to try to recover
272
+ the model string.
273
+
274
+ Args:
275
+ s(str): string to balance braces on.
276
+
277
+ Returns:
278
+ provided string with balanced braces if possible
279
+ """
280
+ open_count = s.count("{")
281
+ close_count = s.count("}")
282
+
283
+ if open_count > close_count:
284
+ s += "}" * (open_count - close_count)
285
+ elif close_count > open_count:
286
+ s = "{" * (close_count - open_count) + s
287
+
288
+ return s
289
+
290
+
291
+ def flatten_list(coll):
292
+ flattened_data = []
293
+ for set_list in coll:
294
+ flattened_data = flattened_data + list(set_list)
295
+ return flattened_data
296
+
297
+
298
+ def keep_errors(collection):
299
+ """
300
+ Given a set of outcomes, keeps any that resulted in an error
301
+
302
+ Args:
303
+ collection (Collection): collection of outcomes to filter.
304
+
305
+ Returns:
306
+ All instances of the collection that contain an error response.
307
+ """
308
+ return [instance for instance in collection if "error" in (instance or [])]
utils.py CHANGED
@@ -1,9 +1,10 @@
1
- import os
2
  import csv
 
 
3
  from string import Template
4
- import networkx as nx
5
  import matplotlib.pyplot as plt
6
- from io import BytesIO
7
  from PIL import Image
8
 
9
  MAX_INPUT_TOKEN_LENGTH = 1024
@@ -108,7 +109,227 @@ def create_summary_tables(json_response):
108
  return node_descriptions, node_property_descriptions
109
 
110
 
111
- def get_example_ai_model_output():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  return """
113
  {
114
  "nodes": [
 
 
1
  import csv
2
+ import os
3
+ from io import BytesIO
4
  from string import Template
5
+
6
  import matplotlib.pyplot as plt
7
+ import networkx as nx
8
  from PIL import Image
9
 
10
  MAX_INPUT_TOKEN_LENGTH = 1024
 
109
  return node_descriptions, node_property_descriptions
110
 
111
 
112
+ def get_example_ai_model_output_simple():
113
+ return """
114
+ {
115
+ "nodes": [
116
+ {
117
+ "name": "project",
118
+ "description": "Any specifically defined piece of work that is undertaken or attempted to meet a single requirement. (NCIt C47885)",
119
+ "links": [],
120
+ "required": [
121
+ "dbgap_accession_number",
122
+ "project.id"
123
+ ],
124
+ "properties": [
125
+ {
126
+ "name": "awg_review",
127
+ "description": "Indicates that the project is an AWG project.",
128
+ "type": "boolean"
129
+ },
130
+ {
131
+ "name": "data_citation",
132
+ "description": "The citation for the published dataset.",
133
+ "type": "string"
134
+ },
135
+ {
136
+ "name": "data_contributor",
137
+ "description": "The name of the organization or individual that the contributed dataset belongs to.",
138
+ "type": "string"
139
+ },
140
+ {
141
+ "name": "data_description",
142
+ "description": "A brief, free-text description of the data files and associated metadata provided for this dataset.",
143
+ "type": "string"
144
+ },
145
+ {
146
+ "name": "dbgap_accession_number",
147
+ "description": "The dbgap accession number provided for the project.",
148
+ "type": "string"
149
+ },
150
+ {
151
+ "name": "in_review",
152
+ "description": "Indicates that the project is under review by the submitter. Upload and data modification is disabled.",
153
+ "type": "boolean"
154
+ },
155
+ {
156
+ "name": "intended_release_date",
157
+ "description": "Tracks a Project's intended release date.",
158
+ "type": "string"
159
+ },
160
+ {
161
+ "name": "project.id",
162
+ "description": "A unique identifier for records in this 'project' table.",
163
+ "type": "string"
164
+ },
165
+ {
166
+ "name": "protocol_number",
167
+ "description": "The project's protocol number or similar amount.",
168
+ "type": "string"
169
+ },
170
+ {
171
+ "name": "releasable",
172
+ "description": "A project can only be released by the user when `releasable` is true.",
173
+ "type": "boolean"
174
+ },
175
+ {
176
+ "name": "request_submission",
177
+ "description": "Indicates that the user has requested submission to the GDC for harmonization.",
178
+ "type": "boolean"
179
+ },
180
+ {
181
+ "name": "research_design",
182
+ "description": "A summary of the goals of the research or a general description of the research's relationship to a clinical application.",
183
+ "type": "string"
184
+ },
185
+ {
186
+ "name": "research_objective",
187
+ "description": "The general objective of the research; what the researchers hope to discover or determine.",
188
+ "type": "string"
189
+ },
190
+ {
191
+ "name": "research_setup",
192
+ "description": "A high level description of the setup used to achieve the research objectives.",
193
+ "type": "string"
194
+ }
195
+ ]
196
+ },
197
+ {
198
+ "name": "dataset",
199
+ "description": "A set of metadata and associated data file objects originating from single a research study, clinical trial or patient cohort.",
200
+ "links": [
201
+ "project"
202
+ ],
203
+ "required": [
204
+ "dataset.id",
205
+ "project.id"
206
+ ],
207
+ "properties": [
208
+ {
209
+ "name": "Class_of_Case_Desc",
210
+ "description": "The text term used to describe the kind of clinical condition that can be defined based on objective criteria or by including all patient information from the case.",
211
+ "type": "string"
212
+ },
213
+ {
214
+ "name": "data_citation",
215
+ "description": "The citation for the published dataset.",
216
+ "type": "string"
217
+ },
218
+ {
219
+ "name": "full_name",
220
+ "description": "The full name or title of the dataset or publication.",
221
+ "type": "string"
222
+ },
223
+ {
224
+ "name": "longitudinal",
225
+ "description": "Indicates whether the dataset has longitudinal or time-series data.",
226
+ "type": "boolean"
227
+ },
228
+ {
229
+ "name": "project.id",
230
+ "description": "Unique identifiers for records in the 'project' table that relate via this foreign key to records in this 'dataset' table.",
231
+ "type": "string"
232
+ },
233
+ {
234
+ "name": "dataset.id",
235
+ "description": "A unique identifier for records in this 'dataset' table.",
236
+ "type": "string"
237
+ }
238
+ ]
239
+ },
240
+ {
241
+ "name": "subject",
242
+ "description": "The collection of all data related to a specific subject in the context of a specific experiment.",
243
+ "links": [
244
+ "project",
245
+ "dataset"
246
+ ],
247
+ "required": [
248
+ "dataset.id",
249
+ "subject.id"
250
+ ],
251
+ "properties": [
252
+ {
253
+ "name": "date_of_death",
254
+ "description": "The date of death of the subject in the context of a specific experiment.",
255
+ "type": "string"
256
+ },
257
+ {
258
+ "name": "dataset.id",
259
+ "description": "Unique identifiers for records in the 'dataset' table that relate via this foreign key to records in this'subject' table.",
260
+ "type": "string"
261
+ },
262
+ {
263
+ "name": "subject.id",
264
+ "description": "A unique identifier for records in this'subject' table.",
265
+ "type": "string"
266
+ },
267
+ {
268
+ "name": "CancerRegistry_PatientID",
269
+ "description": "The patient unique id in the case registry.",
270
+ "type": "string"
271
+ },
272
+ {
273
+ "name": "Ethnicity",
274
+ "description": "An individual's self-described social and cultural grouping, specifically whether an individual describes themselves as Hispanic or Latino. The provided values are based on the categories defined by the U.S. Office of Management and Business and used by the U.S. Census Bureau.",
275
+ "type": "string"
276
+ },
277
+ {
278
+ "name": "Last_Name",
279
+ "description": "The surname(s) of individual(s) in study, in the form used for cultural or ethnic reasons (e.g., Spanish surnames)",
280
+ "type": "string"
281
+ },
282
+ {
283
+ "name": "Sex_Desc",
284
+ "description": "The description of the individual's gender.",
285
+ "type": "string"
286
+ },
287
+ {
288
+ "name": "project.id",
289
+ "description": "Unique identifiers for records in the 'project' table that relate via this foreign key to records in this'subject' table.",
290
+ "type": "string"
291
+ }
292
+ ]
293
+ },
294
+ {
295
+ "name": "sample",
296
+ "description": "Any material sample taken from a biological entity for testing, diagnostic, propagation, treatment or research purposes, including a sample obtained from a living organism or taken from the biological object after halting of all its life functions. Biospecimen can contain one or more components including but not limited to cellular molecules, cells, tissues, organs, body fluids, embryos, and body excretory products.",
297
+ "links": [
298
+ "subject"
299
+ ],
300
+ "required": [
301
+ "sample.id",
302
+ "subject.id"
303
+ ],
304
+ "properties": [
305
+ {
306
+ "name": "body_fluid_code",
307
+ "description": "The code for the body fluid from which the sample was taken.",
308
+ "type": "string"
309
+ },
310
+ {
311
+ "name": "procedure_date",
312
+ "description": "Year the sample was taken for analysis.",
313
+ "type": "integer"
314
+ },
315
+ {
316
+ "name": "subject.id",
317
+ "description": "Unique identifiers for records in the'subject' table that relate via this foreign key to records in this'sample' table.",
318
+ "type": "string"
319
+ },
320
+ {
321
+ "name": "sample.id",
322
+ "description": "A unique identifier for records in this'sample' table.",
323
+ "type": "string"
324
+ }
325
+ ]
326
+ }
327
+ ]
328
+ }
329
+ """
330
+
331
+
332
+ def get_example_ai_model_output_many():
333
  return """
334
  {
335
  "nodes": [