Spaces:
Running
Running
Melika Kheirieh
commited on
Commit
·
64907d7
1
Parent(s):
787d215
test(router): monkeypatch _pipeline.run instead of nl2sql.Pipeline.run
Browse files- app/routers/nl2sql.py +10 -9
- tests/test_nl2sql_router.py +3 -3
app/routers/nl2sql.py
CHANGED
|
@@ -1,30 +1,31 @@
|
|
| 1 |
from dataclasses import asdict, is_dataclass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from fastapi import APIRouter, HTTPException, UploadFile, File
|
|
|
|
| 3 |
from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
|
| 4 |
from nl2sql.pipeline import Pipeline as _Pipeline, FinalResult as _FinalResult
|
| 5 |
from nl2sql.ambiguity_detector import AmbiguityDetector
|
| 6 |
from nl2sql.safety import Safety
|
| 7 |
from nl2sql.planner import Planner
|
| 8 |
from nl2sql.generator import Generator
|
| 9 |
-
from adapters.llm.openai_provider import OpenAIProvider
|
| 10 |
from nl2sql.executor import Executor
|
| 11 |
from nl2sql.verifier import Verifier
|
| 12 |
from nl2sql.repair import Repair
|
|
|
|
| 13 |
from adapters.db.sqlite_adapter import SQLiteAdapter
|
| 14 |
from adapters.db.postgres_adapter import PostgresAdapter
|
| 15 |
|
| 16 |
-
import os
|
| 17 |
-
from pathlib import Path
|
| 18 |
-
import time
|
| 19 |
-
import json
|
| 20 |
-
import uuid
|
| 21 |
-
from typing import Union, Optional, Dict, TypedDict, Any, cast
|
| 22 |
-
|
| 23 |
-
# Re-export for tests & public API stability (pytest expects nl2sql.Pipeline)
|
| 24 |
Pipeline = _Pipeline
|
| 25 |
FinalResult = _FinalResult
|
| 26 |
__all__ = ["Pipeline", "FinalResult"]
|
| 27 |
|
|
|
|
| 28 |
router = APIRouter(prefix="/nl2sql")
|
| 29 |
|
| 30 |
# -------------------------------
|
|
|
|
| 1 |
from dataclasses import asdict, is_dataclass
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import time
|
| 5 |
+
import json
|
| 6 |
+
import uuid
|
| 7 |
+
from typing import Union, Optional, Dict, TypedDict, Any, cast
|
| 8 |
+
|
| 9 |
from fastapi import APIRouter, HTTPException, UploadFile, File
|
| 10 |
+
|
| 11 |
from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
|
| 12 |
from nl2sql.pipeline import Pipeline as _Pipeline, FinalResult as _FinalResult
|
| 13 |
from nl2sql.ambiguity_detector import AmbiguityDetector
|
| 14 |
from nl2sql.safety import Safety
|
| 15 |
from nl2sql.planner import Planner
|
| 16 |
from nl2sql.generator import Generator
|
|
|
|
| 17 |
from nl2sql.executor import Executor
|
| 18 |
from nl2sql.verifier import Verifier
|
| 19 |
from nl2sql.repair import Repair
|
| 20 |
+
from adapters.llm.openai_provider import OpenAIProvider
|
| 21 |
from adapters.db.sqlite_adapter import SQLiteAdapter
|
| 22 |
from adapters.db.postgres_adapter import PostgresAdapter
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
Pipeline = _Pipeline
|
| 25 |
FinalResult = _FinalResult
|
| 26 |
__all__ = ["Pipeline", "FinalResult"]
|
| 27 |
|
| 28 |
+
|
| 29 |
router = APIRouter(prefix="/nl2sql")
|
| 30 |
|
| 31 |
# -------------------------------
|
tests/test_nl2sql_router.py
CHANGED
|
@@ -31,7 +31,7 @@ def test_ambiguity_route(monkeypatch):
|
|
| 31 |
traces=[fake_trace("detector")],
|
| 32 |
)
|
| 33 |
|
| 34 |
-
monkeypatch.setattr(nl2sql.
|
| 35 |
|
| 36 |
resp = client.post(
|
| 37 |
path,
|
|
@@ -64,7 +64,7 @@ def test_error_route(monkeypatch):
|
|
| 64 |
traces=[fake_trace("safety")],
|
| 65 |
)
|
| 66 |
|
| 67 |
-
monkeypatch.setattr(nl2sql.
|
| 68 |
|
| 69 |
resp = client.post(
|
| 70 |
path,
|
|
@@ -94,7 +94,7 @@ def test_success_route(monkeypatch):
|
|
| 94 |
traces=[fake_trace("planner"), fake_trace("generator")],
|
| 95 |
)
|
| 96 |
|
| 97 |
-
monkeypatch.setattr(nl2sql.
|
| 98 |
|
| 99 |
resp = client.post(
|
| 100 |
path,
|
|
|
|
| 31 |
traces=[fake_trace("detector")],
|
| 32 |
)
|
| 33 |
|
| 34 |
+
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
|
| 35 |
|
| 36 |
resp = client.post(
|
| 37 |
path,
|
|
|
|
| 64 |
traces=[fake_trace("safety")],
|
| 65 |
)
|
| 66 |
|
| 67 |
+
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
|
| 68 |
|
| 69 |
resp = client.post(
|
| 70 |
path,
|
|
|
|
| 94 |
traces=[fake_trace("planner"), fake_trace("generator")],
|
| 95 |
)
|
| 96 |
|
| 97 |
+
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
|
| 98 |
|
| 99 |
resp = client.post(
|
| 100 |
path,
|