Maheshsr commited on
Commit
22b03eb
·
1 Parent(s): d044a27
.gitattributes CHANGED
@@ -32,4 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ database/SQLAgent_DEMO_DB_V1.db filter=lfs diff=lfs merge=lfs -text
database/.DS_Store ADDED
Binary file (6.15 kB). View file
 
database/SQLAgent_DEMO_DB_V1.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53cb4c4d3e88df7721e91ba4a4f5828e773168cbe4dc934810b57e4d179d264e
3
+ size 7061504
database/db_tables.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "name": "Patient",
4
+ "description": "The table stores the healthcare encounter information about patients. Each row has an unique patient information. The table contains the key information by distilling and flattening the FHIR encounter schema.",
5
+ "fields": { "identifier_value": [ "patient identifier that uniquely identifies patient and links a patient from this to other tables", "varchar"],
6
+ "identifier_use": [ "if the identifier is used for any specific purpose", "varchar" ],
7
+ "identifier_type": ["type of identifier, ususally means the source, MR' stands for medical record", "varchar" ],
8
+ "identifier_start_date": ["date on since when the identifier was valid", "date"],
9
+ "identifier_assigner": ["Identification value assignment authority", "varchar"],
10
+ "active": ["if he patient is active or not", "boolean"],
11
+ "official_name_family": ["family name of the patient", "varchar"],
12
+ "official_name_given": ["given name of the patient", "varchar"],
13
+ "usual_name_given": ["Short form of the given name", "varchar"],
14
+ "gender": ["patient's gender, male or female", "varchar"],
15
+ "birth_date": ["date of birth of the patient", "date"],
16
+ "Age": ["patient age", "integer"],
17
+ "home_address_line": ["patient's home address street", "varchar"],
18
+ "home_address_city": ["patient's home address city", "varchar"],
19
+ "home_address_district": ["patient's home county", "varchar"],
20
+ "home_address_state": ["patient's home state", "varchar"],
21
+ "home_address_postalCode": ["patient's home address zip code", "varchar"],
22
+ "home_address_period_start": ["start date of the patient's home address", "date"]
23
+ }
24
+ },
25
+ {
26
+ "name": "Encounter",
27
+ "description": "Table that stores all encounters of each patient with the healthcare providers. Every row indicate a single encounter.",
28
+ "fields": { "id": [ "encounter id that identifies an encounter uniquely", "varchar"],
29
+ "status": [ "encounter status, can be one of 'planned', ''completed', 'discharged', 'in-progress' ", "varchar" ],
30
+ "class": [ "indicates location setting of the encounter, valid values are: 'IMP' as inpatient, 'EMER' as emergency, 'AMB' as ambulatory, 'HH' as home health ", "varchar" ],
31
+ "priority": [ "indicates priority of the encounter, valid values are: 'UR' as urgent, 'A' as As soon as, 'S' as stat, 'R' as routine ", "varchar" ],
32
+ "subject_id": [ "indicates id of the patient associated with the encounter, should match with identifier_value of the Patient table", "varchar" ],
33
+ "service_provider_id": [ "contains the id of the care delivery organization where the patient had the encounter", "varchar" ],
34
+ "participant_actor_id": [ "contains the id of the provider associated with the care delivery organization who rendered the encounter", "varchar" ],
35
+ "diagnosis_condition_id": [ "contains list of diagnosis codes relevant to the patient of the encounter", "varchar" ],
36
+ "location_id": [ "location where the encounter happend or is happening or will be happening", "varchar" ],
37
+ "discharge_disposition": [ "how the patient was discharged at the end of the encounter", "varchar" ],
38
+ "diagnosis_condition_text": [ "clinical description of the diagnosis codes", "varchar" ],
39
+ "condition_class": [ "condition of the patient classified into specific broad classe., may contain multiple coditions. All lower case.", "varchar" ]
40
+ }
41
+ },
42
+ {
43
+ "name": "EpisodeOfCare",
44
+ "description": "contains continuous period of engagement by a care manager and/or a care management organization with the patient. Every row indicates a unique episide of care for a patient. One patient may have multiple episodes of care ",
45
+ "fields": { "identifier_value": [ "unique identifier of the episode", "varchar" ],
46
+ "type": [ "type of episode, can be disease management, post acute care or specialist referral", "varchar" ],
47
+ "diagnosis_condition_id": [ "ICD-10 diagnosis code assiciated with the episode of care", "varchar" ],
48
+ "subject_id": [ "id of the patient associated with episode, should have a corresponding 'identifier_value' in the Patient table", "varchar" ],
49
+ "managing_organization_id": [ "contains the id of the organization managing the episode", "varchar" ],
50
+ "care_manager_id": [ "contains the id of the care manager managing the episode", "varchar" ],
51
+ "care_team_id": [ "contains the id of the care team managing the episode. Care manager is part of the care team", "varchar" ]
52
+ }
53
+ },
54
+ {
55
+ "name": "RiskScore",
56
+ "description": "Contains the health risk scores of each of the patients. Only the latest risk score is stored. Every row has risk score of an unique patient",
57
+ "fields": { "patient_id": [ "identifier that uniquely identifies a patient. Matches with at least one identifier_value of Patient table.", "varchar"],
58
+ "risk_score": [ "decimal number between 0 and 1 indicating the risk score", "decimal number" ]
59
+ }
60
+ },
61
+ {
62
+ "name": "patient_sdoh_scores",
63
+ "description": "table stores the various social determinants of quality scores about a patient obtained through assessment. Each row indicate score about one patient and about one type of assessment",
64
+ "fields": { "Patient_Id": [ "unique identifier of the patient. Matches with at least one identifier_value of Patient table.", "varchar"],
65
+ "Assessment_Id": [ "name of the assessment", "varchar" ],
66
+ "Answer": [ "The actual answer provided in the assessment", "integer" ],
67
+ "Assessment_Type": [ "type of the assessment, can be 'Financial', 'Home', 'Food' and 'Physical'", "varchar" ],
68
+ "score": [ "Derived standardized score based on the answer provided", "decimal number" ]
69
+ }
70
+ }
71
+ ]
database/gravity_sdoh.sqbpro ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?><sqlb_project><db path="/Users/niloy/Documents/GitHub/SQL Agent Demo/sdoh/gravity_sdoh_observations.db" readonly="0" foreign_keys="1" case_sensitive_like="0" temp_store="0" wal_autocheckpoint="1000" synchronous="2"/><attached/><window><main_tabs open="structure browser pragmas query" current="0"/></window><tab_structure><column_width id="0" width="300"/><column_width id="1" width="0"/><column_width id="2" width="100"/><column_width id="3" width="10670"/><column_width id="4" width="0"/><expanded_item id="0" parent="1"/><expanded_item id="1" parent="1"/><expanded_item id="2" parent="1"/><expanded_item id="3" parent="1"/></tab_structure><tab_browse><current_table name="4,7:mainPatient"/><default_encoding codec=""/><browse_table_settings/></tab_browse><tab_sql><sql name="SQL 1">CREATE TABLE RiskScore (
2
+ patient_id TEXT, -- Reference to the patient
3
+ risk_score DECIMAL(4, 2), -- Risk score, a decimal value between 0.0 and 1.0
4
+ risk_score_date TEXT, -- Date of the risk score
5
+
6
+ PRIMARY KEY (patient_id, risk_score_date) -- Composite primary key: patient ID and risk score date
7
+ );
8
+ </sql><current_tab id="0"/></tab_sql></sqlb_project>
database/gravity_sdoh_observations.db ADDED
Binary file (86 kB). View file
 
gravity_sdoh_observations.db ADDED
Binary file (86 kB). View file
 
pages/__pycache__/solution.cpython-312.pyc CHANGED
Binary files a/pages/__pycache__/solution.cpython-312.pyc and b/pages/__pycache__/solution.cpython-312.pyc differ
 
pages/solution.py CHANGED
@@ -1,5 +1,5 @@
1
  import json
2
- # import sqlite3
3
  # import pyodbc
4
  import mysql.connector
5
  import boto3
@@ -26,8 +26,6 @@ from loguru import logger
26
  from st_aggrid import AgGrid, GridOptionsBuilder
27
  from datetime import datetime
28
 
29
- APP_TITLE = '**Social <br>Determinant<br>of Health**'
30
-
31
  # Initialize token storage
32
  token_file = "token_usage.json"
33
  if not os.path.exists(token_file):
@@ -64,6 +62,8 @@ def show_messages(message):
64
  success_msg.empty()
65
 
66
  # Locations of various files
 
 
67
  sql_dir = 'generated_sql/'
68
  method_dir = 'generated_method/'
69
  insight_lib = 'insight_library/'
@@ -72,6 +72,8 @@ report_path = 'Reports/'
72
  connection_string = "DefaultEndpointsProtocol=https;AccountName=phsstorageacc;AccountKey=cEvoESH5CknyeZtbe8eCFuebwr7lRFi1EyO8smA35i5EuoSOfnzRXX/4337Y743B05tQsGPoQbsr+AStNRWeBg==;EndpointSuffix=core.windows.net"
73
  container_name = "insights-lab"
74
  persona_list = ["Population Analyst", "SDoH Specialist"]
 
 
75
 
76
  def getBlobContent(dir_path):
77
  try:
@@ -447,22 +449,28 @@ def get_existing_insight(base_code, user_persona):
447
  insights_directory = f"insight_library/{user_persona}/{st.session_state.userId}/"
448
  try:
449
  blobs = container_client.list_blobs(name_starts_with=insights_directory)
450
- for blob in blobs:
 
 
 
451
  blob_name = blob.name # Extract the blob names
452
- print(blob_name)
453
  file_name_with_extension = blob_name.split('/')[-1]
454
  file_name = file_name_with_extension.split('.')[0]
 
455
  blob_client = container_client.get_blob_client(blob_name)
456
  blob_content = blob_client.download_blob().readall()
457
- print(blob_content)
458
  insight_data = json.loads(blob_content)
459
  if insight_data['base_code'] == base_code:
460
- logger.info("Existing insight found for base code: {}", base_code)
461
  return insight_data, file_name
462
- logger.info("No existing insight found for base code: {}", base_code)
 
 
 
463
  return None
464
  except Exception as e:
465
- logger.error("Error while retrieving insight: {}", e)
466
  return None
467
 
468
  def update_insight(insight_data, user_persona, file_number):
@@ -474,10 +482,10 @@ def update_insight(insight_data, user_persona, file_number):
474
  file_path = f"{user_directory}/{file_number}.json"
475
  file_content = json.dumps(insight_data, indent=4)
476
  container_client.upload_blob(file_path, data=file_content, overwrite=True)
477
- logger.info("Insight updated successfully: {}", file_number)
478
  return True
479
  except Exception as e:
480
- logger.error("Error while updating insight: {}", e)
481
  return False
482
 
483
  def save_insight(next_file_number, user_persona, insight_desc, base_prompt, base_code,selected_db, insight_prompt, insight_code, chart_prompt, chart_code):
@@ -590,44 +598,14 @@ def generate_sql(query, table_descriptions, table_details,selected_db):
590
  # logger.error("Error connecting to MySQL: {}", err)
591
  # return None
592
 
593
- def execute_sql(query, selected_db):
594
- update_config(selected_db)
595
- engine = create_sqlalchemy_engine()
596
- if engine:
597
- connection = engine.connect()
598
- logger.info(f"Connected to the database {selected_db}.")
599
- try:
600
- df = pd.read_sql_query(query, connection)
601
- logger.info("Query executed successfully.")
602
- return df
603
- except Exception as e:
604
- logger.error(f"Query execution failed: {e}")
605
- return pd.DataFrame()
606
- finally:
607
- connection.close()
608
- else:
609
- logger.error("Failed to create a SQLAlchemy engine.")
610
- return None
611
-
612
- # def execute_sql(query, selected_db, offset, limit=100):
613
  # update_config(selected_db)
614
  # engine = create_sqlalchemy_engine()
615
  # if engine:
616
  # connection = engine.connect()
617
  # logger.info(f"Connected to the database {selected_db}.")
618
  # try:
619
- # # Modify the query to use ROW_NUMBER() for pagination
620
- # paginated_query = f"""
621
- # WITH CTE AS (
622
- # SELECT *,
623
- # ROW_NUMBER() OVER (ORDER BY (SELECT NULL)) AS RowNum
624
- # FROM ({query.rstrip(';')}) AS subquery
625
- # )
626
- # SELECT *
627
- # FROM CTE
628
- # WHERE RowNum BETWEEN {offset + 1} AND {offset + limit};
629
- # """
630
- # df = pd.read_sql_query(paginated_query, connection)
631
  # logger.info("Query executed successfully.")
632
  # return df
633
  # except Exception as e:
@@ -638,10 +616,23 @@ def execute_sql(query, selected_db):
638
  # else:
639
  # logger.error("Failed to create a SQLAlchemy engine.")
640
  # return None
641
-
642
- # def fetch_data(query, selected_db, offset, limit):
643
- # df = execute_sql(query, selected_db, offset, limit)
644
- # return drop_duplicate_columns(df)
 
 
 
 
 
 
 
 
 
 
 
 
 
645
 
646
  def handle_retrieve_request(prompt):
647
  sql_generated = generate_sql(prompt, st.session_state['table_master'], st.session_state['table_details'], st.session_state['selected_db'])
@@ -830,12 +821,6 @@ def answer_guide_question(question, dframe, df_structure, selected_db):
830
  logger.debug("Code execution error state: {}", st.session_state['code_execution_error'])
831
  return result_df, last_method_num + 1, analysis_code
832
 
833
- # def get_metadata(table):
834
- # table_details = st.session_state['table_details'][table]
835
- # matadata = [[field, details[0], details[1]] for field, details in table_details.items()]
836
- # metadata_df = pd.DataFrame(matadata, columns=['Field Name', 'Field Description', 'Field Type'])
837
- # return metadata_df
838
-
839
  def generate_graph(query, df, df_structure,generate_graph):
840
  if query is None or df is None or df_structure is None:
841
  logger.error("generate_graph received None values for query, df, or df_structure")
@@ -974,79 +959,62 @@ def get_table_details(engine,selected_db):
974
  return tables_master_dict, tables_details_dict
975
 
976
  # Function to fetch database names from SQL Server
977
- def get_database_names():
978
- query = """
979
- SELECT name
980
- FROM sys.databases
981
- WHERE name NOT IN ('master', 'tempdb', 'model', 'msdb');
982
- """
983
- connection_string = (
984
- f"DRIVER={SQL_SERVER_CONFIG['driver']};"
985
- f"SERVER={SQL_SERVER_CONFIG['server']};"
986
- f"UID={SQL_SERVER_CONFIG['username']};" # Use SQL Server authentication username
987
- f"PWD={SQL_SERVER_CONFIG['password']}" # Use SQL Server authentication password
988
- )
989
- engine = create_engine(f"mssql+pyodbc:///?odbc_connect={connection_string}")
990
- try:
991
- with engine.connect() as conn:
992
- result = conn.execute(query)
993
- databases = [row['name'] for row in result]
994
- logger.info("Database names fetched successfully.")
995
- return databases
996
- except Exception as e:
997
- logger.error("Error fetching database names: {}", e)
998
- return []
999
-
1000
- def get_metadata(selected_table):
1001
- try:
1002
- metadata_df = pd.DataFrame(st.session_state['table_details'][selected_table])
1003
- logger.info("Metadata fetched for table: {}", selected_table)
1004
- return metadata_df
1005
- except Exception as e:
1006
- logger.error("Error fetching metadata for table {}: {}", selected_table, e)
1007
- return pd.DataFrame()
1008
-
1009
- # def load_data(sql_generated, selected_db):
1010
- # # Fetch data in chunks of 100 rows
1011
- # if 'offset' not in st.session_state:
1012
- # st.session_state['offset'] = 0
1013
-
1014
- # if 'data' not in st.session_state:
1015
- # st.session_state['data'] = pd.DataFrame() # Initialize as an empty DataFrame
1016
-
1017
- # new_data = fetch_data(sql_generated, selected_db, st.session_state['offset'], 100)
1018
- # if not new_data.empty:
1019
- # if st.session_state['offset'] == 0:
1020
- # st.session_state['data'] = new_data
1021
- # else:
1022
- # st.session_state['data'] = pd.concat([st.session_state['data'], new_data], ignore_index=True)
1023
-
1024
- # grid_options = get_ag_grid_options(st.session_state['data'])
1025
- # AgGrid(st.session_state['data'], gridOptions=grid_options, key=f'query_grid_{st.session_state["offset"]}', lazyloading=True)
1026
- # i=0
1027
- # if not new_data.empty :
1028
- # button_clicked = False
1029
- # while not button_clicked:
1030
- # i+=1
1031
- # if st.button('Load more', key=i):
1032
- # button_clicked = True
1033
- # st.write('Button clicked!')
1034
- # st.session_state['offset'] += 100
1035
- # load_data(sql_generated, selected_db)
1036
- # else:
1037
- # st.write('Waiting for button click...')
1038
- # time.sleep(1)
1039
- # # if st.button("Load more"):
1040
- # # logger.info(st.session_state['offset'])
1041
- # # logger.info("hi............................................................")
1042
- # # st.session_state['offset'] += 100
1043
- # # load_data(sql_generated, selected_db)
1044
- # # else:
1045
- # # logger.info(st.session_state['offset'])
1046
- # # logger.info("hi buttoon............................................................")
1047
- # else:
1048
- # logger.info(st.session_state['offset'])
1049
- # logger.info("hi next data............................................................")
1050
 
1051
  def compose_dataset():
1052
  if "messages" not in st.session_state:
@@ -1064,24 +1032,30 @@ def compose_dataset():
1064
  with col_cc:
1065
  st.markdown(APP_TITLE, unsafe_allow_html=True)
1066
 
1067
- databases = get_database_names()
1068
- selected_db = st.selectbox('Select Database:', [''] + databases)
 
1069
 
 
 
 
 
1070
  if selected_db:
1071
  if 'selected_db' in st.session_state and st.session_state['selected_db'] != selected_db:
1072
  # Clear session state data related to the previous database
1073
  st.session_state['messages'] = []
1074
  st.session_state['selected_table'] = None
1075
  logger.debug('Session state cleared due to database change.')
 
1076
 
1077
- update_config(selected_db)
1078
- engine = create_sqlalchemy_engine()
1079
 
1080
- if 'table_master' not in st.session_state or st.session_state.get('selected_db') != selected_db:
1081
- tables_master_dict, tables_details_dict = get_table_details(engine, selected_db)
1082
- st.session_state['table_master'] = tables_master_dict
1083
- st.session_state['table_details'] = tables_details_dict
1084
- st.session_state['selected_db'] = selected_db
1085
 
1086
  tables = list(st.session_state['table_master'].keys())
1087
  selected_table = st.selectbox('Tables available:', [''] + tables)
@@ -1095,7 +1069,7 @@ def compose_dataset():
1095
  st.session_state.messages.append({"role": "assistant", "type": "text", "content": table_desc})
1096
  st.session_state.messages.append({"role": "assistant", "type": "dataframe", "content": table_metadata_df})
1097
  logger.debug('Table metadata and description added to session state messages.')
1098
- # st.session_state.messages.append({"role": "assistant", "type": "text", "content": ""})
1099
  # display_paginated_dataframe(table_metadata_df, "table_metadata")
1100
  except Exception as e:
1101
  st.error("Please try again")
@@ -1166,16 +1140,7 @@ def compose_dataset():
1166
  st.write(f"Query saved in the library with id {st.session_state['retrieval_query_no']}.")
1167
  logger.info("Query saved in the library with id {}.", st.session_state['retrieval_query_no'])
1168
 
1169
- if 'graph_obj' not in st.session_state:
1170
- st.session_state['graph_obj'] = None
1171
- if 'graph_prompt' not in st.session_state:
1172
- st.session_state['graph_prompt'] = ''
1173
- if 'data_obj' not in st.session_state:
1174
- st.session_state['data_obj'] = None
1175
- if 'data_prompt' not in st.session_state:
1176
- st.session_state['data_prompt'] = ''
1177
- if 'code_execution_error' not in st.session_state:
1178
- st.session_state['code_execution_error'] = (None, None)
1179
 
1180
  def design_insight():
1181
  col_aa, col_bb, col_cc = st.columns([1, 4, 1], gap="small", vertical_alignment="center")
@@ -1186,6 +1151,17 @@ def design_insight():
1186
  st.markdown('**Select a dataset that you generated and ask for different types of tabular insight or graphical charts.**')
1187
  with col_cc:
1188
  st.markdown(APP_TITLE, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
1189
 
1190
  get_saved_query_blob_list()
1191
  selected_query = st.selectbox('Select a saved query', [""] + list(st.session_state['query_display_dict'].keys()))
@@ -1272,7 +1248,7 @@ def design_insight():
1272
  if st.session_state['data_obj'] is not None:
1273
  display_paginated_dataframe(st.session_state['data_obj'], "ag_grid_insight")
1274
  st.session_state['data_prompt'] = data_prompt
1275
-
1276
  with st.container():
1277
  st.subheader('Generate Graph')
1278
  graph_prompt_value = st.session_state.get('graph_prompt', '')
@@ -1296,7 +1272,10 @@ def design_insight():
1296
  logger.error("Error in generating graph: %s", e)
1297
  st.write("Error in generating graph, please try again")
1298
  else:
1299
- st.plotly_chart(st.session_state['graph_obj'], use_container_width=True)
 
 
 
1300
  st.session_state['graph_prompt'] = graph_prompt
1301
  else:
1302
  if st.session_state['graph_obj'] is not None:
@@ -1304,7 +1283,7 @@ def design_insight():
1304
  st.plotly_chart(st.session_state['graph_obj'], use_container_width=True)
1305
  except Exception as e:
1306
  st.write("Error in displaying graph, please try again")
1307
- logger.error("Error in displaying graph: %s", e)
1308
  with st.container():
1309
  if 'graph_obj' in st.session_state or 'data_obj' in st.session_state:
1310
  user_persona = st.selectbox('Select a persona to save the result of your exploration', persona_list)
@@ -1322,15 +1301,17 @@ def design_insight():
1322
  try:
1323
  result = get_existing_insight(base_code, user_persona)
1324
  if result:
1325
- existing_insight, file_number = result
1326
- existing_insight['prompt'][f'prompt_{len(existing_insight["prompt"]) + 1}'] = {
1327
- 'insight_prompt': insight_prompt,
1328
- 'insight_code': insight_code
1329
- }
1330
- existing_insight['chart'][f'chart_{len(existing_insight["chart"]) + 1}'] = {
1331
- 'chart_prompt': chart_prompt,
1332
- 'chart_code': chart_code
1333
- }
 
 
1334
  try:
1335
  update_insight(existing_insight, user_persona, file_number)
1336
  st.text('Insight updated with new Graph and/or Data.')
 
1
  import json
2
+ import sqlite3
3
  # import pyodbc
4
  import mysql.connector
5
  import boto3
 
26
  from st_aggrid import AgGrid, GridOptionsBuilder
27
  from datetime import datetime
28
 
 
 
29
  # Initialize token storage
30
  token_file = "token_usage.json"
31
  if not os.path.exists(token_file):
 
62
  success_msg.empty()
63
 
64
  # Locations of various files
65
+ APP_TITLE = '**Social <br>Determinant<br>of Health**'
66
+
67
  sql_dir = 'generated_sql/'
68
  method_dir = 'generated_method/'
69
  insight_lib = 'insight_library/'
 
72
  connection_string = "DefaultEndpointsProtocol=https;AccountName=phsstorageacc;AccountKey=cEvoESH5CknyeZtbe8eCFuebwr7lRFi1EyO8smA35i5EuoSOfnzRXX/4337Y743B05tQsGPoQbsr+AStNRWeBg==;EndpointSuffix=core.windows.net"
73
  container_name = "insights-lab"
74
  persona_list = ["Population Analyst", "SDoH Specialist"]
75
+ DB_List=["Patient SDOH"]
76
+
77
 
78
  def getBlobContent(dir_path):
79
  try:
 
449
  insights_directory = f"insight_library/{user_persona}/{st.session_state.userId}/"
450
  try:
451
  blobs = container_client.list_blobs(name_starts_with=insights_directory)
452
+ for index, blob in enumerate(blobs):
453
+ # Skip the first item
454
+ if index == 0:
455
+ continue
456
  blob_name = blob.name # Extract the blob names
 
457
  file_name_with_extension = blob_name.split('/')[-1]
458
  file_name = file_name_with_extension.split('.')[0]
459
+
460
  blob_client = container_client.get_blob_client(blob_name)
461
  blob_content = blob_client.download_blob().readall()
462
+
463
  insight_data = json.loads(blob_content)
464
  if insight_data['base_code'] == base_code:
465
+ logger.info("Existing insight found for base code: %s", base_code)
466
  return insight_data, file_name
467
+ logger.info("No existing insight found for base code: %s", base_code)
468
+ return None
469
+ except json.JSONDecodeError as e:
470
+ logger.error("Error while retrieving insight: %s", e)
471
  return None
472
  except Exception as e:
473
+ logger.error("Error while retrieving insight: %s", e)
474
  return None
475
 
476
  def update_insight(insight_data, user_persona, file_number):
 
482
  file_path = f"{user_directory}/{file_number}.json"
483
  file_content = json.dumps(insight_data, indent=4)
484
  container_client.upload_blob(file_path, data=file_content, overwrite=True)
485
+ logger.info("Insight updated successfully: %s", file_number)
486
  return True
487
  except Exception as e:
488
+ logger.error("Error while updating insight: %s", e)
489
  return False
490
 
491
  def save_insight(next_file_number, user_persona, insight_desc, base_prompt, base_code,selected_db, insight_prompt, insight_code, chart_prompt, chart_code):
 
598
  # logger.error("Error connecting to MySQL: {}", err)
599
  # return None
600
 
601
+ # def execute_sql(query, selected_db):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
  # update_config(selected_db)
603
  # engine = create_sqlalchemy_engine()
604
  # if engine:
605
  # connection = engine.connect()
606
  # logger.info(f"Connected to the database {selected_db}.")
607
  # try:
608
+ # df = pd.read_sql_query(query, connection)
 
 
 
 
 
 
 
 
 
 
 
609
  # logger.info("Query executed successfully.")
610
  # return df
611
  # except Exception as e:
 
616
  # else:
617
  # logger.error("Failed to create a SQLAlchemy engine.")
618
  # return None
619
+
620
+ def execute_sql(query,selected_db):
621
+ df = None
622
+ try:
623
+ conn = sqlite3.connect(selected_db)
624
+ curr = conn.cursor()
625
+ curr.execute(query)
626
+
627
+ results = curr.fetchall()
628
+ columns = [desc[0] for desc in curr.description]
629
+ df = pd.DataFrame(results, columns=columns).copy()
630
+ logger.info("Query executed successfully.")
631
+ except sqlite3.Error as e:
632
+ logger.error(f"Error while querying the DB : {e}")
633
+ finally:
634
+ conn.close()
635
+ return df
636
 
637
  def handle_retrieve_request(prompt):
638
  sql_generated = generate_sql(prompt, st.session_state['table_master'], st.session_state['table_details'], st.session_state['selected_db'])
 
821
  logger.debug("Code execution error state: {}", st.session_state['code_execution_error'])
822
  return result_df, last_method_num + 1, analysis_code
823
 
 
 
 
 
 
 
824
  def generate_graph(query, df, df_structure,generate_graph):
825
  if query is None or df is None or df_structure is None:
826
  logger.error("generate_graph received None values for query, df, or df_structure")
 
959
  return tables_master_dict, tables_details_dict
960
 
961
  # Function to fetch database names from SQL Server
962
+ # def get_database_names():
963
+ # query = """
964
+ # SELECT name
965
+ # FROM sys.databases
966
+ # WHERE name NOT IN ('master', 'tempdb', 'model', 'msdb');
967
+ # """
968
+ # connection_string = (
969
+ # f"DRIVER={SQL_SERVER_CONFIG['driver']};"
970
+ # f"SERVER={SQL_SERVER_CONFIG['server']};"
971
+ # f"UID={SQL_SERVER_CONFIG['username']};" # Use SQL Server authentication username
972
+ # f"PWD={SQL_SERVER_CONFIG['password']}" # Use SQL Server authentication password
973
+ # )
974
+ # engine = create_engine(f"mssql+pyodbc:///?odbc_connect={connection_string}")
975
+ # try:
976
+ # with engine.connect() as conn:
977
+ # result = conn.execute(query)
978
+ # databases = [row['name'] for row in result]
979
+ # logger.info("Database names fetched successfully.")
980
+ # return databases
981
+ # except Exception as e:
982
+ # logger.error("Error fetching database names: {}", e)
983
+ # return []
984
+
985
+ # def get_metadata(selected_table):
986
+ # try:
987
+ # metadata_df = pd.DataFrame(st.session_state['table_details'][selected_table])
988
+ # logger.info("Metadata fetched for table: {}", selected_table)
989
+ # return metadata_df
990
+ # except Exception as e:
991
+ # logger.error("Error fetching metadata for table {}: {}", selected_table, e)
992
+ # return pd.DataFrame()
993
+
994
+ def get_metadata(table):
995
+ table_details = st.session_state['table_details'][table]
996
+ matadata = [[field, details[0], details[1]] for field, details in table_details.items()]
997
+ metadata_df = pd.DataFrame(matadata, columns=['Field Name', 'Field Description', 'Field Type'])
998
+ return metadata_df
999
+
1000
+ def get_meta():
1001
+ print("---------------step1 -------------------------")
1002
+ if 'table_master' not in st.session_state:
1003
+ # load db metadata file
1004
+ print("---------------step2 -------------------------")
1005
+ db_js = json.load(open('./database/db_tables.json'))
1006
+ tables_master_dict = {}
1007
+ tables_details_dict = {}
1008
+ for j in db_js:
1009
+ tables_master_dict[j['name']] = j['description']
1010
+ tables_details_dict[j['name']] = j['fields']
1011
+ print(tables_details_dict)
1012
+ print(tables_master_dict)
1013
+ st.session_state['table_master'] = tables_master_dict
1014
+ st.session_state['table_details'] = tables_details_dict
1015
+ return
1016
+
1017
+ get_meta()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1018
 
1019
  def compose_dataset():
1020
  if "messages" not in st.session_state:
 
1032
  with col_cc:
1033
  st.markdown(APP_TITLE, unsafe_allow_html=True)
1034
 
1035
+ # databases = get_database_names()
1036
+ # selected_db = st.selectbox('Select Database:', [''] + databases)
1037
+ selected = st.selectbox('Select Database:', DB_List)
1038
 
1039
+ if selected == "Patient SDOH":
1040
+ selected_db = './gravity_sdoh_observations.db'
1041
+ st.session_state['selected_db'] = selected_db
1042
+
1043
  if selected_db:
1044
  if 'selected_db' in st.session_state and st.session_state['selected_db'] != selected_db:
1045
  # Clear session state data related to the previous database
1046
  st.session_state['messages'] = []
1047
  st.session_state['selected_table'] = None
1048
  logger.debug('Session state cleared due to database change.')
1049
+ st.session_state['selected_db'] = selected_db
1050
 
1051
+ # update_config(selected_db)
1052
+ # engine = create_sqlalchemy_engine()
1053
 
1054
+ # if 'table_master' not in st.session_state or st.session_state.get('selected_db') != selected_db:
1055
+ # tables_master_dict, tables_details_dict = get_table_details(engine, selected_db)
1056
+ # st.session_state['table_master'] = tables_master_dict
1057
+ # st.session_state['table_details'] = tables_details_dict
1058
+ # st.session_state['selected_db'] = selected_db
1059
 
1060
  tables = list(st.session_state['table_master'].keys())
1061
  selected_table = st.selectbox('Tables available:', [''] + tables)
 
1069
  st.session_state.messages.append({"role": "assistant", "type": "text", "content": table_desc})
1070
  st.session_state.messages.append({"role": "assistant", "type": "dataframe", "content": table_metadata_df})
1071
  logger.debug('Table metadata and description added to session state messages.')
1072
+ st.session_state.messages.append({"role": "assistant", "type": "text", "content": ""})
1073
  # display_paginated_dataframe(table_metadata_df, "table_metadata")
1074
  except Exception as e:
1075
  st.error("Please try again")
 
1140
  st.write(f"Query saved in the library with id {st.session_state['retrieval_query_no']}.")
1141
  logger.info("Query saved in the library with id {}.", st.session_state['retrieval_query_no'])
1142
 
1143
+
 
 
 
 
 
 
 
 
 
1144
 
1145
  def design_insight():
1146
  col_aa, col_bb, col_cc = st.columns([1, 4, 1], gap="small", vertical_alignment="center")
 
1151
  st.markdown('**Select a dataset that you generated and ask for different types of tabular insight or graphical charts.**')
1152
  with col_cc:
1153
  st.markdown(APP_TITLE, unsafe_allow_html=True)
1154
+
1155
+ if 'graph_obj' not in st.session_state:
1156
+ st.session_state['graph_obj'] = None
1157
+ if 'graph_prompt' not in st.session_state:
1158
+ st.session_state['graph_prompt'] = ''
1159
+ if 'data_obj' not in st.session_state:
1160
+ st.session_state['data_obj'] = None
1161
+ if 'data_prompt' not in st.session_state:
1162
+ st.session_state['data_prompt'] = ''
1163
+ if 'code_execution_error' not in st.session_state:
1164
+ st.session_state['code_execution_error'] = (None, None)
1165
 
1166
  get_saved_query_blob_list()
1167
  selected_query = st.selectbox('Select a saved query', [""] + list(st.session_state['query_display_dict'].keys()))
 
1248
  if st.session_state['data_obj'] is not None:
1249
  display_paginated_dataframe(st.session_state['data_obj'], "ag_grid_insight")
1250
  st.session_state['data_prompt'] = data_prompt
1251
+
1252
  with st.container():
1253
  st.subheader('Generate Graph')
1254
  graph_prompt_value = st.session_state.get('graph_prompt', '')
 
1272
  logger.error("Error in generating graph: %s", e)
1273
  st.write("Error in generating graph, please try again")
1274
  else:
1275
+ try:
1276
+ st.plotly_chart(st.session_state['graph_obj'], use_container_width=True)
1277
+ except Exception as e:
1278
+ st.write("Error in displaying graph, please try again")
1279
  st.session_state['graph_prompt'] = graph_prompt
1280
  else:
1281
  if st.session_state['graph_obj'] is not None:
 
1283
  st.plotly_chart(st.session_state['graph_obj'], use_container_width=True)
1284
  except Exception as e:
1285
  st.write("Error in displaying graph, please try again")
1286
+ logger.error("Error in displaying graph: %s", e)
1287
  with st.container():
1288
  if 'graph_obj' in st.session_state or 'data_obj' in st.session_state:
1289
  user_persona = st.selectbox('Select a persona to save the result of your exploration', persona_list)
 
1301
  try:
1302
  result = get_existing_insight(base_code, user_persona)
1303
  if result:
1304
+ existing_insight, file_number = result
1305
+ if insight_prompt and insight_code is not None:
1306
+ existing_insight['prompt'][f'prompt_{len(existing_insight["prompt"]) + 1}'] = {
1307
+ 'insight_prompt': insight_prompt,
1308
+ 'insight_code': insight_code
1309
+ }
1310
+ if chart_prompt and chart_code is not None:
1311
+ existing_insight['chart'][f'chart_{len(existing_insight["chart"]) + 1}'] = {
1312
+ 'chart_prompt': chart_prompt,
1313
+ 'chart_code': chart_code
1314
+ }
1315
  try:
1316
  update_insight(existing_insight, user_persona, file_number)
1317
  st.text('Insight updated with new Graph and/or Data.')
requirements.txt CHANGED
@@ -6,6 +6,7 @@ altair==5.4.1
6
  reportlab==4.2.4
7
  streamlit_navigation_bar==3.3.0
8
  altair_saver==0.5.0
 
9
  plotly
10
  boto3
11
  azure.storage.blob
 
6
  reportlab==4.2.4
7
  streamlit_navigation_bar==3.3.0
8
  altair_saver==0.5.0
9
+ sqlite3
10
  plotly
11
  boto3
12
  azure.storage.blob