Spaces:
Running on Zero
Running on Zero
File size: 7,573 Bytes
a067ada 730b25d a067ada c2ac226 730b25d a067ada a7fa6d3 e6681c9 a7fa6d3 a067ada 730b25d a067ada c2ac226 730b25d a067ada 730b25d c2ac226 a067ada 730b25d a067ada 730b25d a067ada e6681c9 a067ada e6681c9 a067ada c2ac226 a067ada 5ae5227 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | """
SQL Generator: load the trained LoRA adapter on top of the standard Qwen
2.5 Coder 7B base. Loaded at module level per ZeroGPU best practice.
"""
import logging
import re
from typing import Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = """You are an expert SQL analyst working with a DuckDB database.
Your job: convert a natural-language question into ONE correct SQL query.
## Rules
1. **Map user wording to schema columns.** Users rarely use the exact column name. Infer which column they mean from the schema, sample rows, and any distinct-value hints. Examples: "shoe size" → use the "size" column; "where" → likely a "country" or "city" column; "biggest" → ORDER BY DESC LIMIT.
2. **Use ONLY the exact column names that appear in the CREATE TABLE statements.** Wrap them in double quotes when they contain spaces or reserved words.
3. **DuckDB syntax** (PostgreSQL-flavored): supports CTEs, window functions, LIMIT, OFFSET, regex, date arithmetic, JSON functions.
4. **Aggregations**: alias them descriptively, e.g. `AVG(price) AS avg_price`, `COUNT(*) AS total`.
5. **Default to TOP 10** when the user asks for "top", "best", "most" without specifying a number.
6. **Filter explicitly**: if the user mentions a categorical value (e.g. "active customers"), match it against distinct values in the hints and use the exact spelling.
7. **GROUP BY rule (critical)**: every non-aggregated column in SELECT must appear in GROUP BY, OR be wrapped in an aggregation (`AVG`, `SUM`, `COUNT`, `MAX`, `MIN`). Never SELECT a raw column when GROUP BY is present unless it's a grouping key.
8. Output **only the SQL**. No markdown fences, no explanation. End with a semicolon.
## Examples
Schema: CREATE TABLE sales ("id" INT, "product_name" VARCHAR, "revenue" DOUBLE);
Question: top earners
SQL: SELECT product_name, revenue FROM sales ORDER BY revenue DESC LIMIT 10;
Schema: CREATE TABLE coffee ("age" INT, "wait_time_sec" DOUBLE, "method" VARCHAR);
-- coffee.method distinct values: 'espresso', 'pour_over', 'french_press'
Question: average wait by brewing style
SQL: SELECT method, AVG(wait_time_sec) AS avg_wait FROM coffee GROUP BY method ORDER BY avg_wait DESC;
Schema: CREATE TABLE companies ("name" VARCHAR, "founded" INT, "country" VARCHAR);
Question: how many startups per region
SQL: SELECT country, COUNT(*) AS total FROM companies GROUP BY country ORDER BY total DESC;
"""
BASE_MODEL = "Qwen/Qwen2.5-Coder-7B-Instruct"
ADAPTER_REPO = "DanielRegaladoCardoso/sql-generator-qwen25-coder-7b-lora"
class SQLGenerator:
def __init__(self, temperature: float = 0.0, max_new_tokens: int = 400) -> None:
self.temperature = temperature
self.max_new_tokens = max_new_tokens
logger.info(f"Loading SQL base: {BASE_MODEL}")
self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="cuda",
)
logger.info(f"Applying LoRA adapter: {ADAPTER_REPO}")
self.model = PeftModel.from_pretrained(
base,
ADAPTER_REPO,
torch_dtype=torch.bfloat16,
)
self.model.eval()
logger.info("SQL generator ready (LoRA applied on Qwen base)")
def generate(
self,
question: str,
schema: str,
previous_sql: Optional[str] = None,
previous_error: Optional[str] = None,
) -> str:
"""Generate SQL. If previous_sql + previous_error are provided, the
model is told what went wrong so it can self-correct."""
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
user_content = f"### Schema\n{schema}\n\n### Question\n{question}"
messages.append({"role": "user", "content": user_content})
if previous_sql and previous_error:
# Retry context: feed the model its previous attempt + the error
messages.append({"role": "assistant", "content": previous_sql})
messages.append({
"role": "user",
"content": (
f"That query failed with error:\n{previous_error}\n\n"
"Fix the query. Return only the corrected SQL."
),
})
input_ids = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
).to(self.model.device)
with torch.no_grad():
out = self.model.generate(
input_ids,
max_new_tokens=self.max_new_tokens,
do_sample=self.temperature > 0,
temperature=self.temperature if self.temperature > 0 else 1.0,
pad_token_id=self.tokenizer.eos_token_id,
)
text = self.tokenizer.decode(
out[0][input_ids.shape[1]:], skip_special_tokens=True
)
return self._clean_sql(text)
@staticmethod
def _clean_sql(text: str) -> str:
text = text.strip()
text = re.sub(r"^```(?:sql)?\s*", "", text, flags=re.IGNORECASE)
text = re.sub(r"\s*```\s*$", "", text)
if ";" in text:
stmt, _, _ = text.partition(";")
text = stmt + ";"
return text.strip()
def narrate(
self,
question: str,
sql: str,
results: list[dict],
columns: list[dict],
) -> str:
"""Generate a 2-sentence narrative interpreting the query results.
Reuses the same Qwen model that's already loaded for SQL generation."""
if not results:
return ""
sample = results[:10]
col_names = [c["name"] for c in columns]
narrate_system = (
"You are a senior data analyst summarizing query results for a "
"stakeholder. Write exactly 1-2 short sentences highlighting the "
"single most interesting finding (top contributor, sharp "
"distribution, surprising gap, etc.). Use specific numbers from "
"the data. Do not describe what the chart looks like; describe "
"what the data reveals. No preamble like 'Here is...'."
)
user_content = (
f"Question asked: {question}\n"
f"Columns: {col_names}\n"
f"Total rows: {len(results)}\n"
f"Top rows: {sample}\n\n"
"Write the 1-2 sentence finding now:"
)
messages = [
{"role": "system", "content": narrate_system},
{"role": "user", "content": user_content},
]
try:
import torch
input_ids = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
).to(self.model.device)
with torch.no_grad():
out = self.model.generate(
input_ids,
max_new_tokens=120,
do_sample=False,
pad_token_id=self.tokenizer.eos_token_id,
)
text = self.tokenizer.decode(
out[0][input_ids.shape[1]:], skip_special_tokens=True
).strip()
# Strip leading quotes/markdown
text = re.sub(r"^[\"'`]+|[\"'`]+$", "", text)
return text
except Exception as e:
logger.warning(f"narration failed: {e}")
return ""
|