Spaces:
Runtime error
Runtime error
File size: 7,743 Bytes
909cddd 2cb3f69 909cddd 2cb3f69 909cddd 2cb3f69 909cddd 2cb3f69 909cddd 2cb3f69 909cddd 2cb3f69 909cddd 2cb3f69 909cddd 2cb3f69 909cddd 2cb3f69 909cddd 2cb3f69 909cddd 2cb3f69 909cddd 2cb3f69 909cddd 2cb3f69 909cddd 2cb3f69 909cddd 2cb3f69 909cddd 2cb3f69 909cddd 2cb3f69 909cddd 2cb3f69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import os
import openai
import json
import yaml
from dotenv import load_dotenv
from db_logging import save_query_to_local_db
import logging
#############LangSmith & OpenAI####################
from langsmith import traceable
from langsmith.wrappers import wrap_openai
# Initialize OpenAI client
client = wrap_openai(openai.Client())
# Load environment variables
load_dotenv()
# Load the OpenAI API key
openai.api_key = os.getenv("OPENAI_API_KEY")
if not openai.api_key:
raise ValueError("Error: OPENAI_API_KEY not found in environment variables")
GPT_MODEL = "gpt-4o-mini"
##################################################
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Function to calculate cost based on token usage
def ydcoza_cost(tokens_out):
t_cost = tokens_out * 0.2 / 1000000 # Rough Estimate For gpt-4o $5.00 / 1M input tokens : mini: $0.150 / 1M input tokens
return t_cost
# Function to load schema from schema.json
def load_schema_from_json():
try:
with open("schema.json", "r") as schema_file:
schema_info = json.load(schema_file)
logging.info("Schema loaded from schema.json")
return schema_info
except Exception as e:
logging.error(f"Error loading schema from schema.json: {e}")
return {"error": str(e)}
# Function to build the schema description with foreign keys
def build_schema_description(schema_info):
schema_description = ""
for table_name, table_details in schema_info.items():
schema_description += f"Table {table_name} (Comment: {table_details['comment']}):\n"
for column in table_details["columns"]:
schema_description += f" - {column['name']} ({column['data_type']}) (Comment: {column['comment']})\n"
if table_details.get("foreign_keys"):
schema_description += " Foreign Keys:\n"
for fk in table_details["foreign_keys"]:
schema_description += f" - {fk['column']} references {fk['references']['table']}({fk['references']['column']})\n"
return schema_description
# Function to load examples from YAML
def load_examples_from_yaml(file_path):
try:
with open(file_path, 'r') as file:
examples = yaml.safe_load(file)
logging.info("Examples loaded from examples.yaml")
return examples
except Exception as e:
logging.error(f"Error loading examples: {e}")
return []
# Function to build examples string for the prompt
def build_examples_string(examples_list):
examples_string = ""
for idx, example in enumerate(examples_list, start=1):
examples_string += f"Example {idx}:\n"
examples_string += f"Input:\n{example['input']}\n"
examples_string += f"Reformulated Query:\n{example['reformulated_query']}\n"
examples_string += f"SQL Query:\n{example['sql_query']}\n\n"
return examples_string
# Load examples from YAML
EXAMPLES_FILE_PATH = os.path.join(os.path.dirname(__file__), 'examples.yaml')
examples_list = load_examples_from_yaml(EXAMPLES_FILE_PATH)
examples = build_examples_string(examples_list)
# logging.info(f"Examples From YAML: {examples}")
def generate_sql_single_call(nl_query):
try:
# Load the schema from schema.json
schema_info = load_schema_from_json()
if isinstance(schema_info, str) and schema_info.startswith("Error"):
return "Error fetching schema", ""
# Build the schema description once and reuse it
schema_description = build_schema_description(schema_info)
logging.info(f"Schema Description: {schema_description}")
# Use the final prompt as shown above
prompt = (
f"Database Schema:\n{schema_description}\n\n"
f"Your task is to:\n"
f"1. Reformulate the user's natural language query to align precisely with the database schema. Also, indicate the tables to use. These tables **MUST** be present in the schema provided.\n"
f"2. Generate the corresponding SQL query based on the reformulated query and the provided schema.\n\n"
f"User's Query:\n\"{nl_query}\"\n\n"
f"Examples:\n{examples}\n"
f"Response Format (in JSON):\n"
f"{{\n"
f' "reformulated_query": "<your reformulated query>",\n'
f' "sql_query": "<your SQL query>"\n'
f"}}\n\n"
f"Important Guidelines:\n"
f"- Use only the tables and columns provided in the schema.\n"
f"- Ensure the SQL query is syntactically correct.\n"
f"- Do not include any additional text or explanations.\n"
f"- Do not include any text outside the JSON format.\n"
f"- Ensure the JSON is valid and properly formatted.\n"
f"- Avoid assumptions about data not present in the schema.\n"
f"- Double-check your response for accuracy."
)
logging.info(f"Full Prompt: {prompt}")
logging.info("Sending combined request to OpenAI...")
response = client.chat.completions.create(
model=GPT_MODEL,
messages=[
{
"role": "system",
"content": (
"You are an expert data analyst and SQL specialist. "
"Your role is to reformulate natural language queries to align with a given database schema "
# "and then generate accurate SQL queries based on the reformulated queries."
"and then create a syntactically correct PostgreSQL query to run based on the reformulated queries."
"Ensure that you only use tables and columns present in the provided schema. "
"Indicate the tables to use in your reformulated query."
"Ensure that your responses are precise, concise, and follow the provided guidelines."
"Provide your response in the specified JSON format."
)
},
{"role": "user", "content": prompt}
],
max_tokens=500,
temperature=0.3 # Adjusted temperature
)
# Process the assistant's response
assistant_response = response.choices[0].message.content.strip()
logging.info(f"Assistant Response:\n{assistant_response}")
# Calculate Tokens
tokens_used = response.usage.total_tokens
cost_per_call = ydcoza_cost(tokens_used)
total_cost_per_call = f"<p class='cost_per_call'>Tokens Consumed: {tokens_used} ; Cost per call: ${cost_per_call:.6f}</p>"
# Parse the assistant's response
try:
response_json = json.loads(assistant_response)
reformulated_query = response_json.get("reformulated_query", "")
sql_query = response_json.get("sql_query", "")
except json.JSONDecodeError:
logging.error("Could not parse assistant response as JSON.")
return "Error parsing assistant response.", ""
logging.info(f"Reformulated Query: {reformulated_query}")
logging.info(f"SQL Query generated: {sql_query}")
# Save the query to the local database
save_query_to_local_db(nl_query, reformulated_query, sql_query)
return reformulated_query, sql_query, total_cost_per_call
except openai.error.OpenAIError as e:
logging.error(f"OpenAI API error: {e}")
return "Error during interaction with OpenAI API.", ""
except Exception as e:
logging.error(f"General error during SQL process: {e}")
return "General error during SQL processing.", ""
|