Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,7 +8,7 @@ import sqlite3
|
|
| 8 |
import uuid
|
| 9 |
from typing import List, Optional, Dict, Any
|
| 10 |
|
| 11 |
-
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
|
| 12 |
from fastapi.middleware.cors import CORSMiddleware
|
| 13 |
from pydantic import BaseModel
|
| 14 |
|
|
@@ -18,6 +18,14 @@ from langdetect import detect
|
|
| 18 |
from transformers import MarianMTModel, MarianTokenizer
|
| 19 |
from openai import OpenAI
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
# ======================================================
|
| 22 |
# 0) Configuración general de paths
|
| 23 |
# ======================================================
|
|
@@ -26,14 +34,10 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
| 26 |
UPLOAD_DIR = os.path.join(BASE_DIR, "uploaded_dbs")
|
| 27 |
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
| 28 |
|
| 29 |
-
# Modelo NL→SQL entrenado por ti en Hugging Face
|
| 30 |
MODEL_DIR = os.getenv("MODEL_DIR", "stvnnnnnn/t5-large-nl2sql-spider")
|
| 31 |
-
DEVICE = torch.device("cpu")
|
| 32 |
|
| 33 |
-
# Cliente OpenAI para transcripción de audio (Whisper / gpt-4o-transcribe)
|
| 34 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 35 |
-
if not OPENAI_API_KEY:
|
| 36 |
-
print("⚠️ OPENAI_API_KEY no está definido. El endpoint /speech-infer no funcionará hasta configurarlo.")
|
| 37 |
openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
|
| 38 |
|
| 39 |
# ======================================================
|
|
@@ -200,19 +204,13 @@ sql_manager = SQLManager()
|
|
| 200 |
# ======================================================
|
| 201 |
|
| 202 |
app = FastAPI(
|
| 203 |
-
title="NL2SQL
|
| 204 |
-
|
| 205 |
-
"Intérprete NL→SQL (T5-large Spider) para usuarios no expertos. "
|
| 206 |
-
"El usuario sube sus dumps .sql (o ZIP con .sql) y se crean "
|
| 207 |
-
"bases dinámicas (actualmente SQLite, futuro Postgres/MySQL)."
|
| 208 |
-
),
|
| 209 |
-
version="2.0.0",
|
| 210 |
)
|
| 211 |
|
| 212 |
app.add_middleware(
|
| 213 |
CORSMiddleware,
|
| 214 |
allow_origins=["*"],
|
| 215 |
-
allow_credentials=True,
|
| 216 |
allow_methods=["*"],
|
| 217 |
allow_headers=["*"],
|
| 218 |
)
|
|
@@ -423,6 +421,13 @@ def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
|
|
| 423 |
f"note: use JOIN when foreign keys link tables"
|
| 424 |
)
|
| 425 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
|
| 427 |
def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
|
| 428 |
if conn_id not in sql_manager.connections:
|
|
@@ -540,6 +545,7 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
|
|
| 540 |
"best_rows_preview": best.get("rows_preview"),
|
| 541 |
"best_columns": best.get("columns", []),
|
| 542 |
"candidates": candidates,
|
|
|
|
| 543 |
}
|
| 544 |
|
| 545 |
|
|
@@ -649,73 +655,54 @@ async def startup_event():
|
|
| 649 |
|
| 650 |
|
| 651 |
@app.post("/upload", response_model=UploadResponse)
|
| 652 |
-
async def upload_database(
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
|
|
|
|
|
|
| 661 |
|
|
|
|
| 662 |
fname_lower = filename.lower()
|
| 663 |
contents = await db_file.read()
|
| 664 |
|
| 665 |
-
|
|
|
|
| 666 |
|
| 667 |
-
#
|
| 668 |
if fname_lower.endswith(".sql"):
|
| 669 |
sql_text = contents.decode("utf-8", errors="ignore")
|
| 670 |
-
try:
|
| 671 |
-
conn_id = sql_manager.create_database_from_dump(label=filename, sql_text=sql_text)
|
| 672 |
-
except Exception as e:
|
| 673 |
-
raise HTTPException(
|
| 674 |
-
status_code=400,
|
| 675 |
-
detail=f"No se pudo crear la BD desde el dump SQL: {e}",
|
| 676 |
-
)
|
| 677 |
-
meta = sql_manager.connections[conn_id]
|
| 678 |
-
engine = meta["engine"]
|
| 679 |
-
db_name = meta["db_name"]
|
| 680 |
-
note = f"SQL dump imported into {engine.upper()} database '{db_name}'."
|
| 681 |
-
|
| 682 |
-
# Caso 2: ZIP con uno o varios .sql
|
| 683 |
elif fname_lower.endswith(".zip"):
|
| 684 |
-
|
| 685 |
-
sql_text = _combine_sql_files_from_zip(contents)
|
| 686 |
-
except ValueError as ve:
|
| 687 |
-
raise HTTPException(status_code=400, detail=str(ve))
|
| 688 |
-
|
| 689 |
-
try:
|
| 690 |
-
conn_id = sql_manager.create_database_from_dump(label=filename, sql_text=sql_text)
|
| 691 |
-
except Exception as e:
|
| 692 |
-
raise HTTPException(
|
| 693 |
-
status_code=400,
|
| 694 |
-
detail=f"No se pudo crear la BD desde los .sql dentro del ZIP: {e}",
|
| 695 |
-
)
|
| 696 |
-
meta = sql_manager.connections[conn_id]
|
| 697 |
-
engine = meta["engine"]
|
| 698 |
-
db_name = meta["db_name"]
|
| 699 |
-
note = f"ZIP with SQL dumps imported into {engine.upper()} database '{db_name}'."
|
| 700 |
-
|
| 701 |
else:
|
| 702 |
-
raise HTTPException(
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
|
|
|
|
|
|
|
|
|
| 706 |
|
| 707 |
meta = sql_manager.connections[conn_id]
|
| 708 |
-
engine = meta["engine"]
|
| 709 |
-
db_name = meta["db_name"]
|
| 710 |
|
| 711 |
-
#
|
| 712 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
|
| 714 |
return UploadResponse(
|
| 715 |
connection_id=conn_id,
|
| 716 |
-
label=
|
| 717 |
-
db_path=
|
| 718 |
-
note=
|
| 719 |
)
|
| 720 |
|
| 721 |
|
|
@@ -772,8 +759,44 @@ async def preview_table(connection_id: str, table: str, limit: int = 20):
|
|
| 772 |
|
| 773 |
|
| 774 |
@app.post("/infer", response_model=InferResponse)
|
| 775 |
-
async def infer_sql(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
result = nl2sql_with_rerank(req.question, req.connection_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 777 |
return InferResponse(**result)
|
| 778 |
|
| 779 |
|
|
@@ -826,6 +849,38 @@ async def health():
|
|
| 826 |
"device": str(DEVICE),
|
| 827 |
}
|
| 828 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 829 |
|
| 830 |
@app.get("/")
|
| 831 |
async def root():
|
|
|
|
| 8 |
import uuid
|
| 9 |
from typing import List, Optional, Dict, Any
|
| 10 |
|
| 11 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException, Form, Header
|
| 12 |
from fastapi.middleware.cors import CORSMiddleware
|
| 13 |
from pydantic import BaseModel
|
| 14 |
|
|
|
|
| 18 |
from transformers import MarianMTModel, MarianTokenizer
|
| 19 |
from openai import OpenAI
|
| 20 |
|
| 21 |
+
# ---- Supabase ----
|
| 22 |
+
from supabase import create_client, Client
|
| 23 |
+
|
| 24 |
+
SUPABASE_URL = "https://bnvmqgjawtaslczewqyd.supabase.co"
|
| 25 |
+
SUPABASE_ANON_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImJudm1xZ2phd3Rhc2xjemV3cXlkIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NjQ0NjM5NDAsImV4cCI6MjA4MDAzOTk0MH0.9zkyqrsm-QOSwMTUPZEWqyFeNpbbuar01rB7pmObkUI"
|
| 26 |
+
|
| 27 |
+
supabase: Client = create_client(SUPABASE_URL, SUPABASE_ANON_KEY)
|
| 28 |
+
|
| 29 |
# ======================================================
|
| 30 |
# 0) Configuración general de paths
|
| 31 |
# ======================================================
|
|
|
|
| 34 |
UPLOAD_DIR = os.path.join(BASE_DIR, "uploaded_dbs")
|
| 35 |
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
| 36 |
|
|
|
|
| 37 |
MODEL_DIR = os.getenv("MODEL_DIR", "stvnnnnnn/t5-large-nl2sql-spider")
|
| 38 |
+
DEVICE = torch.device("cpu")
|
| 39 |
|
|
|
|
| 40 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
|
|
|
|
|
| 41 |
openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
|
| 42 |
|
| 43 |
# ======================================================
|
|
|
|
| 204 |
# ======================================================
|
| 205 |
|
| 206 |
app = FastAPI(
|
| 207 |
+
title="NL2SQL Backend (with Supabase Auth + History)",
|
| 208 |
+
version="2.1.0",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
)
|
| 210 |
|
| 211 |
app.add_middleware(
|
| 212 |
CORSMiddleware,
|
| 213 |
allow_origins=["*"],
|
|
|
|
| 214 |
allow_methods=["*"],
|
| 215 |
allow_headers=["*"],
|
| 216 |
)
|
|
|
|
| 421 |
f"note: use JOIN when foreign keys link tables"
|
| 422 |
)
|
| 423 |
|
| 424 |
+
def normalize_score(raw: float) -> float:
|
| 425 |
+
"""Normaliza el score logit del modelo a un porcentaje 0-100."""
|
| 426 |
+
# Rango típico de logits beam-search: -20 a +5
|
| 427 |
+
norm = (raw + 20) / 25
|
| 428 |
+
norm = max(0, min(1, norm))
|
| 429 |
+
return round(norm * 100, 2)
|
| 430 |
+
|
| 431 |
|
| 432 |
def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
|
| 433 |
if conn_id not in sql_manager.connections:
|
|
|
|
| 545 |
"best_rows_preview": best.get("rows_preview"),
|
| 546 |
"best_columns": best.get("columns", []),
|
| 547 |
"candidates": candidates,
|
| 548 |
+
"score_percent": normalize_score(best["score"]),
|
| 549 |
}
|
| 550 |
|
| 551 |
|
|
|
|
| 655 |
|
| 656 |
|
| 657 |
@app.post("/upload", response_model=UploadResponse)
|
| 658 |
+
async def upload_database(
|
| 659 |
+
db_file: UploadFile = File(...),
|
| 660 |
+
authorization: Optional[str] = Header(None)
|
| 661 |
+
):
|
| 662 |
+
if authorization is None:
|
| 663 |
+
raise HTTPException(401, "Missing Authorization header")
|
| 664 |
+
|
| 665 |
+
jwt = authorization.replace("Bearer ", "")
|
| 666 |
+
user = supabase.auth.get_user(jwt)
|
| 667 |
+
if not user or not user.user:
|
| 668 |
+
raise HTTPException(401, "Invalid Supabase token")
|
| 669 |
|
| 670 |
+
filename = db_file.filename
|
| 671 |
fname_lower = filename.lower()
|
| 672 |
contents = await db_file.read()
|
| 673 |
|
| 674 |
+
if not filename:
|
| 675 |
+
raise HTTPException(400, "Archivo sin nombre.")
|
| 676 |
|
| 677 |
+
# --- procesar SQL ---
|
| 678 |
if fname_lower.endswith(".sql"):
|
| 679 |
sql_text = contents.decode("utf-8", errors="ignore")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 680 |
elif fname_lower.endswith(".zip"):
|
| 681 |
+
sql_text = _combine_sql_files_from_zip(contents)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 682 |
else:
|
| 683 |
+
raise HTTPException(400, "Formato no soportado. Usa .sql o .zip.")
|
| 684 |
+
|
| 685 |
+
# --- crear BD dinámica (SQLite temporal) ---
|
| 686 |
+
try:
|
| 687 |
+
conn_id = sql_manager.create_database_from_dump(label=filename, sql_text=sql_text)
|
| 688 |
+
except Exception as e:
|
| 689 |
+
raise HTTPException(400, f"Error creando BD: {e}")
|
| 690 |
|
| 691 |
meta = sql_manager.connections[conn_id]
|
|
|
|
|
|
|
| 692 |
|
| 693 |
+
# --- guardar en Supabase ---
|
| 694 |
+
supabase.table("databases").insert({
|
| 695 |
+
"user_id": user.user.id,
|
| 696 |
+
"filename": filename,
|
| 697 |
+
"engine": meta["engine"],
|
| 698 |
+
"connection_id": conn_id
|
| 699 |
+
}).execute()
|
| 700 |
|
| 701 |
return UploadResponse(
|
| 702 |
connection_id=conn_id,
|
| 703 |
+
label=filename,
|
| 704 |
+
db_path=f"{meta['engine']}://{meta['db_name']}",
|
| 705 |
+
note="Database stored and indexed in Supabase."
|
| 706 |
)
|
| 707 |
|
| 708 |
|
|
|
|
| 759 |
|
| 760 |
|
| 761 |
@app.post("/infer", response_model=InferResponse)
|
| 762 |
+
async def infer_sql(
|
| 763 |
+
req: InferRequest,
|
| 764 |
+
authorization: Optional[str] = Header(None)
|
| 765 |
+
):
|
| 766 |
+
if authorization is None:
|
| 767 |
+
raise HTTPException(401, "Missing Authorization header")
|
| 768 |
+
|
| 769 |
+
jwt = authorization.replace("Bearer ", "")
|
| 770 |
+
user = supabase.auth.get_user(jwt)
|
| 771 |
+
if not user or not user.user:
|
| 772 |
+
raise HTTPException(401, "Invalid Supabase token")
|
| 773 |
+
|
| 774 |
result = nl2sql_with_rerank(req.question, req.connection_id)
|
| 775 |
+
score = normalize_score(result["candidates"][0]["score"])
|
| 776 |
+
|
| 777 |
+
# buscar db_id en supabase
|
| 778 |
+
db_row = supabase.table("databases") \
|
| 779 |
+
.select("id") \
|
| 780 |
+
.eq("connection_id", req.connection_id) \
|
| 781 |
+
.eq("user_id", user.user.id) \
|
| 782 |
+
.execute()
|
| 783 |
+
|
| 784 |
+
db_id = db_row.data[0]["id"] if db_row.data else None
|
| 785 |
+
|
| 786 |
+
# guardar historial
|
| 787 |
+
supabase.table("queries").insert({
|
| 788 |
+
"user_id": user.user.id,
|
| 789 |
+
"db_id": db_id,
|
| 790 |
+
"nl": result["question_original"],
|
| 791 |
+
"sql_generated": result["best_sql"],
|
| 792 |
+
"sql_repaired": result["candidates"][0]["sql"],
|
| 793 |
+
"execution_ok": result["best_exec_ok"],
|
| 794 |
+
"error": result["best_exec_error"],
|
| 795 |
+
"rows_preview": result["best_rows_preview"],
|
| 796 |
+
"score": score
|
| 797 |
+
}).execute()
|
| 798 |
+
|
| 799 |
+
result["score_percent"] = score
|
| 800 |
return InferResponse(**result)
|
| 801 |
|
| 802 |
|
|
|
|
| 849 |
"device": str(DEVICE),
|
| 850 |
}
|
| 851 |
|
| 852 |
+
@app.get("/history")
|
| 853 |
+
def get_history(authorization: Optional[str] = Header(None)):
|
| 854 |
+
if authorization is None:
|
| 855 |
+
raise HTTPException(401, "Missing Authorization")
|
| 856 |
+
|
| 857 |
+
jwt = authorization.replace("Bearer ", "")
|
| 858 |
+
user = supabase.auth.get_user(jwt)
|
| 859 |
+
|
| 860 |
+
rows = supabase.table("queries") \
|
| 861 |
+
.select("*") \
|
| 862 |
+
.eq("user_id", user.user.id) \
|
| 863 |
+
.order("created_at", desc=True) \
|
| 864 |
+
.execute()
|
| 865 |
+
|
| 866 |
+
return rows.data
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
@app.get("/my-databases")
|
| 870 |
+
def get_my_databases(authorization: Optional[str] = Header(None)):
|
| 871 |
+
if authorization is None:
|
| 872 |
+
raise HTTPException(401, "Missing Authorization")
|
| 873 |
+
|
| 874 |
+
jwt = authorization.replace("Bearer ", "")
|
| 875 |
+
user = supabase.auth.get_user(jwt)
|
| 876 |
+
|
| 877 |
+
rows = supabase.table("databases") \
|
| 878 |
+
.select("*") \
|
| 879 |
+
.eq("user_id", user.user.id) \
|
| 880 |
+
.execute()
|
| 881 |
+
|
| 882 |
+
return rows.data
|
| 883 |
+
|
| 884 |
|
| 885 |
@app.get("/")
|
| 886 |
async def root():
|