Maheshsr commited on
Commit
8ee72dd
·
1 Parent(s): a09fb49

removing src

Browse files
Logger.txt ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from openai import OpenAI
3
+ import os
4
+ import json
5
+ import altair as alt
6
+ import sqlite3
7
+ import ast
8
+ import streamlit as st
9
+ from streamlit_navigation_bar import st_navbar
10
+ from reportlab.lib.pagesizes import letter
11
+ from reportlab.lib import colors
12
+ from reportlab.platypus import SimpleDocTemplate, Table, TableStyle, Image
13
+ from altair_saver import save
14
+ from utils.menu import menu
15
+ from loguru import logger
16
+ from datetime import datetime
17
+
18
+ st.set_page_config( page_title = 'Insight Lab', page_icon="chart_with_upwards_trend", layout = 'wide')
19
+
20
+ if "logger" not in st.session_state:
21
+ logger.add("Logger.txt", format="{time:YYYY-MM-DD HH:mm:ss} {level} {message}", level="INFO", enqueue=True)
22
+ st.session_state.logger = logger
23
+ def log_data(data):
24
+ current_date = datetime.now().strftime('%Y-%m-%d')
25
+ current_time = datetime.now().strftime('%H:%M:%S')
26
+ logger.info(f"Date: {current_date}\n========================================\nTime: {current_time}\nLogger Data: {data}\n----------------------------------------\n")
27
+ data = "This is Insight's lab log data."
28
+ log_data(data)
29
+ else:
30
+ logger = st.session_state.logger
31
+
32
+ # logger.add("file_{time}.log")
33
+
34
+ # Initialize token storage
35
+ token_file = "token_usage.json"
36
+ if not os.path.exists(token_file):
37
+ with open(token_file, 'w') as f:
38
+ json.dump({}, f)
39
+ def store_token_usage(token_usage):
40
+ # current_month = "2025-01"
41
+ current_month = datetime.now().strftime('%Y-%m')
42
+ with open(token_file, 'r') as f:
43
+ token_data = json.load(f)
44
+
45
+ if current_month in token_data:
46
+ token_data[current_month] += token_usage
47
+ else:
48
+ token_data[current_month] = token_usage
49
+
50
+ with open(token_file, 'w') as f:
51
+ json.dump(token_data, f)
52
+
53
+ def get_monthly_token_usage():
54
+ with open(token_file, 'r') as f:
55
+ token_data = json.load(f)
56
+ return token_data
57
+
58
+ # Example usage of get_monthly_token_usage function
59
+ monthly_token_usage = get_monthly_token_usage()
60
+ print(monthly_token_usage)
61
+ menu()
insightlab_logo.png ADDED
logo.png ADDED
pages/automator.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from utils.menu import menu_with_redirect
3
+ from pages.solution import insight_library
4
+
5
+ menu_with_redirect()
6
+ insight_library()
pages/composer.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from utils.menu import menu_with_redirect,_logout
3
+ from pages import solution
4
+ from loguru import logger
5
+
6
+ menu_with_redirect()
7
+
8
+ # query_lib = 'query_library/'
9
+ # with st.container():
10
+ solution.compose_dataset()
11
+ # with st.container():
12
+ # if 'retrieval_sql' in st.session_state and 'selected_db' in st.session_state:
13
+ # if st.button('Save Query'):
14
+ # database_name = st.session_state['selected_db']
15
+ # sql_saved = solution.save_sql_query_blob(st.session_state['retrieval_query'], st.session_state['retrieval_sql'], st.session_state['retrieval_query_no'], st.session_state['retrieval_result_structure'], query_lib, database_name)
16
+ # if sql_saved:
17
+ # st.write(f"Query saved in the library with id {st.session_state['retrieval_query_no']}.")
18
+ # logger.info("Query saved in the library with id {}.", st.session_state['retrieval_query_no'])
19
+
20
+ # st.success("You have been successfully entered in data composer.")
21
+ # run()
pages/config.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from sqlalchemy import create_engine
3
+ from loguru import logger
4
+
5
+ # Use SQL Server or MySQL flag
6
+ USE_SQL_SERVER = True
7
+
8
+ # SQL Server Configuration
9
+ SQL_SERVER_CONFIG = {
10
+ 'server': os.getenv('SQL_SERVER', '(localdb)\\MSSQLLocalDB'), # Ensure double backslash for Windows paths
11
+ 'database': os.getenv('SQL_DATABASE', 'Insightlab'),
12
+ 'driver': '{ODBC Driver 17 for SQL Server}' # Adjust based on your ODBC driver
13
+ }
14
+
15
+ # MySQL Configuration
16
+ MYSQL_SERVER_CONFIG = {
17
+ 'host': os.getenv('MYSQL_HOST', 'localhost'),
18
+ 'user': os.getenv('MYSQL_USER', 'root'),
19
+ 'password': os.getenv('MYSQL_PASSWORD', 'root'),
20
+ 'database': os.getenv('MYSQL_DATABASE', 'Insightlab')
21
+ }
22
+
23
+ def update_config(selected_database):
24
+ logger.debug(f'Updating database configuration to use database: {selected_database}')
25
+ SQL_SERVER_CONFIG['database'] = selected_database
26
+
27
+ def create_sqlalchemy_engine():
28
+ connection_string = f"DRIVER={SQL_SERVER_CONFIG['driver']};SERVER={SQL_SERVER_CONFIG['server']};DATABASE={SQL_SERVER_CONFIG['database']};Trusted_Connection=yes;"
29
+ logger.debug(f'Creating SQLAlchemy engine with connection string: {connection_string}')
30
+ engine = create_engine(f"mssql+pyodbc:///?odbc_connect={connection_string}")
31
+ return engine
pages/designer.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from utils.menu import menu_with_redirect
3
+ from pages.solution import design_insight
4
+
5
+ menu_with_redirect()
6
+ design_insight()
pages/logger.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from utils.menu import menu_with_redirect
3
+ from pages import solution
4
+ from loguru import logger
5
+ import os
6
+
7
+ log_file_path = os.path.join(os.path.dirname(__file__), '../Logger.txt')
8
+
9
+ # Function to read the log file
10
+ def read_log_file(file_path):
11
+ with open(file_path, 'r') as file:
12
+ return file.read()
13
+
14
+ menu_with_redirect()
15
+
16
+ APP_TITLE = '**Social <br>Determinant<br>of Health**'
17
+ col_aa, col_bb, col_cc = st.columns([1, 4, 1], gap="small", vertical_alignment="center")
18
+ with col_aa:
19
+ st.image('logo.png')
20
+ with col_bb:
21
+ st.subheader(f"InsightLab - Log File", divider='blue')
22
+ st.markdown('**Log Contents of Insights-Lab .**')
23
+ with col_cc:
24
+ st.markdown(APP_TITLE, unsafe_allow_html=True)
25
+
26
+ # Display the contents of the log file
27
+ log_contents = read_log_file(log_file_path)
28
+ st.text_area("Log File Contents", log_contents, height=400)
29
+
pages/login.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import boto3
3
+ import time
4
+ import os
5
+ from loguru import logger
6
+
7
+ # Cognito keys declaration
8
+ user_pool_id = 'us-east-1_KC0upUdmD'
9
+ client_id = '64f41q1d7hf9ntis90iebpjcf4'
10
+ cognito_client = boto3.client('cognito-idp', region_name='us-east-1')
11
+ access_token = ''
12
+ validation_flag = False
13
+ display_code_form = False
14
+
15
+ def show_messages(message):
16
+ """Display messages using Streamlit."""
17
+ success_msg = st.info(message)
18
+ time.sleep(1.5)
19
+ success_msg.empty()
20
+
21
+ # Code to sign in using AWS Cognito authentication
22
+ def login(login_user, password):
23
+ try:
24
+ response = cognito_client.initiate_auth(
25
+ ClientId=client_id,
26
+ AuthFlow='USER_PASSWORD_AUTH',
27
+ AuthParameters={
28
+ 'USERNAME': login_user,
29
+ 'PASSWORD': password
30
+ }
31
+ )
32
+ # Access the authentication result
33
+ authentication_result = response['AuthenticationResult']
34
+ access_token = authentication_result['AccessToken']
35
+ id_token = authentication_result['IdToken']
36
+ refresh_token = authentication_result['RefreshToken']
37
+ get_user_details(access_token)
38
+ show_messages("Login Successful!")
39
+ st.session_state["login"] = True # Setting the session state here
40
+ logger.info("Login successful for user: {}", login_user)
41
+ return True
42
+ except Exception as e:
43
+ logger.error("Authentication failed for user: {}: {}", login_user, e)
44
+ show_messages("Login Failed!")
45
+ return False
46
+
47
+ def get_user_details(access_token):
48
+ try:
49
+ response = cognito_client.get_user(AccessToken=access_token)
50
+ attr_name = None
51
+ attr_sub = None
52
+ for attr in response['UserAttributes']:
53
+ if attr['Name'] == 'sub':
54
+ attr_sub = attr['Value']
55
+ if attr['Name'] == 'name':
56
+ attr_name = attr['Value']
57
+
58
+ st.session_state.username = attr_name
59
+ st.session_state.userId = attr_sub
60
+ if 'query_files' in st.session_state:
61
+ st.session_state.__delitem__('query_files')
62
+
63
+ except Exception as e:
64
+ logger.error('Get User failed: {}', e)
65
+
66
+ def confirm_sign_up_code(user, code):
67
+ try:
68
+ logger.debug("Confirming code for user: {} with code: {}", user, code)
69
+ response = cognito_client.confirm_sign_up(
70
+ ClientId=client_id,
71
+ Username=user,
72
+ ConfirmationCode=code
73
+ )
74
+ show_messages("Sign up Successful!")
75
+ logger.info("Sign up confirmed for user: {}", user)
76
+ except Exception as e:
77
+ logger.error("Code confirmation failed for user: {}: {}", user, e)
78
+ show_messages(e)
79
+
80
+ def resend_sign_up_code(username):
81
+ try:
82
+ logger.debug("Hitting resend code for client_id: {}, username: {}", client_id, username)
83
+ response = cognito_client.resend_confirmation_code(
84
+ ClientId=client_id,
85
+ Username=username
86
+ )
87
+ logger.debug("Resend code response: {}", response)
88
+ show_messages("Code has been sent to your email id")
89
+ except Exception as e:
90
+ logger.error("Resend code confirmation failed for user: {}: {}", username, e)
91
+ show_messages(e)
92
+
93
+ def signup_validation(new_password,confirm_password,name,email,new_username):
94
+ validation_flag = True
95
+ if new_password != confirm_password:
96
+ st.error("Password doesn't match")
97
+ validation_flag = False
98
+ logger.warning("Password and confirm password do not match.")
99
+ if name == '':
100
+ st.error("Name is required")
101
+ validation_flag = False
102
+ logger.warning("Name is required but not provided.")
103
+ if email == '':
104
+ st.error("Email is required")
105
+ validation_flag = False
106
+ logger.warning("Email is required but not provided.")
107
+ if new_username == '':
108
+ logger.debug("---", new_username + email + name + new_password)
109
+ st.error("Username is required")
110
+ validation_flag = False
111
+ logger.warning("Username is required but not provided.")
112
+ return validation_flag
113
+
114
+ def signup(new_password,name,email,new_username):
115
+ try:
116
+ response = cognito_client.sign_up(
117
+ ClientId=client_id,
118
+ Username=new_username,
119
+ Password=new_password,
120
+ UserAttributes=[{'Name': 'email', 'Value': email}, {'Name': 'name', 'Value': name}]
121
+ )
122
+ show_messages("User created successfully!")
123
+ logger.debug("Sign up response: {}", response)
124
+ except Exception as e:
125
+ logger.error("Sign up failed for user: {}: {}", new_username, e)
126
+ show_messages(e)
pages/solution.py ADDED
@@ -0,0 +1,1459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ # import sqlite3
3
+ import pyodbc
4
+ import mysql.connector
5
+ import boto3
6
+ import time
7
+ import pandas as pd
8
+ from openai import AzureOpenAI
9
+ import os
10
+ import json
11
+ import altair as alt
12
+ import plotly
13
+ import ast
14
+ import streamlit as st
15
+ from streamlit_navigation_bar import st_navbar
16
+ from glob import glob
17
+ from reportlab.lib.pagesizes import letter
18
+ from reportlab.lib import colors
19
+ from reportlab.platypus import SimpleDocTemplate, Table, TableStyle, Image
20
+ from altair_saver import save
21
+ from azure.storage.blob import BlobServiceClient, ContainerClient
22
+ import re
23
+ from sqlalchemy import create_engine
24
+ from pages.config import SQL_SERVER_CONFIG, update_config, create_sqlalchemy_engine
25
+ from loguru import logger
26
+ from st_aggrid import AgGrid, GridOptionsBuilder
27
+ from datetime import datetime
28
+
29
+ APP_TITLE = '**Social <br>Determinant<br>of Health**'
30
+
31
+ # Initialize token storage
32
+ token_file = "token_usage.json"
33
+ if not os.path.exists(token_file):
34
+ with open(token_file, 'w') as f:
35
+ json.dump({}, f)
36
+ def store_token_usage(token_usage):
37
+ # current_month = "2025-01"
38
+ current_month = datetime.now().strftime('%Y-%m')
39
+ with open(token_file, 'r') as f:
40
+ token_data = json.load(f)
41
+
42
+ if current_month in token_data:
43
+ token_data[current_month] += token_usage
44
+ else:
45
+ token_data[current_month] = token_usage
46
+
47
+ with open(token_file, 'w') as f:
48
+ json.dump(token_data, f)
49
+
50
+ def get_monthly_token_usage():
51
+ with open(token_file, 'r') as f:
52
+ token_data = json.load(f)
53
+ return token_data
54
+
55
+ # Example usage of get_monthly_token_usage function
56
+ monthly_token_usage = get_monthly_token_usage()
57
+ print(monthly_token_usage)
58
+
59
+
60
+ def show_messages(message):
61
+ """Display messages using Streamlit."""
62
+ success_msg = st.info(message)
63
+ time.sleep(1.5)
64
+ success_msg.empty()
65
+
66
+ # Locations of various files
67
+ sql_dir = 'generated_sql/'
68
+ method_dir = 'generated_method/'
69
+ insight_lib = 'insight_library/'
70
+ query_lib = 'query_library/'
71
+ report_path = 'Reports/'
72
+ connection_string = "DefaultEndpointsProtocol=https;AccountName=phsstorageacc;AccountKey=cEvoESH5CknyeZtbe8eCFuebwr7lRFi1EyO8smA35i5EuoSOfnzRXX/4337Y743B05tQsGPoQbsr+AStNRWeBg==;EndpointSuffix=core.windows.net"
73
+ container_name = "insights-lab"
74
+ persona_list = ["Population Analyst", "SDoH Specialist"]
75
+
76
+ def getBlobContent(dir_path):
77
+ try:
78
+ blob_service_client = BlobServiceClient.from_connection_string(connection_string)
79
+ container_client = blob_service_client.get_container_client(container_name)
80
+ blob_client = container_client.get_blob_client(dir_path)
81
+ blob_data = blob_client.download_blob().readall()
82
+ blob_content = blob_data.decode("utf-8")
83
+ logger.info("Blob content retrieved successfully from: {}", dir_path)
84
+ return blob_content
85
+ except Exception as ex:
86
+ logger.error("Exception while retrieving blob content: {}", ex)
87
+ return ""
88
+
89
+ def check_blob_exists(dir):
90
+ file_exists = False
91
+ try:
92
+ blob_service_client = BlobServiceClient.from_connection_string(connection_string)
93
+ container_client = blob_service_client.get_container_client(container_name)
94
+ blob_list = container_client.list_blobs(name_starts_with=f"{dir}")
95
+ if len(list(blob_list)) > 0:
96
+ file_exists = True
97
+ logger.info("Blob exists check for {}: {}", dir, file_exists)
98
+ return file_exists
99
+ except Exception as ex:
100
+ logger.error("Exception while checking if blob exists: {}", ex)
101
+ return None
102
+
103
+ def get_max_blob_num(dir):
104
+ latest_file_number = 0
105
+ logger.debug("Directory for max blob num check: {}", dir)
106
+ try:
107
+ blob_service_client = BlobServiceClient.from_connection_string(connection_string)
108
+ container_client = blob_service_client.get_container_client(container_name)
109
+ blob_list = list(container_client.list_blobs(name_starts_with=f"{dir}"))
110
+ logger.debug("Blob list: {}", blob_list)
111
+ if len(blob_list) == 0:
112
+ logger.debug("No blobs found in directory: {}", dir)
113
+ latest_file_number = 0
114
+ else:
115
+ for blob in blob_list:
116
+ blob.name = blob.name.removeprefix(dir)
117
+ match = re.search(r"(\d+)", blob.name) # Adjust regex if file names have a different pattern
118
+ if match:
119
+ file_number = int(match.group(1))
120
+ if latest_file_number == 0 or file_number > latest_file_number:
121
+ latest_file_number = file_number
122
+ logger.info("Latest file number in {}: {}", dir, latest_file_number)
123
+ return latest_file_number
124
+ except Exception as ex:
125
+ logger.error("Exception while getting max blob number: {}", ex)
126
+ return 0
127
+
128
+ def save_sql_query_blob(prompt, sql, sql_num, df_structure, dir, database):
129
+ data = {"prompt": prompt, "sql": sql, "structure": df_structure,"database": database }
130
+ user_directory = dir + st.session_state.userId
131
+ blob_service_client = BlobServiceClient.from_connection_string(connection_string)
132
+ container_client = blob_service_client.get_container_client(container_name)
133
+ logger.debug("Saving SQL query blob in directory: {}, SQL number: {}", user_directory, sql_num)
134
+ logger.debug("Data to be saved: {}", data)
135
+ try:
136
+ if not check_blob_exists(user_directory + "/"):
137
+ logger.debug("Creating directory: {}", user_directory)
138
+ folder_path = f"{user_directory}/"
139
+ container_client.upload_blob(folder_path, data=b'')
140
+
141
+ file_path = f"{user_directory}/{sql_num}.json"
142
+ file_content = json.dumps(data, indent=4)
143
+ logger.debug("File path: {}", file_path)
144
+ result = container_client.upload_blob(file_path, data=file_content)
145
+ logger.info("SQL query blob saved successfully: {}", file_path)
146
+ return True
147
+ except Exception as e:
148
+ logger.error("Exception while saving SQL query blob: {}", e)
149
+ return False
150
+
151
+ def save_python_method_blob(method_num, code):
152
+ user_directory = method_dir + st.session_state.userId
153
+ blob_service_client = BlobServiceClient.from_connection_string(connection_string)
154
+ container_client = blob_service_client.get_container_client(container_name)
155
+ logger.debug("Saving Python method blob in directory: {}, Method number: {}", user_directory, method_num)
156
+ try:
157
+ if not check_blob_exists(user_directory + "/"):
158
+ logger.debug("Creating directory: {}", user_directory)
159
+ folder_path = f"{user_directory}/"
160
+ container_client.upload_blob(folder_path, data=b'')
161
+
162
+ file_path = f"{user_directory}/{method_num}.py"
163
+ file_content = json.dumps(code, indent=4)
164
+ logger.debug("File path: {}", file_path)
165
+ result = container_client.upload_blob(file_path, data=file_content)
166
+ logger.info("Python method blob saved successfully: {}", file_path)
167
+ return True
168
+ except Exception as e:
169
+ logger.error("Exception while saving Python method blob: {}", e)
170
+ return False
171
+
172
+ def list_blobs_sorted(directory, extension, session_key, latest_first=True):
173
+ logger.debug("Listing blobs in directory: {}", directory)
174
+ try:
175
+ blob_service_client = BlobServiceClient.from_connection_string(connection_string)
176
+ container_client = blob_service_client.get_container_client(container_name)
177
+ blob_list = list(container_client.list_blobs(name_starts_with=f"{directory}"))
178
+
179
+ files_with_dates = []
180
+ for blob in blob_list:
181
+ file_name = blob.name
182
+ last_modified = blob.last_modified
183
+ if file_name.split('/')[-1] != "" and file_name.split('.')[-1] == extension:
184
+ files_with_dates.append((file_name, last_modified.strftime('%Y-%m-%d %H:%M:%S')))
185
+
186
+ # Sort by timestamp in descending order
187
+ files_with_dates.sort(key=lambda x: x[1], reverse=latest_first)
188
+ logger.debug("Files with dates: {}", files_with_dates)
189
+ st.session_state[session_key] = files_with_dates
190
+ return files_with_dates
191
+ except Exception as e:
192
+ logger.error("Exception while listing blobs: {}", e)
193
+ return []
194
+
195
+ def get_saved_query_blob_list():
196
+ try:
197
+ user_id = st.session_state.userId
198
+ query_library = query_lib + user_id + "/"
199
+ if 'query_files' not in st.session_state:
200
+ list_blobs_sorted(query_library, 'json', 'query_files')
201
+
202
+ query_files = st.session_state['query_files']
203
+ logger.debug("Query files: {}", query_files)
204
+ query_display_dict = {}
205
+
206
+ for file, dt in query_files:
207
+ id = file[len(query_library):-5]
208
+ content = getBlobContent(file)
209
+ content_dict = json.loads(content)
210
+ query_display_dict[f"ID: {id}, Query: \"{content_dict['prompt']}\", Created on {dt}"] = content_dict['sql']
211
+ st.session_state['query_display_dict'] = query_display_dict
212
+ except Exception as e:
213
+ logger.error("Exception while getting saved query blob list: {}", e)
214
+
215
+
216
+ def get_existing_token(current_month):
217
+ blob_service_client = BlobServiceClient.from_connection_string(connection_string)
218
+ container_client = blob_service_client.get_container_client(container_name)
219
+
220
+ # Assuming insights are stored in a specific directory
221
+ token_directory = f"token_consumed/{st.session_state.userId}/"
222
+ try:
223
+ blobs = container_client.list_blobs(name_starts_with=token_directory)
224
+ for blob in blobs:
225
+ blob_name = blob.name # Extract the blob names
226
+ print(blob_name)
227
+ file_name_with_extension = blob_name.split('/')[-1]
228
+ file_name = file_name_with_extension.split('.')[0]
229
+ blob_client = container_client.get_blob_client(blob_name)
230
+ blob_content = blob_client.download_blob().readall()
231
+ print(blob_content)
232
+ token_data = json.loads(blob_content)
233
+ if token_data['year-month'] == current_month:
234
+ logger.info("Existing token_consumed found for month: {}", current_month)
235
+ return token_data, file_name
236
+ logger.info("No existing token_consumed found for month: {}", current_month)
237
+ return None
238
+ except Exception as e:
239
+ logger.error("Error while retrieving token_consumed: {}", e)
240
+ return None
241
+
242
+ def update_token(token_data, file_number):
243
+ user_directory = f"token_consumed/{st.session_state.userId}"
244
+ blob_service_client = BlobServiceClient.from_connection_string(connection_string)
245
+ container_client = blob_service_client.get_container_client(container_name)
246
+
247
+ try:
248
+ file_path = f"{user_directory}/{file_number}.json"
249
+ file_content = json.dumps(token_data, indent=4)
250
+ container_client.upload_blob(file_path, data=file_content, overwrite=True)
251
+ logger.info("token updated successfully: {}", file_number)
252
+ return True
253
+ except Exception as e:
254
+ logger.error("Error while updating token: {}", e)
255
+ return False
256
+
257
+ def save_token(current_month, token_usage, userprompt, purpose, selected_db, time):
258
+ new_token = {
259
+ 'year-month': current_month,
260
+ 'total_token': token_usage,
261
+ 'prompt': {
262
+ 'prompt_1': {
263
+ 'user_prompt': userprompt,
264
+ 'prompt_purpose': purpose,
265
+ 'database':selected_db,
266
+ 'date,time':time,
267
+ 'token':token_usage
268
+ }
269
+ }
270
+ }
271
+ user_directory = f"token_consumed/{st.session_state.userId}"
272
+ blob_service_client = BlobServiceClient.from_connection_string(connection_string)
273
+ container_client = blob_service_client.get_container_client(container_name)
274
+
275
+ try:
276
+ if not check_blob_exists(user_directory + "/"):
277
+ folder_path = f"{user_directory}/"
278
+ container_client.upload_blob(folder_path, data=b'')
279
+
280
+ file_path = f"{user_directory}/{current_month}.json"
281
+ file_content = json.dumps(new_token, indent=4)
282
+ container_client.upload_blob(file_path, data=file_content)
283
+ logger.info("New token created: {}", file_path)
284
+ return True
285
+ except Exception as e:
286
+ logger.error("Error while creating new token: {}", e)
287
+ return False
288
+
289
+ def run_prompt(prompt,userprompt,purpose,selected_db, model="provider-gpt4"):
290
+ current_month = datetime.now().strftime('%Y-%m')
291
+ time=datetime.now().strftime('%d/%m/%Y, %H:%M:%S')
292
+ try:
293
+ client = AzureOpenAI(
294
+ azure_endpoint="https://provider-openai-2.openai.azure.com/",
295
+ api_key="84a58994fdf64338b8c8f0610d63f81c",
296
+ api_version="2024-02-15-preview"
297
+ )
298
+ response = client.chat.completions.create(model=model, messages=[{"role": "user", "content": prompt}], temperature=0)
299
+ logger.debug("Prompt response: {}", response)
300
+
301
+
302
+ # Ensure 'usage' attribute exists and is not None
303
+ if response.usage is not None:
304
+ token_usage = response.usage.total_tokens # Retrieve total tokens used
305
+ logger.info("Tokens consumed: {}", token_usage) # Log token usage
306
+ store_token_usage(token_usage) # Store token usage by month
307
+ else:
308
+ token_usage = 0
309
+ logger.warning("Token usage information is not available in the response")
310
+ try:
311
+ result = get_existing_token(current_month)
312
+ if result:
313
+ existing_token, file_number = result
314
+ existing_token['total_token']+= token_usage
315
+ existing_token['prompt'][f'prompt_{len(existing_token["prompt"]) + 1}'] = {
316
+ 'user_prompt': userprompt,
317
+ 'prompt_purpose': purpose,
318
+ 'database':selected_db,
319
+ 'date,time':time,
320
+ 'token':token_usage
321
+ }
322
+ try:
323
+ update_token(existing_token, file_number)
324
+ # st.text('token updated with Data.')
325
+ logger.info("token updated successfully.")
326
+ except Exception as e:
327
+ # st.write('Could not update the token file. Please try again')
328
+ logger.error("Error while updating token file: {}", e)
329
+ else:
330
+ # Create a new token entry
331
+ if not check_blob_exists(f"token_consumed/{st.session_state.userId}"):
332
+ blob_service_client = BlobServiceClient.from_connection_string(connection_string)
333
+ container_client = blob_service_client.get_container_client(container_name)
334
+ logger.info("Creating a new folder in the blob storage:", f"token_consumed/{st.session_state.userId}")
335
+ folder_path = f"token_consumed/{st.session_state.userId}/"
336
+ container_client.upload_blob(folder_path, data=b'')
337
+ # next_file_number = get_max_blob_num(f"insight_library/{user_persona}/{st.session_state.userId}/") + 1
338
+ try:
339
+ save_token(current_month, token_usage, userprompt,purpose, selected_db, time)
340
+ # st.text(f'Token #{current_month} is saved.')
341
+ # logger.info(f'Insight #{next_file_number} with Graph and/or Data saved.')
342
+ except Exception as e:
343
+ # st.write('Could not write the token file.')
344
+ logger.error(f"Error while writing token file: {e}")
345
+ except Exception as e:
346
+ st.write(f"Please try again")
347
+ logger.error(f"Error checking existing token: {e}")
348
+ return response.choices[0].message.content # Return only the code content
349
+ except Exception as e:
350
+ logger.error("Exception while running prompt: {}", e)
351
+ return ""
352
+
353
+
354
+ def list_files_sorted(directory, extension, session_key, latest_first=True):
355
+ try:
356
+ # Get a list of all JSON files in the directory
357
+ files = glob(os.path.join(directory, f"*.{extension}"))
358
+ logger.debug("Files found: {}", files)
359
+
360
+ # Sort the files by modification time, with the latest files first
361
+ files.sort(key=os.path.getmtime, reverse=latest_first)
362
+ logger.debug("Sorted files: {}", files)
363
+
364
+ # Create a list of tuples containing the file name and creation date
365
+ files_with_dates = [(file, datetime.fromtimestamp(os.path.getctime(file)).strftime('%Y-%m-%d %H:%M:%S')) for file in files]
366
+ st.session_state[session_key] = files_with_dates
367
+
368
+ return files_with_dates
369
+ except Exception as e:
370
+ logger.error("Exception while listing files: {}", e)
371
+ return []
372
+
373
+ def get_column_types(df):
374
+ def infer_type(column, series):
375
+ try:
376
+ if series.dtype == 'int64':
377
+ return 'int64'
378
+ elif series.dtype == 'float64':
379
+ return 'float64'
380
+ elif series.dtype == 'bool':
381
+ return 'bool'
382
+ elif series.dtype == 'object':
383
+ try:
384
+ # Try to convert to datetime (with time component)
385
+ pd.to_datetime(series, format='%Y-%m-%d %H:%M:%S', errors='raise')
386
+ return 'datetime'
387
+ except (ValueError, TypeError):
388
+ try:
389
+ # Try to convert to date (without time component)
390
+ pd.to_datetime(series, format='%Y-%m-%d', errors='raise')
391
+ return 'date'
392
+ except (ValueError, TypeError):
393
+ return 'string'
394
+ else:
395
+ return series.dtype.name # fallback for any other dtype
396
+ except Exception as e:
397
+ logger.error("Exception while inferring column type for {}: {}", column, e)
398
+ return 'unknown'
399
+
400
+ # Create a dictionary with inferred types
401
+ try:
402
+ column_types = {col: infer_type(col, df[col]) for col in df.columns}
403
+ # logger.info("Column types inferred successfully.")
404
+ return column_types
405
+ except Exception as e:
406
+ logger.error("Exception while getting column types: {}", e)
407
+ return {}
408
+
409
+ def save_sql_query(prompt, sql, sql_num, df_structure, dir):
410
+ data = {"prompt": prompt, "sql": sql, "structure": df_structure }
411
+ user_directory = dir + st.session_state.userId
412
+ os.makedirs(user_directory, exist_ok=True)
413
+ logger.debug("Saving SQL query to directory: {}, SQL number: {}", user_directory, sql_num)
414
+ logger.debug("Data to be saved: {}", data)
415
+ try:
416
+ # Write the dictionary to a JSON file
417
+ with open(f"{user_directory}/{sql_num}.json", 'w') as json_file:
418
+ json.dump(data, json_file, indent=4)
419
+ logger.info("SQL query saved successfully.")
420
+ return True
421
+ except Exception as e:
422
+ logger.error("Exception while saving SQL query: {}", e)
423
+ return False
424
+
425
+ def save_python_method(method_num, code):
426
+ try:
427
+ # Write the code to a Python file
428
+ with open(f"{method_dir}{method_num}.py", 'w') as code_file:
429
+ code_file.write(code)
430
+ logger.info("Python method saved successfully: {}", method_num)
431
+ return True
432
+ except Exception as e:
433
+ logger.error("Exception while saving Python method: {}", e)
434
+ return False
435
+ def get_ag_grid_options(df):
436
+ gb = GridOptionsBuilder.from_dataframe(df)
437
+ gb.configure_pagination(paginationPageSize=20) # Limit to 20 rows per page
438
+ gb.configure_default_column(resizable=True, sortable=True, filterable=True)
439
+ # gb.configure_grid_options(domLayout='autoHeight') # Auto-size rows
440
+ return gb.build()
441
+
442
+ def get_existing_insight(base_code, user_persona):
443
+ blob_service_client = BlobServiceClient.from_connection_string(connection_string)
444
+ container_client = blob_service_client.get_container_client(container_name)
445
+
446
+ # Assuming insights are stored in a specific directory
447
+ insights_directory = f"insight_library/{user_persona}/{st.session_state.userId}/"
448
+ try:
449
+ blobs = container_client.list_blobs(name_starts_with=insights_directory)
450
+ for blob in blobs:
451
+ blob_name = blob.name # Extract the blob names
452
+ print(blob_name)
453
+ file_name_with_extension = blob_name.split('/')[-1]
454
+ file_name = file_name_with_extension.split('.')[0]
455
+ blob_client = container_client.get_blob_client(blob_name)
456
+ blob_content = blob_client.download_blob().readall()
457
+ print(blob_content)
458
+ insight_data = json.loads(blob_content)
459
+ if insight_data['base_code'] == base_code:
460
+ logger.info("Existing insight found for base code: {}", base_code)
461
+ return insight_data, file_name
462
+ logger.info("No existing insight found for base code: {}", base_code)
463
+ return None
464
+ except Exception as e:
465
+ logger.error("Error while retrieving insight: {}", e)
466
+ return None
467
+
468
+ def update_insight(insight_data, user_persona, file_number):
469
+ user_directory = f"{insight_lib}{user_persona}/{st.session_state.userId}"
470
+ blob_service_client = BlobServiceClient.from_connection_string(connection_string)
471
+ container_client = blob_service_client.get_container_client(container_name)
472
+
473
+ try:
474
+ file_path = f"{user_directory}/{file_number}.json"
475
+ file_content = json.dumps(insight_data, indent=4)
476
+ container_client.upload_blob(file_path, data=file_content, overwrite=True)
477
+ logger.info("Insight updated successfully: {}", file_number)
478
+ return True
479
+ except Exception as e:
480
+ logger.error("Error while updating insight: {}", e)
481
+ return False
482
+
483
+ def save_insight(next_file_number, user_persona, insight_desc, base_prompt, base_code,selected_db, insight_prompt, insight_code, chart_prompt, chart_code):
484
+ new_insight = {
485
+ 'description': insight_desc,
486
+ 'base_prompt': base_prompt,
487
+ 'base_code': base_code,
488
+ 'database':selected_db,
489
+ 'prompt': {
490
+ 'prompt_1': {
491
+ 'insight_prompt': insight_prompt,
492
+ 'insight_code': insight_code
493
+ }
494
+ },
495
+ 'chart': {
496
+ 'chart_1': {
497
+ 'chart_prompt': chart_prompt,
498
+ 'chart_code': chart_code
499
+ }
500
+ }
501
+ }
502
+
503
+ user_directory = f"{insight_lib}{user_persona}/{st.session_state.userId}"
504
+ blob_service_client = BlobServiceClient.from_connection_string(connection_string)
505
+ container_client = blob_service_client.get_container_client(container_name)
506
+
507
+ try:
508
+ if not check_blob_exists(user_directory + "/"):
509
+ folder_path = f"{user_directory}/"
510
+ container_client.upload_blob(folder_path, data=b'')
511
+
512
+ file_path = f"{user_directory}/{next_file_number}.json"
513
+ file_content = json.dumps(new_insight, indent=4)
514
+ container_client.upload_blob(file_path, data=file_content)
515
+ logger.info("New insight created: {}", file_path)
516
+ return True
517
+ except Exception as e:
518
+ logger.error("Error while creating new insight: {}", e)
519
+ return False
520
+
521
+ def generate_sql(query, table_descriptions, table_details,selected_db):
522
+ if len(query) == 0:
523
+ return None
524
+
525
+ with st.spinner('Generating Query'):
526
+ query_prompt = f"""
527
+ You are an expert in understanding an English language healthcare data query and translating it into an SQL Query that can be executed on a SQLite database.
528
+
529
+ I am providing you the table names and their purposes that you need to use as a dictionary within double backticks. There may be more than one table.
530
+ Table descriptions: ``{table_descriptions}``
531
+
532
+ I am providing you the table structure as a dictionary. For this dictionary, table names are the keys. Values within this dictionary
533
+ are other dictionaries (nested dictionaries). In each nested dictionary, the keys are the field names and the values are dictionaries
534
+ where each key is the column name and each value is the datatype. There may be multiple table structures described here.
535
+ The table structure is enclosed in triple backticks.
536
+ Table Structures: ```{table_details}```
537
+
538
+ Pay special attention to the field names. Some field names have an underscore ('_') and some do not. You need to be accurate while generating the query.
539
+ If there is a space in the column name, then you need to fully enclose each occurrence of the column name with double quotes in the query.
540
+
541
+ This is the English language query that needs to be converted into an SQL Query within four backticks.
542
+ English language query: ````{query}````
543
+
544
+ Your task is to generate an SQL query that can be executed on a SQLite database.
545
+ Only produce the SQL query as a string.
546
+ Do NOT produce any backticks before or after.
547
+ Do NOT produce any JSON tags.
548
+ Do NOT produce any additional text that is not part of the query itself.
549
+ """
550
+ logger.info(f"Generating SQL query with prompt:{query_prompt}")
551
+ query_response = run_prompt(query_prompt, query,"generate query",selected_db)
552
+
553
+ # Check if query_response is a tuple and unpack it
554
+ if isinstance(query_response, tuple):
555
+ query_response = query_response[0]
556
+
557
+ if query_response is None:
558
+ logger.error("Query response is None")
559
+ return None
560
+
561
+ q = query_response.replace('\\', '')
562
+ logger.debug("Generated SQL query: %s", q)
563
+ return q
564
+
565
+ # def create_connection():
566
+ # if USE_SQL_SERVER:
567
+ # try:
568
+ # conn = pyodbc.connect(
569
+ # f"DRIVER={SQL_SERVER_CONFIG['driver']};"
570
+ # f"SERVER={SQL_SERVER_CONFIG['server']};"
571
+ # f"DATABASE={SQL_SERVER_CONFIG['database']};"
572
+ # "Trusted_Connection=yes;"
573
+ # )
574
+ # logger.info("Connected to SQL Server")
575
+ # return conn
576
+ # except Exception as e:
577
+ # logger.error("Error connecting to SQL Server: {}", e)
578
+ # return None
579
+ # else:
580
+ # try:
581
+ # conn = mysql.connector.connect(
582
+ # host=MYSQL_SERVER_CONFIG['host'],
583
+ # user=MYSQL_SERVER_CONFIG['user'],
584
+ # password=MYSQL_SERVER_CONFIG['password'],
585
+ # database=MYSQL_SERVER_CONFIG['database']
586
+ # )
587
+ # logger.info("Connected to MySQL Server")
588
+ # return conn
589
+ # except mysql.connector.Error as err:
590
+ # logger.error("Error connecting to MySQL: {}", err)
591
+ # return None
592
+
593
+ def execute_sql(query, selected_db):
594
+ update_config(selected_db)
595
+ engine = create_sqlalchemy_engine()
596
+ if engine:
597
+ connection = engine.connect()
598
+ logger.info(f"Connected to the database {selected_db}.")
599
+ try:
600
+ df = pd.read_sql_query(query, connection)
601
+ logger.info("Query executed successfully.")
602
+ return df
603
+ except Exception as e:
604
+ logger.error(f"Query execution failed: {e}")
605
+ return pd.DataFrame()
606
+ finally:
607
+ connection.close()
608
+ else:
609
+ logger.error("Failed to create a SQLAlchemy engine.")
610
+ return None
611
+
612
+ # def execute_sql(query, selected_db, offset, limit=100):
613
+ # update_config(selected_db)
614
+ # engine = create_sqlalchemy_engine()
615
+ # if engine:
616
+ # connection = engine.connect()
617
+ # logger.info(f"Connected to the database {selected_db}.")
618
+ # try:
619
+ # # Modify the query to use ROW_NUMBER() for pagination
620
+ # paginated_query = f"""
621
+ # WITH CTE AS (
622
+ # SELECT *,
623
+ # ROW_NUMBER() OVER (ORDER BY (SELECT NULL)) AS RowNum
624
+ # FROM ({query.rstrip(';')}) AS subquery
625
+ # )
626
+ # SELECT *
627
+ # FROM CTE
628
+ # WHERE RowNum BETWEEN {offset + 1} AND {offset + limit};
629
+ # """
630
+ # df = pd.read_sql_query(paginated_query, connection)
631
+ # logger.info("Query executed successfully.")
632
+ # return df
633
+ # except Exception as e:
634
+ # logger.error(f"Query execution failed: {e}")
635
+ # return pd.DataFrame()
636
+ # finally:
637
+ # connection.close()
638
+ # else:
639
+ # logger.error("Failed to create a SQLAlchemy engine.")
640
+ # return None
641
+
642
+ # def fetch_data(query, selected_db, offset, limit):
643
+ # df = execute_sql(query, selected_db, offset, limit)
644
+ # return drop_duplicate_columns(df)
645
+
646
+ def handle_retrieve_request(prompt):
647
+ sql_generated = generate_sql(prompt, st.session_state['table_master'], st.session_state['table_details'], st.session_state['selected_db'])
648
+
649
+ logger.debug("Type of sql_generated: %s", type(sql_generated))
650
+ logger.debug("Content of sql_generated: %s", sql_generated)
651
+
652
+ # Check if sql_generated is a tuple and unpack it
653
+ if isinstance(sql_generated, tuple):
654
+ logger.debug("Unpacking tuple returned by generate_sql")
655
+ sql_generated = sql_generated[0]
656
+
657
+ if sql_generated is None:
658
+ logger.error("Generated SQL is None")
659
+ return None, None
660
+
661
+ logger.debug("Generated SQL: %s", sql_generated)
662
+
663
+ if 'sql' in sql_generated:
664
+ s = sql_generated.find('\n')
665
+ rs = sql_generated.rfind('\n')
666
+ sql_generated = sql_generated[s+1:rs]
667
+
668
+ results_df = None
669
+ try:
670
+ logger.debug("Executing SQL: %s", sql_generated)
671
+ sql_generated = sql_generated.replace('###', '')
672
+ selected_db = st.session_state.get('selected_db')
673
+ results_df = execute_sql(sql_generated, selected_db)#,offset=0, limit=100)
674
+ if results_df is not None:
675
+ results_df = results_df.copy()
676
+ except Exception as e:
677
+ logger.error("Error while executing generated query: %s", e)
678
+ return None, None
679
+
680
+ return results_df, sql_generated
681
+
682
+ def display_historical_responses(messages):
683
+ for index, message in enumerate(messages[:-1]):
684
+ logger.debug("Displaying historical response: %s", message)
685
+ with st.chat_message(message["role"]):
686
+ if 'type' in message:
687
+ if message["type"] == "text":
688
+ st.markdown(message["content"])
689
+ elif message["type"] == "dataframe" or message["type"] == "table":
690
+ display_paginated_dataframe(message["content"], f"message_historical_{index}_{id(message)}")
691
+ elif message["type"] == "chart":
692
+ st.plotly_chart(message["content"])
693
+
694
+ def display_paginated_dataframe(df, key):
695
+ if key not in st.session_state:
696
+ st.session_state[key] = {'page_number': 1}
697
+ if df.empty:
698
+ st.write("No data available to display.")
699
+ return
700
+
701
+ page_size = 100 # Number of rows per page
702
+ total_rows = len(df)
703
+ total_pages = (total_rows // page_size) + (1 if total_rows % page_size != 0 else 0)
704
+
705
+ # Get the current page number from the user
706
+ page_number = st.number_input(f'Page number', min_value=1, max_value=total_pages, value=st.session_state[key]['page_number'], key=f'page_number_{key}')
707
+ st.session_state[key]['page_number'] = page_number
708
+
709
+ # Calculate the start and end indices of the rows to display
710
+ start_idx = (page_number - 1) * page_size
711
+ end_idx = start_idx + page_size
712
+
713
+ # Display the current page of data
714
+ current_data = df.iloc[start_idx:end_idx]
715
+
716
+ # Configure AG Grid
717
+ gb = GridOptionsBuilder.from_dataframe(current_data)
718
+ gb.configure_pagination(paginationAutoPageSize=False, paginationPageSize=page_size)
719
+ grid_options = gb.build()
720
+
721
+ # Display the grid
722
+ AgGrid(current_data, gridOptions=grid_options, key=f"query_result_{key}_{page_number}")
723
+
724
+ def display_new_responses(response):
725
+ for k, v in response.items():
726
+ logger.debug("Displaying new response: {} - {}", k, v)
727
+ if k == 'text':
728
+ st.session_state.messages.append({"role": "assistant", "content": v, "type": "text"})
729
+ st.markdown(v)
730
+ # if k == 'dataframe':
731
+ # grid_options = get_ag_grid_options(v)
732
+ # # AgGrid(v,gridOptions=grid_options,key="new_response")
733
+ # st.session_state.messages.append({"role": "assistant", "content": v, "type": "dataframe"})
734
+ if k == 'footnote':
735
+ seq_no, sql_str = v
736
+ filename = f"{sql_dir}{st.session_state.userId}{'/'}{seq_no}.json"
737
+ st.markdown(f"*SQL: {sql_str}', File: {filename}*")
738
+
739
+ def drop_duplicate_columns(df):
740
+ duplicate_columns = df.columns[df.columns.duplicated()].unique()
741
+ df = df.loc[:, ~df.columns.duplicated()]
742
+ # logger.info("Duplicate columns dropped: {}", duplicate_columns)
743
+ return df
744
+
745
+ def recast_object_columns_to_string(df):
746
+ for col in df.columns:
747
+ if df[col].dtype == 'object':
748
+ df[col] = df[col].astype(str)
749
+ logger.debug("Column '{}' recast to string.", col)
750
+ return df
751
+
752
+ def answer_guide_question(question, dframe, df_structure, selected_db):
753
+ logger.debug("Question: {}", question)
754
+ logger.debug("DataFrame Structure: {}", df_structure)
755
+ logger.debug("DataFrame Preview: {}", dframe.head())
756
+
757
+ with st.spinner('Generating analysis code'):
758
+ code_gen_prompt = f"""You are an expert in understanding an english langauge task and write python script that, when executed, provide correect answer by analyzing a python dataframe.
759
+
760
+ I am providing the english language task in double backticks
761
+ Task: ``{question}``
762
+
763
+ I am providing you the dataframe structure as a dictionary. For this dictionary, column names are the keys and values are the datatypes of each column.
764
+ The dataframe structure is enclosed in triple backticks.
765
+ Dataframe Structures: ```{df_structure}```
766
+
767
+ I am providing you the dataframe as a dictionary. For this dictionary, column names are the keys and values are the datatypes of each column.
768
+ The dataframe is enclosed in triple backticks.
769
+ Dataframe: ```{df_structure}```
770
+
771
+ You are required to create a python script that will manipulate a dataframe named 'df' and generate output that satisfies the task.
772
+ Put the final result in a dictionary called output. The output dictionary should have only one key called 'result_df' and the value of that key will be output dataframe.
773
+ Do not define an empty output dictionary as it will be already defined outside the generated code.
774
+ Only keep the relevant columns in the final output df, do not put unnecessary columns that are not needed for the task.
775
+ Pay special attention to the field names. Some field names have an '_' and some do not. You need to be accurate while generating the query.
776
+ If there is a space in the column name, then you need to fully enclose each occurrence of the column name with double quotes in the query.
777
+ Put the given task as a comment line in the first line of the code generated.
778
+ Do not generate a method, but generate only script.
779
+
780
+ Your task is to generate python code that can be executed.
781
+ Do NOT produce any backticks before or after.
782
+ Do NOT produce any narrative or justification before or after the code
783
+ Do NOT produce any additional text that is not part of the python code of the method itself.
784
+ You must give a new line character before every actual line of code.
785
+ The script you produced must be able to run on a Python runtime.
786
+ Go back and check if the generated code can be run within a python runtime.
787
+ Go back and check to make sure you have not produced any narrative or justification before or after the code.
788
+ Go back and check to make sure you have not enclosed the code in triple backticks.
789
+ """
790
+ logger.info(f"Generating insight with prompt: {code_gen_prompt}")
791
+ analysis_code = run_prompt(code_gen_prompt,question,"generate insight",selected_db)
792
+
793
+ # Ensure analysis_code is a string
794
+ if not isinstance(analysis_code, str):
795
+ logger.error("Generated code is not a string: {}", analysis_code)
796
+ raise ValueError("Generated code is not a string")
797
+
798
+ last_method_num = get_max_blob_num(method_dir + st.session_state.userId + '/')
799
+ try:
800
+ file_saved = save_python_method_blob(last_method_num + 1, analysis_code)
801
+ logger.info("Code generated and written in {}/{}.py", method_dir, last_method_num)
802
+ except Exception as e:
803
+ logger.error("Trouble writing the code file for {} and method number {}: {}", question, last_method_num + 1, e)
804
+
805
+ result_df = None
806
+ df = dframe.copy()
807
+ df = recast_object_columns_to_string(df)
808
+ output = {}
809
+
810
+ logger.debug("DataFrame after recasting object columns to string: {}", df.head())
811
+ try:
812
+ logger.debug('Generated code:\n{}', analysis_code)
813
+ exec(analysis_code, globals(), {'df': df, 'output': output}) # type: ignore
814
+ logger.debug("Output dictionary contents: {}", output)
815
+ result_df = output.get('result_df', None)
816
+ if result_df is not None:
817
+ st.session_state['code_execution_error'] = (analysis_code, None)
818
+ # logger.info("Result DataFrame: {}", result_df.head())
819
+ # logger.info("Result DataFrame dtypes: {}", result_df.dtypes)
820
+ # grid_options = get_ag_grid_options(result_df)
821
+ # AgGrid(result_df,gridOptions=grid_options,key="answer_guide_question") # Use AG Grid to display the dataframe
822
+ else:
823
+ logger.warning("result_df is not defined in the output dictionary")
824
+ st.session_state['code_execution_error'] = ("", None)
825
+ except Exception as e:
826
+ logger.error("Error executing generated code {} for {}: {}", last_method_num, question, e)
827
+ logger.debug("Generated code:\n{}", analysis_code)
828
+ st.session_state['code_execution_error'] = (analysis_code, e)
829
+
830
+ logger.debug("Code execution error state: {}", st.session_state['code_execution_error'])
831
+ return result_df, last_method_num + 1, analysis_code
832
+
833
+ # def get_metadata(table):
834
+ # table_details = st.session_state['table_details'][table]
835
+ # matadata = [[field, details[0], details[1]] for field, details in table_details.items()]
836
+ # metadata_df = pd.DataFrame(matadata, columns=['Field Name', 'Field Description', 'Field Type'])
837
+ # return metadata_df
838
+
839
+ def generate_graph(query, df, df_structure,generate_graph):
840
+ if query is None or df is None or df_structure is None:
841
+ logger.error("generate_graph received None values for query, df, or df_structure")
842
+ return None, None
843
+
844
+ if len(query) == 0:
845
+ return None, None
846
+
847
+ df_summary = {
848
+ "columns": df.columns.tolist(),
849
+ "dtypes": df.dtypes.astype(str).to_dict(),
850
+ "describe": df.describe().to_dict()
851
+ }
852
+
853
+ with st.spinner('Generating graph'):
854
+ graph_prompt = f"""
855
+ You are an expert in understanding English language instructions to generate a graph based on a given dataframe.
856
+
857
+ I am providing you the dataframe structure as a dictionary in double backticks.
858
+ Dataframe structure: ``{df_structure}``
859
+
860
+ I am also providing you a summary of the dataframe as a dictionary in double backticks.
861
+ Dataframe summary: ``{df_summary}``
862
+
863
+ I have provided the dataframe structure and its summary. I can't provide the entire dataframe.
864
+
865
+ I am also giving you the intent instruction in triple backticks.
866
+ Instruction for generating the graph: ```{query}```
867
+
868
+ Your task is to write the code that will generate a Plotly chart.
869
+ You should be able to derive the chart type from the instruction.
870
+ Graphs may need calculations, such as aggregating or calculating averages for some of the numeric columns.
871
+
872
+ You should generate the code that will allow me to create the Plotly chart object that can then be used as the parameter in Streamlit's `st.plotly_chart()` method.
873
+
874
+ Pay special attention to the field names. Some field names have an underscore (_) and some do not. You need to be accurate while generating the query.
875
+ Pay special attention when you need to group by based on two categorical columns to create things like bubble charts. For example, the sample code within four backticks below is the correct way to prepare a dataframe with procedure code, a categorical variable in one axis, and diagnosis code, another categorical variable in another axis, and the size of the bubble would be based on the sum of 'Total Paid' values for each procedure and diagnosis code combination.
876
+ Sample code: ````grouped_df = df_ma.groupby(['Procedure Code', 'Diagnosis Codes'])['Total Paid'].sum().reset_index()````
877
+
878
+ If you need to add a filter criterion, then you need to add a second step as indicated in five backticks below. This shows it is filtering the dataframe for all groups with a sum of 'Total Paid' more than 1000. You can feed the last dataframe to the Plotly chart.
879
+ Sample code: `````grouped_df = df.groupby(['Procedure Code', 'Diagnosis Codes'])['Total Paid'].sum().reset_index() \\n\\nfiltered_df = grouped_df[grouped_df['Total Paid'] > 1000]`````
880
+
881
+ If there is a space in the column name, then you need to fully enclose each occurrence of the column name with double quotes in the query.
882
+ While creating the Plotly chart, you need to get the top 5000 rows since Plotly chart cannot handle more than 5000 rows.
883
+ Pay special attention to grouped bar charts. For grouped bar charts, there should be at least two x-axis columns. One can be the actual x-axis and the other can be used in the 'column' parameter of the Plotly Chart object. For example, the following code in four backticks shows a grouped bar chart with the x-axis showing 'year' and each 'site' for each year.
884
+ Grouped bar chart sample code: ````alt.Chart(source).mark_bar().encode(
885
+ x='year:O',
886
+ y='sum(yield):Q',
887
+ column='site:N'
888
+ )````
889
+
890
+ A grouped bar chart will be explicitly asked for in the instructions.
891
+
892
+ Only produce the Python code.
893
+ Do NOT produce any backticks or double quotes or single quotes before or after the code.
894
+ Do generate the Plotly import statement as part of the code.
895
+ Do NOT justify your code.
896
+ Do not generate any narrative or comments in the code.
897
+ Do NOT produce any JSON tags.
898
+ Do not print or return the chart object at the end.
899
+ Do NOT produce any additional text that is not part of the query itself.
900
+ Always name the final Plotly chart object as 'chart'.
901
+ Go back and check if the generated code can be used in the `st.plotly_chart()` method.
902
+ """
903
+ logger.info(f"Generating graph with prompt: {graph_prompt}")
904
+ graph_response = run_prompt(graph_prompt,query,"generate graph",generate_graph)
905
+ logger.debug("Graph response: {}", graph_response)
906
+
907
+ try:
908
+ # Create a dictionary to capture local variables
909
+ local_vars = {}
910
+
911
+ # Execute the chart generation code and update the local_vars dictionary
912
+ exec(graph_response, {}, local_vars) # type: ignore
913
+ logger.debug("Graph code executed.")
914
+
915
+ # Extract the chart object from local_vars
916
+ chart = local_vars['chart']
917
+ logger.info("Plotly chart object created successfully.")
918
+ except Exception as e:
919
+ logger.error("Error creating plotly chart object: {}", e)
920
+ return None, None
921
+
922
+ return chart, graph_response
923
+
924
+ def get_table_details(engine,selected_db):
925
+ query_tables = """
926
+ SELECT
927
+ c.TABLE_NAME,
928
+ c.TABLE_SCHEMA,
929
+ c.COLUMN_NAME,
930
+ c.DATA_TYPE,
931
+ ep.value AS COLUMN_DESCRIPTION
932
+ FROM
933
+ INFORMATION_SCHEMA.COLUMNS c
934
+ LEFT JOIN
935
+ sys.extended_properties ep
936
+ ON OBJECT_ID(c.TABLE_SCHEMA + '.' + c.TABLE_NAME) = ep.major_id
937
+ AND c.ORDINAL_POSITION = ep.minor_id
938
+ AND ep.name = 'MS_Description'
939
+ ORDER BY
940
+ c.TABLE_NAME,
941
+ c.ORDINAL_POSITION;
942
+ """
943
+
944
+ query_descriptions = """
945
+ SELECT
946
+ t.TABLE_NAME,
947
+ t.TABLE_SCHEMA,
948
+ t.TABLE_TYPE,
949
+ ep.value AS TABLE_DESCRIPTION
950
+ FROM
951
+ INFORMATION_SCHEMA.TABLES t
952
+ LEFT JOIN
953
+ sys.extended_properties ep
954
+ ON OBJECT_ID(t.TABLE_SCHEMA + '.' + t.TABLE_NAME) = ep.major_id
955
+ AND ep.class = 1
956
+ WHERE
957
+ t.TABLE_TYPE='BASE TABLE';
958
+ """
959
+
960
+ tables_df = pd.read_sql(query_tables, engine)
961
+ descriptions_df = pd.read_sql(query_descriptions, engine)
962
+ print(tables_df)
963
+ print(descriptions_df)
964
+ tables_master_dict = {}
965
+ for index, row in descriptions_df.iterrows():
966
+ if row['TABLE_NAME'] not in tables_master_dict:
967
+ tables_master_dict[row['TABLE_NAME']] = f"{selected_db} - {row['TABLE_NAME']} - {row['TABLE_DESCRIPTION']}"
968
+ tables_details_dict = {}
969
+ for table_name, group in tables_df.groupby('TABLE_NAME'):
970
+ columns = [{"name": col.COLUMN_NAME, "type": col.DATA_TYPE, "description": col.COLUMN_DESCRIPTION} for col in group.itertuples()]
971
+ tables_details_dict[table_name] = columns
972
+
973
+ logger.info("Table details fetched successfully.")
974
+ return tables_master_dict, tables_details_dict
975
+
976
+ # Function to fetch database names from SQL Server
977
+ def get_database_names():
978
+ query = """
979
+ SELECT name
980
+ FROM sys.databases
981
+ WHERE name NOT IN ('master', 'tempdb', 'model', 'msdb');
982
+ """
983
+ connection_string = f"DRIVER={SQL_SERVER_CONFIG['driver']};SERVER={SQL_SERVER_CONFIG['server']};Trusted_Connection=yes;"
984
+ try:
985
+ with pyodbc.connect(connection_string) as conn:
986
+ cursor = conn.cursor()
987
+ cursor.execute(query)
988
+ databases = [row[0] for row in cursor.fetchall()]
989
+ logger.info("Database names fetched successfully.")
990
+ return databases
991
+ except Exception as e:
992
+ logger.error("Error fetching database names: {}", e)
993
+ return []
994
+
995
+ def get_metadata(selected_table):
996
+ try:
997
+ metadata_df = pd.DataFrame(st.session_state['table_details'][selected_table])
998
+ logger.info("Metadata fetched for table: {}", selected_table)
999
+ return metadata_df
1000
+ except Exception as e:
1001
+ logger.error("Error fetching metadata for table {}: {}", selected_table, e)
1002
+ return pd.DataFrame()
1003
+
1004
+ # def load_data(sql_generated, selected_db):
1005
+ # # Fetch data in chunks of 100 rows
1006
+ # if 'offset' not in st.session_state:
1007
+ # st.session_state['offset'] = 0
1008
+
1009
+ # if 'data' not in st.session_state:
1010
+ # st.session_state['data'] = pd.DataFrame() # Initialize as an empty DataFrame
1011
+
1012
+ # new_data = fetch_data(sql_generated, selected_db, st.session_state['offset'], 100)
1013
+ # if not new_data.empty:
1014
+ # if st.session_state['offset'] == 0:
1015
+ # st.session_state['data'] = new_data
1016
+ # else:
1017
+ # st.session_state['data'] = pd.concat([st.session_state['data'], new_data], ignore_index=True)
1018
+
1019
+ # grid_options = get_ag_grid_options(st.session_state['data'])
1020
+ # AgGrid(st.session_state['data'], gridOptions=grid_options, key=f'query_grid_{st.session_state["offset"]}', lazyloading=True)
1021
+ # i=0
1022
+ # if not new_data.empty :
1023
+ # button_clicked = False
1024
+ # while not button_clicked:
1025
+ # i+=1
1026
+ # if st.button('Load more', key=i):
1027
+ # button_clicked = True
1028
+ # st.write('Button clicked!')
1029
+ # st.session_state['offset'] += 100
1030
+ # load_data(sql_generated, selected_db)
1031
+ # else:
1032
+ # st.write('Waiting for button click...')
1033
+ # time.sleep(1)
1034
+ # # if st.button("Load more"):
1035
+ # # logger.info(st.session_state['offset'])
1036
+ # # logger.info("hi............................................................")
1037
+ # # st.session_state['offset'] += 100
1038
+ # # load_data(sql_generated, selected_db)
1039
+ # # else:
1040
+ # # logger.info(st.session_state['offset'])
1041
+ # # logger.info("hi buttoon............................................................")
1042
+ # else:
1043
+ # logger.info(st.session_state['offset'])
1044
+ # logger.info("hi next data............................................................")
1045
+
1046
+ def compose_dataset():
1047
+ if "messages" not in st.session_state:
1048
+ logger.debug('Initializing session state messages.')
1049
+ st.session_state.messages = []
1050
+ if "query_result" not in st.session_state:
1051
+ st.session_state.query_result = pd.DataFrame()
1052
+
1053
+ col_aa, col_bb, col_cc = st.columns([1, 4, 1], gap="small", vertical_alignment="center")
1054
+ with col_aa:
1055
+ st.image('logo.png')
1056
+ with col_bb:
1057
+ st.subheader(f"InsightLab - Compose Dataset", divider='blue')
1058
+ st.markdown('**Generate a custom dataset by combining any table with English language questions.**')
1059
+ with col_cc:
1060
+ st.markdown(APP_TITLE, unsafe_allow_html=True)
1061
+
1062
+ databases = get_database_names()
1063
+ selected_db = st.selectbox('Select Database:', [''] + databases)
1064
+
1065
+ if selected_db:
1066
+ if 'selected_db' in st.session_state and st.session_state['selected_db'] != selected_db:
1067
+ # Clear session state data related to the previous database
1068
+ st.session_state['messages'] = []
1069
+ st.session_state['selected_table'] = None
1070
+ logger.debug('Session state cleared due to database change.')
1071
+
1072
+ update_config(selected_db)
1073
+ engine = create_sqlalchemy_engine()
1074
+
1075
+ if 'table_master' not in st.session_state or st.session_state.get('selected_db') != selected_db:
1076
+ tables_master_dict, tables_details_dict = get_table_details(engine, selected_db)
1077
+ st.session_state['table_master'] = tables_master_dict
1078
+ st.session_state['table_details'] = tables_details_dict
1079
+ st.session_state['selected_db'] = selected_db
1080
+
1081
+ tables = list(st.session_state['table_master'].keys())
1082
+ selected_table = st.selectbox('Tables available:', [''] + tables)
1083
+
1084
+ if selected_table:
1085
+ if 'selected_table' not in st.session_state or st.session_state['selected_table'] != selected_table:
1086
+ try:
1087
+ table_metadata_df = get_metadata(selected_table).copy()
1088
+ table_desc = st.session_state['table_master'][selected_table]
1089
+ st.session_state['table_metadata_df'] = table_metadata_df
1090
+ st.session_state.messages.append({"role": "assistant", "type": "text", "content": table_desc})
1091
+ st.session_state.messages.append({"role": "assistant", "type": "dataframe", "content": table_metadata_df})
1092
+ logger.debug('Table metadata and description added to session state messages.')
1093
+ # st.session_state.messages.append({"role": "assistant", "type": "text", "content": ""})
1094
+ # display_paginated_dataframe(table_metadata_df, "table_metadata")
1095
+ except Exception as e:
1096
+ st.error("Please try again")
1097
+ logger.error(f"Error while loading the metadata: {e}")
1098
+ # if 'table_metadata_df' in st.session_state and not st.session_state['table_metadata_df'].empty:
1099
+ # display_paginated_dataframe(st.session_state['table_metadata_df'], "table_metadata_selected")
1100
+ st.session_state['selected_table'] = selected_table
1101
+ message_container = st.container()
1102
+
1103
+ logger.debug("Message container initialized.")
1104
+ with message_container:
1105
+ # Display chat messages from history on app rerun
1106
+ display_historical_responses(st.session_state.messages)
1107
+
1108
+ if prompt := st.chat_input("What is your question?"):
1109
+ logger.debug('User question received.')
1110
+ st.session_state.messages.append({"role": "user", "content": prompt, 'type': 'text'})
1111
+
1112
+ with message_container:
1113
+ with st.chat_message("user"):
1114
+ st.markdown(prompt)
1115
+
1116
+ logger.debug('Processing user question...')
1117
+ with st.chat_message("assistant"):
1118
+ message_placeholder = st.empty()
1119
+ full_response = ""
1120
+ response = {}
1121
+ with st.spinner("Working..."):
1122
+ logger.debug('Executing user query...')
1123
+ try:
1124
+ query_result, sql_generated = handle_retrieve_request(prompt)
1125
+ query_result = drop_duplicate_columns(query_result)
1126
+ st.session_state.messages.append({"role": "assistant", "type": "dataframe", "content": query_result})
1127
+ # logger.debug(query_result)
1128
+ if query_result is not None:
1129
+ response['dataframe'] = query_result
1130
+ logger.debug("userId" + st.session_state.userId)
1131
+ st.session_state.query_result = pd.DataFrame(query_result)
1132
+
1133
+ last_sql = get_max_blob_num(sql_dir + st.session_state.userId + '/')
1134
+ logger.debug(f"Last SQL file number: {last_sql}")
1135
+ st.session_state['last_sql'] = last_sql
1136
+
1137
+ sql_saved = save_sql_query_blob(prompt, sql_generated, last_sql + 1, get_column_types(query_result), sql_dir, selected_db)
1138
+ if sql_saved:
1139
+ response['footnote'] = (last_sql + 1, sql_generated)
1140
+ else:
1141
+ response['text'] = 'Error while saving generated SQL.'
1142
+ st.session_state['retrieval_query'] = prompt
1143
+ st.session_state['retrieval_query_no'] = last_sql + 1
1144
+ st.session_state['retrieval_sql'] = sql_generated
1145
+ st.session_state['retrieval_result_structure'] = get_column_types(query_result)
1146
+ else:
1147
+ st.session_state.messages.append({"role": "assistant", "type": "text", "content": 'Error executing query. Please retry.'})
1148
+ except Exception as e:
1149
+ st.error("Please try again with another prompt")
1150
+ logger.error(f"Error processing request: {e}")
1151
+ display_new_responses(response)
1152
+
1153
+ if 'query_result' in st.session_state and not st.session_state.query_result.empty:
1154
+ display_paginated_dataframe(st.session_state['query_result'], st.session_state['retrieval_query_no'])
1155
+ with st.container():
1156
+ if 'retrieval_sql' in st.session_state and 'selected_db' in st.session_state:
1157
+ if st.button('Save Query'):
1158
+ database_name = st.session_state['selected_db']
1159
+ sql_saved = save_sql_query_blob(st.session_state['retrieval_query'], st.session_state['retrieval_sql'], st.session_state['retrieval_query_no'], st.session_state['retrieval_result_structure'], query_lib, database_name)
1160
+ if sql_saved:
1161
+ st.write(f"Query saved in the library with id {st.session_state['retrieval_query_no']}.")
1162
+ logger.info("Query saved in the library with id {}.", st.session_state['retrieval_query_no'])
1163
+
1164
+ if 'graph_obj' not in st.session_state:
1165
+ st.session_state['graph_obj'] = None
1166
+ if 'graph_prompt' not in st.session_state:
1167
+ st.session_state['graph_prompt'] = ''
1168
+ if 'data_obj' not in st.session_state:
1169
+ st.session_state['data_obj'] = None
1170
+ if 'data_prompt' not in st.session_state:
1171
+ st.session_state['data_prompt'] = ''
1172
+ if 'code_execution_error' not in st.session_state:
1173
+ st.session_state['code_execution_error'] = (None, None)
1174
+
1175
+ def design_insight():
1176
+ col_aa, col_bb, col_cc = st.columns([1, 4, 1], gap="small", vertical_alignment="center")
1177
+ with col_aa:
1178
+ st.image('logo.png')
1179
+ with col_bb:
1180
+ st.subheader("InsightLab - Design Insights", divider='blue')
1181
+ st.markdown('**Select a dataset that you generated and ask for different types of tabular insight or graphical charts.**')
1182
+ with col_cc:
1183
+ st.markdown(APP_TITLE, unsafe_allow_html=True)
1184
+
1185
+ get_saved_query_blob_list()
1186
+ selected_query = st.selectbox('Select a saved query', [""] + list(st.session_state['query_display_dict'].keys()))
1187
+
1188
+ if len(selected_query) > 0:
1189
+ if 'selected_query' not in st.session_state or st.session_state['selected_query'] != selected_query:
1190
+ st.session_state['selected_query'] = selected_query
1191
+ st.session_state['data_obj'] = None
1192
+ st.session_state['graph_obj'] = None
1193
+ st.session_state['data_prompt'] = ''
1194
+ st.session_state['graph_prompt'] = ''
1195
+ st.session_state['data_prompt_value']= ''
1196
+ st.session_state['graph_prompt_value']= ''
1197
+
1198
+ col1, col2 = st.columns([1, 3])
1199
+ with col1:
1200
+ with st.container():
1201
+ st.subheader('Dataset Columns')
1202
+ s = selected_query[len("ID: "):]
1203
+ end_index = s.find(",")
1204
+ id = s[:end_index]
1205
+ try:
1206
+ blob_content = getBlobContent(f"{query_lib}{st.session_state.userId}/{id}.json")
1207
+ content = json.loads(blob_content)
1208
+ st.session_state['query_file_content'] = content
1209
+ sql_query = content['sql']
1210
+ selected_db = content['database']
1211
+ df = execute_sql(sql_query, selected_db)
1212
+ df = drop_duplicate_columns(df)
1213
+ df_dict = get_column_types(df)
1214
+ df_dtypes = pd.DataFrame.from_dict(df_dict, orient='index', columns=['Dtype'])
1215
+ df_dtypes.reset_index(inplace=True)
1216
+ df_dtypes.rename(columns={'index': 'Column'}, inplace=True)
1217
+
1218
+ int_cols = df_dtypes[df_dtypes['Dtype'] == 'int64']['Column'].reset_index(drop=True)
1219
+ float_cols = df_dtypes[df_dtypes['Dtype'] == 'float64']['Column'].reset_index(drop=True)
1220
+ string_cols = df_dtypes[df_dtypes['Dtype'] == 'string']['Column'].reset_index(drop=True)
1221
+ datetime_cols = df_dtypes[df_dtypes['Dtype'] == 'datetime']['Column'].reset_index(drop=True)
1222
+
1223
+ with st.expander("Integer Columns", icon=":material/looks_one:"):
1224
+ st.write("\n\n".join(list(int_cols.values)))
1225
+
1226
+ with st.expander("Decimal Number Columns", icon=":material/pin:"):
1227
+ st.write("\n\n".join(list(float_cols.values)))
1228
+
1229
+ with st.expander("String Columns", icon=":material/abc:"):
1230
+ st.write("\n\n".join(list(string_cols.values)))
1231
+
1232
+ with st.expander("Datetime Columns", icon=":material/calendar_month:"):
1233
+ st.write("\n\n".join(list(datetime_cols.values)))
1234
+
1235
+ st.session_state['explore_df'] = df
1236
+ st.session_state['explore_dtype'] = df_dtypes
1237
+
1238
+ logger.info("Dataset columns displayed using AG Grid.")
1239
+ except Exception as e:
1240
+ st.error("Error while loading the dataset")
1241
+ logger.error("Error loading dataset: {}", e)
1242
+
1243
+ with col2:
1244
+ with st.container():
1245
+ st.subheader('Generate Insight')
1246
+ data_prompt_value = st.session_state.get('data_prompt', '')
1247
+ data_prompt = st.text_area("What insight would you like to generate?", value=data_prompt_value)
1248
+ if st.button('Generate Insight'):
1249
+ st.session_state['data_obj'] = None
1250
+ if data_prompt:
1251
+ st.session_state['data_prompt'] = data_prompt
1252
+ try:
1253
+ data_obj, code_number, st.session_state['data_code'] = answer_guide_question(data_prompt, st.session_state['explore_df'], st.session_state['explore_dtype'], selected_db)
1254
+ if st.session_state['code_execution_error'][1] is None:
1255
+ if data_obj is not None:
1256
+ st.session_state['data_obj'] = data_obj
1257
+ logger.info("Insight generated and displayed using AG Grid.")
1258
+ else:
1259
+ st.session_state['data_obj'] = None
1260
+ st.write('Dataset is empty, please try again.')
1261
+ else:
1262
+ st.write('Please retry again.')
1263
+ del st.session_state['code_execution_error']
1264
+ except Exception as e:
1265
+ st.write("Please try again with another prompt")
1266
+ logger.error("Error generating insight: %s", e)
1267
+ if st.session_state['data_obj'] is not None:
1268
+ display_paginated_dataframe(st.session_state['data_obj'], "ag_grid_insight")
1269
+ st.session_state['data_prompt'] = data_prompt
1270
+
1271
+ with st.container():
1272
+ st.subheader('Generate Graph')
1273
+ graph_prompt_value = st.session_state.get('graph_prompt', '')
1274
+ graph_prompt = st.text_area("What graph would you like to generate?", value=graph_prompt_value)
1275
+ if st.button('Generate Graph'):
1276
+ graph_obj = None
1277
+ if graph_prompt:
1278
+ logger.debug("Graph prompt: %s | Previous graph prompt: %s", st.session_state.get('graph_prompt'), graph_prompt)
1279
+ if st.session_state['graph_prompt'] != graph_prompt:
1280
+ try:
1281
+ graph_obj, st.session_state['graph_code'] = generate_graph(graph_prompt, st.session_state['explore_df'], st.session_state['explore_dtype'], selected_db)
1282
+ st.session_state['graph_obj'] = graph_obj
1283
+
1284
+ if graph_obj is not None:
1285
+ st.plotly_chart(graph_obj, use_container_width=True)
1286
+ logger.info("Graph generated and displayed using Plotly.")
1287
+ else:
1288
+ st.session_state['graph_obj'] = None
1289
+ st.text('Error in generating graph, please try again.')
1290
+ except Exception as e:
1291
+ logger.error("Error in generating graph: %s", e)
1292
+ st.write("Error in generating graph, please try again")
1293
+ else:
1294
+ st.plotly_chart(st.session_state['graph_obj'], use_container_width=True)
1295
+ st.session_state['graph_prompt'] = graph_prompt
1296
+ else:
1297
+ if st.session_state['graph_obj'] is not None:
1298
+ try:
1299
+ st.plotly_chart(st.session_state['graph_obj'], use_container_width=True)
1300
+ except Exception as e:
1301
+ st.write("Error in displaying graph, please try again")
1302
+ logger.error("Error in displaying graph: %s", e)
1303
+ with st.container():
1304
+ if 'graph_obj' in st.session_state or 'data_obj' in st.session_state:
1305
+ user_persona = st.selectbox('Select a persona to save the result of your exploration', persona_list)
1306
+ insight_desc = st.text_area(label='Describe the purpose of this insight for your reference later')
1307
+ if st.button('Save in Library'):
1308
+ base_prompt = st.session_state['query_file_content']['prompt']
1309
+ base_code = st.session_state['query_file_content']['sql']
1310
+
1311
+ insight_prompt = st.session_state.get('data_prompt', '')
1312
+ insight_code = st.session_state.get('data_code', '')
1313
+
1314
+ chart_prompt = st.session_state.get('graph_prompt', '')
1315
+ chart_code = st.session_state.get('graph_code', '')
1316
+
1317
+ try:
1318
+ result = get_existing_insight(base_code, user_persona)
1319
+ if result:
1320
+ existing_insight, file_number = result
1321
+ existing_insight['prompt'][f'prompt_{len(existing_insight["prompt"]) + 1}'] = {
1322
+ 'insight_prompt': insight_prompt,
1323
+ 'insight_code': insight_code
1324
+ }
1325
+ existing_insight['chart'][f'chart_{len(existing_insight["chart"]) + 1}'] = {
1326
+ 'chart_prompt': chart_prompt,
1327
+ 'chart_code': chart_code
1328
+ }
1329
+ try:
1330
+ update_insight(existing_insight, user_persona, file_number)
1331
+ st.text('Insight updated with new Graph and/or Data.')
1332
+ logger.info("Insight updated successfully.")
1333
+ except Exception as e:
1334
+ st.write('Could not update the insight file. Please try again')
1335
+ logger.error("Error while updating insight file: {}", e)
1336
+ else:
1337
+ # Create a new insight entry
1338
+ if not check_blob_exists(f"insight_library/{user_persona}/{st.session_state.userId}"):
1339
+ blob_service_client = BlobServiceClient.from_connection_string(connection_string)
1340
+ container_client = blob_service_client.get_container_client(container_name)
1341
+ logger.info("Creating a new folder in the blob storage:", f"insight_library/{user_persona}/{st.session_state.userId}")
1342
+ folder_path = f"insight_library/{user_persona}/{st.session_state.userId}/"
1343
+ container_client.upload_blob(folder_path, data=b'')
1344
+ next_file_number = get_max_blob_num(f"insight_library/{user_persona}/{st.session_state.userId}/") + 1
1345
+ # logger.info(f"Next file number: {next_file_number}")
1346
+
1347
+ try:
1348
+ save_insight(next_file_number, user_persona, insight_desc, base_prompt, base_code,selected_db, insight_prompt, insight_code, chart_prompt, chart_code)
1349
+ st.text(f'Insight #{next_file_number} with Graph and/or Data saved.')
1350
+ # logger.info(f'Insight #{next_file_number} with Graph and/or Data saved.')
1351
+ except Exception as e:
1352
+ st.write('Could not write the insight file.')
1353
+ logger.error(f"Error while writing insight file: {e}")
1354
+ except Exception as e:
1355
+ st.write(f"Please try again")
1356
+ logger.error(f"Error checking existing insights: {e}")
1357
+
1358
+ def get_insight_list(persona):
1359
+ try:
1360
+ list_blobs_sorted(f"{insight_lib}{persona}/{st.session_state.userId}/", 'json', 'library_files')
1361
+ library_files = st.session_state['library_files']
1362
+ logger.debug("Library files: {}", library_files)
1363
+
1364
+ library_file_list = []
1365
+ library_file_description_list = []
1366
+
1367
+ for file, dt in library_files:
1368
+ id = file[len(insight_lib) + len(persona) + len(st.session_state.userId) + 3:-5]
1369
+ content = getBlobContent(file)
1370
+ content_dict = json.loads(content)
1371
+ description = content_dict.get('description', 'No description available')
1372
+ library_file_description_list.append(f"ID: {id}, Description: \"{description}\", Created on {dt}")
1373
+ library_file_list.append(file)
1374
+
1375
+ logger.info("Insight list generated successfully.")
1376
+ return library_file_list, library_file_description_list
1377
+ except Exception as e:
1378
+ logger.error("Error generating insight list: {}", e)
1379
+ return [], []
1380
+
1381
+ def insight_library():
1382
+ col_aa, col_bb, col_cc = st.columns([1, 4, 1], gap="small", vertical_alignment="center")
1383
+ with col_aa:
1384
+ st.image('logo.png')
1385
+ with col_bb:
1386
+ st.subheader("InsightLab - Personalized Insight Library", divider='blue')
1387
+ st.markdown('**Select one of the pre-configured insights and get the result on the latest data.**')
1388
+ with col_cc:
1389
+ st.markdown(APP_TITLE, unsafe_allow_html=True)
1390
+
1391
+ selected_persona = st.selectbox('Select an analyst persona:', [''] + persona_list)
1392
+
1393
+ if selected_persona:
1394
+ st.session_state['selected_persona'] = selected_persona
1395
+ try:
1396
+ file_list, file_description_list = get_insight_list(selected_persona)
1397
+ selected_insight = st.selectbox(label='Select an insight from the library', options=[""] + file_description_list)
1398
+
1399
+ if selected_insight:
1400
+ idx = file_description_list.index(selected_insight)
1401
+ file = file_list[idx]
1402
+ st.session_state['insight_file'] = file
1403
+
1404
+ content = getBlobContent(file)
1405
+ task_dict = json.loads(content)
1406
+ base_prompt = task_dict.get('base_prompt', 'No base prompt available')
1407
+ base_code = task_dict.get('base_code', '')
1408
+ selected_db = task_dict.get('database', '') # Retrieve the database name from the task dictionary
1409
+ prompts = task_dict.get('prompt', {})
1410
+ charts = task_dict.get('chart', {})
1411
+
1412
+ # Get base dataset
1413
+ df = execute_sql(base_code, selected_db)
1414
+ df = drop_duplicate_columns(df)
1415
+
1416
+ # Display insights
1417
+ st.subheader("Insight Generated")
1418
+ for key, value in prompts.items():
1419
+ st.markdown(f"**{value.get('insight_prompt', 'No insight prompt available')}**")
1420
+ output = {}
1421
+ try:
1422
+ exec(value.get('insight_code', ''), globals(), {'df': df, 'output': output})
1423
+ result_df = output.get('result_df', None)
1424
+ if result_df is not None:
1425
+ st.session_state['code_execution_error'] = (value.get('insight_code', ''), None)
1426
+ display_paginated_dataframe(result_df, f"insight_value_{key}")
1427
+ st.session_state['print_result_df'] = result_df
1428
+ else:
1429
+ logger.warning("result_df is not defined in the output dictionary")
1430
+ except Exception as e:
1431
+ logger.error(f"Error executing generated insight code: {repr(e)}")
1432
+ logger.debug(f"Generated code:\n{value.get('insight_code', '')}")
1433
+
1434
+ # Display charts
1435
+ st.subheader("Chart Generated")
1436
+ for key, value in charts.items():
1437
+ st.markdown(f"**{value.get('chart_prompt', 'No chart prompt available')}**")
1438
+ try:
1439
+ local_vars = {}
1440
+ exec(value.get('chart_code', ''), {}, local_vars)
1441
+ chart = local_vars.get('chart', None)
1442
+ if chart is not None:
1443
+ st.plotly_chart(chart, use_container_width=True)
1444
+ st.session_state['print_chart'] = chart
1445
+ except Exception as e:
1446
+ logger.error(f"Error generating chart: {repr(e)}")
1447
+ st.error("Please try again")
1448
+
1449
+ with st.expander('See base dataset'):
1450
+ st.subheader("Dataset Retrieved")
1451
+ st.markdown(f"**{base_prompt}**")
1452
+ display_paginated_dataframe(df, "base_dataset")
1453
+ st.session_state['print_df'] = df
1454
+ except Exception as e:
1455
+ st.error("Please try again")
1456
+ logger.error(f"Error loading insights: {e}")
1457
+
1458
+
1459
+
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai==1.54.3
2
+ numpy==2.1.1
3
+ pandas==2.2.3
4
+ streamlit==1.40.0
5
+ altair==5.4.1
6
+ reportlab==4.2.4
7
+ streamlit_navigation_bar==3.3.0
8
+ altair_saver==0.5.0
9
+ plotly
10
+ boto3
11
+ azure.storage.blob
12
+ azure
13
+ pyodbc
14
+ mysql-connector-python
15
+ sqlalchemy
16
+ loguru
17
+ streamlit_aggrid
token_usage.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"2025-01": 596483}