ai_eee_sql_gen / openai_integration.py
laudes's picture
Upload 8 files
2cb3f69 verified
import os
import openai
import json
import yaml
from dotenv import load_dotenv
from db_logging import save_query_to_local_db
import logging
#############LangSmith & OpenAI####################
from langsmith import traceable
from langsmith.wrappers import wrap_openai
# Initialize OpenAI client
client = wrap_openai(openai.Client())
# Load environment variables
load_dotenv()
# Load the OpenAI API key
openai.api_key = os.getenv("OPENAI_API_KEY")
if not openai.api_key:
raise ValueError("Error: OPENAI_API_KEY not found in environment variables")
GPT_MODEL = "gpt-4o-mini"
##################################################
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Function to calculate cost based on token usage
def ydcoza_cost(tokens_out):
t_cost = tokens_out * 0.2 / 1000000 # Rough Estimate For gpt-4o $5.00 / 1M input tokens : mini: $0.150 / 1M input tokens
return t_cost
# Function to load schema from schema.json
def load_schema_from_json():
try:
with open("schema.json", "r") as schema_file:
schema_info = json.load(schema_file)
logging.info("Schema loaded from schema.json")
return schema_info
except Exception as e:
logging.error(f"Error loading schema from schema.json: {e}")
return {"error": str(e)}
# Function to build the schema description with foreign keys
def build_schema_description(schema_info):
schema_description = ""
for table_name, table_details in schema_info.items():
schema_description += f"Table {table_name} (Comment: {table_details['comment']}):\n"
for column in table_details["columns"]:
schema_description += f" - {column['name']} ({column['data_type']}) (Comment: {column['comment']})\n"
if table_details.get("foreign_keys"):
schema_description += " Foreign Keys:\n"
for fk in table_details["foreign_keys"]:
schema_description += f" - {fk['column']} references {fk['references']['table']}({fk['references']['column']})\n"
return schema_description
# Function to load examples from YAML
def load_examples_from_yaml(file_path):
try:
with open(file_path, 'r') as file:
examples = yaml.safe_load(file)
logging.info("Examples loaded from examples.yaml")
return examples
except Exception as e:
logging.error(f"Error loading examples: {e}")
return []
# Function to build examples string for the prompt
def build_examples_string(examples_list):
examples_string = ""
for idx, example in enumerate(examples_list, start=1):
examples_string += f"Example {idx}:\n"
examples_string += f"Input:\n{example['input']}\n"
examples_string += f"Reformulated Query:\n{example['reformulated_query']}\n"
examples_string += f"SQL Query:\n{example['sql_query']}\n\n"
return examples_string
# Load examples from YAML
EXAMPLES_FILE_PATH = os.path.join(os.path.dirname(__file__), 'examples.yaml')
examples_list = load_examples_from_yaml(EXAMPLES_FILE_PATH)
examples = build_examples_string(examples_list)
# logging.info(f"Examples From YAML: {examples}")
def generate_sql_single_call(nl_query):
try:
# Load the schema from schema.json
schema_info = load_schema_from_json()
if isinstance(schema_info, str) and schema_info.startswith("Error"):
return "Error fetching schema", ""
# Build the schema description once and reuse it
schema_description = build_schema_description(schema_info)
logging.info(f"Schema Description: {schema_description}")
# Use the final prompt as shown above
prompt = (
f"Database Schema:\n{schema_description}\n\n"
f"Your task is to:\n"
f"1. Reformulate the user's natural language query to align precisely with the database schema. Also, indicate the tables to use. These tables **MUST** be present in the schema provided.\n"
f"2. Generate the corresponding SQL query based on the reformulated query and the provided schema.\n\n"
f"User's Query:\n\"{nl_query}\"\n\n"
f"Examples:\n{examples}\n"
f"Response Format (in JSON):\n"
f"{{\n"
f' "reformulated_query": "<your reformulated query>",\n'
f' "sql_query": "<your SQL query>"\n'
f"}}\n\n"
f"Important Guidelines:\n"
f"- Use only the tables and columns provided in the schema.\n"
f"- Ensure the SQL query is syntactically correct.\n"
f"- Do not include any additional text or explanations.\n"
f"- Do not include any text outside the JSON format.\n"
f"- Ensure the JSON is valid and properly formatted.\n"
f"- Avoid assumptions about data not present in the schema.\n"
f"- Double-check your response for accuracy."
)
logging.info(f"Full Prompt: {prompt}")
logging.info("Sending combined request to OpenAI...")
response = client.chat.completions.create(
model=GPT_MODEL,
messages=[
{
"role": "system",
"content": (
"You are an expert data analyst and SQL specialist. "
"Your role is to reformulate natural language queries to align with a given database schema "
# "and then generate accurate SQL queries based on the reformulated queries."
"and then create a syntactically correct PostgreSQL query to run based on the reformulated queries."
"Ensure that you only use tables and columns present in the provided schema. "
"Indicate the tables to use in your reformulated query."
"Ensure that your responses are precise, concise, and follow the provided guidelines."
"Provide your response in the specified JSON format."
)
},
{"role": "user", "content": prompt}
],
max_tokens=500,
temperature=0.3 # Adjusted temperature
)
# Process the assistant's response
assistant_response = response.choices[0].message.content.strip()
logging.info(f"Assistant Response:\n{assistant_response}")
# Calculate Tokens
tokens_used = response.usage.total_tokens
cost_per_call = ydcoza_cost(tokens_used)
total_cost_per_call = f"<p class='cost_per_call'>Tokens Consumed: {tokens_used} ; Cost per call: ${cost_per_call:.6f}</p>"
# Parse the assistant's response
try:
response_json = json.loads(assistant_response)
reformulated_query = response_json.get("reformulated_query", "")
sql_query = response_json.get("sql_query", "")
except json.JSONDecodeError:
logging.error("Could not parse assistant response as JSON.")
return "Error parsing assistant response.", ""
logging.info(f"Reformulated Query: {reformulated_query}")
logging.info(f"SQL Query generated: {sql_query}")
# Save the query to the local database
save_query_to_local_db(nl_query, reformulated_query, sql_query)
return reformulated_query, sql_query, total_cost_per_call
except openai.error.OpenAIError as e:
logging.error(f"OpenAI API error: {e}")
return "Error during interaction with OpenAI API.", ""
except Exception as e:
logging.error(f"General error during SQL process: {e}")
return "General error during SQL processing.", ""