File size: 5,667 Bytes
f9ad313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8441ef
f9ad313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7562827
 
 
 
f9ad313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7562827
f9ad313
 
 
 
 
7562827
f9ad313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""

Text-to-SQL Generator - Multi-Database Support.



Uses LLM to generate SQL queries from natural language,

with dynamic schema context. Supports MySQL, PostgreSQL, and SQLite.

"""

import logging
from typing import Optional, Dict, Any, List, Tuple
import re

logger = logging.getLogger(__name__)


def get_sql_dialect(db_type: str) -> str:
    """Get the SQL dialect name for the given database type."""
    dialects = {
        "mysql": "MySQL",
        "postgresql": "PostgreSQL"
    }
    return dialects.get(db_type, "SQL")


def get_dialect_specific_hints(db_type: str) -> str:
    """Get database-specific hints for SQL generation."""
    if db_type == "postgresql":
        return """

PostgreSQL-SPECIFIC NOTES:

- Use ILIKE for case-insensitive pattern matching (instead of LIKE)

- String concatenation uses || operator

- Use LIMIT at the end of queries

- Boolean values are TRUE/FALSE (not 1/0)

- Use double quotes for identifiers with special chars, single quotes for strings

"""
    elif db_type == "sqlite":
        return """

SQLite-SPECIFIC NOTES:

- LIKE is case-insensitive for ASCII characters by default

- Use || for string concatenation

- No ILIKE - use LIKE (case-insensitive) or GLOB (case-sensitive)

- Use LIMIT at the end of queries

- Boolean values are 1/0

- Uses strftime() for date functions instead of DATE_FORMAT

"""
    else:  # MySQL
        return """

MySQL-SPECIFIC NOTES:

- LIKE is case-insensitive for non-binary strings

- Use CONCAT() for string concatenation

- Use LIMIT at the end of queries

- Boolean values are 1/0

- Use backticks for identifiers with special chars, single quotes for strings

"""


class SQLGenerator:
    """Generates SQL queries from natural language using LLM."""
    
    SYSTEM_PROMPT_TEMPLATE = """You are a SQL expert. Generate {dialect} SELECT queries based on user questions.



RULES:

1. ONLY generate SELECT statements.

2. NEVER use INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, or TRUNCATE.

3. Always include a LIMIT clause (max 50 rows unless specified).

4. Use table and column names EXACTLY as shown in the schema.

5. AMBIGUITY: If the user asks for a category, type, or specific value, and you are unsure which column it belongs to:

   - Check multiple likely columns (e.g., `category`, `sub_category`, `type`, `description`).

   - Use pattern matching for flexibility.

   - Use `OR` to combine multiple column checks.

6. DATA AWARENESS: In footwear databases, specific types like 'Formal', 'Casual', or 'Sports' often appear in `sub_category` OR `category`. Check both if available.

7. Return ONLY the SQL query, no explanations.

8. PAGINATION: If the user asks to "show more", "show other", "see remaining", or similar follow-up:

   - Look at the previous conversation for the original query conditions.

   - Use LIMIT with OFFSET to get the next set of results (e.g., LIMIT 10 OFFSET 10 for the second page).

   - Keep the same WHERE conditions from the previous query.



{dialect_hints}



DATABASE SCHEMA:

{schema}



Generate a single {dialect} SELECT query to answer the user's question."""

    def __init__(self, llm_client=None, db_type: str = "mysql"):
        self.llm_client = llm_client
        self.db_type = db_type
    
    def set_llm_client(self, llm_client):
        self.llm_client = llm_client
    
    def set_db_type(self, db_type: str):
        """Set the database type for SQL generation."""
        self.db_type = db_type
    
    def generate(

        self,

        question: str,

        schema_context: str,

        chat_history: Optional[List[Dict[str, str]]] = None

    ) -> Tuple[str, str]:
        """

        Generate SQL from natural language.

        

        Returns:

            Tuple of (sql_query, explanation)

        """
        if not self.llm_client:
            raise ValueError("LLM client not configured")
        
        dialect = get_sql_dialect(self.db_type)
        dialect_hints = get_dialect_specific_hints(self.db_type)
        
        system_prompt = self.SYSTEM_PROMPT_TEMPLATE.format(
            dialect=dialect,
            dialect_hints=dialect_hints,
            schema=schema_context
        )
        
        messages = [{"role": "system", "content": system_prompt}]
        
        if chat_history:
            for msg in chat_history[-3:]:  # Last 3 exchanges for context
                messages.append(msg)
        
        messages.append({"role": "user", "content": question})
        
        response = self.llm_client.chat(messages)
        
        # Extract SQL from response
        sql = self._extract_sql(response)
        
        return sql, response
    
    def _extract_sql(self, response: str) -> str:
        """Extract SQL query from LLM response."""
        # Look for SQL in code blocks
        code_block = re.search(r'```(?:sql)?\s*(.*?)```', response, re.DOTALL | re.IGNORECASE)
        if code_block:
            return code_block.group(1).strip()
        
        # Look for SELECT statement
        select_match = re.search(
            r'(SELECT\s+.+?(?:;|$))', 
            response, 
            re.DOTALL | re.IGNORECASE
        )
        if select_match:
            return select_match.group(1).strip().rstrip(';')
        
        return response.strip()


_generator: Optional[SQLGenerator] = None


def get_sql_generator(db_type: str = "mysql") -> SQLGenerator:
    global _generator
    if _generator is None:
        _generator = SQLGenerator(db_type=db_type)
    else:
        _generator.set_db_type(db_type)
    return _generator