""" SQL Generator: load the trained LoRA adapter on top of the standard Qwen 2.5 Coder 7B base. Loaded at module level per ZeroGPU best practice. """ import logging import re from typing import Optional import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel logger = logging.getLogger(__name__) SYSTEM_PROMPT = """You are an expert SQL analyst working with a DuckDB database. Your job: convert a natural-language question into ONE correct SQL query. ## Rules 1. **Map user wording to schema columns.** Users rarely use the exact column name. Infer which column they mean from the schema, sample rows, and any distinct-value hints. Examples: "shoe size" → use the "size" column; "where" → likely a "country" or "city" column; "biggest" → ORDER BY DESC LIMIT. 2. **Use ONLY the exact column names that appear in the CREATE TABLE statements.** Wrap them in double quotes when they contain spaces or reserved words. 3. **DuckDB syntax** (PostgreSQL-flavored): supports CTEs, window functions, LIMIT, OFFSET, regex, date arithmetic, JSON functions. 4. **Aggregations**: alias them descriptively, e.g. `AVG(price) AS avg_price`, `COUNT(*) AS total`. 5. **Default to TOP 10** when the user asks for "top", "best", "most" without specifying a number. 6. **Filter explicitly**: if the user mentions a categorical value (e.g. "active customers"), match it against distinct values in the hints and use the exact spelling. 7. **GROUP BY rule (critical)**: every non-aggregated column in SELECT must appear in GROUP BY, OR be wrapped in an aggregation (`AVG`, `SUM`, `COUNT`, `MAX`, `MIN`). Never SELECT a raw column when GROUP BY is present unless it's a grouping key. 8. Output **only the SQL**. No markdown fences, no explanation. End with a semicolon. ## Examples Schema: CREATE TABLE sales ("id" INT, "product_name" VARCHAR, "revenue" DOUBLE); Question: top earners SQL: SELECT product_name, revenue FROM sales ORDER BY revenue DESC LIMIT 10; Schema: CREATE TABLE coffee ("age" INT, "wait_time_sec" DOUBLE, "method" VARCHAR); -- coffee.method distinct values: 'espresso', 'pour_over', 'french_press' Question: average wait by brewing style SQL: SELECT method, AVG(wait_time_sec) AS avg_wait FROM coffee GROUP BY method ORDER BY avg_wait DESC; Schema: CREATE TABLE companies ("name" VARCHAR, "founded" INT, "country" VARCHAR); Question: how many startups per region SQL: SELECT country, COUNT(*) AS total FROM companies GROUP BY country ORDER BY total DESC; """ BASE_MODEL = "Qwen/Qwen2.5-Coder-7B-Instruct" ADAPTER_REPO = "DanielRegaladoCardoso/sql-generator-qwen25-coder-7b-lora" class SQLGenerator: def __init__(self, temperature: float = 0.0, max_new_tokens: int = 400) -> None: self.temperature = temperature self.max_new_tokens = max_new_tokens logger.info(f"Loading SQL base: {BASE_MODEL}") self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) base = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.bfloat16, device_map="cuda", ) logger.info(f"Applying LoRA adapter: {ADAPTER_REPO}") self.model = PeftModel.from_pretrained( base, ADAPTER_REPO, torch_dtype=torch.bfloat16, ) self.model.eval() logger.info("SQL generator ready (LoRA applied on Qwen base)") def generate( self, question: str, schema: str, previous_sql: Optional[str] = None, previous_error: Optional[str] = None, ) -> str: """Generate SQL. If previous_sql + previous_error are provided, the model is told what went wrong so it can self-correct.""" messages = [{"role": "system", "content": SYSTEM_PROMPT}] user_content = f"### Schema\n{schema}\n\n### Question\n{question}" messages.append({"role": "user", "content": user_content}) if previous_sql and previous_error: # Retry context: feed the model its previous attempt + the error messages.append({"role": "assistant", "content": previous_sql}) messages.append({ "role": "user", "content": ( f"That query failed with error:\n{previous_error}\n\n" "Fix the query. Return only the corrected SQL." ), }) input_ids = self.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" ).to(self.model.device) with torch.no_grad(): out = self.model.generate( input_ids, max_new_tokens=self.max_new_tokens, do_sample=self.temperature > 0, temperature=self.temperature if self.temperature > 0 else 1.0, pad_token_id=self.tokenizer.eos_token_id, ) text = self.tokenizer.decode( out[0][input_ids.shape[1]:], skip_special_tokens=True ) return self._clean_sql(text) @staticmethod def _clean_sql(text: str) -> str: text = text.strip() text = re.sub(r"^```(?:sql)?\s*", "", text, flags=re.IGNORECASE) text = re.sub(r"\s*```\s*$", "", text) if ";" in text: stmt, _, _ = text.partition(";") text = stmt + ";" return text.strip() def narrate( self, question: str, sql: str, results: list[dict], columns: list[dict], ) -> str: """Generate a 2-sentence narrative interpreting the query results. Reuses the same Qwen model that's already loaded for SQL generation.""" if not results: return "" sample = results[:10] col_names = [c["name"] for c in columns] narrate_system = ( "You are a senior data analyst summarizing query results for a " "stakeholder. Write exactly 1-2 short sentences highlighting the " "single most interesting finding (top contributor, sharp " "distribution, surprising gap, etc.). Use specific numbers from " "the data. Do not describe what the chart looks like; describe " "what the data reveals. No preamble like 'Here is...'." ) user_content = ( f"Question asked: {question}\n" f"Columns: {col_names}\n" f"Total rows: {len(results)}\n" f"Top rows: {sample}\n\n" "Write the 1-2 sentence finding now:" ) messages = [ {"role": "system", "content": narrate_system}, {"role": "user", "content": user_content}, ] try: import torch input_ids = self.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" ).to(self.model.device) with torch.no_grad(): out = self.model.generate( input_ids, max_new_tokens=120, do_sample=False, pad_token_id=self.tokenizer.eos_token_id, ) text = self.tokenizer.decode( out[0][input_ids.shape[1]:], skip_special_tokens=True ).strip() # Strip leading quotes/markdown text = re.sub(r"^[\"'`]+|[\"'`]+$", "", text) return text except Exception as e: logger.warning(f"narration failed: {e}") return ""