improve query functions

#40
by nolanzandi - opened
functions/query_functions.py CHANGED
@@ -81,8 +81,8 @@ class PostgreSQLQuery:
81
 
82
 
83
 
84
- def sql_query_func(queries: List[str], session_hash, args, **kwargs):
85
- sql_query = PostgreSQLQuery(args[0], args[1], args[2], args[3], args[4])
86
  try:
87
  result = sql_query.run(queries, session_hash)
88
  print("RESULT")
@@ -150,8 +150,8 @@ class DocDBQuery:
150
 
151
 
152
 
153
- def doc_db_query_func(aggregation_pipeline: List[str], db_collection: AnyStr, session_hash, args, **kwargs):
154
- doc_db_query = DocDBQuery(args[0], args[1])
155
  try:
156
  result = doc_db_query.run(aggregation_pipeline, db_collection, session_hash)
157
  print("RESULT")
@@ -206,10 +206,10 @@ class GraphQLQuery:
206
 
207
 
208
 
209
- def graphql_query_func(graphql_query: AnyStr, session_hash, args, **kwargs):
210
  graphql_object = GraphQLQuery()
211
  try:
212
- result = graphql_object.run(graphql_query, args[0], args[1], args[2], session_hash)
213
  print("RESULT")
214
  if len(result["results"][0]) > 1000:
215
  print("QUERY TOO LARGE")
 
81
 
82
 
83
 
84
+ def sql_query_func(queries: List[str], session_hash, db_url, db_port, db_user, db_pass, db_name, **kwargs):
85
+ sql_query = PostgreSQLQuery(db_url, db_port, db_user, db_pass, db_name)
86
  try:
87
  result = sql_query.run(queries, session_hash)
88
  print("RESULT")
 
150
 
151
 
152
 
153
+ def doc_db_query_func(aggregation_pipeline: List[str], db_collection: AnyStr, session_hash, connection_string, doc_db_name, **kwargs):
154
+ doc_db_query = DocDBQuery(connection_string, doc_db_name)
155
  try:
156
  result = doc_db_query.run(aggregation_pipeline, db_collection, session_hash)
157
  print("RESULT")
 
206
 
207
 
208
 
209
+ def graphql_query_func(graphql_query: AnyStr, session_hash, graphql_api_string, graphql_api_token, graphql_token_header, **kwargs):
210
  graphql_object = GraphQLQuery()
211
  try:
212
+ result = graphql_object.run(graphql_query, graphql_api_string, graphql_api_token, graphql_token_header, session_hash)
213
  print("RESULT")
214
  if len(result["results"][0]) > 1000:
215
  print("QUERY TOO LARGE")
templates/sql_db.py CHANGED
@@ -1,6 +1,6 @@
1
  import ast
2
  import gradio as gr
3
- from functions import example_question_generator, chatbot_func
4
  from data_sources import connect_sql_db
5
  from utils import message_dict
6
 
@@ -55,7 +55,7 @@ with gr.Blocks() as demo:
55
  ]
56
  else:
57
  try:
58
- generated_examples = ast.literal_eval(example_question_generator(request.session_hash, 'sql', sql_db_name, process_message[2], ""))
59
  example_questions = [
60
  ["Describe the dataset"]
61
  ]
@@ -75,18 +75,16 @@ with gr.Blocks() as demo:
75
  db_user = gr.Textbox(visible=False, value=sql_user)
76
  db_pass = gr.Textbox(visible=False, value=sql_pass)
77
  db_name = gr.Textbox(visible=False, value=sql_db_name)
78
- titles = gr.Textbox(value=process_message[2], interactive=False, label="SQL Tables")
79
- data_source = gr.Textbox(visible=False, value='sql')
80
- schema = gr.Textbox(visible=False, value='')
81
  bot = gr.Chatbot(type='messages', label="SQL DB Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
82
  chat = gr.ChatInterface(
83
- fn=chatbot_func,
84
  type='messages',
85
  chatbot=bot,
86
  title="Chat with your Database",
87
  examples=example_questions,
88
  concurrency_limit=None,
89
- additional_inputs=[session_hash, data_source, titles, schema, db_url, db_port, db_user, db_pass, db_name]
90
  )
91
 
92
  def process_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash):
 
1
  import ast
2
  import gradio as gr
3
+ from functions import sql_example_question_generator, sql_chatbot_with_fc
4
  from data_sources import connect_sql_db
5
  from utils import message_dict
6
 
 
55
  ]
56
  else:
57
  try:
58
+ generated_examples = ast.literal_eval(sql_example_question_generator(request.session_hash, process_message[2], sql_db_name))
59
  example_questions = [
60
  ["Describe the dataset"]
61
  ]
 
75
  db_user = gr.Textbox(visible=False, value=sql_user)
76
  db_pass = gr.Textbox(visible=False, value=sql_pass)
77
  db_name = gr.Textbox(visible=False, value=sql_db_name)
78
+ db_tables = gr.Textbox(value=process_message[2], interactive=False, label="SQL Tables")
 
 
79
  bot = gr.Chatbot(type='messages', label="SQL DB Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
80
  chat = gr.ChatInterface(
81
+ fn=sql_chatbot_with_fc,
82
  type='messages',
83
  chatbot=bot,
84
  title="Chat with your Database",
85
  examples=example_questions,
86
  concurrency_limit=None,
87
+ additional_inputs=[session_hash, db_url, db_port, db_user, db_pass, db_name, db_tables]
88
  )
89
 
90
  def process_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash):