stvnnnnnn commited on
Commit
869f1a1
·
verified ·
1 Parent(s): 32750dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +281 -156
app.py CHANGED
@@ -4,7 +4,6 @@ import zipfile
4
  import re
5
  import difflib
6
  import tempfile
7
- import sqlite3
8
  import uuid
9
  from typing import List, Optional, Dict, Any
10
 
@@ -18,21 +17,27 @@ from langdetect import detect
18
  from transformers import MarianMTModel, MarianTokenizer
19
  from openai import OpenAI
20
 
 
 
 
 
21
  # ---- Supabase ----
22
  from supabase import create_client, Client
23
 
24
  SUPABASE_URL = "https://bnvmqgjawtaslczewqyd.supabase.co"
25
- SUPABASE_ANON_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImJudm1xZ2phd3Rhc2xjemV3cXlkIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NjQ0NjM5NDAsImV4cCI6MjA4MDAzOTk0MH0.9zkyqrsm-QOSwMTUPZEWqyFeNpbbuar01rB7pmObkUI"
 
 
 
 
26
 
27
  supabase: Client = create_client(SUPABASE_URL, SUPABASE_ANON_KEY)
28
 
29
  # ======================================================
30
- # 0) Configuración general de paths
31
  # ======================================================
32
 
33
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
34
- UPLOAD_DIR = os.path.join(BASE_DIR, "uploaded_dbs")
35
- os.makedirs(UPLOAD_DIR, exist_ok=True)
36
 
37
  MODEL_DIR = os.getenv("MODEL_DIR", "stvnnnnnn/t5-large-nl2sql-spider")
38
  DEVICE = torch.device("cpu")
@@ -40,24 +45,32 @@ DEVICE = torch.device("cpu")
40
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
41
  openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
42
 
 
 
 
 
 
 
 
 
 
 
43
  # ======================================================
44
- # 1) SQLManager (versión actual: SQLite local)
45
  # ======================================================
46
 
47
- class SQLManager:
48
  """
49
- Gestor de "conexiones" a bases dinámicas.
50
- Versión actual: cada conexión es un archivo SQLite en UPLOAD_DIR.
51
- API pensada para poder cambiar después a Postgres/MySQL (Railway).
 
 
 
52
  """
53
 
54
- def __init__(self):
55
- # connections[connection_id] = {
56
- # "label": str,
57
- # "engine": "sqlite",
58
- # "db_name": str,
59
- # "db_path": str
60
- # }
61
  self.connections: Dict[str, Dict[str, Any]] = {}
62
 
63
  # ---------- utilidades internas ----------
@@ -70,51 +83,73 @@ class SQLManager:
70
  raise KeyError(f"connection_id '{connection_id}' no registrado")
71
  return self.connections[connection_id]
72
 
 
 
 
 
 
73
  # ---------- creación de BD desde dump ----------
74
 
75
  def create_database_from_dump(self, label: str, sql_text: str) -> str:
76
  """
77
- Crea un archivo SQLite, ejecuta el dump SQL y
78
- registra la conexión. Por ahora el dump debe ser
79
- razonablemente compatible con SQLite.
80
  """
81
  connection_id = self._new_connection_id()
82
- db_name = connection_id # nombre lógico
83
- db_path = os.path.join(UPLOAD_DIR, f"{db_name}.sqlite")
84
 
85
- # Ejecutar todo el script. Si falla, borramos el archivo.
86
- conn = sqlite3.connect(db_path)
87
  try:
88
- conn.executescript(sql_text)
89
- conn.commit()
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  except Exception as e:
 
 
 
 
 
 
 
 
 
 
91
  conn.close()
92
- if os.path.exists(db_path):
93
- os.remove(db_path)
94
- raise RuntimeError(f"Error ejecutando dump SQL en SQLite: {e}")
95
  finally:
96
  conn.close()
97
 
98
  self.connections[connection_id] = {
99
  "label": label,
100
- "engine": "sqlite",
101
- "db_name": db_name,
102
- "db_path": db_path,
103
  }
104
  return connection_id
105
 
106
  # ---------- ejecución segura de SQL ----------
107
 
108
- def execute_sql(self, connection_id: str, sql: str) -> Dict[str, Any]:
109
  """
110
- Ejecuta un SELECT sobre la BD asociada al connection_id.
111
  Bloquea operaciones destructivas por seguridad.
112
  """
113
  info = self._get_info(connection_id)
114
- db_path = info["db_path"]
115
 
116
  forbidden = ["drop ", "delete ", "update ", "insert ", "alter ", "replace "]
117
- sql_low = sql.lower()
118
  if any(tok in sql_low for tok in forbidden):
119
  return {
120
  "ok": False,
@@ -123,89 +158,152 @@ class SQLManager:
123
  "columns": [],
124
  }
125
 
 
126
  try:
127
- conn = sqlite3.connect(db_path)
128
- cur = conn.cursor()
129
- cur.execute(sql)
130
- rows = cur.fetchall()
131
- cols = [d[0] for d in cur.description] if cur.description else []
132
- conn.close()
133
- return {"ok": True, "error": None, "rows": [list(r) for r in rows], "columns": cols}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  except Exception as e:
135
  return {"ok": False, "error": str(e), "rows": None, "columns": []}
 
 
136
 
137
  # ---------- introspección de esquema ----------
138
 
139
  def get_schema(self, connection_id: str) -> Dict[str, Any]:
140
  info = self._get_info(connection_id)
141
- db_path = info["db_path"]
142
-
143
- if not os.path.exists(db_path):
144
- raise RuntimeError(f"SQLite no encontrado: {db_path}")
145
-
146
- conn = sqlite3.connect(db_path)
147
- cur = conn.cursor()
148
-
149
- cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
150
- tables = [row[0] for row in cur.fetchall()]
151
-
152
- tables_info: Dict[str, Dict[str, Any]] = {}
153
- foreign_keys: List[Dict[str, Any]] = []
154
-
155
- for t in tables:
156
- cur.execute(f"PRAGMA table_info('{t}');")
157
- rows = cur.fetchall()
158
- cols = [r[1] for r in rows]
159
- tables_info[t] = {"columns": cols}
160
 
161
- cur.execute(f"PRAGMA foreign_key_list('{t}');")
162
- fks = cur.fetchall()
163
- for (id_, seq, ref_table, from_col, to_col, on_update, on_delete, match) in fks:
164
- foreign_keys.append({
165
- "from_table": t,
166
- "from_column": from_col,
167
- "to_table": ref_table,
168
- "to_column": to_col,
169
- })
170
-
171
- conn.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- return {
174
- "tables": tables_info,
175
- "foreign_keys": foreign_keys,
176
- }
 
 
177
 
178
  # ---------- preview de tabla ----------
179
 
180
- def get_preview(self, connection_id: str, table: str, limit: int = 20) -> Dict[str, Any]:
 
 
181
  info = self._get_info(connection_id)
182
- db_path = info["db_path"]
183
 
184
- conn = sqlite3.connect(db_path)
185
- cur = conn.cursor()
186
  try:
187
- cur.execute(f'SELECT * FROM "{table}" LIMIT {int(limit)};')
188
- rows = cur.fetchall()
189
- cols = [d[0] for d in cur.description] if cur.description else []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  finally:
191
  conn.close()
192
 
193
- return {
194
- "columns": cols,
195
- "rows": [list(r) for r in rows],
196
- }
197
 
198
-
199
- # Instancia global de SQLManager
200
- sql_manager = SQLManager()
201
 
202
  # ======================================================
203
  # 2) Inicialización de FastAPI
204
  # ======================================================
205
 
206
  app = FastAPI(
207
- title="NL2SQL Backend (with Supabase Auth + History)",
208
- version="2.1.0",
209
  )
210
 
211
  app.add_middleware(
@@ -232,7 +330,9 @@ def load_nl2sql_model():
232
  return
233
  print(f"🔁 Cargando modelo NL→SQL desde: {MODEL_DIR}")
234
  t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
235
- t5_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR, torch_dtype=torch.float32)
 
 
236
  t5_model.to(DEVICE)
237
  t5_model.eval()
238
  print("✅ Modelo NL→SQL listo en memoria.")
@@ -281,14 +381,16 @@ def translate_es_to_en(text: str) -> str:
281
 
282
  def _normalize_name_for_match(name: str) -> str:
283
  s = name.lower()
284
- s = s.replace('"', '').replace("`", "")
285
  s = s.replace("_", "")
286
  if s.endswith("s") and len(s) > 3:
287
  s = s[:-1]
288
  return s
289
 
290
 
291
- def _build_schema_indexes(tables_info: Dict[str, Dict[str, List[str]]]) -> Dict[str, Dict[str, List[str]]]:
 
 
292
  table_index: Dict[str, List[str]] = {}
293
  column_index: Dict[str, List[str]] = {}
294
 
@@ -349,6 +451,8 @@ DOMAIN_SYNONYMS_COLUMN = {
349
  def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optional[str]:
350
  """
351
  Intenta reparar nombres de tablas/columnas basándose en el esquema real.
 
 
352
  """
353
  tables_info = schema_meta["tables"]
354
  idx = _build_schema_indexes(tables_info)
@@ -361,13 +465,13 @@ def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optiona
361
  missing_table = None
362
  missing_column = None
363
 
364
- m_t = re.search(r"relation \"([\w\.]+)\" does not exist", error, re.IGNORECASE)
365
  if not m_t:
366
  m_t = re.search(r"no such table: ([\w\.]+)", error)
367
  if m_t:
368
  missing_table = m_t.group(1)
369
 
370
- m_c = re.search(r"column \"([\w\.]+)\" does not exist", error, re.IGNORECASE)
371
  if not m_c:
372
  m_c = re.search(r"no such column: ([\w\.]+)", error)
373
  if m_c:
@@ -411,7 +515,7 @@ def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optiona
411
 
412
 
413
  # ======================================================
414
- # 5) Construcción de prompt y NL→SQL + re-ranking
415
  # ======================================================
416
 
417
  def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
@@ -421,9 +525,9 @@ def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
421
  f"note: use JOIN when foreign keys link tables"
422
  )
423
 
 
424
  def normalize_score(raw: float) -> float:
425
  """Normaliza el score logit del modelo a un porcentaje 0-100."""
426
- # Rango típico de logits beam-search: -20 a +5
427
  norm = (raw + 20) / 25
428
  norm = max(0, min(1, norm))
429
  return round(norm * 100, 2)
@@ -431,9 +535,10 @@ def normalize_score(raw: float) -> float:
431
 
432
  def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
433
  if conn_id not in sql_manager.connections:
434
- raise HTTPException(status_code=404, detail=f"connection_id '{conn_id}' no registrado")
 
 
435
 
436
- # Obtener esquema real desde SQLite (futuro: Postgres/MySQL)
437
  meta = sql_manager.get_schema(conn_id)
438
  tables_info = meta["tables"]
439
 
@@ -451,7 +556,9 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
451
  if t5_model is None:
452
  load_nl2sql_model()
453
 
454
- inputs = t5_tokenizer([prompt], return_tensors="pt", truncation=True, max_length=768).to(DEVICE)
 
 
455
  num_beams = 6
456
  num_return = 6
457
 
@@ -478,7 +585,9 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
478
  best_score = -1e9
479
 
480
  for i in range(sequences.size(0)):
481
- raw_sql = t5_tokenizer.decode(sequences[i], skip_special_tokens=True).strip()
 
 
482
  cand: Dict[str, Any] = {
483
  "sql": raw_sql,
484
  "score": float(scores[i]),
@@ -489,7 +598,6 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
489
 
490
  exec_info = sql_manager.execute_sql(conn_id, raw_sql)
491
 
492
- # Intentar reparación solo si es error por tabla/columna
493
  err_lower = (exec_info["error"] or "").lower()
494
  if (not exec_info["ok"]) and (
495
  "no such table" in err_lower
@@ -503,7 +611,9 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
503
  if not repaired_sql or repaired_sql == current_sql:
504
  break
505
  exec_info2 = sql_manager.execute_sql(conn_id, repaired_sql)
506
- cand["repaired_from"] = current_sql if cand["repaired_from"] is None else cand["repaired_from"]
 
 
507
  cand["repair_note"] = f"auto-repair (table/column name, step {step})"
508
  cand["sql"] = repaired_sql
509
  exec_info = exec_info2
@@ -556,7 +666,7 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
556
  class UploadResponse(BaseModel):
557
  connection_id: str
558
  label: str
559
- db_path: str # pseudo-path (engine://db_name o similar)
560
  note: Optional[str] = None
561
 
562
 
@@ -564,7 +674,7 @@ class ConnectionInfo(BaseModel):
564
  connection_id: str
565
  label: str
566
  engine: Optional[str] = None
567
- db_name: Optional[str] = None
568
 
569
 
570
  class SchemaResponse(BaseModel):
@@ -649,7 +759,7 @@ def _combine_sql_files_from_zip(zip_bytes: bytes) -> str:
649
  @app.on_event("startup")
650
  async def startup_event():
651
  load_nl2sql_model()
652
- print("✅ Backend NL2SQL inicializado (engine SQLite por ahora).")
653
  print(f"MODEL_DIR={MODEL_DIR}, DEVICE={DEVICE}")
654
  print(f"Conexiones activas al inicio: {len(sql_manager.connections)}")
655
 
@@ -657,7 +767,7 @@ async def startup_event():
657
  @app.post("/upload", response_model=UploadResponse)
658
  async def upload_database(
659
  db_file: UploadFile = File(...),
660
- authorization: Optional[str] = Header(None)
661
  ):
662
  if authorization is None:
663
  raise HTTPException(401, "Missing Authorization header")
@@ -667,7 +777,7 @@ async def upload_database(
667
  if not user or not user.user:
668
  raise HTTPException(401, "Invalid Supabase token")
669
 
670
- filename = db_file.filename
671
  fname_lower = filename.lower()
672
  contents = await db_file.read()
673
 
@@ -682,7 +792,7 @@ async def upload_database(
682
  else:
683
  raise HTTPException(400, "Formato no soportado. Usa .sql o .zip.")
684
 
685
- # --- crear BD dinámica (SQLite temporal) ---
686
  try:
687
  conn_id = sql_manager.create_database_from_dump(label=filename, sql_text=sql_text)
688
  except Exception as e:
@@ -690,19 +800,21 @@ async def upload_database(
690
 
691
  meta = sql_manager.connections[conn_id]
692
 
693
- # --- guardar en Supabase ---
694
- supabase.table("databases").insert({
695
- "user_id": user.user.id,
696
- "filename": filename,
697
- "engine": meta["engine"],
698
- "connection_id": conn_id
699
- }).execute()
 
 
700
 
701
  return UploadResponse(
702
  connection_id=conn_id,
703
  label=filename,
704
- db_path=f"{meta['engine']}://{meta['db_name']}",
705
- note="Database stored and indexed in Supabase."
706
  )
707
 
708
 
@@ -713,7 +825,7 @@ async def list_connections():
713
  connection_id=cid,
714
  label=meta.get("label", ""),
715
  engine=meta.get("engine"),
716
- db_name=meta.get("db_name"),
717
  )
718
  for cid, meta in sql_manager.connections.items()
719
  ]
@@ -748,7 +860,9 @@ async def preview_table(connection_id: str, table: str, limit: int = 20):
748
  try:
749
  preview = sql_manager.get_preview(connection_id, table, limit)
750
  except Exception as e:
751
- raise HTTPException(status_code=400, detail=f"Error al leer tabla '{table}': {e}")
 
 
752
 
753
  return PreviewResponse(
754
  connection_id=connection_id,
@@ -761,7 +875,7 @@ async def preview_table(connection_id: str, table: str, limit: int = 20):
761
  @app.post("/infer", response_model=InferResponse)
762
  async def infer_sql(
763
  req: InferRequest,
764
- authorization: Optional[str] = Header(None)
765
  ):
766
  if authorization is None:
767
  raise HTTPException(401, "Missing Authorization header")
@@ -774,27 +888,28 @@ async def infer_sql(
774
  result = nl2sql_with_rerank(req.question, req.connection_id)
775
  score = normalize_score(result["candidates"][0]["score"])
776
 
777
- # buscar db_id en supabase
778
- db_row = supabase.table("databases") \
779
- .select("id") \
780
- .eq("connection_id", req.connection_id) \
781
- .eq("user_id", user.user.id) \
782
  .execute()
783
-
784
  db_id = db_row.data[0]["id"] if db_row.data else None
785
 
786
- # guardar historial
787
- supabase.table("queries").insert({
788
- "user_id": user.user.id,
789
- "db_id": db_id,
790
- "nl": result["question_original"],
791
- "sql_generated": result["best_sql"],
792
- "sql_repaired": result["candidates"][0]["sql"],
793
- "execution_ok": result["best_exec_ok"],
794
- "error": result["best_exec_error"],
795
- "rows_preview": result["best_rows_preview"],
796
- "score": score
797
- }).execute()
 
798
 
799
  result["score_percent"] = score
800
  return InferResponse(**result)
@@ -803,12 +918,12 @@ async def infer_sql(
803
  @app.post("/speech-infer", response_model=SpeechInferResponse)
804
  async def speech_infer(
805
  connection_id: str = Form(...),
806
- audio: UploadFile = File(...)
807
  ):
808
  if openai_client is None:
809
  raise HTTPException(
810
  status_code=500,
811
- detail="OPENAI_API_KEY no está configurado en el backend."
812
  )
813
 
814
  if audio.content_type is None:
@@ -819,7 +934,9 @@ async def speech_infer(
819
  tmp.write(await audio.read())
820
  tmp_path = tmp.name
821
  except Exception:
822
- raise HTTPException(status_code=500, detail="No se pudo procesar el audio recibido.")
 
 
823
 
824
  try:
825
  with open(tmp_path, "rb") as f:
@@ -847,8 +964,10 @@ async def health():
847
  "model_loaded": t5_model is not None,
848
  "connections": len(sql_manager.connections),
849
  "device": str(DEVICE),
 
850
  }
851
 
 
852
  @app.get("/history")
853
  def get_history(authorization: Optional[str] = Header(None)):
854
  if authorization is None:
@@ -857,11 +976,13 @@ def get_history(authorization: Optional[str] = Header(None)):
857
  jwt = authorization.replace("Bearer ", "")
858
  user = supabase.auth.get_user(jwt)
859
 
860
- rows = supabase.table("queries") \
861
- .select("*") \
862
- .eq("user_id", user.user.id) \
863
- .order("created_at", desc=True) \
 
864
  .execute()
 
865
 
866
  return rows.data
867
 
@@ -874,10 +995,12 @@ def get_my_databases(authorization: Optional[str] = Header(None)):
874
  jwt = authorization.replace("Bearer ", "")
875
  user = supabase.auth.get_user(jwt)
876
 
877
- rows = supabase.table("databases") \
878
- .select("*") \
879
- .eq("user_id", user.user.id) \
 
880
  .execute()
 
881
 
882
  return rows.data
883
 
@@ -885,14 +1008,16 @@ def get_my_databases(authorization: Optional[str] = Header(None)):
885
  @app.get("/")
886
  async def root():
887
  return {
888
- "message": "NL2SQL T5-large backend is running (engine SQLite, ready to upgrade to Postgres/MySQL).",
889
  "endpoints": [
890
- "POST /upload (subir .sql o .zip con .sql → crear BD dinámica)",
891
- "GET /connections (listar BDs subidas)",
892
  "GET /schema/{id} (esquema resumido)",
893
  "GET /preview/{id}/{t} (preview de tabla)",
894
  "POST /infer (NL→SQL + ejecución en BD)",
895
- "POST /speech-infer (NL por voz → SQL + ejecución)",
 
 
896
  "GET /health (estado del backend)",
897
  "GET /docs (OpenAPI UI)",
898
  ],
 
4
  import re
5
  import difflib
6
  import tempfile
 
7
  import uuid
8
  from typing import List, Optional, Dict, Any
9
 
 
17
  from transformers import MarianMTModel, MarianTokenizer
18
  from openai import OpenAI
19
 
20
+ # ---- Postgres (Neon) ----
21
+ import psycopg2
22
+ from psycopg2 import sql as pgsql
23
+
24
  # ---- Supabase ----
25
  from supabase import create_client, Client
26
 
27
  SUPABASE_URL = "https://bnvmqgjawtaslczewqyd.supabase.co"
28
+ SUPABASE_ANON_KEY = (
29
+ "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImJudm1x"
30
+ "Z2phd3Rhc2xjemV3cXlkIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NjQ0NjM5NDAsImV4cCI6MjA4"
31
+ "MDAzOTk0MH0.9zkyqrsm-QOSwMTUPZEWqyFeNpbbuar01rB7pmObkUI"
32
+ )
33
 
34
  supabase: Client = create_client(SUPABASE_URL, SUPABASE_ANON_KEY)
35
 
36
  # ======================================================
37
+ # 0) Configuración general de paths / modelo / OpenAI
38
  # ======================================================
39
 
40
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
 
 
41
 
42
  MODEL_DIR = os.getenv("MODEL_DIR", "stvnnnnnn/t5-large-nl2sql-spider")
43
  DEVICE = torch.device("cpu")
 
45
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
46
  openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
47
 
48
+ # DSN de Neon (Postgres) – EJEMPLO:
49
+ # postgres://user:pass@host/neondb?sslmode=require
50
+ POSTGRES_DSN = os.getenv("POSTGRES_DSN")
51
+
52
+ if not POSTGRES_DSN:
53
+ raise RuntimeError(
54
+ "⚠️ POSTGRES_DSN no está definido. "
55
+ "Configúralo en los secrets del Space con la cadena de conexión de Neon."
56
+ )
57
+
58
  # ======================================================
59
+ # 1) Gestor de conexiones dinámicas: Postgres (Neon)
60
  # ======================================================
61
 
62
+ class PostgresManager:
63
  """
64
+ Cada upload crea un *schema* aislado en Neon.
65
+ connections[connection_id] = {
66
+ "label": str, # nombre de archivo original
67
+ "engine": "postgres",
68
+ "schema": str # nombre del schema en Neon
69
+ }
70
  """
71
 
72
+ def __init__(self, dsn: str):
73
+ self.dsn = dsn
 
 
 
 
 
74
  self.connections: Dict[str, Dict[str, Any]] = {}
75
 
76
  # ---------- utilidades internas ----------
 
83
  raise KeyError(f"connection_id '{connection_id}' no registrado")
84
  return self.connections[connection_id]
85
 
86
+ def _get_conn(self):
87
+ conn = psycopg2.connect(self.dsn)
88
+ conn.autocommit = True
89
+ return conn
90
+
91
  # ---------- creación de BD desde dump ----------
92
 
93
  def create_database_from_dump(self, label: str, sql_text: str) -> str:
94
  """
95
+ Crea un schema en Neon, fija search_path a ese schema
96
+ y ejecuta el dump SQL dentro de él.
 
97
  """
98
  connection_id = self._new_connection_id()
99
+ schema_name = f"sess_{uuid.uuid4().hex[:8]}"
 
100
 
101
+ conn = self._get_conn()
 
102
  try:
103
+ with conn.cursor() as cur:
104
+ # Crear schema aislado
105
+ cur.execute(
106
+ pgsql.SQL("CREATE SCHEMA {}").format(
107
+ pgsql.Identifier(schema_name)
108
+ )
109
+ )
110
+ # Usar ese schema por defecto
111
+ cur.execute(
112
+ pgsql.SQL("SET search_path TO {}").format(
113
+ pgsql.Identifier(schema_name)
114
+ )
115
+ )
116
+ # Ejecutar dump completo (puede tener múltiples sentencias)
117
+ cur.execute(sql_text)
118
  except Exception as e:
119
+ # Si falla, intentar limpiar el schema
120
+ try:
121
+ with conn.cursor() as cur:
122
+ cur.execute(
123
+ pgsql.SQL("DROP SCHEMA IF EXISTS {} CASCADE").format(
124
+ pgsql.Identifier(schema_name)
125
+ )
126
+ )
127
+ except Exception:
128
+ pass
129
  conn.close()
130
+ raise RuntimeError(f"Error ejecutando dump SQL en Postgres: {e}")
 
 
131
  finally:
132
  conn.close()
133
 
134
  self.connections[connection_id] = {
135
  "label": label,
136
+ "engine": "postgres",
137
+ "schema": schema_name,
 
138
  }
139
  return connection_id
140
 
141
  # ---------- ejecución segura de SQL ----------
142
 
143
+ def execute_sql(self, connection_id: str, sql_text: str) -> Dict[str, Any]:
144
  """
145
+ Ejecuta un SELECT dentro del schema asociado al connection_id.
146
  Bloquea operaciones destructivas por seguridad.
147
  """
148
  info = self._get_info(connection_id)
149
+ schema = info["schema"]
150
 
151
  forbidden = ["drop ", "delete ", "update ", "insert ", "alter ", "replace "]
152
+ sql_low = sql_text.lower()
153
  if any(tok in sql_low for tok in forbidden):
154
  return {
155
  "ok": False,
 
158
  "columns": [],
159
  }
160
 
161
+ conn = self._get_conn()
162
  try:
163
+ with conn.cursor() as cur:
164
+ # usar el schema de la sesión
165
+ cur.execute(
166
+ pgsql.SQL("SET search_path TO {}").format(
167
+ pgsql.Identifier(schema)
168
+ )
169
+ )
170
+ cur.execute(sql_text)
171
+
172
+ if cur.description:
173
+ rows = cur.fetchall()
174
+ cols = [d[0] for d in cur.description]
175
+ else:
176
+ rows, cols = [], []
177
+
178
+ return {
179
+ "ok": True,
180
+ "error": None,
181
+ "rows": [list(r) for r in rows],
182
+ "columns": cols,
183
+ }
184
  except Exception as e:
185
  return {"ok": False, "error": str(e), "rows": None, "columns": []}
186
+ finally:
187
+ conn.close()
188
 
189
  # ---------- introspección de esquema ----------
190
 
191
  def get_schema(self, connection_id: str) -> Dict[str, Any]:
192
  info = self._get_info(connection_id)
193
+ schema = info["schema"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ conn = self._get_conn()
196
+ try:
197
+ tables_info: Dict[str, Dict[str, Any]] = {}
198
+ foreign_keys: List[Dict[str, Any]] = []
199
+
200
+ with conn.cursor() as cur:
201
+ # Tablas básicas
202
+ cur.execute(
203
+ """
204
+ SELECT table_name
205
+ FROM information_schema.tables
206
+ WHERE table_schema = %s
207
+ AND table_type = 'BASE TABLE'
208
+ ORDER BY table_name;
209
+ """,
210
+ (schema,),
211
+ )
212
+ tables = [r[0] for r in cur.fetchall()]
213
+
214
+ # Columnas por tabla
215
+ for t in tables:
216
+ cur.execute(
217
+ """
218
+ SELECT column_name
219
+ FROM information_schema.columns
220
+ WHERE table_schema = %s
221
+ AND table_name = %s
222
+ ORDER BY ordinal_position;
223
+ """,
224
+ (schema, t),
225
+ )
226
+ cols = [r[0] for r in cur.fetchall()]
227
+ tables_info[t] = {"columns": cols}
228
+
229
+ # Foreign keys
230
+ cur.execute(
231
+ """
232
+ SELECT
233
+ tc.table_name AS from_table,
234
+ kcu.column_name AS from_column,
235
+ ccu.table_name AS to_table,
236
+ ccu.column_name AS to_column
237
+ FROM information_schema.table_constraints AS tc
238
+ JOIN information_schema.key_column_usage AS kcu
239
+ ON tc.constraint_name = kcu.constraint_name
240
+ AND tc.table_schema = kcu.table_schema
241
+ JOIN information_schema.constraint_column_usage AS ccu
242
+ ON ccu.constraint_name = tc.constraint_name
243
+ AND ccu.table_schema = tc.table_schema
244
+ WHERE tc.constraint_type = 'FOREIGN KEY'
245
+ AND tc.table_schema = %s;
246
+ """,
247
+ (schema,),
248
+ )
249
+ for ft, fc, tt, tc2 in cur.fetchall():
250
+ foreign_keys.append(
251
+ {
252
+ "from_table": ft,
253
+ "from_column": fc,
254
+ "to_table": tt,
255
+ "to_column": tc2,
256
+ }
257
+ )
258
 
259
+ return {
260
+ "tables": tables_info,
261
+ "foreign_keys": foreign_keys,
262
+ }
263
+ finally:
264
+ conn.close()
265
 
266
  # ---------- preview de tabla ----------
267
 
268
+ def get_preview(
269
+ self, connection_id: str, table: str, limit: int = 20
270
+ ) -> Dict[str, Any]:
271
  info = self._get_info(connection_id)
272
+ schema = info["schema"]
273
 
274
+ conn = self._get_conn()
 
275
  try:
276
+ with conn.cursor() as cur:
277
+ cur.execute(
278
+ pgsql.SQL("SET search_path TO {}").format(
279
+ pgsql.Identifier(schema)
280
+ )
281
+ )
282
+ query = pgsql.SQL("SELECT * FROM {} LIMIT %s").format(
283
+ pgsql.Identifier(table)
284
+ )
285
+ cur.execute(query, (int(limit),))
286
+ rows = cur.fetchall()
287
+ cols = [d[0] for d in cur.description] if cur.description else []
288
+
289
+ return {
290
+ "columns": cols,
291
+ "rows": [list(r) for r in rows],
292
+ }
293
  finally:
294
  conn.close()
295
 
 
 
 
 
296
 
297
+ # Instancia global de PostgresManager
298
+ sql_manager = PostgresManager(POSTGRES_DSN)
 
299
 
300
  # ======================================================
301
  # 2) Inicialización de FastAPI
302
  # ======================================================
303
 
304
  app = FastAPI(
305
+ title="NL2SQL Backend (Supabase + Postgres/Neon)",
306
+ version="3.0.0",
307
  )
308
 
309
  app.add_middleware(
 
330
  return
331
  print(f"🔁 Cargando modelo NL→SQL desde: {MODEL_DIR}")
332
  t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
333
+ t5_model = AutoModelForSeq2SeqLM.from_pretrained(
334
+ MODEL_DIR, torch_dtype=torch.float32
335
+ )
336
  t5_model.to(DEVICE)
337
  t5_model.eval()
338
  print("✅ Modelo NL→SQL listo en memoria.")
 
381
 
382
  def _normalize_name_for_match(name: str) -> str:
383
  s = name.lower()
384
+ s = s.replace('"', "").replace("`", "")
385
  s = s.replace("_", "")
386
  if s.endswith("s") and len(s) > 3:
387
  s = s[:-1]
388
  return s
389
 
390
 
391
+ def _build_schema_indexes(
392
+ tables_info: Dict[str, Dict[str, List[str]]]
393
+ ) -> Dict[str, Dict[str, List[str]]]:
394
  table_index: Dict[str, List[str]] = {}
395
  column_index: Dict[str, List[str]] = {}
396
 
 
451
  def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optional[str]:
452
  """
453
  Intenta reparar nombres de tablas/columnas basándose en el esquema real.
454
+ Compatible con mensajes de Postgres y también con los de SQLite
455
+ (por si algún día reusamos la lógica).
456
  """
457
  tables_info = schema_meta["tables"]
458
  idx = _build_schema_indexes(tables_info)
 
465
  missing_table = None
466
  missing_column = None
467
 
468
+ m_t = re.search(r'relation "([\w\.]+)" does not exist', error, re.IGNORECASE)
469
  if not m_t:
470
  m_t = re.search(r"no such table: ([\w\.]+)", error)
471
  if m_t:
472
  missing_table = m_t.group(1)
473
 
474
+ m_c = re.search(r'column "([\w\.]+)" does not exist', error, re.IGNORECASE)
475
  if not m_c:
476
  m_c = re.search(r"no such column: ([\w\.]+)", error)
477
  if m_c:
 
515
 
516
 
517
  # ======================================================
518
+ # 5) Prompt NL→SQL + re-ranking
519
  # ======================================================
520
 
521
  def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
 
525
  f"note: use JOIN when foreign keys link tables"
526
  )
527
 
528
+
529
  def normalize_score(raw: float) -> float:
530
  """Normaliza el score logit del modelo a un porcentaje 0-100."""
 
531
  norm = (raw + 20) / 25
532
  norm = max(0, min(1, norm))
533
  return round(norm * 100, 2)
 
535
 
536
  def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
537
  if conn_id not in sql_manager.connections:
538
+ raise HTTPException(
539
+ status_code=404, detail=f"connection_id '{conn_id}' no registrado"
540
+ )
541
 
 
542
  meta = sql_manager.get_schema(conn_id)
543
  tables_info = meta["tables"]
544
 
 
556
  if t5_model is None:
557
  load_nl2sql_model()
558
 
559
+ inputs = t5_tokenizer(
560
+ [prompt], return_tensors="pt", truncation=True, max_length=768
561
+ ).to(DEVICE)
562
  num_beams = 6
563
  num_return = 6
564
 
 
585
  best_score = -1e9
586
 
587
  for i in range(sequences.size(0)):
588
+ raw_sql = t5_tokenizer.decode(
589
+ sequences[i], skip_special_tokens=True
590
+ ).strip()
591
  cand: Dict[str, Any] = {
592
  "sql": raw_sql,
593
  "score": float(scores[i]),
 
598
 
599
  exec_info = sql_manager.execute_sql(conn_id, raw_sql)
600
 
 
601
  err_lower = (exec_info["error"] or "").lower()
602
  if (not exec_info["ok"]) and (
603
  "no such table" in err_lower
 
611
  if not repaired_sql or repaired_sql == current_sql:
612
  break
613
  exec_info2 = sql_manager.execute_sql(conn_id, repaired_sql)
614
+ cand["repaired_from"] = (
615
+ current_sql if cand["repaired_from"] is None else cand["repaired_from"]
616
+ )
617
  cand["repair_note"] = f"auto-repair (table/column name, step {step})"
618
  cand["sql"] = repaired_sql
619
  exec_info = exec_info2
 
666
  class UploadResponse(BaseModel):
667
  connection_id: str
668
  label: str
669
+ db_path: str
670
  note: Optional[str] = None
671
 
672
 
 
674
  connection_id: str
675
  label: str
676
  engine: Optional[str] = None
677
+ db_name: Optional[str] = None # ya no usamos archivo, pero mantenemos campo
678
 
679
 
680
  class SchemaResponse(BaseModel):
 
759
  @app.on_event("startup")
760
  async def startup_event():
761
  load_nl2sql_model()
762
+ print("✅ Backend NL2SQL inicializado (engine Postgres/Neon).")
763
  print(f"MODEL_DIR={MODEL_DIR}, DEVICE={DEVICE}")
764
  print(f"Conexiones activas al inicio: {len(sql_manager.connections)}")
765
 
 
767
  @app.post("/upload", response_model=UploadResponse)
768
  async def upload_database(
769
  db_file: UploadFile = File(...),
770
+ authorization: Optional[str] = Header(None),
771
  ):
772
  if authorization is None:
773
  raise HTTPException(401, "Missing Authorization header")
 
777
  if not user or not user.user:
778
  raise HTTPException(401, "Invalid Supabase token")
779
 
780
+ filename = db_file.filename or ""
781
  fname_lower = filename.lower()
782
  contents = await db_file.read()
783
 
 
792
  else:
793
  raise HTTPException(400, "Formato no soportado. Usa .sql o .zip.")
794
 
795
+ # --- crear schema dinámico en Postgres ---
796
  try:
797
  conn_id = sql_manager.create_database_from_dump(label=filename, sql_text=sql_text)
798
  except Exception as e:
 
800
 
801
  meta = sql_manager.connections[conn_id]
802
 
803
+ # --- guardar en Supabase (metadatos) ---
804
+ supabase.table("databases").insert(
805
+ {
806
+ "user_id": user.user.id,
807
+ "filename": filename,
808
+ "engine": meta["engine"],
809
+ "connection_id": conn_id,
810
+ }
811
+ ).execute()
812
 
813
  return UploadResponse(
814
  connection_id=conn_id,
815
  label=filename,
816
+ db_path=f"{meta['engine']}://schema/{meta['schema']}",
817
+ note="Database schema created in Neon and indexed in Supabase.",
818
  )
819
 
820
 
 
825
  connection_id=cid,
826
  label=meta.get("label", ""),
827
  engine=meta.get("engine"),
828
+ db_name=meta.get("schema"), # usamos schema como "nombre"
829
  )
830
  for cid, meta in sql_manager.connections.items()
831
  ]
 
860
  try:
861
  preview = sql_manager.get_preview(connection_id, table, limit)
862
  except Exception as e:
863
+ raise HTTPException(
864
+ status_code=400, detail=f"Error al leer tabla '{table}': {e}"
865
+ )
866
 
867
  return PreviewResponse(
868
  connection_id=connection_id,
 
875
  @app.post("/infer", response_model=InferResponse)
876
  async def infer_sql(
877
  req: InferRequest,
878
+ authorization: Optional[str] = Header(None),
879
  ):
880
  if authorization is None:
881
  raise HTTPException(401, "Missing Authorization header")
 
888
  result = nl2sql_with_rerank(req.question, req.connection_id)
889
  score = normalize_score(result["candidates"][0]["score"])
890
 
891
+ db_row = (
892
+ supabase.table("databases")
893
+ .select("id")
894
+ .eq("connection_id", req.connection_id)
895
+ .eq("user_id", user.user.id)
896
  .execute()
897
+ )
898
  db_id = db_row.data[0]["id"] if db_row.data else None
899
 
900
+ supabase.table("queries").insert(
901
+ {
902
+ "user_id": user.user.id,
903
+ "db_id": db_id,
904
+ "nl": result["question_original"],
905
+ "sql_generated": result["best_sql"],
906
+ "sql_repaired": result["candidates"][0]["sql"],
907
+ "execution_ok": result["best_exec_ok"],
908
+ "error": result["best_exec_error"],
909
+ "rows_preview": result["best_rows_preview"],
910
+ "score": score,
911
+ }
912
+ ).execute()
913
 
914
  result["score_percent"] = score
915
  return InferResponse(**result)
 
918
  @app.post("/speech-infer", response_model=SpeechInferResponse)
919
  async def speech_infer(
920
  connection_id: str = Form(...),
921
+ audio: UploadFile = File(...),
922
  ):
923
  if openai_client is None:
924
  raise HTTPException(
925
  status_code=500,
926
+ detail="OPENAI_API_KEY no está configurado en el backend.",
927
  )
928
 
929
  if audio.content_type is None:
 
934
  tmp.write(await audio.read())
935
  tmp_path = tmp.name
936
  except Exception:
937
+ raise HTTPException(
938
+ status_code=500, detail="No se pudo procesar el audio recibido."
939
+ )
940
 
941
  try:
942
  with open(tmp_path, "rb") as f:
 
964
  "model_loaded": t5_model is not None,
965
  "connections": len(sql_manager.connections),
966
  "device": str(DEVICE),
967
+ "engine": "postgres",
968
  }
969
 
970
+
971
  @app.get("/history")
972
  def get_history(authorization: Optional[str] = Header(None)):
973
  if authorization is None:
 
976
  jwt = authorization.replace("Bearer ", "")
977
  user = supabase.auth.get_user(jwt)
978
 
979
+ rows = (
980
+ supabase.table("queries")
981
+ .select("*")
982
+ .eq("user_id", user.user.id)
983
+ .order("created_at", desc=True)
984
  .execute()
985
+ )
986
 
987
  return rows.data
988
 
 
995
  jwt = authorization.replace("Bearer ", "")
996
  user = supabase.auth.get_user(jwt)
997
 
998
+ rows = (
999
+ supabase.table("databases")
1000
+ .select("*")
1001
+ .eq("user_id", user.user.id)
1002
  .execute()
1003
+ )
1004
 
1005
  return rows.data
1006
 
 
1008
  @app.get("/")
1009
  async def root():
1010
  return {
1011
+ "message": "NL2SQL T5-large backend running on Postgres/Neon (no SQLite).",
1012
  "endpoints": [
1013
+ "POST /upload (subir .sql o .zip con .sql → crea schema en Neon)",
1014
+ "GET /connections (listar BDs subidas en esta instancia)",
1015
  "GET /schema/{id} (esquema resumido)",
1016
  "GET /preview/{id}/{t} (preview de tabla)",
1017
  "POST /infer (NL→SQL + ejecución en BD)",
1018
+ "POST /speech-infer (voz → NL→SQL + ejecución)",
1019
+ "GET /history (historial de consultas en Supabase)",
1020
+ "GET /my-databases (BDs del usuario en Supabase)",
1021
  "GET /health (estado del backend)",
1022
  "GET /docs (OpenAPI UI)",
1023
  ],