Spaces:
Running
Running
| """ | |
| Template-based SQL generator for SQLite. | |
| Deterministic schema-aware NL→SQL with minimal templates and high accuracy. | |
| Focuses only on: users, categories, suppliers, expenses, debts. | |
| """ | |
| import os | |
| import re | |
| from datetime import datetime | |
| from typing import Any, Dict, Optional, Tuple | |
| from dataclasses import dataclass, asdict | |
| # Core business schema only (for reference/documentation) | |
| DEFAULT_DB_SCHEMA = ( | |
| "users : id int , name varchar , email varchar , created_at datetime , updated_at datetime | " | |
| "categories : id int , name varchar , slug varchar , notes text , created_at datetime , updated_at datetime | " | |
| "suppliers : id int , name varchar , slug varchar , category_id int , created_at datetime , updated_at datetime | " | |
| "expenses : id int , user_id int , date date , category_id int , supplier_id int , sum numeric , notes text , created_at datetime , updated_at datetime | " | |
| "debts : id int , date date , user_id int , debt_sum numeric , payment_status varchar , partial_sum numeric , date_paid date , created_at datetime , updated_at datetime" | |
| ) | |
| _SQL_GENERATOR: Any | None = None | |
| _MONTHS = { | |
| "january": 1, | |
| "february": 2, | |
| "march": 3, | |
| "april": 4, | |
| "may": 5, | |
| "june": 6, | |
| "july": 7, | |
| "august": 8, | |
| "september": 9, | |
| "october": 10, | |
| "november": 11, | |
| "december": 12, | |
| } | |
| class SqlGenerationRequest: | |
| question: str | |
| limit: int = 200 | |
| def _normalize_text(text: str) -> str: | |
| return re.sub(r"\s+", " ", text.lower()).strip() | |
| def _contains_any(text: str, markers: tuple[str, ...]) -> bool: | |
| return any(marker in text for marker in markers) | |
| def _end_of_month(year: int, month: int) -> int: | |
| if month == 12: | |
| return 31 | |
| next_month = date(year, month + 1, 1) | |
| current_month = date(year, month, 1) | |
| return (next_month - current_month).days | |
| def _extract_month_filter(question: str) -> tuple[str, str] | None: | |
| text = _normalize_text(question) | |
| for month_name, month_idx in _MONTHS.items(): | |
| if month_name in text: | |
| year_match = re.search(r"\b(20\d{2})\b", text) | |
| if not year_match: | |
| continue | |
| year = int(year_match.group(1)) | |
| day_end = _end_of_month(year, month_idx) | |
| start = f"{year:04d}-{month_idx:02d}-01" | |
| end = f"{year:04d}-{month_idx:02d}-{day_end:02d}" | |
| return start, end | |
| return None | |
| def _extract_top_limit(question: str) -> int | None: | |
| match = re.search(r"\btop\s+(\d{1,4})\b", _normalize_text(question)) | |
| if not match: | |
| return None | |
| return max(1, min(1000, int(match.group(1)))) | |
| def _extract_metric(question: str) -> tuple[str, str]: | |
| text = _normalize_text(question) | |
| if _contains_any(text, ("count", "how many", "number of")): | |
| return "COUNT(*)", "items_count" | |
| if _contains_any(text, ("average", "avg", "mean")): | |
| return "AVG(e.sum)", "avg_amount" | |
| if _contains_any(text, ("minimum", "lowest", "min ")): | |
| return "MIN(e.sum)", "min_amount" | |
| if _contains_any(text, ("maximum", "highest", "max ")): | |
| return "MAX(e.sum)", "max_amount" | |
| return "SUM(e.sum)", "total_amount" | |
| def _extract_dimension(question: str) -> str | None: | |
| text = _normalize_text(question) | |
| if _contains_any(text, ("category", "categories")): | |
| return "category" | |
| if _contains_any(text, ("supplier", "suppliers", "vendor", "vendors")): | |
| return "supplier" | |
| if _contains_any(text, ("user", "users", "person")): | |
| return "user" | |
| return None | |
| def _build_expenses_aggregate_sql(payload: SqlGenerationRequest) -> str: | |
| question = _normalize_text(payload.question) | |
| metric_expr, metric_alias = _extract_metric(question) | |
| dimension = _extract_dimension(question) | |
| select_parts = [] | |
| joins = [] | |
| group_by = [] | |
| if dimension == "category": | |
| select_parts.append("c.name AS category_name") | |
| joins.append("JOIN categories AS c ON c.id = e.category_id") | |
| group_by.append("c.id, c.name") | |
| elif dimension == "supplier": | |
| select_parts.append("s.name AS supplier_name") | |
| joins.append("JOIN suppliers AS s ON s.id = e.supplier_id") | |
| group_by.append("s.id, s.name") | |
| elif dimension == "user": | |
| select_parts.append("u.name AS user_name") | |
| joins.append("JOIN users AS u ON u.id = e.user_id") | |
| group_by.append("u.id, u.name") | |
| select_parts.append(f"{metric_expr} AS {metric_alias}") | |
| filters = [] | |
| month_filter = _extract_month_filter(question) | |
| if month_filter: | |
| start, end = month_filter | |
| filters.append(f"e.date BETWEEN '{start}' AND '{end}'") | |
| where_clause = f" WHERE {' AND '.join(filters)}" if filters else "" | |
| join_clause = f" {' '.join(joins)}" if joins else "" | |
| group_clause = f" GROUP BY {', '.join(group_by)}" if group_by else "" | |
| order_direction = "ASC" if " asc" in question or "ascending" in question else "DESC" | |
| order_clause = f" ORDER BY {metric_alias} {order_direction}" | |
| top_limit = _extract_top_limit(question) | |
| final_limit = top_limit if top_limit is not None else payload.limit | |
| return ( | |
| f"SELECT {', '.join(select_parts)} " | |
| f"FROM expenses AS e" | |
| f"{join_clause}" | |
| f"{where_clause}" | |
| f"{group_clause}" | |
| f"{order_clause}" | |
| f" LIMIT {final_limit}" | |
| ) | |
| def _build_expenses_detail_sql(payload: SqlGenerationRequest) -> str: | |
| question = _normalize_text(payload.question) | |
| include_category = _contains_any(question, ("category", "categories")) | |
| include_supplier = _contains_any(question, ("supplier", "suppliers", "vendor", "vendors")) | |
| include_user = _contains_any(question, ("user", "users", "person")) | |
| select_parts = ["e.date", "e.sum", "e.notes"] | |
| joins = [] | |
| if include_category: | |
| select_parts.append("c.name AS category_name") | |
| joins.append("JOIN categories AS c ON c.id = e.category_id") | |
| if include_supplier: | |
| select_parts.append("s.name AS supplier_name") | |
| joins.append("JOIN suppliers AS s ON s.id = e.supplier_id") | |
| if include_user: | |
| select_parts.append("u.name AS user_name") | |
| joins.append("JOIN users AS u ON u.id = e.user_id") | |
| filters = [] | |
| month_filter = _extract_month_filter(question) | |
| if month_filter: | |
| start, end = month_filter | |
| filters.append(f"e.date BETWEEN '{start}' AND '{end}'") | |
| where_clause = f" WHERE {' AND '.join(filters)}" if filters else "" | |
| join_clause = f" {' '.join(joins)}" if joins else "" | |
| order_clause = " ORDER BY e.date DESC" | |
| return ( | |
| f"SELECT {', '.join(select_parts)} " | |
| f"FROM expenses AS e" | |
| f"{join_clause}" | |
| f"{where_clause}" | |
| f"{order_clause}" | |
| f" LIMIT {payload.limit}" | |
| ) | |
| def _build_debts_sql(payload: SqlGenerationRequest) -> str: | |
| question = _normalize_text(payload.question) | |
| with_user = _contains_any(question, ("user", "users", "person", "name")) | |
| select_parts = ["d.date", "d.debt_sum", "d.payment_status"] | |
| joins = [] | |
| if with_user: | |
| select_parts.append("u.name AS user_name") | |
| joins.append("LEFT JOIN users AS u ON u.id = d.user_id") | |
| filters = [] | |
| if _contains_any(question, ("unpaid", "not paid", "open debt", "open debts")): | |
| filters.append("d.payment_status = 'unpaid'") | |
| elif _contains_any(question, ("paid", "closed debt", "closed debts")): | |
| filters.append("d.payment_status = 'paid'") | |
| elif _contains_any(question, ("partial", "partially")): | |
| filters.append("d.payment_status = 'partial'") | |
| month_filter = _extract_month_filter(question) | |
| if month_filter: | |
| start, end = month_filter | |
| filters.append(f"d.date BETWEEN '{start}' AND '{end}'") | |
| where_clause = f" WHERE {' AND '.join(filters)}" if filters else "" | |
| join_clause = f" {' '.join(joins)}" if joins else "" | |
| order_clause = " ORDER BY d.date DESC" | |
| return ( | |
| f"SELECT {', '.join(select_parts)} " | |
| f"FROM debts AS d" | |
| f"{join_clause}" | |
| f"{where_clause}" | |
| f"{order_clause}" | |
| f" LIMIT {payload.limit}" | |
| ) | |
| def _generate_template_sql(payload: SqlGenerationRequest) -> str: | |
| question = _normalize_text(payload.question) | |
| debt_markers = ("debt", "debts", "payment_status", "unpaid", "partial", "paid") | |
| aggregate_markers = ( | |
| "sum", | |
| "total", | |
| "group", | |
| "grouped", | |
| "top", | |
| "count", | |
| "average", | |
| "avg", | |
| "minimum", | |
| "maximum", | |
| ) | |
| if _contains_any(question, debt_markers): | |
| return _build_debts_sql(payload) | |
| if _contains_any(question, aggregate_markers): | |
| return _build_expenses_aggregate_sql(payload) | |
| return _build_expenses_detail_sql(payload) | |
| def _get_sql_generator() -> Any: | |
| global _SQL_GENERATOR | |
| if _SQL_GENERATOR is None: | |
| from transformers import pipeline | |
| model_id = os.getenv("SQL_MODEL", "gaussalgo/T5-LM-Large-text2sql-spider") | |
| _SQL_GENERATOR = pipeline( | |
| task="text2text-generation", | |
| model=model_id, | |
| tokenizer=model_id, | |
| device=-1, | |
| ) | |
| return _SQL_GENERATOR | |
| def _build_prompt(payload: SqlGenerationRequest) -> str: | |
| # Optional fallback prompt for transformer model. | |
| return f"Question: {payload.question} | {DEFAULT_DB_SCHEMA}" | |
| def _normalize_sql(raw_sql: str, limit: int) -> str: | |
| sql = (raw_sql or "").strip() | |
| if not sql: | |
| raise ValueError("SQL model returned an empty result.") | |
| if "```" in sql: | |
| parts = [part.strip() for part in sql.split("```") if part.strip()] | |
| sql = parts[-1] | |
| upper_sql = sql.upper() | |
| sql_start = upper_sql.find("SELECT") | |
| if sql_start == -1: | |
| raise ValueError("Generated SQL is not a SELECT query.") | |
| sql = sql[sql_start:] | |
| if ";" in sql: | |
| sql = sql.split(";", 1)[0].strip() | |
| upper_sql = sql.upper() | |
| forbidden = ("INSERT ", "UPDATE ", "DELETE ", "DROP ", "ALTER ", "PRAGMA ", "ATTACH ", "CREATE ", "REPLACE ") | |
| if any(keyword in upper_sql for keyword in forbidden): | |
| raise ValueError("Generated SQL contains forbidden statements.") | |
| if not upper_sql.startswith("SELECT "): | |
| raise ValueError("Only SELECT queries are allowed.") | |
| aggregate_markers = ("COUNT(", "SUM(", "AVG(", "MIN(", "MAX(") | |
| has_limit = " LIMIT " in upper_sql | |
| if not has_limit and not any(marker in upper_sql for marker in aggregate_markers): | |
| sql = f"{sql} LIMIT {limit}" | |
| return sql | |
| def generate_sql(question: str, limit: int = 200) -> str: | |
| clean_question = (question or "").strip() | |
| if not clean_question: | |
| raise ValueError("Field 'query' is required.") | |
| payload = SqlGenerationRequest( | |
| question=clean_question, | |
| limit=max(1, min(1000, int(limit))), | |
| ) | |
| # Primary path: deterministic template engine for core tables. | |
| template_sql = _generate_template_sql(payload) | |
| if template_sql: | |
| return _normalize_sql(template_sql, limit=payload.limit) | |
| # Secondary path: optional model fallback. | |
| if os.getenv("SQL_USE_LLM_FALLBACK", "false").strip().lower() not in {"1", "true", "yes", "on"}: | |
| raise ValueError("Unable to map query to a supported SQL template.") | |
| generator = _get_sql_generator() | |
| prompt = _build_prompt(payload) | |
| result = generator( | |
| prompt, | |
| max_new_tokens=512, | |
| do_sample=False, | |
| num_beams=4, | |
| truncation=True, | |
| ) | |
| generated_text = result[0].get("generated_text", "") if result else "" | |
| return _normalize_sql(generated_text, limit=payload.limit) |