dvwn commited on
Commit
dfa643b
·
1 Parent(s): f160d1e

Update evaluation mode

Browse files

- To evaluate which categories causes the low ex and esm.
- Adding new method for model registry on hf_engine.py file
- Adding models for testing

src/nl2sql/__pycache__/hf_engine.cpython-313.pyc CHANGED
Binary files a/src/nl2sql/__pycache__/hf_engine.cpython-313.pyc and b/src/nl2sql/__pycache__/hf_engine.cpython-313.pyc differ
 
src/nl2sql/hf_engine.py CHANGED
@@ -2,12 +2,24 @@
2
  # This module defines the HuggingFace-based engine for generating SQL queries from natural language questions.
3
  import os
4
  from huggingface_hub import InferenceClient
 
5
  from langchain_core.language_models.llms import LLM
6
  from typing import Any, List, Optional
7
 
8
  # Default Model
9
  # DEFAULT_MODEL_ID = "defog/llama-3-sqlcoder-8b:featherless-ai"
10
- DEFAULT_MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct:featherless-ai"
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Custom LangChain wrapper for HuggingFace Inference API
13
  class HFChatWrapper(LLM):
@@ -33,8 +45,9 @@ class HFChatWrapper(LLM):
33
  return "huggingface_inference_client"
34
 
35
  # Initialize the HuggingFace endpoint using the InferenceClient
36
- def get_llm(model_id: str = DEFAULT_MODEL_ID):
37
  """
 
38
  Initializes the HuggingFace InferenceClient and returns an LLM instance for generating SQL queries.
39
  """
40
  # Load HuggingFace API token from environment variable
@@ -42,10 +55,40 @@ def get_llm(model_id: str = DEFAULT_MODEL_ID):
42
  if not hf_token:
43
  raise ValueError("HuggingFace API token not found!")
44
 
 
45
  print(f"Initializing HuggingFace InferenceClient with model: {model_id}")
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # Initialize the HuggingFace InferenceClient
48
- client = InferenceClient(api_key=hf_token)
49
- llm = HFChatWrapper(client=client, model_id=model_id)
 
 
 
 
 
 
50
 
51
- return llm
 
 
 
 
 
 
 
2
  # This module defines the HuggingFace-based engine for generating SQL queries from natural language questions.
3
  import os
4
  from huggingface_hub import InferenceClient
5
+ from langchain_huggingface import HuggingFaceEndpoint
6
  from langchain_core.language_models.llms import LLM
7
  from typing import Any, List, Optional
8
 
9
  # Default Model
10
  # DEFAULT_MODEL_ID = "defog/llama-3-sqlcoder-8b:featherless-ai"
11
+ # DEFAULT_MODEL_ID = "defog/sqlcoder-7b-2"
12
+ # DEFAULT_MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct:featherless-ai"
13
+ # Model Registry: Add several model to be tested
14
+ MODEL_REGISTRY = {
15
+ "defog/sqlcoder-7b-2": "text",
16
+ "Qwen/Qwen2.5-Coder-7B-Instruct:featherless-ai": "chat",
17
+ "Qwen/Qwen2.5-Coder-32B-Instruct:featherless-ai": "chat",
18
+ "defog/llama-3-sqlcoder-8b:featherless-ai": "chat"
19
+ #"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B:featherless-ai": "chat"
20
+ }
21
+
22
+ ACTIVE_MODEL_ID = "Qwen/Qwen2.5-Coder-32B-Instruct:featherless-ai"
23
 
24
  # Custom LangChain wrapper for HuggingFace Inference API
25
  class HFChatWrapper(LLM):
 
45
  return "huggingface_inference_client"
46
 
47
  # Initialize the HuggingFace endpoint using the InferenceClient
48
+ def get_llm(model_id: str = ACTIVE_MODEL_ID):
49
  """
50
+ Automatically detects the model type and returns the correct LangChain interface.
51
  Initializes the HuggingFace InferenceClient and returns an LLM instance for generating SQL queries.
52
  """
53
  # Load HuggingFace API token from environment variable
 
55
  if not hf_token:
56
  raise ValueError("HuggingFace API token not found!")
57
 
58
+ model_type = MODEL_REGISTRY.get(model_id, "chat")
59
  print(f"Initializing HuggingFace InferenceClient with model: {model_id}")
60
 
61
+ if model_type == "chat":
62
+ client = InferenceClient(api_key=hf_token)
63
+ return HFChatWrapper(client=client, model_id=model_id)
64
+ elif model_type == "text":
65
+ # Route to standard Text Generation API
66
+ return HuggingFaceEndpoint(
67
+ repo_id=model_id,
68
+ task="text-generation",
69
+ max_new_tokens=512,
70
+ temperature=0.0,
71
+ huggingfacehub_api_token=hf_token,
72
+ do_sample=False,
73
+ return_full_text=False
74
+ )
75
+ else:
76
+ raise ValueError(f"Unknown model type: {model_type}")
77
+
78
  # Initialize the HuggingFace InferenceClient
79
+ #client = InferenceClient(api_key=hf_token)
80
+ #llm = HFChatWrapper(client=client, model_id=model_id)
81
+
82
+ #return llm
83
+
84
+ if __name__=="__main__":
85
+ from dotenv import load_dotenv
86
+ load_dotenv()
87
 
88
+ try:
89
+ test_llm = get_llm()
90
+ print("Model loaded successfully! Running a quick ping...")
91
+ response = test_llm.invoke("write a single SQL statement to count all rows in a table name 'Employee'.")
92
+ print(f"\nResponse:\n{response}")
93
+ except Exception as e:
94
+ print(f"Error during LLM initialization: {e}")
src/scripts/evaluation_mode.py CHANGED
@@ -1,44 +1,75 @@
1
  # Path: src/scripts/evaluation_mode.py
2
  # Evaluation script for Hugging Face SQL generation.
3
  import json
 
4
  from pathlib import Path
5
  import pandas as pd
6
  from src.database.db_manager import get_db_connection
7
  from src.nl2sql.sql_agent import nl2sql_agent
 
8
 
9
  TEST_CASES_PATH = Path("src/scripts/test_cases.json")
10
  RESULTS_PATH = Path("hf_evaluation_results.json")
11
 
12
  def _normalize_dataframe(dataframe: pd.DataFrame) -> pd.DataFrame:
13
  # Normalize dataframe to ensure accurate comparison
 
 
 
 
 
14
  normalized = dataframe.copy()
15
- normalized.columns = [str(column).lower() for column in normalized.columns]
16
 
17
  for column in normalized.columns:
18
  normalized[column] = normalized[column].map(
19
  lambda value: round(float(value), 6)
20
- if isinstance(value, float)
21
  else value
22
  )
23
 
24
  sort_columns = list(normalized.columns)
25
  if sort_columns:
26
- normalized = normalized.sort_values(by=sort_columns, kind="mergesort").reset_index(drop=True)
27
 
28
  return normalized
29
 
30
- # Compare generated SQL results with expected results
31
- def compare_results(df_generated: pd.DataFrame, df_gold: pd.DataFrame) -> bool:
32
- """Compare generated and expected query results."""
 
 
 
33
  if df_generated is None or df_gold is None:
34
  return False
35
 
36
  try:
37
  normalized_generated = _normalize_dataframe(df_generated)
38
  normalized_gold = _normalize_dataframe(df_gold)
39
- return normalized_generated.equals(normalized_gold)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  except Exception as error:
41
- print(f"Error comparing results: {error}")
42
  return False
43
 
44
  def run_evaluation():
@@ -50,58 +81,67 @@ def run_evaluation():
50
  test_cases = json.load(handle)
51
 
52
  results = []
53
- correct_count = 0
 
54
 
55
  print(f"Running evaluation on {len(test_cases)} test cases...\n")
56
 
57
  for case in test_cases:
58
- question = case["question"]
59
- print(f"Testing ID {case['id']}: {question[:50]}...")
 
 
 
60
 
61
  # Implement agent to handle RAG retrieval and SQL generation
62
  agent_response = nl2sql_agent(user_question=question)
63
  generated_sql = agent_response.get("query", "")
64
 
 
 
 
 
 
 
 
65
  connection = get_db_connection()
66
  if connection is None:
67
  raise RuntimeError("Unable to connect to the SQLite database.")
68
 
69
  try:
70
  df_generated = pd.read_sql_query(generated_sql, connection)
71
- df_gold = pd.read_sql_query(case["gold_sql"], connection)
72
-
73
- is_correct = compare_results(df_generated, df_gold)
74
- if is_correct:
75
- correct_count += 1
76
-
77
- results.append(
78
- {
79
- "id": case["id"],
80
- "question": question,
81
- "status": "PASS" if is_correct else "FAIL",
82
- "generated_sql": generated_sql,
83
- "gold_sql": case["gold_sql"],
84
- }
85
- )
86
  except Exception as error:
87
- results.append(
88
- {
89
- "id": case["id"],
90
- "question": question,
91
- "status": "ERROR",
92
- "generated_sql": generated_sql,
93
- "gold_sql": case["gold_sql"],
94
- "error": str(error),
95
- }
96
- )
97
  finally:
98
  connection.close()
99
 
100
- accuracy = (correct_count / len(test_cases)) * 100 if test_cases else 0.0
101
- print("\nEVALUATION COMPLETE")
102
- print(f"Total Test Cases: {len(test_cases)}")
103
- print(f"Correctly Generated SQL: {correct_count} / {len(test_cases)}")
104
- print(f"Execution Accuracy: {accuracy:.2f}%")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  with RESULTS_PATH.open("w", encoding="utf-8") as handle:
107
- json.dump(results, handle, indent=4)
 
 
 
1
  # Path: src/scripts/evaluation_mode.py
2
  # Evaluation script for Hugging Face SQL generation.
3
  import json
4
+ import sqlglot
5
  from pathlib import Path
6
  import pandas as pd
7
  from src.database.db_manager import get_db_connection
8
  from src.nl2sql.sql_agent import nl2sql_agent
9
+ from src.scripts.taxonomy_report import print_taxonomyReport
10
 
11
  TEST_CASES_PATH = Path("src/scripts/test_cases.json")
12
  RESULTS_PATH = Path("hf_evaluation_results.json")
13
 
14
  def _normalize_dataframe(dataframe: pd.DataFrame) -> pd.DataFrame:
15
  # Normalize dataframe to ensure accurate comparison
16
+ """
17
+ Standardize dataframes for Execution Accuracy (EX).
18
+ - Ensures Order Agnoticism by sorting all values.
19
+ - Prepares for Column Agnoticism by focuing on value comparison rather than column names.
20
+ """
21
  normalized = dataframe.copy()
22
+ #normalized.columns = [str(column).lower() for column in normalized.columns]
23
 
24
  for column in normalized.columns:
25
  normalized[column] = normalized[column].map(
26
  lambda value: round(float(value), 6)
27
+ if isinstance(value, (float, int))
28
  else value
29
  )
30
 
31
  sort_columns = list(normalized.columns)
32
  if sort_columns:
33
+ normalized = normalized.sort_values(by=sort_columns).reset_index(drop=True)
34
 
35
  return normalized
36
 
37
+ # EX: Compare generated SQL results with expected results
38
+ def calculate_ex(df_generated: pd.DataFrame, df_gold: pd.DataFrame) -> bool:
39
+ """
40
+ Execution Accuracy (EX): Compare generated SQL results with expected results.
41
+ - Column Name Agnostic: Use .values to ignore header differences.
42
+ """
43
  if df_generated is None or df_gold is None:
44
  return False
45
 
46
  try:
47
  normalized_generated = _normalize_dataframe(df_generated)
48
  normalized_gold = _normalize_dataframe(df_gold)
49
+
50
+ if normalized_generated.shape != normalized_gold.shape:
51
+ return False
52
+
53
+ return bool((normalized_generated.values == normalized_gold.values).all())
54
+ # return normalized_generated.equals(normalized_gold)
55
+ except Exception as error:
56
+ print(f"EX Evaluation Error: {error}")
57
+ return False
58
+
59
+ def calculate_esm(generated_sql: str, gold_sql: str) -> bool:
60
+ """
61
+ Exact Set Match (ESM): Compare AST structure using sqlglot.
62
+ - Ignores formatting, capitalization, and minor syntactic sugar.
63
+ """
64
+ try:
65
+ # Parse both SQL queries into expressions
66
+ generated_exp = sqlglot.parse_one(generated_sql, read=None)
67
+ gold_exp = sqlglot.parse_one(gold_sql, read=None)
68
+
69
+ # Compare the expressions for structural equivalence
70
+ return generated_exp == gold_exp
71
  except Exception as error:
72
+ print(f"ESM Evaluation Error: {error}")
73
  return False
74
 
75
  def run_evaluation():
 
81
  test_cases = json.load(handle)
82
 
83
  results = []
84
+ ex_count = 0
85
+ esm_count = 0
86
 
87
  print(f"Running evaluation on {len(test_cases)} test cases...\n")
88
 
89
  for case in test_cases:
90
+ id = case.get("id")
91
+ question = case.get("question")
92
+ gold_sql = case.get("gold_sql")
93
+ taxonomy = case.get("taxonomy", "Unknown")
94
+ # print(f"Testing ID {id}: {question[:50]}...")
95
 
96
  # Implement agent to handle RAG retrieval and SQL generation
97
  agent_response = nl2sql_agent(user_question=question)
98
  generated_sql = agent_response.get("query", "")
99
 
100
+ # ESM Evaluation
101
+ esm_result = calculate_esm(generated_sql, gold_sql)
102
+ if esm_result:
103
+ esm_count += 1
104
+
105
+ # EX Evaluation
106
+ ex_result = False
107
  connection = get_db_connection()
108
  if connection is None:
109
  raise RuntimeError("Unable to connect to the SQLite database.")
110
 
111
  try:
112
  df_generated = pd.read_sql_query(generated_sql, connection)
113
+ df_gold = pd.read_sql_query(gold_sql, connection)
114
+
115
+ ex_result = calculate_ex(df_generated, df_gold)
116
+ if ex_result:
117
+ ex_count += 1
 
 
 
 
 
 
 
 
 
 
118
  except Exception as error:
119
+ print(f"Error executing SQL for ID {id}: {error}")
 
 
 
 
 
 
 
 
 
120
  finally:
121
  connection.close()
122
 
123
+ results.append({
124
+ "id": id,
125
+ "question": question,
126
+ "taxonomy": taxonomy,
127
+ "ex_pass": ex_result,
128
+ "esm_pass": esm_result,
129
+ "generated_sql": generated_sql,
130
+ "gold_sql": gold_sql
131
+ })
132
+
133
+ # Summary Statistics
134
+ total = len(test_cases)
135
+ ex_accuracy = (ex_count / total) * 100 if total > 0 else 0
136
+ esm_accuracy = (esm_count / total) * 100 if total > 0 else 0
137
+
138
+ print("\nEVALUATION SUMMARY")
139
+ print("-" * 40)
140
+ print(f"Total Test Cases: {total}")
141
+ print(f"Execution Accuracy (EX): {ex_accuracy:.2f}% ({ex_count}/{total})")
142
+ print(f"Exact Set Match (ESM): {esm_accuracy:.2f}% ({esm_count}/{total})")
143
 
144
  with RESULTS_PATH.open("w", encoding="utf-8") as handle:
145
+ json.dump(results, handle, indent=4)
146
+
147
+ print_taxonomyReport(results)
src/scripts/taxonomy_report.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Path: src/scripts/taxonomy_report.py
2
+ # Generate a taxonomy report to identify which taxonomy tags model struggles with
3
+ import json
4
+ import pandas as pd
5
+ from pathlib import Path
6
+
7
+ def print_taxonomyReport(results_data):
8
+ """
9
+ Generates and prints taxonomy breakdown.
10
+ Accepts either a list of dictionaries (from memory) or reads from the default JSON
11
+ """
12
+ if not results_data:
13
+ results_path = Path("hf_evaluation_results.json")
14
+ if results_path.exists():
15
+ with open(results_path, "r", encoding="utf-8") as f:
16
+ results_data = json.load(f)
17
+ else:
18
+ print("No data provided and results file not found.")
19
+ return
20
+
21
+ if not results_data:
22
+ return
23
+
24
+ df = pd.DataFrame(results_data)
25
+ df['taxonomy'] = df['taxonomy'].fillna("Unknown").astype(str)
26
+ df['taxonomy'] = df['taxonomy'].str.split(', ')
27
+ df_exploded = df.explode('taxonomy')
28
+
29
+ # Calculate Accuract per Taxonomy Tag
30
+ taxonomy_summary = df_exploded.groupby('taxonomy').agg(
31
+ total_cases = ('id', 'count'),
32
+ ex_passed = ('ex_pass', 'sum'),
33
+ esm_passed = ('esm_pass', 'sum')
34
+ )
35
+
36
+ taxonomy_summary['ex_acc'] = (taxonomy_summary['ex_passed'] / taxonomy_summary['total_cases']) * 100
37
+ taxonomy_summary['esm_acc'] = (taxonomy_summary['esm_passed'] / taxonomy_summary['total_cases']) * 100
38
+
39
+ print("\n" + "="*50)
40
+ print("TAXONOMY PERFORMANCE REPORT SUMMARY")
41
+ print("-"*50)
42
+
43
+ # Sort by execution accuracy
44
+ final_report = taxonomy_summary.sort_values(by='ex_acc', ascending=False)
45
+ print(final_report.to_string())
46
+
47
+ # To run the script on its own manually
48
+ if __name__ == "__main__":
49
+ print_taxonomyReport(None)
src/scripts/test_cases.json CHANGED
@@ -1,76 +1,106 @@
1
  [
2
  {
3
  "id": 1,
 
 
4
  "question": "List all the artists name in the database.",
5
  "gold_sql": "SELECT Name FROM Artist;"
6
  },
7
  {
8
  "id": 2,
 
 
9
  "question": "How many genres are there?",
10
  "gold_sql": "SELECT COUNT(*) FROM Genre;"
11
  },
12
  {
13
  "id": 3,
 
 
14
  "question": "List the names of the first 5 tracks.",
15
  "gold_sql": "SELECT Name FROM Track LIMIT 5;"
16
  },
17
  {
18
  "id": 4,
 
 
19
  "question": "Count the number of customers located in the USA.",
20
  "gold_sql": "SELECT COUNT(*) FROM Customer WHERE Country = 'USA';"
21
  },
22
  {
23
  "id": 5,
 
 
24
  "question": "Find all invoices for the customer with ID 1.",
25
  "gold_sql": "SELECT * FROM Invoice WHERE CustomerId = 1;"
26
  },
27
  {
28
  "id": 6,
 
 
29
  "question": "List each album title along with the artist's name.",
30
  "gold_sql": "SELECT Album.Title, Artist.Name FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId;"
31
  },
32
  {
33
  "id": 7,
 
 
34
  "question": "How many tracks belong to the 'Rock' genre?",
35
  "gold_sql": "SELECT COUNT(*) FROM Track JOIN Genre ON Track.GenreId = Genre.GenreId WHERE Genre.Name = 'Rock';"
36
  },
37
  {
38
  "id": 8,
 
 
39
  "question": "Show the total revenue generated from each country.",
40
  "gold_sql": "SELECT BillingCountry, SUM(Total) FROM Invoice GROUP BY BillingCountry;"
41
  },
42
  {
43
  "id": 9,
 
 
44
  "question": "Find the total number of items sold for each media type.",
45
  "gold_sql": "SELECT MediaType.Name, SUM(InvoiceLine.Quantity) FROM InvoiceLine JOIN Track ON InvoiceLine.TrackId = Track.TrackId JOIN MediaType ON Track.MediaTypeId = MediaType.MediaTypeId GROUP BY MediaType.Name;"
46
  },
47
  {
48
  "id": 10,
 
 
49
  "question": "List the first and last names of all employees who are Sales Support Agents.",
50
  "gold_sql": "SELECT FirstName, LastName FROM Employee WHERE Title = 'Sales Support Agent';"
51
  },
52
  {
53
  "id": 11,
 
 
54
  "question": "List the top 5 customers who have spent the most money in total.",
55
  "gold_sql": "SELECT c.FirstName, c.LastName, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY TotalSpent DESC LIMIT 5;"
56
  },
57
  {
58
  "id": 12,
 
 
59
  "question": "Which artist has the most tracks in the database? Give the name and count.",
60
  "gold_sql": "SELECT ar.Name, COUNT(t.TrackId) as TrackCount FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId GROUP BY ar.ArtistId ORDER BY TrackCount DESC LIMIT 1;"
61
  },
62
  {
63
  "id": 13,
 
 
64
  "question": "Which genres have more than 100 tracks? List the genre name and count.",
65
  "gold_sql": "SELECT g.Name, COUNT(t.TrackId) as TrackCount FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId HAVING TrackCount > 100;"
66
  },
67
  {
68
  "id": 14,
 
 
69
  "question": "Calculate the average track length in seconds for each genre.",
70
  "gold_sql": "SELECT g.Name, AVG(t.Milliseconds) / 1000.0 as AvgSeconds FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId;"
71
  },
72
  {
73
  "id": 15,
 
 
74
  "question": "Identify the artist who has earned the most revenue from customers in Canada.",
75
  "gold_sql": "SELECT ar.Name, SUM(il.UnitPrice * il.Quantity) AS Revenue FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId JOIN InvoiceLine il ON t.TrackId = il.TrackId JOIN Invoice i ON il.InvoiceId = i.InvoiceId WHERE i.BillingCountry = 'Canada' GROUP BY ar.ArtistId ORDER BY Revenue DESC LIMIT 1;"
76
  }
 
1
  [
2
  {
3
  "id": 1,
4
+ "difficulty": "easy",
5
+ "taxonomy": "Selection",
6
  "question": "List all the artists name in the database.",
7
  "gold_sql": "SELECT Name FROM Artist;"
8
  },
9
  {
10
  "id": 2,
11
+ "difficulty": "easy",
12
+ "taxonomy": "Aggregation",
13
  "question": "How many genres are there?",
14
  "gold_sql": "SELECT COUNT(*) FROM Genre;"
15
  },
16
  {
17
  "id": 3,
18
+ "difficulty": "easy",
19
+ "taxonomy": "Selection, Limit",
20
  "question": "List the names of the first 5 tracks.",
21
  "gold_sql": "SELECT Name FROM Track LIMIT 5;"
22
  },
23
  {
24
  "id": 4,
25
+ "difficulty": "easy",
26
+ "taxonomy": "Aggregation, Filtering",
27
  "question": "Count the number of customers located in the USA.",
28
  "gold_sql": "SELECT COUNT(*) FROM Customer WHERE Country = 'USA';"
29
  },
30
  {
31
  "id": 5,
32
+ "difficulty": "easy",
33
+ "taxonomy": "Selection, Filtering",
34
  "question": "Find all invoices for the customer with ID 1.",
35
  "gold_sql": "SELECT * FROM Invoice WHERE CustomerId = 1;"
36
  },
37
  {
38
  "id": 6,
39
+ "difficulty": "medium",
40
+ "taxonomy": "Simple Join",
41
  "question": "List each album title along with the artist's name.",
42
  "gold_sql": "SELECT Album.Title, Artist.Name FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId;"
43
  },
44
  {
45
  "id": 7,
46
+ "difficulty": "medium",
47
+ "taxonomy": "Simple Join, Filtering, Aggregation",
48
  "question": "How many tracks belong to the 'Rock' genre?",
49
  "gold_sql": "SELECT COUNT(*) FROM Track JOIN Genre ON Track.GenreId = Genre.GenreId WHERE Genre.Name = 'Rock';"
50
  },
51
  {
52
  "id": 8,
53
+ "difficulty": "medium",
54
+ "taxonomy": "Aggregation, Grouping",
55
  "question": "Show the total revenue generated from each country.",
56
  "gold_sql": "SELECT BillingCountry, SUM(Total) FROM Invoice GROUP BY BillingCountry;"
57
  },
58
  {
59
  "id": 9,
60
+ "difficulty": "medium",
61
+ "taxonomy": "Multi-Join, Aggregation, Grouping",
62
  "question": "Find the total number of items sold for each media type.",
63
  "gold_sql": "SELECT MediaType.Name, SUM(InvoiceLine.Quantity) FROM InvoiceLine JOIN Track ON InvoiceLine.TrackId = Track.TrackId JOIN MediaType ON Track.MediaTypeId = MediaType.MediaTypeId GROUP BY MediaType.Name;"
64
  },
65
  {
66
  "id": 10,
67
+ "difficulty": "easy",
68
+ "taxonomy": "Selection, Filtering",
69
  "question": "List the first and last names of all employees who are Sales Support Agents.",
70
  "gold_sql": "SELECT FirstName, LastName FROM Employee WHERE Title = 'Sales Support Agent';"
71
  },
72
  {
73
  "id": 11,
74
+ "difficulty": "medium",
75
+ "taxonomy": "Simple Join, Aggregation, Grouping, Ordering, Limit",
76
  "question": "List the top 5 customers who have spent the most money in total.",
77
  "gold_sql": "SELECT c.FirstName, c.LastName, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY TotalSpent DESC LIMIT 5;"
78
  },
79
  {
80
  "id": 12,
81
+ "difficulty": "hard",
82
+ "taxonomy": "Multi-Join, Aggregation, Grouping, Ordering, Limit",
83
  "question": "Which artist has the most tracks in the database? Give the name and count.",
84
  "gold_sql": "SELECT ar.Name, COUNT(t.TrackId) as TrackCount FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId GROUP BY ar.ArtistId ORDER BY TrackCount DESC LIMIT 1;"
85
  },
86
  {
87
  "id": 13,
88
+ "difficulty": "medium",
89
+ "taxonomy": "Simple Join, Aggregation, Grouping, Having",
90
  "question": "Which genres have more than 100 tracks? List the genre name and count.",
91
  "gold_sql": "SELECT g.Name, COUNT(t.TrackId) as TrackCount FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId HAVING TrackCount > 100;"
92
  },
93
  {
94
  "id": 14,
95
+ "difficulty": "medium",
96
+ "taxonomy": "Simple Join, Aggregation, Arithmetic, Grouping",
97
  "question": "Calculate the average track length in seconds for each genre.",
98
  "gold_sql": "SELECT g.Name, AVG(t.Milliseconds) / 1000.0 as AvgSeconds FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId;"
99
  },
100
  {
101
  "id": 15,
102
+ "difficulty": "hard",
103
+ "taxonomy": "Multi-Join, Aggregation, Grouping, Ordering, Limit",
104
  "question": "Identify the artist who has earned the most revenue from customers in Canada.",
105
  "gold_sql": "SELECT ar.Name, SUM(il.UnitPrice * il.Quantity) AS Revenue FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId JOIN InvoiceLine il ON t.TrackId = il.TrackId JOIN Invoice i ON il.InvoiceId = i.InvoiceId WHERE i.BillingCountry = 'Canada' GROUP BY ar.ArtistId ORDER BY Revenue DESC LIMIT 1;"
106
  }