import os import openai import json import yaml from dotenv import load_dotenv from db_logging import save_query_to_local_db import logging #############LangSmith & OpenAI#################### from langsmith import traceable from langsmith.wrappers import wrap_openai # Initialize OpenAI client client = wrap_openai(openai.Client()) # Load environment variables load_dotenv() # Load the OpenAI API key openai.api_key = os.getenv("OPENAI_API_KEY") if not openai.api_key: raise ValueError("Error: OPENAI_API_KEY not found in environment variables") GPT_MODEL = "gpt-4o-mini" ################################################## # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Function to calculate cost based on token usage def ydcoza_cost(tokens_out): t_cost = tokens_out * 0.2 / 1000000 # Rough Estimate For gpt-4o $5.00 / 1M input tokens : mini: $0.150 / 1M input tokens return t_cost # Function to load schema from schema.json def load_schema_from_json(): try: with open("schema.json", "r") as schema_file: schema_info = json.load(schema_file) logging.info("Schema loaded from schema.json") return schema_info except Exception as e: logging.error(f"Error loading schema from schema.json: {e}") return {"error": str(e)} # Function to build the schema description with foreign keys def build_schema_description(schema_info): schema_description = "" for table_name, table_details in schema_info.items(): schema_description += f"Table {table_name} (Comment: {table_details['comment']}):\n" for column in table_details["columns"]: schema_description += f" - {column['name']} ({column['data_type']}) (Comment: {column['comment']})\n" if table_details.get("foreign_keys"): schema_description += " Foreign Keys:\n" for fk in table_details["foreign_keys"]: schema_description += f" - {fk['column']} references {fk['references']['table']}({fk['references']['column']})\n" return schema_description # Function to load examples from YAML def load_examples_from_yaml(file_path): try: with open(file_path, 'r') as file: examples = yaml.safe_load(file) logging.info("Examples loaded from examples.yaml") return examples except Exception as e: logging.error(f"Error loading examples: {e}") return [] # Function to build examples string for the prompt def build_examples_string(examples_list): examples_string = "" for idx, example in enumerate(examples_list, start=1): examples_string += f"Example {idx}:\n" examples_string += f"Input:\n{example['input']}\n" examples_string += f"Reformulated Query:\n{example['reformulated_query']}\n" examples_string += f"SQL Query:\n{example['sql_query']}\n\n" return examples_string # Load examples from YAML EXAMPLES_FILE_PATH = os.path.join(os.path.dirname(__file__), 'examples.yaml') examples_list = load_examples_from_yaml(EXAMPLES_FILE_PATH) examples = build_examples_string(examples_list) # logging.info(f"Examples From YAML: {examples}") def generate_sql_single_call(nl_query): try: # Load the schema from schema.json schema_info = load_schema_from_json() if isinstance(schema_info, str) and schema_info.startswith("Error"): return "Error fetching schema", "" # Build the schema description once and reuse it schema_description = build_schema_description(schema_info) logging.info(f"Schema Description: {schema_description}") # Use the final prompt as shown above prompt = ( f"Database Schema:\n{schema_description}\n\n" f"Your task is to:\n" f"1. Reformulate the user's natural language query to align precisely with the database schema. Also, indicate the tables to use. These tables **MUST** be present in the schema provided.\n" f"2. Generate the corresponding SQL query based on the reformulated query and the provided schema.\n\n" f"User's Query:\n\"{nl_query}\"\n\n" f"Examples:\n{examples}\n" f"Response Format (in JSON):\n" f"{{\n" f' "reformulated_query": "",\n' f' "sql_query": ""\n' f"}}\n\n" f"Important Guidelines:\n" f"- Use only the tables and columns provided in the schema.\n" f"- Ensure the SQL query is syntactically correct.\n" f"- Do not include any additional text or explanations.\n" f"- Do not include any text outside the JSON format.\n" f"- Ensure the JSON is valid and properly formatted.\n" f"- Avoid assumptions about data not present in the schema.\n" f"- Double-check your response for accuracy." ) logging.info(f"Full Prompt: {prompt}") logging.info("Sending combined request to OpenAI...") response = client.chat.completions.create( model=GPT_MODEL, messages=[ { "role": "system", "content": ( "You are an expert data analyst and SQL specialist. " "Your role is to reformulate natural language queries to align with a given database schema " # "and then generate accurate SQL queries based on the reformulated queries." "and then create a syntactically correct PostgreSQL query to run based on the reformulated queries." "Ensure that you only use tables and columns present in the provided schema. " "Indicate the tables to use in your reformulated query." "Ensure that your responses are precise, concise, and follow the provided guidelines." "Provide your response in the specified JSON format." ) }, {"role": "user", "content": prompt} ], max_tokens=500, temperature=0.3 # Adjusted temperature ) # Process the assistant's response assistant_response = response.choices[0].message.content.strip() logging.info(f"Assistant Response:\n{assistant_response}") # Calculate Tokens tokens_used = response.usage.total_tokens cost_per_call = ydcoza_cost(tokens_used) total_cost_per_call = f"

Tokens Consumed: {tokens_used} ; Cost per call: ${cost_per_call:.6f}

" # Parse the assistant's response try: response_json = json.loads(assistant_response) reformulated_query = response_json.get("reformulated_query", "") sql_query = response_json.get("sql_query", "") except json.JSONDecodeError: logging.error("Could not parse assistant response as JSON.") return "Error parsing assistant response.", "" logging.info(f"Reformulated Query: {reformulated_query}") logging.info(f"SQL Query generated: {sql_query}") # Save the query to the local database save_query_to_local_db(nl_query, reformulated_query, sql_query) return reformulated_query, sql_query, total_cost_per_call except openai.error.OpenAIError as e: logging.error(f"OpenAI API error: {e}") return "Error during interaction with OpenAI API.", "" except Exception as e: logging.error(f"General error during SQL process: {e}") return "General error during SQL processing.", ""