Spaces:
Sleeping
Sleeping
| import json | |
| import sqlite3 | |
| # import pyodbc | |
| import mysql.connector | |
| import boto3 | |
| import time | |
| import pandas as pd | |
| import duckdb | |
| import ydata_profiling | |
| from streamlit_pandas_profiling import st_profile_report | |
| from pygwalker.api.streamlit import StreamlitRenderer | |
| import streamlit.components.v1 as components | |
| from openai import AzureOpenAI | |
| import os | |
| import json | |
| import altair as alt | |
| import plotly.express as px | |
| import ast | |
| import streamlit as st | |
| from streamlit_navigation_bar import st_navbar | |
| from glob import glob | |
| from reportlab.lib.pagesizes import letter | |
| from reportlab.lib import colors | |
| from reportlab.platypus import SimpleDocTemplate, Table, TableStyle, Image | |
| from altair_saver import save | |
| from azure.storage.blob import BlobServiceClient, ContainerClient | |
| import re | |
| from sqlalchemy import create_engine | |
| from pages.config import SQL_SERVER_CONFIG, update_config, create_sqlalchemy_engine | |
| from loguru import logger | |
| from st_aggrid import AgGrid, GridOptionsBuilder | |
| from datetime import datetime | |
| # Initialize token storage | |
| token_file = "token_usage.json" | |
| if not os.path.exists(token_file): | |
| with open(token_file, 'w') as f: | |
| json.dump({}, f) | |
| def store_token_usage(token_usage): | |
| # current_month = "2025-01" | |
| current_month = datetime.now().strftime('%Y-%m') | |
| with open(token_file, 'r') as f: | |
| token_data = json.load(f) | |
| if current_month in token_data: | |
| token_data[current_month] += token_usage | |
| else: | |
| token_data[current_month] = token_usage | |
| with open(token_file, 'w') as f: | |
| json.dump(token_data, f) | |
| def get_monthly_token_usage(): | |
| with open(token_file, 'r') as f: | |
| token_data = json.load(f) | |
| return token_data | |
| # Example usage of get_monthly_token_usage function | |
| monthly_token_usage = get_monthly_token_usage() | |
| print(monthly_token_usage) | |
| def show_messages(message): | |
| """Display messages using Streamlit.""" | |
| success_msg = st.info(message) | |
| time.sleep(1.5) | |
| success_msg.empty() | |
| # Locations of various files | |
| APP_TITLE = ' '#'**Social <br>Determinant<br>of Health**' | |
| sql_dir = 'generated_sql/' | |
| method_dir = 'generated_method/' | |
| insight_lib = 'insight_library/' | |
| query_lib = 'query_library/' | |
| report_path = 'Reports/' | |
| connection_string = "DefaultEndpointsProtocol=https;AccountName=phsstorageacc;AccountKey=cEvoESH5CknyeZtbe8eCFuebwr7lRFi1EyO8smA35i5EuoSOfnzRXX/4337Y743B05tQsGPoQbsr+AStNRWeBg==;EndpointSuffix=core.windows.net" | |
| container_name = "insights-lab" | |
| persona_list = ["Population Analyst", "SDoH Specialist"] | |
| DB_List=["Patient SDOH"] | |
| def getBlobContent(dir_path): | |
| try: | |
| blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
| container_client = blob_service_client.get_container_client(container_name) | |
| blob_client = container_client.get_blob_client(dir_path) | |
| blob_data = blob_client.download_blob().readall() | |
| blob_content = blob_data.decode("utf-8") | |
| logger.info("Blob content retrieved successfully from: {}", dir_path) | |
| return blob_content | |
| except Exception as ex: | |
| logger.error("Exception while retrieving blob content: {}", ex) | |
| return "" | |
| def check_blob_exists(dir): | |
| file_exists = False | |
| try: | |
| blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
| container_client = blob_service_client.get_container_client(container_name) | |
| blob_list = container_client.list_blobs(name_starts_with=f"{dir}") | |
| if len(list(blob_list)) > 0: | |
| file_exists = True | |
| logger.info("Blob exists check for {}: {}", dir, file_exists) | |
| return file_exists | |
| except Exception as ex: | |
| logger.error("Exception while checking if blob exists: {}", ex) | |
| return None | |
| def get_max_blob_num(dir): | |
| latest_file_number = 0 | |
| logger.debug("Directory for max blob num check: {}", dir) | |
| try: | |
| blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
| container_client = blob_service_client.get_container_client(container_name) | |
| blob_list = list(container_client.list_blobs(name_starts_with=f"{dir}")) | |
| logger.debug("Blob list: {}", blob_list) | |
| if len(blob_list) == 0: | |
| logger.debug("No blobs found in directory: {}", dir) | |
| latest_file_number = 0 | |
| else: | |
| for blob in blob_list: | |
| blob.name = blob.name.removeprefix(dir) | |
| match = re.search(r"(\d+)", blob.name) # Adjust regex if file names have a different pattern | |
| if match: | |
| file_number = int(match.group(1)) | |
| if latest_file_number == 0 or file_number > latest_file_number: | |
| latest_file_number = file_number | |
| logger.info("Latest file number in {}: {}", dir, latest_file_number) | |
| return latest_file_number | |
| except Exception as ex: | |
| logger.error("Exception while getting max blob number: {}", ex) | |
| return 0 | |
| def save_sql_query_blob(prompt, sql, sql_num, df_structure, dir, database): | |
| data = {"prompt": prompt, "sql": sql, "structure": df_structure,"database": database } | |
| user_directory = dir + st.session_state.userId | |
| blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
| container_client = blob_service_client.get_container_client(container_name) | |
| logger.debug("Saving SQL query blob in directory: {}, SQL number: {}", user_directory, sql_num) | |
| logger.debug("Data to be saved: {}", data) | |
| try: | |
| if not check_blob_exists(user_directory + "/"): | |
| logger.debug("Creating directory: {}", user_directory) | |
| folder_path = f"{user_directory}/" | |
| container_client.upload_blob(folder_path, data=b'') | |
| file_path = f"{user_directory}/{sql_num}.json" | |
| file_content = json.dumps(data, indent=4) | |
| logger.debug("File path: {}", file_path) | |
| result = container_client.upload_blob(file_path, data=file_content) | |
| logger.info("SQL query blob saved successfully: {}", file_path) | |
| return True | |
| except Exception as e: | |
| logger.error("Exception while saving SQL query blob: {}", e) | |
| return False | |
| def save_python_method_blob(method_num, code): | |
| user_directory = method_dir + st.session_state.userId | |
| blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
| container_client = blob_service_client.get_container_client(container_name) | |
| logger.debug("Saving Python method blob in directory: {}, Method number: {}", user_directory, method_num) | |
| try: | |
| if not check_blob_exists(user_directory + "/"): | |
| logger.debug("Creating directory: {}", user_directory) | |
| folder_path = f"{user_directory}/" | |
| container_client.upload_blob(folder_path, data=b'') | |
| file_path = f"{user_directory}/{method_num}.py" | |
| file_content = json.dumps(code, indent=4) | |
| logger.debug("File path: {}", file_path) | |
| result = container_client.upload_blob(file_path, data=file_content) | |
| logger.info("Python method blob saved successfully: {}", file_path) | |
| return True | |
| except Exception as e: | |
| logger.error("Exception while saving Python method blob: {}", e) | |
| return False | |
| def list_blobs_sorted(directory, extension, session_key, latest_first=True): | |
| logger.debug("Listing blobs in directory: {}", directory) | |
| try: | |
| blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
| container_client = blob_service_client.get_container_client(container_name) | |
| blob_list = list(container_client.list_blobs(name_starts_with=f"{directory}")) | |
| files_with_dates = [] | |
| for blob in blob_list: | |
| file_name = blob.name | |
| last_modified = blob.last_modified | |
| if file_name.split('/')[-1] != "" and file_name.split('.')[-1] == extension: | |
| files_with_dates.append((file_name, last_modified.strftime('%Y-%m-%d %H:%M:%S'))) | |
| # Sort by timestamp in descending order | |
| files_with_dates.sort(key=lambda x: x[1], reverse=latest_first) | |
| logger.debug("Files with dates: {}", files_with_dates) | |
| st.session_state[session_key] = files_with_dates | |
| return files_with_dates | |
| except Exception as e: | |
| logger.error("Exception while listing blobs: {}", e) | |
| return [] | |
| # def get_saved_query_blob_list(): | |
| # try: | |
| # user_id = st.session_state.userId | |
| # query_library = query_lib + user_id + "/" | |
| # if 'query_files' not in st.session_state: | |
| # list_blobs_sorted(query_library, 'json', 'query_files') | |
| # query_files = st.session_state['query_files'] | |
| # logger.debug("Query files: {}", query_files) | |
| # query_display_dict = {} | |
| # for file, dt in query_files: | |
| # id = file[len(query_library):-5] | |
| # content = getBlobContent(file) | |
| # content_dict = json.loads(content) | |
| # query_display_dict[f"ID: {id}, Query: \"{content_dict['prompt']}\", Created on {dt}"] = content_dict['sql'] | |
| # st.session_state['query_display_dict']=query_display_dict | |
| # except Exception as e: | |
| # logger.error("Exception while getting saved query blob list: {}", e) | |
| def get_saved_query_blob_list(): | |
| try: | |
| user_id = st.session_state.userId | |
| query_library = query_lib + user_id + "/" | |
| # Always call list_blobs_sorted to get the most recent list of query files | |
| list_blobs_sorted(query_library, 'json', 'query_files') | |
| query_files = st.session_state['query_files'] | |
| logger.debug("Query files: {}", query_files) | |
| query_display_dict = {} | |
| for file, dt in query_files: | |
| id = file[len(query_library):-5] | |
| content = getBlobContent(file) | |
| content_dict = json.loads(content) | |
| query_display_dict[f"ID: {id}, Query: \"{content_dict['prompt']}\", Created on {dt}"] = content_dict['sql'] | |
| st.session_state['query_display_dict'] = query_display_dict | |
| except Exception as e: | |
| logger.error("Exception while getting saved query blob list: {}", e) | |
| def get_existing_token(current_month): | |
| blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
| container_client = blob_service_client.get_container_client(container_name) | |
| # Assuming insights are stored in a specific directory | |
| token_directory = f"token_consumed/{st.session_state.userId}/" | |
| try: | |
| blobs = container_client.list_blobs(name_starts_with=token_directory) | |
| for blob in blobs: | |
| blob_name = blob.name # Extract the blob names | |
| # print(blob_name) | |
| file_name_with_extension = blob_name.split('/')[-1] | |
| file_name = file_name_with_extension.split('.')[0] | |
| blob_client = container_client.get_blob_client(blob_name) | |
| blob_content = blob_client.download_blob().readall() | |
| # print(blob_content) | |
| token_data = json.loads(blob_content) | |
| if token_data['year-month'] == current_month: | |
| logger.info("Existing token_consumed found for month: {}", current_month) | |
| return token_data, file_name | |
| logger.info("No existing token_consumed found for month: {}", current_month) | |
| return None | |
| except Exception as e: | |
| logger.error("Error while retrieving token_consumed: {}", e) | |
| return None | |
| def update_token(token_data, file_number): | |
| user_directory = f"token_consumed/{st.session_state.userId}" | |
| blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
| container_client = blob_service_client.get_container_client(container_name) | |
| try: | |
| file_path = f"{user_directory}/{file_number}.json" | |
| file_content = json.dumps(token_data, indent=4) | |
| container_client.upload_blob(file_path, data=file_content, overwrite=True) | |
| logger.info("token updated successfully: {}", file_number) | |
| return True | |
| except Exception as e: | |
| logger.error("Error while updating token: {}", e) | |
| return False | |
| def save_token(current_month, token_usage, userprompt, purpose, selected_db, time): | |
| new_token = { | |
| 'year-month': current_month, | |
| 'total_token': token_usage, | |
| 'prompt': { | |
| 'prompt_1': { | |
| 'user_prompt': userprompt, | |
| 'prompt_purpose': purpose, | |
| 'database':selected_db, | |
| 'date,time':time, | |
| 'token':token_usage | |
| } | |
| } | |
| } | |
| user_directory = f"token_consumed/{st.session_state.userId}" | |
| blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
| container_client = blob_service_client.get_container_client(container_name) | |
| try: | |
| if not check_blob_exists(user_directory + "/"): | |
| folder_path = f"{user_directory}/" | |
| container_client.upload_blob(folder_path, data=b'') | |
| file_path = f"{user_directory}/{current_month}.json" | |
| file_content = json.dumps(new_token, indent=4) | |
| container_client.upload_blob(file_path, data=file_content) | |
| logger.info("New token created: {}", file_path) | |
| return True | |
| except Exception as e: | |
| logger.error("Error while creating new token: {}", e) | |
| return False | |
| def run_prompt(prompt,userprompt,purpose,selected_db, model="provider-gpt4"): | |
| current_month = datetime.now().strftime('%Y-%m') | |
| time=datetime.now().strftime('%d/%m/%Y, %H:%M:%S') | |
| try: | |
| client = AzureOpenAI( | |
| azure_endpoint="https://provider-openai-2.openai.azure.com/", | |
| api_key="84a58994fdf64338b8c8f0610d63f81c", | |
| api_version="2024-02-15-preview" | |
| ) | |
| response = client.chat.completions.create(model=model, messages=[{"role": "user", "content": prompt}], temperature=0) | |
| logger.debug("Prompt response: {}", response) | |
| # Ensure 'usage' attribute exists and is not None | |
| if response.usage is not None: | |
| token_usage = response.usage.total_tokens # Retrieve total tokens used | |
| logger.info("Tokens consumed: {}", token_usage) # Log token usage | |
| store_token_usage(token_usage) # Store token usage by month | |
| else: | |
| token_usage = 0 | |
| logger.warning("Token usage information is not available in the response") | |
| try: | |
| result = get_existing_token(current_month) | |
| if result: | |
| existing_token, file_number = result | |
| existing_token['total_token']+= token_usage | |
| existing_token['prompt'][f'prompt_{len(existing_token["prompt"]) + 1}'] = { | |
| 'user_prompt': userprompt, | |
| 'prompt_purpose': purpose, | |
| 'database':selected_db, | |
| 'date,time':time, | |
| 'token':token_usage | |
| } | |
| try: | |
| update_token(existing_token, file_number) | |
| # st.text('token updated with Data.') | |
| logger.info("token updated successfully.") | |
| except Exception as e: | |
| # st.write('Could not update the token file. Please try again') | |
| logger.error("Error while updating token file: {}", e) | |
| else: | |
| # Create a new token entry | |
| if not check_blob_exists(f"token_consumed/{st.session_state.userId}"): | |
| blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
| container_client = blob_service_client.get_container_client(container_name) | |
| logger.info("Creating a new folder in the blob storage:", f"token_consumed/{st.session_state.userId}") | |
| folder_path = f"token_consumed/{st.session_state.userId}/" | |
| container_client.upload_blob(folder_path, data=b'') | |
| # next_file_number = get_max_blob_num(f"insight_library/{user_persona}/{st.session_state.userId}/") + 1 | |
| try: | |
| save_token(current_month, token_usage, userprompt,purpose, selected_db, time) | |
| # st.text(f'Token #{current_month} is saved.') | |
| # logger.info(f'Insight #{next_file_number} with Graph and/or Data saved.') | |
| except Exception as e: | |
| # st.write('Could not write the token file.') | |
| logger.error(f"Error while writing token file: {e}") | |
| except Exception as e: | |
| st.write(f"Please try again") | |
| logger.error(f"Error checking existing token: {e}") | |
| return response.choices[0].message.content # Return only the code content | |
| except Exception as e: | |
| logger.error("Exception while running prompt: {}", e) | |
| return "" | |
| def list_files_sorted(directory, extension, session_key, latest_first=True): | |
| try: | |
| # Get a list of all JSON files in the directory | |
| files = glob(os.path.join(directory, f"*.{extension}")) | |
| logger.debug("Files found: {}", files) | |
| # Sort the files by modification time, with the latest files first | |
| files.sort(key=os.path.getmtime, reverse=latest_first) | |
| logger.debug("Sorted files: {}", files) | |
| # Create a list of tuples containing the file name and creation date | |
| files_with_dates = [(file, datetime.fromtimestamp(os.path.getctime(file)).strftime('%Y-%m-%d %H:%M:%S')) for file in files] | |
| st.session_state[session_key] = files_with_dates | |
| return files_with_dates | |
| except Exception as e: | |
| logger.error("Exception while listing files: {}", e) | |
| return [] | |
| def get_column_types(df): | |
| def infer_type(column, series): | |
| try: | |
| if series.dtype == 'int64': | |
| return 'int64' | |
| elif series.dtype == 'float64': | |
| return 'float64' | |
| elif series.dtype == 'bool': | |
| return 'bool' | |
| elif series.dtype == 'object': | |
| try: | |
| # Try to convert to datetime (with time component) | |
| pd.to_datetime(series, format='%Y-%m-%d %H:%M:%S', errors='raise') | |
| return 'datetime' | |
| except (ValueError, TypeError): | |
| try: | |
| # Try to convert to date (without time component) | |
| pd.to_datetime(series, format='%Y-%m-%d', errors='raise') | |
| return 'date' | |
| except (ValueError, TypeError): | |
| return 'string' | |
| else: | |
| return series.dtype.name # fallback for any other dtype | |
| except Exception as e: | |
| logger.error("Exception while inferring column type for {}: {}", column, e) | |
| return 'unknown' | |
| # Create a dictionary with inferred types | |
| try: | |
| column_types = {col: infer_type(col, df[col]) for col in df.columns} | |
| # logger.info("Column types inferred successfully.") | |
| return column_types | |
| except Exception as e: | |
| logger.error("Exception while getting column types: {}", e) | |
| return {} | |
| def save_sql_query(prompt, sql, sql_num, df_structure, dir): | |
| data = {"prompt": prompt, "sql": sql, "structure": df_structure } | |
| user_directory = dir + st.session_state.userId | |
| os.makedirs(user_directory, exist_ok=True) | |
| logger.debug("Saving SQL query to directory: {}, SQL number: {}", user_directory, sql_num) | |
| logger.debug("Data to be saved: {}", data) | |
| try: | |
| # Write the dictionary to a JSON file | |
| with open(f"{user_directory}/{sql_num}.json", 'w') as json_file: | |
| json.dump(data, json_file, indent=4) | |
| logger.info("SQL query saved successfully.") | |
| return True | |
| except Exception as e: | |
| logger.error("Exception while saving SQL query: {}", e) | |
| return False | |
| def save_python_method(method_num, code): | |
| try: | |
| # Write the code to a Python file | |
| with open(f"{method_dir}{method_num}.py", 'w') as code_file: | |
| code_file.write(code) | |
| logger.info("Python method saved successfully: {}", method_num) | |
| return True | |
| except Exception as e: | |
| logger.error("Exception while saving Python method: {}", e) | |
| return False | |
| def get_ag_grid_options(df): | |
| gb = GridOptionsBuilder.from_dataframe(df) | |
| gb.configure_pagination(paginationPageSize=20) # Limit to 20 rows per page | |
| gb.configure_default_column(resizable=True, sortable=True, filterable=True) | |
| # gb.configure_grid_options(domLayout='autoHeight') # Auto-size rows | |
| return gb.build() | |
| def get_existing_insight(base_code, user_persona): | |
| blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
| container_client = blob_service_client.get_container_client(container_name) | |
| # Assuming insights are stored in a specific directory | |
| insights_directory = f"insight_library/{user_persona}/{st.session_state.userId}/" | |
| try: | |
| blobs = container_client.list_blobs(name_starts_with=insights_directory) | |
| for index, blob in enumerate(blobs): | |
| # Skip the first item | |
| if index == 0: | |
| continue | |
| blob_name = blob.name # Extract the blob names | |
| file_name_with_extension = blob_name.split('/')[-1] | |
| file_name = file_name_with_extension.split('.')[0] | |
| blob_client = container_client.get_blob_client(blob_name) | |
| blob_content = blob_client.download_blob().readall() | |
| insight_data = json.loads(blob_content) | |
| if insight_data['base_code'] == base_code: | |
| logger.info("Existing insight found for base code: %s", base_code) | |
| return insight_data, file_name | |
| logger.info("No existing insight found for base code: %s", base_code) | |
| return None | |
| except json.JSONDecodeError as e: | |
| logger.error("Error while retrieving insight: %s", e) | |
| return None | |
| except Exception as e: | |
| logger.error("Error while retrieving insight: %s", e) | |
| return None | |
| def update_insight(insight_data, user_persona, file_number): | |
| user_directory = f"{insight_lib}{user_persona}/{st.session_state.userId}" | |
| blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
| container_client = blob_service_client.get_container_client(container_name) | |
| try: | |
| file_path = f"{user_directory}/{file_number}.json" | |
| file_content = json.dumps(insight_data, indent=4) | |
| container_client.upload_blob(file_path, data=file_content, overwrite=True) | |
| logger.info("Insight updated successfully: %s", file_number) | |
| return True | |
| except Exception as e: | |
| logger.error("Error while updating insight: %s", e) | |
| return False | |
| def save_insight(next_file_number, user_persona, insight_desc, base_prompt, base_code,selected_db, insight_prompt, insight_code, chart_prompt, chart_query, chart_code): | |
| new_insight = { | |
| 'description': insight_desc, | |
| 'base_prompt': base_prompt, | |
| 'base_code': base_code, | |
| 'database':selected_db, | |
| 'prompt': { | |
| 'prompt_1': { | |
| 'insight_prompt': insight_prompt, | |
| 'insight_code': insight_code | |
| } | |
| }, | |
| 'chart': { | |
| 'chart_1': { | |
| 'chart_prompt': chart_prompt, | |
| 'chart_query': chart_query, | |
| 'chart_code': chart_code | |
| } | |
| } | |
| } | |
| user_directory = f"{insight_lib}{user_persona}/{st.session_state.userId}" | |
| blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
| container_client = blob_service_client.get_container_client(container_name) | |
| try: | |
| if not check_blob_exists(user_directory + "/"): | |
| folder_path = f"{user_directory}/" | |
| container_client.upload_blob(folder_path, data=b'') | |
| file_path = f"{user_directory}/{next_file_number}.json" | |
| file_content = json.dumps(new_insight, indent=4) | |
| container_client.upload_blob(file_path, data=file_content) | |
| logger.info("New insight created: {}", file_path) | |
| return True | |
| except Exception as e: | |
| logger.error("Error while creating new insight: {}", e) | |
| return False | |
| def generate_sql(query, table_descriptions, table_details, selected_db): | |
| if len(query) == 0: | |
| return None | |
| with st.spinner('Generating Query'): | |
| query_prompt = f""" | |
| You are an expert in understanding an English language healthcare data query and translating it into an SQL Query that can be executed on a SQLite database. | |
| I am providing you the table names and their purposes that you need to use as a dictionary within double backticks. There may be more than one table. | |
| Table descriptions: ``{table_descriptions}`` | |
| I am providing you the table structure as a dictionary. For this dictionary, table names are the keys. Values within this dictionary | |
| are other dictionaries (nested dictionaries). In each nested dictionary, the keys are the field names and the values are dictionaries | |
| where each key is the column name and each value is the datatype. There may be multiple table structures described here. | |
| The table structure is enclosed in triple backticks. | |
| Table Structures: ```{table_details}``` | |
| Pay special attention to the field names. Some field names have an underscore ('_') and some do not. You need to be accurate while generating the query. | |
| If there is a space in the column name, then you need to fully enclose each occurrence of the column name with double quotes in the query. | |
| This is the English language query that needs to be converted into an SQL Query within four backticks. | |
| English language query: ````{query}```` | |
| Your task is to generate an SQL query that can be executed on a SQLite database. | |
| Only produce the SQL query as a string. | |
| Do NOT produce any backticks before or after. | |
| Do NOT produce any JSON tags. | |
| Do NOT produce any additional text that is not part of the query itself. | |
| """ | |
| logger.info(f"Generating SQL query with prompt:{query_prompt}") | |
| query_response = run_prompt(query_prompt, query,"generate query",selected_db) | |
| # Check if query_response is a tuple and unpack it | |
| if isinstance(query_response, tuple): | |
| query_response = query_response | |
| if query_response is None: | |
| logger.error("Query response is None") | |
| return None | |
| q = query_response.replace('\\', '') | |
| logger.debug("Generated SQL query: %s", q) | |
| return q | |
| # def create_connection(): | |
| # if USE_SQL_SERVER: | |
| # try: | |
| # conn = pyodbc.connect( | |
| # f"DRIVER={SQL_SERVER_CONFIG['driver']};" | |
| # f"SERVER={SQL_SERVER_CONFIG['server']};" | |
| # f"DATABASE={SQL_SERVER_CONFIG['database']};" | |
| # "Trusted_Connection=yes;" | |
| # ) | |
| # logger.info("Connected to SQL Server") | |
| # return conn | |
| # except Exception as e: | |
| # logger.error("Error connecting to SQL Server: {}", e) | |
| # return None | |
| # else: | |
| # try: | |
| # conn = mysql.connector.connect( | |
| # host=MYSQL_SERVER_CONFIG['host'], | |
| # user=MYSQL_SERVER_CONFIG['user'], | |
| # password=MYSQL_SERVER_CONFIG['password'], | |
| # database=MYSQL_SERVER_CONFIG['database'] | |
| # ) | |
| # logger.info("Connected to MySQL Server") | |
| # return conn | |
| # except mysql.connector.Error as err: | |
| # logger.error("Error connecting to MySQL: {}", err) | |
| # return None | |
| # def execute_sql(query, selected_db): | |
| # update_config(selected_db) | |
| # engine = create_sqlalchemy_engine() | |
| # if engine: | |
| # connection = engine.connect() | |
| # logger.info(f"Connected to the database {selected_db}.") | |
| # try: | |
| # df = pd.read_sql_query(query, connection) | |
| # logger.info("Query executed successfully.") | |
| # return df | |
| # except Exception as e: | |
| # logger.error(f"Query execution failed: {e}") | |
| # return pd.DataFrame() | |
| # finally: | |
| # connection.close() | |
| # else: | |
| # logger.error("Failed to create a SQLAlchemy engine.") | |
| # return None | |
| def execute_sql(query,selected_db): | |
| df = None | |
| try: | |
| conn = sqlite3.connect(selected_db) | |
| curr = conn.cursor() | |
| curr.execute(query) | |
| results = curr.fetchall() | |
| columns = [desc[0] for desc in curr.description] | |
| df = pd.DataFrame(results, columns=columns).copy() | |
| logger.info("Query executed successfully.") | |
| except sqlite3.Error as e: | |
| logger.error(f"Error while querying the DB : {e}") | |
| finally: | |
| conn.close() | |
| return df | |
| def handle_retrieve_request(prompt): | |
| sql_generated = generate_sql(prompt, st.session_state['table_master'], st.session_state['table_details'], st.session_state['selected_db']) | |
| logger.debug("Type of sql_generated: %s", type(sql_generated)) | |
| logger.debug("Content of sql_generated: %s", sql_generated) | |
| # Check if sql_generated is a tuple and unpack it | |
| if isinstance(sql_generated, tuple): | |
| logger.debug("Unpacking tuple returned by generate_sql") | |
| sql_generated = sql_generated[0] | |
| if sql_generated is None: | |
| logger.error("Generated SQL is None") | |
| return None, None | |
| logger.debug("Generated SQL: %s", sql_generated) | |
| if 'sql' in sql_generated: | |
| s = sql_generated.find('\n') | |
| rs = sql_generated.rfind('\n') | |
| sql_generated = sql_generated[s+1:rs] | |
| results_df = None | |
| try: | |
| logger.debug("Executing SQL: %s", sql_generated) | |
| sql_generated = sql_generated.replace('###', '') | |
| selected_db = st.session_state.get('selected_db') | |
| results_df = execute_sql(sql_generated, selected_db) | |
| print(sql_generated) | |
| print(results_df) | |
| if results_df.empty: | |
| return None, None | |
| results_df = results_df.copy() | |
| except Exception as e: | |
| logger.error("Error while executing generated query: %s", e) | |
| return results_df, sql_generated | |
| def display_historical_responses(messages): | |
| for index, message in enumerate(messages[:-1]): | |
| logger.debug("Displaying historical response: %s", message) | |
| with st.chat_message(message["role"]): | |
| if 'type' in message: | |
| if message["type"] == "text": | |
| st.markdown(message["content"]) | |
| elif message["type"] == "dataframe" or message["type"] == "table": | |
| display_paginated_dataframe(message["content"], f"message_historical_{index}_{id(message)}") | |
| elif message["type"] == "chart": | |
| st.plotly_chart(message["content"]) | |
| def display_paginated_dataframe(df, key): | |
| if key not in st.session_state: | |
| st.session_state[key] = {'page_number': 1} | |
| if df.empty: | |
| st.write("No data available to display.") | |
| return | |
| page_size = 100 # Number of rows per page | |
| total_rows = len(df) | |
| total_pages = (total_rows // page_size) + (1 if total_rows % page_size != 0 else 0) | |
| # Get the current page number from the user | |
| page_number = st.number_input(f'Page number', min_value=1, max_value=total_pages, value=st.session_state[key]['page_number'], key=f'page_number_{key}') | |
| st.session_state[key]['page_number'] = page_number | |
| # Calculate the start and end indices of the rows to display | |
| start_idx = (page_number - 1) * page_size | |
| end_idx = start_idx + page_size | |
| # Display the current page of data | |
| current_data = df.iloc[start_idx:end_idx] | |
| # Configure AG Grid | |
| gb = GridOptionsBuilder.from_dataframe(current_data) | |
| gb.configure_pagination(paginationAutoPageSize=False, paginationPageSize=page_size) | |
| grid_options = gb.build() | |
| # Display the grid | |
| AgGrid(current_data, gridOptions=grid_options, key=f"query_result_{key}_{page_number}") | |
| def display_new_responses(response): | |
| for k, v in response.items(): | |
| logger.debug("Displaying new response: {} - {}", k, v) | |
| if k == 'text': | |
| st.session_state.messages.append({"role": "assistant", "content": v, "type": "text"}) | |
| st.markdown(v) | |
| # if k == 'dataframe': | |
| # grid_options = get_ag_grid_options(v) | |
| # # AgGrid(v,gridOptions=grid_options,key="new_response") | |
| # st.session_state.messages.append({"role": "assistant", "content": v, "type": "dataframe"}) | |
| if k == 'footnote': | |
| seq_no, sql_str = v | |
| filename = f"{sql_dir}{st.session_state.userId}{'/'}{seq_no}.json" | |
| st.markdown(f"*SQL: {sql_str}', File: {filename}*") | |
| def drop_duplicate_columns(df): | |
| duplicate_columns = df.columns[df.columns.duplicated()].unique() | |
| df = df.loc[:, ~df.columns.duplicated()] | |
| # logger.info("Duplicate columns dropped: {}", duplicate_columns) | |
| return df | |
| def recast_object_columns_to_string(df): | |
| for col in df.columns: | |
| if df[col].dtype == 'object': | |
| df[col] = df[col].astype(str) | |
| logger.debug("Column '{}' recast to string.", col) | |
| return df | |
| def answer_guide_question(question, dframe, df_structure, selected_db): | |
| logger.debug("Question: {}", question) | |
| logger.debug("DataFrame Structure: {}", df_structure) | |
| logger.debug("DataFrame Preview: {}", dframe.head()) | |
| with st.spinner('Generating analysis code'): | |
| # Modified code generation prompt to return just the SQL query without extra formatting | |
| code_gen_prompt = f""" | |
| You are an expert in writing SQL queries for DuckDB. Given the task and the structure of a dataframe, your goal is to generate only the SQL query string that can be executed directly on DuckDB, **without any extra code or formatting**. | |
| The task is provided in double backticks: | |
| Task: ``{question}`` | |
| The dataframe structure is provided as a dictionary where the column names are the keys, and their data types are the values: | |
| DataFrame Structure: ```{df_structure}``` | |
| Your goal is to generate a **clean, valid DuckDB SQL query** that can be executed with `duckdb.query()`. Do **NOT** include any assignment to variables (e.g., `result_df`), comments, backticks, or any additional text. | |
| The **output should be a valid SQL query string**, ready to be executed directly in DuckDB. **Do not include any extra SQL keywords like `sql` or backticks around the query**. | |
| Return **only the raw SQL query string**, without any additional formatting, comments, or explanation. | |
| """ | |
| logger.info(f"Generating insight with prompt: {code_gen_prompt}") | |
| analysis_code = run_prompt(code_gen_prompt, question, "generate insight", selected_db) | |
| # Ensure analysis_code is a string | |
| if not isinstance(analysis_code, str): | |
| logger.error("Generated code is not a string: {}", analysis_code) | |
| raise ValueError("Generated code is not a string") | |
| # Strip any unwanted formatting | |
| duckdb_query = analysis_code.strip() | |
| duckdb_query = duckdb_query.replace("''' sql", "").replace("'''", "").strip() | |
| # Replace "FROM dataframe" with "FROM mydf" | |
| duckdb_query = duckdb_query.replace("FROM dataframe", "FROM mydf").replace("from dataframe", "from mydf").replace("FROM Dataframe", "FROM mydf").replace("from Dataframe", "from mydf") | |
| # Ensure no additional modifications like newlines or extra spaces | |
| duckdb_query = duckdb_query.strip() | |
| last_method_num = get_max_blob_num(method_dir + st.session_state.userId + '/') | |
| try: | |
| file_saved = save_python_method_blob(last_method_num + 1, analysis_code) | |
| logger.info("Code generated and written in {}/{}.py", method_dir, last_method_num) | |
| except Exception as e: | |
| logger.error("Trouble writing the code file for {} and method number {}: {}", question, last_method_num + 1, e) | |
| return duckdb_query, last_method_num + 1 | |
| def generate_duckdb_query(question, mydf , df_structure, selected_db): | |
| # Generate the DuckDB query based on the graph prompt and dataframe structure | |
| code_gen_prompt = f""" | |
| You are an expert in writing SQL queries for DuckDB. Given the task and the structure of a dataframe, your goal is to generate only the SQL query string that can be executed directly on DuckDB, **without any extra code or formatting**. | |
| The user prompt is a graph prompt: generate a 2-column dataset for that graph. | |
| Task: ``{question}`` | |
| The dataframe structure is provided as a dictionary where the column names are the keys, and their data types are the values: | |
| DataFrame Structure: ```{df_structure}``` | |
| Your goal is to generate a **clean, valid DuckDB SQL query** that can be executed with `duckdb.query()`. Do **NOT** include any assignment to variables (e.g., `result_df`), comments, backticks, or any additional text. | |
| The **output should be a valid SQL query string**, ready to be executed directly in DuckDB. **Do not include any extra SQL keywords like `sql` or backticks around the query**. | |
| Return **only the raw SQL query string**, without any additional formatting, comments, or explanation. | |
| """ | |
| logger.info(f"Generating insight with prompt: {code_gen_prompt}") | |
| analysis_code = run_prompt(code_gen_prompt, question, "generate graph query", selected_db) | |
| # Ensure analysis_code is a string | |
| if not isinstance(analysis_code, str): | |
| logger.error("Generated code is not a string: {}", analysis_code) | |
| raise ValueError("Generated code is not a string") | |
| # Strip any unwanted formatting | |
| duckdb_query = analysis_code.strip() | |
| duckdb_query = duckdb_query.replace("''' sql", "").replace("'''", "").strip() | |
| # Replace "FROM dataframe" with "FROM mydf" | |
| duckdb_query = duckdb_query.replace("FROM dataframe", "FROM mydf").replace("from dataframe", "from mydf").replace("FROM Dataframe", "FROM mydf").replace("from Dataframe", "from mydf") | |
| # Ensure no additional modifications like newlines or extra spaces | |
| graph_query = duckdb_query.strip() | |
| logger.error(graph_query) | |
| return graph_query | |
| def generate_graph(query, df_structure, selected_db): | |
| if query is None or df_structure is None: | |
| logger.error("generate_graph received None values for query or df_structure") | |
| return None, None | |
| if len(query) == 0: | |
| return None, None | |
| with st.spinner('Generating graph'): | |
| graph_prompt = f""" | |
| You are an expert in understanding English language instructions to generate a graph based on a given dataframe. | |
| I am providing you the dataframe structure as a dictionary in double backticks. | |
| Dataframe structure: ``{df_structure}`` | |
| I am also giving you the intent instruction in triple backticks. | |
| Instruction for generating the graph: ```{query}``` | |
| # Ensure deterministic behavior in graph code | |
| Only produce the Python code for creating the Plotly chart. | |
| based on the query i want the type of graph/plotly chart. px.bar is just an example type of graph should be genearate based on graph | |
| Do NOT produce any backticks or double quotes or single quotes before or after the code. | |
| Do generate the Plotly import statement as part of the code. | |
| Do NOT justify your code. | |
| Do not generate any narrative or comments in the code. | |
| Do NOT produce any JSON tags. | |
| Do not print or return the chart object at the end. | |
| Do NOT produce any additional text that is not part of the query itself. | |
| Always name the final Plotly chart object as 'chart'. | |
| The task is to generate a Plotly chart using the 2-coloum dataset. Mention the x, y, title, and type of chart based on the user prompt and dataframe structure. | |
| Extract only the Plotly chart creation code segment like `px.bar(graph_df, x='discharge_disposition', y='record_count', color='condition_class', title='Count of Records for Every Condition Class with X Axis Showing Discharge Dispositions')`. | |
| """ | |
| logger.info(f"Generating graph with prompt: {graph_prompt}") | |
| graph_response = run_prompt(graph_prompt, query, "generate graph", selected_db) | |
| logger.debug(f"Graph response: {graph_response}") | |
| # Extract the specific Plotly chart creation code segment | |
| import re | |
| pattern = r'px\.[a-z]+\([^\)]*\)' # Regex pattern to match Plotly chart code | |
| match = re.search(pattern, graph_response) | |
| graph_code = match.group(0) if match else "" | |
| return graph_code | |
| def get_table_details(engine,selected_db): | |
| query_tables = """ | |
| SELECT | |
| c.TABLE_NAME, | |
| c.TABLE_SCHEMA, | |
| c.COLUMN_NAME, | |
| c.DATA_TYPE, | |
| ep.value AS COLUMN_DESCRIPTION | |
| FROM | |
| INFORMATION_SCHEMA.COLUMNS c | |
| LEFT JOIN | |
| sys.extended_properties ep | |
| ON OBJECT_ID(c.TABLE_SCHEMA + '.' + c.TABLE_NAME) = ep.major_id | |
| AND c.ORDINAL_POSITION = ep.minor_id | |
| AND ep.name = 'MS_Description' | |
| ORDER BY | |
| c.TABLE_NAME, | |
| c.ORDINAL_POSITION; | |
| """ | |
| query_descriptions = """ | |
| SELECT | |
| t.TABLE_NAME, | |
| t.TABLE_SCHEMA, | |
| t.TABLE_TYPE, | |
| ep.value AS TABLE_DESCRIPTION | |
| FROM | |
| INFORMATION_SCHEMA.TABLES t | |
| LEFT JOIN | |
| sys.extended_properties ep | |
| ON OBJECT_ID(t.TABLE_SCHEMA + '.' + t.TABLE_NAME) = ep.major_id | |
| AND ep.class = 1 | |
| WHERE | |
| t.TABLE_TYPE='BASE TABLE'; | |
| """ | |
| tables_df = pd.read_sql(query_tables, engine) | |
| descriptions_df = pd.read_sql(query_descriptions, engine) | |
| print(tables_df) | |
| print(descriptions_df) | |
| tables_master_dict = {} | |
| for index, row in descriptions_df.iterrows(): | |
| if row['TABLE_NAME'] not in tables_master_dict: | |
| tables_master_dict[row['TABLE_NAME']] = f"{selected_db} - {row['TABLE_NAME']} - {row['TABLE_DESCRIPTION']}" | |
| tables_details_dict = {} | |
| for table_name, group in tables_df.groupby('TABLE_NAME'): | |
| columns = [{"name": col.COLUMN_NAME, "type": col.DATA_TYPE, "description": col.COLUMN_DESCRIPTION} for col in group.itertuples()] | |
| tables_details_dict[table_name] = columns | |
| logger.info("Table details fetched successfully.") | |
| return tables_master_dict, tables_details_dict | |
| # Function to fetch database names from SQL Server | |
| # def get_database_names(): | |
| # query = """ | |
| # SELECT name | |
| # FROM sys.databases | |
| # WHERE name NOT IN ('master', 'tempdb', 'model', 'msdb'); | |
| # """ | |
| # connection_string = ( | |
| # f"DRIVER={SQL_SERVER_CONFIG['driver']};" | |
| # f"SERVER={SQL_SERVER_CONFIG['server']};" | |
| # f"UID={SQL_SERVER_CONFIG['username']};" # Use SQL Server authentication username | |
| # f"PWD={SQL_SERVER_CONFIG['password']}" # Use SQL Server authentication password | |
| # ) | |
| # engine = create_engine(f"mssql+pyodbc:///?odbc_connect={connection_string}") | |
| # try: | |
| # with engine.connect() as conn: | |
| # result = conn.execute(query) | |
| # databases = [row['name'] for row in result] | |
| # logger.info("Database names fetched successfully.") | |
| # return databases | |
| # except Exception as e: | |
| # logger.error("Error fetching database names: {}", e) | |
| # return [] | |
| # def get_metadata(selected_table): | |
| # try: | |
| # metadata_df = pd.DataFrame(st.session_state['table_details'][selected_table]) | |
| # logger.info("Metadata fetched for table: {}", selected_table) | |
| # return metadata_df | |
| # except Exception as e: | |
| # logger.error("Error fetching metadata for table {}: {}", selected_table, e) | |
| # return pd.DataFrame() | |
| def get_metadata(table): | |
| table_details = st.session_state['table_details'][table] | |
| matadata = [[field, details[0], details[1]] for field, details in table_details.items()] | |
| metadata_df = pd.DataFrame(matadata, columns=['Field Name', 'Field Description', 'Field Type']) | |
| return metadata_df | |
| def get_meta(): | |
| print("---------------step1 -------------------------") | |
| if 'table_master' not in st.session_state: | |
| # load db metadata file | |
| print("---------------step2 -------------------------") | |
| db_js = json.load(open('database/db_tables.json')) | |
| tables_master_dict = {} | |
| tables_details_dict = {} | |
| for j in db_js: | |
| tables_master_dict[j['name']] = j['description'] | |
| tables_details_dict[j['name']] = j['fields'] | |
| print(tables_details_dict) | |
| print(tables_master_dict) | |
| st.session_state['table_master'] = tables_master_dict | |
| st.session_state['table_details'] = tables_details_dict | |
| return | |
| def compose_dataset(): | |
| if "messages" not in st.session_state: | |
| logger.debug('Initializing session state messages.') | |
| st.session_state.messages = [] | |
| if "query_result" not in st.session_state: | |
| st.session_state.query_result = pd.DataFrame() | |
| col_aa, col_bb, col_cc = st.columns([1, 4, 1], gap="small", vertical_alignment="center") | |
| with col_aa: | |
| st.image('logo.png') | |
| with col_bb: | |
| st.subheader(f"InsightLab - Compose Dataset", divider='blue') | |
| st.markdown('**Generate a custom dataset by combining any table with English language questions.**') | |
| with col_cc: | |
| st.markdown(APP_TITLE, unsafe_allow_html=True) | |
| # Initialize selected_db | |
| selected_db = None | |
| selected = st.selectbox('Select Database:', DB_List) | |
| if selected == "Patient SDOH": | |
| selected_db = './gravity_sdoh_observations.db' | |
| st.session_state['selected_db'] = selected_db | |
| if selected_db: | |
| if 'selected_db' in st.session_state and st.session_state['selected_db'] != selected_db: | |
| st.session_state['messages'] = [] | |
| # st.session_state['selected_table'] = None | |
| logger.debug('Session state cleared due to database change.') | |
| st.session_state['selected_db'] = selected_db | |
| if 'table_master' not in st.session_state or st.session_state.get('selected_db') != selected_db: | |
| get_meta() | |
| table_keys = list(st.session_state['table_master'].keys()) | |
| selected_table = st.selectbox('Tables available:', [''] + table_keys) | |
| if selected_table: | |
| if 'selected_table' not in st.session_state or st.session_state['selected_table'] != selected_table: | |
| try: | |
| table_metadata_df = get_metadata(selected_table).copy() | |
| table_desc = st.session_state['table_master'][selected_table] | |
| st.session_state['table_metadata_df'] = table_metadata_df | |
| st.session_state.messages.append({"role": "assistant", "type": "text", "content": table_desc}) | |
| st.session_state.messages.append({"role": "assistant", "type": "dataframe", "content": table_metadata_df}) | |
| logger.debug('Table metadata and description added to session state messages.') | |
| st.session_state.messages.append({"role": "", "type": "", "content": ""}) | |
| except Exception as e: | |
| st.error("Please try again") | |
| logger.error(f"Error while loading the metadata: {e}") | |
| st.session_state['selected_table'] = selected_table | |
| else: | |
| # Debugging statement to check if table_master is None | |
| logger.debug("table_master is None or not in session_state") | |
| message_container = st.container() | |
| logger.debug("Message container initialized.") | |
| with message_container: | |
| display_historical_responses(st.session_state.messages) | |
| if prompt := st.chat_input("What is your question?"): | |
| logger.debug('User question received.') | |
| st.session_state.messages.append({"role": "user", "content": prompt, 'type': 'text'}) | |
| with message_container: | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| logger.debug('Processing user question...') | |
| with st.chat_message("assistant"): | |
| message_placeholder = st.empty() | |
| full_response = "" | |
| response = {} | |
| with st.spinner("Working..."): | |
| logger.debug('Executing user query...') | |
| try: | |
| query_result, sql_generated = handle_retrieve_request(prompt) | |
| query_result = drop_duplicate_columns(query_result) | |
| logger.error(query_result) | |
| st.session_state.messages.append({"role": "assistant", "type": "dataframe", "content": query_result}) | |
| st.session_state.messages.append({"role": "", "type": "", "content": ""}) | |
| if query_result is not None: | |
| response['dataframe'] = query_result | |
| logger.debug("userId" + st.session_state.userId) | |
| st.session_state.query_result = pd.DataFrame(query_result) | |
| last_sql = get_max_blob_num(sql_dir + st.session_state.userId + '/') | |
| logger.debug(f"Last SQL file number: {last_sql}") | |
| st.session_state['last_sql'] = last_sql | |
| sql_saved = save_sql_query_blob(prompt, sql_generated, last_sql + 1, get_column_types(query_result), sql_dir, selected_db) | |
| if sql_saved: | |
| response['footnote'] = (last_sql + 1, sql_generated) | |
| else: | |
| response['text'] = 'Error while saving generated SQL.' | |
| st.session_state['retrieval_query'] = prompt | |
| st.session_state['retrieval_query_no'] = last_sql + 1 | |
| st.session_state['retrieval_sql'] = sql_generated | |
| st.session_state['retrieval_result_structure'] = get_column_types(query_result) | |
| else: | |
| st.session_state.messages.append({"role": "assistant", "type": "text", "content": 'The data set is empty'}) | |
| except Exception as e: | |
| st.write("Please try again with another prompt, the dataset is empty") | |
| logger.error(f"Error processing request: {e}") | |
| display_new_responses(response) | |
| if 'query_result' in st.session_state and not st.session_state.query_result.empty: | |
| display_paginated_dataframe(st.session_state.query_result, st.session_state['retrieval_query_no']) | |
| with st.container(): | |
| if 'retrieval_sql' in st.session_state and 'selected_db' in st.session_state: | |
| if st.button('Save Query'): | |
| database_name = st.session_state['selected_db'] | |
| sql_saved = save_sql_query_blob(st.session_state['retrieval_query'], st.session_state['retrieval_sql'], st.session_state['retrieval_query_no'], st.session_state['retrieval_result_structure'], query_lib, database_name) | |
| if sql_saved: | |
| st.write(f"Query saved in the library with id {st.session_state['retrieval_query_no']}.") | |
| logger.info("Query saved in the library with id {}.", st.session_state['retrieval_query_no']) | |
| def design_insight(): | |
| col_aa, col_bb, col_cc = st.columns([1, 4, 1], gap="small", vertical_alignment="center") | |
| with col_aa: | |
| st.image('logo.png') | |
| with col_bb: | |
| st.subheader("InsightLab - Design Insights", divider='blue') | |
| st.markdown('**Select a dataset that you generated and ask for different types of tabular insight or graphical charts.**') | |
| with col_cc: | |
| st.markdown(APP_TITLE, unsafe_allow_html=True) | |
| if 'graph_obj' not in st.session_state: | |
| st.session_state['graph_obj'] = None | |
| if 'graph_prompt' not in st.session_state: | |
| st.session_state['graph_prompt'] = '' | |
| if 'data_obj' not in st.session_state: | |
| st.session_state['data_obj'] = None | |
| if 'data_prompt' not in st.session_state: | |
| st.session_state['data_prompt'] = '' | |
| if 'code_execution_error' not in st.session_state: | |
| st.session_state['code_execution_error'] = (None, None) | |
| get_saved_query_blob_list() | |
| selected_query = st.selectbox('Select a saved query', [""] + list(st.session_state['query_display_dict'].keys())) | |
| if len(selected_query) > 0: | |
| if 'selected_query' not in st.session_state or st.session_state['selected_query']!= selected_query: | |
| st.session_state['selected_query'] = selected_query | |
| st.session_state['data_obj'] = None | |
| st.session_state['graph_query'] = None | |
| st.session_state['graph_obj'] = None | |
| st.session_state['graph_chart'] = None | |
| st.session_state['data_prompt'] = '' | |
| st.session_state['graph_prompt'] = '' | |
| st.session_state['data_prompt_value']= '' | |
| st.session_state['graph_prompt_value']= '' | |
| # col1, col2 = st.columns([1, 3]) | |
| # with col1: | |
| with st.container(): | |
| st.subheader('Dataset Columns') | |
| s = selected_query[len("ID: "):] | |
| end_index = s.find(",") | |
| id = s[:end_index] | |
| try: | |
| blob_content = getBlobContent(f"{query_lib}{st.session_state.userId}/{id}.json") | |
| content = json.loads(blob_content) | |
| st.session_state['query_file_content'] = content | |
| sql_query = content['sql'] | |
| selected_db = content['database'] | |
| df = execute_sql(sql_query, selected_db) | |
| df = drop_duplicate_columns(df) | |
| df_dict = get_column_types(df) | |
| df_dtypes = pd.DataFrame.from_dict(df_dict, orient='index', columns=['Dtype']) | |
| df_dtypes.reset_index(inplace=True) | |
| df_dtypes.rename(columns={'index': 'Column'}, inplace=True) | |
| int_cols = df_dtypes[df_dtypes['Dtype'] == 'int64']['Column'].reset_index(drop=True) | |
| float_cols = df_dtypes[df_dtypes['Dtype'] == 'float64']['Column'].reset_index(drop=True) | |
| string_cols = df_dtypes[df_dtypes['Dtype'] == 'string']['Column'].reset_index(drop=True) | |
| datetime_cols = df_dtypes[df_dtypes['Dtype'] == 'datetime']['Column'].reset_index(drop=True) | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| with st.expander("Integer Columns", icon=":material/looks_one:"): | |
| st.write("\n\n".join(list(int_cols.values))) | |
| with col2: | |
| with st.expander("Decimal Columns", icon=":material/pin:"): | |
| st.write("\n\n".join(list(float_cols.values))) | |
| with col3: | |
| with st.expander("String Columns", icon=":material/abc:"): | |
| st.write("\n\n".join(list(string_cols.values))) | |
| with col4: | |
| with st.expander("Datetime Columns", icon=":material/calendar_month:"): | |
| st.write("\n\n".join(list(datetime_cols.values))) | |
| st.session_state['explore_df'] = df | |
| st.session_state['explore_dtype'] = df_dtypes | |
| logger.info("Dataset columns displayed using AG Grid.") | |
| except Exception as e: | |
| st.error("Error while loading the dataset") | |
| logger.error("Error loading dataset: {}", e) | |
| # with col2: | |
| with st.container(): | |
| st.subheader('Generate Insight') | |
| # data_prompt_value = st.session_state.get('data_prompt', '') | |
| data_prompt = st.text_area("What insight would you like to generate?")#, value=data_prompt_value) | |
| if st.button('Generate Insight'): | |
| st.session_state['data_obj'] = None | |
| if data_prompt: | |
| st.session_state['data_prompt'] = data_prompt | |
| try: | |
| query, method_num = answer_guide_question(data_prompt, st.session_state['explore_df'], st.session_state['explore_dtype'], selected_db) | |
| if query: | |
| try: | |
| mydf = st.session_state['explore_df'] | |
| st.session_state['query'] = query | |
| print(query) | |
| result_df = duckdb.query(query).to_df() | |
| st.session_state['data_obj'] = result_df | |
| logger.info("Insight generated and displayed using AG Grid.") | |
| # st.session_state['data_prompt'] = '' # Clear the input field | |
| except Exception as e: | |
| st.write('Error executing the query. Please try again.') | |
| logger.error("Error executing the query: %s", e) | |
| else: | |
| st.write('Please retry again.') | |
| del st.session_state['code_execution_error'] | |
| except Exception as e: | |
| st.write("Please try again with another prompt") | |
| logger.error("Error generating insight: %s", e) | |
| if st.session_state['data_obj'] is not None: | |
| # st.text(st.session_state['data_prompt']) | |
| display_paginated_dataframe(st.session_state['data_obj'], "ag_grid_insight") | |
| st.session_state['data_prompt'] = data_prompt | |
| with st.container(): | |
| st.subheader('Generate Graph') | |
| # graph_prompt_value = st.session_state.get('graph_prompt', '') | |
| graph_prompt = st.text_area("What graph would you like to generate?")#, value=graph_prompt_value) | |
| if st.button('Generate Graph'): | |
| graph_obj = None | |
| if graph_prompt: | |
| logger.debug("Graph prompt: %s | Previous graph prompt: %s", st.session_state.get('graph_prompt'), graph_prompt) | |
| if st.session_state['graph_prompt'] != graph_prompt: | |
| try: | |
| duckdb_query =generate_duckdb_query(graph_prompt, st.session_state['explore_df'], st.session_state['explore_dtype'], selected_db) | |
| logger.debug(duckdb_query) | |
| mydf=st.session_state['explore_df'] | |
| st.session_state['graph_query'] = duckdb_query | |
| result_df = duckdb.query(duckdb_query).to_df() | |
| result_df = drop_duplicate_columns(result_df) | |
| result_df_dict = get_column_types(result_df) | |
| result_df_dtypes = pd.DataFrame.from_dict(result_df_dict, orient='index', columns=['Dtype']) | |
| result_df_dtypes.reset_index(inplace=True) | |
| result_df_dtypes.rename(columns={'index': 'Column'}, inplace=True) | |
| graph_df=result_df | |
| graph_response = generate_graph(graph_prompt, result_df_dtypes, selected_db) | |
| graph_code = graph_response # Extract the graph code from the response | |
| logger.debug(graph_code) | |
| st.session_state['graph_obj'] = graph_code | |
| # Ensure 'graph_df' is replaced by 'df' in the generated code | |
| graph_code = graph_code.replace('graph_df', 'df') | |
| # Check and print the generated graph code for debugging | |
| print("Generated graph code:", graph_code) | |
| # Execute the graph code to create the Plotly figure object | |
| local_vars = {'df': graph_df} # Define the dataframe as 'df' | |
| exec(f"import plotly.express as px\nchart = {graph_code}", local_vars) | |
| if 'chart' in local_vars: | |
| chart = local_vars['chart'] # Extract the Plotly chart object | |
| st.session_state['graph_chart'] = chart | |
| st.session_state['graph_df'] = graph_df | |
| st.plotly_chart(chart, use_container_width=True) | |
| else: | |
| st.write("please try agiain with another prompt.") | |
| except Exception as e: | |
| logger.error("Error in generating graph:", e) | |
| st.write("please mention the type of chart/change the prompt and try again") | |
| else: | |
| try: | |
| st.plotly_chart(st.session_state['graph_chart'], use_container_width=True) | |
| except Exception as e: | |
| st.write("Error in displaying graph, please try again") | |
| st.session_state['graph_prompt'] = graph_prompt | |
| else: | |
| if st.session_state['graph_chart'] is not None: | |
| try: | |
| graph_df = st.session_state['graph_df'] | |
| st.plotly_chart(st.session_state['graph_chart'], use_container_width=True) | |
| except Exception as e: | |
| st.write("Error in displaying graph, please try again") | |
| logger.error("Error in displaying graph: %s", e) | |
| with st.container(): | |
| if 'graph_obj' in st.session_state or 'data_obj' in st.session_state: | |
| user_persona = st.selectbox('Select a persona to save the result of your exploration', persona_list) | |
| start_index = selected_query.find('Query: "') + len('Query: "') | |
| end_index = selected_query.find('", Created on') | |
| query = selected_query[start_index:end_index] | |
| insight_desc = st.text_area("Enter your insight discribtion", value=query) | |
| # insight_desc = st.text_area(value=st.session_state['selected_query']) | |
| if st.button('Save in Library'): | |
| base_prompt = st.session_state['query_file_content']['prompt'] | |
| base_code = st.session_state['query_file_content']['sql'] | |
| insight_prompt = st.session_state.get('data_prompt', '') | |
| insight_code = st.session_state.get('query', '') | |
| chart_prompt = st.session_state.get('graph_prompt', '') | |
| chart_query = st.session_state.get('graph_query','') | |
| chart_code = st.session_state.get('graph_obj', '') | |
| try: | |
| result = get_existing_insight(base_code, user_persona) | |
| if result: | |
| existing_insight, file_number = result | |
| if insight_prompt and insight_code is not None: | |
| existing_insight['prompt'][f'prompt_{len(existing_insight["prompt"]) + 1}'] = { | |
| 'insight_prompt': insight_prompt, | |
| 'insight_code': insight_code | |
| } | |
| if chart_prompt and chart_code is not None: | |
| existing_insight['chart'][f'chart_{len(existing_insight["chart"]) + 1}'] = { | |
| 'chart_prompt': chart_prompt, | |
| 'chart_query' : chart_query, | |
| 'chart_code': chart_code | |
| } | |
| try: | |
| update_insight(existing_insight, user_persona, file_number) | |
| st.text('Insight updated with new Graph and/or Data.') | |
| logger.info("Insight updated successfully.") | |
| except Exception as e: | |
| st.write('Could not update the insight file. Please try again') | |
| logger.error("Error while updating insight file: {}", e) | |
| else: | |
| # Create a new insight entry | |
| if not check_blob_exists(f"insight_library/{user_persona}/{st.session_state.userId}"): | |
| blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
| container_client = blob_service_client.get_container_client(container_name) | |
| logger.info("Creating a new folder in the blob storage:", f"insight_library/{user_persona}/{st.session_state.userId}") | |
| folder_path = f"insight_library/{user_persona}/{st.session_state.userId}/" | |
| container_client.upload_blob(folder_path, data=b'') | |
| next_file_number = get_max_blob_num(f"insight_library/{user_persona}/{st.session_state.userId}/") + 1 | |
| # logger.info(f"Next file number: {next_file_number}") | |
| try: | |
| save_insight(next_file_number, user_persona, insight_desc, base_prompt, base_code,selected_db, insight_prompt, insight_code, chart_prompt, chart_query, chart_code) | |
| st.text(f'Insight #{next_file_number} with Graph and/or Data saved.') | |
| # logger.info(f'Insight #{next_file_number} with Graph and/or Data saved.') | |
| except Exception as e: | |
| st.write('Could not write the insight file.') | |
| logger.error(f"Error while writing insight file: {e}") | |
| except Exception as e: | |
| st.write(f"Please try again") | |
| logger.error(f"Error checking existing insights: {e}") | |
| def get_insight_list(persona): | |
| try: | |
| list_blobs_sorted(f"{insight_lib}{persona}/{st.session_state.userId}/", 'json', 'library_files') | |
| library_files = st.session_state['library_files'] | |
| logger.debug("Library files: {}", library_files) | |
| library_file_list = [] | |
| library_file_description_list = [] | |
| for file, dt in library_files: | |
| id = file[len(insight_lib) + len(persona) + len(st.session_state.userId) + 3:-5] | |
| content = getBlobContent(file) | |
| content_dict = json.loads(content) | |
| description = content_dict.get('description', 'No description available') | |
| library_file_description_list.append(f"ID: {id}, Description: \"{description}\", Created on {dt}") | |
| library_file_list.append(file) | |
| logger.info("Insight list generated successfully.") | |
| return library_file_list, library_file_description_list | |
| except Exception as e: | |
| logger.error("Error generating insight list: {}", e) | |
| return [], [] | |
| def insight_library(): | |
| col_aa, col_bb, col_cc = st.columns([1, 4, 1], gap="small", vertical_alignment="center") | |
| with col_aa: | |
| st.image('logo.png') | |
| with col_bb: | |
| st.subheader("InsightLab - Personalized Insight Library", divider='blue') | |
| st.markdown('**Select one of the pre-configured insights and get the result on the latest data.**') | |
| with col_cc: | |
| st.markdown(APP_TITLE, unsafe_allow_html=True) | |
| selected_persona = st.selectbox('Select an analyst persona:', [''] + persona_list) | |
| if selected_persona: | |
| st.session_state['selected_persona'] = selected_persona | |
| try: | |
| file_list, file_description_list = get_insight_list(selected_persona) | |
| selected_insight = st.selectbox(label='Select an insight from the library', options=[""] + file_description_list) | |
| if selected_insight: | |
| idx = file_description_list.index(selected_insight) | |
| file = file_list[idx] | |
| st.session_state['insight_file'] = file | |
| content = getBlobContent(file) | |
| task_dict = json.loads(content) | |
| base_prompt = task_dict.get('base_prompt', 'No base prompt available') | |
| base_code = task_dict.get('base_code', '') | |
| selected_db = task_dict.get('database', '') # Retrieve the database name from the task dictionary | |
| prompts = task_dict.get('prompt', {}) | |
| charts = task_dict.get('chart', {}) | |
| # Get base dataset | |
| df = execute_sql(base_code, selected_db) | |
| df = drop_duplicate_columns(df) | |
| # Display insights | |
| st.subheader("Insight Generated") | |
| for key, value in prompts.items(): | |
| st.markdown(f"**{value.get('insight_prompt', 'No insight prompt available')}**") | |
| output = {} | |
| try: | |
| mydf=df | |
| query_code = value.get('insight_code', '') | |
| result_df = duckdb.query(query_code).to_df() | |
| if result_df is not None: | |
| st.session_state['code_execution_error'] = (value.get('insight_code', ''), None) | |
| display_paginated_dataframe(result_df, f"insight_value_{key}") | |
| st.session_state['print_result_df'] = result_df | |
| else: | |
| logger.warning("result_df is not defined in the output dictionary") | |
| except Exception as e: | |
| logger.error(f"Error executing generated insight code: {repr(e)}") | |
| logger.debug(f"Generated code:\n{value.get('insight_code', '')}") | |
| # Display charts | |
| st.subheader("Chart Generated") | |
| for key, value in charts.items(): | |
| st.markdown(f"**{value.get('chart_prompt', 'No chart prompt available')}**") | |
| try: | |
| mydf=df | |
| query_code = value.get('chart_query','') | |
| result_df = duckdb.query(query_code).to_df() | |
| graph_df=result_df | |
| graph_code = value.get('chart_code', '') | |
| graph_code = graph_code.replace('graph_df', 'df') | |
| local_vars = {'df': graph_df} # Define the dataframe as 'df' | |
| exec(f"import plotly.express as px\nchart = {graph_code}", local_vars) | |
| if 'chart' in local_vars: | |
| chart = local_vars['chart'] # Extract the Plotly chart object | |
| st.plotly_chart(chart, use_container_width=True, key=f"chart_{key}") | |
| st.session_state[f'print_chart_{key}'] = chart | |
| except Exception as e: | |
| logger.error(f"Error generating chart: {repr(e)}") | |
| st.error("Please try again") | |
| with st.expander('See base dataset'): | |
| st.subheader("Dataset Retrieved") | |
| st.markdown(f"**{base_prompt}**") | |
| display_paginated_dataframe(df, "base_dataset") | |
| st.session_state['print_df'] = df | |
| except Exception as e: | |
| st.error("Please try again") | |
| logger.error(f"Error loading insights: {e}") | |
| def data_visualize(): | |
| col_aa, col_bb, col_cc = st.columns([1, 4, 1], gap="small", vertical_alignment="center") | |
| with col_aa: | |
| st.image('logo.png') | |
| with col_bb: | |
| st.subheader("InsightLab - Data Visualize", divider='blue') | |
| st.markdown('**Select a dataset that you generated to visualize the dataset.**') | |
| with col_cc: | |
| st.markdown(APP_TITLE , unsafe_allow_html=True) | |
| get_saved_query_blob_list() | |
| selected_query = st.selectbox('Select a saved query', [""] + list(st.session_state['query_display_dict'].keys())) | |
| if len(selected_query) > 0: | |
| if 'selected_query' not in st.session_state or st.session_state['selected_query'] != selected_query: | |
| with st.container(): | |
| s = selected_query[len("ID: "):] | |
| end_index = s.find(",") | |
| id = s[:end_index] | |
| try: | |
| blob_content = getBlobContent(f"{query_lib}{st.session_state.userId}/{id}.json") | |
| content = json.loads(blob_content) | |
| sql_query = content['sql'] | |
| selected_db = content['database'] | |
| st.session_state['visualize_df'] = execute_sql(sql_query, selected_db) | |
| # Create a StreamlitRenderer instance | |
| if st.session_state.get('visualize_df') is not None: | |
| with st.expander(label = '**Raw Dataset**'): | |
| display_paginated_dataframe(st.session_state['visualize_df'], "base_dataset_for_visualization") | |
| # st.write(st.session_state['visualize_df']) | |
| pyg_app = StreamlitRenderer(st.session_state['visualize_df']) | |
| # Display the interactive visualization | |
| pyg_app.explorer() | |
| # pyg_html=pyg.walk(df).to_html() | |
| # components.html(pyg_html, height=1000, scrolling=True) | |
| except Exception as e: | |
| st.error(f"Error loading dataset: {e}") | |
| def data_profiler(): | |
| col_aa, col_bb, col_cc = st.columns([1, 4, 1], gap="small", vertical_alignment="center") | |
| with col_aa: | |
| st.image('logo.png') | |
| with col_bb: | |
| st.subheader("InsightLab - Data Profiler", divider='blue') | |
| st.markdown('**Select a dataset that you generated for detailed profiling report.**') | |
| with col_cc: | |
| st.markdown(APP_TITLE , unsafe_allow_html=True) | |
| get_saved_query_blob_list() | |
| selected_query = st.selectbox('Select a saved query', [""] + list(st.session_state['query_display_dict'].keys())) | |
| if len(selected_query) > 0: | |
| if 'selected_query' not in st.session_state or st.session_state['selected_query'] != selected_query: | |
| with st.container(): | |
| s = selected_query[len("ID: "):] | |
| end_index = s.find(",") | |
| id = s[:end_index] | |
| try: | |
| blob_content = getBlobContent(f"{query_lib}{st.session_state.userId}/{id}.json") | |
| content = json.loads(blob_content) | |
| sql_query = content['sql'] | |
| selected_db = content['database'] | |
| st.session_state['profile_df'] = execute_sql(sql_query, selected_db) | |
| if st.session_state.get('profile_df') is not None: | |
| with st.expander(label = '**Raw Dataset**'): | |
| display_paginated_dataframe(st.session_state['profile_df'], "base_dataset_for_profiling") | |
| # st.write(st.session_state['profile_df']) | |
| # if st.button('Perform Profiling'): | |
| pr = st.session_state['profile_df'].profile_report() | |
| st_profile_report(pr) | |
| except Exception as e: | |
| st.error(f"Error loading dataset: {e}") |