TEXT-2-SQL-CHATBOT / backend_app.py
ViditOstwal's picture
Final Commit
7a8984c
import json
import sqlite3
from pydantic import BaseModel
from openai import OpenAI
from system_prompt import SYSTEM_PROMPT
import re
from csv_to_sqlite import load_csv_to_sqlite
try:
client = OpenAI()
except Exception as e:
print(f"Error initializing OpenAI client: {e}")
print("Please ensure OPENAI_API_KEY environment variable is set.")
DB_FILE = "pharma_data.db"
GPT_MODEL = "gpt-4o-mini"
load_csv_to_sqlite()
MAX_RETRIES = 5
class QueryRequest(BaseModel):
"""Model for the incoming user query."""
user_question: str
def build_sql_prompt(user_question: str) -> str:
"""Builds the full prompt used to generate the SQL query."""
return f"""
-- ROLE: Data Analyst SQL Expert
-- TASK: You are a highly accurate SQLite database expert. Your sole function is to translate a user's natural language question into a single, valid, and executable SQLite query.
-- OUTPUT FORMAT (required, exact):
-- 1) A short, high-level, non-sensitive explanation enclosed in <explanation>...</explanation>.
-- - This should be a concise rationale describing the approach and any key assumptions (reveal internal chain-of-thought or step-by-step hidden reasoning).
-- 2) The final SQL enclosed in <sql>...</sql>.
-- - The <sql> section must contain only the single, executable SQLite query (no surrounding backticks, no commentary).
-- Example correct output:
-- <explanation>Use date_dim.year filter for 2024; aggregate prescriptions by doctor; exclude null names.</explanation>
-- <sql>SELECT ...;</sql>
-- DATABASE CONTEXT:
-- The database is named 'pharma_data.db' and uses the SQLite dialect.
-- SCHEMA DEFINITION:
{SYSTEM_PROMPT}
-- INSTRUCTIONS & RULES:
-- 1. **SQL DIALECT:** Use standard SQLite syntax.
-- 2. **Aliasing:** Always use table aliases (e.g., T1, T2) for readability.
-- 3. **Aggregation:** Use appropriate aggregate functions (SUM, AVG, COUNT) and GROUP BY clauses when needed.
-- 4. **HCP/Account Names:** Names are typically full strings. Use LIKE '%%' or '=' as appropriate.
-- 5. **Date Filtering:** Use the appropriate columns in the `date_dim` table for filtering by year, quarter, etc.
-- 6. **Single Query:** Return exactly one valid, executable SQL statement inside <sql>...</sql>.
-- 7. **No Extra Output:** Do NOT include any other text, markup, or commentary outside the two tags.
-- 8. **Safety:** Do NOT output internal chain-of-thought or verbatim reasoning. Only provide the short, high-level explanation in <explanation> as described above.
-- USER QUESTION START:
{user_question}
-- USER QUESTION END
"""
def parse_response(text: str):
expl = re.search(r"<explanation>(.*?)</explanation>", text, re.DOTALL)
sql = re.search(r"<sql>(.*?)</sql>", text, re.DOTALL)
if not expl or not sql:
raise ValueError("Missing <explanation> or <sql> tags in model output")
return {
"explanation": expl.group(1).strip(),
"sql": sql.group(1).strip()
}
def generate_sql_query(messages: list = [], max_retry: int = 3) -> dict | None:
"""Generates a SQL query from the user question using the LLM with retry logic."""
retry_count = 0
while retry_count < max_retry:
try:
response = client.chat.completions.create(
model=GPT_MODEL,
messages=messages,
temperature=0.0
)
response = response.choices[0].message.content.strip()
parsed = parse_response(response)
return parsed
except Exception:
retry_count += 1
print(f"Failed to generate valid SQL after {max_retry} retries.")
return {}
def generate_final_answer(question: str, sql_query: str, sql_result: str) -> str:
"""Translates the raw SQL result into a natural language answer."""
prompt = f"""
You are a friendly, concise data analyst.
The user asked the question: "{question}"
The SQL query executed was: "{sql_query}"
The raw result from the database was: "{sql_result}"
Please provide the final answer in a single, concise, human-readable sentence.
If the result is an empty set, state that no data was found.
"""
# Simple single-attempt LLM call for translation
try:
response = client.chat.completions.create(
model=GPT_MODEL,
messages=[
{"role": "user", "content": prompt}
],
temperature=0.0
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"Error translating result: {e}"
def execute_sql(sql_query: str) -> str:
"""Executes the SQL query against the SQLite database and formats the result."""
conn = None
try:
# Connect to the database
conn = sqlite3.connect(DB_FILE)
cursor = conn.cursor()
cursor.execute(sql_query)
# Fetch results
rows = cursor.fetchall()
column_names = [description[0] for description in cursor.description]
if not rows:
return "[]" # Return an empty JSON string for no data
# Format results as a list of dictionaries (JSON format)
result_list = []
for row in rows:
result_list.append(dict(zip(column_names, row)))
# Use json.dumps to handle complex types and structure for LLM context
return json.dumps(result_list, indent=2)
except sqlite3.OperationalError as e:
# This is a critical error (e.g., bad syntax, misspelled table/column)
print(f"SQL Execution Error: {e}")
return f"SQL_ERROR: {e}"
except Exception as e:
print(f"General Execution Error: {e}")
return f"GENERAL_ERROR: {e}"
finally:
if conn:
conn.close()
def process_query(user_question: str) -> dict:
"""
Processes a natural language query, generates SQL, executes it, and provides a final answer.
"""
messages = []
user_question = user_question.strip()
prompt = build_sql_prompt(user_question)
messages.append({"role": "user", "content": prompt})
final_answer = ""
sql_query = ""
sql_result = ""
explanation = ""
for attempt in range(MAX_RETRIES):
print(f"Attempt {attempt + 1} of {MAX_RETRIES} to generate valid SQL query...")
generated_response = generate_sql_query(messages)
sql_query = generated_response.get('sql')
explanation = generated_response.get('explanation')
error_message = None
if sql_query is None:
error_message = "Failed to generate valid SQL query."
else:
sql_result = execute_sql(sql_query)
if sql_result.startswith("SQL_ERROR") or sql_result.startswith("GENERAL_ERROR"):
error_message = f"SQL generation was successful, but execution failed. The error was: {sql_result.split(': ', 1)[1]}"
elif sql_result == "[]": # Check for exact empty array string
error_message = "No data found for the given query, try another query."
else:
# If no error and data found, break the retry loop
break
# If an error occurred, append messages for retry and continue loop
if error_message:
messages.append({"role": "assistant", "content": generated_response})
messages.append({"role": "user", "content": f"The previous attempt resulted in an error: {error_message}. Please try again."})
print(f"Error encountered: {error_message}. Retrying...")
continue
if sql_query is None or error_message:
# If loop finished without success due to persistent errors or max retries
final_answer = "I was unable to generate a valid SQL query or execute it successfully. Please try rephrasing your question."
return {
"status": "failure",
"final_answer": final_answer,
"generated_sql": sql_query if sql_query else "N/A",
"sql_result": sql_result if sql_result else "N/A",
"model_used": GPT_MODEL,
"explanation": explanation if explanation else "N/A"
}
final_answer = generate_final_answer(user_question, sql_query, sql_result)
return {
"status": "success",
"final_answer": final_answer,
"generated_sql": sql_query,
"sql_result": sql_result,
"model_used": GPT_MODEL,
"explanation": explanation
}