refactor chat functions

#39
by nolanzandi - opened
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ samples/online_retail_data.csv filter=lfs diff=lfs merge=lfs -text
.gitignore DELETED
@@ -1,4 +0,0 @@
1
- __pycache__/
2
- .gradio/
3
- .env
4
- temp/
 
 
 
 
 
README.md CHANGED
@@ -4,10 +4,10 @@ emoji: 📈
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.29.0
8
  app_file: app.py
9
  pinned: true
10
- short_description: Queries, visualizations, stat analysis on your data
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.23.3
8
  app_file: app.py
9
  pinned: true
10
+ short_description: Queries, visualizations, analysis on your files/DBs/APIs
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,192 +1,90 @@
1
- from utils import TEMP_DIR, message_dict, api_key_store, model_store
2
- import gradio as gr
3
- import templates.data_file as data_file, templates.sql_db as sql_db, templates.doc_db as doc_db, templates.graphql as graphql
4
-
5
- import os
6
- from dotenv import load_dotenv
7
-
8
- load_dotenv()
9
-
10
- def delete_db(req: gr.Request):
11
- import shutil
12
- dir_path = TEMP_DIR / str(req.session_hash)
13
- if os.path.exists(dir_path):
14
- shutil.rmtree(dir_path)
15
- message_dict[req.session_hash] = {}
16
- api_key_store.pop(req.session_hash, None)
17
- model_store.pop(req.session_hash, None)
18
-
19
- def set_api_key(api_key, model, request: gr.Request):
20
- api_key = api_key.strip()
21
- if not api_key:
22
- return (
23
- gr.update(visible=True),
24
- gr.update(visible=True, value="<p style='color:#b91c1c;text-align:center;margin:6px 0;font-size:14px;'>Please enter your API key.</p>"),
25
- gr.update(visible=False),
26
- )
27
- api_key_store[request.session_hash] = api_key
28
- model_store[request.session_hash] = model
29
- provider = "Anthropic" if api_key.startswith("sk-ant-") else "OpenAI"
30
- provider_icon = "fa-a" if provider == "Anthropic" else "fa-o"
31
- badge_html = f"""
32
- <div style="display:flex;flex-direction:column;align-items:center;gap:6px;padding:10px 0 4px;">
33
- <div style="display:inline-flex;align-items:center;gap:10px;background:#f0fdf4;border:1px solid #86efac;
34
- padding:8px 20px;border-radius:9999px;font-size:13px;font-weight:500;color:#15803d;
35
- box-shadow:0 1px 3px rgba(0,0,0,0.06);">
36
- <i class="fas fa-circle-check" style="font-size:14px;"></i>
37
- <span>{provider}</span>
38
- <span style="color:#86efac;">·</span>
39
- <span style="font-weight:600;">{model}</span>
40
- </div>
41
- <p style="margin:0;font-size:11px;color:#9ca3af;letter-spacing:0.02em;">
42
- Session active — use the button below to change
43
- </p>
44
- </div>
45
- """
46
- return gr.update(visible=False), gr.update(visible=True, value=badge_html), gr.update(visible=True)
47
-
48
- def show_api_form():
49
- return gr.update(visible=True), gr.update(visible=False, value=""), gr.update(visible=False)
50
-
51
- css = ".file_marker .large{min-height:50px !important;} .padding{padding:0;} .description_component{overflow:visible !important;}"
52
- head = """<meta charset="UTF-8">
53
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
54
- <title>Virtual Data Analyst</title>
55
- <!-- Tailwind CSS -->
56
- <script src="https://cdn.tailwindcss.com"></script>
57
- <!-- Google Fonts -->
58
- <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
59
- <!-- Font Awesome -->
60
- <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css">
61
- <!-- Custom Styles -->
62
- <link rel="stylesheet" href="/gradio_api/file=assets/styles.css">
63
- """
64
-
65
- theme = gr.themes.Base(primary_hue="sky", secondary_hue="slate", font=[gr.themes.GoogleFont("Inter"), "Inter", "sans-serif"]).set(
66
- button_primary_background_fill="#3B82F6",
67
- button_secondary_background_fill="#6B7280",
68
- )
69
-
70
- from pathlib import Path
71
- gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
72
-
73
- _env_api_key = os.getenv("OPENAI_API_KEY", "")
74
-
75
- OPENAI_MODELS = [
76
- "gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano",
77
- "gpt-4o", "gpt-4o-mini",
78
- "o3-mini", "o4-mini",
79
- "gpt-5.4-mini", "gpt-5.4", "gpt-5.5",
80
- ]
81
- ANTHROPIC_MODELS = [
82
- "claude-sonnet-4-6",
83
- "claude-opus-4-8",
84
- "claude-haiku-4-5-20251001",
85
- ]
86
-
87
- def update_models(api_key):
88
- if api_key.strip().startswith("sk-ant-"):
89
- return gr.update(choices=ANTHROPIC_MODELS, value=ANTHROPIC_MODELS[0])
90
- return gr.update(choices=OPENAI_MODELS, value=OPENAI_MODELS[0])
91
-
92
- with gr.Blocks(theme=theme, css=css, head=head, delete_cache=(3600, 3600)) as demo:
93
-
94
- with gr.Column(visible=True) as api_key_section:
95
- gr.HTML("""
96
- <div style="max-width:640px;margin:28px auto 12px;padding:22px 28px;
97
- background:linear-gradient(135deg,#eff6ff 0%,#e0f2fe 100%);
98
- border:1px solid #bfdbfe;border-radius:14px;
99
- box-shadow:0 2px 8px rgba(59,130,246,0.08);">
100
- <div style="display:flex;align-items:flex-start;gap:16px;">
101
- <div style="width:42px;height:42px;flex-shrink:0;background:#3B82F6;
102
- border-radius:10px;display:flex;align-items:center;
103
- justify-content:center;box-shadow:0 2px 6px rgba(59,130,246,0.35);">
104
- <i class="fas fa-key" style="color:white;font-size:16px;"></i>
105
- </div>
106
- <div>
107
- <h3 style="color:#1e40af;margin:0 0 6px;font-size:16px;font-weight:700;letter-spacing:-0.01em;">
108
- Get Started
109
- </h3>
110
- <p style="color:#3730a3;font-size:13.5px;margin:0;line-height:1.6;">
111
- Enter your <strong>OpenAI</strong>
112
- (<code style="background:rgba(255,255,255,0.7);padding:1px 6px;border-radius:4px;font-size:12px;">sk-...</code>)
113
- or <strong>Anthropic</strong>
114
- (<code style="background:rgba(255,255,255,0.7);padding:1px 6px;border-radius:4px;font-size:12px;">sk-ant-...</code>)
115
- API key. The model list updates automatically. Your key is held in memory only
116
- and cleared when you leave — never saved or shared.
117
- </p>
118
- </div>
119
- </div>
120
- </div>
121
- """)
122
- with gr.Row(equal_height=True):
123
- api_key_input = gr.Textbox(
124
- label="API Key",
125
- placeholder="sk-proj-... or sk-ant-api03-...",
126
- type="password",
127
- value=_env_api_key,
128
- scale=4,
129
- )
130
- model_dropdown = gr.Dropdown(
131
- label="Model",
132
- choices=OPENAI_MODELS,
133
- value=OPENAI_MODELS[0],
134
- scale=2,
135
- )
136
- api_key_btn = gr.Button("Set API Key", variant="primary", scale=1, min_width=120)
137
-
138
- api_key_status = gr.HTML("", visible=False)
139
- change_key_btn = gr.Button("🔑 Change Key / Model", variant="secondary", visible=False, size="sm")
140
-
141
- api_key_input.change(fn=update_models, inputs=api_key_input, outputs=model_dropdown)
142
- api_key_btn.click(
143
- fn=set_api_key,
144
- inputs=[api_key_input, model_dropdown],
145
- outputs=[api_key_section, api_key_status, change_key_btn],
146
- )
147
- change_key_btn.click(fn=show_api_form, outputs=[api_key_section, api_key_status, change_key_btn])
148
-
149
- header = gr.HTML("""
150
- <header class="max-w-4xl mx-auto mb-12 text-center">
151
- <h1 class="text-4xl font-bold text-gray-900 mb-4">Virtual Data Analyst</h1>
152
- <p class="text-lg text-gray-600 mb-6">
153
- A powerful tool for data analysis, visualizations, and insights
154
- </p>
155
- </header>
156
- <main class="max-w-4xl mx-auto">
157
- <div class="mt-12 grid md:grid-cols-3 gap-6" style="margin-bottom:3px !important;">
158
- <div class="feature-card bg-white p-6 rounded-lg shadow-md">
159
- <i class="feature-icon fas fa-chart-line text-primary text-2xl mb-4"></i>
160
- <h3 class="font-semibold text-gray-800 mb-2">Advanced Analytics</h3>
161
- <p class="text-gray-600 text-sm">Run SQL queries, perform regressions, and analyze results with ease</p>
162
- </div>
163
- <div class="feature-card bg-white p-6 rounded-lg shadow-md">
164
- <i class="feature-icon fas fa-chart-pie text-primary text-2xl mb-4"></i>
165
- <h3 class="font-semibold text-gray-800 mb-2">Rich Visualizations</h3>
166
- <p class="text-gray-600 text-sm">Create scatter plots, line charts, pie charts, and more</p>
167
- </div>
168
- <div class="feature-card bg-white p-6 rounded-lg shadow-md">
169
- <i class="feature-icon fas fa-magic text-primary text-2xl mb-4"></i>
170
- <h3 class="font-semibold text-gray-800 mb-2">Automated Insights</h3>
171
- <p class="text-gray-600 text-sm">Get instant insights and recommendations for your data</p>
172
- </div>
173
- </div>
174
- </main>""")
175
-
176
- with gr.Tab("📄 Data File"):
177
- data_file.demo.render()
178
- with gr.Tab("🗄 SQL Database"):
179
- sql_db.demo.render()
180
- with gr.Tab("🍃 MongoDB"):
181
- doc_db.demo.render()
182
- with gr.Tab("⚡ GraphQL API"):
183
- graphql.demo.render()
184
-
185
- footer = gr.HTML("""
186
- <footer class="max-w-4xl mx-auto mt-12 text-center text-gray-500 text-sm">
187
- <p>This application is under active development. For bugs or feedback, please open a discussion in the community tab.</p>
188
- </footer>""")
189
-
190
- demo.unload(delete_db)
191
-
192
- demo.launch(debug=True, allowed_paths=["temp/", "assets/"])
 
1
+ from utils import TEMP_DIR, message_dict
2
+ import gradio as gr
3
+ import templates.data_file as data_file, templates.sql_db as sql_db, templates.doc_db as doc_db, templates.graphql as graphql
4
+
5
+ import os
6
+ from getpass import getpass
7
+ from dotenv import load_dotenv
8
+
9
+ load_dotenv()
10
+
11
+ def delete_db(req: gr.Request):
12
+ import shutil
13
+ dir_path = TEMP_DIR / str(req.session_hash)
14
+ if os.path.exists(dir_path):
15
+ shutil.rmtree(dir_path)
16
+ message_dict[req.session_hash] = {}
17
+
18
+ if "OPENAI_API_KEY" not in os.environ:
19
+ os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
20
+
21
+ css= ".file_marker .large{min-height:50px !important;} .padding{padding:0;} .description_component{overflow:visible !important;}"
22
+ head = """<meta charset="UTF-8">
23
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
24
+ <title>Virtual Data Analyst</title>
25
+ <!-- Tailwind CSS -->
26
+ <script src="https://cdn.tailwindcss.com"></script>
27
+ <!-- Google Fonts -->
28
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
29
+ <!-- Font Awesome -->
30
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css">
31
+ <!-- Custom Styles -->
32
+ <link rel="stylesheet" href="/gradio_api/file=assets/styles.css">
33
+ """
34
+
35
+ theme = gr.themes.Base(primary_hue="sky", secondary_hue="slate",font=[gr.themes.GoogleFont("Inter"), "Inter", "sans-serif"]).set(
36
+ button_primary_background_fill="#3B82F6",
37
+ button_secondary_background_fill="#6B7280",
38
+ )
39
+
40
+ from pathlib import Path
41
+ gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
42
+
43
+ with gr.Blocks(theme=theme, css=css, head=head, delete_cache=(3600,3600)) as demo:
44
+ header = gr.HTML("""
45
+ <!-- Header -->
46
+ <header class="max-w-4xl mx-auto mb-12 text-center">
47
+ <h1 class="text-4xl font-bold text-gray-900 mb-4">Virtual Data Analyst</h1>
48
+ <p class="text-lg text-gray-600 mb-6">
49
+ A powerful tool for data analysis, visualizations, and insights
50
+ </p>
51
+ </header>
52
+ <!-- Main Content -->
53
+ <main class="max-w-4xl mx-auto">
54
+ <!-- Features Preview -->
55
+ <div class="mt-12 grid md:grid-cols-3 gap-6" style="margin-bottom:3px !important;">
56
+ <div class="feature-card bg-white p-6 rounded-lg shadow-md">
57
+ <i class="feature-icon fas fa-chart-line text-primary text-2xl mb-4"></i>
58
+ <h3 class="font-semibold text-gray-800 mb-2">Advanced Analytics</h3>
59
+ <p class="text-gray-600 text-sm">Run SQL queries, perform regressions, and analyze results with ease</p>
60
+ </div>
61
+ <div class="feature-card bg-white p-6 rounded-lg shadow-md">
62
+ <i class="feature-icon fas fa-chart-pie text-primary text-2xl mb-4"></i>
63
+ <h3 class="font-semibold text-gray-800 mb-2">Rich Visualizations</h3>
64
+ <p class="text-gray-600 text-sm">Create scatter plots, line charts, pie charts, and more</p>
65
+ </div>
66
+ <div class="feature-card bg-white p-6 rounded-lg shadow-md">
67
+ <i class="feature-icon fas fa-magic text-primary text-2xl mb-4"></i>
68
+ <h3 class="font-semibold text-gray-800 mb-2">Automated Insights</h3>
69
+ <p class="text-gray-600 text-sm">Get instant insights and recommendations for your data</p>
70
+ </div>
71
+ </div>
72
+ </main>""")
73
+ with gr.Tab("Data File"):
74
+ data_file.demo.render()
75
+ with gr.Tab("SQL Database"):
76
+ sql_db.demo.render()
77
+ with gr.Tab("Document (MongoDB) Database"):
78
+ doc_db.demo.render()
79
+ with gr.Tab("GraphQL API"):
80
+ graphql.demo.render()
81
+
82
+ footer = gr.HTML("""<!-- Footer -->
83
+ <footer class="max-w-4xl mx-auto mt-12 text-center text-gray-500 text-sm">
84
+ <p>This application is under active development. For bugs or feedback, please open a discussion in the community tab.</p>
85
+ </footer>""")
86
+
87
+ demo.unload(delete_db)
88
+
89
+ ## Uncomment the line below to launch the chat app with UI
90
+ demo.launch(debug=True, allowed_paths=["temp/","assets/"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
assets/styles.css CHANGED
@@ -89,7 +89,6 @@
89
  transition: all 0.3s ease;
90
  position: relative;
91
  overflow: hidden;
92
- background: linear-gradient(135deg, #3B82F6, #0ea5e9) !important;
93
  }
94
 
95
  .sample-btn::after {
@@ -99,7 +98,7 @@
99
  left: 0;
100
  width: 100%;
101
  height: 100%;
102
- background: linear-gradient(rgba(255,255,255,0.12), rgba(255,255,255,0));
103
  transform: translateY(-100%);
104
  transition: transform 0.3s ease;
105
  }
@@ -110,17 +109,7 @@
110
 
111
  .sample-btn:hover {
112
  transform: translateY(-2px);
113
- box-shadow: 0 8px 20px rgba(59,130,246,0.3);
114
- }
115
-
116
- /* Status badge fade-in */
117
- @keyframes fadeSlideIn {
118
- from { opacity: 0; transform: translateY(-6px); }
119
- to { opacity: 1; transform: translateY(0); }
120
- }
121
-
122
- .api-status-badge {
123
- animation: fadeSlideIn 0.35s ease forwards;
124
  }
125
 
126
  /* Drop Zone Enhancements */
@@ -185,14 +174,4 @@
185
  grid-template-columns: 1fr 2fr;
186
  align-items: baseline;
187
  }
188
- }
189
-
190
- dialog {
191
- margin: 10% auto;
192
- width: 80%;
193
- max-width: 350px;
194
- background-color: #fff;
195
- padding: 34px;
196
- border: 0;
197
- border-radius: 5px;
198
- }
 
89
  transition: all 0.3s ease;
90
  position: relative;
91
  overflow: hidden;
 
92
  }
93
 
94
  .sample-btn::after {
 
98
  left: 0;
99
  width: 100%;
100
  height: 100%;
101
+ background: linear-gradient(rgba(255,255,255,0.1), rgba(255,255,255,0));
102
  transform: translateY(-100%);
103
  transition: transform 0.3s ease;
104
  }
 
109
 
110
  .sample-btn:hover {
111
  transform: translateY(-2px);
112
+ box-shadow: 0 8px 15px rgba(0,0,0,0.1);
 
 
 
 
 
 
 
 
 
 
113
  }
114
 
115
  /* Drop Zone Enhancements */
 
174
  grid-template-columns: 1fr 2fr;
175
  align-items: baseline;
176
  }
177
+ }
 
 
 
 
 
 
 
 
 
 
data_sources/connect_graphql.py CHANGED
@@ -1,5 +1,4 @@
1
  import requests
2
- import certifi
3
  import os
4
  import json
5
  from utils import TEMP_DIR
@@ -103,8 +102,7 @@ def connect_graphql(graphql_url, api_token, graphql_token_header, session_hash):
103
  headers = {"Content-Type": "application/json"}
104
  if graphql_token_header and api_token:
105
  headers[graphql_token_header] = api_token
106
- response = requests.post(graphql_url, headers=headers, json={"query": query},
107
- verify=certifi.where())
108
  response.raise_for_status()
109
 
110
  introspection_result = response.json()
@@ -121,8 +119,7 @@ def connect_graphql(graphql_url, api_token, graphql_token_header, session_hash):
121
  }
122
  }
123
  """
124
- types_response = requests.post(graphql_url, headers=headers, json={"query": type_names_query},
125
- verify=certifi.where())
126
 
127
  types_response_results =types_response.json()
128
 
 
1
  import requests
 
2
  import os
3
  import json
4
  from utils import TEMP_DIR
 
102
  headers = {"Content-Type": "application/json"}
103
  if graphql_token_header and api_token:
104
  headers[graphql_token_header] = api_token
105
+ response = requests.post(graphql_url, headers=headers, json={"query": query})
 
106
  response.raise_for_status()
107
 
108
  introspection_result = response.json()
 
119
  }
120
  }
121
  """
122
+ types_response = requests.post(graphql_url, headers=headers, json={"query": type_names_query})
 
123
 
124
  types_response_results =types_response.json()
125
 
data_sources/upload_file.py CHANGED
@@ -95,72 +95,7 @@ def process_data_upload(data_file, session_hash):
95
  connection.commit()
96
  connection.close()
97
 
98
- missing_per_col = {col: int(df[col].isnull().sum()) for col in df.columns}
99
- total_missing = sum(missing_per_col.values())
100
-
101
- def _simplify_dtype(d):
102
- s = str(d)
103
- if 'int' in s: return 'Integer'
104
- if 'float' in s: return 'Float'
105
- if 'datetime' in s: return 'DateTime'
106
- if 'bool' in s: return 'Boolean'
107
- return 'Text'
108
-
109
- dtypes = {col: _simplify_dtype(df[col].dtype) for col in df.columns}
110
-
111
- preview = []
112
- for _, row in df.head(5).iterrows():
113
- row_vals = []
114
- for v in row:
115
- try:
116
- row_vals.append('' if pd.isna(v) else str(v)[:60])
117
- except Exception:
118
- row_vals.append(str(v)[:60])
119
- preview.append(row_vals)
120
-
121
- duplicate_rows = int(df.duplicated().sum())
122
- unique_counts = {col: int(df[col].nunique()) for col in df.columns}
123
-
124
- col_stats = {}
125
- for col in df.columns:
126
- dtype_str = str(df[col].dtype)
127
- try:
128
- if 'int' in dtype_str or 'float' in dtype_str:
129
- col_stats[col] = {
130
- 'type': 'numeric',
131
- 'min': float(df[col].min()),
132
- 'max': float(df[col].max()),
133
- 'mean': float(df[col].mean()),
134
- }
135
- elif 'datetime' in dtype_str:
136
- col_stats[col] = {
137
- 'type': 'datetime',
138
- 'min': str(df[col].min())[:10],
139
- 'max': str(df[col].max())[:10],
140
- }
141
- except Exception:
142
- pass
143
-
144
- try:
145
- file_size_bytes = os.path.getsize(data_file)
146
- except Exception:
147
- file_size_bytes = 0
148
-
149
- stats = {
150
- 'num_rows': len(df),
151
- 'num_cols': len(df.columns),
152
- 'total_missing': total_missing,
153
- 'missing_per_col': missing_per_col,
154
- 'dtypes': dtypes,
155
- 'preview_cols': list(df.columns),
156
- 'preview': preview,
157
- 'duplicate_rows': duplicate_rows,
158
- 'unique_counts': unique_counts,
159
- 'col_stats': col_stats,
160
- 'file_size_bytes': file_size_bytes,
161
- }
162
-
163
- return ["success","<p style='color:green;text-align:center;font-size:18px;'>Data upload successful</p>", columns, stats]
164
  except Exception as e:
165
  print("UPLOAD ERROR")
166
  print(e)
 
95
  connection.commit()
96
  connection.close()
97
 
98
+ return ["success","<p style='color:green;text-align:center;font-size:18px;'>Data upload successful</p>", columns]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  except Exception as e:
100
  print("UPLOAD ERROR")
101
  print(e)
functions/__init__.py CHANGED
@@ -1,17 +1,9 @@
1
- from .query_functions import graphql_schema_query, graphql_csv_query, query_func
2
  from .chart_functions import table_generation_func, scatter_chart_generation_func, \
3
- line_chart_generation_func, bar_chart_generation_func, pie_chart_generation_func, \
4
- histogram_generation_func, box_chart_generation_func, correlation_heatmap_func, \
5
- scatter_chart_fig, rolling_stats_func
6
  from .chat_functions import example_question_generator, chatbot_func
7
- from .stat_functions import regression_func, descriptive_stats_func, \
8
- kmeans_clustering_func, hypothesis_test_func
9
 
10
- __all__ = [
11
- "query_func", "graphql_schema_query", "graphql_csv_query",
12
- "table_generation_func", "scatter_chart_generation_func", "line_chart_generation_func",
13
- "bar_chart_generation_func", "pie_chart_generation_func", "histogram_generation_func",
14
- "box_chart_generation_func", "correlation_heatmap_func", "rolling_stats_func",
15
- "regression_func", "descriptive_stats_func", "kmeans_clustering_func", "hypothesis_test_func",
16
- "scatter_chart_fig", "example_question_generator", "chatbot_func",
17
- ]
 
1
+ from .query_functions import SQLiteQuery, sqlite_query_func, sql_query_func, doc_db_query_func, graphql_query_func, graphql_schema_query, graphql_csv_query
2
  from .chart_functions import table_generation_func, scatter_chart_generation_func, \
3
+ line_chart_generation_func, bar_chart_generation_func, pie_chart_generation_func, histogram_generation_func, scatter_chart_fig
 
 
4
  from .chat_functions import example_question_generator, chatbot_func
5
+ from .stat_functions import regression_func
 
6
 
7
+ __all__ = ["SQLiteQuery","sqlite_query_func","sql_query_func","doc_db_query_func","graphql_query_func","graphql_schema_query","graphql_csv_query","table_generation_func","scatter_chart_generation_func",
8
+ "line_chart_generation_func","bar_chart_generation_func","regression_func", "pie_chart_generation_func", "histogram_generation_func",
9
+ "scatter_chart_fig","example_question_generator","chatbot_func"]
 
 
 
 
 
functions/chart_functions.py CHANGED
@@ -9,20 +9,7 @@ from dotenv import load_dotenv
9
 
10
  load_dotenv()
11
 
12
- root_url = os.getenv("ROOT_URL", "")
13
-
14
-
15
- def _write_chart(fig, chart_path, chart_url):
16
- """Write a Plotly figure to disk and return a responsive iframe HTML string."""
17
- pio.write_html(fig, chart_path, full_html=False, config={"responsive": True})
18
- return (
19
- 'Please display this iframe: '
20
- '<div style="width:100%;overflow-x:auto;">'
21
- '<iframe style="width:100%;min-width:400px;" height="500" '
22
- f'src="{chart_url}" frameborder="0" allowfullscreen>'
23
- '</iframe></div>'
24
- )
25
-
26
 
27
  def llm_chart_data_scrub(data, layout):
28
  #Processing data to account for variation from LLM
@@ -138,8 +125,13 @@ def scatter_chart_generation_func(x_column: List[str], y_column: str, session_ha
138
  for data_item in fig["data"]:
139
  data_item[key] = value
140
 
 
 
141
  chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
142
- return {"reply": _write_chart(fig, chart_path, chart_url)}
 
 
 
143
 
144
  except Exception as e:
145
  print("SCATTER PLOT ERROR")
@@ -182,10 +174,15 @@ def line_chart_generation_func(x_column: str, y_column: str, session_hash, sessi
182
  for data_item in fig["data"]:
183
  data_item[key] = value
184
 
185
- print(fig)
 
 
186
 
187
  chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
188
- return {"reply": _write_chart(fig, chart_path, chart_url)}
 
 
 
189
 
190
  except Exception as e:
191
  print("LINE CHART ERROR")
@@ -232,10 +229,15 @@ def bar_chart_generation_func(x_column: str, y_column: str, session_hash, sessio
232
  for data_item in fig["data"]:
233
  data_item[key] = value
234
 
235
- print(fig)
 
 
236
 
237
  chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
238
- return {"reply": _write_chart(fig, chart_path, chart_url)}
 
 
 
239
 
240
  except Exception as e:
241
  print("BAR CHART ERROR")
@@ -274,10 +276,15 @@ def pie_chart_generation_func(values: str, names: str, session_hash, session_fol
274
  for data_item in fig["data"]:
275
  data_item[key] = value
276
 
277
- print(fig)
 
 
278
 
279
  chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
280
- return {"reply": _write_chart(fig, chart_path, chart_url)}
 
 
 
281
 
282
  except Exception as e:
283
  print("PIE CHART ERROR")
@@ -328,10 +335,15 @@ def histogram_generation_func(x_column: str, session_hash, session_folder, y_col
328
  for data_item in fig["data"]:
329
  data_item[key] = value
330
 
331
- print(fig)
 
 
332
 
333
  chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
334
- return {"reply": _write_chart(fig, chart_path, chart_url)}
 
 
 
335
 
336
  except Exception as e:
337
  print("HISTOGRAM ERROR")
@@ -342,185 +354,32 @@ def histogram_generation_func(x_column: str, session_hash, session_folder, y_col
342
  """
343
  return {"reply": reply}
344
 
345
- def box_chart_generation_func(y_column: str, session_hash, session_folder,
346
- x_column: str="", category: str="",
347
- layout: List[dict]=[{}], **kwargs):
348
- try:
349
- dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
350
- chart_path = f'{dir_path}/chart.html'
351
- csv_query_path = f'{dir_path}/query.csv'
352
-
353
- df = pd.read_csv(csv_query_path)
354
-
355
- function_args = {"data_frame": df, "y": y_column}
356
- if x_column:
357
- function_args["x"] = x_column
358
- if category:
359
- function_args["color"] = category
360
-
361
- initial_graph = px.box(**function_args)
362
- fig = initial_graph.to_dict()
363
-
364
- _, layout_dict = llm_chart_data_scrub({}, layout)
365
- if layout_dict:
366
- fig["layout"] = layout_dict
367
-
368
- chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
369
- return {"reply": _write_chart(fig, chart_path, chart_url)}
370
-
371
- except Exception as e:
372
- print("BOX CHART ERROR")
373
- print(e)
374
- return {"reply": f"There was an error generating the box plot. Error: {e}. You should probably try again."}
375
-
376
-
377
- def correlation_heatmap_func(session_hash, session_folder, columns: List[str]=[], **kwargs):
378
- try:
379
- dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
380
- chart_path = f'{dir_path}/chart.html'
381
- csv_query_path = f'{dir_path}/query.csv'
382
-
383
- df = pd.read_csv(csv_query_path)
384
-
385
- numeric_df = df[columns].select_dtypes(include='number') if columns else df.select_dtypes(include='number')
386
-
387
- if numeric_df.shape[1] < 2:
388
- return {"reply": "At least two numeric columns are needed for a correlation matrix. Please refine your query to include more numeric columns."}
389
-
390
- corr = numeric_df.corr().round(3)
391
-
392
- fig = px.imshow(
393
- corr,
394
- text_auto='.2f',
395
- color_continuous_scale='RdBu_r',
396
- zmin=-1,
397
- zmax=1,
398
- title='Correlation Matrix',
399
- aspect='auto',
400
- )
401
- fig.update_layout(font=dict(family='Inter, system-ui, sans-serif'))
402
-
403
- chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
404
- return {"reply": _write_chart(fig, chart_path, chart_url)}
405
-
406
- except Exception as e:
407
- print("CORRELATION HEATMAP ERROR")
408
- print(e)
409
- return {"reply": f"There was an error generating the correlation heatmap. Error: {e}. You should probably try again."}
410
-
411
-
412
- def rolling_stats_func(x_column: str, y_column: str, session_hash, session_folder,
413
- window: int = 7, stats: List[str] = ["mean"],
414
- layout: List[dict] = [{}], category: str = "", **kwargs):
415
- try:
416
- import plotly.graph_objects as go
417
-
418
- dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
419
- chart_path = f'{dir_path}/chart.html'
420
- csv_query_path = f'{dir_path}/query.csv'
421
-
422
- df = pd.read_csv(csv_query_path)
423
-
424
- try:
425
- df[x_column] = pd.to_datetime(df[x_column])
426
- except Exception:
427
- pass
428
- df = df.sort_values(x_column)
429
-
430
- valid_stats = {"mean", "std", "min", "max"}
431
- selected_stats = [s for s in stats if s in valid_stats] or ["mean"]
432
-
433
- fig = go.Figure()
434
-
435
- groups = df[category].unique().tolist() if category and category in df.columns else [None]
436
-
437
- for group in groups:
438
- group_df = df[df[category] == group] if group is not None else df
439
- prefix = f"{group} — " if group is not None else ""
440
-
441
- fig.add_trace(go.Scatter(
442
- x=group_df[x_column].values, y=group_df[y_column].values,
443
- mode="lines", name=f"{prefix}{y_column} (raw)",
444
- opacity=0.35, line=dict(width=1)
445
- ))
446
-
447
- rolling_obj = group_df[y_column].rolling(window)
448
- for stat in selected_stats:
449
- rolled = getattr(rolling_obj, stat)()
450
- fig.add_trace(go.Scatter(
451
- x=group_df[x_column].values, y=rolled.values,
452
- mode="lines", name=f"{prefix}Rolling {stat.capitalize()} (w={window})",
453
- line=dict(width=2.5)
454
- ))
455
-
456
- fig.update_layout(
457
- title=f"Rolling Statistics (window={window}) — {y_column}",
458
- xaxis_title=x_column,
459
- yaxis_title=y_column,
460
- font=dict(family="Inter, system-ui, sans-serif"),
461
- legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
462
- )
463
-
464
- _, layout_dict = llm_chart_data_scrub({}, layout)
465
- if layout_dict:
466
- fig.update_layout(**layout_dict)
467
-
468
- chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
469
- return {"reply": _write_chart(fig, chart_path, chart_url)}
470
-
471
- except Exception as e:
472
- print("ROLLING STATS ERROR")
473
- print(e)
474
- return {"reply": f"There was an error generating the rolling statistics chart. Error: {e}. You should probably try again."}
475
-
476
-
477
  def table_generation_func(session_hash, session_folder, **kwargs):
478
  print("TABLE GENERATION")
479
- try:
480
- from html import escape
481
-
482
  dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
483
  csv_query_path = f'{dir_path}/query.csv'
 
484
 
485
  df = pd.read_csv(csv_query_path)
486
 
487
- total_rows = len(df)
488
- max_rows = 200
489
- if total_rows > max_rows:
490
- df = df.head(max_rows)
491
- note = (f'<p class="vda-table-note">Showing first {max_rows} of {total_rows} rows'
492
- ' — refine your query to see more specific results.</p>')
493
- else:
494
- note = ''
495
-
496
- header_cells = ''.join(f'<th>{escape(str(col))}</th>' for col in df.columns)
497
- row_html = [
498
- '<tr>' + ''.join(f'<td>{escape(str(val))}</td>' for val in row) + '</tr>'
499
- for _, row in df.iterrows()
500
- ]
501
-
502
- style = (
503
- '<style>'
504
- '.vda-table-wrap{overflow-x:auto;margin:8px 0;border-radius:8px;border:1px solid #e5e7eb;}'
505
- '.vda-table{width:100%;border-collapse:collapse;font-size:13px;font-family:Inter,system-ui,sans-serif;}'
506
- '.vda-table thead th{background:#3B82F6;color:#fff;padding:9px 14px;text-align:left;white-space:nowrap;font-weight:600;}'
507
- '.vda-table tbody td{padding:7px 14px;border-bottom:1px solid #f1f5f9;white-space:nowrap;}'
508
- '.vda-table tbody tr:nth-child(even){background:#f8fafc;}'
509
- '.vda-table tbody tr:last-child td{border-bottom:none;}'
510
- '.vda-table-note{font-size:12px;color:#6b7280;margin:4px 0 0;text-align:right;}'
511
- '</style>'
512
- )
513
-
514
- table = (
515
- '<div class="vda-table-wrap"><table class="vda-table">'
516
- f'<thead><tr>{header_cells}</tr></thead>'
517
- '<tbody>' + '\n'.join(row_html) + '</tbody>'
518
- '</table></div>'
519
- )
520
-
521
- return {"reply": style + table + note}
522
 
 
 
 
 
 
 
 
 
 
523
  except Exception as e:
524
- print("TABLE ERROR")
525
- print(e)
526
- return {"reply": f"There was an error generating the table. Error: {e}. You should probably try again."}
 
 
 
 
 
9
 
10
  load_dotenv()
11
 
12
+ root_url = os.getenv("ROOT_URL")
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def llm_chart_data_scrub(data, layout):
15
  #Processing data to account for variation from LLM
 
125
  for data_item in fig["data"]:
126
  data_item[key] = value
127
 
128
+ pio.write_html(fig, chart_path, full_html=False)
129
+
130
  chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
131
+
132
+ iframe = 'Please display this iframe: <div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
133
+
134
+ return {"reply": iframe}
135
 
136
  except Exception as e:
137
  print("SCATTER PLOT ERROR")
 
174
  for data_item in fig["data"]:
175
  data_item[key] = value
176
 
177
+ print(fig)
178
+
179
+ pio.write_html(fig, chart_path, full_html=False)
180
 
181
  chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
182
+
183
+ iframe = 'Please display this iframe: <div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
184
+
185
+ return {"reply": iframe}
186
 
187
  except Exception as e:
188
  print("LINE CHART ERROR")
 
229
  for data_item in fig["data"]:
230
  data_item[key] = value
231
 
232
+ print(fig)
233
+
234
+ pio.write_html(fig, chart_path, full_html=False)
235
 
236
  chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
237
+
238
+ iframe = 'Please display this iframe: <div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
239
+
240
+ return {"reply": iframe}
241
 
242
  except Exception as e:
243
  print("BAR CHART ERROR")
 
276
  for data_item in fig["data"]:
277
  data_item[key] = value
278
 
279
+ print(fig)
280
+
281
+ pio.write_html(fig, chart_path, full_html=False)
282
 
283
  chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
284
+
285
+ iframe = 'Please display this iframe: <div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
286
+
287
+ return {"reply": iframe}
288
 
289
  except Exception as e:
290
  print("PIE CHART ERROR")
 
335
  for data_item in fig["data"]:
336
  data_item[key] = value
337
 
338
+ print(fig)
339
+
340
+ pio.write_html(fig, chart_path, full_html=False)
341
 
342
  chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
343
+
344
+ iframe = 'Please display this iframe: <div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
345
+
346
+ return {"reply": iframe}
347
 
348
  except Exception as e:
349
  print("HISTOGRAM ERROR")
 
354
  """
355
  return {"reply": reply}
356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  def table_generation_func(session_hash, session_folder, **kwargs):
358
  print("TABLE GENERATION")
359
+ try:
 
 
360
  dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
361
  csv_query_path = f'{dir_path}/query.csv'
362
+ table_path = f'{dir_path}/table.html'
363
 
364
  df = pd.read_csv(csv_query_path)
365
 
366
+ html_table = df.to_html()
367
+ print(html_table[:1000])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
+ with open(table_path, "w") as file:
370
+ file.write(html_table)
371
+
372
+ table_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/table.html'
373
+
374
+ iframe = 'Please display this iframe: <div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + table_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
375
+ print(iframe)
376
+ return {"reply": iframe}
377
+
378
  except Exception as e:
379
+ print("TABLE ERROR")
380
+ print(e)
381
+ reply = f"""There was an error generating the Pandas DataFrame table results.
382
+ The error is {e},
383
+ You should probably try again.
384
+ """
385
+ return {"reply": reply}
functions/chat_functions.py CHANGED
@@ -1,19 +1,9 @@
1
- from utils import message_dict, api_key_store, model_store
2
 
3
  from haystack.dataclasses import ChatMessage
4
  from haystack.components.generators.chat import OpenAIChatGenerator
5
- from haystack.utils import Secret
6
-
7
- def _get_generator(session_hash):
8
- api_key = api_key_store.get(session_hash)
9
- if not api_key:
10
- raise ValueError("No API key found for this session. Please enter your API key at the top of the page.")
11
- model = model_store.get(session_hash, "gpt-4o")
12
- if api_key.startswith("sk-ant-"):
13
- from haystack_integrations.components.generators.chat import AnthropicChatGenerator
14
- return AnthropicChatGenerator(model=model, api_key=Secret.from_token(api_key))
15
- return OpenAIChatGenerator(model=model, api_key=Secret.from_token(api_key))
16
 
 
17
  response = None
18
 
19
  def example_question_message(data_source, name, titles, schema):
@@ -23,15 +13,15 @@ def example_question_message(data_source, name, titles, schema):
23
  f"""We have a SQLite database with the following {titles}.
24
  We also have an AI agent with access to the same database that will be performing data analysis.
25
  Please return an array of seven strings, each one being a question for our data analysis agent
26
- that we can suggest that you believe will be insightful or helpful to a data analyst looking for
27
  data insights. Return nothing more than the array of questions because I need that specific data structure
28
  to process your response. No other response type or data structure will work."""],
29
 
30
- 'sql' : [f"You are a helpful and knowledgeable agent who has access to a PostgreSQL database called {name}.",
31
  f"""We have a PostgreSQL database with the following tables: {titles}.
32
  We also have an AI agent with access to the same database that will be performing data analysis.
33
  Please return an array of seven strings, each one being a question for our data analysis agent
34
- that we can suggest that you believe will be insightful or helpful to a data analyst looking for
35
  data insights. Return nothing more than the array of questions because I need that specific data structure
36
  to process your response. No other response type or data structure will work."""],
37
 
@@ -40,7 +30,7 @@ def example_question_message(data_source, name, titles, schema):
40
  The schema of these collections is: {schema}.
41
  We also have an AI agent with access to the same database that will be performing data analysis.
42
  Please return an array of seven strings, each one being a question for our data analysis agent
43
- that we can suggest that you believe will be insightful or helpful to a data analyst looking for
44
  data insights. Return nothing more than the array of questions because I need that specific data structure
45
  to process your response. No other response type or data structure will work."""],
46
 
@@ -48,7 +38,7 @@ def example_question_message(data_source, name, titles, schema):
48
  f"""We have a GraphQL API endpoint with the following types: {titles}.
49
  We also have an AI agent with access to the same GraphQL API endpoint that will be performing data analysis.
50
  Please return an array of seven strings, each one being a question for our data analysis agent
51
- that we can suggest that you believe will be insightful or helpful to a data analyst looking for
52
  data insights. Return nothing more than the array of questions because I need that specific data structure
53
  to process your response. No other response type or data structure will work."""]
54
 
@@ -67,84 +57,72 @@ def example_question_generator(session_hash, data_source, name, titles, schema):
67
 
68
  example_messages.append(ChatMessage.from_user(text=example_message_list[1]))
69
 
70
- example_response = _get_generator(session_hash).run(messages=example_messages)
71
 
72
- response_text = example_response["replies"][0].text
73
- start = response_text.index("[") + 1
74
- end = response_text.index("]")
75
- response_content = response_text[start:end]
76
- response_list = '[' + response_content + ']'
77
- print(response_list)
78
-
79
- return response_list
80
 
81
  def system_message(data_source, titles, schema=""):
82
- print("TITLES")
83
- print(titles)
84
-
85
- tools_desc = (
86
- " You have access to tools for querying the data source, generating charts and visualisations,"
87
- " and performing statistical analyses — use them proactively whenever they would help answer the user's question."
88
- " Always display any charts, tables, and visualisations inline in your responses by outputting the returned HTML verbatim."
89
- )
90
 
91
  system_message_dict = {
92
- 'file_upload': (
93
- f"You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source' that contains the following columns: {titles}."
94
- + tools_desc
95
- ),
96
- 'sql': (
97
- f"You are a helpful and knowledgeable agent who has access to a PostgreSQL database which has a series of tables called {titles}."
98
- + tools_desc
99
- ),
100
- 'doc_db': (
101
- f"You are a helpful and knowledgeable agent who has access to a NoSQL MongoDB Document database which has a series of collections called {titles}. "
102
- f"The schema of these collections is: {schema}."
103
- + tools_desc
104
- ),
105
- 'graphql': (
106
- f"You are a helpful and knowledgeable agent who has access to a GraphQL API which has the following types: {titles}. "
107
- "We have also saved a schema.json file that contains the entire introspection query that we can use to find out more about each type before making a query."
108
- + tools_desc
109
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  }
111
 
112
  return system_message_dict[data_source]
113
 
114
  def chatbot_func(message, history, session_hash, data_source, titles, schema, *args):
115
- try:
116
- chat_generator = _get_generator(session_hash)
117
- except ValueError as e:
118
- return str(e)
119
-
120
- from functions import (
121
- table_generation_func, regression_func, descriptive_stats_func,
122
- scatter_chart_generation_func, line_chart_generation_func, bar_chart_generation_func,
123
- pie_chart_generation_func, histogram_generation_func,
124
- box_chart_generation_func, correlation_heatmap_func, rolling_stats_func,
125
- query_func, graphql_schema_query, graphql_csv_query,
126
- kmeans_clustering_func, hypothesis_test_func,
127
- )
128
  import tools.tools as tools
129
 
130
- available_functions = {
131
- "query_func": query_func,
132
- "graphql_schema_query": graphql_schema_query,
133
- "graphql_csv_query": graphql_csv_query,
134
- "table_generation_func": table_generation_func,
135
- "scatter_chart_generation_func": scatter_chart_generation_func,
136
- "line_chart_generation_func": line_chart_generation_func,
137
- "bar_chart_generation_func": bar_chart_generation_func,
138
- "pie_chart_generation_func": pie_chart_generation_func,
139
- "histogram_generation_func": histogram_generation_func,
140
- "box_chart_generation_func": box_chart_generation_func,
141
- "correlation_heatmap_func": correlation_heatmap_func,
142
- "rolling_stats_func": rolling_stats_func,
143
- "regression_func": regression_func,
144
- "descriptive_stats_func": descriptive_stats_func,
145
- "kmeans_clustering_func": kmeans_clustering_func,
146
- "hypothesis_test_func": hypothesis_test_func,
147
- }
148
 
149
  if message_dict[session_hash][data_source] != None:
150
  message_dict[session_hash][data_source].append(ChatMessage.from_user(message))
@@ -155,11 +133,10 @@ def chatbot_func(message, history, session_hash, data_source, titles, schema, *a
155
  messages.append(ChatMessage.from_user(message))
156
  message_dict[session_hash][data_source] = messages
157
 
158
- active_tools = tools.tools_call(session_hash, data_source, titles)
159
- response = chat_generator.run(messages=message_dict[session_hash][data_source], tools=active_tools)
160
 
161
  while True:
162
- # if the response is a tool call
163
  if response and response["replies"][0].meta["finish_reason"] == "tool_calls" or response["replies"][0].tool_calls:
164
  function_calls = response["replies"][0].tool_calls
165
  for function_call in function_calls:
@@ -174,7 +151,7 @@ def chatbot_func(message, history, session_hash, data_source, titles, schema, *a
174
  print(function_name)
175
  ## Append function response to the messages list using `ChatMessage.from_tool`
176
  message_dict[session_hash][data_source].append(ChatMessage.from_tool(tool_result=function_response['reply'], origin=function_call))
177
- response = chat_generator.run(messages=message_dict[session_hash][data_source], tools=active_tools)
178
 
179
  # Regular Conversation
180
  else:
 
1
+ from utils import message_dict
2
 
3
  from haystack.dataclasses import ChatMessage
4
  from haystack.components.generators.chat import OpenAIChatGenerator
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ chat_generator = OpenAIChatGenerator(model="gpt-4o")
7
  response = None
8
 
9
  def example_question_message(data_source, name, titles, schema):
 
13
  f"""We have a SQLite database with the following {titles}.
14
  We also have an AI agent with access to the same database that will be performing data analysis.
15
  Please return an array of seven strings, each one being a question for our data analysis agent
16
+ that we can suggest that you believe will be insightful or helpful to a data analysis looking for
17
  data insights. Return nothing more than the array of questions because I need that specific data structure
18
  to process your response. No other response type or data structure will work."""],
19
 
20
+ 'sql' : [f"You are a helpful and knowledgeable agent who has access to an MongoDB NoSQL document database called {name}.",
21
  f"""We have a PostgreSQL database with the following tables: {titles}.
22
  We also have an AI agent with access to the same database that will be performing data analysis.
23
  Please return an array of seven strings, each one being a question for our data analysis agent
24
+ that we can suggest that you believe will be insightful or helpful to a data analysis looking for
25
  data insights. Return nothing more than the array of questions because I need that specific data structure
26
  to process your response. No other response type or data structure will work."""],
27
 
 
30
  The schema of these collections is: {schema}.
31
  We also have an AI agent with access to the same database that will be performing data analysis.
32
  Please return an array of seven strings, each one being a question for our data analysis agent
33
+ that we can suggest that you believe will be insightful or helpful to a data analysis looking for
34
  data insights. Return nothing more than the array of questions because I need that specific data structure
35
  to process your response. No other response type or data structure will work."""],
36
 
 
38
  f"""We have a GraphQL API endpoint with the following types: {titles}.
39
  We also have an AI agent with access to the same GraphQL API endpoint that will be performing data analysis.
40
  Please return an array of seven strings, each one being a question for our data analysis agent
41
+ that we can suggest that you believe will be insightful or helpful to a data analysis looking for
42
  data insights. Return nothing more than the array of questions because I need that specific data structure
43
  to process your response. No other response type or data structure will work."""]
44
 
 
57
 
58
  example_messages.append(ChatMessage.from_user(text=example_message_list[1]))
59
 
60
+ example_response = chat_generator.run(messages=example_messages)
61
 
62
+ return example_response["replies"][0].text
 
 
 
 
 
 
 
63
 
64
  def system_message(data_source, titles, schema=""):
 
 
 
 
 
 
 
 
65
 
66
  system_message_dict = {
67
+ 'file_upload' : f"""You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source' that contains the following columns: {titles}.
68
+ You also have access to a function, called table_generation_func, that can take a query.csv file generated from our sql query and returns an iframe that we should display in our chat window.
69
+ You also have access to a scatter plot function, called scatter_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a scatter plot and returns an iframe that we should display in our chat window.
70
+ You also have access to a line chart function, called line_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a line chart and returns an iframe that we should display in our chat window.
71
+ You also have access to a bar graph function, called line_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a bar graph and returns an iframe that we should display in our chat window.
72
+ You also have access to a pie chart function, called pie_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a pie chart and returns an iframe that we should display in our chat window.
73
+ You also have access to a histogram function, called histogram_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a histogram and returns an iframe that we should display in our chat window.
74
+ You also have access to a linear regression function, called regression_func, that can take a query.csv file generated from our sql query and a list of column names for our independent and dependent variables and return a regression data string and a regression chart which is returned as an iframe.
75
+ Could you please always display the generated charts, tables, and visualizations as part of your output?""",
76
+
77
+ 'sql' : f"""You are a helpful and knowledgeable agent who has access to an PostgreSQL database which has a series of tables called {titles}.
78
+ You also have access to a function, called table_generation_func, that can take a query.csv file generated from our sql query and returns an iframe that we should display in our chat window.
79
+ You also have access to a scatter plot function, called scatter_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a scatter plot and returns an iframe that we should display in our chat window.
80
+ You also have access to a line chart function, called line_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a line chart and returns an iframe that we should display in our chat window.
81
+ You also have access to a bar graph function, called line_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a bar graph and returns an iframe that we should display in our chat window.
82
+ You also have access to a pie chart function, called pie_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a pie chart and returns an iframe that we should display in our chat window.
83
+ You also have access to a histogram function, called histogram_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a histogram and returns an iframe that we should display in our chat window.
84
+ You also have access to a linear regression function, called regression_func, that can take a query.csv file generated from our sql query and a list of column names for our independent and dependent variables and return a regression data string and a regression chart which is returned as an iframe.
85
+ Could you please always display the generated charts, tables, and visualizations as part of your output?""",
86
+
87
+ 'doc_db' : f"""You are a helpful and knowledgeable agent who has access to a NoSQL MongoDB Document database which has a series of collections called {titles}.
88
+ The schema of these collections is: {schema}.
89
+ You also have access to a function, called table_generation_func, that can take a query.csv file generated from our MongoDB query and returns an iframe that we should display in our chat window.
90
+ You also have access to a scatter plot function, called scatter_chart_generation_func, that can take a query.csv file generated from our MongoDB query and uses plotly dictionaries to generate a scatter plot and returns an iframe that we should display in our chat window.
91
+ You also have access to a line chart function, called line_chart_generation_func, that can take a query.csv file generated from our MongoDB query and uses plotly dictionaries to generate a line chart and returns an iframe that we should display in our chat window.
92
+ You also have access to a bar graph function, called line_chart_generation_func, that can take a query.csv file generated from our MongoDB query and uses plotly dictionaries to generate a bar graph and returns an iframe that we should display in our chat window.
93
+ You also have access to a pie chart function, called pie_chart_generation_func, that can take a query.csv file generated from our MongoDB query and uses plotly dictionaries to generate a pie chart and returns an iframe that we should display in our chat window.
94
+ You also have access to a histogram function, called histogram_generation_func, that can take a query.csv file generated from our MongoDB query and uses plotly dictionaries to generate a histogram and returns an iframe that we should display in our chat window.
95
+ You also have access to a linear regression function, called regression_func, that can take a query.csv file generated from our MongoDB query and a list of column names for our independent and dependent variables and return a regression data string and a regression chart which is returned as an iframe.
96
+ Could you please always display the generated charts, tables, and visualizations as part of your output?""",
97
+
98
+ 'graphql' : f"""You are a helpful and knowledgeable agent who has access to a GraphQL API which has the following types: {titles}.
99
+ We have also saved a schema.json file that contains the entire introspection query that we can use to find out more about each type before making a query.
100
+ You also have access to a function, called table_generation_func, that can take a query.csv file generated from our GraphQL API query and returns an iframe that we should display in our chat window.
101
+ You also have access to a scatter plot function, called scatter_chart_generation_func, that can take a query.csv file generated from our GraphQL API query and uses plotly dictionaries to generate a scatter plot and returns an iframe that we should display in our chat window.
102
+ You also have access to a line chart function, called line_chart_generation_func, that can take a query.csv file generated from our GraphQL API query and uses plotly dictionaries to generate a line chart and returns an iframe that we should display in our chat window.
103
+ You also have access to a bar graph function, called line_chart_generation_func, that can take a query.csv file generated from our GraphQL API query and uses plotly dictionaries to generate a bar graph and returns an iframe that we should display in our chat window.
104
+ You also have access to a pie chart function, called pie_chart_generation_func, that can take a query.csv file generated from our GraphQL API query and uses plotly dictionaries to generate a pie chart and returns an iframe that we should display in our chat window.
105
+ You also have access to a histogram function, called histogram_generation_func, that can take a query.csv file generated from our GraphQL API query and uses plotly dictionaries to generate a histogram and returns an iframe that we should display in our chat window.
106
+ You also have access to a linear regression function, called regression_func, that can take a query.csv file generated from our GraphQL API query and a list of column names for our independent and dependent variables and return a regression data string and a regression chart which is returned as an iframe.
107
+ Could you please always display the generated charts, tables, and visualizations as part of your output?"""
108
+
109
  }
110
 
111
  return system_message_dict[data_source]
112
 
113
  def chatbot_func(message, history, session_hash, data_source, titles, schema, *args):
114
+ from functions import sqlite_query_func, table_generation_func, regression_func, scatter_chart_generation_func, \
115
+ sql_query_func, doc_db_query_func, graphql_query_func, graphql_schema_query, graphql_csv_query, \
116
+ line_chart_generation_func,bar_chart_generation_func,pie_chart_generation_func,histogram_generation_func
 
 
 
 
 
 
 
 
 
 
117
  import tools.tools as tools
118
 
119
+ available_functions = {"sqlite_query_func": sqlite_query_func,"sql_query_func": sql_query_func,"doc_db_query_func": doc_db_query_func,
120
+ "graphql_query_func": graphql_query_func,"graphql_schema_query": graphql_schema_query,"graphql_csv_query": graphql_csv_query,
121
+ "table_generation_func":table_generation_func,
122
+ "line_chart_generation_func":line_chart_generation_func,"bar_chart_generation_func":bar_chart_generation_func,
123
+ "scatter_chart_generation_func":scatter_chart_generation_func, "pie_chart_generation_func":pie_chart_generation_func,
124
+ "histogram_generation_func":histogram_generation_func,
125
+ "regression_func":regression_func }
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  if message_dict[session_hash][data_source] != None:
128
  message_dict[session_hash][data_source].append(ChatMessage.from_user(message))
 
133
  messages.append(ChatMessage.from_user(message))
134
  message_dict[session_hash][data_source] = messages
135
 
136
+ response = chat_generator.run(messages=message_dict[session_hash][data_source], generation_kwargs={"tools": tools.tools_call(session_hash, data_source, titles)})
 
137
 
138
  while True:
139
+ # if OpenAI response is a tool call
140
  if response and response["replies"][0].meta["finish_reason"] == "tool_calls" or response["replies"][0].tool_calls:
141
  function_calls = response["replies"][0].tool_calls
142
  for function_call in function_calls:
 
151
  print(function_name)
152
  ## Append function response to the messages list using `ChatMessage.from_tool`
153
  message_dict[session_hash][data_source].append(ChatMessage.from_tool(tool_result=function_response['reply'], origin=function_call))
154
+ response = chat_generator.run(messages=message_dict[session_hash][data_source], generation_kwargs={"tools": tools.tools_call(session_hash, data_source, titles)})
155
 
156
  # Regular Conversation
157
  else:
functions/query_functions.py CHANGED
@@ -23,16 +23,36 @@ class SQLiteQuery:
23
  self.connection = sqlite3.connect(sql_database, check_same_thread=False)
24
 
25
  @component.output_types(results=List[str], queries=List[str])
26
- def run(self, queries: AnyStr, session_hash):
27
  print("ATTEMPTING TO RUN SQLITE QUERY")
28
  dir_path = TEMP_DIR / str(session_hash)
29
  results = []
30
- result = pd.read_sql(queries, self.connection)
31
- result.to_csv(f'{dir_path}/file_upload/query.csv', index=False)
32
- column_names = list(result.columns)
33
- results.append(f"{result}")
34
  self.connection.close()
35
- return {"results": results, "queries": queries, "csv_columns": column_names}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  @component
38
  class PostgreSQLQuery:
@@ -47,16 +67,39 @@ class PostgreSQLQuery:
47
  )
48
 
49
  @component.output_types(results=List[str], queries=List[str])
50
- def run(self, queries: AnyStr, session_hash):
51
  print("ATTEMPTING TO RUN POSTGRESQL QUERY")
52
  dir_path = TEMP_DIR / str(session_hash)
53
  results = []
54
- result = pd.read_sql_query(queries, self.connection)
55
- result.to_csv(f'{dir_path}/sql/query.csv', index=False)
56
- column_names = list(result.columns)
57
- results.append(f"{result}")
 
58
  self.connection.close()
59
- return {"results": results, "queries": queries, "csv_columns": column_names}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  @component
62
  class DocDBQuery:
@@ -100,11 +143,31 @@ class DocDBQuery:
100
  docs = collection.aggregate_pandas_all(query_list)
101
  print("DATA FRAME COMPLETE")
102
  docs.to_csv(f'{dir_path}/doc_db/query.csv', index=False)
103
- column_names = list(docs.columns)
104
  print("CSV COMPLETE")
105
  results.append(f"{docs}")
106
  self.client.close()
107
- return {"results": results, "queries": aggregation_pipeline, "csv_columns": column_names}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  @component
110
  class GraphQLQuery:
@@ -137,40 +200,25 @@ class GraphQLQuery:
137
  #print(response_frame)
138
 
139
  response_frame.to_csv(f'{dir_path}/graphql/query.csv', index=False)
140
- column_names = list(response_frame.columns)
141
  print("CSV COMPLETE")
142
  results.append(f"{response_frame}")
143
- return {"results": results, "queries": graphql_query, "csv_columns": column_names}
144
 
145
- def query_func(queries:AnyStr, session_hash, session_folder, args, **kwargs):
 
 
 
146
  try:
147
- print("QUERY")
148
- print(queries)
149
- if session_folder == "file_upload":
150
- dir_path = TEMP_DIR / str(session_hash)
151
- sql_query = SQLiteQuery(f'{dir_path}/file_upload/data_source.db')
152
- result = sql_query.run(queries, session_hash)
153
- elif session_folder == "sql":
154
- sql_query = PostgreSQLQuery(args[0], args[1], args[2], args[3], args[4])
155
- result = sql_query.run(queries, session_hash)
156
- elif session_folder == 'doc_db':
157
- doc_db_query = DocDBQuery(args[0], args[1])
158
- result = doc_db_query.run(queries, kwargs['db_collection'], session_hash)
159
- elif session_folder == 'graphql':
160
- graphql_object = GraphQLQuery()
161
- result = graphql_object.run(queries, args[0], args[1], args[2], session_hash)
162
  print("RESULT")
163
- print(result["csv_columns"])
164
  if len(result["results"][0]) > 1000:
165
  print("QUERY TOO LARGE")
166
- return {"reply": f"""query result too large to be processed by llm, the query results are in our query.csv file.
167
- The column names of this query.csv file are: {result["csv_columns"]}.
168
- If you need to display the results directly, perhaps use the table_generation_func function."""}
169
  else:
170
  return {"reply": result["results"][0]}
171
 
172
  except Exception as e:
173
- reply = f"""There was an error running the {session_folder} Query = {queries}
174
  The error is {e},
175
  You should probably try again.
176
  """
@@ -206,19 +254,11 @@ def graphql_csv_query(csv_query: AnyStr, session_hash, **kwargs):
206
  query = pd.read_csv(f'{dir_path}/graphql/query.csv')
207
  query.Name = 'query'
208
  print("GRAPHQL CSV QUERY")
209
- print(csv_query)
210
  queried_df = sqldf(csv_query, locals())
211
  print(queried_df)
212
- column_names = list(queried_df.columns)
213
  queried_df.to_csv(f'{dir_path}/graphql/query.csv', index=False)
214
 
215
- if len(queried_df) > 1000:
216
- print("CSV QUERY TOO LARGE")
217
- return {"reply": f"""The new query results are in our query.csv file.
218
- The column names of this query.csv file are: {column_names}.
219
- If you need to display the results directly, perhaps use the table_generation_func function."""}
220
- else:
221
- return {"reply": str(queried_df)}
222
 
223
  except Exception as e:
224
  reply = f"""There was an error querying our query.csv file with the query:{csv_query}
@@ -226,4 +266,4 @@ def graphql_csv_query(csv_query: AnyStr, session_hash, **kwargs):
226
  You should probably try again.
227
  """
228
  print(reply)
229
- return {"reply": reply}
 
23
  self.connection = sqlite3.connect(sql_database, check_same_thread=False)
24
 
25
  @component.output_types(results=List[str], queries=List[str])
26
+ def run(self, queries: List[str], session_hash):
27
  print("ATTEMPTING TO RUN SQLITE QUERY")
28
  dir_path = TEMP_DIR / str(session_hash)
29
  results = []
30
+ for query in queries:
31
+ result = pd.read_sql(query, self.connection)
32
+ result.to_csv(f'{dir_path}/file_upload/query.csv', index=False)
33
+ results.append(f"{result}")
34
  self.connection.close()
35
+ return {"results": results, "queries": queries}
36
+
37
+
38
+
39
+ def sqlite_query_func(queries: List[str], session_hash, **kwargs):
40
+ dir_path = TEMP_DIR / str(session_hash)
41
+ sql_query = SQLiteQuery(f'{dir_path}/file_upload/data_source.db')
42
+ try:
43
+ result = sql_query.run(queries, session_hash)
44
+ if len(result["results"][0]) > 1000:
45
+ print("QUERY TOO LARGE")
46
+ return {"reply": "query result too large to be processed by llm, the query results are in our query.csv file. If you need to display the results directly, perhaps use the table_generation_func function."}
47
+ else:
48
+ return {"reply": result["results"][0]}
49
+
50
+ except Exception as e:
51
+ reply = f"""There was an error running the SQL Query = {queries}
52
+ The error is {e},
53
+ You should probably try again.
54
+ """
55
+ return {"reply": reply}
56
 
57
  @component
58
  class PostgreSQLQuery:
 
67
  )
68
 
69
  @component.output_types(results=List[str], queries=List[str])
70
+ def run(self, queries: List[str], session_hash):
71
  print("ATTEMPTING TO RUN POSTGRESQL QUERY")
72
  dir_path = TEMP_DIR / str(session_hash)
73
  results = []
74
+ for query in queries:
75
+ print(query)
76
+ result = pd.read_sql_query(query, self.connection)
77
+ result.to_csv(f'{dir_path}/sql/query.csv', index=False)
78
+ results.append(f"{result}")
79
  self.connection.close()
80
+ return {"results": results, "queries": queries}
81
+
82
+
83
+
84
+ def sql_query_func(queries: List[str], session_hash, args, **kwargs):
85
+ sql_query = PostgreSQLQuery(args[0], args[1], args[2], args[3], args[4])
86
+ try:
87
+ result = sql_query.run(queries, session_hash)
88
+ print("RESULT")
89
+ print(result)
90
+ if len(result["results"][0]) > 1000:
91
+ print("QUERY TOO LARGE")
92
+ return {"reply": "query result too large to be processed by llm, the query results are in our query.csv file. If you need to display the results directly, perhaps use the table_generation_func function."}
93
+ else:
94
+ return {"reply": result["results"][0]}
95
+
96
+ except Exception as e:
97
+ reply = f"""There was an error running the SQL Query = {queries}
98
+ The error is {e},
99
+ You should probably try again.
100
+ """
101
+ print(reply)
102
+ return {"reply": reply}
103
 
104
  @component
105
  class DocDBQuery:
 
143
  docs = collection.aggregate_pandas_all(query_list)
144
  print("DATA FRAME COMPLETE")
145
  docs.to_csv(f'{dir_path}/doc_db/query.csv', index=False)
 
146
  print("CSV COMPLETE")
147
  results.append(f"{docs}")
148
  self.client.close()
149
+ return {"results": results, "queries": aggregation_pipeline}
150
+
151
+
152
+
153
+ def doc_db_query_func(aggregation_pipeline: List[str], db_collection: AnyStr, session_hash, args, **kwargs):
154
+ doc_db_query = DocDBQuery(args[0], args[1])
155
+ try:
156
+ result = doc_db_query.run(aggregation_pipeline, db_collection, session_hash)
157
+ print("RESULT")
158
+ if len(result["results"][0]) > 1000:
159
+ print("QUERY TOO LARGE")
160
+ return {"reply": "query result too large to be processed by llm, the query results are in our query.csv file. If you need to display the results directly, perhaps use the table_generation_func function."}
161
+ else:
162
+ return {"reply": result["results"][0]}
163
+
164
+ except Exception as e:
165
+ reply = f"""There was an error running the NoSQL (Mongo) Query = {aggregation_pipeline}
166
+ The error is {e},
167
+ You should probably try again.
168
+ """
169
+ print(reply)
170
+ return {"reply": reply}
171
 
172
  @component
173
  class GraphQLQuery:
 
200
  #print(response_frame)
201
 
202
  response_frame.to_csv(f'{dir_path}/graphql/query.csv', index=False)
 
203
  print("CSV COMPLETE")
204
  results.append(f"{response_frame}")
205
+ return {"results": results, "queries": graphql_query}
206
 
207
+
208
+
209
+ def graphql_query_func(graphql_query: AnyStr, session_hash, args, **kwargs):
210
+ graphql_object = GraphQLQuery()
211
  try:
212
+ result = graphql_object.run(graphql_query, args[0], args[1], args[2], session_hash)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  print("RESULT")
 
214
  if len(result["results"][0]) > 1000:
215
  print("QUERY TOO LARGE")
216
+ return {"reply": "query result too large to be processed by llm, the query results are in our query.csv file. If you need to display the results directly, perhaps use the table_generation_func function."}
 
 
217
  else:
218
  return {"reply": result["results"][0]}
219
 
220
  except Exception as e:
221
+ reply = f"""There was an error running the GraphQL Query = {graphql_query}
222
  The error is {e},
223
  You should probably try again.
224
  """
 
254
  query = pd.read_csv(f'{dir_path}/graphql/query.csv')
255
  query.Name = 'query'
256
  print("GRAPHQL CSV QUERY")
 
257
  queried_df = sqldf(csv_query, locals())
258
  print(queried_df)
 
259
  queried_df.to_csv(f'{dir_path}/graphql/query.csv', index=False)
260
 
261
+ return {"reply": "The new query results are in our query.csv file. If you need to display the results directly, perhaps use the table_generation_func function."}
 
 
 
 
 
 
262
 
263
  except Exception as e:
264
  reply = f"""There was an error querying our query.csv file with the query:{csv_query}
 
266
  You should probably try again.
267
  """
268
  print(reply)
269
+ return {"reply": reply}
functions/stat_functions.py CHANGED
@@ -5,244 +5,12 @@ from utils import TEMP_DIR
5
  import plotly.express as px
6
  import plotly.io as pio
7
  import os
8
- from functions.chart_functions import scatter_chart_fig, llm_chart_data_scrub, _write_chart
9
  from dotenv import load_dotenv
10
 
11
  load_dotenv()
12
 
13
- root_url = os.getenv("ROOT_URL", "")
14
-
15
- def descriptive_stats_func(session_hash, session_folder, columns: List[str]=[], **kwargs):
16
- print("DESCRIPTIVE STATISTICS")
17
- try:
18
- from html import escape
19
-
20
- dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
21
- csv_query_path = f'{dir_path}/query.csv'
22
-
23
- df = pd.read_csv(csv_query_path)
24
-
25
- if columns:
26
- df = df[[c for c in columns if c in df.columns]]
27
-
28
- desc = df.describe().round(4)
29
-
30
- header_cells = '<th style="background:#1e40af;">Statistic</th>' + ''.join(
31
- f'<th>{escape(str(col))}</th>' for col in desc.columns
32
- )
33
- row_html = [
34
- '<tr>'
35
- + f'<td style="font-weight:600;color:#1e40af;background:#eff6ff;white-space:nowrap;">{escape(str(idx))}</td>'
36
- + ''.join(f'<td>{escape(str(val))}</td>' for val in row)
37
- + '</tr>'
38
- for idx, row in desc.iterrows()
39
- ]
40
-
41
- style = (
42
- '<style>'
43
- '.vda-table-wrap{overflow-x:auto;margin:8px 0;border-radius:8px;border:1px solid #e5e7eb;}'
44
- '.vda-table{width:100%;border-collapse:collapse;font-size:13px;font-family:Inter,system-ui,sans-serif;}'
45
- '.vda-table thead th{background:#3B82F6;color:#fff;padding:9px 14px;text-align:left;white-space:nowrap;font-weight:600;}'
46
- '.vda-table tbody td{padding:7px 14px;border-bottom:1px solid #f1f5f9;white-space:nowrap;}'
47
- '.vda-table tbody tr:nth-child(even){background:#f8fafc;}'
48
- '.vda-table tbody tr:last-child td{border-bottom:none;}'
49
- '</style>'
50
- )
51
- table = (
52
- '<div class="vda-table-wrap"><table class="vda-table">'
53
- f'<thead><tr>{header_cells}</tr></thead>'
54
- '<tbody>' + '\n'.join(row_html) + '</tbody>'
55
- '</table></div>'
56
- )
57
-
58
- return {"reply": style + table}
59
-
60
- except Exception as e:
61
- print("DESCRIPTIVE STATS ERROR")
62
- print(e)
63
- return {"reply": f"There was an error generating descriptive statistics. Error: {e}. You should probably try again."}
64
-
65
-
66
- def kmeans_clustering_func(feature_columns: List[str], x_column: str, y_column: str,
67
- session_hash, session_folder, n_clusters: int = 3,
68
- layout: List[dict] = [{}], **kwargs):
69
- print("KMEANS CLUSTERING")
70
- try:
71
- from sklearn.cluster import KMeans
72
- from sklearn.preprocessing import StandardScaler
73
- from html import escape
74
-
75
- dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
76
- chart_path = f'{dir_path}/chart.html'
77
- csv_query_path = f'{dir_path}/query.csv'
78
-
79
- df = pd.read_csv(csv_query_path)
80
-
81
- feature_df = df[feature_columns].select_dtypes(include='number').dropna()
82
- if feature_df.shape[1] < 1:
83
- return {"reply": "No numeric feature columns found for clustering. Please refine your query to include numeric columns."}
84
-
85
- X_scaled = StandardScaler().fit_transform(feature_df)
86
- labels = KMeans(n_clusters=n_clusters, random_state=42, n_init=10).fit_predict(X_scaled)
87
-
88
- df_clustered = df.loc[feature_df.index].copy()
89
- df_clustered['Cluster'] = [f'Cluster {l}' for l in labels]
90
-
91
- fig = px.scatter(
92
- df_clustered, x=x_column, y=y_column, color='Cluster',
93
- title=f'K-Means Clustering (k={n_clusters})',
94
- )
95
- fig.update_layout(font=dict(family='Inter, system-ui, sans-serif'))
96
-
97
- _, layout_dict = llm_chart_data_scrub({}, layout)
98
- if layout_dict:
99
- fig.update_layout(**layout_dict)
100
-
101
- chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
102
- iframe = _write_chart(fig, chart_path, chart_url)
103
-
104
- cluster_summary = df_clustered.groupby('Cluster')[feature_columns].mean().round(3)
105
- header_cells = '<th style="background:#1e40af;">Cluster</th>' + ''.join(
106
- f'<th>{escape(str(col))}</th>' for col in cluster_summary.columns
107
- )
108
- row_html = [
109
- '<tr>'
110
- + f'<td style="font-weight:600;color:#1e40af;background:#eff6ff;white-space:nowrap;">{escape(str(idx))}</td>'
111
- + ''.join(f'<td>{escape(str(val))}</td>' for val in row)
112
- + '</tr>'
113
- for idx, row in cluster_summary.iterrows()
114
- ]
115
- style = (
116
- '<style>'
117
- '.vda-table-wrap{overflow-x:auto;margin:8px 0;border-radius:8px;border:1px solid #e5e7eb;}'
118
- '.vda-table{width:100%;border-collapse:collapse;font-size:13px;font-family:Inter,system-ui,sans-serif;}'
119
- '.vda-table thead th{background:#3B82F6;color:#fff;padding:9px 14px;text-align:left;white-space:nowrap;font-weight:600;}'
120
- '.vda-table tbody td{padding:7px 14px;border-bottom:1px solid #f1f5f9;white-space:nowrap;}'
121
- '.vda-table tbody tr:nth-child(even){background:#f8fafc;}'
122
- '.vda-table tbody tr:last-child td{border-bottom:none;}'
123
- '</style>'
124
- )
125
- summary_table = (
126
- '<div class="vda-table-wrap"><table class="vda-table">'
127
- f'<thead><tr>{header_cells}</tr></thead>'
128
- '<tbody>' + '\n'.join(row_html) + '</tbody>'
129
- '</table></div>'
130
- )
131
-
132
- return {"reply": f'{iframe}\n\n**Cluster Centroids (feature means per cluster):**\n{style}{summary_table}'}
133
-
134
- except Exception as e:
135
- print("KMEANS CLUSTERING ERROR")
136
- print(e)
137
- return {"reply": f"There was an error running K-Means clustering. Error: {e}. You should probably try again."}
138
-
139
-
140
- def hypothesis_test_func(test_type: str, column: str, session_hash, session_folder,
141
- column2: str = "", group_column: str = "",
142
- group_values: List[str] = [], pop_mean: float = 0.0, **kwargs):
143
- print("HYPOTHESIS TEST")
144
- try:
145
- from scipy import stats
146
- from html import escape
147
-
148
- dir_path = TEMP_DIR / str(session_hash) / str(session_folder)
149
- csv_query_path = f'{dir_path}/query.csv'
150
- df = pd.read_csv(csv_query_path)
151
-
152
- if test_type == "t_test_independent":
153
- if not group_column or group_column not in df.columns:
154
- return {"reply": "Please specify a valid group_column for the independent t-test."}
155
- unique_groups = df[group_column].dropna().unique().tolist()
156
- if group_values and len(group_values) == 2:
157
- g1_label, g2_label = group_values[0], group_values[1]
158
- elif len(unique_groups) == 2:
159
- g1_label, g2_label = unique_groups[0], unique_groups[1]
160
- else:
161
- return {"reply": f"For an independent t-test, exactly 2 groups are needed. Found: {unique_groups}. Specify group_values with 2 entries."}
162
-
163
- g1 = df[df[group_column] == g1_label][column].dropna()
164
- g2 = df[df[group_column] == g2_label][column].dropna()
165
- t_stat, p_value = stats.ttest_ind(g1, g2)
166
-
167
- result_rows = [
168
- ("Test", "Independent Samples T-Test"),
169
- ("Column", column),
170
- ("Group Column", group_column),
171
- (f"Group 1", str(g1_label)),
172
- (f"Group 2", str(g2_label)),
173
- (f"Group 1 Mean (n={len(g1)})", f"{g1.mean():.4f}"),
174
- (f"Group 2 Mean (n={len(g2)})", f"{g2.mean():.4f}"),
175
- ("T-Statistic", f"{t_stat:.4f}"),
176
- ("P-Value", f"{p_value:.6f}"),
177
- ("Significant at α=0.05", "Yes ✓" if p_value < 0.05 else "No ✗"),
178
- ]
179
- title = f"T-Test: {column} by {group_column}"
180
-
181
- elif test_type == "t_test_one_sample":
182
- sample = df[column].dropna()
183
- t_stat, p_value = stats.ttest_1samp(sample, pop_mean)
184
- result_rows = [
185
- ("Test", "One-Sample T-Test"),
186
- ("Column", column),
187
- ("Hypothesized Mean (μ₀)", f"{pop_mean:.4f}"),
188
- (f"Sample Mean (n={len(sample)})", f"{sample.mean():.4f}"),
189
- ("Sample Std Dev", f"{sample.std():.4f}"),
190
- ("T-Statistic", f"{t_stat:.4f}"),
191
- ("P-Value", f"{p_value:.6f}"),
192
- ("Significant at α=0.05", "Yes ✓" if p_value < 0.05 else "No ✗"),
193
- ]
194
- title = f"One-Sample T-Test: {column} vs μ={pop_mean}"
195
-
196
- elif test_type == "chi_square":
197
- if not column2 or column2 not in df.columns:
198
- return {"reply": "Please specify a valid column2 for the chi-square test."}
199
- contingency = pd.crosstab(df[column], df[column2])
200
- chi2, p_value, dof, _ = stats.chi2_contingency(contingency)
201
- result_rows = [
202
- ("Test", "Chi-Square Test of Independence"),
203
- ("Column 1", column),
204
- ("Column 2", column2),
205
- ("Chi-Square Statistic", f"{chi2:.4f}"),
206
- ("Degrees of Freedom", str(dof)),
207
- ("P-Value", f"{p_value:.6f}"),
208
- ("Significant at α=0.05", "Yes ✓" if p_value < 0.05 else "No ✗"),
209
- ]
210
- title = f"Chi-Square: {column} × {column2}"
211
-
212
- else:
213
- return {"reply": f"Unknown test_type '{test_type}'. Use one of: t_test_independent, t_test_one_sample, chi_square."}
214
-
215
- style = (
216
- '<style>'
217
- '.vda-table-wrap{overflow-x:auto;margin:8px 0;border-radius:8px;border:1px solid #e5e7eb;}'
218
- '.vda-table{width:100%;border-collapse:collapse;font-size:13px;font-family:Inter,system-ui,sans-serif;}'
219
- '.vda-table thead th{background:#3B82F6;color:#fff;padding:9px 14px;text-align:left;white-space:nowrap;font-weight:600;}'
220
- '.vda-table tbody td{padding:7px 14px;border-bottom:1px solid #f1f5f9;white-space:nowrap;}'
221
- '.vda-table tbody tr:nth-child(even){background:#f8fafc;}'
222
- '.vda-table tbody tr:last-child td{border-bottom:none;}'
223
- '</style>'
224
- )
225
- header_cells = f'<th style="background:#1e40af;" colspan="2">{escape(title)}</th>'
226
- row_html = [
227
- '<tr>'
228
- + f'<td style="font-weight:600;color:#1e40af;background:#eff6ff;white-space:nowrap;">{escape(label)}</td>'
229
- + f'<td>{escape(value)}</td>'
230
- + '</tr>'
231
- for label, value in result_rows
232
- ]
233
- table = (
234
- '<div class="vda-table-wrap"><table class="vda-table">'
235
- f'<thead><tr>{header_cells}</tr></thead>'
236
- '<tbody>' + '\n'.join(row_html) + '</tbody>'
237
- '</table></div>'
238
- )
239
- return {"reply": style + table}
240
-
241
- except Exception as e:
242
- print("HYPOTHESIS TEST ERROR")
243
- print(e)
244
- return {"reply": f"There was an error running the hypothesis test. Error: {e}. You should probably try again."}
245
-
246
 
247
  def regression_func(independent_variables: List[str], dependent_variable: str, session_hash, session_folder, category: str='', **kwargs):
248
  print("LINEAR REGRESSION CALCULATION")
@@ -262,8 +30,11 @@ def regression_func(independent_variables: List[str], dependent_variable: str, s
262
  fig = scatter_chart_fig(df=df,x_column=independent_variables,y_column=dependent_variable,
263
  trendline="ols")
264
 
 
 
265
  chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
266
- iframe = _write_chart(fig, chart_path, chart_url)
 
267
 
268
  results_frame = px.get_trendline_results(fig)
269
 
 
5
  import plotly.express as px
6
  import plotly.io as pio
7
  import os
8
+ from functions import scatter_chart_fig
9
  from dotenv import load_dotenv
10
 
11
  load_dotenv()
12
 
13
+ root_url = os.getenv("ROOT_URL")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def regression_func(independent_variables: List[str], dependent_variable: str, session_hash, session_folder, category: str='', **kwargs):
16
  print("LINEAR REGRESSION CALCULATION")
 
30
  fig = scatter_chart_fig(df=df,x_column=independent_variables,y_column=dependent_variable,
31
  trendline="ols")
32
 
33
+ pio.write_html(fig, chart_path, full_html=False)
34
+
35
  chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/{session_folder}/chart.html'
36
+
37
+ iframe = 'Please display this iframe: <div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
38
 
39
  results_frame = px.get_trendline_results(fig)
40
 
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
- haystack-ai>=2.7.0
2
- anthropic-haystack
3
  python-dotenv
4
  gradio
5
  pandas
@@ -13,6 +12,3 @@ pymongoarrow
13
  pymongo_schema
14
  pandasql
15
  pluck-graphql
16
- certifi==2025.1.31
17
- scipy
18
- scikit-learn
 
1
+ haystack-ai
 
2
  python-dotenv
3
  gradio
4
  pandas
 
12
  pymongo_schema
13
  pandasql
14
  pluck-graphql
 
 
 
samples/online_retail_data.csv CHANGED
The diff for this file is too large to render. See raw diff
 
temp/.gitignore DELETED
@@ -1,2 +0,0 @@
1
- *
2
- !.gitignore
 
 
 
templates/data_file.py CHANGED
@@ -1,286 +1,136 @@
1
- import gradio as gr
2
- from functions import example_question_generator, chatbot_func
3
- from data_sources import process_data_upload
4
- from utils import message_dict
5
- import ast
6
- import html as _html
7
-
8
- def build_summary_modal(stats):
9
- num_rows = stats['num_rows']
10
- num_cols = stats['num_cols']
11
- total_missing = stats['total_missing']
12
- duplicate_rows = stats.get('duplicate_rows', 0)
13
- file_size_bytes = stats.get('file_size_bytes', 0)
14
-
15
- def _fmt_num(v):
16
- try:
17
- if v != v: return '—' # NaN
18
- abs_v = abs(v)
19
- if abs_v >= 1e9: return f"{v/1e9:.1f}B"
20
- if abs_v >= 1e6: return f"{v/1e6:.1f}M"
21
- if abs_v >= 1e3: return f"{v:,.0f}" if v == int(v) else f"{v:,.1f}"
22
- return f"{v:,.0f}" if v == int(v) else f"{v:.2f}"
23
- except Exception:
24
- return str(v)
25
-
26
- def _fmt_size(b):
27
- if not b: return ''
28
- if b < 1024: return f"{b} B"
29
- if b < 1024 ** 2: return f"{b / 1024:.1f} KB"
30
- if b < 1024 ** 3: return f"{b / 1024 ** 2:.1f} MB"
31
- return f"{b / 1024 ** 3:.2f} GB"
32
-
33
- file_size_label = _fmt_size(file_size_bytes)
34
- dup_color = "#ef4444" if duplicate_rows > 0 else "#a16207"
35
- dup_bg = "#fef2f2" if duplicate_rows > 0 else "#fefce8"
36
- dup_border = "#fecaca" if duplicate_rows > 0 else "#fde68a"
37
-
38
- dtype_rows_html = ""
39
- for i, (col, dtype) in enumerate(stats['dtypes'].items()):
40
- bg = "#ffffff" if i % 2 == 0 else "#f9fafb"
41
- missing = stats['missing_per_col'].get(col, 0)
42
- pct_missing = (missing / num_rows * 100) if num_rows > 0 else 0
43
- missing_color = "#ef4444" if missing > 0 else "#9ca3af"
44
- missing_weight = "600" if missing > 0 else "400"
45
- missing_cell = f'{missing:,} <span style="color:#9ca3af;font-size:0.7rem;">({pct_missing:.1f}%)</span>'
46
-
47
- unique = stats.get('unique_counts', {}).get(col, '—')
48
- is_id = isinstance(unique, int) and num_rows > 0 and (unique / num_rows) >= 0.95 and unique > 10
49
- id_badge = ' <span style="background:#fef3c7;color:#92400e;padding:1px 5px;border-radius:3px;font-size:0.65rem;vertical-align:middle;">ID?</span>' if is_id else ''
50
- unique_cell = f'{unique:,}{id_badge}' if isinstance(unique, int) else str(unique)
51
-
52
- cs = stats.get('col_stats', {}).get(col, {})
53
- if cs.get('type') == 'numeric':
54
- stats_cell = (
55
- f'<span style="font-size:0.74rem;color:#6b7280;line-height:1.6;">'
56
- f'{_fmt_num(cs["min"])} {_fmt_num(cs["max"])}'
57
- f'<br><span style="color:#9ca3af;">avg {_fmt_num(cs["mean"])}</span></span>'
58
- )
59
- elif cs.get('type') == 'datetime':
60
- stats_cell = (
61
- f'<span style="font-size:0.74rem;color:#6b7280;line-height:1.6;">'
62
- f'{cs["min"]}<br>→ {cs["max"]}</span>'
63
- )
64
- else:
65
- stats_cell = '<span style="color:#d1d5db;">—</span>'
66
-
67
- dtype_rows_html += (
68
- f'<tr style="background:{bg}">'
69
- f'<td style="padding:7px 12px;border-bottom:1px solid #f3f4f6;color:#111827;white-space:nowrap;">{_html.escape(col)}</td>'
70
- f'<td style="padding:7px 12px;border-bottom:1px solid #f3f4f6;white-space:nowrap;"><span style="background:#dbeafe;color:#1e40af;padding:2px 8px;border-radius:4px;font-size:0.74rem;">{dtype}</span></td>'
71
- f'<td style="padding:7px 12px;border-bottom:1px solid #f3f4f6;text-align:right;color:{missing_color};font-weight:{missing_weight};white-space:nowrap;">{missing_cell}</td>'
72
- f'<td style="padding:7px 12px;border-bottom:1px solid #f3f4f6;text-align:right;white-space:nowrap;color:#374151;">{unique_cell}</td>'
73
- f'<td style="padding:7px 12px;border-bottom:1px solid #f3f4f6;">{stats_cell}</td>'
74
- f'</tr>'
75
- )
76
-
77
- preview_headers_html = "".join(
78
- f'<th style="padding:8px 12px;color:#6b7280;font-weight:500;border-bottom:1px solid #e5e7eb;white-space:nowrap;text-align:left;">{_html.escape(col)}</th>'
79
- for col in stats['preview_cols']
80
- )
81
-
82
- preview_rows_html = ""
83
- for i, row in enumerate(stats['preview']):
84
- bg = "#ffffff" if i % 2 == 0 else "#f9fafb"
85
- cells = "".join(
86
- f'<td style="padding:7px 12px;border-bottom:1px solid #f3f4f6;color:#374151;white-space:nowrap;">{_html.escape(str(cell))}</td>'
87
- for cell in row
88
- )
89
- preview_rows_html += f'<tr style="background:{bg}">{cells}</tr>'
90
-
91
- size_tag = f'<span style="background:rgba(255,255,255,0.2);color:#fff;padding:2px 10px;border-radius:12px;font-size:0.75rem;font-weight:400;">{file_size_label}</span>' if file_size_label else ''
92
-
93
- return f"""
94
- <div class="vda-modal-overlay" style="position:fixed;inset:0;background:rgba(0,0,0,0.55);z-index:9999;display:flex;align-items:center;justify-content:center;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',sans-serif;">
95
- <div style="background:#fff;border-radius:14px;width:90%;max-width:800px;max-height:88vh;display:flex;flex-direction:column;box-shadow:0 25px 50px -12px rgba(0,0,0,0.35);overflow:hidden;">
96
- <div style="background:linear-gradient(135deg,#3B82F6,#0ea5e9);padding:16px 20px;display:flex;justify-content:space-between;align-items:center;flex-shrink:0;gap:12px;">
97
- <div style="display:flex;align-items:center;gap:10px;">
98
- <span style="color:#fff;font-weight:600;font-size:1rem;">Dataset Summary</span>
99
- {size_tag}
100
- </div>
101
- <button onclick="document.querySelectorAll('.vda-modal-overlay').forEach(function(e){{e.remove()}})" style="background:rgba(255,255,255,0.2);border:none;color:#fff;width:30px;height:30px;border-radius:50%;cursor:pointer;font-size:1rem;line-height:1;flex-shrink:0;">&#x2715;</button>
102
- </div>
103
- <div style="padding:20px;overflow-y:auto;flex:1;">
104
- <div style="display:grid;grid-template-columns:1fr 1fr 1fr 1fr;gap:10px;margin-bottom:20px;">
105
- <div style="background:#eff6ff;border:1px solid #bfdbfe;border-radius:8px;padding:12px;text-align:center;">
106
- <div style="font-size:1.4rem;font-weight:700;color:#1d4ed8;">{num_rows:,}</div>
107
- <div style="font-size:0.7rem;color:#64748b;text-transform:uppercase;letter-spacing:0.06em;margin-top:4px;">Rows</div>
108
- </div>
109
- <div style="background:#f0fdf4;border:1px solid #bbf7d0;border-radius:8px;padding:12px;text-align:center;">
110
- <div style="font-size:1.4rem;font-weight:700;color:#15803d;">{num_cols}</div>
111
- <div style="font-size:0.7rem;color:#64748b;text-transform:uppercase;letter-spacing:0.06em;margin-top:4px;">Columns</div>
112
- </div>
113
- <div style="background:#fefce8;border:1px solid #fde68a;border-radius:8px;padding:12px;text-align:center;">
114
- <div style="font-size:1.4rem;font-weight:700;color:#a16207;">{total_missing:,}</div>
115
- <div style="font-size:0.7rem;color:#64748b;text-transform:uppercase;letter-spacing:0.06em;margin-top:4px;">Missing Values</div>
116
- </div>
117
- <div style="background:{dup_bg};border:1px solid {dup_border};border-radius:8px;padding:12px;text-align:center;">
118
- <div style="font-size:1.4rem;font-weight:700;color:{dup_color};">{duplicate_rows:,}</div>
119
- <div style="font-size:0.7rem;color:#64748b;text-transform:uppercase;letter-spacing:0.06em;margin-top:4px;">Duplicate Rows</div>
120
- </div>
121
- </div>
122
- <div style="margin-bottom:20px;">
123
- <div style="font-size:0.78rem;font-weight:600;color:#374151;text-transform:uppercase;letter-spacing:0.06em;margin-bottom:10px;">Column Info</div>
124
- <div style="border:1px solid #e5e7eb;border-radius:8px;overflow:hidden;">
125
- <div style="max-height:210px;overflow:auto;">
126
- <table style="border-collapse:collapse;font-size:0.83rem;min-width:100%;">
127
- <thead style="background:#f9fafb;position:sticky;top:0;z-index:1;">
128
- <tr>
129
- <th style="text-align:left;padding:8px 12px;color:#6b7280;font-weight:500;border-bottom:1px solid #e5e7eb;white-space:nowrap;">Column</th>
130
- <th style="text-align:left;padding:8px 12px;color:#6b7280;font-weight:500;border-bottom:1px solid #e5e7eb;white-space:nowrap;">Type</th>
131
- <th style="text-align:right;padding:8px 12px;color:#6b7280;font-weight:500;border-bottom:1px solid #e5e7eb;white-space:nowrap;">Missing</th>
132
- <th style="text-align:right;padding:8px 12px;color:#6b7280;font-weight:500;border-bottom:1px solid #e5e7eb;white-space:nowrap;">Unique</th>
133
- <th style="text-align:left;padding:8px 12px;color:#6b7280;font-weight:500;border-bottom:1px solid #e5e7eb;white-space:nowrap;">Stats / Range</th>
134
- </tr>
135
- </thead>
136
- <tbody>{dtype_rows_html}</tbody>
137
- </table>
138
- </div>
139
- </div>
140
- </div>
141
- <div>
142
- <div style="font-size:0.78rem;font-weight:600;color:#374151;text-transform:uppercase;letter-spacing:0.06em;margin-bottom:10px;">Data Preview (first 5 rows)</div>
143
- <div style="border:1px solid #e5e7eb;border-radius:8px;overflow:hidden;">
144
- <div style="overflow:auto;max-height:200px;">
145
- <table style="border-collapse:collapse;font-size:0.8rem;">
146
- <thead style="background:#f9fafb;position:sticky;top:0;z-index:1;">
147
- <tr>{preview_headers_html}</tr>
148
- </thead>
149
- <tbody>{preview_rows_html}</tbody>
150
- </table>
151
- </div>
152
- </div>
153
- </div>
154
- </div>
155
- </div>
156
- </div>
157
- """
158
-
159
- def run_example(input):
160
- return input
161
-
162
- def example_display(input):
163
- if input == None:
164
- display = True
165
- else:
166
- display = False
167
- return [gr.update(visible=display), gr.update(visible=display), gr.update(visible=display), gr.update(visible=display)]
168
-
169
- with gr.Blocks() as demo:
170
- description = gr.HTML("""
171
- <div class="max-w-4xl mx-auto mb-12 text-center">
172
- <div class="bg-blue-50 border border-blue-200 rounded-lg max-w-2xl mx-auto">
173
- <h2 class="font-semibold text-blue-800 ">
174
- <i class="fas fa-info-circle mr-2"></i>Supported Files
175
- </h2>
176
- <div class="flex flex-wrap justify-center gap-3 pb-4 text-blue-700">
177
- <span class="tooltip">
178
- <i class="fas fa-file-csv mr-1"></i>CSV
179
- <span class="tooltip-text">Comma-separated values</span>
180
- </span>
181
- <span class="tooltip">
182
- <i class="fas fa-file-alt mr-1"></i>TSV
183
- <span class="tooltip-text">Tab-separated values</span>
184
- </span>
185
- <span class="tooltip">
186
- <i class="fas fa-file-alt mr-1"></i>TXT
187
- <span class="tooltip-text">Text files</span>
188
- </span>
189
- <span class="tooltip">
190
- <i class="fas fa-file-excel mr-1"></i>XLS/XLSX
191
- <span class="tooltip-text">Excel spreadsheets</span>
192
- </span>
193
- <span class="tooltip">
194
- <i class="fas fa-file-code mr-1"></i>XML
195
- <span class="tooltip-text">XML documents</span>
196
- </span>
197
- <span class="tooltip">
198
- <i class="fas fa-file-code mr-1"></i>JSON
199
- <span class="tooltip-text">JSON data files</span>
200
- </span>
201
- </div>
202
- </div>
203
- </div>
204
- """, elem_classes="description_component")
205
- example_file_1 = gr.File(visible=False, value="samples/bank_marketing_campaign.csv")
206
- example_file_2 = gr.File(visible=False, value="samples/online_retail_data.csv")
207
- example_file_3 = gr.File(visible=False, value="samples/tb_illness_data.csv")
208
- with gr.Row():
209
- example_btn_1 = gr.Button(value="Try Me: bank_marketing_campaign.csv", elem_classes="sample-btn bg-gradient-to-r from-blue-500 to-sky-600 text-white p-6 rounded-lg text-left hover:shadow-lg", size="md", variant="primary")
210
- example_btn_2 = gr.Button(value="Try Me: online_retail_data.csv", elem_classes="sample-btn bg-gradient-to-r from-blue-500 to-sky-600 text-white p-6 rounded-lg text-left hover:shadow-lg", size="md", variant="primary")
211
- example_btn_3 = gr.Button(value="Try Me: tb_illness_data.csv", elem_classes="sample-btn bg-gradient-to-r from-blue-500 to-sky-600 text-white p-6 rounded-lg text-left hover:shadow-lg", size="md", variant="primary")
212
-
213
- file_output = gr.File(label="Data File (CSV, TSV, TXT, XLS, XLSX, XML, JSON)", show_label=True, elem_classes="file_marker drop-zone border-2 border-dashed border-gray-300 rounded-lg hover:border-primary cursor-pointer bg-gray-50 hover:bg-blue-50 transition-colors duration-300", file_types=['.csv', '.xlsx', '.txt', '.json', '.ndjson', '.xml', '.xls', '.tsv'])
214
- example_btn_1.click(fn=run_example, inputs=example_file_1, outputs=file_output)
215
- example_btn_2.click(fn=run_example, inputs=example_file_2, outputs=file_output)
216
- example_btn_3.click(fn=run_example, inputs=example_file_3, outputs=file_output)
217
- file_output.change(fn=example_display, inputs=file_output, outputs=[example_btn_1, example_btn_2, example_btn_3, description])
218
-
219
- @gr.render(inputs=file_output)
220
- def data_options(filename, request: gr.Request):
221
- print(filename)
222
- if request.session_hash not in message_dict:
223
- message_dict[request.session_hash] = {}
224
- message_dict[request.session_hash]['file_upload'] = None
225
- if filename:
226
- process_message = process_upload(filename, request.session_hash)
227
- gr.HTML(value=process_message[1], padding=False)
228
- if process_message[0] == "success":
229
- gr.HTML(value=build_summary_modal(process_message[3]), padding=False)
230
- if "bank_marketing_campaign" in filename:
231
- example_questions = [
232
- ["Describe the dataset"],
233
- ["What levels of education have the highest and lowest average balance?"],
234
- ["What job is most and least common for a yes response from the individuals, not counting 'unknown'?"],
235
- ["Can you generate a bar chart of education vs. average balance?"],
236
- ["Can you generate a table of levels of education versus average balance, percent married, percent with a loan, and percent in default?"],
237
- ["Can we predict the relationship between the number of contacts performed before this campaign and the average balance?"],
238
- ["Can you plot the number of contacts performed before this campaign versus the duration and use balance as the size in a bubble chart?"]
239
- ]
240
- elif "online_retail_data" in filename:
241
- example_questions = [
242
- ["Describe the dataset"],
243
- ["What month had the highest revenue?"],
244
- ["Is revenue higher in the morning or afternoon?"],
245
- ["Can you generate a line graph of revenue per month?"],
246
- ["Can you generate a table of revenue per month?"],
247
- ["Can we predict how time of day affects transaction value in this data set?"],
248
- ["Can you plot revenue per month with size being the number of units sold that month in a bubble chart?"]
249
- ]
250
- else:
251
- try:
252
- generated_examples = ast.literal_eval(example_question_generator(request.session_hash, 'file_upload', '', process_message[2], ''))
253
- example_questions = [["Describe the dataset"]]
254
- for example in generated_examples:
255
- example_questions.append([example])
256
- except Exception as e:
257
- print("DATA FILE QUESTION GENERATION ERROR")
258
- print(e)
259
- example_questions = [
260
- ["Describe the dataset"],
261
- ["List the columns in the dataset"],
262
- ["What could this data be used for?"],
263
- ]
264
- session_hash = gr.Textbox(visible=False, value=request.session_hash)
265
- data_source = gr.Textbox(visible=False, value='file_upload')
266
- schema = gr.Textbox(visible=False, value='')
267
- titles = gr.Textbox(value=process_message[2], interactive=False, visible=False)
268
- bot = gr.Chatbot(type='messages', label="CSV Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
269
- chat = gr.ChatInterface(
270
- fn=chatbot_func,
271
- type='messages',
272
- chatbot=bot,
273
- title="Chat with your data file",
274
- concurrency_limit=None,
275
- examples=example_questions,
276
- additional_inputs=[session_hash, data_source, titles, schema]
277
- )
278
-
279
- def process_upload(upload_value, session_hash):
280
- if upload_value:
281
- process_message = process_data_upload(upload_value, session_hash)
282
- return process_message
283
-
284
-
285
- if __name__ == "__main__":
286
- demo.launch()
 
1
+ import gradio as gr
2
+ from functions import example_question_generator, chatbot_func
3
+ from data_sources import process_data_upload
4
+ from utils import message_dict
5
+ import ast
6
+
7
+ def run_example(input):
8
+ return input
9
+
10
+ def example_display(input):
11
+ if input == None:
12
+ display = True
13
+ else:
14
+ display = False
15
+ return [gr.update(visible=display),gr.update(visible=display),gr.update(visible=display),gr.update(visible=display)]
16
+
17
+ with gr.Blocks() as demo:
18
+ description = gr.HTML("""
19
+ <!-- Header -->
20
+ <div class="max-w-4xl mx-auto mb-12 text-center">
21
+ <div class="bg-blue-50 border border-blue-200 rounded-lg max-w-2xl mx-auto">
22
+ <h2 class="font-semibold text-blue-800 ">
23
+ <i class="fas fa-info-circle mr-2"></i>Supported Files
24
+ </h2>
25
+ <div class="flex flex-wrap justify-center gap-3 pb-4 text-blue-700">
26
+ <span class="tooltip">
27
+ <i class="fas fa-file-csv mr-1"></i>CSV
28
+ <span class="tooltip-text">Comma-separated values</span>
29
+ </span>
30
+ <span class="tooltip">
31
+ <i class="fas fa-file-alt mr-1"></i>TSV
32
+ <span class="tooltip-text">Tab-separated values</span>
33
+ </span>
34
+ <span class="tooltip">
35
+ <i class="fas fa-file-alt mr-1"></i>TXT
36
+ <span class="tooltip-text">Text files</span>
37
+ </span>
38
+ <span class="tooltip">
39
+ <i class="fas fa-file-excel mr-1"></i>XLS/XLSX
40
+ <span class="tooltip-text">Excel spreadsheets</span>
41
+ </span>
42
+ <span class="tooltip">
43
+ <i class="fas fa-file-code mr-1"></i>XML
44
+ <span class="tooltip-text">XML documents</span>
45
+ </span>
46
+ <span class="tooltip">
47
+ <i class="fas fa-file-code mr-1"></i>JSON
48
+ <span class="tooltip-text">JSON data files</span>
49
+ </span>
50
+ </div>
51
+ </div>
52
+ </div>
53
+ """, elem_classes="description_component")
54
+ example_file_1 = gr.File(visible=False, value="samples/bank_marketing_campaign.csv")
55
+ example_file_2 = gr.File(visible=False, value="samples/online_retail_data.csv")
56
+ example_file_3 = gr.File(visible=False, value="samples/tb_illness_data.csv")
57
+ with gr.Row():
58
+ example_btn_1 = gr.Button(value="Try Me: bank_marketing_campaign.csv", elem_classes="sample-btn bg-gradient-to-r from-purple-500 to-indigo-600 text-white p-6 rounded-lg text-left hover:shadow-lg", size="md", variant="primary")
59
+ example_btn_2 = gr.Button(value="Try Me: online_retail_data.csv", elem_classes="sample-btn bg-gradient-to-r from-purple-500 to-indigo-600 text-white p-6 rounded-lg text-left hover:shadow-lg", size="md", variant="primary")
60
+ example_btn_3 = gr.Button(value="Try Me: tb_illness_data.csv", elem_classes="sample-btn bg-gradient-to-r from-purple-500 to-indigo-600 text-white p-6 rounded-lg text-left hover:shadow-lg", size="md", variant="primary")
61
+
62
+ file_output = gr.File(label="Data File (CSV, TSV, TXT, XLS, XLSX, XML, JSON)", show_label=True, elem_classes="file_marker drop-zone border-2 border-dashed border-gray-300 rounded-lg hover:border-primary cursor-pointer bg-gray-50 hover:bg-blue-50 transition-colors duration-300", file_types=['.csv','.xlsx','.txt','.json','.ndjson','.xml','.xls','.tsv'])
63
+ example_btn_1.click(fn=run_example, inputs=example_file_1, outputs=file_output)
64
+ example_btn_2.click(fn=run_example, inputs=example_file_2, outputs=file_output)
65
+ example_btn_3.click(fn=run_example, inputs=example_file_3, outputs=file_output)
66
+ file_output.change(fn=example_display, inputs=file_output, outputs=[example_btn_1, example_btn_2, example_btn_3, description])
67
+
68
+ @gr.render(inputs=file_output)
69
+ def data_options(filename, request: gr.Request):
70
+ print(filename)
71
+ if request.session_hash not in message_dict:
72
+ message_dict[request.session_hash] = {}
73
+ message_dict[request.session_hash]['file_upload'] = None
74
+ if filename:
75
+ process_message = process_upload(filename, request.session_hash)
76
+ gr.HTML(value=process_message[1], padding=False)
77
+ if process_message[0] == "success":
78
+ if "bank_marketing_campaign" in filename:
79
+ example_questions = [
80
+ ["Describe the dataset"],
81
+ ["What levels of education have the highest and lowest average balance?"],
82
+ ["What job is most and least common for a yes response from the individuals, not counting 'unknown'?"],
83
+ ["Can you generate a bar chart of education vs. average balance?"],
84
+ ["Can you generate a table of levels of education versus average balance, percent married, percent with a loan, and percent in default?"],
85
+ ["Can we predict the relationship between the number of contacts performed before this campaign and the average balance?"],
86
+ ["Can you plot the number of contacts performed before this campaign versus the duration and use balance as the size in a bubble chart?"]
87
+ ]
88
+ elif "online_retail_data" in filename:
89
+ example_questions = [
90
+ ["Describe the dataset"],
91
+ ["What month had the highest revenue?"],
92
+ ["Is revenue higher in the morning or afternoon?"],
93
+ ["Can you generate a line graph of revenue per month?"],
94
+ ["Can you generate a table of revenue per month?"],
95
+ ["Can we predict how time of day affects transaction value in this data set?"],
96
+ ["Can you plot revenue per month with size being the number of units sold that month in a bubble chart?"]
97
+ ]
98
+ else:
99
+ try:
100
+ generated_examples = ast.literal_eval(example_question_generator(request.session_hash, 'file_upload', '', process_message[1], ''))
101
+ example_questions = [
102
+ ["Describe the dataset"]
103
+ ]
104
+ for example in generated_examples:
105
+ example_questions.append([example])
106
+ except Exception as e:
107
+ print("DATA FILE QUESTION GENERATION ERROR")
108
+ print(e)
109
+ example_questions = [
110
+ ["Describe the dataset"],
111
+ ["List the columns in the dataset"],
112
+ ["What could this data be used for?"],
113
+ ]
114
+ session_hash = gr.Textbox(visible=False, value=request.session_hash)
115
+ data_source = gr.Textbox(visible=False, value='file_upload')
116
+ schema = gr.Textbox(visible=False, value='')
117
+ titles = gr.Textbox(value=process_message[1], interactive=False, visible=False)
118
+ bot = gr.Chatbot(type='messages', label="CSV Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
119
+ chat = gr.ChatInterface(
120
+ fn=chatbot_func,
121
+ type='messages',
122
+ chatbot=bot,
123
+ title="Chat with your data file",
124
+ concurrency_limit=None,
125
+ examples=example_questions,
126
+ additional_inputs=[session_hash, data_source, titles, schema]
127
+ )
128
+
129
+ def process_upload(upload_value, session_hash):
130
+ if upload_value:
131
+ process_message = process_data_upload(upload_value, session_hash)
132
+ return process_message
133
+
134
+
135
+ if __name__ == "__main__":
136
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
templates/doc_db.py CHANGED
@@ -1,105 +1,99 @@
1
- import ast
2
- import gradio as gr
3
- from functions import example_question_generator, chatbot_func
4
- from data_sources import connect_doc_db
5
- from utils import message_dict
6
-
7
- with gr.Blocks() as demo:
8
- with gr.Accordion("ℹ️ About the MongoDB Connector", open=False):
9
- gr.HTML("""
10
- <div class="max-w-4xl mx-auto text-center">
11
- <div class="bg-blue-50 border border-blue-200 rounded-lg max-w-2xl mx-auto p-4">
12
- <p>Connect to a MongoDB database and query it using natural language.</p>
13
- <p style="font-weight:bold;">
14
- No credentials are retained they are passed as session variables and disappear when you leave or refresh.
15
- Queries use PyMongoArrow's <code>aggregate_pandas_all</code>, which cannot delete, drop, or insert documents.
16
- Use caution connecting production databases to third-party tools.
17
- </p>
18
- <p>Contact me if you'd like this built for your organization with proper infrastructure and security controls.</p>
19
- </div>
20
- </div>
21
- """)
22
-
23
- gr.HTML("""
24
- <div style="max-width:560px;margin:8px auto 4px;padding:8px 14px;background:#f0f9ff;
25
- border:1px solid #bae6fd;border-radius:8px;text-align:center;">
26
- <p style="margin:0;font-size:13px;color:#0369a1;">
27
- <i class="fas fa-flask" style="margin-right:6px;"></i>
28
- <strong>Demo credentials pre-filled.</strong>
29
- &nbsp;Replace with your own database details to analyze your own data.
30
- </p>
31
- </div>
32
- """)
33
-
34
- connection_string = gr.Textbox(label="Connection String", value="dataanalyst0.l1klmww.mongodb.net/")
35
- with gr.Row():
36
- connection_user = gr.Textbox(label="Connection User", value="virtual-data-analyst")
37
- connection_password = gr.Textbox(label="Connection Password", value="zcpbmoGJ3mC8o", type="password")
38
- doc_db_name = gr.Textbox(label="Database Name", value="sample_mflix")
39
-
40
- gr.HTML("""
41
- <p style="text-align:center;font-size:13px;color:#6b7280;margin:4px 0 8px;">
42
- <i class="fas fa-circle-info" style="margin-right:4px;"></i>
43
- Schema analysis runs on connect — this may take 1–2 minutes for large databases.
44
- </p>
45
- """)
46
- submit = gr.Button(value="Connect", variant="primary")
47
-
48
- @gr.render(inputs=[connection_string, connection_user, connection_password, doc_db_name], triggers=[submit.click])
49
- def db_chat(request: gr.Request, connection_string=connection_string.value, connection_user=connection_user.value, connection_password=connection_password.value, doc_db_name=doc_db_name.value):
50
- if request.session_hash not in message_dict:
51
- message_dict[request.session_hash] = {}
52
- message_dict[request.session_hash]['doc_db'] = None
53
- connection_login_value = "mongodb+srv://" + connection_user + ":" + connection_password + "@" + connection_string
54
- if connection_login_value:
55
- print("MONGO APP")
56
- process_message = process_doc_db(connection_login_value, doc_db_name, request.session_hash)
57
- gr.HTML(value=process_message[1], padding=False)
58
- if process_message[0] == "success":
59
- if "dataanalyst0.l1klmww.mongodb.net" in connection_login_value:
60
- example_questions = [
61
- ["Describe the dataset"],
62
- ["What are the top 5 most common movie genres?"],
63
- ["How do user comment counts on a movie correlate with the movie award wins?"],
64
- ["Can you generate a pie chart showing the top 10 states with the most movie theaters?"],
65
- ["What are the top 10 most represented directors in the database?"],
66
- ["What are the different movie categories and how many movies are in each category?"]
67
- ]
68
- else:
69
- try:
70
- generated_examples = ast.literal_eval(example_question_generator(request.session_hash, 'doc_db', doc_db_name, process_message[2], process_message[3]))
71
- example_questions = [["Describe the dataset"]]
72
- for example in generated_examples:
73
- example_questions.append([example])
74
- except Exception as e:
75
- print("DOC DB QUESTION GENERATION ERROR")
76
- print(e)
77
- example_questions = [
78
- ["Describe the dataset"],
79
- ["List the collections in the database"],
80
- ["What could this data be used for?"],
81
- ]
82
- session_hash = gr.Textbox(visible=False, value=request.session_hash)
83
- db_connection_string = gr.Textbox(visible=False, value=connection_login_value)
84
- db_name = gr.Textbox(visible=False, value=doc_db_name)
85
- titles = gr.Textbox(value=process_message[2], interactive=False, label="DB Collections")
86
- data_source = gr.Textbox(visible=False, value='doc_db')
87
- schema = gr.Textbox(visible=False, value=process_message[3])
88
- bot = gr.Chatbot(type='messages', label="MongoDB Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
89
- chat = gr.ChatInterface(
90
- fn=chatbot_func,
91
- type='messages',
92
- chatbot=bot,
93
- title="Chat with your Database",
94
- examples=example_questions,
95
- concurrency_limit=None,
96
- additional_inputs=[session_hash, data_source, titles, schema, db_connection_string, db_name]
97
- )
98
-
99
- def process_doc_db(connection_string, nosql_db_name, session_hash):
100
- if connection_string:
101
- process_message = connect_doc_db(connection_string, nosql_db_name, session_hash)
102
- return process_message
103
-
104
- if __name__ == "__main__":
105
- demo.launch()
 
1
+ import ast
2
+ import gradio as gr
3
+ from functions import example_question_generator, chatbot_func
4
+ from data_sources import connect_doc_db
5
+ from utils import message_dict
6
+
7
+ def hide_info():
8
+ return gr.update(visible=False)
9
+
10
+ with gr.Blocks() as demo:
11
+ description = gr.HTML("""
12
+ <!-- Header -->
13
+ <div class="max-w-4xl mx-auto mb-12 text-center">
14
+ <div class="bg-blue-50 border border-blue-200 rounded-lg max-w-2xl mx-auto">
15
+ <p>This tool allows users to communicate with and query real time data from a Document DB (MongoDB for now, others can be added if requested) using natural
16
+ language and the above features.</p>
17
+ <p style="font-weight:bold;">Notice: the way this system is designed, no login information is retained and credentials are passed as session variables until the user leaves or
18
+ refreshes the page in which they disappear. They are never saved to any files. I also make use of the PyMongoArrow aggregate_pandas_all function to apply pipelines,
19
+ which can't delete, drop, or add database lines to avoid unhappy accidents or glitches.
20
+ That being said, it's probably best to use caution when connecting to a production database to a strange AI tool with an unfamiliar author.
21
+ This should be for demonstration purposes.</p>
22
+ <p>Contact me if this is something you would like built in your organization, on your infrastructure, and with the requisite privacy and control a production
23
+ database analytics tool requires.</p>
24
+ </div>
25
+ </div>
26
+ """, elem_classes="description_component")
27
+
28
+ status_message = gr.HTML(value='<p style="color:green;text-align:center;font-size:18px;">Please be patient while connecting as we need to generate '
29
+ 'and read a schema before connection can be successful. This process can take a few minutes.</p>', padding=False)
30
+
31
+ connection_string = gr.Textbox(label="Connection String", value="dataanalyst0.l1klmww.mongodb.net/")
32
+ with gr.Row():
33
+ connection_user = gr.Textbox(label="Connection User", value="virtual-data-analyst")
34
+ connection_password = gr.Textbox(label="Connection Password", value="zcpbmoGJ3mC8o", type="password")
35
+ doc_db_name = gr.Textbox(label="Database Name", value="sample_mflix")
36
+
37
+ submit = gr.Button(value="Submit")
38
+ submit.click(fn=hide_info, outputs=description)
39
+
40
+ @gr.render(inputs=[connection_string,connection_user,connection_password,doc_db_name], triggers=[submit.click])
41
+ def db_chat(request: gr.Request, connection_string=connection_string.value, connection_user=connection_user.value, connection_password=connection_password.value, doc_db_name=doc_db_name.value):
42
+ if request.session_hash not in message_dict:
43
+ message_dict[request.session_hash] = {}
44
+ message_dict[request.session_hash]['doc_db'] = None
45
+ connection_login_value = "mongodb+srv://" + connection_user + ":" + connection_password + "@" + connection_string
46
+ if connection_login_value:
47
+ print("MONGO APP")
48
+ process_message = process_doc_db(connection_login_value, doc_db_name, request.session_hash)
49
+ gr.HTML(value=process_message[1], padding=False)
50
+ if process_message[0] == "success":
51
+ if "dataanalyst0.l1klmww.mongodb.net" in connection_login_value:
52
+ example_questions = [
53
+ ["Describe the dataset"],
54
+ ["What are the top 5 most common movie genres?"],
55
+ ["How do user comment counts on a movie correlate with the movie award wins?"],
56
+ ["Can you generate a pie chart showing the top 10 states with the most movie theaters?"],
57
+ ["What are the top 10 most represented directors in the database?"],
58
+ ["What are the different movie categories and how many movies are in each category?"]
59
+ ]
60
+ else:
61
+ try:
62
+ generated_examples = ast.literal_eval(example_question_generator(request.session_hash, 'graphql', doc_db_name, process_message[2], process_message[3]))
63
+ example_questions = [
64
+ ["Describe the dataset"]
65
+ ]
66
+ for example in generated_examples:
67
+ example_questions.append([example])
68
+ except Exception as e:
69
+ print("DOC DB QUESTION GENERATION ERROR")
70
+ print(e)
71
+ example_questions = [
72
+ ["Describe the dataset"],
73
+ ["List the columns in the dataset"],
74
+ ["What could this data be used for?"],
75
+ ]
76
+ session_hash = gr.Textbox(visible=False, value=request.session_hash)
77
+ db_connection_string = gr.Textbox(visible=False, value=connection_login_value)
78
+ db_name = gr.Textbox(visible=False, value=doc_db_name)
79
+ titles = gr.Textbox(value=process_message[2], interactive=False, label="DB Collections")
80
+ data_source = gr.Textbox(visible=False, value='doc_db')
81
+ schema = gr.Textbox(visible=False, value=process_message[3])
82
+ bot = gr.Chatbot(type='messages', label="DocDB Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
83
+ chat = gr.ChatInterface(
84
+ fn=chatbot_func,
85
+ type='messages',
86
+ chatbot=bot,
87
+ title="Chat with your Database",
88
+ examples=example_questions,
89
+ concurrency_limit=None,
90
+ additional_inputs=[session_hash, data_source, titles, schema, db_connection_string, db_name]
91
+ )
92
+
93
+ def process_doc_db(connection_string, nosql_db_name, session_hash):
94
+ if connection_string:
95
+ process_message = connect_doc_db(connection_string, nosql_db_name, session_hash)
96
+ return process_message
97
+
98
+ if __name__ == "__main__":
99
+ demo.launch()
 
 
 
 
 
 
templates/graphql.py CHANGED
@@ -1,110 +1,110 @@
1
- import ast
2
- import gradio as gr
3
- from functions import example_question_generator, chatbot_func
4
- from data_sources import connect_graphql
5
- from utils import message_dict
6
-
7
- import os
8
- from dotenv import load_dotenv
9
-
10
- load_dotenv()
11
-
12
- graphql_sample_endpoint = os.getenv("GRAPHQL_SAMPLE_ENDPOINT")
13
- graphql_sample_api_token = os.getenv("GRAPHQL_SAMPLE_API_TOKEN")
14
- graphql_sample_header_name = os.getenv("GRAPHQL_SAMPLE_HEADER_NAME")
15
-
16
- with gr.Blocks() as demo:
17
- with gr.Accordion("ℹ️ About the GraphQL Connector", open=False):
18
- gr.HTML("""
19
- <div class="max-w-4xl mx-auto text-center">
20
- <div class="bg-blue-50 border border-blue-200 rounded-lg max-w-2xl mx-auto p-4">
21
- <p>Connect to any GraphQL API endpoint and query it using natural language.</p>
22
- <p style="font-weight:bold;">
23
- API querying is the most experimental feature and performance may vary.
24
- No credentials are retained they are passed as session variables and disappear when you leave or refresh.
25
- Mutations are not exposed and the agent is instructed not to alter data, though restricting
26
- your API token's permissions is still strongly recommended.
27
- </p>
28
- <p>Contact me if you'd like this built for your organization with proper infrastructure and security controls.</p>
29
- </div>
30
- </div>
31
- """)
32
-
33
- gr.HTML("""
34
- <div style="max-width:560px;margin:8px auto 4px;padding:8px 14px;background:#f0f9ff;
35
- border:1px solid #bae6fd;border-radius:8px;text-align:center;">
36
- <p style="margin:0;font-size:13px;color:#0369a1;">
37
- <i class="fas fa-flask" style="margin-right:6px;"></i>
38
- <strong>Demo credentials pre-filled.</strong>
39
- &nbsp;Replace with your own endpoint and token to analyze your own API.
40
- </p>
41
- </div>
42
- """)
43
-
44
- graphql_url = gr.Textbox(label="GraphQL Endpoint URL", value=graphql_sample_endpoint)
45
- with gr.Row():
46
- api_token_header_name = gr.Textbox(label="API Token Header Name", value=graphql_sample_header_name)
47
- api_token = gr.Textbox(label="API Token", value=graphql_sample_api_token, type="password")
48
-
49
- submit = gr.Button(value="Connect", variant="primary")
50
-
51
- @gr.render(inputs=[graphql_url, api_token, api_token_header_name], triggers=[submit.click])
52
- def api_chat(request: gr.Request, graphql_url=graphql_url.value, api_token=api_token.value, api_token_header_name=api_token_header_name.value):
53
- if request.session_hash not in message_dict:
54
- message_dict[request.session_hash] = {}
55
- message_dict[request.session_hash]['graphql'] = None
56
- if graphql_url:
57
- print("GraphQL API")
58
- process_message = process_graphql(graphql_url, api_token, api_token_header_name, request.session_hash)
59
- gr.HTML(value=process_message[1], padding=False)
60
- if process_message[0] == "success":
61
- if "qdl-app-testing" in graphql_url:
62
- example_questions = [
63
- ["Describe the dataset"],
64
- ["What is the total revenue for this shopify store?"],
65
- ["What is the average duration from the fulfillment of an order to its delivery?"],
66
- ["What is the total value of orders processed in the current month?"],
67
- ["Which product has the highest number of variants in the inventory?"],
68
- ["How many gift cards have been issued this year, and what is their total value?"],
69
- ["How many active apps are currently installed on the store?"],
70
- ["What is the total count of abandoned checkouts over the last month?"]
71
- ]
72
- else:
73
- try:
74
- generated_examples = ast.literal_eval(example_question_generator(request.session_hash, 'graphql', graphql_url, process_message[2], ''))
75
- example_questions = [["Describe the dataset"]]
76
- for example in generated_examples:
77
- example_questions.append([example])
78
- except Exception as e:
79
- print("GRAPHQL QUESTION GENERATION ERROR")
80
- print(e)
81
- example_questions = [
82
- ["Describe the dataset"],
83
- ["List the types in this API"],
84
- ["What could this data be used for?"],
85
- ]
86
- session_hash = gr.Textbox(visible=False, value=request.session_hash)
87
- graphql_api_string = gr.Textbox(visible=False, value=graphql_url)
88
- graphql_api_token = gr.Textbox(visible=False, value=api_token)
89
- graphql_token_header = gr.Textbox(visible=False, value=api_token_header_name)
90
- titles = gr.Textbox(value=process_message[2], interactive=False, label="GraphQL Types")
91
- data_source = gr.Textbox(visible=False, value='graphql')
92
- schema = gr.Textbox(visible=False, value='')
93
- bot = gr.Chatbot(type='messages', label="GraphQL Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
94
- chat = gr.ChatInterface(
95
- fn=chatbot_func,
96
- type='messages',
97
- chatbot=bot,
98
- title="Chat with your GraphQL API",
99
- examples=example_questions,
100
- concurrency_limit=None,
101
- additional_inputs=[session_hash, data_source, titles, schema, graphql_api_string, graphql_api_token, graphql_token_header]
102
- )
103
-
104
- def process_graphql(graphql_url, api_token, api_token_header_name, session_hash):
105
- if graphql_url:
106
- process_message = connect_graphql(graphql_url, api_token, api_token_header_name, session_hash)
107
- return process_message
108
-
109
- if __name__ == "__main__":
110
- demo.launch()
 
1
+ import ast
2
+ import gradio as gr
3
+ from functions import example_question_generator, chatbot_func
4
+ from data_sources import connect_graphql
5
+ from utils import message_dict
6
+
7
+ import os
8
+ from dotenv import load_dotenv
9
+
10
+ load_dotenv()
11
+
12
+ graphql_sample_endpoint = os.getenv("GRAPHQL_SAMPLE_ENDPOINT")
13
+ graphql_sample_api_token = os.getenv("GRAPHQL_SAMPLE_API_TOKEN")
14
+ graphql_sample_header_name = os.getenv("GRAPHQL_SAMPLE_HEADER_NAME")
15
+
16
+ def hide_info():
17
+ return gr.update(visible=False)
18
+
19
+ with gr.Blocks() as demo:
20
+ description = gr.HTML("""
21
+ <!-- Header -->
22
+ <div class="max-w-4xl mx-auto mb-12 text-center">
23
+ <div class="bg-blue-50 border border-blue-200 rounded-lg max-w-2xl mx-auto">
24
+ <p>This tool allows users to communicate with and query real time data from a GraphQL API endpoint using natural
25
+ language and the above features.</p>
26
+ <p style="font-weight:bold;">Notice: API querying is the most difficult and experimental feature so far.
27
+ This tool may have variable performance and quality, although it should get better over time as I evaluate use.
28
+ No login information is retained and credentials are passed as session variables until the user leaves or
29
+ refreshes the page in which they disappear. They are never saved to any files.</p>
30
+ <p style="font-weight:bold;"> I don't include a function that allows the system to run mutations and I instruct the agent to not alter any data, but it could in theory be possible,
31
+ although my testing wasn't able to get the system to alter or write to the api. I would be careful to make sure permissions are restricted for the
32
+ api token being used.
33
+ And of course, it's probably best to use caution when connecting to a strange AI tool with an unfamiliar author.
34
+ This should be for demonstration purposes.</p>
35
+ <p>Contact me if this is something you would like built in your organization, on your infrastructure, and with the requisite privacy and control a production
36
+ database analytics tool requires.</p>
37
+ </div>
38
+ </div>
39
+ """, elem_classes="description_component")
40
+
41
+ graphql_url = gr.Textbox(label="GraphQL Endpoint URL", value=graphql_sample_endpoint)
42
+ with gr.Row():
43
+ api_token_header_name = gr.Textbox(label="API Token Header Name", value=graphql_sample_header_name)
44
+ api_token = gr.Textbox(label="API Token", value=graphql_sample_api_token, type="password")
45
+
46
+ submit = gr.Button(value="Submit")
47
+ submit.click(fn=hide_info, outputs=description)
48
+
49
+ @gr.render(inputs=[graphql_url,api_token,api_token_header_name], triggers=[submit.click])
50
+ def api_chat(request: gr.Request, graphql_url=graphql_url.value, api_token=api_token.value, api_token_header_name=api_token_header_name.value):
51
+ if request.session_hash not in message_dict:
52
+ message_dict[request.session_hash] = {}
53
+ message_dict[request.session_hash]['graphql'] = None
54
+ if graphql_url:
55
+ print("GraphQL API")
56
+ process_message = process_graphql(graphql_url, api_token, api_token_header_name, request.session_hash)
57
+ gr.HTML(value=process_message[1], padding=False)
58
+ if process_message[0] == "success":
59
+ if "qdl-app-testing" in graphql_url:
60
+ example_questions = [
61
+ ["Describe the dataset"],
62
+ ["What is the total revenue for this shopify store?"],
63
+ ["What is the average duration from the fulfillment of an order to its delivery?"],
64
+ ["What is the total value of orders processed in the current month?"],
65
+ ["Which product has the highest number of variants in the inventory?"],
66
+ ["How many gift cards have been issued this year, and what is their total value?"],
67
+ ["How many active apps are currently installed on the store?"],
68
+ ["What is the total count of abandoned checkouts over the last month?"]
69
+ ]
70
+ else:
71
+ try:
72
+ generated_examples = ast.literal_eval(example_question_generator(request.session_hash, 'graphql', graphql_url, process_message[2], ''))
73
+ example_questions = [
74
+ ["Describe the dataset"]
75
+ ]
76
+ for example in generated_examples:
77
+ example_questions.append([example])
78
+ except Exception as e:
79
+ print("GRAPHQL QUESTION GENERATION ERROR")
80
+ print(e)
81
+ example_questions = [
82
+ ["Describe the dataset"],
83
+ ["List the columns in the dataset"],
84
+ ["What could this data be used for?"],
85
+ ]
86
+ session_hash = gr.Textbox(visible=False, value=request.session_hash)
87
+ graphql_api_string = gr.Textbox(visible=False, value=graphql_url)
88
+ graphql_api_token = gr.Textbox(visible=False, value=api_token)
89
+ graphql_token_header = gr.Textbox(visible=False, value=api_token_header_name)
90
+ titles = gr.Textbox(value=process_message[2], interactive=False, label="GraphQL Types")
91
+ data_source = gr.Textbox(visible=False, value='graphql')
92
+ schema = gr.Textbox(visible=False, value='')
93
+ bot = gr.Chatbot(type='messages', label="GraphQL Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
94
+ chat = gr.ChatInterface(
95
+ fn=chatbot_func,
96
+ type='messages',
97
+ chatbot=bot,
98
+ title="Chat with your Graphql API",
99
+ examples=example_questions,
100
+ concurrency_limit=None,
101
+ additional_inputs=[session_hash, data_source, titles, schema, graphql_api_string, graphql_api_token, graphql_token_header]
102
+ )
103
+
104
+ def process_graphql(graphql_url, api_token, api_token_header_name, session_hash):
105
+ if graphql_url:
106
+ process_message = connect_graphql(graphql_url, api_token, api_token_header_name, session_hash)
107
+ return process_message
108
+
109
+ if __name__ == "__main__":
110
+ demo.launch()
templates/sql_db.py CHANGED
@@ -1,102 +1,98 @@
1
- import ast
2
- import gradio as gr
3
- from functions import example_question_generator, chatbot_func
4
- from data_sources import connect_sql_db
5
- from utils import message_dict
6
-
7
- with gr.Blocks() as demo:
8
- with gr.Accordion("ℹ️ About the SQL Connector", open=False):
9
- gr.HTML("""
10
- <div class="max-w-4xl mx-auto text-center">
11
- <div class="bg-blue-50 border border-blue-200 rounded-lg max-w-2xl mx-auto p-4">
12
- <p>Connect to a PostgreSQL database and query it using natural language.</p>
13
- <p style="font-weight:bold;">
14
- No credentials are retained they are passed as session variables and disappear when you leave or refresh.
15
- Queries run through Pandas <code>read_sql_query</code>, which cannot delete, drop, or insert rows.
16
- Use caution connecting production databases to third-party tools.
17
- </p>
18
- <p>Contact me if you'd like this built for your organization with proper infrastructure and security controls.</p>
19
- </div>
20
- </div>
21
- """)
22
-
23
- gr.HTML("""
24
- <div style="max-width:560px;margin:8px auto 4px;padding:8px 14px;background:#f0f9ff;
25
- border:1px solid #bae6fd;border-radius:8px;text-align:center;">
26
- <p style="margin:0;font-size:13px;color:#0369a1;">
27
- <i class="fas fa-flask" style="margin-right:6px;"></i>
28
- <strong>Demo credentials pre-filled.</strong>
29
- &nbsp;Replace with your own database details to analyze your own data.
30
- </p>
31
- </div>
32
- """)
33
-
34
- sql_url = gr.Textbox(label="URL", value="virtual-data-analyst-pg.cyetm2yjzppu.us-west-1.rds.amazonaws.com")
35
- with gr.Row():
36
- sql_port = gr.Textbox(label="Port", value="5432")
37
- sql_user = gr.Textbox(label="Username", value="postgres")
38
- sql_pass = gr.Textbox(label="Password", value="Vda-1988", type="password")
39
- sql_db_name = gr.Textbox(label="Database Name", value="dvdrental")
40
-
41
- submit = gr.Button(value="Connect", variant="primary")
42
-
43
- @gr.render(inputs=[sql_url, sql_port, sql_user, sql_pass, sql_db_name], triggers=[submit.click])
44
- def sql_chat(request: gr.Request, url=sql_url.value, sql_port=sql_port.value, sql_user=sql_user.value, sql_pass=sql_pass.value, sql_db_name=sql_db_name.value):
45
- if request.session_hash not in message_dict:
46
- message_dict[request.session_hash] = {}
47
- message_dict[request.session_hash]['sql'] = None
48
- if url:
49
- print("SQL APP")
50
- process_message = process_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, request.session_hash)
51
- gr.HTML(value=process_message[1], padding=False)
52
- if process_message[0] == "success":
53
- if "virtual-data-analyst-pg.cyetm2yjzppu.us-west-1.rds.amazonaws.com" in url:
54
- example_questions = [
55
- ["Describe the dataset"],
56
- ["What is the total revenue generated by each store?"],
57
- ["Can you generate and display a bar chart of film category to number of films in that category?"],
58
- ["Can you generate a pie chart showing the top 10 most rented films by revenue?"],
59
- ["Can you generate a line chart of rental revenue over time?"],
60
- ["What is the relationship between film length and rental frequency?"]
61
- ]
62
- else:
63
- try:
64
- generated_examples = ast.literal_eval(example_question_generator(request.session_hash, 'sql', sql_db_name, process_message[2], ""))
65
- example_questions = [["Describe the dataset"]]
66
- for example in generated_examples:
67
- example_questions.append([example])
68
- except Exception as e:
69
- print("SQL QUESTION GENERATION ERROR")
70
- print(e)
71
- example_questions = [
72
- ["Describe the dataset"],
73
- ["List the tables in the database"],
74
- ["What could this data be used for?"],
75
- ]
76
- session_hash = gr.Textbox(visible=False, value=request.session_hash)
77
- db_url = gr.Textbox(visible=False, value=url)
78
- db_port = gr.Textbox(visible=False, value=sql_port)
79
- db_user = gr.Textbox(visible=False, value=sql_user)
80
- db_pass = gr.Textbox(visible=False, value=sql_pass)
81
- db_name = gr.Textbox(visible=False, value=sql_db_name)
82
- titles = gr.Textbox(value=process_message[2], interactive=False, label="SQL Tables")
83
- data_source = gr.Textbox(visible=False, value='sql')
84
- schema = gr.Textbox(visible=False, value='')
85
- bot = gr.Chatbot(type='messages', label="SQL DB Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
86
- chat = gr.ChatInterface(
87
- fn=chatbot_func,
88
- type='messages',
89
- chatbot=bot,
90
- title="Chat with your Database",
91
- examples=example_questions,
92
- concurrency_limit=None,
93
- additional_inputs=[session_hash, data_source, titles, schema, db_url, db_port, db_user, db_pass, db_name]
94
- )
95
-
96
- def process_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash):
97
- if url:
98
- process_message = connect_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash)
99
- return process_message
100
-
101
- if __name__ == "__main__":
102
- demo.launch()
 
1
+ import ast
2
+ import gradio as gr
3
+ from functions import example_question_generator, chatbot_func
4
+ from data_sources import connect_sql_db
5
+ from utils import message_dict
6
+
7
+ def hide_info():
8
+ return gr.update(visible=False)
9
+
10
+ with gr.Blocks() as demo:
11
+ description = gr.HTML("""
12
+ <!-- Header -->
13
+ <div class="max-w-4xl mx-auto mb-12 text-center">
14
+ <div class="bg-blue-50 border border-blue-200 rounded-lg max-w-2xl mx-auto">
15
+ <p>This tool allows users to communicate with and query real time data from a SQL DB (postgres for now, others can be added if requested) using natural
16
+ language and the above features.</p>
17
+ <p style="font-weight:bold;">Notice: the way this system is designed, no login information is retained and credentials are passed as session variables until the user leaves or
18
+ refreshes the page in which they disappear. They are never saved to any files. I also make use of the Pandas read_sql_query function to apply SQL
19
+ queries, which can't delete, drop, or add database lines to avoid unhappy accidents or glitches.
20
+ That being said, it's probably best to use caution when connecting to a production database to a strange AI tool with an unfamiliar author.
21
+ This should be for demonstration purposes.</p>
22
+ <p>Contact me if this is something you would like built in your organization, on your infrastructure, and with the requisite privacy and control a production
23
+ database analytics tool requires.</p>
24
+ </div>
25
+ </div>
26
+ """, elem_classes="description_component")
27
+ sql_url = gr.Textbox(label="URL", value="virtual-data-analyst-pg.cyetm2yjzppu.us-west-1.rds.amazonaws.com")
28
+ with gr.Row():
29
+ sql_port = gr.Textbox(label="Port", value="5432")
30
+ sql_user = gr.Textbox(label="Username", value="postgres")
31
+ sql_pass = gr.Textbox(label="Password", value="Vda-1988", type="password")
32
+ sql_db_name = gr.Textbox(label="Database Name", value="dvdrental")
33
+
34
+ submit = gr.Button(value="Submit")
35
+ submit.click(fn=hide_info, outputs=description)
36
+
37
+ @gr.render(inputs=[sql_url,sql_port,sql_user,sql_pass,sql_db_name], triggers=[submit.click])
38
+ def sql_chat(request: gr.Request, url=sql_url.value, sql_port=sql_port.value, sql_user=sql_user.value, sql_pass=sql_pass.value, sql_db_name=sql_db_name.value):
39
+ if request.session_hash not in message_dict:
40
+ message_dict[request.session_hash] = {}
41
+ message_dict[request.session_hash]['sql'] = None
42
+ if url:
43
+ print("SQL APP")
44
+ process_message = process_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, request.session_hash)
45
+ gr.HTML(value=process_message[1], padding=False)
46
+ if process_message[0] == "success":
47
+ if "virtual-data-analyst-pg.cyetm2yjzppu.us-west-1.rds.amazonaws.com" in url:
48
+ example_questions = [
49
+ ["Describe the dataset"],
50
+ ["What is the total revenue generated by each store?"],
51
+ ["Can you generate and display a bar chart of film category to number of films in that category?"],
52
+ ["Can you generate a pie chart showing the top 10 most rented films by revenue vs all other films?"],
53
+ ["Can you generate a line chart of rental revenue over time?"],
54
+ ["What is the relationship between film length and rental frequency?"]
55
+ ]
56
+ else:
57
+ try:
58
+ generated_examples = ast.literal_eval(example_question_generator(request.session_hash, 'sql', sql_db_name, process_message[2], ""))
59
+ example_questions = [
60
+ ["Describe the dataset"]
61
+ ]
62
+ for example in generated_examples:
63
+ example_questions.append([example])
64
+ except Exception as e:
65
+ print("SQL QUESTION GENERATION ERROR")
66
+ print(e)
67
+ example_questions = [
68
+ ["Describe the dataset"],
69
+ ["List the columns in the dataset"],
70
+ ["What could this data be used for?"],
71
+ ]
72
+ session_hash = gr.Textbox(visible=False, value=request.session_hash)
73
+ db_url = gr.Textbox(visible=False, value=url)
74
+ db_port = gr.Textbox(visible=False, value=sql_port)
75
+ db_user = gr.Textbox(visible=False, value=sql_user)
76
+ db_pass = gr.Textbox(visible=False, value=sql_pass)
77
+ db_name = gr.Textbox(visible=False, value=sql_db_name)
78
+ titles = gr.Textbox(value=process_message[2], interactive=False, label="SQL Tables")
79
+ data_source = gr.Textbox(visible=False, value='sql')
80
+ schema = gr.Textbox(visible=False, value='')
81
+ bot = gr.Chatbot(type='messages', label="SQL DB Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
82
+ chat = gr.ChatInterface(
83
+ fn=chatbot_func,
84
+ type='messages',
85
+ chatbot=bot,
86
+ title="Chat with your Database",
87
+ examples=example_questions,
88
+ concurrency_limit=None,
89
+ additional_inputs=[session_hash, data_source, titles, schema, db_url, db_port, db_user, db_pass, db_name]
90
+ )
91
+
92
+ def process_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash):
93
+ if url:
94
+ process_message = connect_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash)
95
+ return process_message
96
+
97
+ if __name__ == "__main__":
98
+ demo.launch()
 
 
 
 
tools/__init__.py DELETED
File without changes
tools/chart_tools.py CHANGED
@@ -1,308 +1,371 @@
1
- # Shared parameter snippets reused across chart tool schemas.
2
- # Update here to change the description everywhere at once.
3
-
4
- _LAYOUT_PARAM = {
5
- "type": "array",
6
- "description": (
7
- "Optional. An array containing a single JSON-formatted Plotly layout dictionary. "
8
- "Use to set chart title, axis labels, colours, fonts, and other layout properties. "
9
- "Example: [{\"title\": \"Monthly Sales\", \"xaxis\": {\"title\": \"Month\"}}]"
10
- ),
11
- "items": {"type": "string"},
12
- }
13
-
14
- _TRACE_STYLE_PARAM = {
15
- "type": "array",
16
- "description": (
17
- "Optional. An array containing a single JSON-formatted Plotly trace styling dictionary. "
18
- "Use to control visual properties such as line colour, opacity, and marker style. "
19
- "Do NOT include 'x', 'y', or 'type' keys — those are set automatically from query.csv."
20
- ),
21
- "items": {"type": "string"},
22
- }
23
-
24
- chart_tool_schemas = [
25
- {
26
- "name": "scatter_chart_generation_func",
27
- "description": (
28
- "Generates a Plotly scatter plot from query.csv data. "
29
- "Use when the user wants to visualise the relationship between two numeric columns, "
30
- "create a bubble chart (via the size parameter), or overlay a trendline. "
31
- "Returns an HTML iframe — display it verbatim in the chat."
32
- ),
33
- "parameters": {
34
- "type": "object",
35
- "properties": {
36
- "x_column": {
37
- "type": "array",
38
- "description": (
39
- "One or more column names from query.csv to plot on the x-axis. "
40
- "Multiple columns produce multiple series, each plotted against y_column."
41
- ),
42
- "items": {"type": "string"},
43
- },
44
- "y_column": {
45
- "type": "string",
46
- "description": "Column name from query.csv to plot on the y-axis.",
47
- },
48
- "category": {
49
- "type": "string",
50
- "description": "Optional column name used to colour-code points by a categorical grouping.",
51
- },
52
- "trendline": {
53
- "type": "string",
54
- "description": (
55
- "Optional trendline type. One of: 'ols' (linear regression), "
56
- "'lowess' (local smoothing), 'rolling', 'ewm', 'expanding'. "
57
- "Requires trendline_options when using 'lowess', 'rolling', or 'ewm'."
58
- ),
59
- },
60
- "trendline_options": {
61
- "type": "array",
62
- "description": (
63
- "Required when trendline is 'lowess', 'rolling', or 'ewm'. "
64
- "An array containing a single JSON-formatted dict of trendline options "
65
- "(e.g. [{\"window\": 7}] for a 7-point rolling average)."
66
- ),
67
- "items": {"type": "string"},
68
- },
69
- "marginal_x": {
70
- "type": "string",
71
- "description": "Optional marginal distribution plot along the x-axis. One of: 'histogram', 'rug', 'box', 'violin'.",
72
- },
73
- "marginal_y": {
74
- "type": "string",
75
- "description": "Optional marginal distribution plot along the y-axis. One of: 'histogram', 'rug', 'box', 'violin'.",
76
- },
77
- "size": {
78
- "type": "string",
79
- "description": "Optional column name whose values control the size of each point (bubble chart). Negative values are clamped to zero.",
80
- },
81
- "data": _TRACE_STYLE_PARAM,
82
- "layout": _LAYOUT_PARAM,
83
- },
84
- "required": ["x_column", "y_column"],
85
- },
86
- },
87
- {
88
- "name": "line_chart_generation_func",
89
- "description": (
90
- "Generates a Plotly line chart from query.csv data. "
91
- "Use for trends over time or any ordered sequence. "
92
- "Returns an HTML iframe — display it verbatim in the chat."
93
- ),
94
- "parameters": {
95
- "type": "object",
96
- "properties": {
97
- "x_column": {
98
- "type": "string",
99
- "description": "Column name from query.csv for the x-axis (typically a date or ordered index).",
100
- },
101
- "y_column": {
102
- "type": "string",
103
- "description": "Column name from query.csv for the y-axis (numeric values).",
104
- },
105
- "category": {
106
- "type": "string",
107
- "description": "Optional column name used to split the data into multiple colour-coded lines.",
108
- },
109
- "data": _TRACE_STYLE_PARAM,
110
- "layout": _LAYOUT_PARAM,
111
- },
112
- "required": ["x_column", "y_column"],
113
- },
114
- },
115
- {
116
- "name": "bar_chart_generation_func",
117
- "description": (
118
- "Generates a Plotly bar chart from query.csv data. "
119
- "Use for comparing values across categories. Supports grouped/stacked bars via category, "
120
- "and faceted subplots via facet_row or facet_col. "
121
- "Returns an HTML iframe display it verbatim in the chat."
122
- ),
123
- "parameters": {
124
- "type": "object",
125
- "properties": {
126
- "x_column": {
127
- "type": "string",
128
- "description": "Column name from query.csv for the x-axis (category labels).",
129
- },
130
- "y_column": {
131
- "type": "string",
132
- "description": "Column name from query.csv for the y-axis (numeric values).",
133
- },
134
- "category": {
135
- "type": "string",
136
- "description": "Optional column name used to colour-code bars into grouped or stacked series.",
137
- },
138
- "facet_row": {
139
- "type": "string",
140
- "description": "Optional column name. Creates one subplot row per unique value — useful for comparing distributions across a second dimension.",
141
- },
142
- "facet_col": {
143
- "type": "string",
144
- "description": "Optional column name. Creates one subplot column per unique value.",
145
- },
146
- "data": _TRACE_STYLE_PARAM,
147
- "layout": _LAYOUT_PARAM,
148
- },
149
- "required": ["x_column", "y_column"],
150
- },
151
- },
152
- {
153
- "name": "pie_chart_generation_func",
154
- "description": (
155
- "Generates a Plotly pie chart from query.csv data. "
156
- "Use when the user wants to show part-to-whole proportions. "
157
- "Returns an HTML iframe — display it verbatim in the chat."
158
- ),
159
- "parameters": {
160
- "type": "object",
161
- "properties": {
162
- "values": {
163
- "type": "string",
164
- "description": "Column name from query.csv containing the numeric value for each slice.",
165
- },
166
- "names": {
167
- "type": "string",
168
- "description": "Column name from query.csv containing the label for each slice.",
169
- },
170
- "data": _TRACE_STYLE_PARAM,
171
- "layout": _LAYOUT_PARAM,
172
- },
173
- "required": ["values", "names"],
174
- },
175
- },
176
- {
177
- "name": "histogram_generation_func",
178
- "description": (
179
- "Generates a Plotly histogram from query.csv data. "
180
- "Use to show the frequency distribution of a numeric column. "
181
- "Supports normalisation (percent, probability, density) and aggregation functions per bin. "
182
- "Returns an HTML iframe — display it verbatim in the chat."
183
- ),
184
- "parameters": {
185
- "type": "object",
186
- "properties": {
187
- "x_column": {
188
- "type": "string",
189
- "description": "Column name from query.csv whose values are binned on the x-axis.",
190
- },
191
- "y_column": {
192
- "type": "string",
193
- "description": "Optional column name aggregated per bin via histfunc (e.g. sum of sales per price bucket).",
194
- },
195
- "histnorm": {
196
- "type": "string",
197
- "description": "Optional normalisation. One of: 'percent', 'probability', 'density', 'probability density'.",
198
- },
199
- "category": {
200
- "type": "string",
201
- "description": "Optional column name used to overlay multiple colour-coded histograms.",
202
- },
203
- "histfunc": {
204
- "type": "string",
205
- "description": "Optional aggregation function applied to y_column per bin. One of: 'avg', 'sum', 'count'.",
206
- },
207
- "data": _TRACE_STYLE_PARAM,
208
- "layout": _LAYOUT_PARAM,
209
- },
210
- "required": ["x_column"],
211
- },
212
- },
213
- {
214
- "name": "box_chart_generation_func",
215
- "description": (
216
- "Generates a Plotly box plot from query.csv data. "
217
- "Use to visualise the distribution of a numeric column and identify outliers. "
218
- "Especially useful for comparing distributions across categories. "
219
- "Returns an HTML iframe — display it verbatim in the chat."
220
- ),
221
- "parameters": {
222
- "type": "object",
223
- "properties": {
224
- "y_column": {
225
- "type": "string",
226
- "description": "Column name from query.csv containing the numeric values to distribute on the y-axis.",
227
- },
228
- "x_column": {
229
- "type": "string",
230
- "description": "Optional column name. Groups data into one box per unique value on the x-axis.",
231
- },
232
- "category": {
233
- "type": "string",
234
- "description": "Optional column name used to colour-code boxes by a secondary grouping.",
235
- },
236
- "layout": _LAYOUT_PARAM,
237
- },
238
- "required": ["y_column"],
239
- },
240
- },
241
- {
242
- "name": "correlation_heatmap_func",
243
- "description": (
244
- "Computes pairwise Pearson correlations between numeric columns in query.csv and renders "
245
- "the result as a colour-coded heatmap (blue = positive, red = negative). "
246
- "Use when the user asks which variables are related, correlated, or associated with each other. "
247
- "Returns an HTML iframe — display it verbatim in the chat."
248
- ),
249
- "parameters": {
250
- "type": "object",
251
- "properties": {
252
- "columns": {
253
- "type": "array",
254
- "description": "Optional list of numeric column names to include in the matrix. If omitted, all numeric columns from query.csv are used. Avoid ID or index columns.",
255
- "items": {"type": "string"},
256
- },
257
- },
258
- "required": [],
259
- },
260
- },
261
- {
262
- "name": "rolling_stats_func",
263
- "description": (
264
- "Generates a rolling statistics / moving average chart from query.csv data. "
265
- "Overlays rolling aggregations (mean, std, min, max) on top of the original series. "
266
- "Use when the user asks for a moving average, rolling average, rolling statistics, or wants to smooth a time series. "
267
- "Returns an HTML iframe — display it verbatim in the chat."
268
- ),
269
- "parameters": {
270
- "type": "object",
271
- "properties": {
272
- "x_column": {
273
- "type": "string",
274
- "description": "Column name from query.csv for the x-axis typically a date or sequential index.",
275
- },
276
- "y_column": {
277
- "type": "string",
278
- "description": "Column name from query.csv containing the numeric values to compute rolling stats on.",
279
- },
280
- "window": {
281
- "type": "integer",
282
- "description": "Rolling window size in number of rows. Default 7. Infer from the user's request.",
283
- },
284
- "stats": {
285
- "type": "array",
286
- "description": "Statistics to overlay. Valid values: 'mean', 'std', 'min', 'max'. Defaults to ['mean'] if omitted.",
287
- "items": {"type": "string"},
288
- },
289
- "category": {
290
- "type": "string",
291
- "description": "Optional column name to group the data, producing separate rolling stat lines per group.",
292
- },
293
- "layout": _LAYOUT_PARAM,
294
- },
295
- "required": ["x_column", "y_column"],
296
- },
297
- },
298
- {
299
- "name": "table_generation_func",
300
- "description": (
301
- "Formats query.csv results as a styled HTML table. "
302
- "Use when the user wants to view raw query results in a readable format, "
303
- "or when result data is too large to describe in text. Displays up to 200 rows. "
304
- "Returns an HTML table — display it verbatim in the chat."
305
- ),
306
- "parameters": {"type": "object", "properties": {}},
307
- },
308
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ chart_tools = [
2
+ {
3
+ "type": "function",
4
+ "function": {
5
+ "name": "scatter_chart_generation_func",
6
+ "description": f"""This is a scatter plot generation tool useful to generate scatter plots from queried data from our data source that we are querying.
7
+ The data values will come from the columns of our query.csv (the 'x' and 'y' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
8
+ Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
9
+ from the scatter_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
10
+ to it for context if desired.""",
11
+ "parameters": {
12
+ "type": "object",
13
+ "properties": {
14
+ "data": {
15
+ "type": "array",
16
+ "description": """The array containing a dictionary that contains the 'data' portion of the plotly chart generation and will include the options requested by the user.
17
+ The array must contain a json formatted dictionary with outer brackets included, any other format will not work.
18
+ Do not include the 'x' or 'y' portions of the object as this will come from the query.csv file generated by our SQLite query.
19
+ Infer this from the user's message.""",
20
+ "items": {
21
+ "type": "string",
22
+ }
23
+ },
24
+ "x_column": {
25
+ "type": "array",
26
+ "description": f"""An array of strings that correspond to the the columns in our query.csv file that contain the x values of the graph. There can be more than one column
27
+ that can each be plotted against the y_column, if needed.""",
28
+ "items": {
29
+ "type": "string",
30
+ }
31
+ },
32
+ "y_column": {
33
+ "type": "string",
34
+ "description": f"""The column in our query.csv file that contain the y values of the graph.""",
35
+ "items": {
36
+ "type": "string",
37
+ }
38
+ },
39
+ "category": {
40
+ "type": "string",
41
+ "description": f"""An optional column in our query.csv file that contain a parameter that will define the category for the data.""",
42
+ "items": {
43
+ "type": "string",
44
+ }
45
+ },
46
+ "trendline": {
47
+ "type": "string",
48
+ "description": f"""An optional field to specify the type of plotly trendline we wish to use in the scatter plot.
49
+ This trendline value can be one of ['ols','lowess','rolling','ewm','expanding'].
50
+ Do not send any values outside of this array as the function will fail.
51
+ Infer this from the user's message.""",
52
+ "items": {
53
+ "type": "string",
54
+ }
55
+ },
56
+ "trendline_options": {
57
+ "type": "array",
58
+ "description": """An array containing a dictionary that contains the 'trendline_options' portion of the plotly chart generation.
59
+ The 'lowess', 'rolling', and 'ewm' options require trendline_options to be included.
60
+ The array must contain a json formatted dictionary with outer brackets included, any other format will not work.""",
61
+ "items": {
62
+ "type": "string",
63
+ }
64
+ },
65
+ "marginal_x": {
66
+ "type": "string",
67
+ "description": f"""The type of marginal distribution plot we'd like to specify for the plotly scatter plot for the x axis.
68
+ This marginal_x value can be one of ['histogram','rug','box','violin'].
69
+ Do not send any values outside of this array as the function will fail.
70
+ Infer this from the user's message.""",
71
+ "items": {
72
+ "type": "string",
73
+ }
74
+ },
75
+ "marginal_y": {
76
+ "type": "string",
77
+ "description": f"""The type of marginal distribution plot we'd like to specify for the plotly scatter plot for the y axis.
78
+ This marginal_y value can be one of ['histogram','rug','box','violin'].
79
+ Do not send any values outside of this array as the function will fail.
80
+ Infer this from the user's message.""",
81
+ "items": {
82
+ "type": "string",
83
+ }
84
+ },
85
+ "layout": {
86
+ "type": "array",
87
+ "description": """An array containing a dictionary that contains the 'layout' portion of the plotly chart generation.
88
+ The array must contain a json formatted dictionary with outer brackets included, any other format will not work.""",
89
+ "items": {
90
+ "type": "string",
91
+ }
92
+ },
93
+ "size": {
94
+ "type": "string",
95
+ "description": f"""An optional column in our query.csv file that contain a parameter that will define the size of each plot point.
96
+ This is useful for a bubble chart where another value in our query can be represented by the size of the plotted point.
97
+ Values must be greater than or equal to 0 and so in our query, all values less than 0 should be set equal to zero.""",
98
+ "items": {
99
+ "type": "string",
100
+ }
101
+ }
102
+ },
103
+ "required": ["x_column","y_column"],
104
+ },
105
+ },
106
+ },
107
+ {
108
+ "type": "function",
109
+ "function": {
110
+ "name": "line_chart_generation_func",
111
+ "description": f"""This is a line chart generation tool useful to generate line charts from queried data from our data source that we are querying.
112
+ The data values will come from the columns of our query.csv (the 'x' and 'y' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
113
+ Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
114
+ from the line_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
115
+ to it for context if desired.""",
116
+ "parameters": {
117
+ "type": "object",
118
+ "properties": {
119
+ "data": {
120
+ "type": "array",
121
+ "description": """The array containing a dictionary that contains the 'data' portion of the plotly chart generation and will include the options requested by the user.
122
+ The array must contain a json formatted dictionary with outer brackets included, any other format will not work.
123
+ Do not include the 'x' or 'y' portions of the object as this will come from the query.csv file generated by our SQLite query.
124
+ Infer this from the user's message.""",
125
+ "items": {
126
+ "type": "string",
127
+ }
128
+ },
129
+ "x_column": {
130
+ "type": "string",
131
+ "description": f"""The column in our query.csv file that contain the x values of the graph.""",
132
+ "items": {
133
+ "type": "string",
134
+ }
135
+ },
136
+ "y_column": {
137
+ "type": "string",
138
+ "description": f"""The column in our query.csv file that contain the y values of the graph.""",
139
+ "items": {
140
+ "type": "string",
141
+ }
142
+ },
143
+ "category": {
144
+ "type": "string",
145
+ "description": f"""An optional column in our query.csv file that contain a parameter that will define the category for the data.""",
146
+ "items": {
147
+ "type": "string",
148
+ }
149
+ },
150
+ "layout": {
151
+ "type": "array",
152
+ "description": """An array containing a dictionary that contains the 'layout' portion of the plotly chart generation.
153
+ The array must contain a json formatted dictionary with outer brackets included, any other format will not work.""",
154
+ "items": {
155
+ "type": "string",
156
+ }
157
+ }
158
+ },
159
+ "required": ["x_column","y_column","layout"],
160
+ },
161
+ },
162
+ },
163
+ {
164
+ "type": "function",
165
+ "function": {
166
+ "name": "bar_chart_generation_func",
167
+ "description": f"""This is a bar chart generation tool useful to generate line charts from queried data from our data source that we are querying.
168
+ The data values will come from the columns of our query.csv (the 'x' and 'y' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
169
+ Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
170
+ from the bar_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
171
+ to it for context if desired.""",
172
+ "parameters": {
173
+ "type": "object",
174
+ "properties": {
175
+ "data": {
176
+ "type": "array",
177
+ "description": """The array containing a dictionary that contains the 'data' portion of the plotly chart generation and will include the options requested by the user.
178
+ The array must contain a json formatted dictionary with outer brackets included, any other format will not work.
179
+ Do not include the 'x' or 'y' portions of the object as this will come from the query.csv file generated by our SQLite query.
180
+ Infer this from the user's message.""",
181
+ "items": {
182
+ "type": "string",
183
+ }
184
+ },
185
+ "x_column": {
186
+ "type": "string",
187
+ "description": f"""The column in our query.csv file that contains the x values of the graph.""",
188
+ "items": {
189
+ "type": "string",
190
+ }
191
+ },
192
+ "y_column": {
193
+ "type": "string",
194
+ "description": f"""The column in our query.csv file that contains the y values of the graph.""",
195
+ "items": {
196
+ "type": "string",
197
+ }
198
+ },
199
+ "category": {
200
+ "type": "string",
201
+ "description": f"""An optional column in our query.csv file that contains a parameter that will define the category for the data.""",
202
+ "items": {
203
+ "type": "string",
204
+ }
205
+ },
206
+ "facet_row": {
207
+ "type": "string",
208
+ "description": f"""An optional column in our query.csv file that contains a parameter that will define a faceted subplot, where different rows
209
+ correspond to different values of the query specified in this parameter.""",
210
+ "items": {
211
+ "type": "string",
212
+ }
213
+ },
214
+ "facet_col": {
215
+ "type": "string",
216
+ "description": f"""An optional column in our query.csv file that contain a parameter that will define the faceted column, corresponding to
217
+ different values of our query specified in this parameter.""",
218
+ "items": {
219
+ "type": "string",
220
+ }
221
+ },
222
+ "layout": {
223
+ "type": "array",
224
+ "description": """An array containing a dictionary that contains the 'layout' portion of the plotly chart generation.
225
+ The array must contain a json formatted dictionary with outer brackets included, any other format will not work.""",
226
+ "items": {
227
+ "type": "string",
228
+ }
229
+ }
230
+ },
231
+ "required": ["x_column","y_column","layout"],
232
+ },
233
+ },
234
+ },
235
+ {
236
+ "type": "function",
237
+ "function": {
238
+ "name": "pie_chart_generation_func",
239
+ "description": f"""This is a pie chart generation tool useful to generate pie charts from queried data from our data source that we are querying.
240
+ The data values will come from the columns of our query.csv (the 'values' and 'names' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
241
+ Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
242
+ from the pie_chart_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
243
+ to it for context if desired.""",
244
+ "parameters": {
245
+ "type": "object",
246
+ "properties": {
247
+ "data": {
248
+ "type": "array",
249
+ "description": """The array containing a dictionary that contains the 'data' portion of the plotly chart generation and will include the options requested by the user.
250
+ The array must contain a json formatted dictionary with outer brackets included, any other format will not work.
251
+ Do not include the 'x' or 'y' portions of the object as this will come from the query.csv file generated by our SQLite query.
252
+ Infer this from the user's message.""",
253
+ "items": {
254
+ "type": "string",
255
+ }
256
+ },
257
+ "values": {
258
+ "type": "string",
259
+ "description": f"""The column in our query.csv file that contain the values of the pie chart.""",
260
+ "items": {
261
+ "type": "string",
262
+ }
263
+ },
264
+ "names": {
265
+ "type": "string",
266
+ "description": f"""The column in our query.csv file that contain the label or section of each piece of the pie graph and allow us to know what each piece of the pie chart represents.""",
267
+ "items": {
268
+ "type": "string",
269
+ }
270
+ },
271
+ "layout": {
272
+ "type": "array",
273
+ "description": """An array containing a dictionary that contains the 'layout' portion of the plotly chart generation.
274
+ The array must contain a json formatted dictionary with outer brackets included, any other format will not work.""",
275
+ "items": {
276
+ "type": "string",
277
+ }
278
+ }
279
+ },
280
+ "required": ["values","names","layout"],
281
+ },
282
+ },
283
+ },
284
+ {
285
+ "type": "function",
286
+ "function": {
287
+ "name": "histogram_generation_func",
288
+ "description": f"""This is a histogram generation tool useful to generate histograms from queried data from our data source that we are querying.
289
+ The data values will come from the columns of our query.csv (the 'values' and 'names' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
290
+ Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
291
+ from the histogram_generation_func function in any way and always display the iframe fully to the user in the chat window. You can add your own text supplementary
292
+ to it for context if desired.""",
293
+ "parameters": {
294
+ "type": "object",
295
+ "properties": {
296
+ "data": {
297
+ "type": "array",
298
+ "description": """The array containing a dictionary that contains the 'data' portion of the plotly chart generation and will include the options requested by the user.
299
+ The array must contain a json formatted dictionary with outer brackets included, any other format will not work.
300
+ Do not include the 'x' or 'y' portions of the object as this will come from the query.csv file generated by our SQLite query.
301
+ Infer this from the user's message.""",
302
+ "items": {
303
+ "type": "string",
304
+ }
305
+ },
306
+ "x_column": {
307
+ "type": "string",
308
+ "description": f"""The column in our query.csv file that contains the x values of the histogram.
309
+ This would correspond to the counts that would be distributed in the histogram.""",
310
+ "items": {
311
+ "type": "string",
312
+ }
313
+ },
314
+ "y_column": {
315
+ "type": "string",
316
+ "description": f"""An optional column in our query.csv file that contains the y values of the histogram.""",
317
+ "items": {
318
+ "type": "string",
319
+ }
320
+ },
321
+ "histnorm": {
322
+ "type": "string",
323
+ "description": f"""An optional argument to specify the type of normalization if the default isn't used.
324
+ This histnorm value can be one of ['percent','probability','density','probability density'].
325
+ Do not send any values outside of this array as the function will fail.""",
326
+ "items": {
327
+ "type": "string",
328
+ }
329
+ },
330
+ "category": {
331
+ "type": "string",
332
+ "description": f"""An optional column in our query.csv file that contains a parameter that will define the category for the data.""",
333
+ "items": {
334
+ "type": "string",
335
+ }
336
+ },
337
+ "histfunc": {
338
+ "type": "string",
339
+ "description": f"""An optional value that represents the function of data to compute the function which is used on the optional y column.
340
+ This histfunc value can be one of ['avg','sum','count'].
341
+ Do not send any values outside of this array as the function will fail.""",
342
+ "items": {
343
+ "type": "string",
344
+ }
345
+ },
346
+ "layout": {
347
+ "type": "array",
348
+ "description": """An array containing a dictionary that contains the 'layout' portion of the plotly chart generation.
349
+ The array must contain a json formatted dictionary with outer brackets included, any other format will not work.""",
350
+ "items": {
351
+ "type": "string",
352
+ }
353
+ }
354
+ },
355
+ "required": ["x_column"],
356
+ },
357
+ },
358
+ },
359
+ {
360
+ "type": "function",
361
+ "function": {
362
+ "name": "table_generation_func",
363
+ "description": f"""This an table generation tool useful to format data as a table from queried data from our data source that we are querying.
364
+ Takes no parameters as it uses data queried in our query.csv file to build the table.
365
+ Call this function after running our SQLite query and generating query.csv.
366
+ Returns an iframe string which will be displayed inline in our chat window. Do not edit the iframe string returned
367
+ from the table_generation_func function in any way and always display the iframe fully to the user in the chat window.""",
368
+ "parameters": {},
369
+ },
370
+ }
371
+ ]
tools/stats_tools.py CHANGED
@@ -1,130 +1,44 @@
1
- stats_tool_schemas = [
2
- {
3
- "name": "descriptive_stats_func",
4
- "description": (
5
- "Computes summary statistics for numeric columns in query.csv: "
6
- "count, mean, std, min, 25th/50th/75th percentile, and max. "
7
- "Use when the user asks for summary statistics, descriptive statistics, or a statistical overview. "
8
- "Returns a formatted HTML table."
9
- ),
10
- "parameters": {
11
- "type": "object",
12
- "properties": {
13
- "columns": {
14
- "type": "array",
15
- "description": "Optional list of column names to include. If omitted, all numeric columns from query.csv are used. Avoid ID or index columns.",
16
- "items": {"type": "string"},
17
- },
18
- },
19
- "required": [],
20
- },
21
- },
22
- {
23
- "name": "kmeans_clustering_func",
24
- "description": (
25
- "Runs K-Means clustering on numeric feature columns from query.csv. "
26
- "Groups rows into k clusters, displays a scatter plot coloured by cluster assignment, "
27
- "and returns a centroid summary table showing the mean of each feature per cluster. "
28
- "Use when the user asks to cluster the data, find natural segments or groups, or apply K-Means. "
29
- "Returns an HTML iframe and summary table."
30
- ),
31
- "parameters": {
32
- "type": "object",
33
- "properties": {
34
- "feature_columns": {
35
- "type": "array",
36
- "description": "List of numeric column names from query.csv to use as clustering features.",
37
- "items": {"type": "string"},
38
- },
39
- "x_column": {
40
- "type": "string",
41
- "description": "Column name from query.csv for the x-axis of the scatter plot. Usually one of the feature columns.",
42
- },
43
- "y_column": {
44
- "type": "string",
45
- "description": "Column name from query.csv for the y-axis of the scatter plot. Usually one of the feature columns.",
46
- },
47
- "n_clusters": {
48
- "type": "integer",
49
- "description": "Number of clusters (k). Default 3. Infer from the user's request.",
50
- },
51
- "layout": {
52
- "type": "array",
53
- "description": "Optional. An array containing a single JSON-formatted Plotly layout dictionary.",
54
- "items": {"type": "string"},
55
- },
56
- },
57
- "required": ["feature_columns", "x_column", "y_column"],
58
- },
59
- },
60
- {
61
- "name": "hypothesis_test_func",
62
- "description": (
63
- "Performs a statistical hypothesis test on query.csv data and returns a formatted results table "
64
- "with test statistic, p-value, and significance at α=0.05. "
65
- "Supported tests:\n"
66
- "- 't_test_independent': compare means of a numeric column across two groups "
67
- "(requires group_column; use group_values if the column has more than 2 unique values).\n"
68
- "- 't_test_one_sample': test whether a column's mean equals a hypothesized value (requires pop_mean).\n"
69
- "- 'chi_square': test independence between two categorical columns (requires column and column2)."
70
- ),
71
- "parameters": {
72
- "type": "object",
73
- "properties": {
74
- "test_type": {
75
- "type": "string",
76
- "description": "Test to run. One of: 't_test_independent', 't_test_one_sample', 'chi_square'.",
77
- },
78
- "column": {
79
- "type": "string",
80
- "description": "Primary column for the test. Numeric for t-tests; first categorical column for chi-square.",
81
- },
82
- "column2": {
83
- "type": "string",
84
- "description": "Second categorical column. Required for 'chi_square'.",
85
- },
86
- "group_column": {
87
- "type": "string",
88
- "description": "Grouping column. Required for 't_test_independent'. Must have exactly 2 unique values, or specify group_values.",
89
- },
90
- "group_values": {
91
- "type": "array",
92
- "description": "Exactly 2 group labels to compare. Use when group_column has more than 2 unique values.",
93
- "items": {"type": "string"},
94
- },
95
- "pop_mean": {
96
- "type": "number",
97
- "description": "Hypothesized population mean (μ₀). Required for 't_test_one_sample'.",
98
- },
99
- },
100
- "required": ["test_type", "column"],
101
- },
102
- },
103
- {
104
- "name": "regression_func",
105
- "description": (
106
- "Runs an OLS linear regression on query.csv data. "
107
- "Use when the user wants to model the relationship between variables, assess predictors, or run a regression. "
108
- "Returns a regression summary (coefficients, R², p-values) and a scatter plot with the fitted line as an HTML iframe."
109
- ),
110
- "parameters": {
111
- "type": "object",
112
- "properties": {
113
- "independent_variables": {
114
- "type": "array",
115
- "description": "Column names from query.csv to use as independent (predictor) variables.",
116
- "items": {"type": "string"},
117
- },
118
- "dependent_variable": {
119
- "type": "string",
120
- "description": "Column name from query.csv to use as the dependent (outcome) variable.",
121
- },
122
- "category": {
123
- "type": "string",
124
- "description": "Optional column name used to colour-code points and fit separate regression lines per group.",
125
- },
126
- },
127
- "required": ["independent_variables", "dependent_variable"],
128
- },
129
- },
130
- ]
 
1
+ stats_tools = [
2
+ {
3
+ "type": "function",
4
+ "function": {
5
+ "name": "regression_func",
6
+ "description": f"""This a tool to calculate regressions on our data source that we are querying.
7
+ We can run queries with our 'sql_query_func' function and they will be available to use in this function via the query.csv file that is generated.
8
+ Returns a dictionary of values that includes a regression_summary and a regression chart (which is an iframe displaying the
9
+ linear regression in chart form and should be shown to the user).""",
10
+ "parameters": {
11
+ "type": "object",
12
+ "properties": {
13
+ "independent_variables": {
14
+ "type": "array",
15
+ "description": f"""An array of strings that states the independent variables in our data set which should be column names in our query.csv file that is generated
16
+ in the 'sql_query_func' function. This will allow us to identify the data to use for our independent variables.
17
+ Infer this from the user's message.""",
18
+ "items": {
19
+ "type": "string",
20
+ }
21
+ },
22
+ "dependent_variable": {
23
+ "type": "string",
24
+ "description": f"""A string that states the dependent variables in our data set which should be a column name in our query.csv file that is generated
25
+ in the 'sql_query_func' function. This will allow us to identify the data to use for our dependent variables.
26
+ Infer this from the user's message.""",
27
+ "items": {
28
+ "type": "string",
29
+ }
30
+ },
31
+ "category": {
32
+ "type": "string",
33
+ "description": f"""An optional column in our query.csv file that contain a parameter that will define the category for the data.
34
+ Do not send value if no category is needed or specified. This category must be present in our query.csv file to be valid.""",
35
+ "items": {
36
+ "type": "string",
37
+ }
38
+ }
39
+ },
40
+ "required": ["independent_variables","dependent_variable"],
41
+ },
42
+ },
43
+ }
44
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/tools.py CHANGED
@@ -1,130 +1,149 @@
1
- from .stats_tools import stats_tool_schemas
2
- from .chart_tools import chart_tool_schemas
3
-
4
- def tools_call(session_hash, data_source, titles):
5
- from haystack.tools import Tool
6
-
7
- _noop = lambda **kwargs: None
8
-
9
- def make_tool(schema):
10
- return Tool(
11
- name=schema["name"],
12
- description=schema["description"],
13
- parameters=schema["parameters"],
14
- function=_noop,
15
- )
16
-
17
- titles_string = (titles[:625] + '..') if len(titles) > 625 else titles
18
-
19
- query_tool_schemas = {
20
- 'file_upload': {
21
- "name": "query_func",
22
- "description": f"""This is a tool useful to query a SQLite table called 'data_source' with the following Columns: {titles_string}.
23
- There may also be more columns in the table if the number of columns is too large to process.
24
- This function also saves the results of the query to csv file called query.csv.""",
25
- "parameters": {
26
- "type": "object",
27
- "properties": {
28
- "queries": {
29
- "type": "string",
30
- "description": "The query to use in the search. Infer this from the user's message. It should be a question or a statement."
31
- }
32
- },
33
- "required": ["queries"]
34
- },
35
- },
36
- 'sql': {
37
- "name": "query_func",
38
- "description": f"""This is a tool useful to query a PostgreSQL database with the following tables, {titles_string}.
39
- There may also be more tables in the database if the number of tables is too large to process.
40
- This function also saves the results of the query to csv file called query.csv.""",
41
- "parameters": {
42
- "type": "object",
43
- "properties": {
44
- "queries": {
45
- "type": "string",
46
- "description": "The PostgreSQL query to use in the search. Infer this from the user's message. It should be a question or a statement."
47
- }
48
- },
49
- "required": ["queries"]
50
- },
51
- },
52
- 'doc_db': {
53
- "name": "query_func",
54
- "description": f"""This is a tool useful to build an aggregation pipeline to query a MongoDB NoSQL document database with the following collections, {titles_string}.
55
- There may also be more collections in the database if the number of collections is too large to process.
56
- This function also saves the results of the query to a csv file called query.csv.""",
57
- "parameters": {
58
- "type": "object",
59
- "properties": {
60
- "queries": {
61
- "type": "string",
62
- "description": "The MongoDB aggregation pipeline to use in the search. Infer this from the user's message. It should be a question or a statement."
63
- },
64
- "db_collection": {
65
- "type": "string",
66
- "description": "The MongoDB collection to use in the search. Infer this from the user's message. It should be a question or a statement."
67
- }
68
- },
69
- "required": ["queries", "db_collection"]
70
- },
71
- },
72
- 'graphql': [
73
- {
74
- "name": "query_func",
75
- "description": f"""This is a tool useful to build a GraphQL query for a GraphQL API endpoint with the following types, {titles_string}.
76
- There may also be more types in the GraphQL endpoint if the number of types is too large to process.
77
- This function also saves the results of the query to a csv file called query.csv.""",
78
- "parameters": {
79
- "type": "object",
80
- "properties": {
81
- "queries": {
82
- "type": "string",
83
- "description": "The GraphQL query to use in the search. Infer this from the user's message. It should be a question or a statement."
84
- }
85
- },
86
- "required": ["queries"]
87
- },
88
- },
89
- {
90
- "name": "graphql_schema_query",
91
- "description": f"""This is a tool useful to query a GraphQL type and receive back information about its schema. This is useful because
92
- the GraphQL introspection query is too large to be ingested all at once and this allows us to query the schema one type at a time to
93
- view it in manageable bites. You may realize after viewing the schema, that the type you selected was not appropriate for the question
94
- you are attempting answer. You may then query additional types to find the appropriate types to use for your GraphQL API query.""",
95
- "parameters": {
96
- "type": "object",
97
- "properties": {
98
- "graphql_type": {
99
- "type": "string",
100
- "description": "The GraphQL type that we want to view the schema of in order to make the proper query with our graphql_query_func. Infer this from the user's message. It should be a question or a statement."
101
- }
102
- },
103
- "required": ["graphql_type"]
104
- },
105
- },
106
- {
107
- "name": "graphql_csv_query",
108
- "description": f"""This is a tool useful to SQL query our query.csv file that is generated from our GraphQL query. This is useful in a situation
109
- where the results of the GraphQL query need additional querying to answer the user question. The query.csv file is converted to a Pandas dataframe
110
- and we query that dataframe with SQL on a table called 'query' before converting it back to a csv file.""",
111
- "parameters": {
112
- "type": "object",
113
- "properties": {
114
- "csv_query": {
115
- "type": "string",
116
- "description": "The pandas dataframe SQL query to use in the search. The table that we query is named 'query'. Infer this from the user's message. It should be a question or a statement."
117
- }
118
- },
119
- "required": ["csv_query"]
120
- },
121
- },
122
- ]
123
- }
124
-
125
- source_schemas = query_tool_schemas[data_source]
126
- source_tools = [make_tool(s) for s in (source_schemas if isinstance(source_schemas, list) else [source_schemas])]
127
- chart_tools = [make_tool(s) for s in chart_tool_schemas]
128
- stats_tools = [make_tool(s) for s in stats_tool_schemas]
129
-
130
- return source_tools + chart_tools + stats_tools
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .stats_tools import stats_tools
2
+ from .chart_tools import chart_tools
3
+
4
+ def tools_call(session_hash, data_source, titles):
5
+
6
+ titles_string = (titles[:625] + '..') if len(titles) > 625 else titles
7
+
8
+ tools_calls = {
9
+ 'file_upload' : [
10
+ {
11
+ "type": "function",
12
+ "function": {
13
+ "name": "sqlite_query_func",
14
+ "description": f"""This is a tool useful to query a SQLite table called 'data_source' with the following Columns: {titles_string}.
15
+ There may also be more columns in the table if the number of columns is too large to process.
16
+ This function also saves the results of the query to csv file called query.csv.""",
17
+ "parameters": {
18
+ "type": "object",
19
+ "properties": {
20
+ "queries": {
21
+ "type": "array",
22
+ "description": "The query to use in the search. Infer this from the user's message. It should be a question or a statement",
23
+ "items": {
24
+ "type": "string",
25
+ }
26
+ }
27
+ },
28
+ "required": ["queries"],
29
+ },
30
+ },
31
+ },
32
+ ],
33
+ 'sql' : [
34
+ {
35
+ "type": "function",
36
+ "function": {
37
+ "name": "sql_query_func",
38
+ "description": f"""This is a tool useful to query a PostgreSQL database with the following tables, {titles_string}.
39
+ There may also be more tables in the database if the number of tables is too large to process.
40
+ This function also saves the results of the query to csv file called query.csv.""",
41
+ "parameters": {
42
+ "type": "object",
43
+ "properties": {
44
+ "queries": {
45
+ "type": "array",
46
+ "description": "The PostgreSQL query to use in the search. Infer this from the user's message. It should be a question or a statement",
47
+ "items": {
48
+ "type": "string",
49
+ }
50
+ }
51
+ },
52
+ "required": ["queries"],
53
+ },
54
+ },
55
+ },
56
+ ],
57
+ 'doc_db' : [
58
+ {
59
+ "type": "function",
60
+ "function": {
61
+ "name": "doc_db_query_func",
62
+ "description": f"""This is a tool useful to build an aggregation pipeline to query a MongoDB NoSQL document database with the following collections, {titles_string}.
63
+ There may also be more collections in the database if the number of tables is too large to process.
64
+ This function also saves the results of the query to a csv file called query.csv.""",
65
+ "parameters": {
66
+ "type": "object",
67
+ "properties": {
68
+ "aggregation_pipeline": {
69
+ "type": "string",
70
+ "description": "The MongoDB aggregation pipeline to use in the search. Infer this from the user's message. It should be a question or a statement."
71
+ },
72
+ "db_collection": {
73
+ "type": "string",
74
+ "description": "The MongoDB collection to use in the search. Infer this from the user's message. It should be a question or a statement.",
75
+ }
76
+ },
77
+ "required": ["aggregation_pipeline","db_collection"],
78
+ },
79
+ },
80
+ },
81
+ ],
82
+ 'graphql' : [
83
+ {
84
+ "type": "function",
85
+ "function": {
86
+ "name": "graphql_query_func",
87
+ "description": f"""This is a tool useful to build a GraphQL query for a GraphQL API endpoint with the following types, {titles_string}.
88
+ There may also be more types in the GraphQL endpoint if the number of types is too large to process.
89
+ This function also saves the results of the query to a csv file called query.csv.""",
90
+ "parameters": {
91
+ "type": "object",
92
+ "properties": {
93
+ "graphql_query": {
94
+ "type": "string",
95
+ "description": "The GraphQL query to use in the search. Infer this from the user's message. It should be a question or a statement."
96
+ }
97
+ },
98
+ "required": ["graphql_query"],
99
+ },
100
+ },
101
+ },
102
+ {
103
+ "type": "function",
104
+ "function": {
105
+ "name": "graphql_schema_query",
106
+ "description": f"""This is a tool useful to query a GraphQL type and receive back information about its schema. This is useful because
107
+ the GraphQL introspection query is too large to be ingested all at once and this allows us to query the schema one type at a time to
108
+ view it in manageable bites. You may realize after viewing the schema, that the type you selected was not appropriate for the question
109
+ you are attempting answer. You may then query additional types to find the appropriate types to use for your GraphQL API query.""",
110
+ "parameters": {
111
+ "type": "object",
112
+ "properties": {
113
+ "graphql_type": {
114
+ "type": "string",
115
+ "description": "The GraphQL type that we want to view the schema of in order to make the proper query with our graphql_query_func. Infer this from the user's message. It should be a question or a statement."
116
+ }
117
+ },
118
+ "required": ["graphql_type"],
119
+ },
120
+ },
121
+ },
122
+ {
123
+ "type": "function",
124
+ "function": {
125
+ "name": "graphql_csv_query",
126
+ "description": f"""This is a tool useful to SQL query our query.csv file that is generated from our GraphQL query. This is useful in a situation
127
+ where the results of the GraphQL query need additional querying to answer the user question. The query.csv file is converted to a Pandas dataframe
128
+ and we query that dataframe with SQL on a table called 'query' before converting it back to a csv file.""",
129
+ "parameters": {
130
+ "type": "object",
131
+ "properties": {
132
+ "csv_query": {
133
+ "type": "string",
134
+ "description": "The pandas dataframe SQL query to use in the search. The table that we query is named 'query'. Infer this from the user's message. It should be a question or a statement"
135
+ }
136
+ },
137
+ "required": ["csv_query"],
138
+ },
139
+ },
140
+ },
141
+ ]
142
+ }
143
+
144
+ tools = tools_calls[data_source]
145
+
146
+ tools.extend(chart_tools)
147
+ tools.extend(stats_tools)
148
+
149
+ return tools
utils.py CHANGED
@@ -4,6 +4,4 @@ current_dir = Path(__file__).parent
4
 
5
  TEMP_DIR = current_dir / 'temp'
6
 
7
- message_dict = {}
8
- api_key_store = {}
9
- model_store = {}
 
4
 
5
  TEMP_DIR = current_dir / 'temp'
6
 
7
+ message_dict = {}