File size: 7,573 Bytes
a067ada
730b25d
 
a067ada
 
 
 
 
 
c2ac226
 
730b25d
a067ada
 
 
 
a7fa6d3
 
 
 
 
 
 
 
 
 
 
e6681c9
 
a7fa6d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a067ada
730b25d
 
a067ada
 
c2ac226
730b25d
 
a067ada
 
 
730b25d
 
 
 
c2ac226
 
a067ada
730b25d
 
 
 
 
 
a067ada
730b25d
a067ada
e6681c9
 
 
 
 
 
 
 
 
 
a067ada
e6681c9
 
 
 
 
 
 
 
 
 
 
 
 
a067ada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2ac226
a067ada
 
 
 
 
 
 
 
 
 
5ae5227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""
SQL Generator: load the trained LoRA adapter on top of the standard Qwen
2.5 Coder 7B base. Loaded at module level per ZeroGPU best practice.
"""

import logging
import re
from typing import Optional

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

logger = logging.getLogger(__name__)


SYSTEM_PROMPT = """You are an expert SQL analyst working with a DuckDB database.

Your job: convert a natural-language question into ONE correct SQL query.

## Rules
1. **Map user wording to schema columns.** Users rarely use the exact column name. Infer which column they mean from the schema, sample rows, and any distinct-value hints. Examples: "shoe size" → use the "size" column; "where" → likely a "country" or "city" column; "biggest" → ORDER BY DESC LIMIT.
2. **Use ONLY the exact column names that appear in the CREATE TABLE statements.** Wrap them in double quotes when they contain spaces or reserved words.
3. **DuckDB syntax** (PostgreSQL-flavored): supports CTEs, window functions, LIMIT, OFFSET, regex, date arithmetic, JSON functions.
4. **Aggregations**: alias them descriptively, e.g. `AVG(price) AS avg_price`, `COUNT(*) AS total`.
5. **Default to TOP 10** when the user asks for "top", "best", "most" without specifying a number.
6. **Filter explicitly**: if the user mentions a categorical value (e.g. "active customers"), match it against distinct values in the hints and use the exact spelling.
7. **GROUP BY rule (critical)**: every non-aggregated column in SELECT must appear in GROUP BY, OR be wrapped in an aggregation (`AVG`, `SUM`, `COUNT`, `MAX`, `MIN`). Never SELECT a raw column when GROUP BY is present unless it's a grouping key.
8. Output **only the SQL**. No markdown fences, no explanation. End with a semicolon.

## Examples
Schema: CREATE TABLE sales ("id" INT, "product_name" VARCHAR, "revenue" DOUBLE);
Question: top earners
SQL: SELECT product_name, revenue FROM sales ORDER BY revenue DESC LIMIT 10;

Schema: CREATE TABLE coffee ("age" INT, "wait_time_sec" DOUBLE, "method" VARCHAR);
-- coffee.method distinct values: 'espresso', 'pour_over', 'french_press'
Question: average wait by brewing style
SQL: SELECT method, AVG(wait_time_sec) AS avg_wait FROM coffee GROUP BY method ORDER BY avg_wait DESC;

Schema: CREATE TABLE companies ("name" VARCHAR, "founded" INT, "country" VARCHAR);
Question: how many startups per region
SQL: SELECT country, COUNT(*) AS total FROM companies GROUP BY country ORDER BY total DESC;
"""

BASE_MODEL = "Qwen/Qwen2.5-Coder-7B-Instruct"
ADAPTER_REPO = "DanielRegaladoCardoso/sql-generator-qwen25-coder-7b-lora"


class SQLGenerator:

    def __init__(self, temperature: float = 0.0, max_new_tokens: int = 400) -> None:
        self.temperature = temperature
        self.max_new_tokens = max_new_tokens

        logger.info(f"Loading SQL base: {BASE_MODEL}")
        self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
        base = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            torch_dtype=torch.bfloat16,
            device_map="cuda",
        )
        logger.info(f"Applying LoRA adapter: {ADAPTER_REPO}")
        self.model = PeftModel.from_pretrained(
            base,
            ADAPTER_REPO,
            torch_dtype=torch.bfloat16,
        )
        self.model.eval()
        logger.info("SQL generator ready (LoRA applied on Qwen base)")

    def generate(
        self,
        question: str,
        schema: str,
        previous_sql: Optional[str] = None,
        previous_error: Optional[str] = None,
    ) -> str:
        """Generate SQL. If previous_sql + previous_error are provided, the
        model is told what went wrong so it can self-correct."""
        messages = [{"role": "system", "content": SYSTEM_PROMPT}]
        user_content = f"### Schema\n{schema}\n\n### Question\n{question}"
        messages.append({"role": "user", "content": user_content})

        if previous_sql and previous_error:
            # Retry context: feed the model its previous attempt + the error
            messages.append({"role": "assistant", "content": previous_sql})
            messages.append({
                "role": "user",
                "content": (
                    f"That query failed with error:\n{previous_error}\n\n"
                    "Fix the query. Return only the corrected SQL."
                ),
            })

        input_ids = self.tokenizer.apply_chat_template(
            messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
        ).to(self.model.device)

        with torch.no_grad():
            out = self.model.generate(
                input_ids,
                max_new_tokens=self.max_new_tokens,
                do_sample=self.temperature > 0,
                temperature=self.temperature if self.temperature > 0 else 1.0,
                pad_token_id=self.tokenizer.eos_token_id,
            )

        text = self.tokenizer.decode(
            out[0][input_ids.shape[1]:], skip_special_tokens=True
        )
        return self._clean_sql(text)

    @staticmethod
    def _clean_sql(text: str) -> str:
        text = text.strip()
        text = re.sub(r"^```(?:sql)?\s*", "", text, flags=re.IGNORECASE)
        text = re.sub(r"\s*```\s*$", "", text)
        if ";" in text:
            stmt, _, _ = text.partition(";")
            text = stmt + ";"
        return text.strip()

    def narrate(
        self,
        question: str,
        sql: str,
        results: list[dict],
        columns: list[dict],
    ) -> str:
        """Generate a 2-sentence narrative interpreting the query results.
        Reuses the same Qwen model that's already loaded for SQL generation."""
        if not results:
            return ""

        sample = results[:10]
        col_names = [c["name"] for c in columns]
        narrate_system = (
            "You are a senior data analyst summarizing query results for a "
            "stakeholder. Write exactly 1-2 short sentences highlighting the "
            "single most interesting finding (top contributor, sharp "
            "distribution, surprising gap, etc.). Use specific numbers from "
            "the data. Do not describe what the chart looks like; describe "
            "what the data reveals. No preamble like 'Here is...'."
        )
        user_content = (
            f"Question asked: {question}\n"
            f"Columns: {col_names}\n"
            f"Total rows: {len(results)}\n"
            f"Top rows: {sample}\n\n"
            "Write the 1-2 sentence finding now:"
        )
        messages = [
            {"role": "system", "content": narrate_system},
            {"role": "user", "content": user_content},
        ]
        try:
            import torch
            input_ids = self.tokenizer.apply_chat_template(
                messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
            ).to(self.model.device)
            with torch.no_grad():
                out = self.model.generate(
                    input_ids,
                    max_new_tokens=120,
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id,
                )
            text = self.tokenizer.decode(
                out[0][input_ids.shape[1]:], skip_special_tokens=True
            ).strip()
            # Strip leading quotes/markdown
            text = re.sub(r"^[\"'`]+|[\"'`]+$", "", text)
            return text
        except Exception as e:
            logger.warning(f"narration failed: {e}")
            return ""