Spaces:
Sleeping
Sleeping
| from flask import Flask, g, render_template, request, jsonify, session | |
| from flask_cors import CORS # Add this import | |
| from helpers.GROQ import ConversationGROQ | |
| from helpers.postgres import DatabaseConnection | |
| from helpers.prompts import PromptManager | |
| import re | |
| import pandas as pd | |
| app = Flask(__name__) | |
| CORS(app, supports_credentials=True, origins="https://db-bot-amber.vercel.app/", allow_headers=["Content-Type", "Authorization", "Access-Control-Allow-Credentials"], expose_headers=["Set-Cookie"]) | |
| prompt_manager = PromptManager() | |
| prompt_manager.load_prompt('base_schema', 'prompts/base_prompts.txt') | |
| base_prompt = prompt_manager.get_prompt('base_schema') | |
| def extract_sql_regex(input_string) -> str | None: | |
| # Pattern to match SQL query within double quotes after "sql": | |
| pattern = r'"sql":\s*"(.*?)"' | |
| match = re.search(pattern, input_string) | |
| if match: | |
| return match.group(1) | |
| else: | |
| return None | |
| # Add OPTIONS method | |
| def chat(): | |
| # if request.method == 'OPTIONS': | |
| # # Respond to preflight request | |
| # response = app.make_default_options_response() | |
| # response.headers['Access-Control-Allow-Headers'] = 'Content-Type' | |
| # response.headers['Access-Control-Allow-Methods'] = 'POST' | |
| # return response | |
| data = request.json | |
| # if 'DB_HOST' not in session: | |
| # return jsonify({"error": "Database connection not established", "format": "json"}), 400 | |
| db_user = data.get('DB_USER', '') | |
| db_host = data.get('DB_HOST', '') | |
| db_port = data.get('DB_PORT', '') | |
| db_name = data.get('DB_NAME', '') | |
| db_password = data.get('DB_PASSWORD', '') | |
| prompt = data.get('prompt', '') | |
| missing_fields = [] | |
| if not prompt: | |
| missing_fields.append('prompt') | |
| if not db_user: | |
| missing_fields.append('DB_USER') | |
| if not db_host: | |
| missing_fields.append('DB_HOST') | |
| if not db_port: | |
| missing_fields.append('DB_PORT') | |
| if not db_name: | |
| missing_fields.append('DB_NAME') | |
| if not db_password: | |
| missing_fields.append('DB_PASSWORD') | |
| if missing_fields: | |
| return jsonify({ | |
| "error": f"Missing credentials: {', '.join(missing_fields)}", | |
| "format": "json" | |
| }), 400 | |
| db = DatabaseConnection(db_host=db_host, db_port=db_port, db_name=db_name, db_user=db_user, db_password=db_password) | |
| schema = db.execute_query('SELECT schema_name FROM information_schema.schemata;').fetchall() | |
| schema = [schema[0] for schema in schema] | |
| tables = db.execute_query('''SELECT | |
| table_name, | |
| json_object_agg(column_name, data_type) AS columns | |
| FROM | |
| information_schema.columns | |
| WHERE | |
| table_schema = 'public' | |
| GROUP BY | |
| table_name | |
| ORDER BY | |
| table_name;''').fetchall() | |
| table_info = {table[0]: table[1] for table in tables} | |
| full_prompt = base_prompt.format(schema_list=schema, tables=tables, table_info=table_info, user_question=prompt) | |
| groq = ConversationGROQ(api_key='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') | |
| groq.create_conversation(full_prompt) | |
| response = groq.chat(prompt) | |
| sql_query = extract_sql_regex(response) | |
| if(sql_query is None): | |
| print("No SQL query found") | |
| return jsonify({"message": response, "response": response, "Sql": sql_query,"format": "json"}), 200 | |
| result = db.execute_query(sql_query) | |
| print(sql_query, 'result') | |
| row = result.fetchall() | |
| df = pd.DataFrame(row, columns=[desc[0] for desc in result.description]) | |
| df = df.reset_index(drop=True) | |
| print(df.to_markdown(index=False)) | |
| prompt = """ | |
| A user asked the following question: | |
| {user_question} | |
| Based on this question, a query was executed and returned the following data: | |
| {df} | |
| Please provide a clear and concise summary of this data in non-technical language. | |
| Focus on the key insights and how they relate to the user's question. | |
| Avoid using technical terms and present the information in a way that's easy for anyone to understand. | |
| If there are any notable trends, patterns, or important points in the data, please highlight them. | |
| If the data includes price or amount information, please also provide a brief comparison. For example, highlight the highest and lowest values, or compare average prices/amounts between different categories if applicable. | |
| Additionally, if the data spans multiple time periods (e.g., different dates or years), please provide a brief overview of any trends or changes over time. | |
| If applicable, include any relevant statistics or figures, but explain them in simple terms. | |
| Your summary should be informative yet accessible to someone without a technical background. | |
| """.format(user_question=prompt, df=df) | |
| final_response = groq.chat(prompt) | |
| print(final_response) | |
| return jsonify({"message": final_response, "df": df.to_html(),"response": response, "sql": sql_query,"format": "json"}), 200 | |
| def query(): | |
| data = request.json | |
| # Process the query here | |
| print(data) | |
| # For now, we'll just echo back the received data | |
| return jsonify({"response": f"Received: {data}"}) | |
| def close_db(error): | |
| db = g.pop('db', None) | |
| if db is not None: | |
| db.close() | |
| if __name__ == '__main__': | |
| app.run(debug=True, port=5001) | |