Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- .DS_Store +0 -0
- .gitignore +3 -0
- DB_SYS_PROMPT.txt +244 -0
- README.md +3 -9
- app.py +219 -0
- chat_helpers.py +44 -0
- locustfile.py +81 -0
- logger.py +43 -0
- output.log +5 -0
- server.py +126 -0
- sql_tab.py +112 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.gradio/
|
| 2 |
+
user_data
|
| 3 |
+
__pycache__
|
DB_SYS_PROMPT.txt
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DB_SYSTEM_PROMPT = (
|
| 2 |
+
"You are a helpful assistant that answers questions about a PostgreSQL database. "
|
| 3 |
+
"When you use search, include brief citations as links. "
|
| 4 |
+
"Use markdowns to separate the code and the text in your output. "
|
| 5 |
+
"Do not use the WITH clause in your SQL code. "
|
| 6 |
+
"When asked about 'Find the top 10 directors who use the word 'the' the most in their movie titles', assume that the user is looking for a partial match. Remind the user that this is partial match. "
|
| 7 |
+
"Here is the database:"
|
| 8 |
+
"""CREATE TABLE public.sales (
|
| 9 |
+
year integer,
|
| 10 |
+
release_date text,
|
| 11 |
+
title text,
|
| 12 |
+
genre text,
|
| 13 |
+
international_box_office text,
|
| 14 |
+
domestic_box_office text,
|
| 15 |
+
worldwide_box_office text,
|
| 16 |
+
production_budget text,
|
| 17 |
+
opening_weekend text,
|
| 18 |
+
theatre_count integer,
|
| 19 |
+
avg_run_per_theatre text,
|
| 20 |
+
runtime integer,
|
| 21 |
+
keywords text,
|
| 22 |
+
creative_type text,
|
| 23 |
+
url text
|
| 24 |
+
);
|
| 25 |
+
|
| 26 |
+
CREATE TABLE public.metadata (
|
| 27 |
+
|
| 28 |
+
url text,
|
| 29 |
+
title text,
|
| 30 |
+
studio text,
|
| 31 |
+
rating text,
|
| 32 |
+
runtime integer,
|
| 33 |
+
casting text,
|
| 34 |
+
director text,
|
| 35 |
+
genre text,
|
| 36 |
+
summary text,
|
| 37 |
+
awards text,
|
| 38 |
+
metascore integer,
|
| 39 |
+
userscore text,
|
| 40 |
+
RelDate text
|
| 41 |
+
);
|
| 42 |
+
|
| 43 |
+
CREATE TABLE public.user_reviews (
|
| 44 |
+
url text,
|
| 45 |
+
idvscore smallint,
|
| 46 |
+
reviewer text,
|
| 47 |
+
datep date,
|
| 48 |
+
thumbs_up integer,
|
| 49 |
+
thumbs_tot integer,
|
| 50 |
+
wc integer,
|
| 51 |
+
analytic double precision,
|
| 52 |
+
clout double precision,
|
| 53 |
+
authentic double precision,
|
| 54 |
+
tone double precision,
|
| 55 |
+
wps double precision,
|
| 56 |
+
sixltr double precision,
|
| 57 |
+
dic double precision,
|
| 58 |
+
function_ double precision,
|
| 59 |
+
pronoun double precision,
|
| 60 |
+
ppron double precision,
|
| 61 |
+
i double precision,
|
| 62 |
+
we double precision,
|
| 63 |
+
you double precision,
|
| 64 |
+
shehe double precision,
|
| 65 |
+
they double precision,
|
| 66 |
+
ipron double precision,
|
| 67 |
+
article double precision,
|
| 68 |
+
prep double precision,
|
| 69 |
+
auxverb double precision,
|
| 70 |
+
adverb double precision,
|
| 71 |
+
conj double precision,
|
| 72 |
+
negate double precision,
|
| 73 |
+
verb double precision,
|
| 74 |
+
adj double precision,
|
| 75 |
+
compare double precision,
|
| 76 |
+
interrog double precision,
|
| 77 |
+
number double precision,
|
| 78 |
+
quant double precision,
|
| 79 |
+
affect double precision,
|
| 80 |
+
posemo double precision,
|
| 81 |
+
negemo double precision,
|
| 82 |
+
anx double precision,
|
| 83 |
+
anger double precision,
|
| 84 |
+
sad double precision,
|
| 85 |
+
social double precision,
|
| 86 |
+
family double precision,
|
| 87 |
+
friend double precision,
|
| 88 |
+
female double precision,
|
| 89 |
+
male double precision,
|
| 90 |
+
cogproc double precision,
|
| 91 |
+
insight double precision,
|
| 92 |
+
cause double precision,
|
| 93 |
+
discrep double precision,
|
| 94 |
+
tentat double precision,
|
| 95 |
+
certain double precision,
|
| 96 |
+
differ double precision,
|
| 97 |
+
percept double precision,
|
| 98 |
+
see double precision,
|
| 99 |
+
hear double precision,
|
| 100 |
+
feel double precision,
|
| 101 |
+
bio double precision,
|
| 102 |
+
body double precision,
|
| 103 |
+
health double precision,
|
| 104 |
+
sexual double precision,
|
| 105 |
+
ingest double precision,
|
| 106 |
+
drives double precision,
|
| 107 |
+
affiliation double precision,
|
| 108 |
+
achieve double precision,
|
| 109 |
+
power double precision,
|
| 110 |
+
reward double precision,
|
| 111 |
+
risk double precision,
|
| 112 |
+
focuspast double precision,
|
| 113 |
+
focuspresent double precision,
|
| 114 |
+
focusfuture double precision,
|
| 115 |
+
relativ double precision,
|
| 116 |
+
motion double precision,
|
| 117 |
+
space double precision,
|
| 118 |
+
time double precision,
|
| 119 |
+
work double precision,
|
| 120 |
+
leisure double precision,
|
| 121 |
+
home double precision,
|
| 122 |
+
money double precision,
|
| 123 |
+
relig double precision,
|
| 124 |
+
death double precision,
|
| 125 |
+
informal double precision,
|
| 126 |
+
swear double precision,
|
| 127 |
+
netspeak double precision,
|
| 128 |
+
assent double precision,
|
| 129 |
+
nonflu double precision,
|
| 130 |
+
filler double precision,
|
| 131 |
+
allpunc double precision,
|
| 132 |
+
period double precision,
|
| 133 |
+
comma double precision,
|
| 134 |
+
colon double precision,
|
| 135 |
+
semic double precision,
|
| 136 |
+
qmark double precision,
|
| 137 |
+
exclam double precision,
|
| 138 |
+
dash double precision,
|
| 139 |
+
quote double precision,
|
| 140 |
+
apostro double precision,
|
| 141 |
+
parenth double precision,
|
| 142 |
+
otherp double precision
|
| 143 |
+
);
|
| 144 |
+
|
| 145 |
+
CREATE TABLE public.expert_reviews (
|
| 146 |
+
url text,
|
| 147 |
+
idvscore smallint,
|
| 148 |
+
reviewer text,
|
| 149 |
+
datep date,
|
| 150 |
+
wc integer,
|
| 151 |
+
analytic double precision,
|
| 152 |
+
clout double precision,
|
| 153 |
+
authentic double precision,
|
| 154 |
+
tone double precision,
|
| 155 |
+
wps double precision,
|
| 156 |
+
sixltr double precision,
|
| 157 |
+
dic double precision,
|
| 158 |
+
function_ double precision,
|
| 159 |
+
pronoun double precision,
|
| 160 |
+
ppron double precision,
|
| 161 |
+
i double precision,
|
| 162 |
+
we double precision,
|
| 163 |
+
you double precision,
|
| 164 |
+
shehe double precision,
|
| 165 |
+
they double precision,
|
| 166 |
+
ipron double precision,
|
| 167 |
+
article double precision,
|
| 168 |
+
prep double precision,
|
| 169 |
+
auxverb double precision,
|
| 170 |
+
adverb double precision,
|
| 171 |
+
conj double precision,
|
| 172 |
+
negate double precision,
|
| 173 |
+
verb double precision,
|
| 174 |
+
adj double precision,
|
| 175 |
+
compare double precision,
|
| 176 |
+
interrog double precision,
|
| 177 |
+
number double precision,
|
| 178 |
+
quant double precision,
|
| 179 |
+
affect double precision,
|
| 180 |
+
posemo double precision,
|
| 181 |
+
negemo double precision,
|
| 182 |
+
anx double precision,
|
| 183 |
+
anger double precision,
|
| 184 |
+
sad double precision,
|
| 185 |
+
social double precision,
|
| 186 |
+
family double precision,
|
| 187 |
+
friend double precision,
|
| 188 |
+
female double precision,
|
| 189 |
+
male double precision,
|
| 190 |
+
cogproc double precision,
|
| 191 |
+
insight double precision,
|
| 192 |
+
cause double precision,
|
| 193 |
+
discrep double precision,
|
| 194 |
+
tentat double precision,
|
| 195 |
+
certain double precision,
|
| 196 |
+
differ double precision,
|
| 197 |
+
percept double precision,
|
| 198 |
+
see double precision,
|
| 199 |
+
hear double precision,
|
| 200 |
+
feel double precision,
|
| 201 |
+
bio double precision,
|
| 202 |
+
body double precision,
|
| 203 |
+
health double precision,
|
| 204 |
+
sexual double precision,
|
| 205 |
+
ingest double precision,
|
| 206 |
+
drives double precision,
|
| 207 |
+
affiliation double precision,
|
| 208 |
+
achieve double precision,
|
| 209 |
+
power double precision,
|
| 210 |
+
reward double precision,
|
| 211 |
+
risk double precision,
|
| 212 |
+
focuspast double precision,
|
| 213 |
+
focuspresent double precision,
|
| 214 |
+
focusfuture double precision,
|
| 215 |
+
relativ double precision,
|
| 216 |
+
motion double precision,
|
| 217 |
+
space double precision,
|
| 218 |
+
time double precision,
|
| 219 |
+
work double precision,
|
| 220 |
+
leisure double precision,
|
| 221 |
+
home double precision,
|
| 222 |
+
money double precision,
|
| 223 |
+
relig double precision,
|
| 224 |
+
death double precision,
|
| 225 |
+
informal double precision,
|
| 226 |
+
swear double precision,
|
| 227 |
+
netspeak double precision,
|
| 228 |
+
assent double precision,
|
| 229 |
+
nonflu double precision,
|
| 230 |
+
filler double precision,
|
| 231 |
+
allpunc double precision,
|
| 232 |
+
period double precision,
|
| 233 |
+
comma double precision,
|
| 234 |
+
colon double precision,
|
| 235 |
+
semic double precision,
|
| 236 |
+
qmark double precision,
|
| 237 |
+
exclam double precision,
|
| 238 |
+
dash double precision,
|
| 239 |
+
quote double precision,
|
| 240 |
+
apostro double precision,
|
| 241 |
+
parenth double precision,
|
| 242 |
+
otherp double precision
|
| 243 |
+
);"""
|
| 244 |
+
)
|
README.md
CHANGED
|
@@ -1,12 +1,6 @@
|
|
| 1 |
---
|
| 2 |
-
title: SQL
|
| 3 |
-
emoji: 🐠
|
| 4 |
-
colorFrom: indigo
|
| 5 |
-
colorTo: indigo
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.45.0
|
| 8 |
app_file: app.py
|
| 9 |
-
|
|
|
|
| 10 |
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: SQL-Assignment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
app_file: app.py
|
| 4 |
+
sdk: gradio
|
| 5 |
+
sdk_version: 5.44.1
|
| 6 |
---
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from openai import AsyncOpenAI
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import uuid, asyncio
|
| 6 |
+
|
| 7 |
+
from sql_tab import run_sql
|
| 8 |
+
from logger import log_event
|
| 9 |
+
from chat_helpers import build_input_from_history, get_db_sys_prompt
|
| 10 |
+
|
| 11 |
+
oclient = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 12 |
+
max_rows = 100
|
| 13 |
+
|
| 14 |
+
async def respond_once(message, history):
|
| 15 |
+
|
| 16 |
+
text_input = build_input_from_history(message, history)
|
| 17 |
+
kwargs = dict(
|
| 18 |
+
model="gpt-4.1",
|
| 19 |
+
input=text_input,
|
| 20 |
+
temperature=0,
|
| 21 |
+
instructions=get_db_sys_prompt(),
|
| 22 |
+
tools=[{"type": "web_search"}],
|
| 23 |
+
tool_choice="auto",
|
| 24 |
+
parallel_tool_calls=True,
|
| 25 |
+
)
|
| 26 |
+
# MOCK mode to isolate app/DB without burning tokens
|
| 27 |
+
if os.getenv("MOCK_OPENAI", "").lower() in {"1", "true", "yes"}:
|
| 28 |
+
import random, asyncio
|
| 29 |
+
await asyncio.sleep(random.uniform(0.05, 0.25))
|
| 30 |
+
return "MOCK: Here’s a fabricated answer for load testing."
|
| 31 |
+
|
| 32 |
+
resp = await oclient.responses.create(**kwargs)
|
| 33 |
+
return getattr(resp, "output_text", "")
|
| 34 |
+
|
| 35 |
+
async def respond(message, history):
|
| 36 |
+
text_input = build_input_from_history(message, history)
|
| 37 |
+
kwargs = dict(
|
| 38 |
+
model="gpt-4.1",
|
| 39 |
+
input=text_input,
|
| 40 |
+
temperature=0,
|
| 41 |
+
instructions=get_db_sys_prompt(),
|
| 42 |
+
tools=[{"type": "web_search"}],
|
| 43 |
+
tool_choice="auto",
|
| 44 |
+
parallel_tool_calls=True,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
buffer = []
|
| 48 |
+
async with oclient.responses.stream(**kwargs) as stream:
|
| 49 |
+
async for event in stream:
|
| 50 |
+
if event.type == "response.output_text.delta":
|
| 51 |
+
buffer.append(event.delta)
|
| 52 |
+
yield "".join(buffer)
|
| 53 |
+
|
| 54 |
+
final = await stream.get_final_response()
|
| 55 |
+
final_text = getattr(final, "output_text", None)
|
| 56 |
+
if final_text and (not buffer or final_text != "".join(buffer)):
|
| 57 |
+
yield final_text
|
| 58 |
+
|
| 59 |
+
async def chat_driver(user_message, messages_history, _user_name, _session_id):
|
| 60 |
+
messages_history = messages_history or []
|
| 61 |
+
base = messages_history + [{"role": "user", "content": user_message}]
|
| 62 |
+
assistant_text = ""
|
| 63 |
+
|
| 64 |
+
asyncio.create_task(log_event(_user_name, _session_id, "chat_user", {"text": user_message}))
|
| 65 |
+
|
| 66 |
+
async for chunk in respond(user_message, messages_history):
|
| 67 |
+
assistant_text = chunk
|
| 68 |
+
# stream to UI
|
| 69 |
+
yield base + [{"role": "assistant", "content": assistant_text}], ""
|
| 70 |
+
|
| 71 |
+
# after stream finished, log the final assistant text
|
| 72 |
+
asyncio.create_task(log_event(_user_name, _session_id, "chat_assistant", {"text": assistant_text}))
|
| 73 |
+
|
| 74 |
+
async def post_completion_code(_user_name, _session_id):
|
| 75 |
+
code = "9C1F4B2E"
|
| 76 |
+
msg = f"the completion code is {code}"
|
| 77 |
+
updated = [{"role": "assistant", "content": msg}]
|
| 78 |
+
|
| 79 |
+
await log_event(_user_name, _session_id, "completion_code", {"code": code})
|
| 80 |
+
return updated
|
| 81 |
+
|
| 82 |
+
with gr.Blocks(title="Movie Database", theme="soft") as demo:
|
| 83 |
+
# gr.Markdown("## Movie Database Bot and SQL Console")
|
| 84 |
+
user_name = gr.State("")
|
| 85 |
+
session_id = gr.State("")
|
| 86 |
+
|
| 87 |
+
with gr.Column(visible=True) as identify_view:
|
| 88 |
+
gr.Markdown("### Login")
|
| 89 |
+
name_tb = gr.Textbox(label="Student ID (required)", placeholder="Please enter your student ID", autofocus=True)
|
| 90 |
+
enter_btn = gr.Button("Enter", variant="primary")
|
| 91 |
+
id_msg = gr.Markdown("")
|
| 92 |
+
|
| 93 |
+
async def do_login(name):
|
| 94 |
+
name = (name or "").strip()
|
| 95 |
+
if not name:
|
| 96 |
+
return (gr.update(visible=True), gr.update(visible=False), "⚠️ Please enter your student ID to continue.", "", "")
|
| 97 |
+
sid = uuid.uuid4().hex
|
| 98 |
+
await log_event(name, sid, "login", {"meta": {"agent": "gradio_app", "version": 1}})
|
| 99 |
+
return (gr.update(visible=False), gr.update(visible=True), "", name, sid)
|
| 100 |
+
|
| 101 |
+
with gr.Column(visible=False) as app_view:
|
| 102 |
+
|
| 103 |
+
welcome_md = gr.Markdown("")
|
| 104 |
+
with gr.Tabs():
|
| 105 |
+
with gr.Tab("Assignment"):
|
| 106 |
+
gr.Markdown("""
|
| 107 |
+
<h2> Platform Usage and the Assignment </h2>
|
| 108 |
+
<br>
|
| 109 |
+
<ul>
|
| 110 |
+
<li> You can use the "SQL" tab to run your queries and see if you have the correct results.</li>
|
| 111 |
+
<li> The "Chatbot" tab provides you a chatbot (that is connected to ChatGPT) to ask questions about PostgreSQL and the database.</li>
|
| 112 |
+
<li> The chatbot knows the tables and their columns, and would help with questions.</li>
|
| 113 |
+
<li> Even with its knowledge, the chatbot can still make mistakes.</li>
|
| 114 |
+
<li> When you are finished with all questions, the survey platform will ask for a completion code. You can find it in the "Chatbot" tab. </li>
|
| 115 |
+
<li> <b> Reminder: </b> This assignment is optional and ungraded. It is designed for you to practice. You can be relaxed, it is okay to have errors. Good luck! </li>
|
| 116 |
+
</ul>
|
| 117 |
+
<h3> Database </h3>
|
| 118 |
+
The database has 4 tables, each corresponding to the 4 excel files you have for the project:
|
| 119 |
+
<ul>
|
| 120 |
+
<li>sales</li>
|
| 121 |
+
<li>metadata</li>
|
| 122 |
+
<li>user_reviews</li>
|
| 123 |
+
<li>expert_reviews</li>
|
| 124 |
+
</ul>
|
| 125 |
+
<br>
|
| 126 |
+
<b> Important Notes: </b>
|
| 127 |
+
<br>
|
| 128 |
+
<br>
|
| 129 |
+
<ul>
|
| 130 |
+
<li> A proper ERD or foreign key relationships are not defined for the tables. You can still join them based on the column names, but be careful. </li>
|
| 131 |
+
<li> Some movies have the same title but they are different movies. </li>
|
| 132 |
+
<li> A column that stores numerical information might have the datatype "text". </li>
|
| 133 |
+
<li> Datatypes might not be exactly the same as the excel files. </li>
|
| 134 |
+
<li> Some columns might store null values as text, like "n/a" or "null". </li>
|
| 135 |
+
<li> Columns with the same names might store different values in different tables. Example: "url" column in metadata and sales.</li>
|
| 136 |
+
</ul>
|
| 137 |
+
""")
|
| 138 |
+
|
| 139 |
+
with gr.Tab("Chatbot"):
|
| 140 |
+
chatbot = gr.Chatbot(type="messages", label="Conversation", height=450)
|
| 141 |
+
|
| 142 |
+
with gr.Row():
|
| 143 |
+
chat_input = gr.Textbox(
|
| 144 |
+
placeholder="How can I help you with PostgreSQL today?",
|
| 145 |
+
scale=8,
|
| 146 |
+
autofocus=True,
|
| 147 |
+
container=False,
|
| 148 |
+
)
|
| 149 |
+
send_btn = gr.Button("Send", variant="primary", scale=1)
|
| 150 |
+
code_btn = gr.Button("Completion code", variant="secondary", scale=1)
|
| 151 |
+
|
| 152 |
+
def _clear_input():
|
| 153 |
+
return ""
|
| 154 |
+
|
| 155 |
+
ev = send_btn.click(chat_driver, [chat_input, chatbot, user_name, session_id], [chatbot, chat_input])
|
| 156 |
+
ev.then(_clear_input, None, [chat_input])
|
| 157 |
+
|
| 158 |
+
ev2 = chat_input.submit(chat_driver, [chat_input, chatbot, user_name, session_id], [chatbot, chat_input])
|
| 159 |
+
ev2.then(_clear_input, None, [chat_input])
|
| 160 |
+
|
| 161 |
+
code_btn.click(
|
| 162 |
+
post_completion_code,
|
| 163 |
+
inputs=[user_name, session_id],
|
| 164 |
+
outputs=[chatbot],
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
with gr.Tab("SQL"):
|
| 168 |
+
with gr.Column():
|
| 169 |
+
sql_input = gr.Code(
|
| 170 |
+
label="SQL",
|
| 171 |
+
language="sql",
|
| 172 |
+
value="SELECT * FROM sales;",
|
| 173 |
+
lines=10,
|
| 174 |
+
)
|
| 175 |
+
with gr.Row():
|
| 176 |
+
run_btn = gr.Button("Run", variant="primary")
|
| 177 |
+
clear_btn = gr.Button("Clear")
|
| 178 |
+
|
| 179 |
+
results = gr.Dataframe(
|
| 180 |
+
label="Results",
|
| 181 |
+
wrap=True,
|
| 182 |
+
interactive=True,
|
| 183 |
+
)
|
| 184 |
+
meta = gr.Markdown("")
|
| 185 |
+
plan = gr.Markdown("", label="Explain/Plan")
|
| 186 |
+
|
| 187 |
+
async def on_run(q, _user_name, _session_id):
|
| 188 |
+
df, meta_msg, _ = await asyncio.to_thread(run_sql, q, max_rows, False)
|
| 189 |
+
|
| 190 |
+
await log_event(
|
| 191 |
+
_user_name, _session_id, "sql",
|
| 192 |
+
{
|
| 193 |
+
"query": q,
|
| 194 |
+
"row_limit": max_rows,
|
| 195 |
+
"row_count": int(getattr(df, "shape", [0])[0]),
|
| 196 |
+
"meta": meta_msg,
|
| 197 |
+
},
|
| 198 |
+
)
|
| 199 |
+
return df, meta_msg, ""
|
| 200 |
+
|
| 201 |
+
def on_clear():
|
| 202 |
+
return "", pd.DataFrame(), "Cleared.", ""
|
| 203 |
+
|
| 204 |
+
run_btn.click(on_run, [sql_input, user_name, session_id], [results, meta, plan])
|
| 205 |
+
clear_btn.click(on_clear, inputs=None, outputs=[sql_input, results, meta, plan])
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
outputs = [identify_view, app_view, id_msg, user_name, session_id]
|
| 210 |
+
enter_btn.click(do_login, inputs=[name_tb], outputs=outputs)
|
| 211 |
+
name_tb.submit(do_login, inputs=[name_tb], outputs=outputs)
|
| 212 |
+
|
| 213 |
+
def greet(name):
|
| 214 |
+
return f"**Hello, {name}!**"
|
| 215 |
+
user_name.change(greet, inputs=[user_name], outputs=[welcome_md])
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
if __name__ == "__main__":
|
| 219 |
+
demo.launch(share=True)
|
chat_helpers.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tiktoken
|
| 2 |
+
|
| 3 |
+
MAX_TOKENS = 16000
|
| 4 |
+
|
| 5 |
+
SQL_SYSTEM_PROMPT = (
|
| 6 |
+
"You are a helpful assistant that answers questions about PostgreSQL databases. "
|
| 7 |
+
"When you use search, include brief citations as links. "
|
| 8 |
+
"Use markdowns to separate the code and the text in your output. "
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
def get_db_sys_prompt():
|
| 12 |
+
with open("DB_SYS_PROMPT.txt", "r") as f:
|
| 13 |
+
return f.read()
|
| 14 |
+
|
| 15 |
+
def build_input_from_history(message, history):
|
| 16 |
+
parts = []
|
| 17 |
+
parts.append({"role": "system", "content": get_db_sys_prompt()})
|
| 18 |
+
# prior turns
|
| 19 |
+
for msg in history:
|
| 20 |
+
if msg["role"] == "user":
|
| 21 |
+
parts.append({"role": "user", "content": msg["content"]})
|
| 22 |
+
if msg["role"] == "assistant":
|
| 23 |
+
parts.append({"role": "assistant", "content": msg["content"]})
|
| 24 |
+
parts.append({"role": "user", "content": message})
|
| 25 |
+
|
| 26 |
+
parts = truncate_history(parts, MAX_TOKENS)
|
| 27 |
+
|
| 28 |
+
return parts
|
| 29 |
+
|
| 30 |
+
def count_tokens(messages, model="gpt-4.1"):
|
| 31 |
+
"""Count tokens in a list of messages using tiktoken."""
|
| 32 |
+
try:
|
| 33 |
+
enc = tiktoken.encoding_for_model(model)
|
| 34 |
+
except KeyError:
|
| 35 |
+
enc = tiktoken.get_encoding("cl100k_base")
|
| 36 |
+
num_tokens = 0
|
| 37 |
+
for msg in messages:
|
| 38 |
+
num_tokens += len(enc.encode(msg["content"]))
|
| 39 |
+
return num_tokens
|
| 40 |
+
|
| 41 |
+
def truncate_history(messages, max_tokens=MAX_TOKENS, model="gpt-4.1"):
|
| 42 |
+
while count_tokens(messages, model=model) > max_tokens and len(messages) > 2:
|
| 43 |
+
messages.pop(1)
|
| 44 |
+
return messages
|
locustfile.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# locustfile.py
|
| 2 |
+
import json, random, string, time
|
| 3 |
+
from locust import HttpUser, task, between
|
| 4 |
+
|
| 5 |
+
CHAT_PATH = "/e2e/chat"
|
| 6 |
+
SQL_PATH = "/e2e/sql"
|
| 7 |
+
|
| 8 |
+
SIMPLE_SQL = [
|
| 9 |
+
"SELECT * FROM sales",
|
| 10 |
+
"SELECT genre, COUNT(*) AS n FROM sales GROUP BY genre ORDER BY n DESC",
|
| 11 |
+
"SELECT title, runtime FROM sales WHERE runtime > 150 ORDER BY runtime DESC",
|
| 12 |
+
"SELECT AVG(metascore) FROM metadata",
|
| 13 |
+
"SELECT s.title, m.studio FROM sales s JOIN metadata m ON s.title=m.title LIMIT 200",
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
def rnd_name():
|
| 17 |
+
import string, random
|
| 18 |
+
return "U-" + "".join(random.choices(string.ascii_uppercase + string.digits, k=6))
|
| 19 |
+
|
| 20 |
+
class GradioDBUser(HttpUser):
|
| 21 |
+
wait_time = between(0.05, 0.25)
|
| 22 |
+
|
| 23 |
+
@task(3)
|
| 24 |
+
def chat(self):
|
| 25 |
+
msg = random.choice([
|
| 26 |
+
"Which table has the highest number of rows and why?",
|
| 27 |
+
"Write a SQL to list top-5 genres by revenue.",
|
| 28 |
+
"Explain how to compute ROI = worldwide_box_office / production_budget.",
|
| 29 |
+
"What's the average runtime by studio?",
|
| 30 |
+
])
|
| 31 |
+
payload = {
|
| 32 |
+
"message": msg,
|
| 33 |
+
"history": [
|
| 34 |
+
{"role": "user", "content": "Hi"},
|
| 35 |
+
{"role": "assistant", "content": "Hello!"},
|
| 36 |
+
],
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
t0 = time.perf_counter()
|
| 40 |
+
with self.client.post(
|
| 41 |
+
CHAT_PATH,
|
| 42 |
+
data=json.dumps(payload),
|
| 43 |
+
headers={"Content-Type": "application/json"},
|
| 44 |
+
name="chat",
|
| 45 |
+
catch_response=True,
|
| 46 |
+
) as r:
|
| 47 |
+
dt_ms = (time.perf_counter() - t0) * 1000
|
| 48 |
+
if r.status_code != 200:
|
| 49 |
+
r.failure(f"HTTP {r.status_code}: {r.text[:200]}")
|
| 50 |
+
else:
|
| 51 |
+
# Optionally assert JSON shape
|
| 52 |
+
try:
|
| 53 |
+
_ = r.json().get("output", "")
|
| 54 |
+
r.success()
|
| 55 |
+
except Exception as e:
|
| 56 |
+
r.failure(f"Bad JSON after {dt_ms:.1f}ms: {e}")
|
| 57 |
+
|
| 58 |
+
@task(2)
|
| 59 |
+
def sql(self):
|
| 60 |
+
payload = {"query": random.choice(SIMPLE_SQL), "limit": 200, "allow_writes": False}
|
| 61 |
+
t0 = time.perf_counter()
|
| 62 |
+
with self.client.post(
|
| 63 |
+
SQL_PATH,
|
| 64 |
+
data=json.dumps(payload),
|
| 65 |
+
headers={"Content-Type": "application/json"},
|
| 66 |
+
name="sql",
|
| 67 |
+
catch_response=True,
|
| 68 |
+
) as r:
|
| 69 |
+
dt_ms = (time.perf_counter() - t0) * 1000
|
| 70 |
+
if r.status_code != 200:
|
| 71 |
+
r.failure(f"HTTP {r.status_code}: {r.text[:200]}")
|
| 72 |
+
else:
|
| 73 |
+
try:
|
| 74 |
+
j = r.json()
|
| 75 |
+
# Optional sanity checks so bad responses are marked failures
|
| 76 |
+
if "rows" in j and isinstance(j["rows"], list):
|
| 77 |
+
r.success()
|
| 78 |
+
else:
|
| 79 |
+
r.failure(f"Unexpected JSON shape after {dt_ms:.1f}ms: {j}")
|
| 80 |
+
except Exception as e:
|
| 81 |
+
r.failure(f"Bad JSON after {dt_ms:.1f}ms: {e}")
|
logger.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json, uuid, pathlib, asyncio, datetime, re
|
| 3 |
+
|
| 4 |
+
DATA_DIR = pathlib.Path(os.getenv("APP_DATA_DIR", "./user_data"))
|
| 5 |
+
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 6 |
+
|
| 7 |
+
_name_re = re.compile(r"[^A-Za-z0-9._-]+")
|
| 8 |
+
|
| 9 |
+
def _slugify(name: str) -> str:
|
| 10 |
+
name = (name or "").strip().lower()
|
| 11 |
+
name = _name_re.sub("_", name)
|
| 12 |
+
return name or f"anon_{uuid.uuid4().hex[:8]}"
|
| 13 |
+
|
| 14 |
+
def _user_log_path(name: str) -> pathlib.Path:
|
| 15 |
+
return DATA_DIR / f"{_slugify(name)}.jsonl"
|
| 16 |
+
|
| 17 |
+
def _utc_now():
|
| 18 |
+
# ISO 8601 with 'Z'
|
| 19 |
+
return datetime.datetime.utcnow().replace(microsecond=0).isoformat() + "Z"
|
| 20 |
+
|
| 21 |
+
async def _append_jsonl(path: pathlib.Path, obj: dict):
|
| 22 |
+
"""
|
| 23 |
+
Append 1 line of JSON to the user's file without blocking the event loop.
|
| 24 |
+
"""
|
| 25 |
+
line = json.dumps(obj, ensure_ascii=False)
|
| 26 |
+
def _write():
|
| 27 |
+
with path.open("a", encoding="utf-8") as f:
|
| 28 |
+
f.write(line + "\n")
|
| 29 |
+
await asyncio.to_thread(_write)
|
| 30 |
+
|
| 31 |
+
async def log_event(user_name: str, session_id: str, kind: str, payload: dict):
|
| 32 |
+
"""
|
| 33 |
+
kind: "login" | "chat_user" | "chat_assistant" | "sql"
|
| 34 |
+
payload: arbitrary fields, we’ll add timestamp/ids.
|
| 35 |
+
"""
|
| 36 |
+
record = {
|
| 37 |
+
"ts": _utc_now(),
|
| 38 |
+
"user": user_name,
|
| 39 |
+
"session_id": session_id,
|
| 40 |
+
"kind": kind,
|
| 41 |
+
**payload,
|
| 42 |
+
}
|
| 43 |
+
await _append_jsonl(_user_log_path(user_name), record)
|
output.log
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
nohup: ignoring input
|
| 2 |
+
* Running on local URL: http://127.0.0.1:7860
|
| 3 |
+
* Running on public URL: https://e026dc1db28b1a0014.gradio.live
|
| 4 |
+
|
| 5 |
+
This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
|
server.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import uvicorn
|
| 5 |
+
|
| 6 |
+
from gradio_app import demo, respond_once
|
| 7 |
+
from sql_tab import run_sql
|
| 8 |
+
|
| 9 |
+
import math, uuid, decimal, datetime as dt
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from fastapi.responses import ORJSONResponse
|
| 13 |
+
|
| 14 |
+
import traceback, sys, logging
|
| 15 |
+
log = logging.getLogger("uvicorn.error")
|
| 16 |
+
|
| 17 |
+
app = FastAPI(default_response_class=ORJSONResponse)
|
| 18 |
+
|
| 19 |
+
def df_json_safe(df: pd.DataFrame) -> list[dict]:
|
| 20 |
+
# 1) kill Infs -> NaN
|
| 21 |
+
df = df.replace([np.inf, -np.inf], np.nan)
|
| 22 |
+
# 2) force object dtype so None can live in numeric cols
|
| 23 |
+
df = df.astype(object)
|
| 24 |
+
# 3) NaN -> None
|
| 25 |
+
df = df.where(pd.notnull(df), None)
|
| 26 |
+
|
| 27 |
+
def to_py(v):
|
| 28 |
+
# --- numbers ---
|
| 29 |
+
if isinstance(v, decimal.Decimal):
|
| 30 |
+
# convert to float; fall back to None if weird
|
| 31 |
+
try:
|
| 32 |
+
f = float(v)
|
| 33 |
+
if math.isnan(f) or math.isinf(f):
|
| 34 |
+
return None
|
| 35 |
+
return f
|
| 36 |
+
except Exception:
|
| 37 |
+
return None
|
| 38 |
+
if isinstance(v, np.floating):
|
| 39 |
+
f = float(v)
|
| 40 |
+
if math.isnan(f) or math.isinf(f):
|
| 41 |
+
return None
|
| 42 |
+
return f
|
| 43 |
+
if isinstance(v, np.integer):
|
| 44 |
+
return int(v)
|
| 45 |
+
if isinstance(v, (np.bool_,)):
|
| 46 |
+
return bool(v)
|
| 47 |
+
|
| 48 |
+
# --- datetimes / timedeltas ---
|
| 49 |
+
if isinstance(v, (pd.Timestamp, np.datetime64, dt.datetime, dt.date, dt.time)):
|
| 50 |
+
try:
|
| 51 |
+
# ensure ISO8601
|
| 52 |
+
return pd.to_datetime(v).isoformat()
|
| 53 |
+
except Exception:
|
| 54 |
+
return str(v)
|
| 55 |
+
if isinstance(v, (pd.Timedelta, dt.timedelta)):
|
| 56 |
+
return str(v)
|
| 57 |
+
|
| 58 |
+
# --- misc types you can get from Postgres ---
|
| 59 |
+
if isinstance(v, (bytes, bytearray, memoryview)):
|
| 60 |
+
try:
|
| 61 |
+
return bytes(v).decode("utf-8", "replace")
|
| 62 |
+
except Exception:
|
| 63 |
+
return str(v)
|
| 64 |
+
if isinstance(v, uuid.UUID):
|
| 65 |
+
return str(v)
|
| 66 |
+
|
| 67 |
+
# leave str, dict, list, None as-is
|
| 68 |
+
return v
|
| 69 |
+
|
| 70 |
+
records = df.to_dict(orient="records")
|
| 71 |
+
return [{k: to_py(v) for k, v in row.items()} for row in records]
|
| 72 |
+
|
| 73 |
+
class ChatReq(BaseModel):
|
| 74 |
+
message: str
|
| 75 |
+
history: list[dict] = []
|
| 76 |
+
|
| 77 |
+
class SqlReq(BaseModel):
|
| 78 |
+
query: str
|
| 79 |
+
limit: int = 200
|
| 80 |
+
allow_writes: bool = False
|
| 81 |
+
|
| 82 |
+
@app.get("/healthz")
|
| 83 |
+
def healthz():
|
| 84 |
+
return {"ok": True}
|
| 85 |
+
|
| 86 |
+
@app.post("/e2e/chat")
|
| 87 |
+
async def e2e_chat(req: ChatReq):
|
| 88 |
+
text = await respond_once(req.message, req.history)
|
| 89 |
+
return {"output": text}
|
| 90 |
+
|
| 91 |
+
@app.post("/e2e/sql")
|
| 92 |
+
def e2e_sql(req: SqlReq):
|
| 93 |
+
try:
|
| 94 |
+
df, meta, elapsed = run_sql(req.query, req.limit, req.allow_writes)
|
| 95 |
+
|
| 96 |
+
# Take only head for safety
|
| 97 |
+
head = df.head(min(len(df), 200))
|
| 98 |
+
|
| 99 |
+
# Log raw DF preview (before cleaning)
|
| 100 |
+
log.error("DEBUG DF (raw):\n%s", head.to_string())
|
| 101 |
+
|
| 102 |
+
rows = df_json_safe(head)
|
| 103 |
+
payload = {
|
| 104 |
+
"meta": str(meta),
|
| 105 |
+
"elapsed": float(elapsed) if elapsed == elapsed and not math.isinf(elapsed) else None,
|
| 106 |
+
"n": int(len(df)),
|
| 107 |
+
"rows": rows,
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
return ORJSONResponse(payload, headers={"X-Serializer": "orjson"})
|
| 111 |
+
except Exception as e:
|
| 112 |
+
# Log script name + stack + dataframe if available
|
| 113 |
+
log.error("Exception in %s", __file__)
|
| 114 |
+
traceback.print_exc(file=sys.stderr)
|
| 115 |
+
try:
|
| 116 |
+
log.error("Last DF snapshot:\n%s", head.to_string())
|
| 117 |
+
except Exception:
|
| 118 |
+
pass
|
| 119 |
+
raise
|
| 120 |
+
|
| 121 |
+
# Mount Gradio UI on "/"
|
| 122 |
+
mounted = gr.mount_gradio_app(app, demo, path="/")
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
# Run with multiple workers for concurrency in real tests (see section D)
|
| 126 |
+
uvicorn.run(mounted, host="0.0.0.0", port=7860)
|
sql_tab.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import time
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
import psycopg2.extras as extras
|
| 7 |
+
from psycopg2.pool import SimpleConnectionPool
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
DB_NAME = os.getenv("PGDATABASE", "mert")
|
| 11 |
+
DB_USER = os.getenv("PGUSER", "mert")
|
| 12 |
+
DB_PASS = os.getenv("POSTGRES_PASSWORD")
|
| 13 |
+
DB_HOST = os.getenv("PGHOST", "127.0.0.1")
|
| 14 |
+
DB_PORT = int(os.getenv("PGPORT", "5432"))
|
| 15 |
+
POOL_MAX = int(os.getenv("PG_POOL_MAX", "10"))
|
| 16 |
+
|
| 17 |
+
_pool: SimpleConnectionPool | None = None
|
| 18 |
+
|
| 19 |
+
def _get_pool():
|
| 20 |
+
global _pool
|
| 21 |
+
if _pool is None:
|
| 22 |
+
_pool = SimpleConnectionPool(
|
| 23 |
+
minconn=1, maxconn=POOL_MAX,
|
| 24 |
+
database=DB_NAME, user=DB_USER, password=DB_PASS,
|
| 25 |
+
host=DB_HOST, port=DB_PORT
|
| 26 |
+
)
|
| 27 |
+
return _pool
|
| 28 |
+
|
| 29 |
+
def _borrow_conn():
|
| 30 |
+
pool = _get_pool()
|
| 31 |
+
conn = pool.getconn()
|
| 32 |
+
conn.autocommit = True
|
| 33 |
+
try:
|
| 34 |
+
with conn.cursor() as cur:
|
| 35 |
+
cur.execute("SET statement_timeout = 10000;") # 10s
|
| 36 |
+
except Exception:
|
| 37 |
+
pass
|
| 38 |
+
return conn
|
| 39 |
+
|
| 40 |
+
def _return_conn(conn):
|
| 41 |
+
try:
|
| 42 |
+
_get_pool().putconn(conn)
|
| 43 |
+
except Exception:
|
| 44 |
+
try: conn.close()
|
| 45 |
+
except Exception: pass
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
WRITE_FIRST_KEYWORDS = {
|
| 49 |
+
"INSERT","UPDATE","DELETE","DROP","ALTER","CREATE","TRUNCATE",
|
| 50 |
+
"VACUUM","REINDEX","GRANT","REVOKE","MERGE","CALL","DO",
|
| 51 |
+
"ATTACH","DETACH"
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
def is_write_query(sql: str) -> bool:
|
| 55 |
+
"""
|
| 56 |
+
Returns True if the first *statement* is a write. Ignores function names like REPLACE().
|
| 57 |
+
Handles WITH ... (SELECT ...) vs WITH ... INSERT/UPDATE/DELETE/MERGE ...
|
| 58 |
+
"""
|
| 59 |
+
first_stmt = re.split(r";\s*", sql.strip(), maxsplit=1)[0]
|
| 60 |
+
|
| 61 |
+
# If it starts with WITH, decide based on the main statement following the CTEs
|
| 62 |
+
if re.match(r"^\s*WITH\b", first_stmt, flags=re.IGNORECASE):
|
| 63 |
+
# Heuristic: if an INSERT/UPDATE/DELETE/MERGE appears after the CTE block, treat as write
|
| 64 |
+
return bool(re.search(r"\)\s*(INSERT|UPDATE|DELETE|MERGE)\b", first_stmt, flags=re.IGNORECASE))
|
| 65 |
+
|
| 66 |
+
# Otherwise, just check the very first keyword
|
| 67 |
+
m = re.match(r"^\s*([A-Za-z]+)", first_stmt)
|
| 68 |
+
first_kw = m.group(1).upper() if m else ""
|
| 69 |
+
return first_kw in WRITE_FIRST_KEYWORDS
|
| 70 |
+
|
| 71 |
+
def enforce_limit(sql: str, limit: int) -> str:
|
| 72 |
+
"""
|
| 73 |
+
Adds LIMIT if:
|
| 74 |
+
- query starts with SELECT or WITH
|
| 75 |
+
- and no existing LIMIT present (naive but practical)
|
| 76 |
+
"""
|
| 77 |
+
first = sql.strip().strip(";")
|
| 78 |
+
if re.match(r"^(SELECT|WITH)\b", first, flags=re.IGNORECASE) and not re.search(r"\bLIMIT\b", first, flags=re.IGNORECASE):
|
| 79 |
+
return f"{first} LIMIT {int(limit)}"
|
| 80 |
+
return first
|
| 81 |
+
|
| 82 |
+
def run_sql(query: str, max_rows: int, allow_writes: bool):
|
| 83 |
+
|
| 84 |
+
if not query or not query.strip():
|
| 85 |
+
return pd.DataFrame(), "Provide a SQL query.", 0.0
|
| 86 |
+
|
| 87 |
+
if ";" in query.strip().rstrip(";"):
|
| 88 |
+
return pd.DataFrame(), "Multiple statements detected; please run one at a time.", 0.0
|
| 89 |
+
|
| 90 |
+
if not allow_writes and is_write_query(query):
|
| 91 |
+
return pd.DataFrame(), "Write operations are disabled. Enable the toggle to allow writes.", 0.0
|
| 92 |
+
|
| 93 |
+
sql_to_run = enforce_limit(query, max_rows)
|
| 94 |
+
started = time.perf_counter()
|
| 95 |
+
conn = None
|
| 96 |
+
try:
|
| 97 |
+
conn = _borrow_conn()
|
| 98 |
+
with conn.cursor(cursor_factory=extras.RealDictCursor) as cur:
|
| 99 |
+
cur.execute("SET LOCAL statement_timeout = 10000;")
|
| 100 |
+
cur.execute(sql_to_run)
|
| 101 |
+
rows = cur.fetchall() if cur.description else []
|
| 102 |
+
df = pd.DataFrame(rows)
|
| 103 |
+
except Exception as e:
|
| 104 |
+
return pd.DataFrame(), f"Error: {e}", 0.0
|
| 105 |
+
finally:
|
| 106 |
+
if conn: _return_conn(conn)
|
| 107 |
+
|
| 108 |
+
elapsed = time.perf_counter() - started
|
| 109 |
+
meta = f"Rows: {len(df)} | Time: {elapsed:.3f}s"
|
| 110 |
+
df.replace([np.inf, -np.inf], pd.NA, inplace=True)
|
| 111 |
+
df = df.where(pd.notnull(df), None)
|
| 112 |
+
return df, meta, elapsed
|