Melika Kheirieh commited on
Commit
5eeca35
·
1 Parent(s): e207f41

Add more advanced metrics

Browse files
benchmarks/evaluate_spider.py CHANGED
@@ -1,57 +1,274 @@
1
  from __future__ import annotations
2
- import time, json, csv
3
  from pathlib import Path
4
  from tqdm import tqdm
5
 
6
- from app import get_schema_preview, on_generate_query, make_sql_chain
7
  from langchain_community.utilities import SQLDatabase
8
  from benchmarks import load_spider_sqlite
9
 
 
 
10
 
11
  LOG_DIR = Path("logs/spider_eval")
12
  LOG_DIR.mkdir(parents=True, exist_ok=True)
13
 
14
- def run_eval(split="dev", limit=20):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  data = load_spider_sqlite(split)
 
 
16
  data = data[:limit]
17
- print(f"Running eval on {len(data)} examples...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- results = []
20
- for ex in tqdm(data):
21
- db_path = str(ex.db_path)
 
 
 
 
22
 
23
- schema = get_schema_preview(str(ex.db_path), 0)
 
24
 
25
- sql_db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
26
- chain = make_sql_chain(sql_db)
27
 
28
- state = {
29
- "db_path": db_path,
30
- "sql_db": sql_db,
31
- "schema_text": schema,
32
- "chain": chain,
33
- }
34
 
35
- msg, sql, output = on_generate_query(ex.question, 1000, state)
 
 
 
 
 
 
36
 
37
- results.append({
38
- "db_id": ex.db_id,
39
- "question": ex.question,
40
- "gold_sql": ex.gold_sql,
41
- "pred_sql": sql,
42
- "status": msg,
43
- "output": output,
44
- })
 
 
 
 
 
 
 
45
 
46
- time.sleep(0.3)
 
47
 
48
- ts = int(time.time())
49
- out_path = LOG_DIR / f"{split}_results_{ts}.jsonl"
50
- with out_path.open("w", encoding="utf-8") as f:
51
- for r in results:
52
- f.write(json.dumps(r, ensure_ascii=False) + "\n")
53
 
54
- print(f"Wrote results → {out_path}")
55
 
56
  if __name__ == "__main__":
57
- run_eval("train", 20)
 
1
  from __future__ import annotations
2
+ import time, json, subprocess
3
  from pathlib import Path
4
  from tqdm import tqdm
5
 
6
+ from app import get_schema_preview, on_generate_query, make_sql_chain
7
  from langchain_community.utilities import SQLDatabase
8
  from benchmarks import load_spider_sqlite
9
 
10
+ from sqlglot import parse_one, exp
11
+ from sqlglot.errors import ParseError
12
 
13
  LOG_DIR = Path("logs/spider_eval")
14
  LOG_DIR.mkdir(parents=True, exist_ok=True)
15
 
16
+ def normalize_sql(sql: str) -> str:
17
+ # نسخه ساده؛ می‌تونی قوی‌ترش کنی با پارس + بازسازی
18
+ return " ".join(sql.lower().strip().split())
19
+
20
+ def compare_results(pred_rows, gold_rows):
21
+ if pred_rows is None or gold_rows is None:
22
+ return False
23
+ # اگر ترتیب مهم نیست
24
+ return set(pred_rows) == set(gold_rows)
25
+
26
+ def try_execute_sql(sql_db, sql, timeout: float = None):
27
+ start = time.time()
28
+ try:
29
+ rows = sql_db.run(sql)
30
+ return rows, time.time() - start, None
31
+ except Exception as e:
32
+ return None, time.time() - start, str(e)
33
+
34
+ def exact_match_structural(sql_pred: str, sql_gold: str) -> bool:
35
+ try:
36
+ ast_pred = parse_one(sql_pred)
37
+ ast_gold = parse_one(sql_gold)
38
+ except Exception:
39
+ return False
40
+
41
+ def normalize_ast(node: exp.Expression):
42
+ for name, arg in node.args.items():
43
+ if isinstance(arg, list):
44
+ arg.sort(key=lambda x: str(x))
45
+ for child in arg:
46
+ normalize_ast(child)
47
+ elif isinstance(arg, exp.Expression):
48
+ normalize_ast(arg)
49
+ if isinstance(node, exp.Alias):
50
+ return normalize_ast(node.this)
51
+ return node
52
+
53
+ norm_prd = normalize_ast(ast_pred)
54
+ norm_gold = normalize_ast(ast_gold)
55
+ return norm_prd == norm_gold
56
+
57
+ def get_git_commit_hash() -> str:
58
+ try:
59
+ out = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("ascii")
60
+ return out
61
+ except Exception:
62
+ return "UNKNOWN"
63
+
64
+ FORBIDDEN_NODES = (
65
+ exp.Insert,
66
+ exp.Delete,
67
+ exp.Update,
68
+ exp.Drop,
69
+ exp.Alter,
70
+ exp.Attach,
71
+ exp.Pragma,
72
+ exp.Create,
73
+ )
74
+
75
+ def is_safe_sql(sql: str, dialect: str | None = None) -> bool:
76
+ try:
77
+ ast = parse_one(sql, read=dialect)
78
+ except ParseError:
79
+ return False
80
+ if not isinstance(ast, exp.Select):
81
+ return False
82
+ for node in ast.walk():
83
+ if isinstance(node, FORBIDDEN_NODES):
84
+ return False
85
+ return True
86
+
87
+ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
88
  data = load_spider_sqlite(split)
89
+ if len(data) < limit:
90
+ limit = len(data)
91
  data = data[:limit]
92
+ print(f"Running eval on {len(data)} examples in split={split}...")
93
+
94
+ commit_hash = get_git_commit_hash()
95
+ start_ts = int(time.time())
96
+
97
+ pred_txt = LOG_DIR / f"{split}_pred_{start_ts}.txt"
98
+ gold_txt = LOG_DIR / f"{split}_gold_{start_ts}.txt"
99
+ results_fn = LOG_DIR / f"{split}_results_{start_ts}.jsonl"
100
+ metrics_fn = LOG_DIR / f"{split}_metrics_{start_ts}.json"
101
+
102
+ done = set()
103
+ if resume and results_fn.exists():
104
+ with results_fn.open("r", encoding="utf-8") as f:
105
+ for line in f:
106
+ if line.startswith("#"):
107
+ continue
108
+ try:
109
+ r = json.loads(line)
110
+ done.add((r.get("db_id"), r.get("question")))
111
+ except Exception:
112
+ pass
113
+
114
+ write_header = not results_fn.exists()
115
+ with results_fn.open("a", encoding="utf-8") as fout, \
116
+ pred_txt.open("a", encoding="utf-8") as fpred, \
117
+ gold_txt.open("a", encoding="utf-8") as fgold:
118
+
119
+ if write_header:
120
+ header = {
121
+ "commit_hash": commit_hash,
122
+ "split": split,
123
+ "limit": limit,
124
+ "start_time": start_ts,
125
+ }
126
+ fout.write("# " + json.dumps(header, ensure_ascii=False) + "\n")
127
+ fout.flush()
128
+
129
+ agg = []
130
+ for ex in tqdm(data):
131
+ key = (ex.db_id, ex.question)
132
+ if resume and key in done:
133
+ continue
134
+
135
+ db_path = str(ex.db_path)
136
+ schema = get_schema_preview(db_path, 0)
137
+ sql_db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
138
+ chain = make_sql_chain(sql_db)
139
+ state = {
140
+ "db_path": db_path,
141
+ "sql_db": sql_db,
142
+ "schema_text": schema,
143
+ "chain": chain,
144
+ }
145
+
146
+ t0 = time.time()
147
+ msg, sql, output = on_generate_query(ex.question, 1000, state)
148
+ gen_time = time.time() - t0
149
+
150
+ safe_flag = is_safe_sql(sql)
151
+ if not safe_flag:
152
+ rec = {
153
+ "db_id": ex.db_id,
154
+ "question": ex.question,
155
+ "gold_sql": ex.gold_sql,
156
+ "pred_sql": sql,
157
+ "status": "rejected_safe_check",
158
+ "output": output,
159
+ "gen_time": gen_time,
160
+ "exec_time": None,
161
+ "error": "unsafe_sql",
162
+ "gold_error": None,
163
+ "pred_rows": None,
164
+ "gold_rows": None,
165
+ "exact_match": False,
166
+ "exact_match_structural": False,
167
+ "execution_accuracy": False,
168
+ "safe_check_failed": True,
169
+ }
170
+ fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
171
+ fout.flush()
172
+ fpred.write(f"{sql}\t{ex.db_id}\n")
173
+ fgold.write(f"{ex.gold_sql}\t{ex.db_id}\n")
174
+ fpred.flush()
175
+ fgold.flush()
176
+ agg.append(rec)
177
+ if sleep_time > 0:
178
+ time.sleep(sleep_time)
179
+ continue
180
+
181
+ pred_rows, exec_time, error = try_execute_sql(sql_db, sql)
182
+ gold_rows, gold_time, gold_error = try_execute_sql(sql_db, ex.gold_sql)
183
+
184
+ skip = gold_error is not None
185
+
186
+ em = False
187
+ if not skip:
188
+ try:
189
+ em = normalize_sql(sql) == normalize_sql(ex.gold_sql)
190
+ except Exception:
191
+ pass
192
+
193
+ em_struct = False
194
+ if not skip:
195
+ em_struct = exact_match_structural(sql, ex.gold_sql)
196
+
197
+ exec_acc = False
198
+ if not skip:
199
+ exec_acc = compare_results(pred_rows, gold_rows)
200
+
201
+ rec = {
202
+ "db_id": ex.db_id,
203
+ "question": ex.question,
204
+ "gold_sql": ex.gold_sql,
205
+ "pred_sql": sql,
206
+ "status": msg,
207
+ "output": output,
208
+ "gen_time": gen_time,
209
+ "exec_time": exec_time,
210
+ "error": error,
211
+ "gold_error": gold_error,
212
+ "pred_rows": pred_rows,
213
+ "gold_rows": gold_rows,
214
+ "exact_match": em,
215
+ "exact_match_structural": em_struct,
216
+ "execution_accuracy": exec_acc,
217
+ "safe_check_failed": False,
218
+ }
219
 
220
+ fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
221
+ fout.flush()
222
+ fpred.write(f"{sql}\t{ex.db_id}\n")
223
+ fgold.write(f"{ex.gold_sql}\t{ex.db_id}\n")
224
+ fpred.flush()
225
+ fgold.flush()
226
+ agg.append(rec)
227
 
228
+ if sleep_time > 0:
229
+ time.sleep(sleep_time)
230
 
 
 
231
 
232
+ valid = [r for r in agg if (not r.get("safe_check_failed", False)) and r.get("gold_error") is None]
233
+ total_valid = len(valid)
234
+ total_all = len(agg)
235
+ if total_valid == 0:
236
+ print("No valid examples to compute metrics")
237
+ return
238
 
239
+ em_count = sum(1 for r in valid if r["exact_match"])
240
+ em_struct_count = sum(1 for r in valid if r["exact_match_structural"])
241
+ exec_acc_count = sum(1 for r in valid if r["execution_accuracy"])
242
+ error_count = sum(1 for r in agg if (r.get("error") is not None) and (not r.get("safe_check_failed", False)))
243
+ safe_fail_count = sum(1 for r in agg if r.get("safe_check_failed", False))
244
+ avg_gen_time = sum(r["gen_time"] for r in valid) / total_valid
245
+ avg_exec_time = sum(r["exec_time"] for r in valid) / total_valid
246
 
247
+ metrics = {
248
+ "commit_hash": commit_hash,
249
+ "split": split,
250
+ "limit": limit,
251
+ "total_examples": total_all,
252
+ "valid_examples": total_valid,
253
+ "exact_match_rate": em_count / total_valid,
254
+ "exact_match_structural_rate": em_struct_count / total_valid,
255
+ "execution_accuracy_rate": exec_acc_count / total_valid,
256
+ "error_rate": error_count / total_valid,
257
+ "safe_check_fail_rate": safe_fail_count / total_all,
258
+ "avg_gen_time": avg_gen_time,
259
+ "avg_exec_time": avg_exec_time,
260
+ "run_id": start_ts,
261
+ }
262
 
263
+ with metrics_fn.open("w", encoding="utf-8") as fm:
264
+ json.dump(metrics, fm, ensure_ascii=False, indent=2)
265
 
266
+ print("Metrics:", metrics)
267
+ print(f"Wrote results {results_fn}")
268
+ print(f"Wrote pred file → {pred_txt}")
269
+ print(f"Wrote gold file → {gold_txt}")
270
+ print(f"Wrote metrics {metrics_fn}")
271
 
 
272
 
273
  if __name__ == "__main__":
274
+ run_eval("dev", limit=10, resume=True, sleep_time=0.05)
logs/spider_eval/dev_gold_1760430884.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ SELECT count(*) FROM singer concert_singer
2
+ SELECT count(*) FROM singer concert_singer
3
+ SELECT name , country , age FROM singer ORDER BY age DESC concert_singer
4
+ SELECT name , country , age FROM singer ORDER BY age DESC concert_singer
5
+ SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France' concert_singer
6
+ SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France' concert_singer
7
+ SELECT song_name , song_release_year FROM singer ORDER BY age LIMIT 1 concert_singer
8
+ SELECT song_name , song_release_year FROM singer ORDER BY age LIMIT 1 concert_singer
9
+ SELECT DISTINCT country FROM singer WHERE age > 20 concert_singer
10
+ SELECT DISTINCT country FROM singer WHERE age > 20 concert_singer
logs/spider_eval/dev_metrics_1760430884.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "commit_hash": "e207f417ac5923220817e3c3f61c72e51a98c63b",
3
+ "split": "dev",
4
+ "limit": 10,
5
+ "total_examples": 10,
6
+ "valid_examples": 10,
7
+ "exact_match_rate": 0.2,
8
+ "exact_match_structural_rate": 0.0,
9
+ "execution_accuracy_rate": 0.8,
10
+ "error_rate": 0.0,
11
+ "safe_check_fail_rate": 0.0,
12
+ "avg_gen_time": 1.4374850749969483,
13
+ "avg_exec_time": 0.0007865667343139648,
14
+ "run_id": 1760430884
15
+ }
logs/spider_eval/dev_pred_1760430884.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ SELECT COUNT(*) AS total_singers FROM singer; concert_singer
2
+ SELECT COUNT(*) AS total_singers FROM singer; concert_singer
3
+ SELECT Name, Country, Age FROM singer ORDER BY Age DESC concert_singer
4
+ SELECT Name, Country, Age FROM singer ORDER BY Age DESC concert_singer
5
+ SELECT AVG(Age) AS average_age, MIN(Age) AS minimum_age, MAX(Age) AS maximum_age FROM singer WHERE Country = 'France' concert_singer
6
+ SELECT AVG(Age) AS average_age, MIN(Age) AS minimum_age, MAX(Age) AS maximum_age FROM singer WHERE Country = 'France'; concert_singer
7
+ SELECT Name, Song_Name, Song_release_year FROM singer WHERE Age = (SELECT MAX(Age) FROM singer) concert_singer
8
+ SELECT Song_Name, Song_release_year FROM singer WHERE Age = (SELECT MAX(Age) FROM singer) concert_singer
9
+ SELECT DISTINCT Country FROM singer WHERE Age > 20 concert_singer
10
+ SELECT DISTINCT Country FROM singer WHERE Age > 20 concert_singer
logs/spider_eval/dev_results_1760430884.jsonl ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # {"commit_hash": "e207f417ac5923220817e3c3f61c72e51a98c63b", "split": "dev", "limit": 10, "start_time": 1760430884}
2
+ {"db_id": "concert_singer", "question": "How many singers do we have?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "SELECT COUNT(*) AS total_singers FROM singer;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"total_singers\": 6\n }\n]", "gen_time": 1.2182981967926025, "exec_time": 0.0008916854858398438, "error": null, "gold_error": null, "pred_rows": "[(6,)]", "gold_rows": "[(6,)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
3
+ {"db_id": "concert_singer", "question": "What is the total number of singers?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "SELECT COUNT(*) AS total_singers FROM singer;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"total_singers\": 6\n }\n]", "gen_time": 1.261944055557251, "exec_time": 0.00044798851013183594, "error": null, "gold_error": null, "pred_rows": "[(6,)]", "gold_rows": "[(6,)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
4
+ {"db_id": "concert_singer", "question": "Show name, country, age for all singers ordered by age from the oldest to the youngest.", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "SELECT Name, Country, Age FROM singer ORDER BY Age DESC", "status": "Returned 6 row(s). Showing up to 50.", "output": "[\n {\n \"Name\": \"Joe Sharp\",\n \"Country\": \"Netherlands\",\n \"Age\": 52\n },\n {\n \"Name\": \"John Nizinik\",\n \"Country\": \"France\",\n \"Age\": 43\n },\n {\n \"Name\": \"Rose White\",\n \"Country\": \"France\",\n \"Age\": 41\n },\n {\n \"Name\": \"Timbaland\",\n \"Country\": \"United States\",\n \"Age\": 32\n },\n {\n \"Name\": \"Justin Brown\",\n \"Country\": \"France\",\n \"Age\": 29\n },\n {\n \"Name\": \"Tribal King\",\n \"Country\": \"France\",\n \"Age\": 25\n }\n]", "gen_time": 1.0276496410369873, "exec_time": 0.0006437301635742188, "error": null, "gold_error": null, "pred_rows": "[('Joe Sharp', 'Netherlands', 52), ('John Nizinik', 'France', 43), ('Rose White', 'France', 41), ('Timbaland', 'United States', 32), ('Justin Brown', 'France', 29), ('Tribal King', 'France', 25)]", "gold_rows": "[('Joe Sharp', 'Netherlands', 52), ('John Nizinik', 'France', 43), ('Rose White', 'France', 41), ('Timbaland', 'United States', 32), ('Justin Brown', 'France', 29), ('Tribal King', 'France', 25)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
5
+ {"db_id": "concert_singer", "question": "What are the names, countries, and ages for every singer in descending order of age?", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "SELECT Name, Country, Age FROM singer ORDER BY Age DESC", "status": "Returned 6 row(s). Showing up to 50.", "output": "[\n {\n \"Name\": \"Joe Sharp\",\n \"Country\": \"Netherlands\",\n \"Age\": 52\n },\n {\n \"Name\": \"John Nizinik\",\n \"Country\": \"France\",\n \"Age\": 43\n },\n {\n \"Name\": \"Rose White\",\n \"Country\": \"France\",\n \"Age\": 41\n },\n {\n \"Name\": \"Timbaland\",\n \"Country\": \"United States\",\n \"Age\": 32\n },\n {\n \"Name\": \"Justin Brown\",\n \"Country\": \"France\",\n \"Age\": 29\n },\n {\n \"Name\": \"Tribal King\",\n \"Country\": \"France\",\n \"Age\": 25\n }\n]", "gen_time": 1.4854280948638916, "exec_time": 0.0007121562957763672, "error": null, "gold_error": null, "pred_rows": "[('Joe Sharp', 'Netherlands', 52), ('John Nizinik', 'France', 43), ('Rose White', 'France', 41), ('Timbaland', 'United States', 32), ('Justin Brown', 'France', 29), ('Tribal King', 'France', 25)]", "gold_rows": "[('Joe Sharp', 'Netherlands', 52), ('John Nizinik', 'France', 43), ('Rose White', 'France', 41), ('Timbaland', 'United States', 32), ('Justin Brown', 'France', 29), ('Tribal King', 'France', 25)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
6
+ {"db_id": "concert_singer", "question": "What is the average, minimum, and maximum age of all singers from France?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "SELECT AVG(Age) AS average_age, MIN(Age) AS minimum_age, MAX(Age) AS maximum_age FROM singer WHERE Country = 'France'", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"average_age\": 34.5,\n \"minimum_age\": 25,\n \"maximum_age\": 43\n }\n]", "gen_time": 1.8432340621948242, "exec_time": 0.0009491443634033203, "error": null, "gold_error": null, "pred_rows": "[(34.5, 25, 43)]", "gold_rows": "[(34.5, 25, 43)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
7
+ {"db_id": "concert_singer", "question": "What is the average, minimum, and maximum age for all French singers?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "SELECT AVG(Age) AS average_age, MIN(Age) AS minimum_age, MAX(Age) AS maximum_age FROM singer WHERE Country = 'France';", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"average_age\": 34.5,\n \"minimum_age\": 25,\n \"maximum_age\": 43\n }\n]", "gen_time": 2.5529589653015137, "exec_time": 0.0012962818145751953, "error": null, "gold_error": null, "pred_rows": "[(34.5, 25, 43)]", "gold_rows": "[(34.5, 25, 43)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
8
+ {"db_id": "concert_singer", "question": "Show the name and the release year of the song by the youngest singer.", "gold_sql": "SELECT song_name , song_release_year FROM singer ORDER BY age LIMIT 1", "pred_sql": "SELECT Name, Song_Name, Song_release_year FROM singer WHERE Age = (SELECT MAX(Age) FROM singer)", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Name\": \"Joe Sharp\",\n \"Song_Name\": \"You\",\n \"Song_release_year\": \"1992\"\n }\n]", "gen_time": 1.2428169250488281, "exec_time": 0.0007181167602539062, "error": null, "gold_error": null, "pred_rows": "[('Joe Sharp', 'You', '1992')]", "gold_rows": "[('Love', '2016')]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": false, "safe_check_failed": false}
9
+ {"db_id": "concert_singer", "question": "What are the names and release years for all the songs of the youngest singer?", "gold_sql": "SELECT song_name , song_release_year FROM singer ORDER BY age LIMIT 1", "pred_sql": "SELECT Song_Name, Song_release_year FROM singer WHERE Age = (SELECT MAX(Age) FROM singer)", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Song_Name\": \"You\",\n \"Song_release_year\": \"1992\"\n }\n]", "gen_time": 1.4568238258361816, "exec_time": 0.0009098052978515625, "error": null, "gold_error": null, "pred_rows": "[('You', '1992')]", "gold_rows": "[('Love', '2016')]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": false, "safe_check_failed": false}
10
+ {"db_id": "concert_singer", "question": "What are all distinct countries where singers above age 20 are from?", "gold_sql": "SELECT DISTINCT country FROM singer WHERE age > 20", "pred_sql": "SELECT DISTINCT Country FROM singer WHERE Age > 20", "status": "Returned 3 row(s). Showing up to 50.", "output": "[\n {\n \"Country\": \"Netherlands\"\n },\n {\n \"Country\": \"United States\"\n },\n {\n \"Country\": \"France\"\n }\n]", "gen_time": 0.9801719188690186, "exec_time": 0.0007050037384033203, "error": null, "gold_error": null, "pred_rows": "[('Netherlands',), ('United States',), ('France',)]", "gold_rows": "[('Netherlands',), ('United States',), ('France',)]", "exact_match": true, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
11
+ {"db_id": "concert_singer", "question": "What are the different countries with singers above age 20?", "gold_sql": "SELECT DISTINCT country FROM singer WHERE age > 20", "pred_sql": "SELECT DISTINCT Country FROM singer WHERE Age > 20", "status": "Returned 3 row(s). Showing up to 50.", "output": "[\n {\n \"Country\": \"Netherlands\"\n },\n {\n \"Country\": \"United States\"\n },\n {\n \"Country\": \"France\"\n }\n]", "gen_time": 1.3055250644683838, "exec_time": 0.0005917549133300781, "error": null, "gold_error": null, "pred_rows": "[('Netherlands',), ('United States',), ('France',)]", "gold_rows": "[('Netherlands',), ('United States',), ('France',)]", "exact_match": true, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
logs/spider_eval/train_results_1760375755.jsonl DELETED
@@ -1,20 +0,0 @@
1
- {"db_id": "department_management", "question": "How many heads of the departments are older than 56 ?", "gold_sql": "SELECT count(*) FROM head WHERE age > 56", "pred_sql": "SELECT COUNT(*) FROM head WHERE age > 56", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"COUNT(*)\": 5\n }\n]"}
2
- {"db_id": "department_management", "question": "List the name, born state and age of the heads of departments ordered by age.", "gold_sql": "SELECT name , born_state , age FROM head ORDER BY age", "pred_sql": "SELECT h.name, h.born_state, h.age FROM head h JOIN management m ON h.head_ID = m.head_ID ORDER BY h.age", "status": "Returned 5 row(s). Showing up to 50.", "output": "[\n {\n \"name\": \"Dudley Hart\",\n \"born_state\": \"California\",\n \"age\": 52.0\n },\n {\n \"name\": \"Jeff Maggert\",\n \"born_state\": \"Delaware\",\n \"age\": 53.0\n },\n {\n \"name\": \"Franklin Langham\",\n \"born_state\": \"Connecticut\",\n \"age\": 67.0\n },\n {\n \"name\": \"Billy Mayfair\",\n \"born_state\": \"California\",\n \"age\": 69.0\n },\n {\n \"name\": \"K. J. Choi\",\n \"born_state\": \"Alabama\",\n \"age\": 69.0\n }\n]"}
3
- {"db_id": "department_management", "question": "List the creation year, name and budget of each department.", "gold_sql": "SELECT creation , name , budget_in_billions FROM department", "pred_sql": "SELECT Creation, Name, Budget_in_Billions FROM department;", "status": "Returned 15 row(s). Showing up to 50.", "output": "[\n {\n \"Creation\": \"1789\",\n \"Name\": \"State\",\n \"Budget_in_Billions\": 9.96\n },\n {\n \"Creation\": \"1789\",\n \"Name\": \"Treasury\",\n \"Budget_in_Billions\": 11.1\n },\n {\n \"Creation\": \"1947\",\n \"Name\": \"Defense\",\n \"Budget_in_Billions\": 439.3\n },\n {\n \"Creation\": \"1870\",\n \"Name\": \"Justice\",\n \"Budget_in_Billions\": 23.4\n },\n {\n \"Creation\": \"1849\",\n \"Name\": \"Interior\",\n \"Budget_in_Billions\": 10.7\n },\n {\n \"Creation\": \"1889\",\n \"Name\": \"Agriculture\",\n \"Budget_in_Billions\": 77.6\n },\n {\n \"Creation\": \"1903\",\n \"Name\": \"Commerce\",\n \"Budget_in_Billions\": 6.2\n },\n {\n \"Creation\": \"1913\",\n \"Name\": \"Labor\",\n \"Budget_in_Billions\": 59.7\n },\n {\n \"Creation\": \"1953\",\n \"Name\": \"Health and Human Services\",\n \"Budget_in_Billions\": 543.2\n },\n {\n \"Creation\": \"1965\",\n \"Name\": \"Housing and Urban Development\",\n \"Budget_in_Billions\": 46.2\n },\n {\n \"Creation\": \"1966\",\n \"Name\": \"Transportation\",\n \"Budget_in_Billions\": 58.0\n },\n {\n \"Creation\": \"1977\",\n \"Name\": \"Energy\",\n \"Budget_in_Billions\": 21.5\n },\n {\n \"Creation\": \"1979\",\n \"Name\": \"Education\",\n \"Budget_in_Billions\": 62.8\n },\n {\n \"Creation\": \"1989\",\n \"Name\": \"Veterans Affairs\",\n \"Budget_in_Billions\": 73.2\n },\n {\n \"Creation\": \"2002\",\n \"Name\": \"Homeland Security\",\n \"Budget_in_Billions\": 44.6\n }\n]"}
4
- {"db_id": "department_management", "question": "What are the maximum and minimum budget of the departments?", "gold_sql": "SELECT max(budget_in_billions) , min(budget_in_billions) FROM department", "pred_sql": "SELECT MAX(Budget_in_Billions) AS Max_Budget, MIN(Budget_in_Billions) AS Min_Budget FROM department", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Max_Budget\": 543.2,\n \"Min_Budget\": 6.2\n }\n]"}
5
- {"db_id": "department_management", "question": "What is the average number of employees of the departments whose rank is between 10 and 15?", "gold_sql": "SELECT avg(num_employees) FROM department WHERE ranking BETWEEN 10 AND 15", "pred_sql": "SELECT AVG(Num_Employees) AS average_num_employees FROM department WHERE Ranking BETWEEN 10 AND 15;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"average_num_employees\": 105468.16666666667\n }\n]"}
6
- {"db_id": "department_management", "question": "What are the names of the heads who are born outside the California state?", "gold_sql": "SELECT name FROM head WHERE born_state != 'California'", "pred_sql": "SELECT name FROM head WHERE born_state <> 'California';", "status": "Returned 6 row(s). Showing up to 50.", "output": "[\n {\n \"name\": \"Tiger Woods\"\n },\n {\n \"name\": \"K. J. Choi\"\n },\n {\n \"name\": \"Jeff Maggert\"\n },\n {\n \"name\": \"Stewart Cink\"\n },\n {\n \"name\": \"P\\u00e1draig Harrington\"\n },\n {\n \"name\": \"Franklin Langham\"\n }\n]"}
7
- {"db_id": "department_management", "question": "What are the distinct creation years of the departments managed by a secretary born in state 'Alabama'?", "gold_sql": "SELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id = T2.department_id JOIN head AS T3 ON T2.head_id = T3.head_id WHERE T3.born_state = 'Alabama'", "pred_sql": "SELECT DISTINCT d.Creation FROM department d JOIN management m ON d.Department_ID = m.department_ID JOIN head h ON m.head_ID = h.head_ID WHERE h.born_state = 'Alabama'", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Creation\": \"1903\"\n }\n]"}
8
- {"db_id": "department_management", "question": "What are the names of the states where at least 3 heads were born?", "gold_sql": "SELECT born_state FROM head GROUP BY born_state HAVING count(*) >= 3", "pred_sql": "SELECT born_state FROM head GROUP BY born_state HAVING COUNT(head_ID) >= 3", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"born_state\": \"California\"\n }\n]"}
9
- {"db_id": "department_management", "question": "In which year were most departments established?", "gold_sql": "SELECT creation FROM department GROUP BY creation ORDER BY count(*) DESC LIMIT 1", "pred_sql": "SELECT strftime('%Y', Creation) AS Year, COUNT(*) AS DepartmentCount FROM department GROUP BY Year ORDER BY DepartmentCount DESC LIMIT 1;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Year\": \"-4707\",\n \"DepartmentCount\": 12\n }\n]"}
10
- {"db_id": "department_management", "question": "Show the name and number of employees for the departments managed by heads whose temporary acting value is 'Yes'?", "gold_sql": "SELECT T1.name , T1.num_employees FROM department AS T1 JOIN management AS T2 ON T1.department_id = T2.department_id WHERE T2.temporary_acting = 'Yes'", "pred_sql": "SELECT d.Name, d.Num_Employees FROM department d JOIN management m ON d.Department_ID = m.department_ID WHERE m.temporary_acting = 'Yes'", "status": "Returned 3 row(s). Showing up to 50.", "output": "[\n {\n \"Name\": \"Treasury\",\n \"Num_Employees\": 115897.0\n },\n {\n \"Name\": \"Homeland Security\",\n \"Num_Employees\": 208000.0\n },\n {\n \"Name\": \"Treasury\",\n \"Num_Employees\": 115897.0\n }\n]"}
11
- {"db_id": "department_management", "question": "How many acting statuses are there?", "gold_sql": "SELECT count(DISTINCT temporary_acting) FROM management", "pred_sql": "SELECT COUNT(DISTINCT temporary_acting) AS acting_status_count FROM management;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"acting_status_count\": 2\n }\n]"}
12
- {"db_id": "department_management", "question": "How many departments are led by heads who are not mentioned?", "gold_sql": "SELECT count(*) FROM department WHERE department_id NOT IN (SELECT department_id FROM management);", "pred_sql": "SELECT COUNT(*) FROM department d LEFT JOIN management m ON d.Department_ID = m.department_ID WHERE m.head_ID IS NULL;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"COUNT(*)\": 11\n }\n]"}
13
- {"db_id": "department_management", "question": "What are the distinct ages of the heads who are acting?", "gold_sql": "SELECT DISTINCT T1.age FROM management AS T2 JOIN head AS T1 ON T1.head_id = T2.head_id WHERE T2.temporary_acting = 'Yes'", "pred_sql": "SELECT DISTINCT h.age FROM head h JOIN management m ON h.head_ID = m.head_ID WHERE m.temporary_acting = 'Yes';", "status": "Returned 3 row(s). Showing up to 50.", "output": "[\n {\n \"age\": 53.0\n },\n {\n \"age\": 52.0\n },\n {\n \"age\": 69.0\n }\n]"}
14
- {"db_id": "department_management", "question": "List the states where both the secretary of 'Treasury' department and the secretary of 'Homeland Security' were born.", "gold_sql": "SELECT T3.born_state FROM department AS T1 JOIN management AS T2 ON T1.department_id = T2.department_id JOIN head AS T3 ON T2.head_id = T3.head_id WHERE T1.name = 'Treasury' INTERSECT SELECT T3.born_state FROM department AS T1 JOIN management AS T2 ON T1.department_id = T2.department_id JOIN head AS T3 ON T2.head_id = T3.head_id WHERE T1.name = 'Homeland Security'", "pred_sql": "SELECT DISTINCT h1.born_state FROM management m1 JOIN head h1 ON m1.head_ID = h1.head_ID JOIN department d1 ON m1.department_ID = d1.Department_ID WHERE d1.Name = 'Treasury' AND h1.born_state IN (SELECT h2.born_state FROM management m2 JOIN head h2 ON m2.head_ID = h2.head_ID JOIN department d2 ON m2.department_ID = d2.Department_ID WHERE d2.Name = 'Homeland Security')", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"born_state\": \"California\"\n }\n]"}
15
- {"db_id": "department_management", "question": "Which department has more than 1 head at a time? List the id, name and the number of heads.", "gold_sql": "SELECT T1.department_id , T1.name , count(*) FROM management AS T2 JOIN department AS T1 ON T1.department_id = T2.department_id GROUP BY T1.department_id HAVING count(*) > 1", "pred_sql": "SELECT d.Department_ID, d.Name, COUNT(m.head_ID) AS Number_of_Heads FROM department d JOIN management m ON d.Department_ID = m.department_ID GROUP BY d.Department_ID, d.Name HAVING COUNT(m.head_ID) > 1", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Department_ID\": 2,\n \"Name\": \"Treasury\",\n \"Number_of_Heads\": 2\n }\n]"}
16
- {"db_id": "department_management", "question": "Which head's name has the substring 'Ha'? List the id and name.", "gold_sql": "SELECT head_id , name FROM head WHERE name LIKE '%Ha%'", "pred_sql": "SELECT h.head_ID, h.name FROM head h JOIN management m ON h.head_ID = m.head_ID WHERE h.name LIKE '%Ha%'", "status": "Returned 2 row(s). Showing up to 50.", "output": "[\n {\n \"head_ID\": 10,\n \"name\": \"Franklin Langham\"\n },\n {\n \"head_ID\": 4,\n \"name\": \"Dudley Hart\"\n }\n]"}
17
- {"db_id": "farm", "question": "How many farms are there?", "gold_sql": "SELECT count(*) FROM farm", "pred_sql": "", "status": "Error: Error code: 403 - {'error': {'message': 'token quota is not enough, token remain quota: $0.000008, need quota: $0.000114 (request id: 20251013204552914852769LyjH3RSk)', 'type': 'new_api_error', 'param': '', 'code': 'pre_consume_token_quota_failed'}}", "output": ""}
18
- {"db_id": "farm", "question": "Count the number of farms.", "gold_sql": "SELECT count(*) FROM farm", "pred_sql": "", "status": "Error: Error code: 403 - {'error': {'message': 'token quota is not enough, token remain quota: $0.000008, need quota: $0.000114 (request id: 20251013204553544953937LKDeBltm)', 'type': 'new_api_error', 'param': '', 'code': 'pre_consume_token_quota_failed'}}", "output": ""}
19
- {"db_id": "farm", "question": "List the total number of horses on farms in ascending order.", "gold_sql": "SELECT Total_Horses FROM farm ORDER BY Total_Horses ASC", "pred_sql": "", "status": "Error: Error code: 403 - {'error': {'message': 'token quota is not enough, token remain quota: $0.000008, need quota: $0.000114 (request id: 20251013204554156810277QpyTydIp)', 'type': 'new_api_error', 'param': '', 'code': 'pre_consume_token_quota_failed'}}", "output": ""}
20
- {"db_id": "farm", "question": "What is the total horses record for each farm, sorted ascending?", "gold_sql": "SELECT Total_Horses FROM farm ORDER BY Total_Horses ASC", "pred_sql": "", "status": "Error: Error code: 403 - {'error': {'message': 'token quota is not enough, token remain quota: $0.000008, need quota: $0.000116 (request id: 202510132045548459300909IPAkckU)', 'type': 'new_api_error', 'param': '', 'code': 'pre_consume_token_quota_failed'}}", "output": ""}