Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,7 +4,6 @@ import zipfile
|
|
| 4 |
import re
|
| 5 |
import difflib
|
| 6 |
import tempfile
|
| 7 |
-
import sqlite3
|
| 8 |
import uuid
|
| 9 |
from typing import List, Optional, Dict, Any
|
| 10 |
|
|
@@ -18,21 +17,27 @@ from langdetect import detect
|
|
| 18 |
from transformers import MarianMTModel, MarianTokenizer
|
| 19 |
from openai import OpenAI
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
# ---- Supabase ----
|
| 22 |
from supabase import create_client, Client
|
| 23 |
|
| 24 |
SUPABASE_URL = "https://bnvmqgjawtaslczewqyd.supabase.co"
|
| 25 |
-
SUPABASE_ANON_KEY =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
supabase: Client = create_client(SUPABASE_URL, SUPABASE_ANON_KEY)
|
| 28 |
|
| 29 |
# ======================================================
|
| 30 |
-
# 0) Configuración general de paths
|
| 31 |
# ======================================================
|
| 32 |
|
| 33 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 34 |
-
UPLOAD_DIR = os.path.join(BASE_DIR, "uploaded_dbs")
|
| 35 |
-
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
| 36 |
|
| 37 |
MODEL_DIR = os.getenv("MODEL_DIR", "stvnnnnnn/t5-large-nl2sql-spider")
|
| 38 |
DEVICE = torch.device("cpu")
|
|
@@ -40,24 +45,32 @@ DEVICE = torch.device("cpu")
|
|
| 40 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 41 |
openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
# ======================================================
|
| 44 |
-
# 1)
|
| 45 |
# ======================================================
|
| 46 |
|
| 47 |
-
class
|
| 48 |
"""
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
| 52 |
"""
|
| 53 |
|
| 54 |
-
def __init__(self):
|
| 55 |
-
|
| 56 |
-
# "label": str,
|
| 57 |
-
# "engine": "sqlite",
|
| 58 |
-
# "db_name": str,
|
| 59 |
-
# "db_path": str
|
| 60 |
-
# }
|
| 61 |
self.connections: Dict[str, Dict[str, Any]] = {}
|
| 62 |
|
| 63 |
# ---------- utilidades internas ----------
|
|
@@ -70,51 +83,73 @@ class SQLManager:
|
|
| 70 |
raise KeyError(f"connection_id '{connection_id}' no registrado")
|
| 71 |
return self.connections[connection_id]
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
# ---------- creación de BD desde dump ----------
|
| 74 |
|
| 75 |
def create_database_from_dump(self, label: str, sql_text: str) -> str:
|
| 76 |
"""
|
| 77 |
-
Crea un
|
| 78 |
-
|
| 79 |
-
razonablemente compatible con SQLite.
|
| 80 |
"""
|
| 81 |
connection_id = self._new_connection_id()
|
| 82 |
-
|
| 83 |
-
db_path = os.path.join(UPLOAD_DIR, f"{db_name}.sqlite")
|
| 84 |
|
| 85 |
-
|
| 86 |
-
conn = sqlite3.connect(db_path)
|
| 87 |
try:
|
| 88 |
-
conn.
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
conn.close()
|
| 92 |
-
|
| 93 |
-
os.remove(db_path)
|
| 94 |
-
raise RuntimeError(f"Error ejecutando dump SQL en SQLite: {e}")
|
| 95 |
finally:
|
| 96 |
conn.close()
|
| 97 |
|
| 98 |
self.connections[connection_id] = {
|
| 99 |
"label": label,
|
| 100 |
-
"engine": "
|
| 101 |
-
"
|
| 102 |
-
"db_path": db_path,
|
| 103 |
}
|
| 104 |
return connection_id
|
| 105 |
|
| 106 |
# ---------- ejecución segura de SQL ----------
|
| 107 |
|
| 108 |
-
def execute_sql(self, connection_id: str,
|
| 109 |
"""
|
| 110 |
-
Ejecuta un SELECT
|
| 111 |
Bloquea operaciones destructivas por seguridad.
|
| 112 |
"""
|
| 113 |
info = self._get_info(connection_id)
|
| 114 |
-
|
| 115 |
|
| 116 |
forbidden = ["drop ", "delete ", "update ", "insert ", "alter ", "replace "]
|
| 117 |
-
sql_low =
|
| 118 |
if any(tok in sql_low for tok in forbidden):
|
| 119 |
return {
|
| 120 |
"ok": False,
|
|
@@ -123,89 +158,152 @@ class SQLManager:
|
|
| 123 |
"columns": [],
|
| 124 |
}
|
| 125 |
|
|
|
|
| 126 |
try:
|
| 127 |
-
conn
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
except Exception as e:
|
| 135 |
return {"ok": False, "error": str(e), "rows": None, "columns": []}
|
|
|
|
|
|
|
| 136 |
|
| 137 |
# ---------- introspección de esquema ----------
|
| 138 |
|
| 139 |
def get_schema(self, connection_id: str) -> Dict[str, Any]:
|
| 140 |
info = self._get_info(connection_id)
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
if not os.path.exists(db_path):
|
| 144 |
-
raise RuntimeError(f"SQLite no encontrado: {db_path}")
|
| 145 |
-
|
| 146 |
-
conn = sqlite3.connect(db_path)
|
| 147 |
-
cur = conn.cursor()
|
| 148 |
-
|
| 149 |
-
cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
| 150 |
-
tables = [row[0] for row in cur.fetchall()]
|
| 151 |
-
|
| 152 |
-
tables_info: Dict[str, Dict[str, Any]] = {}
|
| 153 |
-
foreign_keys: List[Dict[str, Any]] = []
|
| 154 |
-
|
| 155 |
-
for t in tables:
|
| 156 |
-
cur.execute(f"PRAGMA table_info('{t}');")
|
| 157 |
-
rows = cur.fetchall()
|
| 158 |
-
cols = [r[1] for r in rows]
|
| 159 |
-
tables_info[t] = {"columns": cols}
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
| 177 |
|
| 178 |
# ---------- preview de tabla ----------
|
| 179 |
|
| 180 |
-
def get_preview(
|
|
|
|
|
|
|
| 181 |
info = self._get_info(connection_id)
|
| 182 |
-
|
| 183 |
|
| 184 |
-
conn =
|
| 185 |
-
cur = conn.cursor()
|
| 186 |
try:
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
finally:
|
| 191 |
conn.close()
|
| 192 |
|
| 193 |
-
return {
|
| 194 |
-
"columns": cols,
|
| 195 |
-
"rows": [list(r) for r in rows],
|
| 196 |
-
}
|
| 197 |
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
sql_manager = SQLManager()
|
| 201 |
|
| 202 |
# ======================================================
|
| 203 |
# 2) Inicialización de FastAPI
|
| 204 |
# ======================================================
|
| 205 |
|
| 206 |
app = FastAPI(
|
| 207 |
-
title="NL2SQL Backend (
|
| 208 |
-
version="
|
| 209 |
)
|
| 210 |
|
| 211 |
app.add_middleware(
|
|
@@ -232,7 +330,9 @@ def load_nl2sql_model():
|
|
| 232 |
return
|
| 233 |
print(f"🔁 Cargando modelo NL→SQL desde: {MODEL_DIR}")
|
| 234 |
t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
|
| 235 |
-
t5_model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
|
|
|
|
|
| 236 |
t5_model.to(DEVICE)
|
| 237 |
t5_model.eval()
|
| 238 |
print("✅ Modelo NL→SQL listo en memoria.")
|
|
@@ -281,14 +381,16 @@ def translate_es_to_en(text: str) -> str:
|
|
| 281 |
|
| 282 |
def _normalize_name_for_match(name: str) -> str:
|
| 283 |
s = name.lower()
|
| 284 |
-
s = s.replace('"',
|
| 285 |
s = s.replace("_", "")
|
| 286 |
if s.endswith("s") and len(s) > 3:
|
| 287 |
s = s[:-1]
|
| 288 |
return s
|
| 289 |
|
| 290 |
|
| 291 |
-
def _build_schema_indexes(
|
|
|
|
|
|
|
| 292 |
table_index: Dict[str, List[str]] = {}
|
| 293 |
column_index: Dict[str, List[str]] = {}
|
| 294 |
|
|
@@ -349,6 +451,8 @@ DOMAIN_SYNONYMS_COLUMN = {
|
|
| 349 |
def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optional[str]:
|
| 350 |
"""
|
| 351 |
Intenta reparar nombres de tablas/columnas basándose en el esquema real.
|
|
|
|
|
|
|
| 352 |
"""
|
| 353 |
tables_info = schema_meta["tables"]
|
| 354 |
idx = _build_schema_indexes(tables_info)
|
|
@@ -361,13 +465,13 @@ def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optiona
|
|
| 361 |
missing_table = None
|
| 362 |
missing_column = None
|
| 363 |
|
| 364 |
-
m_t = re.search(r
|
| 365 |
if not m_t:
|
| 366 |
m_t = re.search(r"no such table: ([\w\.]+)", error)
|
| 367 |
if m_t:
|
| 368 |
missing_table = m_t.group(1)
|
| 369 |
|
| 370 |
-
m_c = re.search(r
|
| 371 |
if not m_c:
|
| 372 |
m_c = re.search(r"no such column: ([\w\.]+)", error)
|
| 373 |
if m_c:
|
|
@@ -411,7 +515,7 @@ def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optiona
|
|
| 411 |
|
| 412 |
|
| 413 |
# ======================================================
|
| 414 |
-
# 5)
|
| 415 |
# ======================================================
|
| 416 |
|
| 417 |
def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
|
|
@@ -421,9 +525,9 @@ def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
|
|
| 421 |
f"note: use JOIN when foreign keys link tables"
|
| 422 |
)
|
| 423 |
|
|
|
|
| 424 |
def normalize_score(raw: float) -> float:
|
| 425 |
"""Normaliza el score logit del modelo a un porcentaje 0-100."""
|
| 426 |
-
# Rango típico de logits beam-search: -20 a +5
|
| 427 |
norm = (raw + 20) / 25
|
| 428 |
norm = max(0, min(1, norm))
|
| 429 |
return round(norm * 100, 2)
|
|
@@ -431,9 +535,10 @@ def normalize_score(raw: float) -> float:
|
|
| 431 |
|
| 432 |
def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
|
| 433 |
if conn_id not in sql_manager.connections:
|
| 434 |
-
raise HTTPException(
|
|
|
|
|
|
|
| 435 |
|
| 436 |
-
# Obtener esquema real desde SQLite (futuro: Postgres/MySQL)
|
| 437 |
meta = sql_manager.get_schema(conn_id)
|
| 438 |
tables_info = meta["tables"]
|
| 439 |
|
|
@@ -451,7 +556,9 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
|
|
| 451 |
if t5_model is None:
|
| 452 |
load_nl2sql_model()
|
| 453 |
|
| 454 |
-
inputs = t5_tokenizer(
|
|
|
|
|
|
|
| 455 |
num_beams = 6
|
| 456 |
num_return = 6
|
| 457 |
|
|
@@ -478,7 +585,9 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
|
|
| 478 |
best_score = -1e9
|
| 479 |
|
| 480 |
for i in range(sequences.size(0)):
|
| 481 |
-
raw_sql = t5_tokenizer.decode(
|
|
|
|
|
|
|
| 482 |
cand: Dict[str, Any] = {
|
| 483 |
"sql": raw_sql,
|
| 484 |
"score": float(scores[i]),
|
|
@@ -489,7 +598,6 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
|
|
| 489 |
|
| 490 |
exec_info = sql_manager.execute_sql(conn_id, raw_sql)
|
| 491 |
|
| 492 |
-
# Intentar reparación solo si es error por tabla/columna
|
| 493 |
err_lower = (exec_info["error"] or "").lower()
|
| 494 |
if (not exec_info["ok"]) and (
|
| 495 |
"no such table" in err_lower
|
|
@@ -503,7 +611,9 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
|
|
| 503 |
if not repaired_sql or repaired_sql == current_sql:
|
| 504 |
break
|
| 505 |
exec_info2 = sql_manager.execute_sql(conn_id, repaired_sql)
|
| 506 |
-
cand["repaired_from"] =
|
|
|
|
|
|
|
| 507 |
cand["repair_note"] = f"auto-repair (table/column name, step {step})"
|
| 508 |
cand["sql"] = repaired_sql
|
| 509 |
exec_info = exec_info2
|
|
@@ -556,7 +666,7 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
|
|
| 556 |
class UploadResponse(BaseModel):
|
| 557 |
connection_id: str
|
| 558 |
label: str
|
| 559 |
-
db_path: str
|
| 560 |
note: Optional[str] = None
|
| 561 |
|
| 562 |
|
|
@@ -564,7 +674,7 @@ class ConnectionInfo(BaseModel):
|
|
| 564 |
connection_id: str
|
| 565 |
label: str
|
| 566 |
engine: Optional[str] = None
|
| 567 |
-
db_name: Optional[str] = None
|
| 568 |
|
| 569 |
|
| 570 |
class SchemaResponse(BaseModel):
|
|
@@ -649,7 +759,7 @@ def _combine_sql_files_from_zip(zip_bytes: bytes) -> str:
|
|
| 649 |
@app.on_event("startup")
|
| 650 |
async def startup_event():
|
| 651 |
load_nl2sql_model()
|
| 652 |
-
print("✅ Backend NL2SQL inicializado (engine
|
| 653 |
print(f"MODEL_DIR={MODEL_DIR}, DEVICE={DEVICE}")
|
| 654 |
print(f"Conexiones activas al inicio: {len(sql_manager.connections)}")
|
| 655 |
|
|
@@ -657,7 +767,7 @@ async def startup_event():
|
|
| 657 |
@app.post("/upload", response_model=UploadResponse)
|
| 658 |
async def upload_database(
|
| 659 |
db_file: UploadFile = File(...),
|
| 660 |
-
authorization: Optional[str] = Header(None)
|
| 661 |
):
|
| 662 |
if authorization is None:
|
| 663 |
raise HTTPException(401, "Missing Authorization header")
|
|
@@ -667,7 +777,7 @@ async def upload_database(
|
|
| 667 |
if not user or not user.user:
|
| 668 |
raise HTTPException(401, "Invalid Supabase token")
|
| 669 |
|
| 670 |
-
filename = db_file.filename
|
| 671 |
fname_lower = filename.lower()
|
| 672 |
contents = await db_file.read()
|
| 673 |
|
|
@@ -682,7 +792,7 @@ async def upload_database(
|
|
| 682 |
else:
|
| 683 |
raise HTTPException(400, "Formato no soportado. Usa .sql o .zip.")
|
| 684 |
|
| 685 |
-
# --- crear
|
| 686 |
try:
|
| 687 |
conn_id = sql_manager.create_database_from_dump(label=filename, sql_text=sql_text)
|
| 688 |
except Exception as e:
|
|
@@ -690,19 +800,21 @@ async def upload_database(
|
|
| 690 |
|
| 691 |
meta = sql_manager.connections[conn_id]
|
| 692 |
|
| 693 |
-
# --- guardar en Supabase ---
|
| 694 |
-
supabase.table("databases").insert(
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
|
|
|
|
|
|
| 700 |
|
| 701 |
return UploadResponse(
|
| 702 |
connection_id=conn_id,
|
| 703 |
label=filename,
|
| 704 |
-
db_path=f"{meta['engine']}://{meta['
|
| 705 |
-
note="Database
|
| 706 |
)
|
| 707 |
|
| 708 |
|
|
@@ -713,7 +825,7 @@ async def list_connections():
|
|
| 713 |
connection_id=cid,
|
| 714 |
label=meta.get("label", ""),
|
| 715 |
engine=meta.get("engine"),
|
| 716 |
-
db_name=meta.get("
|
| 717 |
)
|
| 718 |
for cid, meta in sql_manager.connections.items()
|
| 719 |
]
|
|
@@ -748,7 +860,9 @@ async def preview_table(connection_id: str, table: str, limit: int = 20):
|
|
| 748 |
try:
|
| 749 |
preview = sql_manager.get_preview(connection_id, table, limit)
|
| 750 |
except Exception as e:
|
| 751 |
-
raise HTTPException(
|
|
|
|
|
|
|
| 752 |
|
| 753 |
return PreviewResponse(
|
| 754 |
connection_id=connection_id,
|
|
@@ -761,7 +875,7 @@ async def preview_table(connection_id: str, table: str, limit: int = 20):
|
|
| 761 |
@app.post("/infer", response_model=InferResponse)
|
| 762 |
async def infer_sql(
|
| 763 |
req: InferRequest,
|
| 764 |
-
authorization: Optional[str] = Header(None)
|
| 765 |
):
|
| 766 |
if authorization is None:
|
| 767 |
raise HTTPException(401, "Missing Authorization header")
|
|
@@ -774,27 +888,28 @@ async def infer_sql(
|
|
| 774 |
result = nl2sql_with_rerank(req.question, req.connection_id)
|
| 775 |
score = normalize_score(result["candidates"][0]["score"])
|
| 776 |
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
.select("id")
|
| 780 |
-
.eq("connection_id", req.connection_id)
|
| 781 |
-
.eq("user_id", user.user.id)
|
| 782 |
.execute()
|
| 783 |
-
|
| 784 |
db_id = db_row.data[0]["id"] if db_row.data else None
|
| 785 |
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
|
|
|
| 798 |
|
| 799 |
result["score_percent"] = score
|
| 800 |
return InferResponse(**result)
|
|
@@ -803,12 +918,12 @@ async def infer_sql(
|
|
| 803 |
@app.post("/speech-infer", response_model=SpeechInferResponse)
|
| 804 |
async def speech_infer(
|
| 805 |
connection_id: str = Form(...),
|
| 806 |
-
audio: UploadFile = File(...)
|
| 807 |
):
|
| 808 |
if openai_client is None:
|
| 809 |
raise HTTPException(
|
| 810 |
status_code=500,
|
| 811 |
-
detail="OPENAI_API_KEY no está configurado en el backend."
|
| 812 |
)
|
| 813 |
|
| 814 |
if audio.content_type is None:
|
|
@@ -819,7 +934,9 @@ async def speech_infer(
|
|
| 819 |
tmp.write(await audio.read())
|
| 820 |
tmp_path = tmp.name
|
| 821 |
except Exception:
|
| 822 |
-
raise HTTPException(
|
|
|
|
|
|
|
| 823 |
|
| 824 |
try:
|
| 825 |
with open(tmp_path, "rb") as f:
|
|
@@ -847,8 +964,10 @@ async def health():
|
|
| 847 |
"model_loaded": t5_model is not None,
|
| 848 |
"connections": len(sql_manager.connections),
|
| 849 |
"device": str(DEVICE),
|
|
|
|
| 850 |
}
|
| 851 |
|
|
|
|
| 852 |
@app.get("/history")
|
| 853 |
def get_history(authorization: Optional[str] = Header(None)):
|
| 854 |
if authorization is None:
|
|
@@ -857,11 +976,13 @@ def get_history(authorization: Optional[str] = Header(None)):
|
|
| 857 |
jwt = authorization.replace("Bearer ", "")
|
| 858 |
user = supabase.auth.get_user(jwt)
|
| 859 |
|
| 860 |
-
rows =
|
| 861 |
-
.
|
| 862 |
-
.
|
| 863 |
-
.
|
|
|
|
| 864 |
.execute()
|
|
|
|
| 865 |
|
| 866 |
return rows.data
|
| 867 |
|
|
@@ -874,10 +995,12 @@ def get_my_databases(authorization: Optional[str] = Header(None)):
|
|
| 874 |
jwt = authorization.replace("Bearer ", "")
|
| 875 |
user = supabase.auth.get_user(jwt)
|
| 876 |
|
| 877 |
-
rows =
|
| 878 |
-
.
|
| 879 |
-
.
|
|
|
|
| 880 |
.execute()
|
|
|
|
| 881 |
|
| 882 |
return rows.data
|
| 883 |
|
|
@@ -885,14 +1008,16 @@ def get_my_databases(authorization: Optional[str] = Header(None)):
|
|
| 885 |
@app.get("/")
|
| 886 |
async def root():
|
| 887 |
return {
|
| 888 |
-
"message": "NL2SQL T5-large backend
|
| 889 |
"endpoints": [
|
| 890 |
-
"POST /upload (subir .sql o .zip con .sql →
|
| 891 |
-
"GET /connections (listar BDs subidas)",
|
| 892 |
"GET /schema/{id} (esquema resumido)",
|
| 893 |
"GET /preview/{id}/{t} (preview de tabla)",
|
| 894 |
"POST /infer (NL→SQL + ejecución en BD)",
|
| 895 |
-
"POST /speech-infer (
|
|
|
|
|
|
|
| 896 |
"GET /health (estado del backend)",
|
| 897 |
"GET /docs (OpenAPI UI)",
|
| 898 |
],
|
|
|
|
| 4 |
import re
|
| 5 |
import difflib
|
| 6 |
import tempfile
|
|
|
|
| 7 |
import uuid
|
| 8 |
from typing import List, Optional, Dict, Any
|
| 9 |
|
|
|
|
| 17 |
from transformers import MarianMTModel, MarianTokenizer
|
| 18 |
from openai import OpenAI
|
| 19 |
|
| 20 |
+
# ---- Postgres (Neon) ----
|
| 21 |
+
import psycopg2
|
| 22 |
+
from psycopg2 import sql as pgsql
|
| 23 |
+
|
| 24 |
# ---- Supabase ----
|
| 25 |
from supabase import create_client, Client
|
| 26 |
|
| 27 |
SUPABASE_URL = "https://bnvmqgjawtaslczewqyd.supabase.co"
|
| 28 |
+
SUPABASE_ANON_KEY = (
|
| 29 |
+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImJudm1x"
|
| 30 |
+
"Z2phd3Rhc2xjemV3cXlkIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NjQ0NjM5NDAsImV4cCI6MjA4"
|
| 31 |
+
"MDAzOTk0MH0.9zkyqrsm-QOSwMTUPZEWqyFeNpbbuar01rB7pmObkUI"
|
| 32 |
+
)
|
| 33 |
|
| 34 |
supabase: Client = create_client(SUPABASE_URL, SUPABASE_ANON_KEY)
|
| 35 |
|
| 36 |
# ======================================================
|
| 37 |
+
# 0) Configuración general de paths / modelo / OpenAI
|
| 38 |
# ======================================================
|
| 39 |
|
| 40 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
|
|
| 41 |
|
| 42 |
MODEL_DIR = os.getenv("MODEL_DIR", "stvnnnnnn/t5-large-nl2sql-spider")
|
| 43 |
DEVICE = torch.device("cpu")
|
|
|
|
| 45 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 46 |
openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
|
| 47 |
|
| 48 |
+
# DSN de Neon (Postgres) – EJEMPLO:
|
| 49 |
+
# postgres://user:pass@host/neondb?sslmode=require
|
| 50 |
+
POSTGRES_DSN = os.getenv("POSTGRES_DSN")
|
| 51 |
+
|
| 52 |
+
if not POSTGRES_DSN:
|
| 53 |
+
raise RuntimeError(
|
| 54 |
+
"⚠️ POSTGRES_DSN no está definido. "
|
| 55 |
+
"Configúralo en los secrets del Space con la cadena de conexión de Neon."
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
# ======================================================
|
| 59 |
+
# 1) Gestor de conexiones dinámicas: Postgres (Neon)
|
| 60 |
# ======================================================
|
| 61 |
|
| 62 |
+
class PostgresManager:
|
| 63 |
"""
|
| 64 |
+
Cada upload crea un *schema* aislado en Neon.
|
| 65 |
+
connections[connection_id] = {
|
| 66 |
+
"label": str, # nombre de archivo original
|
| 67 |
+
"engine": "postgres",
|
| 68 |
+
"schema": str # nombre del schema en Neon
|
| 69 |
+
}
|
| 70 |
"""
|
| 71 |
|
| 72 |
+
def __init__(self, dsn: str):
|
| 73 |
+
self.dsn = dsn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
self.connections: Dict[str, Dict[str, Any]] = {}
|
| 75 |
|
| 76 |
# ---------- utilidades internas ----------
|
|
|
|
| 83 |
raise KeyError(f"connection_id '{connection_id}' no registrado")
|
| 84 |
return self.connections[connection_id]
|
| 85 |
|
| 86 |
+
def _get_conn(self):
|
| 87 |
+
conn = psycopg2.connect(self.dsn)
|
| 88 |
+
conn.autocommit = True
|
| 89 |
+
return conn
|
| 90 |
+
|
| 91 |
# ---------- creación de BD desde dump ----------
|
| 92 |
|
| 93 |
def create_database_from_dump(self, label: str, sql_text: str) -> str:
|
| 94 |
"""
|
| 95 |
+
Crea un schema en Neon, fija search_path a ese schema
|
| 96 |
+
y ejecuta el dump SQL dentro de él.
|
|
|
|
| 97 |
"""
|
| 98 |
connection_id = self._new_connection_id()
|
| 99 |
+
schema_name = f"sess_{uuid.uuid4().hex[:8]}"
|
|
|
|
| 100 |
|
| 101 |
+
conn = self._get_conn()
|
|
|
|
| 102 |
try:
|
| 103 |
+
with conn.cursor() as cur:
|
| 104 |
+
# Crear schema aislado
|
| 105 |
+
cur.execute(
|
| 106 |
+
pgsql.SQL("CREATE SCHEMA {}").format(
|
| 107 |
+
pgsql.Identifier(schema_name)
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
# Usar ese schema por defecto
|
| 111 |
+
cur.execute(
|
| 112 |
+
pgsql.SQL("SET search_path TO {}").format(
|
| 113 |
+
pgsql.Identifier(schema_name)
|
| 114 |
+
)
|
| 115 |
+
)
|
| 116 |
+
# Ejecutar dump completo (puede tener múltiples sentencias)
|
| 117 |
+
cur.execute(sql_text)
|
| 118 |
except Exception as e:
|
| 119 |
+
# Si falla, intentar limpiar el schema
|
| 120 |
+
try:
|
| 121 |
+
with conn.cursor() as cur:
|
| 122 |
+
cur.execute(
|
| 123 |
+
pgsql.SQL("DROP SCHEMA IF EXISTS {} CASCADE").format(
|
| 124 |
+
pgsql.Identifier(schema_name)
|
| 125 |
+
)
|
| 126 |
+
)
|
| 127 |
+
except Exception:
|
| 128 |
+
pass
|
| 129 |
conn.close()
|
| 130 |
+
raise RuntimeError(f"Error ejecutando dump SQL en Postgres: {e}")
|
|
|
|
|
|
|
| 131 |
finally:
|
| 132 |
conn.close()
|
| 133 |
|
| 134 |
self.connections[connection_id] = {
|
| 135 |
"label": label,
|
| 136 |
+
"engine": "postgres",
|
| 137 |
+
"schema": schema_name,
|
|
|
|
| 138 |
}
|
| 139 |
return connection_id
|
| 140 |
|
| 141 |
# ---------- ejecución segura de SQL ----------
|
| 142 |
|
| 143 |
+
def execute_sql(self, connection_id: str, sql_text: str) -> Dict[str, Any]:
|
| 144 |
"""
|
| 145 |
+
Ejecuta un SELECT dentro del schema asociado al connection_id.
|
| 146 |
Bloquea operaciones destructivas por seguridad.
|
| 147 |
"""
|
| 148 |
info = self._get_info(connection_id)
|
| 149 |
+
schema = info["schema"]
|
| 150 |
|
| 151 |
forbidden = ["drop ", "delete ", "update ", "insert ", "alter ", "replace "]
|
| 152 |
+
sql_low = sql_text.lower()
|
| 153 |
if any(tok in sql_low for tok in forbidden):
|
| 154 |
return {
|
| 155 |
"ok": False,
|
|
|
|
| 158 |
"columns": [],
|
| 159 |
}
|
| 160 |
|
| 161 |
+
conn = self._get_conn()
|
| 162 |
try:
|
| 163 |
+
with conn.cursor() as cur:
|
| 164 |
+
# usar el schema de la sesión
|
| 165 |
+
cur.execute(
|
| 166 |
+
pgsql.SQL("SET search_path TO {}").format(
|
| 167 |
+
pgsql.Identifier(schema)
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
cur.execute(sql_text)
|
| 171 |
+
|
| 172 |
+
if cur.description:
|
| 173 |
+
rows = cur.fetchall()
|
| 174 |
+
cols = [d[0] for d in cur.description]
|
| 175 |
+
else:
|
| 176 |
+
rows, cols = [], []
|
| 177 |
+
|
| 178 |
+
return {
|
| 179 |
+
"ok": True,
|
| 180 |
+
"error": None,
|
| 181 |
+
"rows": [list(r) for r in rows],
|
| 182 |
+
"columns": cols,
|
| 183 |
+
}
|
| 184 |
except Exception as e:
|
| 185 |
return {"ok": False, "error": str(e), "rows": None, "columns": []}
|
| 186 |
+
finally:
|
| 187 |
+
conn.close()
|
| 188 |
|
| 189 |
# ---------- introspección de esquema ----------
|
| 190 |
|
| 191 |
def get_schema(self, connection_id: str) -> Dict[str, Any]:
|
| 192 |
info = self._get_info(connection_id)
|
| 193 |
+
schema = info["schema"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
+
conn = self._get_conn()
|
| 196 |
+
try:
|
| 197 |
+
tables_info: Dict[str, Dict[str, Any]] = {}
|
| 198 |
+
foreign_keys: List[Dict[str, Any]] = []
|
| 199 |
+
|
| 200 |
+
with conn.cursor() as cur:
|
| 201 |
+
# Tablas básicas
|
| 202 |
+
cur.execute(
|
| 203 |
+
"""
|
| 204 |
+
SELECT table_name
|
| 205 |
+
FROM information_schema.tables
|
| 206 |
+
WHERE table_schema = %s
|
| 207 |
+
AND table_type = 'BASE TABLE'
|
| 208 |
+
ORDER BY table_name;
|
| 209 |
+
""",
|
| 210 |
+
(schema,),
|
| 211 |
+
)
|
| 212 |
+
tables = [r[0] for r in cur.fetchall()]
|
| 213 |
+
|
| 214 |
+
# Columnas por tabla
|
| 215 |
+
for t in tables:
|
| 216 |
+
cur.execute(
|
| 217 |
+
"""
|
| 218 |
+
SELECT column_name
|
| 219 |
+
FROM information_schema.columns
|
| 220 |
+
WHERE table_schema = %s
|
| 221 |
+
AND table_name = %s
|
| 222 |
+
ORDER BY ordinal_position;
|
| 223 |
+
""",
|
| 224 |
+
(schema, t),
|
| 225 |
+
)
|
| 226 |
+
cols = [r[0] for r in cur.fetchall()]
|
| 227 |
+
tables_info[t] = {"columns": cols}
|
| 228 |
+
|
| 229 |
+
# Foreign keys
|
| 230 |
+
cur.execute(
|
| 231 |
+
"""
|
| 232 |
+
SELECT
|
| 233 |
+
tc.table_name AS from_table,
|
| 234 |
+
kcu.column_name AS from_column,
|
| 235 |
+
ccu.table_name AS to_table,
|
| 236 |
+
ccu.column_name AS to_column
|
| 237 |
+
FROM information_schema.table_constraints AS tc
|
| 238 |
+
JOIN information_schema.key_column_usage AS kcu
|
| 239 |
+
ON tc.constraint_name = kcu.constraint_name
|
| 240 |
+
AND tc.table_schema = kcu.table_schema
|
| 241 |
+
JOIN information_schema.constraint_column_usage AS ccu
|
| 242 |
+
ON ccu.constraint_name = tc.constraint_name
|
| 243 |
+
AND ccu.table_schema = tc.table_schema
|
| 244 |
+
WHERE tc.constraint_type = 'FOREIGN KEY'
|
| 245 |
+
AND tc.table_schema = %s;
|
| 246 |
+
""",
|
| 247 |
+
(schema,),
|
| 248 |
+
)
|
| 249 |
+
for ft, fc, tt, tc2 in cur.fetchall():
|
| 250 |
+
foreign_keys.append(
|
| 251 |
+
{
|
| 252 |
+
"from_table": ft,
|
| 253 |
+
"from_column": fc,
|
| 254 |
+
"to_table": tt,
|
| 255 |
+
"to_column": tc2,
|
| 256 |
+
}
|
| 257 |
+
)
|
| 258 |
|
| 259 |
+
return {
|
| 260 |
+
"tables": tables_info,
|
| 261 |
+
"foreign_keys": foreign_keys,
|
| 262 |
+
}
|
| 263 |
+
finally:
|
| 264 |
+
conn.close()
|
| 265 |
|
| 266 |
# ---------- preview de tabla ----------
|
| 267 |
|
| 268 |
+
def get_preview(
|
| 269 |
+
self, connection_id: str, table: str, limit: int = 20
|
| 270 |
+
) -> Dict[str, Any]:
|
| 271 |
info = self._get_info(connection_id)
|
| 272 |
+
schema = info["schema"]
|
| 273 |
|
| 274 |
+
conn = self._get_conn()
|
|
|
|
| 275 |
try:
|
| 276 |
+
with conn.cursor() as cur:
|
| 277 |
+
cur.execute(
|
| 278 |
+
pgsql.SQL("SET search_path TO {}").format(
|
| 279 |
+
pgsql.Identifier(schema)
|
| 280 |
+
)
|
| 281 |
+
)
|
| 282 |
+
query = pgsql.SQL("SELECT * FROM {} LIMIT %s").format(
|
| 283 |
+
pgsql.Identifier(table)
|
| 284 |
+
)
|
| 285 |
+
cur.execute(query, (int(limit),))
|
| 286 |
+
rows = cur.fetchall()
|
| 287 |
+
cols = [d[0] for d in cur.description] if cur.description else []
|
| 288 |
+
|
| 289 |
+
return {
|
| 290 |
+
"columns": cols,
|
| 291 |
+
"rows": [list(r) for r in rows],
|
| 292 |
+
}
|
| 293 |
finally:
|
| 294 |
conn.close()
|
| 295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
+
# Instancia global de PostgresManager
|
| 298 |
+
sql_manager = PostgresManager(POSTGRES_DSN)
|
|
|
|
| 299 |
|
| 300 |
# ======================================================
|
| 301 |
# 2) Inicialización de FastAPI
|
| 302 |
# ======================================================
|
| 303 |
|
| 304 |
app = FastAPI(
|
| 305 |
+
title="NL2SQL Backend (Supabase + Postgres/Neon)",
|
| 306 |
+
version="3.0.0",
|
| 307 |
)
|
| 308 |
|
| 309 |
app.add_middleware(
|
|
|
|
| 330 |
return
|
| 331 |
print(f"🔁 Cargando modelo NL→SQL desde: {MODEL_DIR}")
|
| 332 |
t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
|
| 333 |
+
t5_model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 334 |
+
MODEL_DIR, torch_dtype=torch.float32
|
| 335 |
+
)
|
| 336 |
t5_model.to(DEVICE)
|
| 337 |
t5_model.eval()
|
| 338 |
print("✅ Modelo NL→SQL listo en memoria.")
|
|
|
|
| 381 |
|
| 382 |
def _normalize_name_for_match(name: str) -> str:
|
| 383 |
s = name.lower()
|
| 384 |
+
s = s.replace('"', "").replace("`", "")
|
| 385 |
s = s.replace("_", "")
|
| 386 |
if s.endswith("s") and len(s) > 3:
|
| 387 |
s = s[:-1]
|
| 388 |
return s
|
| 389 |
|
| 390 |
|
| 391 |
+
def _build_schema_indexes(
|
| 392 |
+
tables_info: Dict[str, Dict[str, List[str]]]
|
| 393 |
+
) -> Dict[str, Dict[str, List[str]]]:
|
| 394 |
table_index: Dict[str, List[str]] = {}
|
| 395 |
column_index: Dict[str, List[str]] = {}
|
| 396 |
|
|
|
|
| 451 |
def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optional[str]:
|
| 452 |
"""
|
| 453 |
Intenta reparar nombres de tablas/columnas basándose en el esquema real.
|
| 454 |
+
Compatible con mensajes de Postgres y también con los de SQLite
|
| 455 |
+
(por si algún día reusamos la lógica).
|
| 456 |
"""
|
| 457 |
tables_info = schema_meta["tables"]
|
| 458 |
idx = _build_schema_indexes(tables_info)
|
|
|
|
| 465 |
missing_table = None
|
| 466 |
missing_column = None
|
| 467 |
|
| 468 |
+
m_t = re.search(r'relation "([\w\.]+)" does not exist', error, re.IGNORECASE)
|
| 469 |
if not m_t:
|
| 470 |
m_t = re.search(r"no such table: ([\w\.]+)", error)
|
| 471 |
if m_t:
|
| 472 |
missing_table = m_t.group(1)
|
| 473 |
|
| 474 |
+
m_c = re.search(r'column "([\w\.]+)" does not exist', error, re.IGNORECASE)
|
| 475 |
if not m_c:
|
| 476 |
m_c = re.search(r"no such column: ([\w\.]+)", error)
|
| 477 |
if m_c:
|
|
|
|
| 515 |
|
| 516 |
|
| 517 |
# ======================================================
|
| 518 |
+
# 5) Prompt NL→SQL + re-ranking
|
| 519 |
# ======================================================
|
| 520 |
|
| 521 |
def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
|
|
|
|
| 525 |
f"note: use JOIN when foreign keys link tables"
|
| 526 |
)
|
| 527 |
|
| 528 |
+
|
| 529 |
def normalize_score(raw: float) -> float:
|
| 530 |
"""Normaliza el score logit del modelo a un porcentaje 0-100."""
|
|
|
|
| 531 |
norm = (raw + 20) / 25
|
| 532 |
norm = max(0, min(1, norm))
|
| 533 |
return round(norm * 100, 2)
|
|
|
|
| 535 |
|
| 536 |
def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
|
| 537 |
if conn_id not in sql_manager.connections:
|
| 538 |
+
raise HTTPException(
|
| 539 |
+
status_code=404, detail=f"connection_id '{conn_id}' no registrado"
|
| 540 |
+
)
|
| 541 |
|
|
|
|
| 542 |
meta = sql_manager.get_schema(conn_id)
|
| 543 |
tables_info = meta["tables"]
|
| 544 |
|
|
|
|
| 556 |
if t5_model is None:
|
| 557 |
load_nl2sql_model()
|
| 558 |
|
| 559 |
+
inputs = t5_tokenizer(
|
| 560 |
+
[prompt], return_tensors="pt", truncation=True, max_length=768
|
| 561 |
+
).to(DEVICE)
|
| 562 |
num_beams = 6
|
| 563 |
num_return = 6
|
| 564 |
|
|
|
|
| 585 |
best_score = -1e9
|
| 586 |
|
| 587 |
for i in range(sequences.size(0)):
|
| 588 |
+
raw_sql = t5_tokenizer.decode(
|
| 589 |
+
sequences[i], skip_special_tokens=True
|
| 590 |
+
).strip()
|
| 591 |
cand: Dict[str, Any] = {
|
| 592 |
"sql": raw_sql,
|
| 593 |
"score": float(scores[i]),
|
|
|
|
| 598 |
|
| 599 |
exec_info = sql_manager.execute_sql(conn_id, raw_sql)
|
| 600 |
|
|
|
|
| 601 |
err_lower = (exec_info["error"] or "").lower()
|
| 602 |
if (not exec_info["ok"]) and (
|
| 603 |
"no such table" in err_lower
|
|
|
|
| 611 |
if not repaired_sql or repaired_sql == current_sql:
|
| 612 |
break
|
| 613 |
exec_info2 = sql_manager.execute_sql(conn_id, repaired_sql)
|
| 614 |
+
cand["repaired_from"] = (
|
| 615 |
+
current_sql if cand["repaired_from"] is None else cand["repaired_from"]
|
| 616 |
+
)
|
| 617 |
cand["repair_note"] = f"auto-repair (table/column name, step {step})"
|
| 618 |
cand["sql"] = repaired_sql
|
| 619 |
exec_info = exec_info2
|
|
|
|
| 666 |
class UploadResponse(BaseModel):
|
| 667 |
connection_id: str
|
| 668 |
label: str
|
| 669 |
+
db_path: str
|
| 670 |
note: Optional[str] = None
|
| 671 |
|
| 672 |
|
|
|
|
| 674 |
connection_id: str
|
| 675 |
label: str
|
| 676 |
engine: Optional[str] = None
|
| 677 |
+
db_name: Optional[str] = None # ya no usamos archivo, pero mantenemos campo
|
| 678 |
|
| 679 |
|
| 680 |
class SchemaResponse(BaseModel):
|
|
|
|
| 759 |
@app.on_event("startup")
|
| 760 |
async def startup_event():
|
| 761 |
load_nl2sql_model()
|
| 762 |
+
print("✅ Backend NL2SQL inicializado (engine Postgres/Neon).")
|
| 763 |
print(f"MODEL_DIR={MODEL_DIR}, DEVICE={DEVICE}")
|
| 764 |
print(f"Conexiones activas al inicio: {len(sql_manager.connections)}")
|
| 765 |
|
|
|
|
| 767 |
@app.post("/upload", response_model=UploadResponse)
|
| 768 |
async def upload_database(
|
| 769 |
db_file: UploadFile = File(...),
|
| 770 |
+
authorization: Optional[str] = Header(None),
|
| 771 |
):
|
| 772 |
if authorization is None:
|
| 773 |
raise HTTPException(401, "Missing Authorization header")
|
|
|
|
| 777 |
if not user or not user.user:
|
| 778 |
raise HTTPException(401, "Invalid Supabase token")
|
| 779 |
|
| 780 |
+
filename = db_file.filename or ""
|
| 781 |
fname_lower = filename.lower()
|
| 782 |
contents = await db_file.read()
|
| 783 |
|
|
|
|
| 792 |
else:
|
| 793 |
raise HTTPException(400, "Formato no soportado. Usa .sql o .zip.")
|
| 794 |
|
| 795 |
+
# --- crear schema dinámico en Postgres ---
|
| 796 |
try:
|
| 797 |
conn_id = sql_manager.create_database_from_dump(label=filename, sql_text=sql_text)
|
| 798 |
except Exception as e:
|
|
|
|
| 800 |
|
| 801 |
meta = sql_manager.connections[conn_id]
|
| 802 |
|
| 803 |
+
# --- guardar en Supabase (metadatos) ---
|
| 804 |
+
supabase.table("databases").insert(
|
| 805 |
+
{
|
| 806 |
+
"user_id": user.user.id,
|
| 807 |
+
"filename": filename,
|
| 808 |
+
"engine": meta["engine"],
|
| 809 |
+
"connection_id": conn_id,
|
| 810 |
+
}
|
| 811 |
+
).execute()
|
| 812 |
|
| 813 |
return UploadResponse(
|
| 814 |
connection_id=conn_id,
|
| 815 |
label=filename,
|
| 816 |
+
db_path=f"{meta['engine']}://schema/{meta['schema']}",
|
| 817 |
+
note="Database schema created in Neon and indexed in Supabase.",
|
| 818 |
)
|
| 819 |
|
| 820 |
|
|
|
|
| 825 |
connection_id=cid,
|
| 826 |
label=meta.get("label", ""),
|
| 827 |
engine=meta.get("engine"),
|
| 828 |
+
db_name=meta.get("schema"), # usamos schema como "nombre"
|
| 829 |
)
|
| 830 |
for cid, meta in sql_manager.connections.items()
|
| 831 |
]
|
|
|
|
| 860 |
try:
|
| 861 |
preview = sql_manager.get_preview(connection_id, table, limit)
|
| 862 |
except Exception as e:
|
| 863 |
+
raise HTTPException(
|
| 864 |
+
status_code=400, detail=f"Error al leer tabla '{table}': {e}"
|
| 865 |
+
)
|
| 866 |
|
| 867 |
return PreviewResponse(
|
| 868 |
connection_id=connection_id,
|
|
|
|
| 875 |
@app.post("/infer", response_model=InferResponse)
|
| 876 |
async def infer_sql(
|
| 877 |
req: InferRequest,
|
| 878 |
+
authorization: Optional[str] = Header(None),
|
| 879 |
):
|
| 880 |
if authorization is None:
|
| 881 |
raise HTTPException(401, "Missing Authorization header")
|
|
|
|
| 888 |
result = nl2sql_with_rerank(req.question, req.connection_id)
|
| 889 |
score = normalize_score(result["candidates"][0]["score"])
|
| 890 |
|
| 891 |
+
db_row = (
|
| 892 |
+
supabase.table("databases")
|
| 893 |
+
.select("id")
|
| 894 |
+
.eq("connection_id", req.connection_id)
|
| 895 |
+
.eq("user_id", user.user.id)
|
| 896 |
.execute()
|
| 897 |
+
)
|
| 898 |
db_id = db_row.data[0]["id"] if db_row.data else None
|
| 899 |
|
| 900 |
+
supabase.table("queries").insert(
|
| 901 |
+
{
|
| 902 |
+
"user_id": user.user.id,
|
| 903 |
+
"db_id": db_id,
|
| 904 |
+
"nl": result["question_original"],
|
| 905 |
+
"sql_generated": result["best_sql"],
|
| 906 |
+
"sql_repaired": result["candidates"][0]["sql"],
|
| 907 |
+
"execution_ok": result["best_exec_ok"],
|
| 908 |
+
"error": result["best_exec_error"],
|
| 909 |
+
"rows_preview": result["best_rows_preview"],
|
| 910 |
+
"score": score,
|
| 911 |
+
}
|
| 912 |
+
).execute()
|
| 913 |
|
| 914 |
result["score_percent"] = score
|
| 915 |
return InferResponse(**result)
|
|
|
|
| 918 |
@app.post("/speech-infer", response_model=SpeechInferResponse)
|
| 919 |
async def speech_infer(
|
| 920 |
connection_id: str = Form(...),
|
| 921 |
+
audio: UploadFile = File(...),
|
| 922 |
):
|
| 923 |
if openai_client is None:
|
| 924 |
raise HTTPException(
|
| 925 |
status_code=500,
|
| 926 |
+
detail="OPENAI_API_KEY no está configurado en el backend.",
|
| 927 |
)
|
| 928 |
|
| 929 |
if audio.content_type is None:
|
|
|
|
| 934 |
tmp.write(await audio.read())
|
| 935 |
tmp_path = tmp.name
|
| 936 |
except Exception:
|
| 937 |
+
raise HTTPException(
|
| 938 |
+
status_code=500, detail="No se pudo procesar el audio recibido."
|
| 939 |
+
)
|
| 940 |
|
| 941 |
try:
|
| 942 |
with open(tmp_path, "rb") as f:
|
|
|
|
| 964 |
"model_loaded": t5_model is not None,
|
| 965 |
"connections": len(sql_manager.connections),
|
| 966 |
"device": str(DEVICE),
|
| 967 |
+
"engine": "postgres",
|
| 968 |
}
|
| 969 |
|
| 970 |
+
|
| 971 |
@app.get("/history")
|
| 972 |
def get_history(authorization: Optional[str] = Header(None)):
|
| 973 |
if authorization is None:
|
|
|
|
| 976 |
jwt = authorization.replace("Bearer ", "")
|
| 977 |
user = supabase.auth.get_user(jwt)
|
| 978 |
|
| 979 |
+
rows = (
|
| 980 |
+
supabase.table("queries")
|
| 981 |
+
.select("*")
|
| 982 |
+
.eq("user_id", user.user.id)
|
| 983 |
+
.order("created_at", desc=True)
|
| 984 |
.execute()
|
| 985 |
+
)
|
| 986 |
|
| 987 |
return rows.data
|
| 988 |
|
|
|
|
| 995 |
jwt = authorization.replace("Bearer ", "")
|
| 996 |
user = supabase.auth.get_user(jwt)
|
| 997 |
|
| 998 |
+
rows = (
|
| 999 |
+
supabase.table("databases")
|
| 1000 |
+
.select("*")
|
| 1001 |
+
.eq("user_id", user.user.id)
|
| 1002 |
.execute()
|
| 1003 |
+
)
|
| 1004 |
|
| 1005 |
return rows.data
|
| 1006 |
|
|
|
|
| 1008 |
@app.get("/")
|
| 1009 |
async def root():
|
| 1010 |
return {
|
| 1011 |
+
"message": "NL2SQL T5-large backend running on Postgres/Neon (no SQLite).",
|
| 1012 |
"endpoints": [
|
| 1013 |
+
"POST /upload (subir .sql o .zip con .sql → crea schema en Neon)",
|
| 1014 |
+
"GET /connections (listar BDs subidas en esta instancia)",
|
| 1015 |
"GET /schema/{id} (esquema resumido)",
|
| 1016 |
"GET /preview/{id}/{t} (preview de tabla)",
|
| 1017 |
"POST /infer (NL→SQL + ejecución en BD)",
|
| 1018 |
+
"POST /speech-infer (voz → NL→SQL + ejecución)",
|
| 1019 |
+
"GET /history (historial de consultas en Supabase)",
|
| 1020 |
+
"GET /my-databases (BDs del usuario en Supabase)",
|
| 1021 |
"GET /health (estado del backend)",
|
| 1022 |
"GET /docs (OpenAPI UI)",
|
| 1023 |
],
|