File size: 7,743 Bytes
909cddd
 
 
2cb3f69
909cddd
 
 
2cb3f69
 
 
 
 
 
909cddd
2cb3f69
 
 
 
 
 
909cddd
 
 
 
2cb3f69
 
 
 
909cddd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cb3f69
 
 
 
 
 
 
909cddd
2cb3f69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
909cddd
2cb3f69
 
 
 
 
909cddd
2cb3f69
 
 
 
 
909cddd
 
2cb3f69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
909cddd
2cb3f69
 
909cddd
 
 
2cb3f69
 
 
 
 
 
 
 
 
 
 
 
 
909cddd
 
2cb3f69
 
909cddd
 
2cb3f69
 
 
909cddd
 
 
2cb3f69
 
 
 
909cddd
 
2cb3f69
 
 
 
 
 
 
 
909cddd
2cb3f69
 
909cddd
 
 
 
2cb3f69
909cddd
 
 
 
 
 
 
 
2cb3f69
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import openai
import json
import yaml
from dotenv import load_dotenv
from db_logging import save_query_to_local_db
import logging
#############LangSmith & OpenAI####################
from langsmith import traceable
from langsmith.wrappers import wrap_openai
# Initialize OpenAI client
client = wrap_openai(openai.Client())
# Load environment variables
load_dotenv()
# Load the OpenAI API key
openai.api_key = os.getenv("OPENAI_API_KEY")
if not openai.api_key:
    raise ValueError("Error: OPENAI_API_KEY not found in environment variables")
GPT_MODEL = "gpt-4o-mini"
##################################################

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Function to calculate cost based on token usage
def ydcoza_cost(tokens_out):
    t_cost = tokens_out * 0.2 / 1000000  # Rough Estimate For gpt-4o $5.00 / 1M input tokens :  mini: $0.150 / 1M input tokens
    return t_cost

# Function to load schema from schema.json
def load_schema_from_json():
    try:
        with open("schema.json", "r") as schema_file:
            schema_info = json.load(schema_file)
        logging.info("Schema loaded from schema.json")
        return schema_info
    except Exception as e:
        logging.error(f"Error loading schema from schema.json: {e}")
        return {"error": str(e)}

# Function to build the schema description with foreign keys
def build_schema_description(schema_info):
    schema_description = ""

    for table_name, table_details in schema_info.items():
        schema_description += f"Table {table_name} (Comment: {table_details['comment']}):\n"
        for column in table_details["columns"]:
            schema_description += f"  - {column['name']} ({column['data_type']}) (Comment: {column['comment']})\n"
        if table_details.get("foreign_keys"):
            schema_description += "  Foreign Keys:\n"
            for fk in table_details["foreign_keys"]:
                schema_description += f"    - {fk['column']} references {fk['references']['table']}({fk['references']['column']})\n"
    return schema_description


# Function to load examples from YAML
def load_examples_from_yaml(file_path):
    try:
        with open(file_path, 'r') as file:
            examples = yaml.safe_load(file)
        logging.info("Examples loaded from examples.yaml")
        return examples
    except Exception as e:
        logging.error(f"Error loading examples: {e}")
        return []

# Function to build examples string for the prompt
def build_examples_string(examples_list):
    examples_string = ""
    for idx, example in enumerate(examples_list, start=1):
        examples_string += f"Example {idx}:\n"
        examples_string += f"Input:\n{example['input']}\n"
        examples_string += f"Reformulated Query:\n{example['reformulated_query']}\n"
        examples_string += f"SQL Query:\n{example['sql_query']}\n\n"
    return examples_string

# Load examples from YAML
EXAMPLES_FILE_PATH = os.path.join(os.path.dirname(__file__), 'examples.yaml')
examples_list = load_examples_from_yaml(EXAMPLES_FILE_PATH)
examples = build_examples_string(examples_list)
# logging.info(f"Examples From YAML: {examples}")

def generate_sql_single_call(nl_query):
    try:
        # Load the schema from schema.json
        schema_info = load_schema_from_json()

        if isinstance(schema_info, str) and schema_info.startswith("Error"):
            return "Error fetching schema", ""

        # Build the schema description once and reuse it
        schema_description = build_schema_description(schema_info)
        logging.info(f"Schema Description: {schema_description}")

        # Use the final prompt as shown above
        prompt = (
            f"Database Schema:\n{schema_description}\n\n"
            f"Your task is to:\n"
            f"1. Reformulate the user's natural language query to align precisely with the database schema. Also, indicate the tables to use. These tables **MUST** be present in the schema provided.\n"
            f"2. Generate the corresponding SQL query based on the reformulated query and the provided schema.\n\n"
            f"User's Query:\n\"{nl_query}\"\n\n"
            f"Examples:\n{examples}\n"
            f"Response Format (in JSON):\n"
            f"{{\n"
            f'  "reformulated_query": "<your reformulated query>",\n'
            f'  "sql_query": "<your SQL query>"\n'
            f"}}\n\n"
            f"Important Guidelines:\n"
            f"- Use only the tables and columns provided in the schema.\n"
            f"- Ensure the SQL query is syntactically correct.\n"
            f"- Do not include any additional text or explanations.\n"
            f"- Do not include any text outside the JSON format.\n"
            f"- Ensure the JSON is valid and properly formatted.\n"
            f"- Avoid assumptions about data not present in the schema.\n"
            f"- Double-check your response for accuracy."
        )
        logging.info(f"Full Prompt: {prompt}")
        logging.info("Sending combined request to OpenAI...")
        response = client.chat.completions.create(
            model=GPT_MODEL,
            messages=[
                {
                    "role": "system",
                    "content": (
                        "You are an expert data analyst and SQL specialist. "
                        "Your role is to reformulate natural language queries to align with a given database schema "
                        # "and then generate accurate SQL queries based on the reformulated queries."
                        "and then create a syntactically correct PostgreSQL query to run based on the reformulated queries."
                        "Ensure that you only use tables and columns present in the provided schema. "
                        "Indicate the tables to use in your reformulated query."
                        "Ensure that your responses are precise, concise, and follow the provided guidelines."
                        "Provide your response in the specified JSON format."
                    )
                },
                {"role": "user", "content": prompt}
            ],
            max_tokens=500,
            temperature=0.3  # Adjusted temperature
        )

        # Process the assistant's response
        assistant_response = response.choices[0].message.content.strip()
        logging.info(f"Assistant Response:\n{assistant_response}")



        # Calculate Tokens
        tokens_used = response.usage.total_tokens
        cost_per_call = ydcoza_cost(tokens_used)
        total_cost_per_call = f"<p class='cost_per_call'>Tokens Consumed: {tokens_used} ; Cost per call: ${cost_per_call:.6f}</p>"


        # Parse the assistant's response
        try:
            response_json = json.loads(assistant_response)
            reformulated_query = response_json.get("reformulated_query", "")
            sql_query = response_json.get("sql_query", "")
        except json.JSONDecodeError:
            logging.error("Could not parse assistant response as JSON.")
            return "Error parsing assistant response.", ""

        logging.info(f"Reformulated Query: {reformulated_query}")
        logging.info(f"SQL Query generated: {sql_query}")

        # Save the query to the local database
        save_query_to_local_db(nl_query, reformulated_query, sql_query)

        return reformulated_query, sql_query, total_cost_per_call

    except openai.error.OpenAIError as e:
        logging.error(f"OpenAI API error: {e}")
        return "Error during interaction with OpenAI API.", ""

    except Exception as e:
        logging.error(f"General error during SQL process: {e}")
        return "General error during SQL processing.", ""