Spaces:
Sleeping
Sleeping
| import anthropic | |
| import json | |
| import os | |
| import re | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| class NaturalLanguageParser: | |
| """Advanced Natural Language to SQL Engine using Claude""" | |
| def __init__(self): | |
| api_key = os.getenv("ANTHROPIC_API_KEY") | |
| if not api_key: | |
| raise ValueError("ANTHROPIC_API_KEY not found in .env file!") | |
| self.client = anthropic.Anthropic(api_key=api_key) | |
| def generate_sql(self, description, schema, dialect="PostgreSQL"): | |
| """ | |
| Generate complete, optimized SQL directly from natural language. | |
| This is the main engine that handles all SQL generation logic. | |
| """ | |
| schema_text = self._format_schema(schema) | |
| prompt = f"""You are an advanced Natural-Language-to-SQL Engine. | |
| Your job is to convert user instructions into correct, executable, efficient SQL queries, based solely on the provided database schema. | |
| π₯ Core Responsibilities | |
| 1. NEVER hallucinate tables or columns. Only use what exists in the provided schema. | |
| 2. Follow the SQL dialect: {dialect} | |
| 3. Validate user intent and generate the best query, even if their natural language is unclear. | |
| 4. Fix common SQL mistakes: | |
| - Use IS NULL / IS NOT NULL (never = NULL) | |
| - Use >= and < for date ranges instead of BETWEEN | |
| - Proper JOIN syntax | |
| - Correct aggregate/grouping logic | |
| 5. Optimize the SQL: | |
| - Use proper filters | |
| - Use EXISTS/NOT EXISTS for subqueries | |
| - Use CTEs for clarity in complex queries | |
| - Avoid unnecessary computation | |
| 6. For complex analysis (cohorts, trends, rankings, window functions): | |
| - Use Common Table Expressions (CTEs) | |
| - Use window functions when appropriate | |
| ποΈ Available Database Schema: | |
| {schema_text} | |
| π€ User Request: {description} | |
| π― SQL Construction Rules: | |
| WHERE clauses: | |
| - Use IS NULL / IS NOT NULL, not = 'NULL' | |
| - For date ranges use >= and < instead of BETWEEN | |
| Aggregations: | |
| - GROUP BY all non-aggregated fields | |
| - Avoid GROUP BY unnecessary columns | |
| Joins: | |
| - Prefer explicit JOIN syntax | |
| - Use LEFT JOIN for "missing data" queries | |
| - Always specify join conditions clearly | |
| Subqueries: | |
| - Prefer EXISTS/NOT EXISTS for performance | |
| - Use CTEs for readability in complex queries | |
| Window Functions (use when user asks for): | |
| - Top N per group | |
| - Rankings | |
| - Running totals | |
| - Comparisons to averages | |
| - Rolling windows | |
| π Output Format: | |
| Return a JSON object with this exact structure: | |
| {{ | |
| "sql": "the complete SQL query here", | |
| "explanation": "brief explanation of what the query does", | |
| "query_type": "simple|aggregate|join|window|cte|analytical", | |
| "warnings": ["any warnings about schema limitations or assumptions"], | |
| "optimizations": ["list of optimizations applied"] | |
| }} | |
| CRITICAL RULES FOR JSON OUTPUT: | |
| - Return ONLY valid JSON (no markdown, no code blocks, no extra text) | |
| - Escape ALL special characters in the SQL string: | |
| * Newlines must be \\n | |
| * Quotes must be \\" | |
| * Backslashes must be \\\\ | |
| - The "sql" field must be a single-line string with \\n for line breaks | |
| - Always end SQL with semicolon | |
| - If the request is impossible with the given schema, set "sql" to "-- ERROR: <explanation>" and explain in "warnings" | |
| Generate the SQL query now:""" | |
| try: | |
| response = self.client.messages.create( | |
| model="claude-3-opus-20240229", | |
| max_tokens=4000, | |
| messages=[{"role": "user", "content": prompt}] | |
| ) | |
| content = response.content[0].text.strip() | |
| # Remove markdown code blocks if present | |
| content = content.replace("```json", "").replace("```", "").strip() | |
| # Try to parse JSON | |
| try: | |
| result = json.loads(content) | |
| # Ensure all required fields exist | |
| if "sql" not in result or not result["sql"]: | |
| result["sql"] = "-- ERROR: No SQL generated" | |
| if "explanation" not in result: | |
| result["explanation"] = "SQL query generated" | |
| if "query_type" not in result: | |
| result["query_type"] = "select" | |
| if "warnings" not in result: | |
| result["warnings"] = [] | |
| if "optimizations" not in result: | |
| result["optimizations"] = [] | |
| return result | |
| except json.JSONDecodeError as e: | |
| # If JSON parsing fails, try to extract components manually using regex | |
| print(f"JSON parsing failed: {e}, attempting manual extraction...") | |
| # Try to extract SQL between "sql": " and next quote | |
| sql_match = re.search(r'"sql"\s*:\s*"((?:[^"\\]|\\.|\\n)*)"', content, re.DOTALL) | |
| if sql_match: | |
| sql = sql_match.group(1) | |
| # Unescape the JSON string | |
| sql = sql.replace('\\n', '\n').replace('\\"', '"').replace('\\\\', '\\') | |
| else: | |
| # Try alternative format or use entire content | |
| sql = content | |
| # Try to extract explanation | |
| expl_match = re.search(r'"explanation"\s*:\s*"((?:[^"\\]|\\.)*)"', content, re.DOTALL) | |
| explanation = expl_match.group(1) if expl_match else "SQL generated (with parsing issues)" | |
| # Try to extract query type | |
| type_match = re.search(r'"query_type"\s*:\s*"([^"]*)"', content) | |
| query_type = type_match.group(1) if type_match else "select" | |
| result = { | |
| "sql": sql, | |
| "explanation": explanation, | |
| "query_type": query_type, | |
| "warnings": ["JSON parsing issue - SQL may need review"], | |
| "optimizations": [] | |
| } | |
| return result | |
| except Exception as e: | |
| return { | |
| "sql": f"-- ERROR: {str(e)}", | |
| "explanation": "An error occurred during SQL generation", | |
| "query_type": "error", | |
| "warnings": [str(e)], | |
| "optimizations": [] | |
| } | |
| def _format_schema(self, schema): | |
| """Format schema for prompt""" | |
| lines = ["Tables and Columns:"] | |
| for table, columns in schema.items(): | |
| lines.append(f"\nπ {table}") | |
| for col in columns: | |
| nullable = "NULL" if col.get('nullable', True) else "NOT NULL" | |
| pk = " (PRIMARY KEY)" if col.get('primary_key', False) else "" | |
| lines.append(f" - {col['name']}: {col['type']} {nullable}{pk}") | |
| return "\n".join(lines) | |
| # Legacy method for backward compatibility | |
| def parse(self, description, schema): | |
| """Legacy method - redirects to generate_sql""" | |
| result = self.generate_sql(description, schema) | |
| # Return just the SQL for backward compatibility | |
| return {"raw_sql": result.get("sql", ""), "metadata": result} |