Spaces:
Sleeping
Sleeping
saman shrestha
commited on
Commit
·
9e62ca3
1
Parent(s):
660f1ad
removed session "
Browse files- flask_app.py +21 -55
flask_app.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
from flask import Flask, g, render_template, request, jsonify, session
|
| 2 |
from flask_cors import CORS # Add this import
|
| 3 |
-
import os
|
| 4 |
|
| 5 |
from helpers.GROQ import ConversationGROQ
|
| 6 |
from helpers.postgres import DatabaseConnection
|
|
@@ -10,15 +9,13 @@ import pandas as pd
|
|
| 10 |
|
| 11 |
app = Flask(__name__)
|
| 12 |
CORS(app, supports_credentials=True, origins="https://db-bot-amber.vercel.app/", allow_headers=["Content-Type", "Authorization", "Access-Control-Allow-Credentials"], expose_headers=["Set-Cookie"])
|
| 13 |
-
|
| 14 |
-
app.config['SESSION_COOKIE_SAMESITE'] = 'None'
|
| 15 |
-
app.config['SESSION_COOKIE_SECURE'] = False
|
| 16 |
-
app.config['SESSION_COOKIE_HTTPONLY'] = True
|
| 17 |
|
| 18 |
|
| 19 |
prompt_manager = PromptManager()
|
| 20 |
prompt_manager.load_prompt('base_schema', 'prompts/base_prompts.txt')
|
| 21 |
base_prompt = prompt_manager.get_prompt('base_schema')
|
|
|
|
| 22 |
def extract_sql_regex(input_string) -> str | None:
|
| 23 |
# Pattern to match SQL query within double quotes after "sql":
|
| 24 |
pattern = r'"sql":\s*"(.*?)"'
|
|
@@ -29,27 +26,31 @@ def extract_sql_regex(input_string) -> str | None:
|
|
| 29 |
else:
|
| 30 |
return None
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
@app.route('/', methods=['POST'])
|
| 45 |
-
def index():
|
| 46 |
data = request.json
|
|
|
|
|
|
|
|
|
|
| 47 |
db_user = data.get('DB_USER', '')
|
| 48 |
db_host = data.get('DB_HOST', '')
|
| 49 |
db_port = data.get('DB_PORT', '')
|
| 50 |
db_name = data.get('DB_NAME', '')
|
| 51 |
db_password = data.get('DB_PASSWORD', '')
|
|
|
|
| 52 |
missing_fields = []
|
|
|
|
|
|
|
|
|
|
| 53 |
if not db_user:
|
| 54 |
missing_fields.append('DB_USER')
|
| 55 |
if not db_host:
|
|
@@ -67,42 +68,7 @@ def index():
|
|
| 67 |
"format": "json"
|
| 68 |
}), 400
|
| 69 |
|
| 70 |
-
|
| 71 |
-
session['DB_HOST'] = db_host
|
| 72 |
-
session['DB_PORT'] = db_port
|
| 73 |
-
session['DB_NAME'] = db_name
|
| 74 |
-
session['DB_USER'] = db_user
|
| 75 |
-
session['DB_PASSWORD'] = db_password
|
| 76 |
-
|
| 77 |
-
# Test the connection
|
| 78 |
-
try:
|
| 79 |
-
db = get_db()
|
| 80 |
-
if db is None:
|
| 81 |
-
return jsonify({"error": "Database connection failed", "format": "json"}), 500
|
| 82 |
-
return jsonify({"message": "Database connection successful", "format": "json"}), 200
|
| 83 |
-
except Exception as e:
|
| 84 |
-
return jsonify({"error": f"Database connection failed: {str(e)}", "format": "json"}), 500
|
| 85 |
-
|
| 86 |
-
@app.route('/chat', methods=['POST', 'OPTIONS']) # Add OPTIONS method
|
| 87 |
-
def chat():
|
| 88 |
-
if request.method == 'OPTIONS':
|
| 89 |
-
# Respond to preflight request
|
| 90 |
-
response = app.make_default_options_response()
|
| 91 |
-
response.headers['Access-Control-Allow-Headers'] = 'Content-Type'
|
| 92 |
-
response.headers['Access-Control-Allow-Methods'] = 'POST'
|
| 93 |
-
return response
|
| 94 |
-
|
| 95 |
-
data = request.json
|
| 96 |
-
|
| 97 |
-
if 'DB_HOST' not in session:
|
| 98 |
-
return jsonify({"error": "Database connection not established", "format": "json"}), 400
|
| 99 |
-
|
| 100 |
-
prompt = data.get('prompt', '')
|
| 101 |
-
if not prompt:
|
| 102 |
-
return jsonify({"error": "Prompt is required", "format": "json"}), 400
|
| 103 |
-
|
| 104 |
-
db = get_db()
|
| 105 |
-
|
| 106 |
schema = db.execute_query('SELECT schema_name FROM information_schema.schemata;').fetchall()
|
| 107 |
schema = [schema[0] for schema in schema]
|
| 108 |
|
|
|
|
| 1 |
from flask import Flask, g, render_template, request, jsonify, session
|
| 2 |
from flask_cors import CORS # Add this import
|
|
|
|
| 3 |
|
| 4 |
from helpers.GROQ import ConversationGROQ
|
| 5 |
from helpers.postgres import DatabaseConnection
|
|
|
|
| 9 |
|
| 10 |
app = Flask(__name__)
|
| 11 |
CORS(app, supports_credentials=True, origins="https://db-bot-amber.vercel.app/", allow_headers=["Content-Type", "Authorization", "Access-Control-Allow-Credentials"], expose_headers=["Set-Cookie"])
|
| 12 |
+
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
prompt_manager = PromptManager()
|
| 16 |
prompt_manager.load_prompt('base_schema', 'prompts/base_prompts.txt')
|
| 17 |
base_prompt = prompt_manager.get_prompt('base_schema')
|
| 18 |
+
|
| 19 |
def extract_sql_regex(input_string) -> str | None:
|
| 20 |
# Pattern to match SQL query within double quotes after "sql":
|
| 21 |
pattern = r'"sql":\s*"(.*?)"'
|
|
|
|
| 26 |
else:
|
| 27 |
return None
|
| 28 |
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@app.route('/chat', methods=['POST', 'OPTIONS']) # Add OPTIONS method
|
| 32 |
+
def chat():
|
| 33 |
+
# if request.method == 'OPTIONS':
|
| 34 |
+
# # Respond to preflight request
|
| 35 |
+
# response = app.make_default_options_response()
|
| 36 |
+
# response.headers['Access-Control-Allow-Headers'] = 'Content-Type'
|
| 37 |
+
# response.headers['Access-Control-Allow-Methods'] = 'POST'
|
| 38 |
+
# return response
|
| 39 |
+
|
|
|
|
|
|
|
|
|
|
| 40 |
data = request.json
|
| 41 |
+
|
| 42 |
+
# if 'DB_HOST' not in session:
|
| 43 |
+
# return jsonify({"error": "Database connection not established", "format": "json"}), 400
|
| 44 |
db_user = data.get('DB_USER', '')
|
| 45 |
db_host = data.get('DB_HOST', '')
|
| 46 |
db_port = data.get('DB_PORT', '')
|
| 47 |
db_name = data.get('DB_NAME', '')
|
| 48 |
db_password = data.get('DB_PASSWORD', '')
|
| 49 |
+
prompt = data.get('prompt', '')
|
| 50 |
missing_fields = []
|
| 51 |
+
|
| 52 |
+
if not prompt:
|
| 53 |
+
missing_fields.append('prompt')
|
| 54 |
if not db_user:
|
| 55 |
missing_fields.append('DB_USER')
|
| 56 |
if not db_host:
|
|
|
|
| 68 |
"format": "json"
|
| 69 |
}), 400
|
| 70 |
|
| 71 |
+
db = DatabaseConnection(db_host=db_host, db_port=db_port, db_name=db_name, db_user=db_user, db_password=db_password)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
schema = db.execute_query('SELECT schema_name FROM information_schema.schemata;').fetchall()
|
| 73 |
schema = [schema[0] for schema in schema]
|
| 74 |
|