Spaces:
Sleeping
Sleeping
| import json | |
| import openai | |
| from config import PROJECT_ID, DATASET_ID | |
| from utils.bigquery_utils import get_bigquery_schema_info | |
| def table_selection_agent(state): | |
| """Identifies relevant tables for the natural language query based on schema.""" | |
| natural_language_query = state["sql_query"] | |
| client = state["client"] | |
| if client is None: | |
| return {"relevant_tables": [], "error": "Failed to connect to BigQuery."} | |
| schema_info = get_bigquery_schema_info(client, PROJECT_ID, DATASET_ID) | |
| # Format the schema for the prompt | |
| schema_text = "" | |
| for table_name, columns in schema_info.items(): | |
| schema_text += f"- **{DATASET_ID}.{table_name}** ({', '.join(columns)})\n" | |
| prompt = f""" | |
| Based on the following natural language query and BigQuery schema, identify the tables that would be needed to answer the query. | |
| **Query:** "{natural_language_query}" | |
| **BigQuery Schema:** | |
| {schema_text} | |
| Analyze the query and determine which tables contain the necessary information. | |
| IMPORTANT: Return ONLY a raw JSON array of table names without any markdown formatting, code blocks, or explanations. | |
| Example of correct response format: | |
| ["{DATASET_ID}.users", "{DATASET_ID}.orders"] | |
| Example of INCORRECT response format: | |
| ```json | |
| ["{DATASET_ID}.users", "{DATASET_ID}.orders"] | |
| ``` | |
| DO NOT use code blocks, backticks, or any other formatting. Return ONLY the raw JSON array. | |
| """ | |
| try: | |
| response = openai.chat.completions.create( | |
| model="gpt-4o", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.0 | |
| ) | |
| # Get the content from the response | |
| content = response.choices[0].message.content.strip() | |
| # Remove markdown code block formatting if present | |
| if content.startswith("```"): | |
| # Extract content between the code block markers | |
| parts = content.split("```") | |
| if len(parts) >= 3: # There should be at least 3 parts if there are code blocks | |
| content = parts[1] | |
| # If there's a language identifier (like json), remove it | |
| if content.startswith("json"): | |
| content = content.replace("json", "", 1).strip() | |
| # Parse the JSON | |
| relevant_tables = json.loads(content) | |
| print(f"Parsed relevant tables: {relevant_tables}") | |
| return {"relevant_tables": relevant_tables} | |
| except json.JSONDecodeError as e: | |
| print(f"JSON Decode Error: {e}") | |
| print(f"Response content: {response.choices[0].message.content}") | |
| return {"relevant_tables": [], "error": "Invalid JSON response from OpenAI"} | |
| except Exception as e: | |
| print(f"Unexpected error: {e}") | |
| return {"relevant_tables": [], "error": f"Error: {str(e)}"} |