File size: 5,443 Bytes
6812fa3
eb31690
6812fa3
 
 
 
 
 
 
 
660f1ad
9e62ca3
660f1ad
 
6812fa3
 
 
9e62ca3
6812fa3
 
 
 
 
 
 
 
 
 
 
9e62ca3
 
 
 
 
 
 
 
 
 
6812fa3
9e62ca3
 
 
6812fa3
 
 
 
 
9e62ca3
6812fa3
9e62ca3
 
 
6812fa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e62ca3
6812fa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from flask import Flask, g, render_template, request, jsonify, session
from flask_cors import CORS  # Add this import

from helpers.GROQ import ConversationGROQ
from helpers.postgres import DatabaseConnection
from helpers.prompts import PromptManager
import re
import pandas as pd

app = Flask(__name__)
CORS(app, supports_credentials=True, origins="https://db-bot-amber.vercel.app/", allow_headers=["Content-Type", "Authorization", "Access-Control-Allow-Credentials"], expose_headers=["Set-Cookie"])



prompt_manager = PromptManager()
prompt_manager.load_prompt('base_schema', 'prompts/base_prompts.txt')
base_prompt = prompt_manager.get_prompt('base_schema')

def extract_sql_regex(input_string) -> str | None:
    # Pattern to match SQL query within double quotes after "sql":
    pattern = r'"sql":\s*"(.*?)"'
    
    match = re.search(pattern, input_string)
    if match:
        return match.group(1)
    else:
        return None



@app.route('/chat', methods=['POST', 'OPTIONS'])  # Add OPTIONS method
def chat():
    # if request.method == 'OPTIONS':
    #     # Respond to preflight request
    #     response = app.make_default_options_response()
    #     response.headers['Access-Control-Allow-Headers'] = 'Content-Type'
    #     response.headers['Access-Control-Allow-Methods'] = 'POST'
    #     return response

    data = request.json
    
    # if 'DB_HOST' not in session:
    #     return jsonify({"error": "Database connection not established", "format": "json"}), 400
    db_user = data.get('DB_USER', '')
    db_host = data.get('DB_HOST', '')
    db_port = data.get('DB_PORT', '')
    db_name = data.get('DB_NAME', '')
    db_password = data.get('DB_PASSWORD', '')
    prompt = data.get('prompt', '')
    missing_fields = []

    if not prompt:
        missing_fields.append('prompt')
    if not db_user:
        missing_fields.append('DB_USER')
    if not db_host:
        missing_fields.append('DB_HOST')
    if not db_port:
        missing_fields.append('DB_PORT')
    if not db_name:
        missing_fields.append('DB_NAME')
    if not db_password:
        missing_fields.append('DB_PASSWORD')

    if missing_fields:
        return jsonify({
            "error": f"Missing credentials: {', '.join(missing_fields)}",
            "format": "json"
        }), 400

    db = DatabaseConnection(db_host=db_host, db_port=db_port, db_name=db_name, db_user=db_user, db_password=db_password)
    schema = db.execute_query('SELECT schema_name FROM information_schema.schemata;').fetchall()
    schema = [schema[0] for schema in schema]

    tables = db.execute_query('''SELECT 
        table_name,
        json_object_agg(column_name, data_type) AS columns
    FROM 
        information_schema.columns
    WHERE 
        table_schema = 'public'
    GROUP BY 
        table_name
    ORDER BY 
        table_name;''').fetchall()
    table_info = {table[0]: table[1] for table in tables}
    full_prompt = base_prompt.format(schema_list=schema, tables=tables, table_info=table_info, user_question=prompt)
    
    groq = ConversationGROQ(api_key='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')
    groq.create_conversation(full_prompt)
    response = groq.chat(prompt)
    sql_query =  extract_sql_regex(response)
    if(sql_query is None):
        print("No SQL query found")
        return jsonify({"message": response, "response":  response, "Sql": sql_query,"format": "json"}), 200
    result = db.execute_query(sql_query)
    print(sql_query, 'result')
    row = result.fetchall()
    df = pd.DataFrame(row, columns=[desc[0] for desc in result.description])
    df = df.reset_index(drop=True)
    print(df.to_markdown(index=False))
    prompt = """
    A user asked the following question:
    {user_question}

    Based on this question, a query was executed and returned the following data:

    {df}

    Please provide a clear and concise summary of this data in non-technical language. 
    Focus on the key insights and how they relate to the user's question. 
    Avoid using technical terms and present the information in a way that's easy for anyone to understand.

    If there are any notable trends, patterns, or important points in the data, please highlight them.
    If the data includes price or amount information, please also provide a brief comparison. For example, highlight the highest and lowest values, or compare average prices/amounts between different categories if applicable.

    Additionally, if the data spans multiple time periods (e.g., different dates or years), please provide a brief overview of any trends or changes over time.
    If applicable, include any relevant statistics or figures, but explain them in simple terms.

    Your summary should be informative yet accessible to someone without a technical background.
    """.format(user_question=prompt, df=df)
    final_response = groq.chat(prompt)
    print(final_response)
    return jsonify({"message": final_response, "df": df.to_html(),"response":  response, "sql": sql_query,"format": "json"}), 200

@app.route('/chat', methods=['POST'])
def query():
    data = request.json
    # Process the query here
    print(data)
    # For now, we'll just echo back the received data
    return jsonify({"response": f"Received: {data}"})

@app.teardown_appcontext
def close_db(error):
    db = g.pop('db', None)
    if db is not None:
        db.close()

if __name__ == '__main__':
    app.run(debug=True, port=5001)