stvnnnnnn commited on
Commit
1a1c26f
·
verified ·
1 Parent(s): ef7fd10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -44
app.py CHANGED
@@ -4,6 +4,8 @@ import zipfile
4
  import re
5
  import difflib
6
  import tempfile
 
 
7
  from typing import List, Optional, Dict, Any
8
 
9
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
@@ -15,26 +17,19 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
15
  from langdetect import detect
16
  from transformers import MarianMTModel, MarianTokenizer
17
  from openai import OpenAI
18
- import sys
19
-
20
- # --- forzar que el directorio actual (donde está app.py y sqlmanager.py) esté en sys.path ---
21
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
22
- if BASE_DIR not in sys.path:
23
- sys.path.append(BASE_DIR)
24
-
25
- from sqlmanager import SQLManager
26
 
27
  # ======================================================
28
- # 0) Configuración general
29
  # ======================================================
30
 
 
 
 
 
31
  # Modelo NL→SQL entrenado por ti en Hugging Face
32
  MODEL_DIR = os.getenv("MODEL_DIR", "stvnnnnnn/t5-large-nl2sql-spider")
33
  DEVICE = torch.device("cpu") # inferencia en CPU
34
 
35
- # Gestor de conexiones reales (MySQL/PostgreSQL)
36
- sql_manager = SQLManager()
37
-
38
  # Cliente OpenAI para transcripción de audio (Whisper / gpt-4o-transcribe)
39
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
40
  if not OPENAI_API_KEY:
@@ -42,29 +37,188 @@ if not OPENAI_API_KEY:
42
  openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
43
 
44
  # ======================================================
45
- # 1) Inicialización de FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # ======================================================
47
 
48
  app = FastAPI(
49
- title="NL2SQL T5-large Backend (MySQL/PostgreSQL)",
50
  description=(
51
  "Intérprete NL→SQL (T5-large Spider) para usuarios no expertos. "
52
- "El usuario sube sus dumps .sql (o ZIP con .sql) y se levantan "
53
- "bases reales en MySQL/PostgreSQL; las consultas se ejecutan ahí."
54
  ),
55
  version="2.0.0",
56
  )
57
 
58
  app.add_middleware(
59
  CORSMiddleware,
60
- allow_origins=["*"], # en producción puedes acotar a tu dominio
61
  allow_credentials=True,
62
  allow_methods=["*"],
63
  allow_headers=["*"],
64
  )
65
 
66
  # ======================================================
67
- # 2) Modelo NL→SQL y traductor ES→EN
68
  # ======================================================
69
 
70
  t5_tokenizer = None
@@ -124,7 +278,7 @@ def translate_es_to_en(text: str) -> str:
124
 
125
 
126
  # ======================================================
127
- # 3) Capa de reparación de SQL (usa el schema real)
128
  # ======================================================
129
 
130
  def _normalize_name_for_match(name: str) -> str:
@@ -259,7 +413,7 @@ def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optiona
259
 
260
 
261
  # ======================================================
262
- # 4) Construcción de prompt y NL→SQL + re-ranking
263
  # ======================================================
264
 
265
  def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
@@ -274,7 +428,7 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
274
  if conn_id not in sql_manager.connections:
275
  raise HTTPException(status_code=404, detail=f"connection_id '{conn_id}' no registrado")
276
 
277
- # Obtener esquema real desde MySQL/Postgres
278
  meta = sql_manager.get_schema(conn_id)
279
  tables_info = meta["tables"]
280
 
@@ -331,10 +485,11 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
331
  exec_info = sql_manager.execute_sql(conn_id, raw_sql)
332
 
333
  # Intentar reparación solo si es error por tabla/columna
 
334
  if (not exec_info["ok"]) and (
335
- "no such table" in (exec_info["error"] or "").lower()
336
- or "no such column" in (exec_info["error"] or "").lower()
337
- or "does not exist" in (exec_info["error"] or "").lower()
338
  ):
339
  current_sql = raw_sql
340
  last_error = exec_info["error"] or ""
@@ -389,13 +544,13 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
389
 
390
 
391
  # ======================================================
392
- # 5) Schemas Pydantic
393
  # ======================================================
394
 
395
  class UploadResponse(BaseModel):
396
  connection_id: str
397
  label: str
398
- db_path: str # ahora será un pseudo-path (engine://db_name)
399
  note: Optional[str] = None
400
 
401
 
@@ -444,7 +599,7 @@ class SpeechInferResponse(BaseModel):
444
 
445
 
446
  # ======================================================
447
- # 6) Helpers para /upload (.sql y .zip)
448
  # ======================================================
449
 
450
  def _combine_sql_files_from_zip(zip_bytes: bytes) -> str:
@@ -482,13 +637,13 @@ def _combine_sql_files_from_zip(zip_bytes: bytes) -> str:
482
 
483
 
484
  # ======================================================
485
- # 7) Endpoints FastAPI
486
  # ======================================================
487
 
488
  @app.on_event("startup")
489
  async def startup_event():
490
  load_nl2sql_model()
491
- print("✅ Backend NL2SQL inicializado (MySQL/PostgreSQL).")
492
  print(f"MODEL_DIR={MODEL_DIR}, DEVICE={DEVICE}")
493
  print(f"Conexiones activas al inicio: {len(sql_manager.connections)}")
494
 
@@ -497,7 +652,7 @@ async def startup_event():
497
  async def upload_database(db_file: UploadFile = File(...)):
498
  """
499
  Subida de BD basada en dumps:
500
- - .sql → dump MySQL/PostgreSQL (schema + data) → BD real
501
  - .zip → debe contener uno o varios .sql (se concatenan)
502
  """
503
  filename = db_file.filename
@@ -553,7 +708,7 @@ async def upload_database(db_file: UploadFile = File(...)):
553
  engine = meta["engine"]
554
  db_name = meta["db_name"]
555
 
556
- # db_path ahora es un pseudo-path para mantener compatibilidad
557
  db_path = f"{engine}://{db_name}"
558
 
559
  return UploadResponse(
@@ -568,18 +723,12 @@ async def upload_database(db_file: UploadFile = File(...)):
568
  async def list_connections():
569
  return [
570
  ConnectionInfo(
571
- connection_id=c["connection_id"],
572
- label=c.get("label", ""),
573
- engine=c.get("engine"),
574
- db_name=c.get("db_name"),
575
  )
576
- for c in [
577
- {
578
- "connection_id": cid,
579
- **meta,
580
- }
581
- for cid, meta in sql_manager.connections.items()
582
- ]
583
  ]
584
 
585
 
@@ -681,13 +830,13 @@ async def health():
681
  @app.get("/")
682
  async def root():
683
  return {
684
- "message": "NL2SQL T5-large backend is running with real MySQL/PostgreSQL engines.",
685
  "endpoints": [
686
  "POST /upload (subir .sql o .zip con .sql → crear BD dinámica)",
687
  "GET /connections (listar BDs subidas)",
688
  "GET /schema/{id} (esquema resumido)",
689
  "GET /preview/{id}/{t} (preview de tabla)",
690
- "POST /infer (NL→SQL + ejecución en BD real)",
691
  "POST /speech-infer (NL por voz → SQL + ejecución)",
692
  "GET /health (estado del backend)",
693
  "GET /docs (OpenAPI UI)",
 
4
  import re
5
  import difflib
6
  import tempfile
7
+ import sqlite3
8
+ import uuid
9
  from typing import List, Optional, Dict, Any
10
 
11
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
 
17
  from langdetect import detect
18
  from transformers import MarianMTModel, MarianTokenizer
19
  from openai import OpenAI
 
 
 
 
 
 
 
 
20
 
21
  # ======================================================
22
+ # 0) Configuración general de paths
23
  # ======================================================
24
 
25
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
26
+ UPLOAD_DIR = os.path.join(BASE_DIR, "uploaded_dbs")
27
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
28
+
29
  # Modelo NL→SQL entrenado por ti en Hugging Face
30
  MODEL_DIR = os.getenv("MODEL_DIR", "stvnnnnnn/t5-large-nl2sql-spider")
31
  DEVICE = torch.device("cpu") # inferencia en CPU
32
 
 
 
 
33
  # Cliente OpenAI para transcripción de audio (Whisper / gpt-4o-transcribe)
34
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
35
  if not OPENAI_API_KEY:
 
37
  openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
38
 
39
  # ======================================================
40
+ # 1) SQLManager (versión actual: SQLite local)
41
+ # ======================================================
42
+
43
+ class SQLManager:
44
+ """
45
+ Gestor de "conexiones" a bases dinámicas.
46
+ Versión actual: cada conexión es un archivo SQLite en UPLOAD_DIR.
47
+ API pensada para poder cambiar después a Postgres/MySQL (Railway).
48
+ """
49
+
50
+ def __init__(self):
51
+ # connections[connection_id] = {
52
+ # "label": str,
53
+ # "engine": "sqlite",
54
+ # "db_name": str,
55
+ # "db_path": str
56
+ # }
57
+ self.connections: Dict[str, Dict[str, Any]] = {}
58
+
59
+ # ---------- utilidades internas ----------
60
+
61
+ def _new_connection_id(self) -> str:
62
+ return f"db_{uuid.uuid4().hex[:8]}"
63
+
64
+ def _get_info(self, connection_id: str) -> Dict[str, Any]:
65
+ if connection_id not in self.connections:
66
+ raise KeyError(f"connection_id '{connection_id}' no registrado")
67
+ return self.connections[connection_id]
68
+
69
+ # ---------- creación de BD desde dump ----------
70
+
71
+ def create_database_from_dump(self, label: str, sql_text: str) -> str:
72
+ """
73
+ Crea un archivo SQLite, ejecuta el dump SQL y
74
+ registra la conexión. Por ahora el dump debe ser
75
+ razonablemente compatible con SQLite.
76
+ """
77
+ connection_id = self._new_connection_id()
78
+ db_name = connection_id # nombre lógico
79
+ db_path = os.path.join(UPLOAD_DIR, f"{db_name}.sqlite")
80
+
81
+ # Ejecutar todo el script. Si falla, borramos el archivo.
82
+ conn = sqlite3.connect(db_path)
83
+ try:
84
+ conn.executescript(sql_text)
85
+ conn.commit()
86
+ except Exception as e:
87
+ conn.close()
88
+ if os.path.exists(db_path):
89
+ os.remove(db_path)
90
+ raise RuntimeError(f"Error ejecutando dump SQL en SQLite: {e}")
91
+ finally:
92
+ conn.close()
93
+
94
+ self.connections[connection_id] = {
95
+ "label": label,
96
+ "engine": "sqlite",
97
+ "db_name": db_name,
98
+ "db_path": db_path,
99
+ }
100
+ return connection_id
101
+
102
+ # ---------- ejecución segura de SQL ----------
103
+
104
+ def execute_sql(self, connection_id: str, sql: str) -> Dict[str, Any]:
105
+ """
106
+ Ejecuta un SELECT sobre la BD asociada al connection_id.
107
+ Bloquea operaciones destructivas por seguridad.
108
+ """
109
+ info = self._get_info(connection_id)
110
+ db_path = info["db_path"]
111
+
112
+ forbidden = ["drop ", "delete ", "update ", "insert ", "alter ", "replace "]
113
+ sql_low = sql.lower()
114
+ if any(tok in sql_low for tok in forbidden):
115
+ return {
116
+ "ok": False,
117
+ "error": "Query bloqueada por seguridad (operación destructiva).",
118
+ "rows": None,
119
+ "columns": [],
120
+ }
121
+
122
+ try:
123
+ conn = sqlite3.connect(db_path)
124
+ cur = conn.cursor()
125
+ cur.execute(sql)
126
+ rows = cur.fetchall()
127
+ cols = [d[0] for d in cur.description] if cur.description else []
128
+ conn.close()
129
+ return {"ok": True, "error": None, "rows": [list(r) for r in rows], "columns": cols}
130
+ except Exception as e:
131
+ return {"ok": False, "error": str(e), "rows": None, "columns": []}
132
+
133
+ # ---------- introspección de esquema ----------
134
+
135
+ def get_schema(self, connection_id: str) -> Dict[str, Any]:
136
+ info = self._get_info(connection_id)
137
+ db_path = info["db_path"]
138
+
139
+ if not os.path.exists(db_path):
140
+ raise RuntimeError(f"SQLite no encontrado: {db_path}")
141
+
142
+ conn = sqlite3.connect(db_path)
143
+ cur = conn.cursor()
144
+
145
+ cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
146
+ tables = [row[0] for row in cur.fetchall()]
147
+
148
+ tables_info: Dict[str, Dict[str, Any]] = {}
149
+ foreign_keys: List[Dict[str, Any]] = []
150
+
151
+ for t in tables:
152
+ cur.execute(f"PRAGMA table_info('{t}');")
153
+ rows = cur.fetchall()
154
+ cols = [r[1] for r in rows]
155
+ tables_info[t] = {"columns": cols}
156
+
157
+ cur.execute(f"PRAGMA foreign_key_list('{t}');")
158
+ fks = cur.fetchall()
159
+ for (id_, seq, ref_table, from_col, to_col, on_update, on_delete, match) in fks:
160
+ foreign_keys.append({
161
+ "from_table": t,
162
+ "from_column": from_col,
163
+ "to_table": ref_table,
164
+ "to_column": to_col,
165
+ })
166
+
167
+ conn.close()
168
+
169
+ return {
170
+ "tables": tables_info,
171
+ "foreign_keys": foreign_keys,
172
+ }
173
+
174
+ # ---------- preview de tabla ----------
175
+
176
+ def get_preview(self, connection_id: str, table: str, limit: int = 20) -> Dict[str, Any]:
177
+ info = self._get_info(connection_id)
178
+ db_path = info["db_path"]
179
+
180
+ conn = sqlite3.connect(db_path)
181
+ cur = conn.cursor()
182
+ try:
183
+ cur.execute(f'SELECT * FROM "{table}" LIMIT {int(limit)};')
184
+ rows = cur.fetchall()
185
+ cols = [d[0] for d in cur.description] if cur.description else []
186
+ finally:
187
+ conn.close()
188
+
189
+ return {
190
+ "columns": cols,
191
+ "rows": [list(r) for r in rows],
192
+ }
193
+
194
+
195
+ # Instancia global de SQLManager
196
+ sql_manager = SQLManager()
197
+
198
+ # ======================================================
199
+ # 2) Inicialización de FastAPI
200
  # ======================================================
201
 
202
  app = FastAPI(
203
+ title="NL2SQL T5-large Backend (SQLite engine por ahora)",
204
  description=(
205
  "Intérprete NL→SQL (T5-large Spider) para usuarios no expertos. "
206
+ "El usuario sube sus dumps .sql (o ZIP con .sql) y se crean "
207
+ "bases dinámicas (actualmente SQLite, futuro Postgres/MySQL)."
208
  ),
209
  version="2.0.0",
210
  )
211
 
212
  app.add_middleware(
213
  CORSMiddleware,
214
+ allow_origins=["*"],
215
  allow_credentials=True,
216
  allow_methods=["*"],
217
  allow_headers=["*"],
218
  )
219
 
220
  # ======================================================
221
+ # 3) Modelo NL→SQL y traductor ES→EN
222
  # ======================================================
223
 
224
  t5_tokenizer = None
 
278
 
279
 
280
  # ======================================================
281
+ # 4) Capa de reparación de SQL (usa el schema real)
282
  # ======================================================
283
 
284
  def _normalize_name_for_match(name: str) -> str:
 
413
 
414
 
415
  # ======================================================
416
+ # 5) Construcción de prompt y NL→SQL + re-ranking
417
  # ======================================================
418
 
419
  def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
 
428
  if conn_id not in sql_manager.connections:
429
  raise HTTPException(status_code=404, detail=f"connection_id '{conn_id}' no registrado")
430
 
431
+ # Obtener esquema real desde SQLite (futuro: Postgres/MySQL)
432
  meta = sql_manager.get_schema(conn_id)
433
  tables_info = meta["tables"]
434
 
 
485
  exec_info = sql_manager.execute_sql(conn_id, raw_sql)
486
 
487
  # Intentar reparación solo si es error por tabla/columna
488
+ err_lower = (exec_info["error"] or "").lower()
489
  if (not exec_info["ok"]) and (
490
+ "no such table" in err_lower
491
+ or "no such column" in err_lower
492
+ or "does not exist" in err_lower
493
  ):
494
  current_sql = raw_sql
495
  last_error = exec_info["error"] or ""
 
544
 
545
 
546
  # ======================================================
547
+ # 6) Schemas Pydantic
548
  # ======================================================
549
 
550
  class UploadResponse(BaseModel):
551
  connection_id: str
552
  label: str
553
+ db_path: str # pseudo-path (engine://db_name o similar)
554
  note: Optional[str] = None
555
 
556
 
 
599
 
600
 
601
  # ======================================================
602
+ # 7) Helpers para /upload (.sql y .zip)
603
  # ======================================================
604
 
605
  def _combine_sql_files_from_zip(zip_bytes: bytes) -> str:
 
637
 
638
 
639
  # ======================================================
640
+ # 8) Endpoints FastAPI
641
  # ======================================================
642
 
643
  @app.on_event("startup")
644
  async def startup_event():
645
  load_nl2sql_model()
646
+ print("✅ Backend NL2SQL inicializado (engine SQLite por ahora).")
647
  print(f"MODEL_DIR={MODEL_DIR}, DEVICE={DEVICE}")
648
  print(f"Conexiones activas al inicio: {len(sql_manager.connections)}")
649
 
 
652
  async def upload_database(db_file: UploadFile = File(...)):
653
  """
654
  Subida de BD basada en dumps:
655
+ - .sql → dump (schema + data) → BD dinámica (SQLite por ahora)
656
  - .zip → debe contener uno o varios .sql (se concatenan)
657
  """
658
  filename = db_file.filename
 
708
  engine = meta["engine"]
709
  db_name = meta["db_name"]
710
 
711
+ # db_path pseudo para mantener compatibilidad
712
  db_path = f"{engine}://{db_name}"
713
 
714
  return UploadResponse(
 
723
  async def list_connections():
724
  return [
725
  ConnectionInfo(
726
+ connection_id=cid,
727
+ label=meta.get("label", ""),
728
+ engine=meta.get("engine"),
729
+ db_name=meta.get("db_name"),
730
  )
731
+ for cid, meta in sql_manager.connections.items()
 
 
 
 
 
 
732
  ]
733
 
734
 
 
830
  @app.get("/")
831
  async def root():
832
  return {
833
+ "message": "NL2SQL T5-large backend is running (engine SQLite, ready to upgrade to Postgres/MySQL).",
834
  "endpoints": [
835
  "POST /upload (subir .sql o .zip con .sql → crear BD dinámica)",
836
  "GET /connections (listar BDs subidas)",
837
  "GET /schema/{id} (esquema resumido)",
838
  "GET /preview/{id}/{t} (preview de tabla)",
839
+ "POST /infer (NL→SQL + ejecución en BD)",
840
  "POST /speech-infer (NL por voz → SQL + ejecución)",
841
  "GET /health (estado del backend)",
842
  "GET /docs (OpenAPI UI)",