Spaces:
Sleeping
Sleeping
| """ | |
| Prompt Engine for SQL Generation | |
| Constructs intelligent prompts for SQL generation using retrieved examples and best practices. | |
| """ | |
| import json | |
| from typing import List, Dict, Any, Optional | |
| from pathlib import Path | |
| from loguru import logger | |
| class PromptEngine: | |
| """Intelligent prompt construction for SQL generation.""" | |
| def __init__(self, prompts_dir: str = "./prompts"): | |
| """ | |
| Initialize the prompt engine. | |
| Args: | |
| prompts_dir: Directory containing prompt templates | |
| """ | |
| self.prompts_dir = Path(prompts_dir) | |
| self.prompts_dir.mkdir(parents=True, exist_ok=True) | |
| # Load prompt templates | |
| self.templates = self._load_prompt_templates() | |
| # Default system prompt | |
| self.default_system_prompt = """You are an expert SQL developer. Your task is to convert natural language questions into accurate SQL queries. | |
| Key Guidelines: | |
| 1. Always use the exact table column names provided | |
| 2. Generate standard SQL syntax (compatible with most databases) | |
| 3. Use appropriate JOINs when multiple tables are involved | |
| 4. Apply proper WHERE clauses for filtering | |
| 5. Use GROUP BY for aggregations when needed | |
| 6. Ensure queries are efficient and readable | |
| 7. Handle edge cases appropriately | |
| Table Schema: {table_schema} | |
| Retrieved Examples: | |
| {examples} | |
| Question: {question} | |
| Generate the SQL query:""" | |
| def _load_prompt_templates(self) -> Dict[str, str]: | |
| """Load prompt templates from files.""" | |
| templates = {} | |
| # Create default templates if they don't exist | |
| default_templates = { | |
| "sql_generation.txt": self._get_default_sql_prompt(), | |
| "few_shot_examples.txt": self._get_default_few_shot_prompt(), | |
| "error_correction.txt": self._get_default_error_correction_prompt() | |
| } | |
| for filename, content in default_templates.items(): | |
| template_path = self.prompts_dir / filename | |
| if not template_path.exists(): | |
| with open(template_path, 'w', encoding='utf-8') as f: | |
| f.write(content) | |
| logger.info(f"Created default template: {filename}") | |
| # Load the template | |
| with open(template_path, 'r', encoding='utf-8') as f: | |
| templates[filename.replace('.txt', '')] = f.read() | |
| return templates | |
| def _get_default_sql_prompt(self) -> str: | |
| """Get default SQL generation prompt template.""" | |
| return """You are an expert SQL developer. Convert the natural language question to SQL. | |
| Table Schema: {table_schema} | |
| Examples: | |
| {examples} | |
| Question: {question} | |
| Generate SQL:""" | |
| def _get_default_few_shot_prompt(self) -> str: | |
| """Get default few-shot learning prompt template.""" | |
| return """Given these examples, generate SQL for the new question: | |
| Examples: | |
| {examples} | |
| New Question: {question} | |
| Table Schema: {table_schema} | |
| SQL Query:""" | |
| def _get_default_error_correction_prompt(self) -> str: | |
| """Get default error correction prompt template.""" | |
| return """The following SQL query has an error. Please correct it: | |
| Original Question: {question} | |
| Table Schema: {table_schema} | |
| Incorrect SQL: {incorrect_sql} | |
| Error: {error_message} | |
| Corrected SQL:""" | |
| def construct_sql_prompt(self, | |
| question: str, | |
| table_headers: List[str], | |
| retrieved_examples: List[Dict[str, Any]], | |
| prompt_type: str = "sql_generation") -> str: | |
| """ | |
| Construct a prompt for SQL generation. | |
| Args: | |
| question: Natural language question | |
| table_headers: List of table column names | |
| retrieved_examples: List of retrieved relevant examples | |
| prompt_type: Type of prompt to use | |
| Returns: | |
| Constructed prompt string | |
| """ | |
| # Format table schema | |
| table_schema = self._format_table_schema(table_headers) | |
| # Format examples | |
| examples_text = self._format_examples(retrieved_examples) | |
| # Get template | |
| template = self.templates.get(prompt_type, self.templates["sql_generation"]) | |
| # Fill template | |
| prompt = template.format( | |
| question=question, | |
| table_schema=table_schema, | |
| examples=examples_text | |
| ) | |
| return prompt | |
| def construct_enhanced_prompt(self, | |
| question: str, | |
| table_headers: List[str], | |
| retrieved_examples: List[Dict[str, Any]], | |
| additional_context: Optional[Dict[str, Any]] = None) -> str: | |
| """ | |
| Construct an enhanced prompt with additional context and examples. | |
| Args: | |
| question: Natural language question | |
| table_headers: List of table column names | |
| retrieved_examples: List of retrieved relevant examples | |
| additional_context: Additional context information | |
| Returns: | |
| Enhanced prompt string | |
| """ | |
| # Start with system prompt | |
| prompt_parts = [self.default_system_prompt] | |
| # Add table schema | |
| table_schema = self._format_table_schema(table_headers) | |
| prompt_parts.append(f"Table Schema: {table_schema}\n") | |
| # Add retrieved examples with relevance scores | |
| if retrieved_examples: | |
| prompt_parts.append("Relevant Examples (ordered by relevance):") | |
| for i, example in enumerate(retrieved_examples[:3], 1): # Top 3 examples | |
| relevance = example.get("final_score", example.get("similarity_score", 0)) | |
| prompt_parts.append(f"\nExample {i} (Relevance: {relevance:.2f}):") | |
| prompt_parts.append(f"Question: {example['question']}") | |
| prompt_parts.append(f"SQL: {example['sql']}") | |
| prompt_parts.append(f"Table: {example['table_headers']}") | |
| # Add additional context if provided | |
| if additional_context: | |
| prompt_parts.append("\nAdditional Context:") | |
| for key, value in additional_context.items(): | |
| prompt_parts.append(f"{key}: {value}") | |
| # Add the current question | |
| prompt_parts.append(f"\nCurrent Question: {question}") | |
| prompt_parts.append("\nGenerate the SQL query:") | |
| return "\n".join(prompt_parts) | |
| def construct_few_shot_prompt(self, | |
| question: str, | |
| table_headers: List[str], | |
| examples: List[Dict[str, Any]]) -> str: | |
| """ | |
| Construct a few-shot learning prompt. | |
| Args: | |
| question: Natural language question | |
| table_headers: List of table column names | |
| examples: List of examples for few-shot learning | |
| Returns: | |
| Few-shot prompt string | |
| """ | |
| template = self.templates["few_shot_examples"] | |
| # Format examples in a structured way | |
| examples_text = "" | |
| for i, example in enumerate(examples[:5], 1): # Use top 5 examples | |
| examples_text += f"\n--- Example {i} ---\n" | |
| examples_text += f"Question: {example['question']}\n" | |
| examples_text += f"Table: {example['table_headers']}\n" | |
| examples_text += f"SQL: {example['sql']}\n" | |
| table_schema = self._format_table_schema(table_headers) | |
| return template.format( | |
| examples=examples_text, | |
| question=question, | |
| table_schema=table_schema | |
| ) | |
| def construct_error_correction_prompt(self, | |
| question: str, | |
| table_headers: List[str], | |
| incorrect_sql: str, | |
| error_message: str) -> str: | |
| """ | |
| Construct a prompt for error correction. | |
| Args: | |
| question: Natural language question | |
| table_headers: List of table column names | |
| incorrect_sql: The incorrect SQL query | |
| error_message: Error message or description | |
| Returns: | |
| Error correction prompt string | |
| """ | |
| template = self.templates["error_correction"] | |
| table_schema = self._format_table_schema(table_headers) | |
| return template.format( | |
| question=question, | |
| table_schema=table_schema, | |
| incorrect_sql=incorrect_sql, | |
| error_message=error_message | |
| ) | |
| def _format_table_schema(self, table_headers: List[str]) -> str: | |
| """Format table headers into a readable schema.""" | |
| if not table_headers: | |
| return "No table schema provided" | |
| # Group headers by type for better readability | |
| schema_parts = [] | |
| # Primary keys and IDs | |
| pk_headers = [h for h in table_headers if 'id' in h.lower() or 'key' in h.lower()] | |
| if pk_headers: | |
| schema_parts.append(f"Primary Keys: {', '.join(pk_headers)}") | |
| # Text fields | |
| text_headers = [h for h in table_headers if any(word in h.lower() for word in ['name', 'title', 'description', 'text'])] | |
| if text_headers: | |
| schema_parts.append(f"Text Fields: {', '.join(text_headers)}") | |
| # Numeric fields | |
| numeric_headers = [h for h in table_headers if any(word in h.lower() for word in ['age', 'count', 'price', 'salary', 'amount', 'number'])] | |
| if numeric_headers: | |
| schema_parts.append(f"Numeric Fields: {', '.join(numeric_headers)}") | |
| # Date fields | |
| date_headers = [h for h in table_headers if any(word in h.lower() for word in ['date', 'time', 'created', 'updated', 'birth'])] | |
| if date_headers: | |
| schema_parts.append(f"Date Fields: {', '.join(date_headers)}") | |
| # Boolean fields | |
| bool_headers = [h for h in table_headers if any(word in h.lower() for word in ['is_', 'has_', 'active', 'enabled', 'status'])] | |
| if bool_headers: | |
| schema_parts.append(f"Boolean Fields: {', '.join(bool_headers)}") | |
| # Other fields | |
| other_headers = [h for h in table_headers if h not in pk_headers + text_headers + numeric_headers + date_headers + bool_headers] | |
| if other_headers: | |
| schema_parts.append(f"Other Fields: {', '.join(other_headers)}") | |
| return "\n".join(schema_parts) | |
| def _format_examples(self, examples: List[Dict[str, Any]]) -> str: | |
| """Format retrieved examples for prompt inclusion.""" | |
| if not examples: | |
| return "No relevant examples found." | |
| formatted_examples = [] | |
| for i, example in enumerate(examples[:3], 1): # Use top 3 examples | |
| relevance = example.get("final_score", example.get("similarity_score", 0)) | |
| formatted_examples.append(f"Example {i} (Relevance: {relevance:.2f}):") | |
| formatted_examples.append(f" Question: {example['question']}") | |
| formatted_examples.append(f" SQL: {example['sql']}") | |
| formatted_examples.append(f" Table: {example['table_headers']}") | |
| return "\n".join(formatted_examples) | |
| def get_prompt_statistics(self) -> Dict[str, Any]: | |
| """Get statistics about the prompt engine.""" | |
| return { | |
| "available_templates": list(self.templates.keys()), | |
| "prompts_directory": str(self.prompts_dir), | |
| "template_count": len(self.templates) | |
| } | |