stvnnnnnn commited on
Commit
9a6eae1
·
verified ·
1 Parent(s): abeea5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -67
app.py CHANGED
@@ -8,7 +8,7 @@ import sqlite3
8
  import uuid
9
  from typing import List, Optional, Dict, Any
10
 
11
- from fastapi import FastAPI, UploadFile, File, HTTPException, Form
12
  from fastapi.middleware.cors import CORSMiddleware
13
  from pydantic import BaseModel
14
 
@@ -18,6 +18,14 @@ from langdetect import detect
18
  from transformers import MarianMTModel, MarianTokenizer
19
  from openai import OpenAI
20
 
 
 
 
 
 
 
 
 
21
  # ======================================================
22
  # 0) Configuración general de paths
23
  # ======================================================
@@ -26,14 +34,10 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
26
  UPLOAD_DIR = os.path.join(BASE_DIR, "uploaded_dbs")
27
  os.makedirs(UPLOAD_DIR, exist_ok=True)
28
 
29
- # Modelo NL→SQL entrenado por ti en Hugging Face
30
  MODEL_DIR = os.getenv("MODEL_DIR", "stvnnnnnn/t5-large-nl2sql-spider")
31
- DEVICE = torch.device("cpu") # inferencia en CPU
32
 
33
- # Cliente OpenAI para transcripción de audio (Whisper / gpt-4o-transcribe)
34
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
35
- if not OPENAI_API_KEY:
36
- print("⚠️ OPENAI_API_KEY no está definido. El endpoint /speech-infer no funcionará hasta configurarlo.")
37
  openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
38
 
39
  # ======================================================
@@ -200,19 +204,13 @@ sql_manager = SQLManager()
200
  # ======================================================
201
 
202
  app = FastAPI(
203
- title="NL2SQL T5-large Backend (SQLite engine por ahora)",
204
- description=(
205
- "Intérprete NL→SQL (T5-large Spider) para usuarios no expertos. "
206
- "El usuario sube sus dumps .sql (o ZIP con .sql) y se crean "
207
- "bases dinámicas (actualmente SQLite, futuro Postgres/MySQL)."
208
- ),
209
- version="2.0.0",
210
  )
211
 
212
  app.add_middleware(
213
  CORSMiddleware,
214
  allow_origins=["*"],
215
- allow_credentials=True,
216
  allow_methods=["*"],
217
  allow_headers=["*"],
218
  )
@@ -423,6 +421,13 @@ def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
423
  f"note: use JOIN when foreign keys link tables"
424
  )
425
 
 
 
 
 
 
 
 
426
 
427
  def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
428
  if conn_id not in sql_manager.connections:
@@ -540,6 +545,7 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
540
  "best_rows_preview": best.get("rows_preview"),
541
  "best_columns": best.get("columns", []),
542
  "candidates": candidates,
 
543
  }
544
 
545
 
@@ -649,73 +655,54 @@ async def startup_event():
649
 
650
 
651
  @app.post("/upload", response_model=UploadResponse)
652
- async def upload_database(db_file: UploadFile = File(...)):
653
- """
654
- Subida de BD basada en dumps:
655
- - .sql → dump (schema + data) → BD dinámica (SQLite por ahora)
656
- - .zip debe contener uno o varios .sql (se concatenan)
657
- """
658
- filename = db_file.filename
659
- if not filename:
660
- raise HTTPException(status_code=400, detail="Archivo sin nombre.")
 
 
661
 
 
662
  fname_lower = filename.lower()
663
  contents = await db_file.read()
664
 
665
- note: Optional[str] = None
 
666
 
667
- # Caso 1: dump .sql
668
  if fname_lower.endswith(".sql"):
669
  sql_text = contents.decode("utf-8", errors="ignore")
670
- try:
671
- conn_id = sql_manager.create_database_from_dump(label=filename, sql_text=sql_text)
672
- except Exception as e:
673
- raise HTTPException(
674
- status_code=400,
675
- detail=f"No se pudo crear la BD desde el dump SQL: {e}",
676
- )
677
- meta = sql_manager.connections[conn_id]
678
- engine = meta["engine"]
679
- db_name = meta["db_name"]
680
- note = f"SQL dump imported into {engine.upper()} database '{db_name}'."
681
-
682
- # Caso 2: ZIP con uno o varios .sql
683
  elif fname_lower.endswith(".zip"):
684
- try:
685
- sql_text = _combine_sql_files_from_zip(contents)
686
- except ValueError as ve:
687
- raise HTTPException(status_code=400, detail=str(ve))
688
-
689
- try:
690
- conn_id = sql_manager.create_database_from_dump(label=filename, sql_text=sql_text)
691
- except Exception as e:
692
- raise HTTPException(
693
- status_code=400,
694
- detail=f"No se pudo crear la BD desde los .sql dentro del ZIP: {e}",
695
- )
696
- meta = sql_manager.connections[conn_id]
697
- engine = meta["engine"]
698
- db_name = meta["db_name"]
699
- note = f"ZIP with SQL dumps imported into {engine.upper()} database '{db_name}'."
700
-
701
  else:
702
- raise HTTPException(
703
- status_code=400,
704
- detail="Formato no soportado. Usa: .sql o .zip (con archivos .sql dentro).",
705
- )
 
 
 
706
 
707
  meta = sql_manager.connections[conn_id]
708
- engine = meta["engine"]
709
- db_name = meta["db_name"]
710
 
711
- # db_path pseudo para mantener compatibilidad
712
- db_path = f"{engine}://{db_name}"
 
 
 
 
 
713
 
714
  return UploadResponse(
715
  connection_id=conn_id,
716
- label=meta["label"],
717
- db_path=db_path,
718
- note=note,
719
  )
720
 
721
 
@@ -772,8 +759,44 @@ async def preview_table(connection_id: str, table: str, limit: int = 20):
772
 
773
 
774
  @app.post("/infer", response_model=InferResponse)
775
- async def infer_sql(req: InferRequest):
 
 
 
 
 
 
 
 
 
 
 
776
  result = nl2sql_with_rerank(req.question, req.connection_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
777
  return InferResponse(**result)
778
 
779
 
@@ -826,6 +849,38 @@ async def health():
826
  "device": str(DEVICE),
827
  }
828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
829
 
830
  @app.get("/")
831
  async def root():
 
8
  import uuid
9
  from typing import List, Optional, Dict, Any
10
 
11
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Form, Header
12
  from fastapi.middleware.cors import CORSMiddleware
13
  from pydantic import BaseModel
14
 
 
18
  from transformers import MarianMTModel, MarianTokenizer
19
  from openai import OpenAI
20
 
21
+ # ---- Supabase ----
22
+ from supabase import create_client, Client
23
+
24
+ SUPABASE_URL = "https://bnvmqgjawtaslczewqyd.supabase.co"
25
+ SUPABASE_ANON_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImJudm1xZ2phd3Rhc2xjemV3cXlkIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NjQ0NjM5NDAsImV4cCI6MjA4MDAzOTk0MH0.9zkyqrsm-QOSwMTUPZEWqyFeNpbbuar01rB7pmObkUI"
26
+
27
+ supabase: Client = create_client(SUPABASE_URL, SUPABASE_ANON_KEY)
28
+
29
  # ======================================================
30
  # 0) Configuración general de paths
31
  # ======================================================
 
34
  UPLOAD_DIR = os.path.join(BASE_DIR, "uploaded_dbs")
35
  os.makedirs(UPLOAD_DIR, exist_ok=True)
36
 
 
37
  MODEL_DIR = os.getenv("MODEL_DIR", "stvnnnnnn/t5-large-nl2sql-spider")
38
+ DEVICE = torch.device("cpu")
39
 
 
40
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
 
 
41
  openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
42
 
43
  # ======================================================
 
204
  # ======================================================
205
 
206
  app = FastAPI(
207
+ title="NL2SQL Backend (with Supabase Auth + History)",
208
+ version="2.1.0",
 
 
 
 
 
209
  )
210
 
211
  app.add_middleware(
212
  CORSMiddleware,
213
  allow_origins=["*"],
 
214
  allow_methods=["*"],
215
  allow_headers=["*"],
216
  )
 
421
  f"note: use JOIN when foreign keys link tables"
422
  )
423
 
424
+ def normalize_score(raw: float) -> float:
425
+ """Normaliza el score logit del modelo a un porcentaje 0-100."""
426
+ # Rango típico de logits beam-search: -20 a +5
427
+ norm = (raw + 20) / 25
428
+ norm = max(0, min(1, norm))
429
+ return round(norm * 100, 2)
430
+
431
 
432
  def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
433
  if conn_id not in sql_manager.connections:
 
545
  "best_rows_preview": best.get("rows_preview"),
546
  "best_columns": best.get("columns", []),
547
  "candidates": candidates,
548
+ "score_percent": normalize_score(best["score"]),
549
  }
550
 
551
 
 
655
 
656
 
657
  @app.post("/upload", response_model=UploadResponse)
658
+ async def upload_database(
659
+ db_file: UploadFile = File(...),
660
+ authorization: Optional[str] = Header(None)
661
+ ):
662
+ if authorization is None:
663
+ raise HTTPException(401, "Missing Authorization header")
664
+
665
+ jwt = authorization.replace("Bearer ", "")
666
+ user = supabase.auth.get_user(jwt)
667
+ if not user or not user.user:
668
+ raise HTTPException(401, "Invalid Supabase token")
669
 
670
+ filename = db_file.filename
671
  fname_lower = filename.lower()
672
  contents = await db_file.read()
673
 
674
+ if not filename:
675
+ raise HTTPException(400, "Archivo sin nombre.")
676
 
677
+ # --- procesar SQL ---
678
  if fname_lower.endswith(".sql"):
679
  sql_text = contents.decode("utf-8", errors="ignore")
 
 
 
 
 
 
 
 
 
 
 
 
 
680
  elif fname_lower.endswith(".zip"):
681
+ sql_text = _combine_sql_files_from_zip(contents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682
  else:
683
+ raise HTTPException(400, "Formato no soportado. Usa .sql o .zip.")
684
+
685
+ # --- crear BD dinámica (SQLite temporal) ---
686
+ try:
687
+ conn_id = sql_manager.create_database_from_dump(label=filename, sql_text=sql_text)
688
+ except Exception as e:
689
+ raise HTTPException(400, f"Error creando BD: {e}")
690
 
691
  meta = sql_manager.connections[conn_id]
 
 
692
 
693
+ # --- guardar en Supabase ---
694
+ supabase.table("databases").insert({
695
+ "user_id": user.user.id,
696
+ "filename": filename,
697
+ "engine": meta["engine"],
698
+ "connection_id": conn_id
699
+ }).execute()
700
 
701
  return UploadResponse(
702
  connection_id=conn_id,
703
+ label=filename,
704
+ db_path=f"{meta['engine']}://{meta['db_name']}",
705
+ note="Database stored and indexed in Supabase."
706
  )
707
 
708
 
 
759
 
760
 
761
  @app.post("/infer", response_model=InferResponse)
762
+ async def infer_sql(
763
+ req: InferRequest,
764
+ authorization: Optional[str] = Header(None)
765
+ ):
766
+ if authorization is None:
767
+ raise HTTPException(401, "Missing Authorization header")
768
+
769
+ jwt = authorization.replace("Bearer ", "")
770
+ user = supabase.auth.get_user(jwt)
771
+ if not user or not user.user:
772
+ raise HTTPException(401, "Invalid Supabase token")
773
+
774
  result = nl2sql_with_rerank(req.question, req.connection_id)
775
+ score = normalize_score(result["candidates"][0]["score"])
776
+
777
+ # buscar db_id en supabase
778
+ db_row = supabase.table("databases") \
779
+ .select("id") \
780
+ .eq("connection_id", req.connection_id) \
781
+ .eq("user_id", user.user.id) \
782
+ .execute()
783
+
784
+ db_id = db_row.data[0]["id"] if db_row.data else None
785
+
786
+ # guardar historial
787
+ supabase.table("queries").insert({
788
+ "user_id": user.user.id,
789
+ "db_id": db_id,
790
+ "nl": result["question_original"],
791
+ "sql_generated": result["best_sql"],
792
+ "sql_repaired": result["candidates"][0]["sql"],
793
+ "execution_ok": result["best_exec_ok"],
794
+ "error": result["best_exec_error"],
795
+ "rows_preview": result["best_rows_preview"],
796
+ "score": score
797
+ }).execute()
798
+
799
+ result["score_percent"] = score
800
  return InferResponse(**result)
801
 
802
 
 
849
  "device": str(DEVICE),
850
  }
851
 
852
+ @app.get("/history")
853
+ def get_history(authorization: Optional[str] = Header(None)):
854
+ if authorization is None:
855
+ raise HTTPException(401, "Missing Authorization")
856
+
857
+ jwt = authorization.replace("Bearer ", "")
858
+ user = supabase.auth.get_user(jwt)
859
+
860
+ rows = supabase.table("queries") \
861
+ .select("*") \
862
+ .eq("user_id", user.user.id) \
863
+ .order("created_at", desc=True) \
864
+ .execute()
865
+
866
+ return rows.data
867
+
868
+
869
+ @app.get("/my-databases")
870
+ def get_my_databases(authorization: Optional[str] = Header(None)):
871
+ if authorization is None:
872
+ raise HTTPException(401, "Missing Authorization")
873
+
874
+ jwt = authorization.replace("Bearer ", "")
875
+ user = supabase.auth.get_user(jwt)
876
+
877
+ rows = supabase.table("databases") \
878
+ .select("*") \
879
+ .eq("user_id", user.user.id) \
880
+ .execute()
881
+
882
+ return rows.data
883
+
884
 
885
  @app.get("/")
886
  async def root():