ConvertAudioToJSON / sql_generator.py
VladGeekPro
ChangedDockerFile
82b086c
raw
history blame
11.6 kB
"""
Template-based SQL generator for SQLite.
Deterministic schema-aware NL→SQL with minimal templates and high accuracy.
Focuses only on: users, categories, suppliers, expenses, debts.
"""
import os
import re
from datetime import datetime
from typing import Any, Dict, Optional, Tuple
from dataclasses import dataclass, asdict
# Core business schema only (for reference/documentation)
DEFAULT_DB_SCHEMA = (
"users : id int , name varchar , email varchar , created_at datetime , updated_at datetime | "
"categories : id int , name varchar , slug varchar , notes text , created_at datetime , updated_at datetime | "
"suppliers : id int , name varchar , slug varchar , category_id int , created_at datetime , updated_at datetime | "
"expenses : id int , user_id int , date date , category_id int , supplier_id int , sum numeric , notes text , created_at datetime , updated_at datetime | "
"debts : id int , date date , user_id int , debt_sum numeric , payment_status varchar , partial_sum numeric , date_paid date , created_at datetime , updated_at datetime"
)
_SQL_GENERATOR: Any | None = None
_MONTHS = {
"january": 1,
"february": 2,
"march": 3,
"april": 4,
"may": 5,
"june": 6,
"july": 7,
"august": 8,
"september": 9,
"october": 10,
"november": 11,
"december": 12,
}
@dataclass(frozen=True)
class SqlGenerationRequest:
question: str
limit: int = 200
def _normalize_text(text: str) -> str:
return re.sub(r"\s+", " ", text.lower()).strip()
def _contains_any(text: str, markers: tuple[str, ...]) -> bool:
return any(marker in text for marker in markers)
def _end_of_month(year: int, month: int) -> int:
if month == 12:
return 31
next_month = date(year, month + 1, 1)
current_month = date(year, month, 1)
return (next_month - current_month).days
def _extract_month_filter(question: str) -> tuple[str, str] | None:
text = _normalize_text(question)
for month_name, month_idx in _MONTHS.items():
if month_name in text:
year_match = re.search(r"\b(20\d{2})\b", text)
if not year_match:
continue
year = int(year_match.group(1))
day_end = _end_of_month(year, month_idx)
start = f"{year:04d}-{month_idx:02d}-01"
end = f"{year:04d}-{month_idx:02d}-{day_end:02d}"
return start, end
return None
def _extract_top_limit(question: str) -> int | None:
match = re.search(r"\btop\s+(\d{1,4})\b", _normalize_text(question))
if not match:
return None
return max(1, min(1000, int(match.group(1))))
def _extract_metric(question: str) -> tuple[str, str]:
text = _normalize_text(question)
if _contains_any(text, ("count", "how many", "number of")):
return "COUNT(*)", "items_count"
if _contains_any(text, ("average", "avg", "mean")):
return "AVG(e.sum)", "avg_amount"
if _contains_any(text, ("minimum", "lowest", "min ")):
return "MIN(e.sum)", "min_amount"
if _contains_any(text, ("maximum", "highest", "max ")):
return "MAX(e.sum)", "max_amount"
return "SUM(e.sum)", "total_amount"
def _extract_dimension(question: str) -> str | None:
text = _normalize_text(question)
if _contains_any(text, ("category", "categories")):
return "category"
if _contains_any(text, ("supplier", "suppliers", "vendor", "vendors")):
return "supplier"
if _contains_any(text, ("user", "users", "person")):
return "user"
return None
def _build_expenses_aggregate_sql(payload: SqlGenerationRequest) -> str:
question = _normalize_text(payload.question)
metric_expr, metric_alias = _extract_metric(question)
dimension = _extract_dimension(question)
select_parts = []
joins = []
group_by = []
if dimension == "category":
select_parts.append("c.name AS category_name")
joins.append("JOIN categories AS c ON c.id = e.category_id")
group_by.append("c.id, c.name")
elif dimension == "supplier":
select_parts.append("s.name AS supplier_name")
joins.append("JOIN suppliers AS s ON s.id = e.supplier_id")
group_by.append("s.id, s.name")
elif dimension == "user":
select_parts.append("u.name AS user_name")
joins.append("JOIN users AS u ON u.id = e.user_id")
group_by.append("u.id, u.name")
select_parts.append(f"{metric_expr} AS {metric_alias}")
filters = []
month_filter = _extract_month_filter(question)
if month_filter:
start, end = month_filter
filters.append(f"e.date BETWEEN '{start}' AND '{end}'")
where_clause = f" WHERE {' AND '.join(filters)}" if filters else ""
join_clause = f" {' '.join(joins)}" if joins else ""
group_clause = f" GROUP BY {', '.join(group_by)}" if group_by else ""
order_direction = "ASC" if " asc" in question or "ascending" in question else "DESC"
order_clause = f" ORDER BY {metric_alias} {order_direction}"
top_limit = _extract_top_limit(question)
final_limit = top_limit if top_limit is not None else payload.limit
return (
f"SELECT {', '.join(select_parts)} "
f"FROM expenses AS e"
f"{join_clause}"
f"{where_clause}"
f"{group_clause}"
f"{order_clause}"
f" LIMIT {final_limit}"
)
def _build_expenses_detail_sql(payload: SqlGenerationRequest) -> str:
question = _normalize_text(payload.question)
include_category = _contains_any(question, ("category", "categories"))
include_supplier = _contains_any(question, ("supplier", "suppliers", "vendor", "vendors"))
include_user = _contains_any(question, ("user", "users", "person"))
select_parts = ["e.date", "e.sum", "e.notes"]
joins = []
if include_category:
select_parts.append("c.name AS category_name")
joins.append("JOIN categories AS c ON c.id = e.category_id")
if include_supplier:
select_parts.append("s.name AS supplier_name")
joins.append("JOIN suppliers AS s ON s.id = e.supplier_id")
if include_user:
select_parts.append("u.name AS user_name")
joins.append("JOIN users AS u ON u.id = e.user_id")
filters = []
month_filter = _extract_month_filter(question)
if month_filter:
start, end = month_filter
filters.append(f"e.date BETWEEN '{start}' AND '{end}'")
where_clause = f" WHERE {' AND '.join(filters)}" if filters else ""
join_clause = f" {' '.join(joins)}" if joins else ""
order_clause = " ORDER BY e.date DESC"
return (
f"SELECT {', '.join(select_parts)} "
f"FROM expenses AS e"
f"{join_clause}"
f"{where_clause}"
f"{order_clause}"
f" LIMIT {payload.limit}"
)
def _build_debts_sql(payload: SqlGenerationRequest) -> str:
question = _normalize_text(payload.question)
with_user = _contains_any(question, ("user", "users", "person", "name"))
select_parts = ["d.date", "d.debt_sum", "d.payment_status"]
joins = []
if with_user:
select_parts.append("u.name AS user_name")
joins.append("LEFT JOIN users AS u ON u.id = d.user_id")
filters = []
if _contains_any(question, ("unpaid", "not paid", "open debt", "open debts")):
filters.append("d.payment_status = 'unpaid'")
elif _contains_any(question, ("paid", "closed debt", "closed debts")):
filters.append("d.payment_status = 'paid'")
elif _contains_any(question, ("partial", "partially")):
filters.append("d.payment_status = 'partial'")
month_filter = _extract_month_filter(question)
if month_filter:
start, end = month_filter
filters.append(f"d.date BETWEEN '{start}' AND '{end}'")
where_clause = f" WHERE {' AND '.join(filters)}" if filters else ""
join_clause = f" {' '.join(joins)}" if joins else ""
order_clause = " ORDER BY d.date DESC"
return (
f"SELECT {', '.join(select_parts)} "
f"FROM debts AS d"
f"{join_clause}"
f"{where_clause}"
f"{order_clause}"
f" LIMIT {payload.limit}"
)
def _generate_template_sql(payload: SqlGenerationRequest) -> str:
question = _normalize_text(payload.question)
debt_markers = ("debt", "debts", "payment_status", "unpaid", "partial", "paid")
aggregate_markers = (
"sum",
"total",
"group",
"grouped",
"top",
"count",
"average",
"avg",
"minimum",
"maximum",
)
if _contains_any(question, debt_markers):
return _build_debts_sql(payload)
if _contains_any(question, aggregate_markers):
return _build_expenses_aggregate_sql(payload)
return _build_expenses_detail_sql(payload)
def _get_sql_generator() -> Any:
global _SQL_GENERATOR
if _SQL_GENERATOR is None:
from transformers import pipeline
model_id = os.getenv("SQL_MODEL", "gaussalgo/T5-LM-Large-text2sql-spider")
_SQL_GENERATOR = pipeline(
task="text2text-generation",
model=model_id,
tokenizer=model_id,
device=-1,
)
return _SQL_GENERATOR
def _build_prompt(payload: SqlGenerationRequest) -> str:
# Optional fallback prompt for transformer model.
return f"Question: {payload.question} | {DEFAULT_DB_SCHEMA}"
def _normalize_sql(raw_sql: str, limit: int) -> str:
sql = (raw_sql or "").strip()
if not sql:
raise ValueError("SQL model returned an empty result.")
if "```" in sql:
parts = [part.strip() for part in sql.split("```") if part.strip()]
sql = parts[-1]
upper_sql = sql.upper()
sql_start = upper_sql.find("SELECT")
if sql_start == -1:
raise ValueError("Generated SQL is not a SELECT query.")
sql = sql[sql_start:]
if ";" in sql:
sql = sql.split(";", 1)[0].strip()
upper_sql = sql.upper()
forbidden = ("INSERT ", "UPDATE ", "DELETE ", "DROP ", "ALTER ", "PRAGMA ", "ATTACH ", "CREATE ", "REPLACE ")
if any(keyword in upper_sql for keyword in forbidden):
raise ValueError("Generated SQL contains forbidden statements.")
if not upper_sql.startswith("SELECT "):
raise ValueError("Only SELECT queries are allowed.")
aggregate_markers = ("COUNT(", "SUM(", "AVG(", "MIN(", "MAX(")
has_limit = " LIMIT " in upper_sql
if not has_limit and not any(marker in upper_sql for marker in aggregate_markers):
sql = f"{sql} LIMIT {limit}"
return sql
def generate_sql(question: str, limit: int = 200) -> str:
clean_question = (question or "").strip()
if not clean_question:
raise ValueError("Field 'query' is required.")
payload = SqlGenerationRequest(
question=clean_question,
limit=max(1, min(1000, int(limit))),
)
# Primary path: deterministic template engine for core tables.
template_sql = _generate_template_sql(payload)
if template_sql:
return _normalize_sql(template_sql, limit=payload.limit)
# Secondary path: optional model fallback.
if os.getenv("SQL_USE_LLM_FALLBACK", "false").strip().lower() not in {"1", "true", "yes", "on"}:
raise ValueError("Unable to map query to a supported SQL template.")
generator = _get_sql_generator()
prompt = _build_prompt(payload)
result = generator(
prompt,
max_new_tokens=512,
do_sample=False,
num_beams=4,
truncation=True,
)
generated_text = result[0].get("generated_text", "") if result else ""
return _normalize_sql(generated_text, limit=payload.limit)