Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
dcc30f0
1
Parent(s):
eee3f75
style: format code with ruff
Browse files- benchmarks/run.py +14 -9
benchmarks/run.py
CHANGED
|
@@ -28,8 +28,9 @@ class LLMProvider(Protocol):
|
|
| 28 |
|
| 29 |
provider_id: str
|
| 30 |
|
| 31 |
-
def plan(
|
| 32 |
-
|
|
|
|
| 33 |
|
| 34 |
def generate_sql(
|
| 35 |
self,
|
|
@@ -38,18 +39,20 @@ class LLMProvider(Protocol):
|
|
| 38 |
schema_preview: str,
|
| 39 |
plan_text: str,
|
| 40 |
clarify_answers: Optional[Any] = None,
|
| 41 |
-
) -> Tuple[str, str, int, int, float]:
|
| 42 |
-
...
|
| 43 |
|
| 44 |
-
def repair(
|
| 45 |
-
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
# ---- fallback: Dummy LLM (so it runs without API keys)
|
| 49 |
class DummyLLM:
|
| 50 |
provider_id = "dummy-llm"
|
| 51 |
|
| 52 |
-
def plan(
|
|
|
|
|
|
|
| 53 |
text = (
|
| 54 |
f"- understand question: {user_query}\n"
|
| 55 |
"- identify tables\n- join if needed\n- filter\n- order/limit"
|
|
@@ -69,7 +72,9 @@ class DummyLLM:
|
|
| 69 |
rationale = "Demo SQL from DummyLLM"
|
| 70 |
return sql, rationale, 0, 0, 0.0
|
| 71 |
|
| 72 |
-
def repair(
|
|
|
|
|
|
|
| 73 |
return sql, 0, 0, 0.0
|
| 74 |
|
| 75 |
|
|
@@ -101,7 +106,7 @@ def build_pipeline(db_path: Path, use_openai: bool) -> Pipeline:
|
|
| 101 |
if use_openai and os.getenv("OPENAI_API_KEY"):
|
| 102 |
llm = OpenAIProvider() # conforms to LLMProvider
|
| 103 |
else:
|
| 104 |
-
llm = DummyLLM()
|
| 105 |
|
| 106 |
# stages
|
| 107 |
detector = AmbiguityDetector()
|
|
|
|
| 28 |
|
| 29 |
provider_id: str
|
| 30 |
|
| 31 |
+
def plan(
|
| 32 |
+
self, *, user_query: str, schema_preview: str
|
| 33 |
+
) -> Tuple[str, int, int, float]: ...
|
| 34 |
|
| 35 |
def generate_sql(
|
| 36 |
self,
|
|
|
|
| 39 |
schema_preview: str,
|
| 40 |
plan_text: str,
|
| 41 |
clarify_answers: Optional[Any] = None,
|
| 42 |
+
) -> Tuple[str, str, int, int, float]: ...
|
|
|
|
| 43 |
|
| 44 |
+
def repair(
|
| 45 |
+
self, *, sql: str, error_msg: str, schema_preview: str
|
| 46 |
+
) -> Tuple[str, int, int, float]: ...
|
| 47 |
|
| 48 |
|
| 49 |
# ---- fallback: Dummy LLM (so it runs without API keys)
|
| 50 |
class DummyLLM:
|
| 51 |
provider_id = "dummy-llm"
|
| 52 |
|
| 53 |
+
def plan(
|
| 54 |
+
self, *, user_query: str, schema_preview: str
|
| 55 |
+
) -> Tuple[str, int, int, float]:
|
| 56 |
text = (
|
| 57 |
f"- understand question: {user_query}\n"
|
| 58 |
"- identify tables\n- join if needed\n- filter\n- order/limit"
|
|
|
|
| 72 |
rationale = "Demo SQL from DummyLLM"
|
| 73 |
return sql, rationale, 0, 0, 0.0
|
| 74 |
|
| 75 |
+
def repair(
|
| 76 |
+
self, *, sql: str, error_msg: str, schema_preview: str
|
| 77 |
+
) -> Tuple[str, int, int, float]:
|
| 78 |
return sql, 0, 0, 0.0
|
| 79 |
|
| 80 |
|
|
|
|
| 106 |
if use_openai and os.getenv("OPENAI_API_KEY"):
|
| 107 |
llm = OpenAIProvider() # conforms to LLMProvider
|
| 108 |
else:
|
| 109 |
+
llm = DummyLLM() # conforms to LLMProvider
|
| 110 |
|
| 111 |
# stages
|
| 112 |
detector = AmbiguityDetector()
|