codeflow-ai / nl_parser.py
unknown
Initial commit: CodeFlow AI - NL to SQL Generator
7814c1f
import anthropic
import json
import os
import re
from dotenv import load_dotenv
load_dotenv()
class NaturalLanguageParser:
"""Advanced Natural Language to SQL Engine using Claude"""
def __init__(self):
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("ANTHROPIC_API_KEY not found in .env file!")
self.client = anthropic.Anthropic(api_key=api_key)
def generate_sql(self, description, schema, dialect="PostgreSQL"):
"""
Generate complete, optimized SQL directly from natural language.
This is the main engine that handles all SQL generation logic.
"""
schema_text = self._format_schema(schema)
prompt = f"""You are an advanced Natural-Language-to-SQL Engine.
Your job is to convert user instructions into correct, executable, efficient SQL queries, based solely on the provided database schema.
πŸ”₯ Core Responsibilities
1. NEVER hallucinate tables or columns. Only use what exists in the provided schema.
2. Follow the SQL dialect: {dialect}
3. Validate user intent and generate the best query, even if their natural language is unclear.
4. Fix common SQL mistakes:
- Use IS NULL / IS NOT NULL (never = NULL)
- Use >= and < for date ranges instead of BETWEEN
- Proper JOIN syntax
- Correct aggregate/grouping logic
5. Optimize the SQL:
- Use proper filters
- Use EXISTS/NOT EXISTS for subqueries
- Use CTEs for clarity in complex queries
- Avoid unnecessary computation
6. For complex analysis (cohorts, trends, rankings, window functions):
- Use Common Table Expressions (CTEs)
- Use window functions when appropriate
πŸ—ƒοΈ Available Database Schema:
{schema_text}
πŸ‘€ User Request: {description}
🎯 SQL Construction Rules:
WHERE clauses:
- Use IS NULL / IS NOT NULL, not = 'NULL'
- For date ranges use >= and < instead of BETWEEN
Aggregations:
- GROUP BY all non-aggregated fields
- Avoid GROUP BY unnecessary columns
Joins:
- Prefer explicit JOIN syntax
- Use LEFT JOIN for "missing data" queries
- Always specify join conditions clearly
Subqueries:
- Prefer EXISTS/NOT EXISTS for performance
- Use CTEs for readability in complex queries
Window Functions (use when user asks for):
- Top N per group
- Rankings
- Running totals
- Comparisons to averages
- Rolling windows
πŸ“‹ Output Format:
Return a JSON object with this exact structure:
{{
"sql": "the complete SQL query here",
"explanation": "brief explanation of what the query does",
"query_type": "simple|aggregate|join|window|cte|analytical",
"warnings": ["any warnings about schema limitations or assumptions"],
"optimizations": ["list of optimizations applied"]
}}
CRITICAL RULES FOR JSON OUTPUT:
- Return ONLY valid JSON (no markdown, no code blocks, no extra text)
- Escape ALL special characters in the SQL string:
* Newlines must be \\n
* Quotes must be \\"
* Backslashes must be \\\\
- The "sql" field must be a single-line string with \\n for line breaks
- Always end SQL with semicolon
- If the request is impossible with the given schema, set "sql" to "-- ERROR: <explanation>" and explain in "warnings"
Generate the SQL query now:"""
try:
response = self.client.messages.create(
model="claude-3-opus-20240229",
max_tokens=4000,
messages=[{"role": "user", "content": prompt}]
)
content = response.content[0].text.strip()
# Remove markdown code blocks if present
content = content.replace("```json", "").replace("```", "").strip()
# Try to parse JSON
try:
result = json.loads(content)
# Ensure all required fields exist
if "sql" not in result or not result["sql"]:
result["sql"] = "-- ERROR: No SQL generated"
if "explanation" not in result:
result["explanation"] = "SQL query generated"
if "query_type" not in result:
result["query_type"] = "select"
if "warnings" not in result:
result["warnings"] = []
if "optimizations" not in result:
result["optimizations"] = []
return result
except json.JSONDecodeError as e:
# If JSON parsing fails, try to extract components manually using regex
print(f"JSON parsing failed: {e}, attempting manual extraction...")
# Try to extract SQL between "sql": " and next quote
sql_match = re.search(r'"sql"\s*:\s*"((?:[^"\\]|\\.|\\n)*)"', content, re.DOTALL)
if sql_match:
sql = sql_match.group(1)
# Unescape the JSON string
sql = sql.replace('\\n', '\n').replace('\\"', '"').replace('\\\\', '\\')
else:
# Try alternative format or use entire content
sql = content
# Try to extract explanation
expl_match = re.search(r'"explanation"\s*:\s*"((?:[^"\\]|\\.)*)"', content, re.DOTALL)
explanation = expl_match.group(1) if expl_match else "SQL generated (with parsing issues)"
# Try to extract query type
type_match = re.search(r'"query_type"\s*:\s*"([^"]*)"', content)
query_type = type_match.group(1) if type_match else "select"
result = {
"sql": sql,
"explanation": explanation,
"query_type": query_type,
"warnings": ["JSON parsing issue - SQL may need review"],
"optimizations": []
}
return result
except Exception as e:
return {
"sql": f"-- ERROR: {str(e)}",
"explanation": "An error occurred during SQL generation",
"query_type": "error",
"warnings": [str(e)],
"optimizations": []
}
def _format_schema(self, schema):
"""Format schema for prompt"""
lines = ["Tables and Columns:"]
for table, columns in schema.items():
lines.append(f"\nπŸ“Š {table}")
for col in columns:
nullable = "NULL" if col.get('nullable', True) else "NOT NULL"
pk = " (PRIMARY KEY)" if col.get('primary_key', False) else ""
lines.append(f" - {col['name']}: {col['type']} {nullable}{pk}")
return "\n".join(lines)
# Legacy method for backward compatibility
def parse(self, description, schema):
"""Legacy method - redirects to generate_sql"""
result = self.generate_sql(description, schema)
# Return just the SQL for backward compatibility
return {"raw_sql": result.get("sql", ""), "metadata": result}