Spaces:
Sleeping
Sleeping
| import openai | |
| from config import PROJECT_ID, DATASET_ID | |
| from utils.bigquery_utils import get_bigquery_schema_info | |
| def sql_generation_agent(state): | |
| """Generates a SQL query based on the natural language query and sample data.""" | |
| natural_language_query = state["sql_query"] | |
| relevant_tables = state.get("relevant_tables", []) | |
| sample_data = state.get("sample_data", {}) | |
| client = state["client"] | |
| if client is None: | |
| return {"generated_sql": "-- Error: Failed to connect to BigQuery."} | |
| schema_info = get_bigquery_schema_info(client, PROJECT_ID, DATASET_ID) | |
| # Format the schema for the prompt | |
| schema_text = "" | |
| for table_name, columns in schema_info.items(): | |
| if f"{DATASET_ID}.{table_name}" in relevant_tables: | |
| schema_text += f"- **{DATASET_ID}.{table_name}** ({', '.join(columns)})\n" | |
| # Format sample data for the prompt | |
| sample_data_text = "" | |
| for table, rows in sample_data.items(): | |
| if isinstance(rows, list) and rows: | |
| sample_data_text += f"\n**Sample data from {table}:**\n" | |
| # Get column names from the first row | |
| columns = list(rows[0].keys()) | |
| sample_data_text += "| " + " | ".join(columns) + " |\n" | |
| sample_data_text += "| " + " | ".join(["---"] * len(columns)) + " |\n" | |
| # Add row data | |
| for row in rows: | |
| sample_data_text += "| " + " | ".join([str(row.get(col, "")) for col in columns]) + " |\n" | |
| prompt = f""" | |
| Generate a BigQuery SQL query to answer the following question: | |
| **Question:** "{natural_language_query}" | |
| **Relevant Tables Schema:** | |
| {schema_text} | |
| **Sample Data:** | |
| {sample_data_text} | |
| **Rules:** | |
| - Use only the provided tables with their full dataset.table_name format (e.g., {DATASET_ID}.users). | |
| - Ensure correct column names as shown in the schema. | |
| - Use appropriate joins based on the relationships visible in the sample data. | |
| - Use BigQuery SQL syntax. | |
| - Return ONLY the SQL query without any explanations or markdown formatting. | |
| """ | |
| response = openai.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.0 | |
| ) | |
| generated_sql = response.choices[0].message.content.strip() | |
| # Remove markdown code block formatting if present | |
| if generated_sql.startswith("```sql"): | |
| generated_sql = generated_sql.replace("```sql", "").replace("```", "").strip() | |
| return {"generated_sql": generated_sql} |