Melika Kheirieh commited on
Commit
5cbfffe
·
1 Parent(s): 8224e04

style: format code with ruff

Browse files
Files changed (4) hide show
  1. Makefile +14 -8
  2. app/main.py +3 -3
  3. app/routers/nl2sql.py +28 -6
  4. demo/app.py +21 -6
Makefile CHANGED
@@ -1,9 +1,12 @@
1
-
2
  # ---------- Config ----------
3
  VENV_DIR ?= .venv
4
  PY ?= $(if $(wildcard $(VENV_DIR)/bin/python),$(VENV_DIR)/bin/python,python3)
5
  PIP ?= $(if $(wildcard $(VENV_DIR)/bin/pip),$(VENV_DIR)/bin/pip,pip)
6
  UVICORN ?= $(if $(wildcard $(VENV_DIR)/bin/uvicorn),$(VENV_DIR)/bin/uvicorn,uvicorn)
 
 
 
 
7
  DOCKER_IMG ?= nl2sql-copilot
8
  PORT ?= 8000
9
 
@@ -30,29 +33,32 @@ dev-install: ## Install dev tools (ruff, mypy, pytest, coverage, uvicorn, etc.)
30
  $(PIP) install -U pip wheel
31
  $(PIP) install ruff mypy pytest pytest-cov uvicorn
32
 
 
 
 
33
  # ---------- Quality ----------
34
  .PHONY: format
35
  format: ## Auto-format & fix with ruff
36
- $(VENV_DIR)/bin/ruff format .
37
- $(VENV_DIR)/bin/ruff check . --fix
38
 
39
  .PHONY: lint
40
  lint: ## Run linting and type checking
41
- $(VENV_DIR)/bin/ruff check .
42
- $(VENV_DIR)/bin/mypy .
43
 
44
  .PHONY: typecheck
45
  typecheck: ## Run type checking only
46
- $(VENV_DIR)/bin/mypy .
47
 
48
  # ---------- Tests ----------
49
  .PHONY: test
50
  test: ## Run pytest quietly
51
- PYTHONPATH=$$PWD $(VENV_DIR)/bin/pytest -q
52
 
53
  .PHONY: cov
54
  cov: ## Run tests with coverage
55
- PYTHONPATH=$$PWD $(VENV_DIR)/bin/pytest --cov=nl2sql --cov-report=term-missing
56
 
57
  # ---------- Run ----------
58
  .PHONY: run
 
 
1
  # ---------- Config ----------
2
  VENV_DIR ?= .venv
3
  PY ?= $(if $(wildcard $(VENV_DIR)/bin/python),$(VENV_DIR)/bin/python,python3)
4
  PIP ?= $(if $(wildcard $(VENV_DIR)/bin/pip),$(VENV_DIR)/bin/pip,pip)
5
  UVICORN ?= $(if $(wildcard $(VENV_DIR)/bin/uvicorn),$(VENV_DIR)/bin/uvicorn,uvicorn)
6
+ RUFF ?= $(if $(wildcard $(VENV_DIR)/bin/ruff),$(VENV_DIR)/bin/ruff,ruff)
7
+ MYPY ?= $(if $(wildcard $(VENV_DIR)/bin/mypy),$(VENV_DIR)/bin/mypy,mypy)
8
+ PYTEST ?= $(if $(wildcard $(VENV_DIR)/bin/pytest),$(VENV_DIR)/bin/pytest,pytest)
9
+
10
  DOCKER_IMG ?= nl2sql-copilot
11
  PORT ?= 8000
12
 
 
33
  $(PIP) install -U pip wheel
34
  $(PIP) install ruff mypy pytest pytest-cov uvicorn
35
 
36
+ .PHONY: bootstrap
37
+ bootstrap: venv dev-install ## Create venv and install dev tools
38
+
39
  # ---------- Quality ----------
40
  .PHONY: format
41
  format: ## Auto-format & fix with ruff
42
+ $(RUFF) format .
43
+ $(RUFF) check . --fix
44
 
45
  .PHONY: lint
46
  lint: ## Run linting and type checking
47
+ $(RUFF) check .
48
+ $(MYPY) .
49
 
50
  .PHONY: typecheck
51
  typecheck: ## Run type checking only
52
+ $(MYPY) .
53
 
54
  # ---------- Tests ----------
55
  .PHONY: test
56
  test: ## Run pytest quietly
57
+ PYTHONPATH=$$PWD $(PYTEST) -q
58
 
59
  .PHONY: cov
60
  cov: ## Run tests with coverage
61
+ PYTHONPATH=$$PWD $(PYTEST) --cov=nl2sql --cov-report=term-missing
62
 
63
  # ---------- Run ----------
64
  .PHONY: run
app/main.py CHANGED
@@ -1,9 +1,9 @@
1
  from dotenv import load_dotenv
2
- load_dotenv()
3
 
4
- from fastapi import FastAPI # noqa: E402
5
- from app.routers import nl2sql # noqa: E402
6
 
 
 
7
 
8
 
9
  app = FastAPI(
 
1
  from dotenv import load_dotenv
 
2
 
3
+ load_dotenv()
 
4
 
5
+ from fastapi import FastAPI # noqa: E402
6
+ from app.routers import nl2sql # noqa: E402
7
 
8
 
9
  app = FastAPI(
app/routers/nl2sql.py CHANGED
@@ -36,12 +36,17 @@ _DB_MAP: Dict[str, Dict[str, object]] = {}
36
  # -------------------------------
37
  DB_MODE = os.getenv("DB_MODE", "sqlite").lower() # "sqlite" or "postgres"
38
  POSTGRES_DSN = os.getenv("POSTGRES_DSN")
39
- DEFAULT_SQLITE_DB = os.getenv("DEFAULT_SQLITE_DB", "data/chinook.db") # keep your current default
 
 
 
40
 
41
  def _cleanup_db_map() -> None:
42
  """Remove expired uploaded DB files (best-effort)."""
43
  now = time.time()
44
- expired = [k for k, v in _DB_MAP.items() if now - float(v.get("ts", 0)) > _DB_TTL_SECONDS]
 
 
45
  for k in expired:
46
  path = _DB_MAP[k].get("path")
47
  try:
@@ -51,6 +56,7 @@ def _cleanup_db_map() -> None:
51
  pass
52
  _DB_MAP.pop(k, None)
53
 
 
54
  def _resolve_sqlite_path(db_id: Optional[str]) -> str:
55
  """Resolve a SQLite file path from db_id or fallback to default."""
56
  _cleanup_db_map()
@@ -58,6 +64,7 @@ def _resolve_sqlite_path(db_id: Optional[str]) -> str:
58
  return str(_DB_MAP[db_id]["path"])
59
  return DEFAULT_SQLITE_DB
60
 
 
61
  def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapter]:
62
  """
63
  Build a DB adapter for this request.
@@ -66,7 +73,9 @@ def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapte
66
  """
67
  if DB_MODE == "postgres":
68
  if not POSTGRES_DSN:
69
- raise HTTPException(status_code=500, detail="POSTGRES_DSN is not configured")
 
 
70
  return PostgresAdapter(POSTGRES_DSN)
71
 
72
  # sqlite mode
@@ -75,12 +84,14 @@ def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapte
75
  # If not, ensure your adapter enforces PRAGMA query_only=ON and prevents DDL/DML.
76
  return SQLiteAdapter(sqlite_path)
77
 
 
78
  # -------------------------------
79
  # LLM providers & shared components (stateless)
80
  # -------------------------------
81
  def get_llm():
82
  return OpenAIProvider()
83
 
 
84
  _detector = AmbiguityDetector()
85
  _planner = Planner(get_llm())
86
  _generator = Generator(get_llm())
@@ -88,6 +99,7 @@ _safety = Safety()
88
  _verifier = Verifier()
89
  _repair = Repair(get_llm())
90
 
 
91
  def _build_pipeline(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> Pipeline:
92
  """Build a fresh Pipeline with a per-request Executor bound to the chosen adapter."""
93
  executor = Executor(adapter)
@@ -101,6 +113,7 @@ def _build_pipeline(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> Pipeline:
101
  repair=_repair,
102
  )
103
 
 
104
  # -------------------------------
105
  # Helpers
106
  # -------------------------------
@@ -108,6 +121,7 @@ def _to_dict(obj):
108
  """Safely convert dataclass → dict."""
109
  return asdict(obj) if is_dataclass(obj) else obj
110
 
 
111
  def _round_trace(t: dict) -> dict:
112
  """Round float fields to keep responses tidy and stable."""
113
  if t.get("cost_usd") is not None:
@@ -116,6 +130,7 @@ def _round_trace(t: dict) -> dict:
116
  t["duration_ms"] = round(t["duration_ms"], 2)
117
  return t
118
 
 
119
  # -------------------------------
120
  # Upload endpoint (SQLite only)
121
  # Path will be /api/nl2sql/upload_db if your root APIRouter is mounted at /api
@@ -130,16 +145,22 @@ async def upload_db(file: UploadFile = File(...)):
130
  - Files are stored under /tmp and cleaned by TTL.
131
  """
132
  if DB_MODE != "sqlite":
133
- raise HTTPException(status_code=400, detail="DB upload is only supported in sqlite mode")
 
 
134
 
135
  filename = file.filename or "db.sqlite"
136
  if not (filename.endswith(".db") or filename.endswith(".sqlite")):
137
- raise HTTPException(status_code=400, detail="Only .db or .sqlite files are allowed")
 
 
138
 
139
  data = await file.read()
140
  max_bytes = int(os.getenv("UPLOAD_MAX_BYTES", str(20 * 1024 * 1024))) # 20 MB
141
  if len(data) > max_bytes:
142
- raise HTTPException(status_code=400, detail=f"File too large (> {max_bytes} bytes)")
 
 
143
 
144
  db_id = str(uuid.uuid4())
145
  out_path = os.path.join(_DB_UPLOAD_DIR, f"{db_id}.sqlite")
@@ -152,6 +173,7 @@ async def upload_db(file: UploadFile = File(...)):
152
  _DB_MAP[db_id] = {"path": out_path, "ts": time.time()}
153
  return {"db_id": db_id}
154
 
 
155
  # -------------------------------
156
  # Main NL2SQL endpoint
157
  # Path will be /api/nl2sql if your root APIRouter is mounted at /api
 
36
  # -------------------------------
37
  DB_MODE = os.getenv("DB_MODE", "sqlite").lower() # "sqlite" or "postgres"
38
  POSTGRES_DSN = os.getenv("POSTGRES_DSN")
39
+ DEFAULT_SQLITE_DB = os.getenv(
40
+ "DEFAULT_SQLITE_DB", "data/chinook.db"
41
+ ) # keep your current default
42
+
43
 
44
  def _cleanup_db_map() -> None:
45
  """Remove expired uploaded DB files (best-effort)."""
46
  now = time.time()
47
+ expired = [
48
+ k for k, v in _DB_MAP.items() if now - float(v.get("ts", 0)) > _DB_TTL_SECONDS
49
+ ]
50
  for k in expired:
51
  path = _DB_MAP[k].get("path")
52
  try:
 
56
  pass
57
  _DB_MAP.pop(k, None)
58
 
59
+
60
  def _resolve_sqlite_path(db_id: Optional[str]) -> str:
61
  """Resolve a SQLite file path from db_id or fallback to default."""
62
  _cleanup_db_map()
 
64
  return str(_DB_MAP[db_id]["path"])
65
  return DEFAULT_SQLITE_DB
66
 
67
+
68
  def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapter]:
69
  """
70
  Build a DB adapter for this request.
 
73
  """
74
  if DB_MODE == "postgres":
75
  if not POSTGRES_DSN:
76
+ raise HTTPException(
77
+ status_code=500, detail="POSTGRES_DSN is not configured"
78
+ )
79
  return PostgresAdapter(POSTGRES_DSN)
80
 
81
  # sqlite mode
 
84
  # If not, ensure your adapter enforces PRAGMA query_only=ON and prevents DDL/DML.
85
  return SQLiteAdapter(sqlite_path)
86
 
87
+
88
  # -------------------------------
89
  # LLM providers & shared components (stateless)
90
  # -------------------------------
91
  def get_llm():
92
  return OpenAIProvider()
93
 
94
+
95
  _detector = AmbiguityDetector()
96
  _planner = Planner(get_llm())
97
  _generator = Generator(get_llm())
 
99
  _verifier = Verifier()
100
  _repair = Repair(get_llm())
101
 
102
+
103
  def _build_pipeline(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> Pipeline:
104
  """Build a fresh Pipeline with a per-request Executor bound to the chosen adapter."""
105
  executor = Executor(adapter)
 
113
  repair=_repair,
114
  )
115
 
116
+
117
  # -------------------------------
118
  # Helpers
119
  # -------------------------------
 
121
  """Safely convert dataclass → dict."""
122
  return asdict(obj) if is_dataclass(obj) else obj
123
 
124
+
125
  def _round_trace(t: dict) -> dict:
126
  """Round float fields to keep responses tidy and stable."""
127
  if t.get("cost_usd") is not None:
 
130
  t["duration_ms"] = round(t["duration_ms"], 2)
131
  return t
132
 
133
+
134
  # -------------------------------
135
  # Upload endpoint (SQLite only)
136
  # Path will be /api/nl2sql/upload_db if your root APIRouter is mounted at /api
 
145
  - Files are stored under /tmp and cleaned by TTL.
146
  """
147
  if DB_MODE != "sqlite":
148
+ raise HTTPException(
149
+ status_code=400, detail="DB upload is only supported in sqlite mode"
150
+ )
151
 
152
  filename = file.filename or "db.sqlite"
153
  if not (filename.endswith(".db") or filename.endswith(".sqlite")):
154
+ raise HTTPException(
155
+ status_code=400, detail="Only .db or .sqlite files are allowed"
156
+ )
157
 
158
  data = await file.read()
159
  max_bytes = int(os.getenv("UPLOAD_MAX_BYTES", str(20 * 1024 * 1024))) # 20 MB
160
  if len(data) > max_bytes:
161
+ raise HTTPException(
162
+ status_code=400, detail=f"File too large (> {max_bytes} bytes)"
163
+ )
164
 
165
  db_id = str(uuid.uuid4())
166
  out_path = os.path.join(_DB_UPLOAD_DIR, f"{db_id}.sqlite")
 
173
  _DB_MAP[db_id] = {"path": out_path, "ts": time.time()}
174
  return {"db_id": db_id}
175
 
176
+
177
  # -------------------------------
178
  # Main NL2SQL endpoint
179
  # Path will be /api/nl2sql if your root APIRouter is mounted at /api
demo/app.py CHANGED
@@ -3,7 +3,7 @@ import requests
3
  import gradio as gr
4
 
5
  API_UPLOAD = "http://localhost:8000/api/v1/nl2sql/upload_db"
6
- API_QUERY = "http://localhost:8000/api/v1/nl2sql"
7
 
8
 
9
  def upload_db(file_obj):
@@ -44,8 +44,12 @@ def query_to_sql(user_query, db_id, debug):
44
 
45
  # Flags summary
46
  ambiguous = "Yes" if d.get("ambiguous") else "No"
47
- safety = ("Allowed" if d.get("safety", {}).get("allowed", True) else f"Blocked: {d.get('safety', {}).get('blocked_reason')}")
48
- verification = ("Passed" if d.get("verification", {}).get("passed") else "Failed")
 
 
 
 
49
  repair = d.get("repair", {})
50
  repair_text = f"Applied: {repair.get('applied', False)}, Attempts: {repair.get('attempts', 0)}"
51
 
@@ -70,7 +74,9 @@ with gr.Blocks(title="NL2SQL Copilot") as demo:
70
  db_state = gr.State(value=None)
71
 
72
  with gr.Row():
73
- db_file = gr.File(label="Upload SQLite (.db/.sqlite)", file_types=[".db", ".sqlite"])
 
 
74
  upload_btn = gr.Button("Upload DB")
75
  db_msg = gr.Markdown()
76
  upload_btn.click(upload_db, inputs=[db_file], outputs=[db_state, db_msg])
@@ -100,9 +106,18 @@ with gr.Blocks(title="NL2SQL Copilot") as demo:
100
  run.click(
101
  query_to_sql,
102
  inputs=[q, db_state, debug],
103
- outputs=[badges, sql_out, exp_out, res_out, trace, repair_candidates, repair_diff, timings],
 
 
 
 
 
 
 
 
 
104
  )
105
 
106
  if __name__ == "__main__":
107
  # Let Gradio pick a free port by default to avoid collisions
108
- demo.launch()
 
3
  import gradio as gr
4
 
5
  API_UPLOAD = "http://localhost:8000/api/v1/nl2sql/upload_db"
6
+ API_QUERY = "http://localhost:8000/api/v1/nl2sql"
7
 
8
 
9
  def upload_db(file_obj):
 
44
 
45
  # Flags summary
46
  ambiguous = "Yes" if d.get("ambiguous") else "No"
47
+ safety = (
48
+ "Allowed"
49
+ if d.get("safety", {}).get("allowed", True)
50
+ else f"Blocked: {d.get('safety', {}).get('blocked_reason')}"
51
+ )
52
+ verification = "Passed" if d.get("verification", {}).get("passed") else "Failed"
53
  repair = d.get("repair", {})
54
  repair_text = f"Applied: {repair.get('applied', False)}, Attempts: {repair.get('attempts', 0)}"
55
 
 
74
  db_state = gr.State(value=None)
75
 
76
  with gr.Row():
77
+ db_file = gr.File(
78
+ label="Upload SQLite (.db/.sqlite)", file_types=[".db", ".sqlite"]
79
+ )
80
  upload_btn = gr.Button("Upload DB")
81
  db_msg = gr.Markdown()
82
  upload_btn.click(upload_db, inputs=[db_file], outputs=[db_state, db_msg])
 
106
  run.click(
107
  query_to_sql,
108
  inputs=[q, db_state, debug],
109
+ outputs=[
110
+ badges,
111
+ sql_out,
112
+ exp_out,
113
+ res_out,
114
+ trace,
115
+ repair_candidates,
116
+ repair_diff,
117
+ timings,
118
+ ],
119
  )
120
 
121
  if __name__ == "__main__":
122
  # Let Gradio pick a free port by default to avoid collisions
123
+ demo.launch()