db_bot / flask_app.py
saman shrestha
removed session "
9e62ca3
from flask import Flask, g, render_template, request, jsonify, session
from flask_cors import CORS # Add this import
from helpers.GROQ import ConversationGROQ
from helpers.postgres import DatabaseConnection
from helpers.prompts import PromptManager
import re
import pandas as pd
app = Flask(__name__)
CORS(app, supports_credentials=True, origins="https://db-bot-amber.vercel.app/", allow_headers=["Content-Type", "Authorization", "Access-Control-Allow-Credentials"], expose_headers=["Set-Cookie"])
prompt_manager = PromptManager()
prompt_manager.load_prompt('base_schema', 'prompts/base_prompts.txt')
base_prompt = prompt_manager.get_prompt('base_schema')
def extract_sql_regex(input_string) -> str | None:
# Pattern to match SQL query within double quotes after "sql":
pattern = r'"sql":\s*"(.*?)"'
match = re.search(pattern, input_string)
if match:
return match.group(1)
else:
return None
@app.route('/chat', methods=['POST', 'OPTIONS']) # Add OPTIONS method
def chat():
# if request.method == 'OPTIONS':
# # Respond to preflight request
# response = app.make_default_options_response()
# response.headers['Access-Control-Allow-Headers'] = 'Content-Type'
# response.headers['Access-Control-Allow-Methods'] = 'POST'
# return response
data = request.json
# if 'DB_HOST' not in session:
# return jsonify({"error": "Database connection not established", "format": "json"}), 400
db_user = data.get('DB_USER', '')
db_host = data.get('DB_HOST', '')
db_port = data.get('DB_PORT', '')
db_name = data.get('DB_NAME', '')
db_password = data.get('DB_PASSWORD', '')
prompt = data.get('prompt', '')
missing_fields = []
if not prompt:
missing_fields.append('prompt')
if not db_user:
missing_fields.append('DB_USER')
if not db_host:
missing_fields.append('DB_HOST')
if not db_port:
missing_fields.append('DB_PORT')
if not db_name:
missing_fields.append('DB_NAME')
if not db_password:
missing_fields.append('DB_PASSWORD')
if missing_fields:
return jsonify({
"error": f"Missing credentials: {', '.join(missing_fields)}",
"format": "json"
}), 400
db = DatabaseConnection(db_host=db_host, db_port=db_port, db_name=db_name, db_user=db_user, db_password=db_password)
schema = db.execute_query('SELECT schema_name FROM information_schema.schemata;').fetchall()
schema = [schema[0] for schema in schema]
tables = db.execute_query('''SELECT
table_name,
json_object_agg(column_name, data_type) AS columns
FROM
information_schema.columns
WHERE
table_schema = 'public'
GROUP BY
table_name
ORDER BY
table_name;''').fetchall()
table_info = {table[0]: table[1] for table in tables}
full_prompt = base_prompt.format(schema_list=schema, tables=tables, table_info=table_info, user_question=prompt)
groq = ConversationGROQ(api_key='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')
groq.create_conversation(full_prompt)
response = groq.chat(prompt)
sql_query = extract_sql_regex(response)
if(sql_query is None):
print("No SQL query found")
return jsonify({"message": response, "response": response, "Sql": sql_query,"format": "json"}), 200
result = db.execute_query(sql_query)
print(sql_query, 'result')
row = result.fetchall()
df = pd.DataFrame(row, columns=[desc[0] for desc in result.description])
df = df.reset_index(drop=True)
print(df.to_markdown(index=False))
prompt = """
A user asked the following question:
{user_question}
Based on this question, a query was executed and returned the following data:
{df}
Please provide a clear and concise summary of this data in non-technical language.
Focus on the key insights and how they relate to the user's question.
Avoid using technical terms and present the information in a way that's easy for anyone to understand.
If there are any notable trends, patterns, or important points in the data, please highlight them.
If the data includes price or amount information, please also provide a brief comparison. For example, highlight the highest and lowest values, or compare average prices/amounts between different categories if applicable.
Additionally, if the data spans multiple time periods (e.g., different dates or years), please provide a brief overview of any trends or changes over time.
If applicable, include any relevant statistics or figures, but explain them in simple terms.
Your summary should be informative yet accessible to someone without a technical background.
""".format(user_question=prompt, df=df)
final_response = groq.chat(prompt)
print(final_response)
return jsonify({"message": final_response, "df": df.to_html(),"response": response, "sql": sql_query,"format": "json"}), 200
@app.route('/chat', methods=['POST'])
def query():
data = request.json
# Process the query here
print(data)
# For now, we'll just echo back the received data
return jsonify({"response": f"Received: {data}"})
@app.teardown_appcontext
def close_db(error):
db = g.pop('db', None)
if db is not None:
db.close()
if __name__ == '__main__':
app.run(debug=True, port=5001)