Spaces:
Sleeping
Sleeping
sqllite
Browse files- .gitattributes +2 -1
- database/.DS_Store +0 -0
- database/SQLAgent_DEMO_DB_V1.db +3 -0
- database/db_tables.json +71 -0
- database/gravity_sdoh.sqbpro +8 -0
- database/gravity_sdoh_observations.db +0 -0
- gravity_sdoh_observations.db +0 -0
- pages/__pycache__/solution.cpython-312.pyc +0 -0
- pages/solution.py +139 -158
- requirements.txt +1 -0
.gitattributes
CHANGED
|
@@ -32,4 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
database/SQLAgent_DEMO_DB_V1.db filter=lfs diff=lfs merge=lfs -text
|
database/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
database/SQLAgent_DEMO_DB_V1.db
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:53cb4c4d3e88df7721e91ba4a4f5828e773168cbe4dc934810b57e4d179d264e
|
| 3 |
+
size 7061504
|
database/db_tables.json
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"name": "Patient",
|
| 4 |
+
"description": "The table stores the healthcare encounter information about patients. Each row has an unique patient information. The table contains the key information by distilling and flattening the FHIR encounter schema.",
|
| 5 |
+
"fields": { "identifier_value": [ "patient identifier that uniquely identifies patient and links a patient from this to other tables", "varchar"],
|
| 6 |
+
"identifier_use": [ "if the identifier is used for any specific purpose", "varchar" ],
|
| 7 |
+
"identifier_type": ["type of identifier, ususally means the source, MR' stands for medical record", "varchar" ],
|
| 8 |
+
"identifier_start_date": ["date on since when the identifier was valid", "date"],
|
| 9 |
+
"identifier_assigner": ["Identification value assignment authority", "varchar"],
|
| 10 |
+
"active": ["if he patient is active or not", "boolean"],
|
| 11 |
+
"official_name_family": ["family name of the patient", "varchar"],
|
| 12 |
+
"official_name_given": ["given name of the patient", "varchar"],
|
| 13 |
+
"usual_name_given": ["Short form of the given name", "varchar"],
|
| 14 |
+
"gender": ["patient's gender, male or female", "varchar"],
|
| 15 |
+
"birth_date": ["date of birth of the patient", "date"],
|
| 16 |
+
"Age": ["patient age", "integer"],
|
| 17 |
+
"home_address_line": ["patient's home address street", "varchar"],
|
| 18 |
+
"home_address_city": ["patient's home address city", "varchar"],
|
| 19 |
+
"home_address_district": ["patient's home county", "varchar"],
|
| 20 |
+
"home_address_state": ["patient's home state", "varchar"],
|
| 21 |
+
"home_address_postalCode": ["patient's home address zip code", "varchar"],
|
| 22 |
+
"home_address_period_start": ["start date of the patient's home address", "date"]
|
| 23 |
+
}
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"name": "Encounter",
|
| 27 |
+
"description": "Table that stores all encounters of each patient with the healthcare providers. Every row indicate a single encounter.",
|
| 28 |
+
"fields": { "id": [ "encounter id that identifies an encounter uniquely", "varchar"],
|
| 29 |
+
"status": [ "encounter status, can be one of 'planned', ''completed', 'discharged', 'in-progress' ", "varchar" ],
|
| 30 |
+
"class": [ "indicates location setting of the encounter, valid values are: 'IMP' as inpatient, 'EMER' as emergency, 'AMB' as ambulatory, 'HH' as home health ", "varchar" ],
|
| 31 |
+
"priority": [ "indicates priority of the encounter, valid values are: 'UR' as urgent, 'A' as As soon as, 'S' as stat, 'R' as routine ", "varchar" ],
|
| 32 |
+
"subject_id": [ "indicates id of the patient associated with the encounter, should match with identifier_value of the Patient table", "varchar" ],
|
| 33 |
+
"service_provider_id": [ "contains the id of the care delivery organization where the patient had the encounter", "varchar" ],
|
| 34 |
+
"participant_actor_id": [ "contains the id of the provider associated with the care delivery organization who rendered the encounter", "varchar" ],
|
| 35 |
+
"diagnosis_condition_id": [ "contains list of diagnosis codes relevant to the patient of the encounter", "varchar" ],
|
| 36 |
+
"location_id": [ "location where the encounter happend or is happening or will be happening", "varchar" ],
|
| 37 |
+
"discharge_disposition": [ "how the patient was discharged at the end of the encounter", "varchar" ],
|
| 38 |
+
"diagnosis_condition_text": [ "clinical description of the diagnosis codes", "varchar" ],
|
| 39 |
+
"condition_class": [ "condition of the patient classified into specific broad classe., may contain multiple coditions. All lower case.", "varchar" ]
|
| 40 |
+
}
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"name": "EpisodeOfCare",
|
| 44 |
+
"description": "contains continuous period of engagement by a care manager and/or a care management organization with the patient. Every row indicates a unique episide of care for a patient. One patient may have multiple episodes of care ",
|
| 45 |
+
"fields": { "identifier_value": [ "unique identifier of the episode", "varchar" ],
|
| 46 |
+
"type": [ "type of episode, can be disease management, post acute care or specialist referral", "varchar" ],
|
| 47 |
+
"diagnosis_condition_id": [ "ICD-10 diagnosis code assiciated with the episode of care", "varchar" ],
|
| 48 |
+
"subject_id": [ "id of the patient associated with episode, should have a corresponding 'identifier_value' in the Patient table", "varchar" ],
|
| 49 |
+
"managing_organization_id": [ "contains the id of the organization managing the episode", "varchar" ],
|
| 50 |
+
"care_manager_id": [ "contains the id of the care manager managing the episode", "varchar" ],
|
| 51 |
+
"care_team_id": [ "contains the id of the care team managing the episode. Care manager is part of the care team", "varchar" ]
|
| 52 |
+
}
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"name": "RiskScore",
|
| 56 |
+
"description": "Contains the health risk scores of each of the patients. Only the latest risk score is stored. Every row has risk score of an unique patient",
|
| 57 |
+
"fields": { "patient_id": [ "identifier that uniquely identifies a patient. Matches with at least one identifier_value of Patient table.", "varchar"],
|
| 58 |
+
"risk_score": [ "decimal number between 0 and 1 indicating the risk score", "decimal number" ]
|
| 59 |
+
}
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"name": "patient_sdoh_scores",
|
| 63 |
+
"description": "table stores the various social determinants of quality scores about a patient obtained through assessment. Each row indicate score about one patient and about one type of assessment",
|
| 64 |
+
"fields": { "Patient_Id": [ "unique identifier of the patient. Matches with at least one identifier_value of Patient table.", "varchar"],
|
| 65 |
+
"Assessment_Id": [ "name of the assessment", "varchar" ],
|
| 66 |
+
"Answer": [ "The actual answer provided in the assessment", "integer" ],
|
| 67 |
+
"Assessment_Type": [ "type of the assessment, can be 'Financial', 'Home', 'Food' and 'Physical'", "varchar" ],
|
| 68 |
+
"score": [ "Derived standardized score based on the answer provided", "decimal number" ]
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
]
|
database/gravity_sdoh.sqbpro
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?><sqlb_project><db path="/Users/niloy/Documents/GitHub/SQL Agent Demo/sdoh/gravity_sdoh_observations.db" readonly="0" foreign_keys="1" case_sensitive_like="0" temp_store="0" wal_autocheckpoint="1000" synchronous="2"/><attached/><window><main_tabs open="structure browser pragmas query" current="0"/></window><tab_structure><column_width id="0" width="300"/><column_width id="1" width="0"/><column_width id="2" width="100"/><column_width id="3" width="10670"/><column_width id="4" width="0"/><expanded_item id="0" parent="1"/><expanded_item id="1" parent="1"/><expanded_item id="2" parent="1"/><expanded_item id="3" parent="1"/></tab_structure><tab_browse><current_table name="4,7:mainPatient"/><default_encoding codec=""/><browse_table_settings/></tab_browse><tab_sql><sql name="SQL 1">CREATE TABLE RiskScore (
|
| 2 |
+
patient_id TEXT, -- Reference to the patient
|
| 3 |
+
risk_score DECIMAL(4, 2), -- Risk score, a decimal value between 0.0 and 1.0
|
| 4 |
+
risk_score_date TEXT, -- Date of the risk score
|
| 5 |
+
|
| 6 |
+
PRIMARY KEY (patient_id, risk_score_date) -- Composite primary key: patient ID and risk score date
|
| 7 |
+
);
|
| 8 |
+
</sql><current_tab id="0"/></tab_sql></sqlb_project>
|
database/gravity_sdoh_observations.db
ADDED
|
Binary file (86 kB). View file
|
|
|
gravity_sdoh_observations.db
ADDED
|
Binary file (86 kB). View file
|
|
|
pages/__pycache__/solution.cpython-312.pyc
CHANGED
|
Binary files a/pages/__pycache__/solution.cpython-312.pyc and b/pages/__pycache__/solution.cpython-312.pyc differ
|
|
|
pages/solution.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import json
|
| 2 |
-
|
| 3 |
# import pyodbc
|
| 4 |
import mysql.connector
|
| 5 |
import boto3
|
|
@@ -26,8 +26,6 @@ from loguru import logger
|
|
| 26 |
from st_aggrid import AgGrid, GridOptionsBuilder
|
| 27 |
from datetime import datetime
|
| 28 |
|
| 29 |
-
APP_TITLE = '**Social <br>Determinant<br>of Health**'
|
| 30 |
-
|
| 31 |
# Initialize token storage
|
| 32 |
token_file = "token_usage.json"
|
| 33 |
if not os.path.exists(token_file):
|
|
@@ -64,6 +62,8 @@ def show_messages(message):
|
|
| 64 |
success_msg.empty()
|
| 65 |
|
| 66 |
# Locations of various files
|
|
|
|
|
|
|
| 67 |
sql_dir = 'generated_sql/'
|
| 68 |
method_dir = 'generated_method/'
|
| 69 |
insight_lib = 'insight_library/'
|
|
@@ -72,6 +72,8 @@ report_path = 'Reports/'
|
|
| 72 |
connection_string = "DefaultEndpointsProtocol=https;AccountName=phsstorageacc;AccountKey=cEvoESH5CknyeZtbe8eCFuebwr7lRFi1EyO8smA35i5EuoSOfnzRXX/4337Y743B05tQsGPoQbsr+AStNRWeBg==;EndpointSuffix=core.windows.net"
|
| 73 |
container_name = "insights-lab"
|
| 74 |
persona_list = ["Population Analyst", "SDoH Specialist"]
|
|
|
|
|
|
|
| 75 |
|
| 76 |
def getBlobContent(dir_path):
|
| 77 |
try:
|
|
@@ -447,22 +449,28 @@ def get_existing_insight(base_code, user_persona):
|
|
| 447 |
insights_directory = f"insight_library/{user_persona}/{st.session_state.userId}/"
|
| 448 |
try:
|
| 449 |
blobs = container_client.list_blobs(name_starts_with=insights_directory)
|
| 450 |
-
for blob in blobs:
|
|
|
|
|
|
|
|
|
|
| 451 |
blob_name = blob.name # Extract the blob names
|
| 452 |
-
print(blob_name)
|
| 453 |
file_name_with_extension = blob_name.split('/')[-1]
|
| 454 |
file_name = file_name_with_extension.split('.')[0]
|
|
|
|
| 455 |
blob_client = container_client.get_blob_client(blob_name)
|
| 456 |
blob_content = blob_client.download_blob().readall()
|
| 457 |
-
|
| 458 |
insight_data = json.loads(blob_content)
|
| 459 |
if insight_data['base_code'] == base_code:
|
| 460 |
-
logger.info("Existing insight found for base code:
|
| 461 |
return insight_data, file_name
|
| 462 |
-
logger.info("No existing insight found for base code:
|
|
|
|
|
|
|
|
|
|
| 463 |
return None
|
| 464 |
except Exception as e:
|
| 465 |
-
logger.error("Error while retrieving insight:
|
| 466 |
return None
|
| 467 |
|
| 468 |
def update_insight(insight_data, user_persona, file_number):
|
|
@@ -474,10 +482,10 @@ def update_insight(insight_data, user_persona, file_number):
|
|
| 474 |
file_path = f"{user_directory}/{file_number}.json"
|
| 475 |
file_content = json.dumps(insight_data, indent=4)
|
| 476 |
container_client.upload_blob(file_path, data=file_content, overwrite=True)
|
| 477 |
-
logger.info("Insight updated successfully:
|
| 478 |
return True
|
| 479 |
except Exception as e:
|
| 480 |
-
logger.error("Error while updating insight:
|
| 481 |
return False
|
| 482 |
|
| 483 |
def save_insight(next_file_number, user_persona, insight_desc, base_prompt, base_code,selected_db, insight_prompt, insight_code, chart_prompt, chart_code):
|
|
@@ -590,44 +598,14 @@ def generate_sql(query, table_descriptions, table_details,selected_db):
|
|
| 590 |
# logger.error("Error connecting to MySQL: {}", err)
|
| 591 |
# return None
|
| 592 |
|
| 593 |
-
def execute_sql(query, selected_db):
|
| 594 |
-
update_config(selected_db)
|
| 595 |
-
engine = create_sqlalchemy_engine()
|
| 596 |
-
if engine:
|
| 597 |
-
connection = engine.connect()
|
| 598 |
-
logger.info(f"Connected to the database {selected_db}.")
|
| 599 |
-
try:
|
| 600 |
-
df = pd.read_sql_query(query, connection)
|
| 601 |
-
logger.info("Query executed successfully.")
|
| 602 |
-
return df
|
| 603 |
-
except Exception as e:
|
| 604 |
-
logger.error(f"Query execution failed: {e}")
|
| 605 |
-
return pd.DataFrame()
|
| 606 |
-
finally:
|
| 607 |
-
connection.close()
|
| 608 |
-
else:
|
| 609 |
-
logger.error("Failed to create a SQLAlchemy engine.")
|
| 610 |
-
return None
|
| 611 |
-
|
| 612 |
-
# def execute_sql(query, selected_db, offset, limit=100):
|
| 613 |
# update_config(selected_db)
|
| 614 |
# engine = create_sqlalchemy_engine()
|
| 615 |
# if engine:
|
| 616 |
# connection = engine.connect()
|
| 617 |
# logger.info(f"Connected to the database {selected_db}.")
|
| 618 |
# try:
|
| 619 |
-
#
|
| 620 |
-
# paginated_query = f"""
|
| 621 |
-
# WITH CTE AS (
|
| 622 |
-
# SELECT *,
|
| 623 |
-
# ROW_NUMBER() OVER (ORDER BY (SELECT NULL)) AS RowNum
|
| 624 |
-
# FROM ({query.rstrip(';')}) AS subquery
|
| 625 |
-
# )
|
| 626 |
-
# SELECT *
|
| 627 |
-
# FROM CTE
|
| 628 |
-
# WHERE RowNum BETWEEN {offset + 1} AND {offset + limit};
|
| 629 |
-
# """
|
| 630 |
-
# df = pd.read_sql_query(paginated_query, connection)
|
| 631 |
# logger.info("Query executed successfully.")
|
| 632 |
# return df
|
| 633 |
# except Exception as e:
|
|
@@ -638,10 +616,23 @@ def execute_sql(query, selected_db):
|
|
| 638 |
# else:
|
| 639 |
# logger.error("Failed to create a SQLAlchemy engine.")
|
| 640 |
# return None
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 645 |
|
| 646 |
def handle_retrieve_request(prompt):
|
| 647 |
sql_generated = generate_sql(prompt, st.session_state['table_master'], st.session_state['table_details'], st.session_state['selected_db'])
|
|
@@ -830,12 +821,6 @@ def answer_guide_question(question, dframe, df_structure, selected_db):
|
|
| 830 |
logger.debug("Code execution error state: {}", st.session_state['code_execution_error'])
|
| 831 |
return result_df, last_method_num + 1, analysis_code
|
| 832 |
|
| 833 |
-
# def get_metadata(table):
|
| 834 |
-
# table_details = st.session_state['table_details'][table]
|
| 835 |
-
# matadata = [[field, details[0], details[1]] for field, details in table_details.items()]
|
| 836 |
-
# metadata_df = pd.DataFrame(matadata, columns=['Field Name', 'Field Description', 'Field Type'])
|
| 837 |
-
# return metadata_df
|
| 838 |
-
|
| 839 |
def generate_graph(query, df, df_structure,generate_graph):
|
| 840 |
if query is None or df is None or df_structure is None:
|
| 841 |
logger.error("generate_graph received None values for query, df, or df_structure")
|
|
@@ -974,79 +959,62 @@ def get_table_details(engine,selected_db):
|
|
| 974 |
return tables_master_dict, tables_details_dict
|
| 975 |
|
| 976 |
# Function to fetch database names from SQL Server
|
| 977 |
-
def get_database_names():
|
| 978 |
-
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
|
| 982 |
-
|
| 983 |
-
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
|
| 995 |
-
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
|
| 999 |
-
|
| 1000 |
-
def get_metadata(selected_table):
|
| 1001 |
-
|
| 1002 |
-
|
| 1003 |
-
|
| 1004 |
-
|
| 1005 |
-
|
| 1006 |
-
|
| 1007 |
-
|
| 1008 |
-
|
| 1009 |
-
|
| 1010 |
-
|
| 1011 |
-
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
|
| 1015 |
-
|
| 1016 |
-
|
| 1017 |
-
|
| 1018 |
-
#
|
| 1019 |
-
|
| 1020 |
-
|
| 1021 |
-
|
| 1022 |
-
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
| 1032 |
-
|
| 1033 |
-
# st.write('Button clicked!')
|
| 1034 |
-
# st.session_state['offset'] += 100
|
| 1035 |
-
# load_data(sql_generated, selected_db)
|
| 1036 |
-
# else:
|
| 1037 |
-
# st.write('Waiting for button click...')
|
| 1038 |
-
# time.sleep(1)
|
| 1039 |
-
# # if st.button("Load more"):
|
| 1040 |
-
# # logger.info(st.session_state['offset'])
|
| 1041 |
-
# # logger.info("hi............................................................")
|
| 1042 |
-
# # st.session_state['offset'] += 100
|
| 1043 |
-
# # load_data(sql_generated, selected_db)
|
| 1044 |
-
# # else:
|
| 1045 |
-
# # logger.info(st.session_state['offset'])
|
| 1046 |
-
# # logger.info("hi buttoon............................................................")
|
| 1047 |
-
# else:
|
| 1048 |
-
# logger.info(st.session_state['offset'])
|
| 1049 |
-
# logger.info("hi next data............................................................")
|
| 1050 |
|
| 1051 |
def compose_dataset():
|
| 1052 |
if "messages" not in st.session_state:
|
|
@@ -1064,24 +1032,30 @@ def compose_dataset():
|
|
| 1064 |
with col_cc:
|
| 1065 |
st.markdown(APP_TITLE, unsafe_allow_html=True)
|
| 1066 |
|
| 1067 |
-
databases = get_database_names()
|
| 1068 |
-
selected_db = st.selectbox('Select Database:', [''] + databases)
|
|
|
|
| 1069 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1070 |
if selected_db:
|
| 1071 |
if 'selected_db' in st.session_state and st.session_state['selected_db'] != selected_db:
|
| 1072 |
# Clear session state data related to the previous database
|
| 1073 |
st.session_state['messages'] = []
|
| 1074 |
st.session_state['selected_table'] = None
|
| 1075 |
logger.debug('Session state cleared due to database change.')
|
|
|
|
| 1076 |
|
| 1077 |
-
update_config(selected_db)
|
| 1078 |
-
engine = create_sqlalchemy_engine()
|
| 1079 |
|
| 1080 |
-
if 'table_master' not in st.session_state or st.session_state.get('selected_db') != selected_db:
|
| 1081 |
-
|
| 1082 |
-
|
| 1083 |
-
|
| 1084 |
-
|
| 1085 |
|
| 1086 |
tables = list(st.session_state['table_master'].keys())
|
| 1087 |
selected_table = st.selectbox('Tables available:', [''] + tables)
|
|
@@ -1095,7 +1069,7 @@ def compose_dataset():
|
|
| 1095 |
st.session_state.messages.append({"role": "assistant", "type": "text", "content": table_desc})
|
| 1096 |
st.session_state.messages.append({"role": "assistant", "type": "dataframe", "content": table_metadata_df})
|
| 1097 |
logger.debug('Table metadata and description added to session state messages.')
|
| 1098 |
-
|
| 1099 |
# display_paginated_dataframe(table_metadata_df, "table_metadata")
|
| 1100 |
except Exception as e:
|
| 1101 |
st.error("Please try again")
|
|
@@ -1166,16 +1140,7 @@ def compose_dataset():
|
|
| 1166 |
st.write(f"Query saved in the library with id {st.session_state['retrieval_query_no']}.")
|
| 1167 |
logger.info("Query saved in the library with id {}.", st.session_state['retrieval_query_no'])
|
| 1168 |
|
| 1169 |
-
|
| 1170 |
-
st.session_state['graph_obj'] = None
|
| 1171 |
-
if 'graph_prompt' not in st.session_state:
|
| 1172 |
-
st.session_state['graph_prompt'] = ''
|
| 1173 |
-
if 'data_obj' not in st.session_state:
|
| 1174 |
-
st.session_state['data_obj'] = None
|
| 1175 |
-
if 'data_prompt' not in st.session_state:
|
| 1176 |
-
st.session_state['data_prompt'] = ''
|
| 1177 |
-
if 'code_execution_error' not in st.session_state:
|
| 1178 |
-
st.session_state['code_execution_error'] = (None, None)
|
| 1179 |
|
| 1180 |
def design_insight():
|
| 1181 |
col_aa, col_bb, col_cc = st.columns([1, 4, 1], gap="small", vertical_alignment="center")
|
|
@@ -1186,6 +1151,17 @@ def design_insight():
|
|
| 1186 |
st.markdown('**Select a dataset that you generated and ask for different types of tabular insight or graphical charts.**')
|
| 1187 |
with col_cc:
|
| 1188 |
st.markdown(APP_TITLE, unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1189 |
|
| 1190 |
get_saved_query_blob_list()
|
| 1191 |
selected_query = st.selectbox('Select a saved query', [""] + list(st.session_state['query_display_dict'].keys()))
|
|
@@ -1272,7 +1248,7 @@ def design_insight():
|
|
| 1272 |
if st.session_state['data_obj'] is not None:
|
| 1273 |
display_paginated_dataframe(st.session_state['data_obj'], "ag_grid_insight")
|
| 1274 |
st.session_state['data_prompt'] = data_prompt
|
| 1275 |
-
|
| 1276 |
with st.container():
|
| 1277 |
st.subheader('Generate Graph')
|
| 1278 |
graph_prompt_value = st.session_state.get('graph_prompt', '')
|
|
@@ -1296,7 +1272,10 @@ def design_insight():
|
|
| 1296 |
logger.error("Error in generating graph: %s", e)
|
| 1297 |
st.write("Error in generating graph, please try again")
|
| 1298 |
else:
|
| 1299 |
-
|
|
|
|
|
|
|
|
|
|
| 1300 |
st.session_state['graph_prompt'] = graph_prompt
|
| 1301 |
else:
|
| 1302 |
if st.session_state['graph_obj'] is not None:
|
|
@@ -1304,7 +1283,7 @@ def design_insight():
|
|
| 1304 |
st.plotly_chart(st.session_state['graph_obj'], use_container_width=True)
|
| 1305 |
except Exception as e:
|
| 1306 |
st.write("Error in displaying graph, please try again")
|
| 1307 |
-
logger.error("Error in displaying graph: %s", e)
|
| 1308 |
with st.container():
|
| 1309 |
if 'graph_obj' in st.session_state or 'data_obj' in st.session_state:
|
| 1310 |
user_persona = st.selectbox('Select a persona to save the result of your exploration', persona_list)
|
|
@@ -1322,15 +1301,17 @@ def design_insight():
|
|
| 1322 |
try:
|
| 1323 |
result = get_existing_insight(base_code, user_persona)
|
| 1324 |
if result:
|
| 1325 |
-
existing_insight, file_number = result
|
| 1326 |
-
|
| 1327 |
-
'
|
| 1328 |
-
|
| 1329 |
-
|
| 1330 |
-
|
| 1331 |
-
|
| 1332 |
-
'
|
| 1333 |
-
|
|
|
|
|
|
|
| 1334 |
try:
|
| 1335 |
update_insight(existing_insight, user_persona, file_number)
|
| 1336 |
st.text('Insight updated with new Graph and/or Data.')
|
|
|
|
| 1 |
import json
|
| 2 |
+
import sqlite3
|
| 3 |
# import pyodbc
|
| 4 |
import mysql.connector
|
| 5 |
import boto3
|
|
|
|
| 26 |
from st_aggrid import AgGrid, GridOptionsBuilder
|
| 27 |
from datetime import datetime
|
| 28 |
|
|
|
|
|
|
|
| 29 |
# Initialize token storage
|
| 30 |
token_file = "token_usage.json"
|
| 31 |
if not os.path.exists(token_file):
|
|
|
|
| 62 |
success_msg.empty()
|
| 63 |
|
| 64 |
# Locations of various files
|
| 65 |
+
APP_TITLE = '**Social <br>Determinant<br>of Health**'
|
| 66 |
+
|
| 67 |
sql_dir = 'generated_sql/'
|
| 68 |
method_dir = 'generated_method/'
|
| 69 |
insight_lib = 'insight_library/'
|
|
|
|
| 72 |
connection_string = "DefaultEndpointsProtocol=https;AccountName=phsstorageacc;AccountKey=cEvoESH5CknyeZtbe8eCFuebwr7lRFi1EyO8smA35i5EuoSOfnzRXX/4337Y743B05tQsGPoQbsr+AStNRWeBg==;EndpointSuffix=core.windows.net"
|
| 73 |
container_name = "insights-lab"
|
| 74 |
persona_list = ["Population Analyst", "SDoH Specialist"]
|
| 75 |
+
DB_List=["Patient SDOH"]
|
| 76 |
+
|
| 77 |
|
| 78 |
def getBlobContent(dir_path):
|
| 79 |
try:
|
|
|
|
| 449 |
insights_directory = f"insight_library/{user_persona}/{st.session_state.userId}/"
|
| 450 |
try:
|
| 451 |
blobs = container_client.list_blobs(name_starts_with=insights_directory)
|
| 452 |
+
for index, blob in enumerate(blobs):
|
| 453 |
+
# Skip the first item
|
| 454 |
+
if index == 0:
|
| 455 |
+
continue
|
| 456 |
blob_name = blob.name # Extract the blob names
|
|
|
|
| 457 |
file_name_with_extension = blob_name.split('/')[-1]
|
| 458 |
file_name = file_name_with_extension.split('.')[0]
|
| 459 |
+
|
| 460 |
blob_client = container_client.get_blob_client(blob_name)
|
| 461 |
blob_content = blob_client.download_blob().readall()
|
| 462 |
+
|
| 463 |
insight_data = json.loads(blob_content)
|
| 464 |
if insight_data['base_code'] == base_code:
|
| 465 |
+
logger.info("Existing insight found for base code: %s", base_code)
|
| 466 |
return insight_data, file_name
|
| 467 |
+
logger.info("No existing insight found for base code: %s", base_code)
|
| 468 |
+
return None
|
| 469 |
+
except json.JSONDecodeError as e:
|
| 470 |
+
logger.error("Error while retrieving insight: %s", e)
|
| 471 |
return None
|
| 472 |
except Exception as e:
|
| 473 |
+
logger.error("Error while retrieving insight: %s", e)
|
| 474 |
return None
|
| 475 |
|
| 476 |
def update_insight(insight_data, user_persona, file_number):
|
|
|
|
| 482 |
file_path = f"{user_directory}/{file_number}.json"
|
| 483 |
file_content = json.dumps(insight_data, indent=4)
|
| 484 |
container_client.upload_blob(file_path, data=file_content, overwrite=True)
|
| 485 |
+
logger.info("Insight updated successfully: %s", file_number)
|
| 486 |
return True
|
| 487 |
except Exception as e:
|
| 488 |
+
logger.error("Error while updating insight: %s", e)
|
| 489 |
return False
|
| 490 |
|
| 491 |
def save_insight(next_file_number, user_persona, insight_desc, base_prompt, base_code,selected_db, insight_prompt, insight_code, chart_prompt, chart_code):
|
|
|
|
| 598 |
# logger.error("Error connecting to MySQL: {}", err)
|
| 599 |
# return None
|
| 600 |
|
| 601 |
+
# def execute_sql(query, selected_db):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
# update_config(selected_db)
|
| 603 |
# engine = create_sqlalchemy_engine()
|
| 604 |
# if engine:
|
| 605 |
# connection = engine.connect()
|
| 606 |
# logger.info(f"Connected to the database {selected_db}.")
|
| 607 |
# try:
|
| 608 |
+
# df = pd.read_sql_query(query, connection)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 609 |
# logger.info("Query executed successfully.")
|
| 610 |
# return df
|
| 611 |
# except Exception as e:
|
|
|
|
| 616 |
# else:
|
| 617 |
# logger.error("Failed to create a SQLAlchemy engine.")
|
| 618 |
# return None
|
| 619 |
+
|
| 620 |
+
def execute_sql(query,selected_db):
|
| 621 |
+
df = None
|
| 622 |
+
try:
|
| 623 |
+
conn = sqlite3.connect(selected_db)
|
| 624 |
+
curr = conn.cursor()
|
| 625 |
+
curr.execute(query)
|
| 626 |
+
|
| 627 |
+
results = curr.fetchall()
|
| 628 |
+
columns = [desc[0] for desc in curr.description]
|
| 629 |
+
df = pd.DataFrame(results, columns=columns).copy()
|
| 630 |
+
logger.info("Query executed successfully.")
|
| 631 |
+
except sqlite3.Error as e:
|
| 632 |
+
logger.error(f"Error while querying the DB : {e}")
|
| 633 |
+
finally:
|
| 634 |
+
conn.close()
|
| 635 |
+
return df
|
| 636 |
|
| 637 |
def handle_retrieve_request(prompt):
|
| 638 |
sql_generated = generate_sql(prompt, st.session_state['table_master'], st.session_state['table_details'], st.session_state['selected_db'])
|
|
|
|
| 821 |
logger.debug("Code execution error state: {}", st.session_state['code_execution_error'])
|
| 822 |
return result_df, last_method_num + 1, analysis_code
|
| 823 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 824 |
def generate_graph(query, df, df_structure,generate_graph):
|
| 825 |
if query is None or df is None or df_structure is None:
|
| 826 |
logger.error("generate_graph received None values for query, df, or df_structure")
|
|
|
|
| 959 |
return tables_master_dict, tables_details_dict
|
| 960 |
|
| 961 |
# Function to fetch database names from SQL Server
|
| 962 |
+
# def get_database_names():
|
| 963 |
+
# query = """
|
| 964 |
+
# SELECT name
|
| 965 |
+
# FROM sys.databases
|
| 966 |
+
# WHERE name NOT IN ('master', 'tempdb', 'model', 'msdb');
|
| 967 |
+
# """
|
| 968 |
+
# connection_string = (
|
| 969 |
+
# f"DRIVER={SQL_SERVER_CONFIG['driver']};"
|
| 970 |
+
# f"SERVER={SQL_SERVER_CONFIG['server']};"
|
| 971 |
+
# f"UID={SQL_SERVER_CONFIG['username']};" # Use SQL Server authentication username
|
| 972 |
+
# f"PWD={SQL_SERVER_CONFIG['password']}" # Use SQL Server authentication password
|
| 973 |
+
# )
|
| 974 |
+
# engine = create_engine(f"mssql+pyodbc:///?odbc_connect={connection_string}")
|
| 975 |
+
# try:
|
| 976 |
+
# with engine.connect() as conn:
|
| 977 |
+
# result = conn.execute(query)
|
| 978 |
+
# databases = [row['name'] for row in result]
|
| 979 |
+
# logger.info("Database names fetched successfully.")
|
| 980 |
+
# return databases
|
| 981 |
+
# except Exception as e:
|
| 982 |
+
# logger.error("Error fetching database names: {}", e)
|
| 983 |
+
# return []
|
| 984 |
+
|
| 985 |
+
# def get_metadata(selected_table):
|
| 986 |
+
# try:
|
| 987 |
+
# metadata_df = pd.DataFrame(st.session_state['table_details'][selected_table])
|
| 988 |
+
# logger.info("Metadata fetched for table: {}", selected_table)
|
| 989 |
+
# return metadata_df
|
| 990 |
+
# except Exception as e:
|
| 991 |
+
# logger.error("Error fetching metadata for table {}: {}", selected_table, e)
|
| 992 |
+
# return pd.DataFrame()
|
| 993 |
+
|
| 994 |
+
def get_metadata(table):
|
| 995 |
+
table_details = st.session_state['table_details'][table]
|
| 996 |
+
matadata = [[field, details[0], details[1]] for field, details in table_details.items()]
|
| 997 |
+
metadata_df = pd.DataFrame(matadata, columns=['Field Name', 'Field Description', 'Field Type'])
|
| 998 |
+
return metadata_df
|
| 999 |
+
|
| 1000 |
+
def get_meta():
|
| 1001 |
+
print("---------------step1 -------------------------")
|
| 1002 |
+
if 'table_master' not in st.session_state:
|
| 1003 |
+
# load db metadata file
|
| 1004 |
+
print("---------------step2 -------------------------")
|
| 1005 |
+
db_js = json.load(open('./database/db_tables.json'))
|
| 1006 |
+
tables_master_dict = {}
|
| 1007 |
+
tables_details_dict = {}
|
| 1008 |
+
for j in db_js:
|
| 1009 |
+
tables_master_dict[j['name']] = j['description']
|
| 1010 |
+
tables_details_dict[j['name']] = j['fields']
|
| 1011 |
+
print(tables_details_dict)
|
| 1012 |
+
print(tables_master_dict)
|
| 1013 |
+
st.session_state['table_master'] = tables_master_dict
|
| 1014 |
+
st.session_state['table_details'] = tables_details_dict
|
| 1015 |
+
return
|
| 1016 |
+
|
| 1017 |
+
get_meta()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1018 |
|
| 1019 |
def compose_dataset():
|
| 1020 |
if "messages" not in st.session_state:
|
|
|
|
| 1032 |
with col_cc:
|
| 1033 |
st.markdown(APP_TITLE, unsafe_allow_html=True)
|
| 1034 |
|
| 1035 |
+
# databases = get_database_names()
|
| 1036 |
+
# selected_db = st.selectbox('Select Database:', [''] + databases)
|
| 1037 |
+
selected = st.selectbox('Select Database:', DB_List)
|
| 1038 |
|
| 1039 |
+
if selected == "Patient SDOH":
|
| 1040 |
+
selected_db = './gravity_sdoh_observations.db'
|
| 1041 |
+
st.session_state['selected_db'] = selected_db
|
| 1042 |
+
|
| 1043 |
if selected_db:
|
| 1044 |
if 'selected_db' in st.session_state and st.session_state['selected_db'] != selected_db:
|
| 1045 |
# Clear session state data related to the previous database
|
| 1046 |
st.session_state['messages'] = []
|
| 1047 |
st.session_state['selected_table'] = None
|
| 1048 |
logger.debug('Session state cleared due to database change.')
|
| 1049 |
+
st.session_state['selected_db'] = selected_db
|
| 1050 |
|
| 1051 |
+
# update_config(selected_db)
|
| 1052 |
+
# engine = create_sqlalchemy_engine()
|
| 1053 |
|
| 1054 |
+
# if 'table_master' not in st.session_state or st.session_state.get('selected_db') != selected_db:
|
| 1055 |
+
# tables_master_dict, tables_details_dict = get_table_details(engine, selected_db)
|
| 1056 |
+
# st.session_state['table_master'] = tables_master_dict
|
| 1057 |
+
# st.session_state['table_details'] = tables_details_dict
|
| 1058 |
+
# st.session_state['selected_db'] = selected_db
|
| 1059 |
|
| 1060 |
tables = list(st.session_state['table_master'].keys())
|
| 1061 |
selected_table = st.selectbox('Tables available:', [''] + tables)
|
|
|
|
| 1069 |
st.session_state.messages.append({"role": "assistant", "type": "text", "content": table_desc})
|
| 1070 |
st.session_state.messages.append({"role": "assistant", "type": "dataframe", "content": table_metadata_df})
|
| 1071 |
logger.debug('Table metadata and description added to session state messages.')
|
| 1072 |
+
st.session_state.messages.append({"role": "assistant", "type": "text", "content": ""})
|
| 1073 |
# display_paginated_dataframe(table_metadata_df, "table_metadata")
|
| 1074 |
except Exception as e:
|
| 1075 |
st.error("Please try again")
|
|
|
|
| 1140 |
st.write(f"Query saved in the library with id {st.session_state['retrieval_query_no']}.")
|
| 1141 |
logger.info("Query saved in the library with id {}.", st.session_state['retrieval_query_no'])
|
| 1142 |
|
| 1143 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1144 |
|
| 1145 |
def design_insight():
|
| 1146 |
col_aa, col_bb, col_cc = st.columns([1, 4, 1], gap="small", vertical_alignment="center")
|
|
|
|
| 1151 |
st.markdown('**Select a dataset that you generated and ask for different types of tabular insight or graphical charts.**')
|
| 1152 |
with col_cc:
|
| 1153 |
st.markdown(APP_TITLE, unsafe_allow_html=True)
|
| 1154 |
+
|
| 1155 |
+
if 'graph_obj' not in st.session_state:
|
| 1156 |
+
st.session_state['graph_obj'] = None
|
| 1157 |
+
if 'graph_prompt' not in st.session_state:
|
| 1158 |
+
st.session_state['graph_prompt'] = ''
|
| 1159 |
+
if 'data_obj' not in st.session_state:
|
| 1160 |
+
st.session_state['data_obj'] = None
|
| 1161 |
+
if 'data_prompt' not in st.session_state:
|
| 1162 |
+
st.session_state['data_prompt'] = ''
|
| 1163 |
+
if 'code_execution_error' not in st.session_state:
|
| 1164 |
+
st.session_state['code_execution_error'] = (None, None)
|
| 1165 |
|
| 1166 |
get_saved_query_blob_list()
|
| 1167 |
selected_query = st.selectbox('Select a saved query', [""] + list(st.session_state['query_display_dict'].keys()))
|
|
|
|
| 1248 |
if st.session_state['data_obj'] is not None:
|
| 1249 |
display_paginated_dataframe(st.session_state['data_obj'], "ag_grid_insight")
|
| 1250 |
st.session_state['data_prompt'] = data_prompt
|
| 1251 |
+
|
| 1252 |
with st.container():
|
| 1253 |
st.subheader('Generate Graph')
|
| 1254 |
graph_prompt_value = st.session_state.get('graph_prompt', '')
|
|
|
|
| 1272 |
logger.error("Error in generating graph: %s", e)
|
| 1273 |
st.write("Error in generating graph, please try again")
|
| 1274 |
else:
|
| 1275 |
+
try:
|
| 1276 |
+
st.plotly_chart(st.session_state['graph_obj'], use_container_width=True)
|
| 1277 |
+
except Exception as e:
|
| 1278 |
+
st.write("Error in displaying graph, please try again")
|
| 1279 |
st.session_state['graph_prompt'] = graph_prompt
|
| 1280 |
else:
|
| 1281 |
if st.session_state['graph_obj'] is not None:
|
|
|
|
| 1283 |
st.plotly_chart(st.session_state['graph_obj'], use_container_width=True)
|
| 1284 |
except Exception as e:
|
| 1285 |
st.write("Error in displaying graph, please try again")
|
| 1286 |
+
logger.error("Error in displaying graph: %s", e)
|
| 1287 |
with st.container():
|
| 1288 |
if 'graph_obj' in st.session_state or 'data_obj' in st.session_state:
|
| 1289 |
user_persona = st.selectbox('Select a persona to save the result of your exploration', persona_list)
|
|
|
|
| 1301 |
try:
|
| 1302 |
result = get_existing_insight(base_code, user_persona)
|
| 1303 |
if result:
|
| 1304 |
+
existing_insight, file_number = result
|
| 1305 |
+
if insight_prompt and insight_code is not None:
|
| 1306 |
+
existing_insight['prompt'][f'prompt_{len(existing_insight["prompt"]) + 1}'] = {
|
| 1307 |
+
'insight_prompt': insight_prompt,
|
| 1308 |
+
'insight_code': insight_code
|
| 1309 |
+
}
|
| 1310 |
+
if chart_prompt and chart_code is not None:
|
| 1311 |
+
existing_insight['chart'][f'chart_{len(existing_insight["chart"]) + 1}'] = {
|
| 1312 |
+
'chart_prompt': chart_prompt,
|
| 1313 |
+
'chart_code': chart_code
|
| 1314 |
+
}
|
| 1315 |
try:
|
| 1316 |
update_insight(existing_insight, user_persona, file_number)
|
| 1317 |
st.text('Insight updated with new Graph and/or Data.')
|
requirements.txt
CHANGED
|
@@ -6,6 +6,7 @@ altair==5.4.1
|
|
| 6 |
reportlab==4.2.4
|
| 7 |
streamlit_navigation_bar==3.3.0
|
| 8 |
altair_saver==0.5.0
|
|
|
|
| 9 |
plotly
|
| 10 |
boto3
|
| 11 |
azure.storage.blob
|
|
|
|
| 6 |
reportlab==4.2.4
|
| 7 |
streamlit_navigation_bar==3.3.0
|
| 8 |
altair_saver==0.5.0
|
| 9 |
+
sqlite3
|
| 10 |
plotly
|
| 11 |
boto3
|
| 12 |
azure.storage.blob
|