Melika Kheirieh commited on
Commit
dcc30f0
·
1 Parent(s): eee3f75

style: format code with ruff

Browse files
Files changed (1) hide show
  1. benchmarks/run.py +14 -9
benchmarks/run.py CHANGED
@@ -28,8 +28,9 @@ class LLMProvider(Protocol):
28
 
29
  provider_id: str
30
 
31
- def plan(self, *, user_query: str, schema_preview: str) -> Tuple[str, int, int, float]:
32
- ...
 
33
 
34
  def generate_sql(
35
  self,
@@ -38,18 +39,20 @@ class LLMProvider(Protocol):
38
  schema_preview: str,
39
  plan_text: str,
40
  clarify_answers: Optional[Any] = None,
41
- ) -> Tuple[str, str, int, int, float]:
42
- ...
43
 
44
- def repair(self, *, sql: str, error_msg: str, schema_preview: str) -> Tuple[str, int, int, float]:
45
- ...
 
46
 
47
 
48
  # ---- fallback: Dummy LLM (so it runs without API keys)
49
  class DummyLLM:
50
  provider_id = "dummy-llm"
51
 
52
- def plan(self, *, user_query: str, schema_preview: str) -> Tuple[str, int, int, float]:
 
 
53
  text = (
54
  f"- understand question: {user_query}\n"
55
  "- identify tables\n- join if needed\n- filter\n- order/limit"
@@ -69,7 +72,9 @@ class DummyLLM:
69
  rationale = "Demo SQL from DummyLLM"
70
  return sql, rationale, 0, 0, 0.0
71
 
72
- def repair(self, *, sql: str, error_msg: str, schema_preview: str) -> Tuple[str, int, int, float]:
 
 
73
  return sql, 0, 0, 0.0
74
 
75
 
@@ -101,7 +106,7 @@ def build_pipeline(db_path: Path, use_openai: bool) -> Pipeline:
101
  if use_openai and os.getenv("OPENAI_API_KEY"):
102
  llm = OpenAIProvider() # conforms to LLMProvider
103
  else:
104
- llm = DummyLLM() # conforms to LLMProvider
105
 
106
  # stages
107
  detector = AmbiguityDetector()
 
28
 
29
  provider_id: str
30
 
31
+ def plan(
32
+ self, *, user_query: str, schema_preview: str
33
+ ) -> Tuple[str, int, int, float]: ...
34
 
35
  def generate_sql(
36
  self,
 
39
  schema_preview: str,
40
  plan_text: str,
41
  clarify_answers: Optional[Any] = None,
42
+ ) -> Tuple[str, str, int, int, float]: ...
 
43
 
44
+ def repair(
45
+ self, *, sql: str, error_msg: str, schema_preview: str
46
+ ) -> Tuple[str, int, int, float]: ...
47
 
48
 
49
  # ---- fallback: Dummy LLM (so it runs without API keys)
50
  class DummyLLM:
51
  provider_id = "dummy-llm"
52
 
53
+ def plan(
54
+ self, *, user_query: str, schema_preview: str
55
+ ) -> Tuple[str, int, int, float]:
56
  text = (
57
  f"- understand question: {user_query}\n"
58
  "- identify tables\n- join if needed\n- filter\n- order/limit"
 
72
  rationale = "Demo SQL from DummyLLM"
73
  return sql, rationale, 0, 0, 0.0
74
 
75
+ def repair(
76
+ self, *, sql: str, error_msg: str, schema_preview: str
77
+ ) -> Tuple[str, int, int, float]:
78
  return sql, 0, 0, 0.0
79
 
80
 
 
106
  if use_openai and os.getenv("OPENAI_API_KEY"):
107
  llm = OpenAIProvider() # conforms to LLMProvider
108
  else:
109
+ llm = DummyLLM() # conforms to LLMProvider
110
 
111
  # stages
112
  detector = AmbiguityDetector()