Deepa Shalini commited on
Commit
b537031
·
1 Parent(s): f6ba808

migration from openai to groq/compound-mini

Browse files
.gitignore CHANGED
@@ -1,5 +1,5 @@
1
  # ignore virtual environment
2
- chartbot/
3
 
4
  # ignore python cache
5
  __pycache__/
 
1
  # ignore virtual environment
2
+ .venv/
3
 
4
  # ignore python cache
5
  __pycache__/
app.py CHANGED
@@ -131,4 +131,4 @@ def update_navlink_styles(pathname):
131
  ]
132
 
133
  if __name__ == "__main__":
134
- app.run_server(debug=False)
 
131
  ]
132
 
133
  if __name__ == "__main__":
134
+ app.run(debug=False)
pages/amazon_purchases_chartbot.py CHANGED
@@ -1,5 +1,6 @@
1
  import dash
2
- from dash import callback, Input, Output, State, ctx
 
3
 
4
  from utils import chartbot_dataset_layout, prompt, helpers
5
 
@@ -71,12 +72,20 @@ def update_submit_loading(n_clicks):
71
  prevent_initial_call=True
72
  )
73
  def create_graph(_, user_prompt):
74
- df_5_rows = df.head(5)
75
- data_top5_csv_string = df_5_rows.to_csv(index=False)
 
76
 
77
- result_output = prompt.get_response(user_prompt, data_top5_csv_string, DATA_FILE_PATH)
78
 
79
- return helpers.display_response(result_output, DATA_FILE_PATH)
 
 
 
 
 
 
 
80
 
81
  @callback(
82
  Output("amazon-purchases-download-html", "href"),
 
1
  import dash
2
+ from dash import callback, Input, Output, State, ctx, html
3
+ import dash_mantine_components as dmc
4
 
5
  from utils import chartbot_dataset_layout, prompt, helpers
6
 
 
72
  prevent_initial_call=True
73
  )
74
  def create_graph(_, user_prompt):
75
+ try:
76
+ df_5_rows = df.head(5)
77
+ data_top5_csv_string = df_5_rows.to_csv(index=False)
78
 
79
+ result_output = prompt.get_response(user_prompt, data_top5_csv_string, DATA_FILE_PATH)
80
 
81
+ return helpers.display_response(result_output, DATA_FILE_PATH)
82
+
83
+ except Exception as e:
84
+ error_message = str(e)
85
+ return html.Div([
86
+ html.Br(),
87
+ dmc.Alert(error_message, title="Error", color="red")
88
+ ]), None, {"display": "none"}, None, False
89
 
90
  @callback(
91
  Output("amazon-purchases-download-html", "href"),
pages/chartbot.py CHANGED
@@ -112,15 +112,23 @@ def create_graph(n_clicks, user_prompt, file_data, file_name):
112
  return dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update
113
 
114
  else:
115
- df = pd.DataFrame(file_data)
116
- df_5_rows = df.head(5)
117
- data_top5_csv_string = df_5_rows.to_csv(index=False)
118
-
119
- helpers.save_dataframe_to_current_path(df, file_name)
120
-
121
- result_output = prompt.get_response(user_prompt, data_top5_csv_string, file_name)
122
-
123
- return helpers.display_response(result_output, file_name)
 
 
 
 
 
 
 
 
124
 
125
  @callback(
126
  Output("download-html", "href"),
 
112
  return dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update
113
 
114
  else:
115
+ try:
116
+ df = pd.DataFrame(file_data)
117
+ df_5_rows = df.head(5)
118
+ data_top5_csv_string = df_5_rows.to_csv(index=False)
119
+
120
+ helpers.save_dataframe_to_current_path(df, file_name)
121
+
122
+ result_output = prompt.get_response(user_prompt, data_top5_csv_string, file_name)
123
+
124
+ return helpers.display_response(result_output, file_name)
125
+
126
+ except Exception as e:
127
+ error_message = str(e)
128
+ return html.Div([
129
+ html.Br(),
130
+ dmc.Alert(error_message, title="Error", color="red")
131
+ ]), None, {"display": "none"}, None, False
132
 
133
  @callback(
134
  Output("download-html", "href"),
pages/ewf_chartbot.py CHANGED
@@ -1,5 +1,6 @@
1
  import dash
2
- from dash import callback, Input, Output, State, ctx
 
3
 
4
  from utils import chartbot_dataset_layout, prompt, helpers
5
 
@@ -69,12 +70,20 @@ def update_submit_loading(n_clicks):
69
  prevent_initial_call=True
70
  )
71
  def create_graph(_, user_prompt):
72
- df_5_rows = df.head(5)
73
- data_top5_csv_string = df_5_rows.to_csv(index=False)
 
74
 
75
- result_output = prompt.get_response(user_prompt, data_top5_csv_string, DATA_FILE_PATH)
76
 
77
- return helpers.display_response(result_output, DATA_FILE_PATH)
 
 
 
 
 
 
 
78
 
79
  @callback(
80
  Output("ewf-download-html", "href"),
 
1
  import dash
2
+ from dash import callback, Input, Output, State, ctx, html
3
+ import dash_mantine_components as dmc
4
 
5
  from utils import chartbot_dataset_layout, prompt, helpers
6
 
 
70
  prevent_initial_call=True
71
  )
72
  def create_graph(_, user_prompt):
73
+ try:
74
+ df_5_rows = df.head(5)
75
+ data_top5_csv_string = df_5_rows.to_csv(index=False)
76
 
77
+ result_output = prompt.get_response(user_prompt, data_top5_csv_string, DATA_FILE_PATH)
78
 
79
+ return helpers.display_response(result_output, DATA_FILE_PATH)
80
+
81
+ except Exception as e:
82
+ error_message = str(e)
83
+ return html.Div([
84
+ html.Br(),
85
+ dmc.Alert(error_message, title="Error", color="red")
86
+ ]), None, {"display": "none"}, None, False
87
 
88
  @callback(
89
  Output("ewf-download-html", "href"),
pages/space_missions_chartbot.py CHANGED
@@ -1,5 +1,6 @@
1
  import dash
2
- from dash import callback, Input, Output, State, ctx
 
3
 
4
  from utils import chartbot_dataset_layout, prompt, helpers
5
 
@@ -69,12 +70,20 @@ def update_submit_loading(n_clicks):
69
  prevent_initial_call=True
70
  )
71
  def create_graph(_, user_prompt):
72
- df_5_rows = df.head(5)
73
- data_top5_csv_string = df_5_rows.to_csv(index=False)
 
74
 
75
- result_output = prompt.get_response(user_prompt, data_top5_csv_string, DATA_FILE_PATH)
76
 
77
- return helpers.display_response(result_output, DATA_FILE_PATH)
 
 
 
 
 
 
 
78
 
79
  @callback(
80
  Output("space-missions-download-html", "href"),
 
1
  import dash
2
+ from dash import callback, Input, Output, State, ctx, html
3
+ import dash_mantine_components as dmc
4
 
5
  from utils import chartbot_dataset_layout, prompt, helpers
6
 
 
70
  prevent_initial_call=True
71
  )
72
  def create_graph(_, user_prompt):
73
+ try:
74
+ df_5_rows = df.head(5)
75
+ data_top5_csv_string = df_5_rows.to_csv(index=False)
76
 
77
+ result_output = prompt.get_response(user_prompt, data_top5_csv_string, DATA_FILE_PATH)
78
 
79
+ return helpers.display_response(result_output, DATA_FILE_PATH)
80
+
81
+ except Exception as e:
82
+ error_message = str(e)
83
+ return html.Div([
84
+ html.Br(),
85
+ dmc.Alert(error_message, title="Error", color="red")
86
+ ]), None, {"display": "none"}, None, False
87
 
88
  @callback(
89
  Output("space-missions-download-html", "href"),
pages/us_elections_chartbot.py CHANGED
@@ -1,5 +1,6 @@
1
  import dash
2
- from dash import callback, Input, Output, State, ctx
 
3
 
4
  from utils import chartbot_dataset_layout, prompt, helpers
5
 
@@ -72,12 +73,20 @@ def update_submit_loading(n_clicks):
72
  prevent_initial_call=True
73
  )
74
  def create_graph(_, user_prompt):
75
- df_5_rows = df.head(5)
76
- data_top5_csv_string = df_5_rows.to_csv(index=False)
 
77
 
78
- result_output = prompt.get_response(user_prompt, data_top5_csv_string, DATA_FILE_PATH)
79
 
80
- return helpers.display_response(result_output, DATA_FILE_PATH)
 
 
 
 
 
 
 
81
 
82
  @callback(
83
  Output("us-elections-download-html", "href"),
 
1
  import dash
2
+ from dash import callback, Input, Output, State, ctx, html
3
+ import dash_mantine_components as dmc
4
 
5
  from utils import chartbot_dataset_layout, prompt, helpers
6
 
 
73
  prevent_initial_call=True
74
  )
75
  def create_graph(_, user_prompt):
76
+ try:
77
+ df_5_rows = df.head(5)
78
+ data_top5_csv_string = df_5_rows.to_csv(index=False)
79
 
80
+ result_output = prompt.get_response(user_prompt, data_top5_csv_string, DATA_FILE_PATH)
81
 
82
+ return helpers.display_response(result_output, DATA_FILE_PATH)
83
+
84
+ except Exception as e:
85
+ error_message = str(e)
86
+ return html.Div([
87
+ html.Br(),
88
+ dmc.Alert(error_message, title="Error", color="red")
89
+ ]), None, {"display": "none"}, None, False
90
 
91
  @callback(
92
  Output("us-elections-download-html", "href"),
utils/helpers.py CHANGED
@@ -16,7 +16,7 @@ from utils import prompt
16
 
17
  # Function to get the path of a file in the app source code
18
  def get_app_file_path(directory_name: str, file_name: str) -> str:
19
- return os.path.join(os.path.dirname(__file__), "..\\{}".format(directory_name), file_name)
20
 
21
  # Function to read the content of a file
22
  def read_doc(file_path: str) -> str:
@@ -34,37 +34,50 @@ def get_fig_from_code(code, file_name):
34
  exec(code, {}, local_variables)
35
 
36
  except Exception as e:
37
- result_output = prompt.get_python_exception_response(code, str(e))
38
- return display_response(result_output, file_name)
 
 
 
 
39
 
40
  return local_variables["fig"]
41
 
42
  def display_response(response, file_name):
43
- code_block_match = re.search(r"```(?:[Pp]ython)?(.*?)```", response, re.DOTALL)
44
- #print(code_block_match)
45
-
46
- if code_block_match:
47
- code_block = code_block_match.group(1).strip()
48
- cleaned_code = re.sub(r'(?m)^\s*fig\.show\(\)\s*$', '', code_block)
49
-
50
- try:
51
- fig = get_fig_from_code(cleaned_code, file_name)
52
-
53
- buffer = io.StringIO()
54
- fig.write_html(buffer)
55
- html_bytes = buffer.getvalue().encode()
56
- encoded = b64encode(html_bytes).decode()
57
-
58
- return dcc.Graph(figure=fig), response, {"display": "block"}, encoded, False
59
-
60
- except AttributeError as e:
61
- return html.Div([
62
- html.Br(),
63
- dmc.Alert("The code generated has errors, please try again.", title="Error", color="red")
64
- ]), None, {"display": "none"}, None, False
65
-
66
- else:
67
- return "", response, {"display": "none"}, None, False
 
 
 
 
 
 
 
 
 
68
 
69
  # Function to parse the contents of the uploaded file
70
  def parse_contents(contents, filename):
 
16
 
17
  # Function to get the path of a file in the app source code
18
  def get_app_file_path(directory_name: str, file_name: str) -> str:
19
+ return os.path.join(os.path.dirname(__file__), "..", directory_name, file_name)
20
 
21
  # Function to read the content of a file
22
  def read_doc(file_path: str) -> str:
 
34
  exec(code, {}, local_variables)
35
 
36
  except Exception as e:
37
+ try:
38
+ result_output = prompt.get_python_exception_response(code, str(e))
39
+ return display_response(result_output, file_name)
40
+ except Exception as api_error:
41
+ # If the API call fails, raise the error to be handled by display_response
42
+ raise api_error
43
 
44
  return local_variables["fig"]
45
 
46
  def display_response(response, file_name):
47
+ try:
48
+ code_block_match = re.search(r"```(?:[Pp]ython)?(.*?)```", response, re.DOTALL)
49
+ #print(code_block_match)
50
+
51
+ if code_block_match:
52
+ code_block = code_block_match.group(1).strip()
53
+ cleaned_code = re.sub(r'(?m)^\s*fig\.show\(\)\s*$', '', code_block)
54
+
55
+ try:
56
+ fig = get_fig_from_code(cleaned_code, file_name)
57
+
58
+ buffer = io.StringIO()
59
+ fig.write_html(buffer)
60
+ html_bytes = buffer.getvalue().encode()
61
+ encoded = b64encode(html_bytes).decode()
62
+
63
+ return dcc.Graph(figure=fig), response, {"display": "block"}, encoded, False
64
+
65
+ except AttributeError as e:
66
+ return html.Div([
67
+ html.Br(),
68
+ dmc.Alert("The code generated has errors, please try again.", title="Error", color="red")
69
+ ]), None, {"display": "none"}, None, False
70
+
71
+ else:
72
+ return "", response, {"display": "none"}, None, False
73
+
74
+ except Exception as e:
75
+ # Handle API errors gracefully
76
+ error_message = str(e)
77
+ return html.Div([
78
+ html.Br(),
79
+ dmc.Alert(error_message, title="API Error", color="red")
80
+ ]), None, {"display": "none"}, None, False
81
 
82
  # Function to parse the contents of the uploaded file
83
  def parse_contents(contents, filename):
utils/prompt.py CHANGED
@@ -3,26 +3,32 @@ import os
3
  from dotenv import load_dotenv
4
 
5
  # libraries to help with the AI model
6
- from langchain_openai import AzureChatOpenAI
7
  from langchain_core.messages import HumanMessage
8
- from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
9
 
10
  from utils import helpers
11
 
12
  # get the credentials from .env
13
  load_dotenv()
14
- AZURE_KEY = os.getenv('KEY')
15
- AZURE_ENDPOINT = os.getenv('LLM_ENDPOINT')
16
- AZURE_NAME = os.getenv('LLM_DEPLOYMENT_NAME')
17
- AZURE_VERSION = os.getenv('VERSION')
 
 
 
 
18
 
19
  # define connectivity to the llm
20
- llm = AzureChatOpenAI(
21
- deployment_name=AZURE_NAME,
22
- openai_api_version=AZURE_VERSION,
23
- openai_api_key=AZURE_KEY,
24
- azure_endpoint=AZURE_ENDPOINT
25
- )
 
 
26
 
27
  '''Before creating any visualizations, ensure that any rows with NaN or missing values in the relevant columns are removed. Additionally,
28
  handle missing values appropriately based on the context, ensuring cleaner visualizations.
@@ -68,55 +74,106 @@ def get_prompt_text() -> str:
68
  Here are the first 5 rows of the data set: {data}. Follow the user's indications when creating the graph.
69
  There should be no natural language text in the python code block."""
70
 
71
- def get_response(user_input: str, data_top5_csv_string: str, file_name: str) -> None:
72
- prompt = ChatPromptTemplate.from_messages(
73
- [
74
- (
75
- "system",
76
- get_prompt_text()
77
- ),
78
- MessagesPlaceholder(variable_name="messages")
79
- ]
80
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- chain = prompt | llm
83
-
84
- response = chain.invoke(
85
- {
86
- "messages": [HumanMessage(content=user_input)],
87
- "data_visualization_best_practices": helpers.read_doc(helpers.get_app_file_path("assets", "data_viz_best_practices.txt")),
88
- "example_subplots1": helpers.read_doc(helpers.get_app_file_path("assets", "example_subplots1.txt")),
89
- "example_subplots2": helpers.read_doc(helpers.get_app_file_path("assets", "example_subplots2.txt")),
90
- "example_subplots3": helpers.read_doc(helpers.get_app_file_path("assets", "example_subplots3.txt")),
91
- "data": data_top5_csv_string,
92
- "name_of_file": file_name
93
- }
94
- )
95
 
96
- return response.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def get_python_exception_prompt_text() -> str:
99
  return """The Python code you provided {code} has an error {exception}"""
100
 
101
- def get_python_exception_response(code: str, exception: str) -> None:
102
- prompt = ChatPromptTemplate.from_messages(
103
- [
104
- (
105
- "system",
106
- get_python_exception_prompt_text()
107
- ),
108
- MessagesPlaceholder(variable_name="messages")
109
- ]
110
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- chain = prompt | llm
113
 
114
- response = chain.invoke(
115
- {
116
- "messages": [HumanMessage(content="Rewrite the entire Python code so that it does not contain any errors. The code should be able to run without any errors.")],
117
- "code": code,
118
- "exception": exception
119
- }
120
- )
121
 
122
- return response.content
 
 
 
 
 
 
 
 
 
 
 
 
3
  from dotenv import load_dotenv
4
 
5
  # libraries to help with the AI model
6
+ from langchain_groq import ChatGroq
7
  from langchain_core.messages import HumanMessage
8
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
9
 
10
  from utils import helpers
11
 
12
  # get the credentials from .env
13
  load_dotenv()
14
+ GROQ_API_KEY = os.getenv('GROQ_API_KEY')
15
+
16
+ # Validate that the API key is present and not a placeholder
17
+ if not GROQ_API_KEY or GROQ_API_KEY == 'your_groq_api_key_here':
18
+ raise ValueError(
19
+ "GROQ_API_KEY environment variable is not set or is still the placeholder. "
20
+ "Please update the .env file with your actual Groq API key."
21
+ )
22
 
23
  # define connectivity to the llm
24
+ try:
25
+ llm = ChatGroq(
26
+ model="groq/compound-mini",
27
+ api_key=GROQ_API_KEY,
28
+ temperature=0
29
+ )
30
+ except Exception as e:
31
+ raise ValueError(f"Failed to initialize ChatGroq: {str(e)}")
32
 
33
  '''Before creating any visualizations, ensure that any rows with NaN or missing values in the relevant columns are removed. Additionally,
34
  handle missing values appropriately based on the context, ensuring cleaner visualizations.
 
74
  Here are the first 5 rows of the data set: {data}. Follow the user's indications when creating the graph.
75
  There should be no natural language text in the python code block."""
76
 
77
+ def get_response(user_input: str, data_top5_csv_string: str, file_name: str) -> str:
78
+ """
79
+ Get a response from the LLM for creating data visualizations.
80
+
81
+ Args:
82
+ user_input: User's request for visualization
83
+ data_top5_csv_string: CSV string of first 5 rows of data
84
+ file_name: Name of the data file
85
+
86
+ Returns:
87
+ LLM response content
88
+
89
+ Raises:
90
+ Exception: If API call fails
91
+ """
92
+ try:
93
+ prompt = ChatPromptTemplate.from_messages(
94
+ [
95
+ (
96
+ "system",
97
+ get_prompt_text()
98
+ ),
99
+ MessagesPlaceholder(variable_name="messages")
100
+ ]
101
+ )
102
 
103
+ chain = prompt | llm
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ response = chain.invoke(
106
+ {
107
+ "messages": [HumanMessage(content=user_input)],
108
+ "data_visualization_best_practices": helpers.read_doc(helpers.get_app_file_path("assets", "data_viz_best_practices.txt")),
109
+ "example_subplots1": helpers.read_doc(helpers.get_app_file_path("assets", "example_subplots1.txt")),
110
+ "example_subplots2": helpers.read_doc(helpers.get_app_file_path("assets", "example_subplots2.txt")),
111
+ "example_subplots3": helpers.read_doc(helpers.get_app_file_path("assets", "example_subplots3.txt")),
112
+ "data": data_top5_csv_string,
113
+ "name_of_file": file_name
114
+ }
115
+ )
116
+
117
+ return response.content
118
+
119
+ except Exception as e:
120
+ error_msg = str(e)
121
+ if "rate_limit" in error_msg.lower() or "429" in error_msg:
122
+ raise Exception("Rate limit exceeded. Please wait a moment and try again.")
123
+ elif "authentication" in error_msg.lower() or "401" in error_msg or "api_key" in error_msg.lower():
124
+ raise Exception("Authentication failed. Please check your GROQ_API_KEY in the .env file.")
125
+ elif "timeout" in error_msg.lower():
126
+ raise Exception("Request timed out. Please try again.")
127
+ else:
128
+ raise Exception(f"Error communicating with Groq API: {error_msg}")
129
 
130
  def get_python_exception_prompt_text() -> str:
131
  return """The Python code you provided {code} has an error {exception}"""
132
 
133
+ def get_python_exception_response(code: str, exception: str) -> str:
134
+ """
135
+ Get a response from the LLM to fix Python code errors.
136
+
137
+ Args:
138
+ code: The Python code that has errors
139
+ exception: The exception message
140
+
141
+ Returns:
142
+ LLM response with fixed code
143
+
144
+ Raises:
145
+ Exception: If API call fails
146
+ """
147
+ try:
148
+ prompt = ChatPromptTemplate.from_messages(
149
+ [
150
+ (
151
+ "system",
152
+ get_python_exception_prompt_text()
153
+ ),
154
+ MessagesPlaceholder(variable_name="messages")
155
+ ]
156
+ )
157
 
158
+ chain = prompt | llm
159
 
160
+ response = chain.invoke(
161
+ {
162
+ "messages": [HumanMessage(content="Rewrite the entire Python code so that it does not contain any errors. The code should be able to run without any errors.")],
163
+ "code": code,
164
+ "exception": exception
165
+ }
166
+ )
167
 
168
+ return response.content
169
+
170
+ except Exception as e:
171
+ error_msg = str(e)
172
+ if "rate_limit" in error_msg.lower() or "429" in error_msg:
173
+ raise Exception("Rate limit exceeded. Please wait a moment and try again.")
174
+ elif "authentication" in error_msg.lower() or "401" in error_msg or "api_key" in error_msg.lower():
175
+ raise Exception("Authentication failed. Please check your GROQ_API_KEY in the .env file.")
176
+ elif "timeout" in error_msg.lower():
177
+ raise Exception("Request timed out. Please try again.")
178
+ else:
179
+ raise Exception(f"Error communicating with Groq API: {error_msg}")