Melika Kheirieh commited on
Commit
72e96d1
·
1 Parent(s): f8b2087

fix: resolve SQLite path mismatch and ambiguity false positives

Browse files
adapters/db/sqlite_adapter.py CHANGED
@@ -17,8 +17,9 @@ class SQLiteAdapter(DBAdapter):
17
  log.info("SQLiteAdapter initialized with DB path: %s", self.path)
18
 
19
  def preview_schema(self, limit_per_table: int = 0) -> str:
20
- uri = self.path.as_uri()
21
- with sqlite3.connect(f"{uri}?mode=ro", uri=True) as conn:
 
22
  cur = conn.cursor()
23
  cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
24
  tables = [t[0] for t in cur.fetchall()]
@@ -30,12 +31,12 @@ class SQLiteAdapter(DBAdapter):
30
  return "\n".join(lines)
31
 
32
  def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
33
- # enforce read-only connection
34
- uri = self.path.as_uri()
35
- log.info("SQLiteAdapter opening read-only connection to: %s", uri)
36
  if not self.path.exists():
37
  raise FileNotFoundError(f"SQLite DB does not exist: {self.path}")
38
- with sqlite3.connect(f"{uri}?mode=ro", uri=True, timeout=3) as conn:
 
 
 
39
  cur = conn.cursor()
40
  log.debug("Executing SQL: %s", sql.strip().replace("\n", " "))
41
  cur.execute(sql)
 
17
  log.info("SQLiteAdapter initialized with DB path: %s", self.path)
18
 
19
  def preview_schema(self, limit_per_table: int = 0) -> str:
20
+ if not self.path.exists():
21
+ raise FileNotFoundError(f"SQLite DB does not exist: {self.path}")
22
+ with sqlite3.connect(f"file:{self.path}?mode=ro", uri=True) as conn:
23
  cur = conn.cursor()
24
  cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
25
  tables = [t[0] for t in cur.fetchall()]
 
31
  return "\n".join(lines)
32
 
33
  def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
 
 
 
34
  if not self.path.exists():
35
  raise FileNotFoundError(f"SQLite DB does not exist: {self.path}")
36
+ # use proper SQLite URI (not .as_uri())
37
+ uri = f"file:{self.path}?mode=ro"
38
+ log.info("SQLiteAdapter opening read-only connection to: %s", uri)
39
+ with sqlite3.connect(uri, uri=True, timeout=3) as conn:
40
  cur = conn.cursor()
41
  log.debug("Executing SQL: %s", sql.strip().replace("\n", " "))
42
  cur.execute(sql)
app/routers/nl2sql.py CHANGED
@@ -17,7 +17,7 @@ from prometheus_client import Counter
17
 
18
  # --- Local ---
19
  from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
20
- from app.state import get_db_path, cleanup_stale_dbs, register_db
21
  from nl2sql.pipeline import FinalResult, FinalResult as _FinalResult
22
  from adapters.llm.openai_provider import OpenAIProvider
23
  from adapters.db.sqlite_adapter import SQLiteAdapter
@@ -145,7 +145,7 @@ _PIPELINE = pipeline_from_config(CONFIG_PATH)
145
  # -------------------------------
146
  def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapter]:
147
  """
148
- Resolve a DB adapter based on module-level DB_MODE and an optional db_id.
149
  """
150
  if DB_MODE == "postgres":
151
  dsn = os.environ.get("POSTGRES_DSN")
@@ -153,24 +153,28 @@ def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapte
153
  raise HTTPException(status_code=500, detail="POSTGRES_DSN env is missing")
154
  return PostgresAdapter(dsn)
155
 
156
- # sqlite mode
157
  if db_id:
158
  cleanup_stale_dbs()
159
- path = get_db_path(db_id)
160
 
161
- if not path or not os.path.exists(path):
162
- tmp_path = Path("/tmp/nl2sql_dbs") / f"{db_id}.sqlite"
163
- if tmp_path.exists():
164
- path = str(tmp_path)
165
 
166
- if path and os.path.exists(path):
167
- return SQLiteAdapter(str(path))
 
 
 
 
 
 
168
 
169
- raise HTTPException(
170
- status_code=404, detail=f"db_id not found or expired: {db_id}"
171
- )
 
 
 
172
 
173
- # default sqlite fallback
174
  default_path = Path(DEFAULT_SQLITE_PATH)
175
  if not default_path.exists():
176
  raise HTTPException(status_code=500, detail="default SQLite DB not found")
 
17
 
18
  # --- Local ---
19
  from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
20
+ from app.state import cleanup_stale_dbs, register_db
21
  from nl2sql.pipeline import FinalResult, FinalResult as _FinalResult
22
  from adapters.llm.openai_provider import OpenAIProvider
23
  from adapters.db.sqlite_adapter import SQLiteAdapter
 
145
  # -------------------------------
146
  def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapter]:
147
  """
148
+ Resolve DB adapter path for SQLite or Postgres.
149
  """
150
  if DB_MODE == "postgres":
151
  dsn = os.environ.get("POSTGRES_DSN")
 
153
  raise HTTPException(status_code=500, detail="POSTGRES_DSN env is missing")
154
  return PostgresAdapter(dsn)
155
 
 
156
  if db_id:
157
  cleanup_stale_dbs()
158
+ import logging
159
 
160
+ log = logging.getLogger(__name__)
 
 
 
161
 
162
+ candidates = [
163
+ Path("/tmp/nl2sql_dbs") / f"{db_id}.sqlite",
164
+ Path("/tmp/nl2sql_dbs") / f"{db_id}.db",
165
+ Path("data/uploads") / f"{db_id}.sqlite",
166
+ Path("data/uploads") / f"{db_id}.db",
167
+ Path("data") / f"{db_id}.sqlite",
168
+ Path("data") / f"{db_id}.db",
169
+ ]
170
 
171
+ for candidate in candidates:
172
+ if candidate.exists():
173
+ log.info(f"✅ Using DB file: {candidate}")
174
+ return SQLiteAdapter(str(candidate))
175
+
176
+ raise HTTPException(status_code=404, detail=f"db_id not found: {db_id}")
177
 
 
178
  default_path = Path(DEFAULT_SQLITE_PATH)
179
  if not default_path.exists():
180
  raise HTTPException(status_code=500, detail="default SQLite DB not found")
app/state.py CHANGED
@@ -1,24 +1,72 @@
1
  import os
2
  import time
3
- from typing import Dict, Any
 
 
4
 
5
- DB_TTL_SECONDS = int(os.getenv("NL2SQL_DB_TTL_SEC", "86400"))
6
- DB_MAP: Dict[str, Dict[str, Any]] = {}
7
 
 
 
 
8
 
9
- def register_db(db_id: str, path: str) -> None:
10
- DB_MAP[db_id] = {"path": path, "created_at": time.time()}
 
11
 
 
 
12
 
13
- def get_db_path(db_id: str) -> str | None:
14
- entry = DB_MAP.get(db_id)
15
- return entry["path"] if entry else None
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  def cleanup_stale_dbs() -> None:
 
19
  now = time.time()
20
- stale = [
21
- k for k, v in DB_MAP.items() if now - v.get("created_at", now) > DB_TTL_SECONDS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ]
23
- for k in stale:
24
- DB_MAP.pop(k, None)
 
 
 
 
 
 
 
1
  import os
2
  import time
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Optional
6
 
7
+ log = logging.getLogger(__name__)
 
8
 
9
+ # ------------------------------
10
+ # Config
11
+ # ------------------------------
12
 
13
+ # default upload directory (can override via .env)
14
+ _DB_UPLOAD_DIR = Path(os.getenv("DB_UPLOAD_DIR", "/tmp/nl2sql_dbs"))
15
+ _DB_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
16
 
17
+ # in-memory map: {db_id: {"path": str, "ts": float}}
18
+ DB_MAP: dict[str, dict[str, str | float]] = {}
19
 
20
+ # cleanup threshold (hours)
21
+ DB_TTL_HOURS = 6
22
+
23
+
24
+ # ------------------------------
25
+ # Helpers
26
+ # ------------------------------
27
+
28
+
29
+ def register_db(db_id: str, path: str) -> None:
30
+ """Register new DB in memory (and ensure dir exists)."""
31
+ _DB_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
32
+ DB_MAP[db_id] = {"path": path, "ts": time.time()}
33
+ log.info(f"📦 Registered DB {db_id} -> {path}")
34
 
35
 
36
  def cleanup_stale_dbs() -> None:
37
+ """Remove expired DBs from /tmp/nl2sql_dbs and memory map."""
38
  now = time.time()
39
+ cutoff = DB_TTL_HOURS * 3600
40
+ stale_ids = [db_id for db_id, entry in DB_MAP.items() if now - entry["ts"] > cutoff]
41
+ for db_id in stale_ids:
42
+ path = DB_MAP[db_id]["path"]
43
+ try:
44
+ os.remove(path)
45
+ log.info(f"🧹 Deleted stale DB: {path}")
46
+ except FileNotFoundError:
47
+ pass
48
+ DB_MAP.pop(db_id, None)
49
+
50
+
51
+ def get_db_path(db_id: str) -> Optional[str]:
52
+ """Return full path of an uploaded DB (persistent lookup)."""
53
+ # ⃣ in-memory lookup
54
+ entry = DB_MAP.get(db_id)
55
+ if entry and Path(entry["path"]).exists():
56
+ return entry["path"]
57
+
58
+ # ⃣ persistent fallback scan
59
+ candidates = [
60
+ _DB_UPLOAD_DIR / f"{db_id}.sqlite",
61
+ _DB_UPLOAD_DIR / f"{db_id}.db",
62
+ Path("data/uploads") / f"{db_id}.sqlite",
63
+ Path("data/uploads") / f"{db_id}.db",
64
  ]
65
+ for p in candidates:
66
+ if p.exists():
67
+ log.info(f"🔍 Recovered DB path for {db_id}: {p}")
68
+ return str(p)
69
+
70
+ # ⃣ not found
71
+ log.warning(f"⚠️ DB file not found for id={db_id}")
72
+ return None
nl2sql/ambiguity_detector.py CHANGED
@@ -1,16 +1,43 @@
 
1
  import re
 
 
 
2
 
3
 
4
  class AmbiguityDetector:
5
- """Lightweight AmbiSQL-style ambiguity detection."""
 
 
 
 
6
 
7
- AMBIGUOUS_TERMS = ["recent", "top", "name", "rank", "latest"]
 
 
 
 
 
 
 
 
 
 
8
 
9
  def detect(self, query: str, schema_preview: str) -> list[str]:
10
- hits = []
11
  q_lower = query.lower()
 
 
 
 
 
 
 
12
  for term in self.AMBIGUOUS_TERMS:
13
- if re.search(rf"\b{term}\b", q_lower):
14
- hits.append(f"The term '{term}' is ambiguous in this query.'")
 
 
15
 
16
  return hits
 
1
+ import os
2
  import re
3
+ import logging
4
+
5
+ log = logging.getLogger(__name__)
6
 
7
 
8
  class AmbiguityDetector:
9
+ """Improved AmbiSQL-style ambiguity detection.
10
+
11
+ - Skips detection entirely in DEV_MODE.
12
+ - Ignores qualified references like 'artist.name'.
13
+ """
14
 
15
+ AMBIGUOUS_TERMS = [
16
+ "recent",
17
+ "top",
18
+ "name",
19
+ "rank",
20
+ "latest",
21
+ "id",
22
+ "title",
23
+ "date",
24
+ "type",
25
+ ]
26
 
27
  def detect(self, query: str, schema_preview: str) -> list[str]:
28
+ # Normalize query
29
  q_lower = query.lower()
30
+
31
+ # Skip ambiguity checks entirely in dev mode
32
+ if os.getenv("DEV_MODE") == "1":
33
+ log.warning("Skipping ambiguity detection (DEV_MODE=1).")
34
+ return []
35
+
36
+ hits = []
37
  for term in self.AMBIGUOUS_TERMS:
38
+ # Match only standalone words, not qualified like 'artist.name'
39
+ pattern = rf"(?<!\.)\b{term}\b"
40
+ if re.search(pattern, q_lower):
41
+ hits.append(f"The term '{term}' is ambiguous in this query.")
42
 
43
  return hits