Spaces:
Sleeping
Sleeping
| import openai | |
| import os | |
| from typing import Dict, Any, Optional, Tuple | |
| import re | |
| import json | |
| class NaturalLanguageToSQL: | |
| def __init__(self, api_key: Optional[str] = None): | |
| """Initialize OpenAI client for natural language to SQL conversion""" | |
| self.client = openai.OpenAI( | |
| api_key=api_key or os.getenv('OPENAI_API_KEY') | |
| ) | |
| # Database schema description for the LLM | |
| self.schema_description = """ | |
| Database Schema: | |
| Table: suppliers | |
| - id (INTEGER PRIMARY KEY) | |
| - name (VARCHAR(255)) - Supplier company name | |
| - contact_info (TEXT) - Contact information | |
| - created_at (TIMESTAMP) | |
| Table: customers | |
| - id (INTEGER PRIMARY KEY) | |
| - name (VARCHAR(255)) - Customer name | |
| - email (VARCHAR(255)) | |
| - phone (VARCHAR(50)) | |
| - address (TEXT) | |
| - created_at (TIMESTAMP) | |
| Table: products | |
| - id (INTEGER PRIMARY KEY) | |
| - name (VARCHAR(255)) - Product name | |
| - description (TEXT) | |
| - category (VARCHAR(100)) - Product category | |
| - created_at (TIMESTAMP) | |
| Table: purchases | |
| - id (INTEGER PRIMARY KEY) | |
| - supplier_id (INTEGER) - Foreign key to suppliers table | |
| - product_id (INTEGER) - Foreign key to products table | |
| - quantity (INTEGER) - Number of items purchased | |
| - unit_price (DECIMAL(10,2)) - Price per unit | |
| - total_cost (DECIMAL(10,2)) - Total purchase cost | |
| - purchase_date (TIMESTAMP) - When purchase was made | |
| - notes (TEXT) - Additional notes | |
| Table: sales | |
| - id (INTEGER PRIMARY KEY) | |
| - customer_id (INTEGER) - Foreign key to customers table | |
| - product_id (INTEGER) - Foreign key to products table | |
| - quantity (INTEGER) - Number of items sold | |
| - unit_price (DECIMAL(10,2)) - Price per unit | |
| - total_amount (DECIMAL(10,2)) - Total sale amount | |
| - sale_date (TIMESTAMP) - When sale was made | |
| - notes (TEXT) - Additional notes | |
| Relationships: | |
| - purchases.supplier_id β suppliers.id | |
| - purchases.product_id β products.id | |
| - sales.customer_id β customers.id | |
| - sales.product_id β products.id | |
| """ | |
| def convert_to_sql(self, natural_language_query: str) -> Tuple[str, str]: | |
| """ | |
| Convert natural language query to SQL | |
| Returns: (sql_query, explanation) | |
| """ | |
| system_prompt = f"""You are an expert SQL query generator. Given a natural language question about a business database, generate the appropriate SQL query. | |
| {self.schema_description} | |
| Guidelines: | |
| 1. Generate valid SQLite syntax | |
| 2. Use JOINs when accessing related data across tables | |
| 3. Use appropriate WHERE clauses for filtering | |
| 4. Use aggregate functions (COUNT, SUM, AVG) when appropriate | |
| 5. Use ORDER BY for sorting results | |
| 6. Use LIMIT for restricting result count when reasonable | |
| 7. Always use proper table aliases for clarity | |
| 8. Handle date ranges using DATE() function for SQLite | |
| 9. Use LIKE with % wildcards for text searches | |
| 10. Return only the SQL query, no explanations unless specifically requested | |
| Example queries: | |
| - "Show all USB drives purchased" β SELECT p.name, pu.quantity, pu.unit_price, s.name as supplier FROM purchases pu JOIN products p ON pu.product_id = p.id JOIN suppliers s ON pu.supplier_id = s.id WHERE p.name LIKE '%USB%' | |
| - "Total sales this month" β SELECT SUM(total_amount) FROM sales WHERE DATE(sale_date) >= DATE('now', 'start of month') | |
| - "Top 5 customers by sales" β SELECT c.name, SUM(s.total_amount) as total FROM sales s JOIN customers c ON s.customer_id = c.id GROUP BY c.id, c.name ORDER BY total DESC LIMIT 5 | |
| """ | |
| user_prompt = f"""Convert this natural language query to SQL: | |
| "{natural_language_query}" | |
| Return ONLY the SQL query, nothing else.""" | |
| try: | |
| response = self.client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| temperature=0.1, | |
| max_tokens=500 | |
| ) | |
| sql_query = response.choices[0].message.content.strip() | |
| # Clean up the SQL query (remove markdown formatting if present) | |
| sql_query = re.sub(r'^```sql\s*', '', sql_query) | |
| sql_query = re.sub(r'\s*```$', '', sql_query) | |
| sql_query = sql_query.strip() | |
| # Generate explanation | |
| explanation = self._generate_explanation(natural_language_query, sql_query) | |
| return sql_query, explanation | |
| except Exception as e: | |
| return f"-- Error generating SQL: {str(e)}", f"Failed to convert query: {str(e)}" | |
| def _generate_explanation(self, nl_query: str, sql_query: str) -> str: | |
| """Generate a human-readable explanation of what the SQL query does""" | |
| system_prompt = """You are a helpful assistant that explains SQL queries in simple terms. | |
| Given a natural language question and the corresponding SQL query, provide a brief explanation of what the SQL query does.""" | |
| user_prompt = f"""Natural language query: "{nl_query}" | |
| SQL query: {sql_query} | |
| Provide a brief explanation of what this SQL query does:""" | |
| try: | |
| response = self.client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| temperature=0.3, | |
| max_tokens=200 | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| return f"Generated SQL query for: {nl_query}" | |
| def validate_sql(self, sql_query: str) -> Tuple[bool, str]: | |
| """ | |
| Basic validation of SQL query structure | |
| Returns: (is_valid, error_message) | |
| """ | |
| # Basic checks | |
| sql_lower = sql_query.lower().strip() | |
| # Check for dangerous operations | |
| dangerous_keywords = ['drop', 'delete', 'truncate', 'alter', 'create', 'insert', 'update'] | |
| for keyword in dangerous_keywords: | |
| if keyword in sql_lower and not sql_lower.startswith('select'): | |
| return False, f"Query contains potentially dangerous keyword: {keyword}" | |
| # Check if it starts with SELECT (read-only queries only) | |
| if not sql_lower.startswith('select'): | |
| return False, "Only SELECT queries are allowed for security" | |
| # Basic syntax checks | |
| if sql_query.count('(') != sql_query.count(')'): | |
| return False, "Unmatched parentheses in query" | |
| # Check for basic SQL injection patterns | |
| injection_patterns = [r";\s*(drop|delete|insert|update)", r"--", r"/\*.*\*/"] | |
| for pattern in injection_patterns: | |
| if re.search(pattern, sql_lower): | |
| return False, f"Query contains potentially unsafe pattern: {pattern}" | |
| return True, "Query appears valid" | |
| def suggest_corrections(self, natural_language_query: str, error_message: str) -> str: | |
| """Suggest how to rephrase the query if it fails""" | |
| suggestions = { | |
| "table": "Make sure you're asking about purchases, sales, customers, suppliers, or products", | |
| "column": "Try using terms like 'name', 'quantity', 'price', 'date', 'total'", | |
| "syntax": "Try rephrasing your question more simply", | |
| "ambiguous": "Be more specific about what data you want to see" | |
| } | |
| error_lower = error_message.lower() | |
| for key, suggestion in suggestions.items(): | |
| if key in error_lower: | |
| return f"Suggestion: {suggestion}" | |
| return "Try rephrasing your question or ask for help with available data" |