Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,9 +12,6 @@ from dotenv import load_dotenv
|
|
| 12 |
load_dotenv()
|
| 13 |
userdata = os.environ
|
| 14 |
|
| 15 |
-
DATA_DIR = Path(os.getcwd()) / "data"
|
| 16 |
-
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 17 |
-
|
| 18 |
|
| 19 |
def chat_with_groq(client:groq.Groq,
|
| 20 |
prompt:str,
|
|
@@ -49,19 +46,29 @@ def chat_with_groq(client:groq.Groq,
|
|
| 49 |
# logger.info(f"Completion: {completion}")
|
| 50 |
return completion.choices[0].message.content
|
| 51 |
|
| 52 |
-
def execute_duckdb_query(query:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
try:
|
| 54 |
conn = duckdb.connect(database=":memory:", read_only=False)
|
| 55 |
-
|
| 56 |
-
# Load all CSV files from the data directory
|
| 57 |
-
for csv_file in DATA_DIR.glob("*.csv"):
|
| 58 |
-
table_name = csv_file.stem
|
| 59 |
-
conn.execute(f"CREATE TABLE {table_name} AS SELECT * FROM read_csv_auto('{csv_file}')")
|
| 60 |
-
|
| 61 |
query_result = conn.execute(query).fetch_df().reset_index()
|
|
|
|
| 62 |
return query_result
|
| 63 |
except Exception as e:
|
| 64 |
-
print(f"Error
|
|
|
|
| 65 |
raise e
|
| 66 |
def get_summarization(client:groq.Groq,
|
| 67 |
use_question:str,
|
|
@@ -251,6 +258,10 @@ base_prompt = """
|
|
| 251 |
* Ensure that the entire output is returned on only one single line
|
| 252 |
* Keep your query as simple and straightforward as possible; do not use subqueries
|
| 253 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
user_question = """"""
|
| 255 |
|
| 256 |
# And some rules for querying the dataset:
|
|
@@ -260,32 +271,39 @@ user_question = """"""
|
|
| 260 |
# * Valid values for product_name include 'Tesla','iPhone' and 'Humane pin'
|
| 261 |
|
| 262 |
|
| 263 |
-
def upload_file(files) ->
|
|
|
|
| 264 |
model = "llama3-8b-8192"
|
| 265 |
-
api_key:
|
| 266 |
-
|
| 267 |
-
|
|
|
|
| 268 |
files = [files]
|
| 269 |
-
|
| 270 |
-
uploaded_files = []
|
| 271 |
stored_table_descriptions = []
|
| 272 |
-
|
| 273 |
for file in files:
|
| 274 |
filename = Path(file.name).name
|
| 275 |
-
path =
|
| 276 |
|
| 277 |
# Copy the content of the temporary file to our destination
|
| 278 |
-
|
|
|
|
| 279 |
|
| 280 |
-
|
| 281 |
-
table_description = identify_column_datatypes_to_SQL_DEF(pd.read_csv(path),
|
| 282 |
-
desc =
|
| 283 |
stored_table_descriptions.append(desc)
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
def user_prompt_sanitization(user_prompt:str)->str:
|
| 291 |
guide = """
|
|
@@ -308,52 +326,56 @@ def user_prompt_sanitization(user_prompt:str)->str:
|
|
| 308 |
client = groq.Groq(api_key=api_key)
|
| 309 |
return chat_with_groq(client,formatted_guide,"llama3-70b-8192",None)
|
| 310 |
|
| 311 |
-
def queryModel(user_prompt:
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
try:
|
| 331 |
-
|
| 332 |
-
"type": "json_object"
|
| 333 |
-
})
|
| 334 |
except Exception as e:
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
|
|
|
|
|
|
| 357 |
|
| 358 |
with gr.Blocks() as demo:
|
| 359 |
gr.Markdown("# CSV Database Query Interface")
|
|
@@ -363,28 +385,16 @@ with gr.Blocks() as demo:
|
|
| 363 |
upload_button = gr.Button("Load CSV Files")
|
| 364 |
upload_output = gr.Textbox(label="Upload Status", lines=5)
|
| 365 |
|
| 366 |
-
|
| 367 |
-
result = upload_file(files)
|
| 368 |
-
uploaded_files.extend(result["files"])
|
| 369 |
-
return result["descriptions"], uploaded_files
|
| 370 |
-
|
| 371 |
-
upload_button.click(handle_upload,
|
| 372 |
-
inputs=[file_output, uploaded_files],
|
| 373 |
-
outputs=[upload_output, uploaded_files])
|
| 374 |
-
|
| 375 |
with gr.Tab("Query Interface"):
|
| 376 |
chatbot = gr.Chatbot()
|
| 377 |
with gr.Row():
|
| 378 |
user_input = gr.Textbox(label="Enter your question")
|
| 379 |
submit_button = gr.Button("Submit")
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
return queryModel(user_prompt, files)
|
| 383 |
|
| 384 |
-
submit_button.click(query_model_with_files,
|
| 385 |
-
inputs=[user_input, uploaded_files],
|
| 386 |
-
outputs=chatbot)
|
| 387 |
|
| 388 |
-
demo.launch()
|
| 389 |
|
| 390 |
|
|
|
|
| 12 |
load_dotenv()
|
| 13 |
userdata = os.environ
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
def chat_with_groq(client:groq.Groq,
|
| 17 |
prompt:str,
|
|
|
|
| 46 |
# logger.info(f"Completion: {completion}")
|
| 47 |
return completion.choices[0].message.content
|
| 48 |
|
| 49 |
+
def execute_duckdb_query(query:str)->pd.DataFrame:
|
| 50 |
+
"""
|
| 51 |
+
Execute a DuckDB query and return the result as a pandas DataFrame.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
query (str): The DuckDB query to execute.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
pd.DataFrame: The result of the query as a pandas DataFrame.
|
| 58 |
+
"""
|
| 59 |
+
original_cwd = os.getcwd()
|
| 60 |
+
print(f"PATH:{original_cwd}")
|
| 61 |
+
os.chdir('data')
|
| 62 |
+
print(f"PATH:{os.getcwd()}")
|
| 63 |
+
|
| 64 |
try:
|
| 65 |
conn = duckdb.connect(database=":memory:", read_only=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
query_result = conn.execute(query).fetch_df().reset_index()
|
| 67 |
+
os.chdir(original_cwd)
|
| 68 |
return query_result
|
| 69 |
except Exception as e:
|
| 70 |
+
print(f"Error: {e}")
|
| 71 |
+
os.chdir(original_cwd)
|
| 72 |
raise e
|
| 73 |
def get_summarization(client:groq.Groq,
|
| 74 |
use_question:str,
|
|
|
|
| 258 |
* Ensure that the entire output is returned on only one single line
|
| 259 |
* Keep your query as simple and straightforward as possible; do not use subqueries
|
| 260 |
"""
|
| 261 |
+
table_description = """"""
|
| 262 |
+
tables_string = """"""
|
| 263 |
+
table_1 = """"""
|
| 264 |
+
table_1_wt_xt = """"""
|
| 265 |
user_question = """"""
|
| 266 |
|
| 267 |
# And some rules for querying the dataset:
|
|
|
|
| 271 |
# * Valid values for product_name include 'Tesla','iPhone' and 'Humane pin'
|
| 272 |
|
| 273 |
|
| 274 |
+
def upload_file(files) -> List[str]:
|
| 275 |
+
# will have to change to the private system is initiializes
|
| 276 |
model = "llama3-8b-8192"
|
| 277 |
+
api_key:str=userdata.get("GROQ_API_KEY")
|
| 278 |
+
data_dir = Path("data")
|
| 279 |
+
data_dir.mkdir(parents=True, exist_ok=True)
|
| 280 |
+
if type(files) == str:
|
| 281 |
files = [files]
|
| 282 |
+
stored_paths = []
|
|
|
|
| 283 |
stored_table_descriptions = []
|
| 284 |
+
tables = []
|
| 285 |
for file in files:
|
| 286 |
filename = Path(file.name).name
|
| 287 |
+
path = data_dir / filename
|
| 288 |
|
| 289 |
# Copy the content of the temporary file to our destination
|
| 290 |
+
with open(file.name, "rb") as source, open(path, "wb") as destination:
|
| 291 |
+
destination.write(source.read())
|
| 292 |
|
| 293 |
+
stored_paths.append(str(path.absolute()))
|
| 294 |
+
table_description = identify_column_datatypes_to_SQL_DEF(pd.read_csv(path),api_key,model)
|
| 295 |
+
desc = "Table: " + filename + "\n Columns:\n" + table_description
|
| 296 |
stored_table_descriptions.append(desc)
|
| 297 |
+
tables.append(filename)
|
| 298 |
+
# constructing a string
|
| 299 |
+
tables_string = join_with_and(tables)
|
| 300 |
+
final = "\n".join(stored_table_descriptions)
|
| 301 |
+
table_1_wt_xt = tables[0].split('.')[0]
|
| 302 |
+
table_description = final
|
| 303 |
+
tables_string = tables_string
|
| 304 |
+
table_1 = tables[0]
|
| 305 |
+
table_1_wt_xt = table_1_wt_xt
|
| 306 |
+
return final
|
| 307 |
|
| 308 |
def user_prompt_sanitization(user_prompt:str)->str:
|
| 309 |
guide = """
|
|
|
|
| 326 |
client = groq.Groq(api_key=api_key)
|
| 327 |
return chat_with_groq(client,formatted_guide,"llama3-70b-8192",None)
|
| 328 |
|
| 329 |
+
def queryModel(user_prompt:str,model:str = "llama3-70b-8192",api_key:str=userdata.get("GROQ_API_KEY")):
|
| 330 |
+
client = groq.Groq(api_key=api_key)
|
| 331 |
+
user_prompt = user_prompt_sanitization(user_prompt)
|
| 332 |
+
print(user_prompt)
|
| 333 |
+
full_prompt = base_prompt.format(
|
| 334 |
+
user_question=user_prompt,
|
| 335 |
+
table_description=table_description,
|
| 336 |
+
tables=tables_string,
|
| 337 |
+
table_1=table_1,
|
| 338 |
+
table_1_wt_xt=table_1_wt_xt
|
| 339 |
+
)
|
| 340 |
+
try:
|
| 341 |
+
response = chat_with_groq(client,full_prompt,model,{
|
| 342 |
+
"type":"json_object"
|
| 343 |
+
})
|
| 344 |
+
except Exception as e:
|
| 345 |
+
return [(
|
| 346 |
+
"Groq Advisor",
|
| 347 |
+
"Error: " + str(e)
|
| 348 |
+
)]
|
| 349 |
+
response = json.loads(response)
|
| 350 |
+
if "sql" in response:
|
| 351 |
+
sql_query = response["sql"]
|
| 352 |
try:
|
| 353 |
+
results_df = execute_duckdb_query(sql_query)
|
|
|
|
|
|
|
| 354 |
except Exception as e:
|
| 355 |
+
return [(
|
| 356 |
+
"Groq Advisor",
|
| 357 |
+
"Error: " + str(e)
|
| 358 |
+
)]
|
| 359 |
+
|
| 360 |
+
fotmatted_sql_query = sqlparse.format(sql_query, reindent=True, keyword_case='upper')
|
| 361 |
+
query_n_results = "SQL Query: " + fotmatted_sql_query + "\n\n" + results_df.to_markdown()
|
| 362 |
+
summarization = get_summarization(client,user_prompt,results_df,model)
|
| 363 |
+
query_n_results += "\n\n" + summarization
|
| 364 |
+
|
| 365 |
+
return [(
|
| 366 |
+
"Groq Advisor",
|
| 367 |
+
query_n_results
|
| 368 |
+
)]
|
| 369 |
+
elif "error" in response:
|
| 370 |
+
return [(
|
| 371 |
+
"Groq Advisor",
|
| 372 |
+
"Error: " + response["error"]
|
| 373 |
+
)]
|
| 374 |
+
else:
|
| 375 |
+
return [(
|
| 376 |
+
"Groq Advisor",
|
| 377 |
+
"Error: Unknown error"
|
| 378 |
+
)]
|
| 379 |
|
| 380 |
with gr.Blocks() as demo:
|
| 381 |
gr.Markdown("# CSV Database Query Interface")
|
|
|
|
| 385 |
upload_button = gr.Button("Load CSV Files")
|
| 386 |
upload_output = gr.Textbox(label="Upload Status", lines=5)
|
| 387 |
|
| 388 |
+
upload_button.click(upload_file, inputs=file_output, outputs=upload_output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
with gr.Tab("Query Interface"):
|
| 390 |
chatbot = gr.Chatbot()
|
| 391 |
with gr.Row():
|
| 392 |
user_input = gr.Textbox(label="Enter your question")
|
| 393 |
submit_button = gr.Button("Submit")
|
| 394 |
+
submit_button.click(queryModel, inputs=[user_input], outputs=chatbot)
|
| 395 |
+
|
|
|
|
| 396 |
|
|
|
|
|
|
|
|
|
|
| 397 |
|
| 398 |
+
demo.launch(share=True)
|
| 399 |
|
| 400 |
|