Maheshsr commited on
Commit
485dd62
·
1 Parent(s): f130825

modifying the graph prompt

Browse files
pages/__pycache__/solution.cpython-312.pyc CHANGED
Binary files a/pages/__pycache__/solution.cpython-312.pyc and b/pages/__pycache__/solution.cpython-312.pyc differ
 
pages/solution.py CHANGED
@@ -14,7 +14,7 @@ from openai import AzureOpenAI
14
  import os
15
  import json
16
  import altair as alt
17
- import plotly
18
  import ast
19
  import streamlit as st
20
  from streamlit_navigation_bar import st_navbar
@@ -230,12 +230,12 @@ def get_existing_token(current_month):
230
  blobs = container_client.list_blobs(name_starts_with=token_directory)
231
  for blob in blobs:
232
  blob_name = blob.name # Extract the blob names
233
- print(blob_name)
234
  file_name_with_extension = blob_name.split('/')[-1]
235
  file_name = file_name_with_extension.split('.')[0]
236
  blob_client = container_client.get_blob_client(blob_name)
237
  blob_content = blob_client.download_blob().readall()
238
- print(blob_content)
239
  token_data = json.loads(blob_content)
240
  if token_data['year-month'] == current_month:
241
  logger.info("Existing token_consumed found for month: {}", current_month)
@@ -493,7 +493,7 @@ def update_insight(insight_data, user_persona, file_number):
493
  logger.error("Error while updating insight: %s", e)
494
  return False
495
 
496
- def save_insight(next_file_number, user_persona, insight_desc, base_prompt, base_code,selected_db, insight_prompt, insight_code, chart_prompt, chart_code):
497
  new_insight = {
498
  'description': insight_desc,
499
  'base_prompt': base_prompt,
@@ -508,6 +508,7 @@ def save_insight(next_file_number, user_persona, insight_desc, base_prompt, base
508
  'chart': {
509
  'chart_1': {
510
  'chart_prompt': chart_prompt,
 
511
  'chart_code': chart_code
512
  }
513
  }
@@ -793,21 +794,52 @@ def answer_guide_question(question, dframe, df_structure, selected_db):
793
  logger.error("Trouble writing the code file for {} and method number {}: {}", question, last_method_num + 1, e)
794
 
795
  return duckdb_query, last_method_num + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
796
 
797
- def generate_graph(query, df, df_structure,generate_graph):
798
- if query is None or df is None or df_structure is None:
799
- logger.error("generate_graph received None values for query, df, or df_structure")
 
 
 
 
 
 
 
 
 
 
 
800
  return None, None
801
 
802
  if len(query) == 0:
803
  return None, None
804
 
805
- df_summary = {
806
- "columns": df.columns.tolist(),
807
- "dtypes": df.dtypes.astype(str).to_dict(),
808
- "describe": df.describe().to_dict()
809
- }
810
-
811
  with st.spinner('Generating graph'):
812
  graph_prompt = f"""
813
  You are an expert in understanding English language instructions to generate a graph based on a given dataframe.
@@ -815,39 +847,11 @@ def generate_graph(query, df, df_structure,generate_graph):
815
  I am providing you the dataframe structure as a dictionary in double backticks.
816
  Dataframe structure: ``{df_structure}``
817
 
818
- I am also providing you a summary of the dataframe as a dictionary in double backticks.
819
- Dataframe summary: ``{df_summary}``
820
-
821
- I have provided the dataframe structure and its summary. I can't provide the entire dataframe.
822
-
823
  I am also giving you the intent instruction in triple backticks.
824
  Instruction for generating the graph: ```{query}```
825
 
826
- Your task is to write the code that will generate a Plotly chart.
827
- You should be able to derive the chart type from the instruction.
828
- Graphs may need calculations, such as aggregating or calculating averages for some of the numeric columns.
829
-
830
- You should generate the code that will allow me to create the Plotly chart object that can then be used as the parameter in Streamlit's `st.plotly_chart()` method.
831
-
832
- Pay special attention to the field names. Some field names have an underscore (_) and some do not. You need to be accurate while generating the query.
833
- Pay special attention when you need to group by based on two categorical columns to create things like bubble charts. For example, the sample code within four backticks below is the correct way to prepare a dataframe with procedure code, a categorical variable in one axis, and diagnosis code, another categorical variable in another axis, and the size of the bubble would be based on the sum of 'Total Paid' values for each procedure and diagnosis code combination.
834
- Sample code: ````grouped_df = df_ma.groupby(['Procedure Code', 'Diagnosis Codes'])['Total Paid'].sum().reset_index()````
835
-
836
- If you need to add a filter criterion, then you need to add a second step as indicated in five backticks below. This shows it is filtering the dataframe for all groups with a sum of 'Total Paid' more than 1000. You can feed the last dataframe to the Plotly chart.
837
- Sample code: `````grouped_df = df.groupby(['Procedure Code', 'Diagnosis Codes'])['Total Paid'].sum().reset_index() \\n\\nfiltered_df = grouped_df[grouped_df['Total Paid'] > 1000]`````
838
-
839
- If there is a space in the column name, then you need to fully enclose each occurrence of the column name with double quotes in the query.
840
- While creating the Plotly chart, you need to get the top 5000 rows since Plotly chart cannot handle more than 5000 rows.
841
- Pay special attention to grouped bar charts. For grouped bar charts, there should be at least two x-axis columns. One can be the actual x-axis and the other can be used in the 'column' parameter of the Plotly Chart object. For example, the following code in four backticks shows a grouped bar chart with the x-axis showing 'year' and each 'site' for each year.
842
- Grouped bar chart sample code: ````alt.Chart(source).mark_bar().encode(
843
- x='year:O',
844
- y='sum(yield):Q',
845
- column='site:N'
846
- )````
847
-
848
- A grouped bar chart will be explicitly asked for in the instructions.
849
-
850
- Only produce the Python code.
851
  Do NOT produce any backticks or double quotes or single quotes before or after the code.
852
  Do generate the Plotly import statement as part of the code.
853
  Do NOT justify your code.
@@ -856,28 +860,21 @@ def generate_graph(query, df, df_structure,generate_graph):
856
  Do not print or return the chart object at the end.
857
  Do NOT produce any additional text that is not part of the query itself.
858
  Always name the final Plotly chart object as 'chart'.
859
- Go back and check if the generated code can be used in the `st.plotly_chart()` method.
 
860
  """
861
- logger.info(f"Generating graph with prompt: {graph_prompt}")
862
- graph_response = run_prompt(graph_prompt,query,"generate graph",generate_graph)
863
- logger.debug("Graph response: {}", graph_response)
864
-
865
- try:
866
- # Create a dictionary to capture local variables
867
- local_vars = {}
868
 
869
- # Execute the chart generation code and update the local_vars dictionary
870
- exec(graph_response, {}, local_vars) # type: ignore
871
- logger.debug("Graph code executed.")
872
 
873
- # Extract the chart object from local_vars
874
- chart = local_vars['chart']
875
- logger.info("Plotly chart object created successfully.")
876
- except Exception as e:
877
- logger.error("Error creating plotly chart object: {}", e)
878
- return None, None
879
 
880
- return chart, graph_response
881
 
882
  def get_table_details(engine,selected_db):
883
  query_tables = """
@@ -1134,7 +1131,9 @@ def design_insight():
1134
  if 'selected_query' not in st.session_state or st.session_state['selected_query'] != selected_query:
1135
  st.session_state['selected_query'] = selected_query
1136
  st.session_state['data_obj'] = None
 
1137
  st.session_state['graph_obj'] = None
 
1138
  st.session_state['data_prompt'] = ''
1139
  st.session_state['graph_prompt'] = ''
1140
  st.session_state['data_prompt_value']= ''
@@ -1235,29 +1234,51 @@ def design_insight():
1235
  logger.debug("Graph prompt: %s | Previous graph prompt: %s", st.session_state.get('graph_prompt'), graph_prompt)
1236
  if st.session_state['graph_prompt'] != graph_prompt:
1237
  try:
1238
- graph_obj, st.session_state['graph_code'] = generate_graph(graph_prompt, st.session_state['explore_df'], st.session_state['explore_dtype'], selected_db)
1239
- st.session_state['graph_obj'] = graph_obj
1240
-
1241
- if graph_obj is not None:
1242
- # st.text(st.session_state['graph_prompt'])
1243
- st.plotly_chart(graph_obj, use_container_width=True)
1244
- logger.info("Graph generated and displayed using Plotly.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1245
  else:
1246
- st.session_state['graph_obj'] = None
1247
- st.text('Error in generating graph, please try again.')
1248
  except Exception as e:
1249
- logger.error("Error in generating graph: %s", e)
1250
  st.write("Error in generating graph, please try again")
1251
  else:
1252
  try:
1253
- st.plotly_chart(st.session_state['graph_obj'], use_container_width=True)
1254
  except Exception as e:
1255
  st.write("Error in displaying graph, please try again")
1256
  st.session_state['graph_prompt'] = graph_prompt
1257
  else:
1258
- if st.session_state['graph_obj'] is not None:
1259
  try:
1260
- st.plotly_chart(st.session_state['graph_obj'], use_container_width=True)
 
1261
  except Exception as e:
1262
  st.write("Error in displaying graph, please try again")
1263
  logger.error("Error in displaying graph: %s", e)
@@ -1271,9 +1292,10 @@ def design_insight():
1271
 
1272
  insight_prompt = st.session_state.get('data_prompt', '')
1273
  insight_code = st.session_state.get('query', '')
1274
-
1275
  chart_prompt = st.session_state.get('graph_prompt', '')
1276
- chart_code = st.session_state.get('graph_code', '')
 
1277
 
1278
  try:
1279
  result = get_existing_insight(base_code, user_persona)
@@ -1287,6 +1309,7 @@ def design_insight():
1287
  if chart_prompt and chart_code is not None:
1288
  existing_insight['chart'][f'chart_{len(existing_insight["chart"]) + 1}'] = {
1289
  'chart_prompt': chart_prompt,
 
1290
  'chart_code': chart_code
1291
  }
1292
  try:
@@ -1308,7 +1331,7 @@ def design_insight():
1308
  # logger.info(f"Next file number: {next_file_number}")
1309
 
1310
  try:
1311
- save_insight(next_file_number, user_persona, insight_desc, base_prompt, base_code,selected_db, insight_prompt, insight_code, chart_prompt, chart_code)
1312
  st.text(f'Insight #{next_file_number} with Graph and/or Data saved.')
1313
  # logger.info(f'Insight #{next_file_number} with Graph and/or Data saved.')
1314
  except Exception as e:
@@ -1400,12 +1423,18 @@ def insight_library():
1400
  for key, value in charts.items():
1401
  st.markdown(f"**{value.get('chart_prompt', 'No chart prompt available')}**")
1402
  try:
1403
- local_vars = {}
1404
- exec(value.get('chart_code', ''), {}, local_vars)
1405
- chart = local_vars.get('chart', None)
1406
- if chart is not None:
1407
- st.plotly_chart(chart, use_container_width=True)
1408
- st.session_state['print_chart'] = chart
 
 
 
 
 
 
1409
  except Exception as e:
1410
  logger.error(f"Error generating chart: {repr(e)}")
1411
  st.error("Please try again")
 
14
  import os
15
  import json
16
  import altair as alt
17
+ import plotly.express as px
18
  import ast
19
  import streamlit as st
20
  from streamlit_navigation_bar import st_navbar
 
230
  blobs = container_client.list_blobs(name_starts_with=token_directory)
231
  for blob in blobs:
232
  blob_name = blob.name # Extract the blob names
233
+ # print(blob_name)
234
  file_name_with_extension = blob_name.split('/')[-1]
235
  file_name = file_name_with_extension.split('.')[0]
236
  blob_client = container_client.get_blob_client(blob_name)
237
  blob_content = blob_client.download_blob().readall()
238
+ # print(blob_content)
239
  token_data = json.loads(blob_content)
240
  if token_data['year-month'] == current_month:
241
  logger.info("Existing token_consumed found for month: {}", current_month)
 
493
  logger.error("Error while updating insight: %s", e)
494
  return False
495
 
496
+ def save_insight(next_file_number, user_persona, insight_desc, base_prompt, base_code,selected_db, insight_prompt, insight_code, chart_prompt, chart_query, chart_code):
497
  new_insight = {
498
  'description': insight_desc,
499
  'base_prompt': base_prompt,
 
508
  'chart': {
509
  'chart_1': {
510
  'chart_prompt': chart_prompt,
511
+ 'chart_query': chart_query,
512
  'chart_code': chart_code
513
  }
514
  }
 
794
  logger.error("Trouble writing the code file for {} and method number {}: {}", question, last_method_num + 1, e)
795
 
796
  return duckdb_query, last_method_num + 1
797
+
798
+ def generate_duckdb_query(question, mydf , df_structure, selected_db):
799
+ # Generate the DuckDB query based on the graph prompt and dataframe structure
800
+ code_gen_prompt = f"""
801
+ You are an expert in writing SQL queries for DuckDB. Given the task and the structure of a dataframe, your goal is to generate only the SQL query string that can be executed directly on DuckDB, **without any extra code or formatting**.
802
+
803
+ The user prompt is a graph prompt: generate a 2-column dataset for that graph.
804
+ Task: ``{question}``
805
+
806
+ The dataframe structure is provided as a dictionary where the column names are the keys, and their data types are the values:
807
+ DataFrame Structure: ```{df_structure}```
808
+
809
+ Your goal is to generate a **clean, valid DuckDB SQL query** that can be executed with `duckdb.query()`. Do **NOT** include any assignment to variables (e.g., `result_df`), comments, backticks, or any additional text.
810
+
811
+ The **output should be a valid SQL query string**, ready to be executed directly in DuckDB. **Do not include any extra SQL keywords like `sql` or backticks around the query**.
812
+
813
+ Return **only the raw SQL query string**, without any additional formatting, comments, or explanation.
814
+ """
815
+
816
+ logger.info(f"Generating insight with prompt: {code_gen_prompt}")
817
+ analysis_code = run_prompt(code_gen_prompt, question, "generate graph query", selected_db)
818
+
819
+ # Ensure analysis_code is a string
820
+ if not isinstance(analysis_code, str):
821
+ logger.error("Generated code is not a string: {}", analysis_code)
822
+ raise ValueError("Generated code is not a string")
823
 
824
+ # Strip any unwanted formatting
825
+ duckdb_query = analysis_code.strip()
826
+
827
+ # Replace "FROM dataframe" with "FROM mydf"
828
+ duckdb_query = duckdb_query.replace("FROM dataframe", "FROM mydf")
829
+
830
+ # Ensure no additional modifications like newlines or extra spaces
831
+ graph_query = duckdb_query.strip()
832
+ logger.error(graph_query)
833
+ return graph_query
834
+
835
+ def generate_graph(query, df_structure, selected_db):
836
+ if query is None or df_structure is None:
837
+ logger.error("generate_graph received None values for query or df_structure")
838
  return None, None
839
 
840
  if len(query) == 0:
841
  return None, None
842
 
 
 
 
 
 
 
843
  with st.spinner('Generating graph'):
844
  graph_prompt = f"""
845
  You are an expert in understanding English language instructions to generate a graph based on a given dataframe.
 
847
  I am providing you the dataframe structure as a dictionary in double backticks.
848
  Dataframe structure: ``{df_structure}``
849
 
 
 
 
 
 
850
  I am also giving you the intent instruction in triple backticks.
851
  Instruction for generating the graph: ```{query}```
852
 
853
+ # Ensure deterministic behavior in graph code
854
+ Only produce the Python code for creating the Plotly chart.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855
  Do NOT produce any backticks or double quotes or single quotes before or after the code.
856
  Do generate the Plotly import statement as part of the code.
857
  Do NOT justify your code.
 
860
  Do not print or return the chart object at the end.
861
  Do NOT produce any additional text that is not part of the query itself.
862
  Always name the final Plotly chart object as 'chart'.
863
+ The task is to generate a Plotly chart using the 2-coloum dataset. Mention the x, y, title, and type of chart based on the user prompt and dataframe structure.
864
+ Extract only the Plotly chart creation code segment like `px.bar(graph_df, x='discharge_disposition', y='record_count', color='condition_class', title='Count of Records for Every Condition Class with X Axis Showing Discharge Dispositions')`.
865
  """
 
 
 
 
 
 
 
866
 
867
+ logger.info(f"Generating graph with prompt: {graph_prompt}")
868
+ graph_response = run_prompt(graph_prompt, query, "generate graph", selected_db)
869
+ logger.debug(f"Graph response: {graph_response}")
870
 
871
+ # Extract the specific Plotly chart creation code segment
872
+ import re
873
+ pattern = r'px\.[a-z]+\([^\)]*\)' # Regex pattern to match Plotly chart code
874
+ match = re.search(pattern, graph_response)
875
+ graph_code = match.group(0) if match else ""
876
+ return graph_code
877
 
 
878
 
879
  def get_table_details(engine,selected_db):
880
  query_tables = """
 
1131
  if 'selected_query' not in st.session_state or st.session_state['selected_query'] != selected_query:
1132
  st.session_state['selected_query'] = selected_query
1133
  st.session_state['data_obj'] = None
1134
+ st.session_state['graph_query'] = None
1135
  st.session_state['graph_obj'] = None
1136
+ st.session_state['graph_chart'] = None
1137
  st.session_state['data_prompt'] = ''
1138
  st.session_state['graph_prompt'] = ''
1139
  st.session_state['data_prompt_value']= ''
 
1234
  logger.debug("Graph prompt: %s | Previous graph prompt: %s", st.session_state.get('graph_prompt'), graph_prompt)
1235
  if st.session_state['graph_prompt'] != graph_prompt:
1236
  try:
1237
+ duckdb_query =generate_duckdb_query(graph_prompt, st.session_state['explore_df'], st.session_state['explore_dtype'], selected_db)
1238
+ mydf=df
1239
+ st.session_state['graph_query'] = duckdb_query
1240
+ result_df = duckdb.query(duckdb_query).to_df()
1241
+ result_df = drop_duplicate_columns(result_df)
1242
+ result_df_dict = get_column_types(result_df)
1243
+ result_df_dtypes = pd.DataFrame.from_dict(result_df_dict, orient='index', columns=['Dtype'])
1244
+ result_df_dtypes.reset_index(inplace=True)
1245
+ result_df_dtypes.rename(columns={'index': 'Column'}, inplace=True)
1246
+ graph_df=result_df
1247
+ graph_response = generate_graph(graph_prompt, result_df_dtypes, selected_db)
1248
+
1249
+ graph_code = graph_response # Extract the graph code from the response
1250
+ st.session_state['graph_obj'] = graph_code
1251
+ # Ensure 'graph_df' is replaced by 'df' in the generated code
1252
+ graph_code = graph_code.replace('graph_df', 'df')
1253
+
1254
+ # Check and print the generated graph code for debugging
1255
+ print("Generated graph code:", graph_code)
1256
+
1257
+ # Execute the graph code to create the Plotly figure object
1258
+
1259
+ local_vars = {'df': graph_df} # Define the dataframe as 'df'
1260
+ exec(f"import plotly.express as px\nchart = {graph_code}", local_vars)
1261
+ if 'chart' in local_vars:
1262
+ chart = local_vars['chart'] # Extract the Plotly chart object
1263
+ st.session_state['graph_chart'] = chart
1264
+ st.session_state['graph_df'] = graph_df
1265
+ st.plotly_chart(chart, use_container_width=True)
1266
  else:
1267
+ st.write("Chart object was not created.")
 
1268
  except Exception as e:
1269
+ logger.error("Error in generating graph:", e)
1270
  st.write("Error in generating graph, please try again")
1271
  else:
1272
  try:
1273
+ st.plotly_chart(st.session_state['graph_chart'], use_container_width=True)
1274
  except Exception as e:
1275
  st.write("Error in displaying graph, please try again")
1276
  st.session_state['graph_prompt'] = graph_prompt
1277
  else:
1278
+ if st.session_state['graph_chart'] is not None:
1279
  try:
1280
+ graph_df = st.session_state['graph_df']
1281
+ st.plotly_chart(st.session_state['graph_chart'], use_container_width=True)
1282
  except Exception as e:
1283
  st.write("Error in displaying graph, please try again")
1284
  logger.error("Error in displaying graph: %s", e)
 
1292
 
1293
  insight_prompt = st.session_state.get('data_prompt', '')
1294
  insight_code = st.session_state.get('query', '')
1295
+
1296
  chart_prompt = st.session_state.get('graph_prompt', '')
1297
+ chart_query = st.session_state.get('graph_query','')
1298
+ chart_code = st.session_state.get('graph_obj', '')
1299
 
1300
  try:
1301
  result = get_existing_insight(base_code, user_persona)
 
1309
  if chart_prompt and chart_code is not None:
1310
  existing_insight['chart'][f'chart_{len(existing_insight["chart"]) + 1}'] = {
1311
  'chart_prompt': chart_prompt,
1312
+ 'chart_query' : chart_query,
1313
  'chart_code': chart_code
1314
  }
1315
  try:
 
1331
  # logger.info(f"Next file number: {next_file_number}")
1332
 
1333
  try:
1334
+ save_insight(next_file_number, user_persona, insight_desc, base_prompt, base_code,selected_db, insight_prompt, insight_code, chart_prompt, chart_query, chart_code)
1335
  st.text(f'Insight #{next_file_number} with Graph and/or Data saved.')
1336
  # logger.info(f'Insight #{next_file_number} with Graph and/or Data saved.')
1337
  except Exception as e:
 
1423
  for key, value in charts.items():
1424
  st.markdown(f"**{value.get('chart_prompt', 'No chart prompt available')}**")
1425
  try:
1426
+ mydf=df
1427
+ query_code = value.get('chart_query','')
1428
+ result_df = duckdb.query(query_code).to_df()
1429
+ graph_df=result_df
1430
+ graph_code = value.get('chart_code', '')
1431
+ graph_code = graph_code.replace('graph_df', 'df')
1432
+ local_vars = {'df': graph_df} # Define the dataframe as 'df'
1433
+ exec(f"import plotly.express as px\nchart = {graph_code}", local_vars)
1434
+ if 'chart' in local_vars:
1435
+ chart = local_vars['chart'] # Extract the Plotly chart object
1436
+ st.plotly_chart(chart, use_container_width=True, key=f"chart_{key}")
1437
+ st.session_state[f'print_chart_{key}'] = chart
1438
  except Exception as e:
1439
  logger.error(f"Error generating chart: {repr(e)}")
1440
  st.error("Please try again")