""" Template-based SQL generator for SQLite. Deterministic schema-aware NL→SQL with minimal templates and high accuracy. Focuses only on: users, categories, suppliers, expenses, debts. """ import os import re from datetime import datetime from typing import Any, Dict, Optional, Tuple from dataclasses import dataclass, asdict # Core business schema only (for reference/documentation) DEFAULT_DB_SCHEMA = ( "users : id int , name varchar , email varchar , created_at datetime , updated_at datetime | " "categories : id int , name varchar , slug varchar , notes text , created_at datetime , updated_at datetime | " "suppliers : id int , name varchar , slug varchar , category_id int , created_at datetime , updated_at datetime | " "expenses : id int , user_id int , date date , category_id int , supplier_id int , sum numeric , notes text , created_at datetime , updated_at datetime | " "debts : id int , date date , user_id int , debt_sum numeric , payment_status varchar , partial_sum numeric , date_paid date , created_at datetime , updated_at datetime" ) _SQL_GENERATOR: Any | None = None _MONTHS = { "january": 1, "february": 2, "march": 3, "april": 4, "may": 5, "june": 6, "july": 7, "august": 8, "september": 9, "october": 10, "november": 11, "december": 12, } @dataclass(frozen=True) class SqlGenerationRequest: question: str limit: int = 200 def _normalize_text(text: str) -> str: return re.sub(r"\s+", " ", text.lower()).strip() def _contains_any(text: str, markers: tuple[str, ...]) -> bool: return any(marker in text for marker in markers) def _end_of_month(year: int, month: int) -> int: if month == 12: return 31 next_month = date(year, month + 1, 1) current_month = date(year, month, 1) return (next_month - current_month).days def _extract_month_filter(question: str) -> tuple[str, str] | None: text = _normalize_text(question) for month_name, month_idx in _MONTHS.items(): if month_name in text: year_match = re.search(r"\b(20\d{2})\b", text) if not year_match: continue year = int(year_match.group(1)) day_end = _end_of_month(year, month_idx) start = f"{year:04d}-{month_idx:02d}-01" end = f"{year:04d}-{month_idx:02d}-{day_end:02d}" return start, end return None def _extract_top_limit(question: str) -> int | None: match = re.search(r"\btop\s+(\d{1,4})\b", _normalize_text(question)) if not match: return None return max(1, min(1000, int(match.group(1)))) def _extract_metric(question: str) -> tuple[str, str]: text = _normalize_text(question) if _contains_any(text, ("count", "how many", "number of")): return "COUNT(*)", "items_count" if _contains_any(text, ("average", "avg", "mean")): return "AVG(e.sum)", "avg_amount" if _contains_any(text, ("minimum", "lowest", "min ")): return "MIN(e.sum)", "min_amount" if _contains_any(text, ("maximum", "highest", "max ")): return "MAX(e.sum)", "max_amount" return "SUM(e.sum)", "total_amount" def _extract_dimension(question: str) -> str | None: text = _normalize_text(question) if _contains_any(text, ("category", "categories")): return "category" if _contains_any(text, ("supplier", "suppliers", "vendor", "vendors")): return "supplier" if _contains_any(text, ("user", "users", "person")): return "user" return None def _build_expenses_aggregate_sql(payload: SqlGenerationRequest) -> str: question = _normalize_text(payload.question) metric_expr, metric_alias = _extract_metric(question) dimension = _extract_dimension(question) select_parts = [] joins = [] group_by = [] if dimension == "category": select_parts.append("c.name AS category_name") joins.append("JOIN categories AS c ON c.id = e.category_id") group_by.append("c.id, c.name") elif dimension == "supplier": select_parts.append("s.name AS supplier_name") joins.append("JOIN suppliers AS s ON s.id = e.supplier_id") group_by.append("s.id, s.name") elif dimension == "user": select_parts.append("u.name AS user_name") joins.append("JOIN users AS u ON u.id = e.user_id") group_by.append("u.id, u.name") select_parts.append(f"{metric_expr} AS {metric_alias}") filters = [] month_filter = _extract_month_filter(question) if month_filter: start, end = month_filter filters.append(f"e.date BETWEEN '{start}' AND '{end}'") where_clause = f" WHERE {' AND '.join(filters)}" if filters else "" join_clause = f" {' '.join(joins)}" if joins else "" group_clause = f" GROUP BY {', '.join(group_by)}" if group_by else "" order_direction = "ASC" if " asc" in question or "ascending" in question else "DESC" order_clause = f" ORDER BY {metric_alias} {order_direction}" top_limit = _extract_top_limit(question) final_limit = top_limit if top_limit is not None else payload.limit return ( f"SELECT {', '.join(select_parts)} " f"FROM expenses AS e" f"{join_clause}" f"{where_clause}" f"{group_clause}" f"{order_clause}" f" LIMIT {final_limit}" ) def _build_expenses_detail_sql(payload: SqlGenerationRequest) -> str: question = _normalize_text(payload.question) include_category = _contains_any(question, ("category", "categories")) include_supplier = _contains_any(question, ("supplier", "suppliers", "vendor", "vendors")) include_user = _contains_any(question, ("user", "users", "person")) select_parts = ["e.date", "e.sum", "e.notes"] joins = [] if include_category: select_parts.append("c.name AS category_name") joins.append("JOIN categories AS c ON c.id = e.category_id") if include_supplier: select_parts.append("s.name AS supplier_name") joins.append("JOIN suppliers AS s ON s.id = e.supplier_id") if include_user: select_parts.append("u.name AS user_name") joins.append("JOIN users AS u ON u.id = e.user_id") filters = [] month_filter = _extract_month_filter(question) if month_filter: start, end = month_filter filters.append(f"e.date BETWEEN '{start}' AND '{end}'") where_clause = f" WHERE {' AND '.join(filters)}" if filters else "" join_clause = f" {' '.join(joins)}" if joins else "" order_clause = " ORDER BY e.date DESC" return ( f"SELECT {', '.join(select_parts)} " f"FROM expenses AS e" f"{join_clause}" f"{where_clause}" f"{order_clause}" f" LIMIT {payload.limit}" ) def _build_debts_sql(payload: SqlGenerationRequest) -> str: question = _normalize_text(payload.question) with_user = _contains_any(question, ("user", "users", "person", "name")) select_parts = ["d.date", "d.debt_sum", "d.payment_status"] joins = [] if with_user: select_parts.append("u.name AS user_name") joins.append("LEFT JOIN users AS u ON u.id = d.user_id") filters = [] if _contains_any(question, ("unpaid", "not paid", "open debt", "open debts")): filters.append("d.payment_status = 'unpaid'") elif _contains_any(question, ("paid", "closed debt", "closed debts")): filters.append("d.payment_status = 'paid'") elif _contains_any(question, ("partial", "partially")): filters.append("d.payment_status = 'partial'") month_filter = _extract_month_filter(question) if month_filter: start, end = month_filter filters.append(f"d.date BETWEEN '{start}' AND '{end}'") where_clause = f" WHERE {' AND '.join(filters)}" if filters else "" join_clause = f" {' '.join(joins)}" if joins else "" order_clause = " ORDER BY d.date DESC" return ( f"SELECT {', '.join(select_parts)} " f"FROM debts AS d" f"{join_clause}" f"{where_clause}" f"{order_clause}" f" LIMIT {payload.limit}" ) def _generate_template_sql(payload: SqlGenerationRequest) -> str: question = _normalize_text(payload.question) debt_markers = ("debt", "debts", "payment_status", "unpaid", "partial", "paid") aggregate_markers = ( "sum", "total", "group", "grouped", "top", "count", "average", "avg", "minimum", "maximum", ) if _contains_any(question, debt_markers): return _build_debts_sql(payload) if _contains_any(question, aggregate_markers): return _build_expenses_aggregate_sql(payload) return _build_expenses_detail_sql(payload) def _get_sql_generator() -> Any: global _SQL_GENERATOR if _SQL_GENERATOR is None: from transformers import pipeline model_id = os.getenv("SQL_MODEL", "gaussalgo/T5-LM-Large-text2sql-spider") _SQL_GENERATOR = pipeline( task="text2text-generation", model=model_id, tokenizer=model_id, device=-1, ) return _SQL_GENERATOR def _build_prompt(payload: SqlGenerationRequest) -> str: # Optional fallback prompt for transformer model. return f"Question: {payload.question} | {DEFAULT_DB_SCHEMA}" def _normalize_sql(raw_sql: str, limit: int) -> str: sql = (raw_sql or "").strip() if not sql: raise ValueError("SQL model returned an empty result.") if "```" in sql: parts = [part.strip() for part in sql.split("```") if part.strip()] sql = parts[-1] upper_sql = sql.upper() sql_start = upper_sql.find("SELECT") if sql_start == -1: raise ValueError("Generated SQL is not a SELECT query.") sql = sql[sql_start:] if ";" in sql: sql = sql.split(";", 1)[0].strip() upper_sql = sql.upper() forbidden = ("INSERT ", "UPDATE ", "DELETE ", "DROP ", "ALTER ", "PRAGMA ", "ATTACH ", "CREATE ", "REPLACE ") if any(keyword in upper_sql for keyword in forbidden): raise ValueError("Generated SQL contains forbidden statements.") if not upper_sql.startswith("SELECT "): raise ValueError("Only SELECT queries are allowed.") aggregate_markers = ("COUNT(", "SUM(", "AVG(", "MIN(", "MAX(") has_limit = " LIMIT " in upper_sql if not has_limit and not any(marker in upper_sql for marker in aggregate_markers): sql = f"{sql} LIMIT {limit}" return sql def generate_sql(question: str, limit: int = 200) -> str: clean_question = (question or "").strip() if not clean_question: raise ValueError("Field 'query' is required.") payload = SqlGenerationRequest( question=clean_question, limit=max(1, min(1000, int(limit))), ) # Primary path: deterministic template engine for core tables. template_sql = _generate_template_sql(payload) if template_sql: return _normalize_sql(template_sql, limit=payload.limit) # Secondary path: optional model fallback. if os.getenv("SQL_USE_LLM_FALLBACK", "false").strip().lower() not in {"1", "true", "yes", "on"}: raise ValueError("Unable to map query to a supported SQL template.") generator = _get_sql_generator() prompt = _build_prompt(payload) result = generator( prompt, max_new_tokens=512, do_sample=False, num_beams=4, truncation=True, ) generated_text = result[0].get("generated_text", "") if result else "" return _normalize_sql(generated_text, limit=payload.limit)