Spaces:
Running
Running
Update evaluation mode
Browse files- To evaluate which categories causes the low ex and esm.
- Adding new method for model registry on hf_engine.py file
- Adding models for testing
src/nl2sql/__pycache__/hf_engine.cpython-313.pyc
CHANGED
|
Binary files a/src/nl2sql/__pycache__/hf_engine.cpython-313.pyc and b/src/nl2sql/__pycache__/hf_engine.cpython-313.pyc differ
|
|
|
src/nl2sql/hf_engine.py
CHANGED
|
@@ -2,12 +2,24 @@
|
|
| 2 |
# This module defines the HuggingFace-based engine for generating SQL queries from natural language questions.
|
| 3 |
import os
|
| 4 |
from huggingface_hub import InferenceClient
|
|
|
|
| 5 |
from langchain_core.language_models.llms import LLM
|
| 6 |
from typing import Any, List, Optional
|
| 7 |
|
| 8 |
# Default Model
|
| 9 |
# DEFAULT_MODEL_ID = "defog/llama-3-sqlcoder-8b:featherless-ai"
|
| 10 |
-
DEFAULT_MODEL_ID = "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# Custom LangChain wrapper for HuggingFace Inference API
|
| 13 |
class HFChatWrapper(LLM):
|
|
@@ -33,8 +45,9 @@ class HFChatWrapper(LLM):
|
|
| 33 |
return "huggingface_inference_client"
|
| 34 |
|
| 35 |
# Initialize the HuggingFace endpoint using the InferenceClient
|
| 36 |
-
def get_llm(model_id: str =
|
| 37 |
"""
|
|
|
|
| 38 |
Initializes the HuggingFace InferenceClient and returns an LLM instance for generating SQL queries.
|
| 39 |
"""
|
| 40 |
# Load HuggingFace API token from environment variable
|
|
@@ -42,10 +55,40 @@ def get_llm(model_id: str = DEFAULT_MODEL_ID):
|
|
| 42 |
if not hf_token:
|
| 43 |
raise ValueError("HuggingFace API token not found!")
|
| 44 |
|
|
|
|
| 45 |
print(f"Initializing HuggingFace InferenceClient with model: {model_id}")
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
# Initialize the HuggingFace InferenceClient
|
| 48 |
-
client = InferenceClient(api_key=hf_token)
|
| 49 |
-
llm = HFChatWrapper(client=client, model_id=model_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
# This module defines the HuggingFace-based engine for generating SQL queries from natural language questions.
|
| 3 |
import os
|
| 4 |
from huggingface_hub import InferenceClient
|
| 5 |
+
from langchain_huggingface import HuggingFaceEndpoint
|
| 6 |
from langchain_core.language_models.llms import LLM
|
| 7 |
from typing import Any, List, Optional
|
| 8 |
|
| 9 |
# Default Model
|
| 10 |
# DEFAULT_MODEL_ID = "defog/llama-3-sqlcoder-8b:featherless-ai"
|
| 11 |
+
# DEFAULT_MODEL_ID = "defog/sqlcoder-7b-2"
|
| 12 |
+
# DEFAULT_MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct:featherless-ai"
|
| 13 |
+
# Model Registry: Add several model to be tested
|
| 14 |
+
MODEL_REGISTRY = {
|
| 15 |
+
"defog/sqlcoder-7b-2": "text",
|
| 16 |
+
"Qwen/Qwen2.5-Coder-7B-Instruct:featherless-ai": "chat",
|
| 17 |
+
"Qwen/Qwen2.5-Coder-32B-Instruct:featherless-ai": "chat",
|
| 18 |
+
"defog/llama-3-sqlcoder-8b:featherless-ai": "chat"
|
| 19 |
+
#"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B:featherless-ai": "chat"
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
ACTIVE_MODEL_ID = "Qwen/Qwen2.5-Coder-32B-Instruct:featherless-ai"
|
| 23 |
|
| 24 |
# Custom LangChain wrapper for HuggingFace Inference API
|
| 25 |
class HFChatWrapper(LLM):
|
|
|
|
| 45 |
return "huggingface_inference_client"
|
| 46 |
|
| 47 |
# Initialize the HuggingFace endpoint using the InferenceClient
|
| 48 |
+
def get_llm(model_id: str = ACTIVE_MODEL_ID):
|
| 49 |
"""
|
| 50 |
+
Automatically detects the model type and returns the correct LangChain interface.
|
| 51 |
Initializes the HuggingFace InferenceClient and returns an LLM instance for generating SQL queries.
|
| 52 |
"""
|
| 53 |
# Load HuggingFace API token from environment variable
|
|
|
|
| 55 |
if not hf_token:
|
| 56 |
raise ValueError("HuggingFace API token not found!")
|
| 57 |
|
| 58 |
+
model_type = MODEL_REGISTRY.get(model_id, "chat")
|
| 59 |
print(f"Initializing HuggingFace InferenceClient with model: {model_id}")
|
| 60 |
|
| 61 |
+
if model_type == "chat":
|
| 62 |
+
client = InferenceClient(api_key=hf_token)
|
| 63 |
+
return HFChatWrapper(client=client, model_id=model_id)
|
| 64 |
+
elif model_type == "text":
|
| 65 |
+
# Route to standard Text Generation API
|
| 66 |
+
return HuggingFaceEndpoint(
|
| 67 |
+
repo_id=model_id,
|
| 68 |
+
task="text-generation",
|
| 69 |
+
max_new_tokens=512,
|
| 70 |
+
temperature=0.0,
|
| 71 |
+
huggingfacehub_api_token=hf_token,
|
| 72 |
+
do_sample=False,
|
| 73 |
+
return_full_text=False
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
| 77 |
+
|
| 78 |
# Initialize the HuggingFace InferenceClient
|
| 79 |
+
#client = InferenceClient(api_key=hf_token)
|
| 80 |
+
#llm = HFChatWrapper(client=client, model_id=model_id)
|
| 81 |
+
|
| 82 |
+
#return llm
|
| 83 |
+
|
| 84 |
+
if __name__=="__main__":
|
| 85 |
+
from dotenv import load_dotenv
|
| 86 |
+
load_dotenv()
|
| 87 |
|
| 88 |
+
try:
|
| 89 |
+
test_llm = get_llm()
|
| 90 |
+
print("Model loaded successfully! Running a quick ping...")
|
| 91 |
+
response = test_llm.invoke("write a single SQL statement to count all rows in a table name 'Employee'.")
|
| 92 |
+
print(f"\nResponse:\n{response}")
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print(f"Error during LLM initialization: {e}")
|
src/scripts/evaluation_mode.py
CHANGED
|
@@ -1,44 +1,75 @@
|
|
| 1 |
# Path: src/scripts/evaluation_mode.py
|
| 2 |
# Evaluation script for Hugging Face SQL generation.
|
| 3 |
import json
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
import pandas as pd
|
| 6 |
from src.database.db_manager import get_db_connection
|
| 7 |
from src.nl2sql.sql_agent import nl2sql_agent
|
|
|
|
| 8 |
|
| 9 |
TEST_CASES_PATH = Path("src/scripts/test_cases.json")
|
| 10 |
RESULTS_PATH = Path("hf_evaluation_results.json")
|
| 11 |
|
| 12 |
def _normalize_dataframe(dataframe: pd.DataFrame) -> pd.DataFrame:
|
| 13 |
# Normalize dataframe to ensure accurate comparison
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
normalized = dataframe.copy()
|
| 15 |
-
normalized.columns = [str(column).lower() for column in normalized.columns]
|
| 16 |
|
| 17 |
for column in normalized.columns:
|
| 18 |
normalized[column] = normalized[column].map(
|
| 19 |
lambda value: round(float(value), 6)
|
| 20 |
-
if isinstance(value, float)
|
| 21 |
else value
|
| 22 |
)
|
| 23 |
|
| 24 |
sort_columns = list(normalized.columns)
|
| 25 |
if sort_columns:
|
| 26 |
-
normalized = normalized.sort_values(by=sort_columns
|
| 27 |
|
| 28 |
return normalized
|
| 29 |
|
| 30 |
-
# Compare generated SQL results with expected results
|
| 31 |
-
def
|
| 32 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 33 |
if df_generated is None or df_gold is None:
|
| 34 |
return False
|
| 35 |
|
| 36 |
try:
|
| 37 |
normalized_generated = _normalize_dataframe(df_generated)
|
| 38 |
normalized_gold = _normalize_dataframe(df_gold)
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
except Exception as error:
|
| 41 |
-
print(f"
|
| 42 |
return False
|
| 43 |
|
| 44 |
def run_evaluation():
|
|
@@ -50,58 +81,67 @@ def run_evaluation():
|
|
| 50 |
test_cases = json.load(handle)
|
| 51 |
|
| 52 |
results = []
|
| 53 |
-
|
|
|
|
| 54 |
|
| 55 |
print(f"Running evaluation on {len(test_cases)} test cases...\n")
|
| 56 |
|
| 57 |
for case in test_cases:
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
# Implement agent to handle RAG retrieval and SQL generation
|
| 62 |
agent_response = nl2sql_agent(user_question=question)
|
| 63 |
generated_sql = agent_response.get("query", "")
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
connection = get_db_connection()
|
| 66 |
if connection is None:
|
| 67 |
raise RuntimeError("Unable to connect to the SQLite database.")
|
| 68 |
|
| 69 |
try:
|
| 70 |
df_generated = pd.read_sql_query(generated_sql, connection)
|
| 71 |
-
df_gold = pd.read_sql_query(
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
if
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
results.append(
|
| 78 |
-
{
|
| 79 |
-
"id": case["id"],
|
| 80 |
-
"question": question,
|
| 81 |
-
"status": "PASS" if is_correct else "FAIL",
|
| 82 |
-
"generated_sql": generated_sql,
|
| 83 |
-
"gold_sql": case["gold_sql"],
|
| 84 |
-
}
|
| 85 |
-
)
|
| 86 |
except Exception as error:
|
| 87 |
-
|
| 88 |
-
{
|
| 89 |
-
"id": case["id"],
|
| 90 |
-
"question": question,
|
| 91 |
-
"status": "ERROR",
|
| 92 |
-
"generated_sql": generated_sql,
|
| 93 |
-
"gold_sql": case["gold_sql"],
|
| 94 |
-
"error": str(error),
|
| 95 |
-
}
|
| 96 |
-
)
|
| 97 |
finally:
|
| 98 |
connection.close()
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
with RESULTS_PATH.open("w", encoding="utf-8") as handle:
|
| 107 |
-
json.dump(results, handle, indent=4)
|
|
|
|
|
|
|
|
|
| 1 |
# Path: src/scripts/evaluation_mode.py
|
| 2 |
# Evaluation script for Hugging Face SQL generation.
|
| 3 |
import json
|
| 4 |
+
import sqlglot
|
| 5 |
from pathlib import Path
|
| 6 |
import pandas as pd
|
| 7 |
from src.database.db_manager import get_db_connection
|
| 8 |
from src.nl2sql.sql_agent import nl2sql_agent
|
| 9 |
+
from src.scripts.taxonomy_report import print_taxonomyReport
|
| 10 |
|
| 11 |
TEST_CASES_PATH = Path("src/scripts/test_cases.json")
|
| 12 |
RESULTS_PATH = Path("hf_evaluation_results.json")
|
| 13 |
|
| 14 |
def _normalize_dataframe(dataframe: pd.DataFrame) -> pd.DataFrame:
|
| 15 |
# Normalize dataframe to ensure accurate comparison
|
| 16 |
+
"""
|
| 17 |
+
Standardize dataframes for Execution Accuracy (EX).
|
| 18 |
+
- Ensures Order Agnoticism by sorting all values.
|
| 19 |
+
- Prepares for Column Agnoticism by focuing on value comparison rather than column names.
|
| 20 |
+
"""
|
| 21 |
normalized = dataframe.copy()
|
| 22 |
+
#normalized.columns = [str(column).lower() for column in normalized.columns]
|
| 23 |
|
| 24 |
for column in normalized.columns:
|
| 25 |
normalized[column] = normalized[column].map(
|
| 26 |
lambda value: round(float(value), 6)
|
| 27 |
+
if isinstance(value, (float, int))
|
| 28 |
else value
|
| 29 |
)
|
| 30 |
|
| 31 |
sort_columns = list(normalized.columns)
|
| 32 |
if sort_columns:
|
| 33 |
+
normalized = normalized.sort_values(by=sort_columns).reset_index(drop=True)
|
| 34 |
|
| 35 |
return normalized
|
| 36 |
|
| 37 |
+
# EX: Compare generated SQL results with expected results
|
| 38 |
+
def calculate_ex(df_generated: pd.DataFrame, df_gold: pd.DataFrame) -> bool:
|
| 39 |
+
"""
|
| 40 |
+
Execution Accuracy (EX): Compare generated SQL results with expected results.
|
| 41 |
+
- Column Name Agnostic: Use .values to ignore header differences.
|
| 42 |
+
"""
|
| 43 |
if df_generated is None or df_gold is None:
|
| 44 |
return False
|
| 45 |
|
| 46 |
try:
|
| 47 |
normalized_generated = _normalize_dataframe(df_generated)
|
| 48 |
normalized_gold = _normalize_dataframe(df_gold)
|
| 49 |
+
|
| 50 |
+
if normalized_generated.shape != normalized_gold.shape:
|
| 51 |
+
return False
|
| 52 |
+
|
| 53 |
+
return bool((normalized_generated.values == normalized_gold.values).all())
|
| 54 |
+
# return normalized_generated.equals(normalized_gold)
|
| 55 |
+
except Exception as error:
|
| 56 |
+
print(f"EX Evaluation Error: {error}")
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
def calculate_esm(generated_sql: str, gold_sql: str) -> bool:
|
| 60 |
+
"""
|
| 61 |
+
Exact Set Match (ESM): Compare AST structure using sqlglot.
|
| 62 |
+
- Ignores formatting, capitalization, and minor syntactic sugar.
|
| 63 |
+
"""
|
| 64 |
+
try:
|
| 65 |
+
# Parse both SQL queries into expressions
|
| 66 |
+
generated_exp = sqlglot.parse_one(generated_sql, read=None)
|
| 67 |
+
gold_exp = sqlglot.parse_one(gold_sql, read=None)
|
| 68 |
+
|
| 69 |
+
# Compare the expressions for structural equivalence
|
| 70 |
+
return generated_exp == gold_exp
|
| 71 |
except Exception as error:
|
| 72 |
+
print(f"ESM Evaluation Error: {error}")
|
| 73 |
return False
|
| 74 |
|
| 75 |
def run_evaluation():
|
|
|
|
| 81 |
test_cases = json.load(handle)
|
| 82 |
|
| 83 |
results = []
|
| 84 |
+
ex_count = 0
|
| 85 |
+
esm_count = 0
|
| 86 |
|
| 87 |
print(f"Running evaluation on {len(test_cases)} test cases...\n")
|
| 88 |
|
| 89 |
for case in test_cases:
|
| 90 |
+
id = case.get("id")
|
| 91 |
+
question = case.get("question")
|
| 92 |
+
gold_sql = case.get("gold_sql")
|
| 93 |
+
taxonomy = case.get("taxonomy", "Unknown")
|
| 94 |
+
# print(f"Testing ID {id}: {question[:50]}...")
|
| 95 |
|
| 96 |
# Implement agent to handle RAG retrieval and SQL generation
|
| 97 |
agent_response = nl2sql_agent(user_question=question)
|
| 98 |
generated_sql = agent_response.get("query", "")
|
| 99 |
|
| 100 |
+
# ESM Evaluation
|
| 101 |
+
esm_result = calculate_esm(generated_sql, gold_sql)
|
| 102 |
+
if esm_result:
|
| 103 |
+
esm_count += 1
|
| 104 |
+
|
| 105 |
+
# EX Evaluation
|
| 106 |
+
ex_result = False
|
| 107 |
connection = get_db_connection()
|
| 108 |
if connection is None:
|
| 109 |
raise RuntimeError("Unable to connect to the SQLite database.")
|
| 110 |
|
| 111 |
try:
|
| 112 |
df_generated = pd.read_sql_query(generated_sql, connection)
|
| 113 |
+
df_gold = pd.read_sql_query(gold_sql, connection)
|
| 114 |
+
|
| 115 |
+
ex_result = calculate_ex(df_generated, df_gold)
|
| 116 |
+
if ex_result:
|
| 117 |
+
ex_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
except Exception as error:
|
| 119 |
+
print(f"Error executing SQL for ID {id}: {error}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
finally:
|
| 121 |
connection.close()
|
| 122 |
|
| 123 |
+
results.append({
|
| 124 |
+
"id": id,
|
| 125 |
+
"question": question,
|
| 126 |
+
"taxonomy": taxonomy,
|
| 127 |
+
"ex_pass": ex_result,
|
| 128 |
+
"esm_pass": esm_result,
|
| 129 |
+
"generated_sql": generated_sql,
|
| 130 |
+
"gold_sql": gold_sql
|
| 131 |
+
})
|
| 132 |
+
|
| 133 |
+
# Summary Statistics
|
| 134 |
+
total = len(test_cases)
|
| 135 |
+
ex_accuracy = (ex_count / total) * 100 if total > 0 else 0
|
| 136 |
+
esm_accuracy = (esm_count / total) * 100 if total > 0 else 0
|
| 137 |
+
|
| 138 |
+
print("\nEVALUATION SUMMARY")
|
| 139 |
+
print("-" * 40)
|
| 140 |
+
print(f"Total Test Cases: {total}")
|
| 141 |
+
print(f"Execution Accuracy (EX): {ex_accuracy:.2f}% ({ex_count}/{total})")
|
| 142 |
+
print(f"Exact Set Match (ESM): {esm_accuracy:.2f}% ({esm_count}/{total})")
|
| 143 |
|
| 144 |
with RESULTS_PATH.open("w", encoding="utf-8") as handle:
|
| 145 |
+
json.dump(results, handle, indent=4)
|
| 146 |
+
|
| 147 |
+
print_taxonomyReport(results)
|
src/scripts/taxonomy_report.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Path: src/scripts/taxonomy_report.py
|
| 2 |
+
# Generate a taxonomy report to identify which taxonomy tags model struggles with
|
| 3 |
+
import json
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
def print_taxonomyReport(results_data):
|
| 8 |
+
"""
|
| 9 |
+
Generates and prints taxonomy breakdown.
|
| 10 |
+
Accepts either a list of dictionaries (from memory) or reads from the default JSON
|
| 11 |
+
"""
|
| 12 |
+
if not results_data:
|
| 13 |
+
results_path = Path("hf_evaluation_results.json")
|
| 14 |
+
if results_path.exists():
|
| 15 |
+
with open(results_path, "r", encoding="utf-8") as f:
|
| 16 |
+
results_data = json.load(f)
|
| 17 |
+
else:
|
| 18 |
+
print("No data provided and results file not found.")
|
| 19 |
+
return
|
| 20 |
+
|
| 21 |
+
if not results_data:
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
df = pd.DataFrame(results_data)
|
| 25 |
+
df['taxonomy'] = df['taxonomy'].fillna("Unknown").astype(str)
|
| 26 |
+
df['taxonomy'] = df['taxonomy'].str.split(', ')
|
| 27 |
+
df_exploded = df.explode('taxonomy')
|
| 28 |
+
|
| 29 |
+
# Calculate Accuract per Taxonomy Tag
|
| 30 |
+
taxonomy_summary = df_exploded.groupby('taxonomy').agg(
|
| 31 |
+
total_cases = ('id', 'count'),
|
| 32 |
+
ex_passed = ('ex_pass', 'sum'),
|
| 33 |
+
esm_passed = ('esm_pass', 'sum')
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
taxonomy_summary['ex_acc'] = (taxonomy_summary['ex_passed'] / taxonomy_summary['total_cases']) * 100
|
| 37 |
+
taxonomy_summary['esm_acc'] = (taxonomy_summary['esm_passed'] / taxonomy_summary['total_cases']) * 100
|
| 38 |
+
|
| 39 |
+
print("\n" + "="*50)
|
| 40 |
+
print("TAXONOMY PERFORMANCE REPORT SUMMARY")
|
| 41 |
+
print("-"*50)
|
| 42 |
+
|
| 43 |
+
# Sort by execution accuracy
|
| 44 |
+
final_report = taxonomy_summary.sort_values(by='ex_acc', ascending=False)
|
| 45 |
+
print(final_report.to_string())
|
| 46 |
+
|
| 47 |
+
# To run the script on its own manually
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
print_taxonomyReport(None)
|
src/scripts/test_cases.json
CHANGED
|
@@ -1,76 +1,106 @@
|
|
| 1 |
[
|
| 2 |
{
|
| 3 |
"id": 1,
|
|
|
|
|
|
|
| 4 |
"question": "List all the artists name in the database.",
|
| 5 |
"gold_sql": "SELECT Name FROM Artist;"
|
| 6 |
},
|
| 7 |
{
|
| 8 |
"id": 2,
|
|
|
|
|
|
|
| 9 |
"question": "How many genres are there?",
|
| 10 |
"gold_sql": "SELECT COUNT(*) FROM Genre;"
|
| 11 |
},
|
| 12 |
{
|
| 13 |
"id": 3,
|
|
|
|
|
|
|
| 14 |
"question": "List the names of the first 5 tracks.",
|
| 15 |
"gold_sql": "SELECT Name FROM Track LIMIT 5;"
|
| 16 |
},
|
| 17 |
{
|
| 18 |
"id": 4,
|
|
|
|
|
|
|
| 19 |
"question": "Count the number of customers located in the USA.",
|
| 20 |
"gold_sql": "SELECT COUNT(*) FROM Customer WHERE Country = 'USA';"
|
| 21 |
},
|
| 22 |
{
|
| 23 |
"id": 5,
|
|
|
|
|
|
|
| 24 |
"question": "Find all invoices for the customer with ID 1.",
|
| 25 |
"gold_sql": "SELECT * FROM Invoice WHERE CustomerId = 1;"
|
| 26 |
},
|
| 27 |
{
|
| 28 |
"id": 6,
|
|
|
|
|
|
|
| 29 |
"question": "List each album title along with the artist's name.",
|
| 30 |
"gold_sql": "SELECT Album.Title, Artist.Name FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId;"
|
| 31 |
},
|
| 32 |
{
|
| 33 |
"id": 7,
|
|
|
|
|
|
|
| 34 |
"question": "How many tracks belong to the 'Rock' genre?",
|
| 35 |
"gold_sql": "SELECT COUNT(*) FROM Track JOIN Genre ON Track.GenreId = Genre.GenreId WHERE Genre.Name = 'Rock';"
|
| 36 |
},
|
| 37 |
{
|
| 38 |
"id": 8,
|
|
|
|
|
|
|
| 39 |
"question": "Show the total revenue generated from each country.",
|
| 40 |
"gold_sql": "SELECT BillingCountry, SUM(Total) FROM Invoice GROUP BY BillingCountry;"
|
| 41 |
},
|
| 42 |
{
|
| 43 |
"id": 9,
|
|
|
|
|
|
|
| 44 |
"question": "Find the total number of items sold for each media type.",
|
| 45 |
"gold_sql": "SELECT MediaType.Name, SUM(InvoiceLine.Quantity) FROM InvoiceLine JOIN Track ON InvoiceLine.TrackId = Track.TrackId JOIN MediaType ON Track.MediaTypeId = MediaType.MediaTypeId GROUP BY MediaType.Name;"
|
| 46 |
},
|
| 47 |
{
|
| 48 |
"id": 10,
|
|
|
|
|
|
|
| 49 |
"question": "List the first and last names of all employees who are Sales Support Agents.",
|
| 50 |
"gold_sql": "SELECT FirstName, LastName FROM Employee WHERE Title = 'Sales Support Agent';"
|
| 51 |
},
|
| 52 |
{
|
| 53 |
"id": 11,
|
|
|
|
|
|
|
| 54 |
"question": "List the top 5 customers who have spent the most money in total.",
|
| 55 |
"gold_sql": "SELECT c.FirstName, c.LastName, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY TotalSpent DESC LIMIT 5;"
|
| 56 |
},
|
| 57 |
{
|
| 58 |
"id": 12,
|
|
|
|
|
|
|
| 59 |
"question": "Which artist has the most tracks in the database? Give the name and count.",
|
| 60 |
"gold_sql": "SELECT ar.Name, COUNT(t.TrackId) as TrackCount FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId GROUP BY ar.ArtistId ORDER BY TrackCount DESC LIMIT 1;"
|
| 61 |
},
|
| 62 |
{
|
| 63 |
"id": 13,
|
|
|
|
|
|
|
| 64 |
"question": "Which genres have more than 100 tracks? List the genre name and count.",
|
| 65 |
"gold_sql": "SELECT g.Name, COUNT(t.TrackId) as TrackCount FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId HAVING TrackCount > 100;"
|
| 66 |
},
|
| 67 |
{
|
| 68 |
"id": 14,
|
|
|
|
|
|
|
| 69 |
"question": "Calculate the average track length in seconds for each genre.",
|
| 70 |
"gold_sql": "SELECT g.Name, AVG(t.Milliseconds) / 1000.0 as AvgSeconds FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId;"
|
| 71 |
},
|
| 72 |
{
|
| 73 |
"id": 15,
|
|
|
|
|
|
|
| 74 |
"question": "Identify the artist who has earned the most revenue from customers in Canada.",
|
| 75 |
"gold_sql": "SELECT ar.Name, SUM(il.UnitPrice * il.Quantity) AS Revenue FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId JOIN InvoiceLine il ON t.TrackId = il.TrackId JOIN Invoice i ON il.InvoiceId = i.InvoiceId WHERE i.BillingCountry = 'Canada' GROUP BY ar.ArtistId ORDER BY Revenue DESC LIMIT 1;"
|
| 76 |
}
|
|
|
|
| 1 |
[
|
| 2 |
{
|
| 3 |
"id": 1,
|
| 4 |
+
"difficulty": "easy",
|
| 5 |
+
"taxonomy": "Selection",
|
| 6 |
"question": "List all the artists name in the database.",
|
| 7 |
"gold_sql": "SELECT Name FROM Artist;"
|
| 8 |
},
|
| 9 |
{
|
| 10 |
"id": 2,
|
| 11 |
+
"difficulty": "easy",
|
| 12 |
+
"taxonomy": "Aggregation",
|
| 13 |
"question": "How many genres are there?",
|
| 14 |
"gold_sql": "SELECT COUNT(*) FROM Genre;"
|
| 15 |
},
|
| 16 |
{
|
| 17 |
"id": 3,
|
| 18 |
+
"difficulty": "easy",
|
| 19 |
+
"taxonomy": "Selection, Limit",
|
| 20 |
"question": "List the names of the first 5 tracks.",
|
| 21 |
"gold_sql": "SELECT Name FROM Track LIMIT 5;"
|
| 22 |
},
|
| 23 |
{
|
| 24 |
"id": 4,
|
| 25 |
+
"difficulty": "easy",
|
| 26 |
+
"taxonomy": "Aggregation, Filtering",
|
| 27 |
"question": "Count the number of customers located in the USA.",
|
| 28 |
"gold_sql": "SELECT COUNT(*) FROM Customer WHERE Country = 'USA';"
|
| 29 |
},
|
| 30 |
{
|
| 31 |
"id": 5,
|
| 32 |
+
"difficulty": "easy",
|
| 33 |
+
"taxonomy": "Selection, Filtering",
|
| 34 |
"question": "Find all invoices for the customer with ID 1.",
|
| 35 |
"gold_sql": "SELECT * FROM Invoice WHERE CustomerId = 1;"
|
| 36 |
},
|
| 37 |
{
|
| 38 |
"id": 6,
|
| 39 |
+
"difficulty": "medium",
|
| 40 |
+
"taxonomy": "Simple Join",
|
| 41 |
"question": "List each album title along with the artist's name.",
|
| 42 |
"gold_sql": "SELECT Album.Title, Artist.Name FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId;"
|
| 43 |
},
|
| 44 |
{
|
| 45 |
"id": 7,
|
| 46 |
+
"difficulty": "medium",
|
| 47 |
+
"taxonomy": "Simple Join, Filtering, Aggregation",
|
| 48 |
"question": "How many tracks belong to the 'Rock' genre?",
|
| 49 |
"gold_sql": "SELECT COUNT(*) FROM Track JOIN Genre ON Track.GenreId = Genre.GenreId WHERE Genre.Name = 'Rock';"
|
| 50 |
},
|
| 51 |
{
|
| 52 |
"id": 8,
|
| 53 |
+
"difficulty": "medium",
|
| 54 |
+
"taxonomy": "Aggregation, Grouping",
|
| 55 |
"question": "Show the total revenue generated from each country.",
|
| 56 |
"gold_sql": "SELECT BillingCountry, SUM(Total) FROM Invoice GROUP BY BillingCountry;"
|
| 57 |
},
|
| 58 |
{
|
| 59 |
"id": 9,
|
| 60 |
+
"difficulty": "medium",
|
| 61 |
+
"taxonomy": "Multi-Join, Aggregation, Grouping",
|
| 62 |
"question": "Find the total number of items sold for each media type.",
|
| 63 |
"gold_sql": "SELECT MediaType.Name, SUM(InvoiceLine.Quantity) FROM InvoiceLine JOIN Track ON InvoiceLine.TrackId = Track.TrackId JOIN MediaType ON Track.MediaTypeId = MediaType.MediaTypeId GROUP BY MediaType.Name;"
|
| 64 |
},
|
| 65 |
{
|
| 66 |
"id": 10,
|
| 67 |
+
"difficulty": "easy",
|
| 68 |
+
"taxonomy": "Selection, Filtering",
|
| 69 |
"question": "List the first and last names of all employees who are Sales Support Agents.",
|
| 70 |
"gold_sql": "SELECT FirstName, LastName FROM Employee WHERE Title = 'Sales Support Agent';"
|
| 71 |
},
|
| 72 |
{
|
| 73 |
"id": 11,
|
| 74 |
+
"difficulty": "medium",
|
| 75 |
+
"taxonomy": "Simple Join, Aggregation, Grouping, Ordering, Limit",
|
| 76 |
"question": "List the top 5 customers who have spent the most money in total.",
|
| 77 |
"gold_sql": "SELECT c.FirstName, c.LastName, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY TotalSpent DESC LIMIT 5;"
|
| 78 |
},
|
| 79 |
{
|
| 80 |
"id": 12,
|
| 81 |
+
"difficulty": "hard",
|
| 82 |
+
"taxonomy": "Multi-Join, Aggregation, Grouping, Ordering, Limit",
|
| 83 |
"question": "Which artist has the most tracks in the database? Give the name and count.",
|
| 84 |
"gold_sql": "SELECT ar.Name, COUNT(t.TrackId) as TrackCount FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId GROUP BY ar.ArtistId ORDER BY TrackCount DESC LIMIT 1;"
|
| 85 |
},
|
| 86 |
{
|
| 87 |
"id": 13,
|
| 88 |
+
"difficulty": "medium",
|
| 89 |
+
"taxonomy": "Simple Join, Aggregation, Grouping, Having",
|
| 90 |
"question": "Which genres have more than 100 tracks? List the genre name and count.",
|
| 91 |
"gold_sql": "SELECT g.Name, COUNT(t.TrackId) as TrackCount FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId HAVING TrackCount > 100;"
|
| 92 |
},
|
| 93 |
{
|
| 94 |
"id": 14,
|
| 95 |
+
"difficulty": "medium",
|
| 96 |
+
"taxonomy": "Simple Join, Aggregation, Arithmetic, Grouping",
|
| 97 |
"question": "Calculate the average track length in seconds for each genre.",
|
| 98 |
"gold_sql": "SELECT g.Name, AVG(t.Milliseconds) / 1000.0 as AvgSeconds FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId;"
|
| 99 |
},
|
| 100 |
{
|
| 101 |
"id": 15,
|
| 102 |
+
"difficulty": "hard",
|
| 103 |
+
"taxonomy": "Multi-Join, Aggregation, Grouping, Ordering, Limit",
|
| 104 |
"question": "Identify the artist who has earned the most revenue from customers in Canada.",
|
| 105 |
"gold_sql": "SELECT ar.Name, SUM(il.UnitPrice * il.Quantity) AS Revenue FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId JOIN InvoiceLine il ON t.TrackId = il.TrackId JOIN Invoice i ON il.InvoiceId = i.InvoiceId WHERE i.BillingCountry = 'Canada' GROUP BY ar.ArtistId ORDER BY Revenue DESC LIMIT 1;"
|
| 106 |
}
|