Spaces:
Sleeping
Sleeping
Try and Error to Improve Accuracy
Browse filesTrying Codex to improve model accuracy
- .env +1 -0
- freeze +0 -0
- hf_evaluation_results.json +109 -0
- hf_test_bench.py +29 -0
- requirements.txt +0 -0
- src/database/__pycache__/db_manager.cpython-313.pyc +0 -0
- src/database/db_manager.py +216 -30
- src/nl2sql/__pycache__/hf_engine.cpython-313.pyc +0 -0
- src/nl2sql/hf_engine.py +98 -0
- src/scripts/__pycache__/evaluate_hf.cpython-313.pyc +0 -0
- src/scripts/evaluate_hf.py +110 -0
- src/scripts/test_cases.json +75 -75
.env
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
HF_TOKEN='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'
|
freeze
ADDED
|
File without changes
|
hf_evaluation_results.json
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"id": 1,
|
| 4 |
+
"question": "List all the artists name in the database.",
|
| 5 |
+
"status": "PASS",
|
| 6 |
+
"generated_sql": "SELECT a.Name FROM Artist a;",
|
| 7 |
+
"gold_sql": "SELECT Name FROM Artist;"
|
| 8 |
+
},
|
| 9 |
+
{
|
| 10 |
+
"id": 2,
|
| 11 |
+
"question": "How many genres are there?",
|
| 12 |
+
"status": "FAIL",
|
| 13 |
+
"generated_sql": "SELECT COUNT(DISTINCT g.GenreId) FROM Genre g;",
|
| 14 |
+
"gold_sql": "SELECT COUNT(*) FROM Genre;"
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
"id": 3,
|
| 18 |
+
"question": "List the names of the first 5 tracks.",
|
| 19 |
+
"status": "PASS",
|
| 20 |
+
"generated_sql": "SELECT t.Name FROM Track t ORDER BY t.TrackId LIMIT 5;",
|
| 21 |
+
"gold_sql": "SELECT Name FROM Track LIMIT 5;"
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"id": 4,
|
| 25 |
+
"question": "Count the number of customers located in the USA.",
|
| 26 |
+
"status": "FAIL",
|
| 27 |
+
"generated_sql": "SELECT COUNT(DISTINCT c.CustomerId) FROM Customer c WHERE c.Country = 'USA';",
|
| 28 |
+
"gold_sql": "SELECT COUNT(*) FROM Customer WHERE Country = 'USA';"
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"id": 5,
|
| 32 |
+
"question": "Find all invoices for the customer with ID 1.",
|
| 33 |
+
"status": "FAIL",
|
| 34 |
+
"generated_sql": "SELECT i.InvoiceId, i.InvoiceDate, i.Total FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId WHERE c.CustomerId = 1;",
|
| 35 |
+
"gold_sql": "SELECT * FROM Invoice WHERE CustomerId = 1;"
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"id": 6,
|
| 39 |
+
"question": "List each album title along with the artist's name.",
|
| 40 |
+
"status": "PASS",
|
| 41 |
+
"generated_sql": "SELECT a.Title, ar.Name FROM Album a JOIN Artist ar ON a.ArtistId = ar.ArtistId;",
|
| 42 |
+
"gold_sql": "SELECT Album.Title, Artist.Name FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId;"
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"id": 7,
|
| 46 |
+
"question": "How many tracks belong to the 'Rock' genre?",
|
| 47 |
+
"status": "FAIL",
|
| 48 |
+
"generated_sql": "SELECT COUNT(t.TrackId) FROM Track t JOIN Genre g ON t.GenreId = g.GenreId WHERE LOWER(g.Name) = 'rock';",
|
| 49 |
+
"gold_sql": "SELECT COUNT(*) FROM Track JOIN Genre ON Track.GenreId = Genre.GenreId WHERE Genre.Name = 'Rock';"
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"id": 8,
|
| 53 |
+
"question": "Show the total revenue generated from each country.",
|
| 54 |
+
"status": "FAIL",
|
| 55 |
+
"generated_sql": "SELECT i.BillingCountry, SUM(i.Total) AS TotalRevenue FROM Invoice i GROUP BY i.BillingCountry ORDER BY TotalRevenue DESC NULLS LAST;",
|
| 56 |
+
"gold_sql": "SELECT BillingCountry, SUM(Total) FROM Invoice GROUP BY BillingCountry;"
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"id": 9,
|
| 60 |
+
"question": "Find the total number of items sold for each media type.",
|
| 61 |
+
"status": "FAIL",
|
| 62 |
+
"generated_sql": "SELECT mt.Name AS MediaType, SUM(il.Quantity) AS TotalQuantity FROM InvoiceLine il JOIN Track t ON il.TrackId = t.TrackId JOIN MediaType mt ON t.MediaTypeId = mt.MediaTypeId GROUP BY mt.Name ORDER BY TotalQuantity DESC NULLS LAST;",
|
| 63 |
+
"gold_sql": "SELECT MediaType.Name, COUNT(InvoiceLine.TrackId) FROM InvoiceLine JOIN Track ON InvoiceLine.TrackId = Track.TrackId JOIN MediaType ON Track.MediaTypeId = MediaType.MediaTypeId GROUP BY MediaType.Name;"
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"id": 10,
|
| 67 |
+
"question": "List the first and last names of all employees who are Sales Support Agents.",
|
| 68 |
+
"status": "PASS",
|
| 69 |
+
"generated_sql": "SELECT e.FirstName, e.LastName FROM Employee e WHERE e.Title = 'Sales Support Agent';",
|
| 70 |
+
"gold_sql": "SELECT FirstName, LastName FROM Employee WHERE Title = 'Sales Support Agent';"
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"id": 11,
|
| 74 |
+
"question": "List the top 5 customers who have spent the most money in total.",
|
| 75 |
+
"status": "FAIL",
|
| 76 |
+
"generated_sql": "SELECT c.CustomerId, SUM(i.Total) AS total_spent FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.CustomerId ORDER BY total_spent DESC LIMIT 5;",
|
| 77 |
+
"gold_sql": "SELECT c.FirstName, c.LastName, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY TotalSpent DESC LIMIT 5;"
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"id": 12,
|
| 81 |
+
"question": "Which artist has the most tracks in the database? Give the name and count.",
|
| 82 |
+
"status": "ERROR",
|
| 83 |
+
"generated_sql": "SELECT a.Name, COUNT(t.TrackId) AS track_count FROM Track t JOIN Album a ON t.AlbumId = a.AlbumId GROUP BY a.Name ORDER BY track_count DESC LIMIT 1;",
|
| 84 |
+
"gold_sql": "SELECT ar.Name, COUNT(t.TrackId) as TrackCount FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId GROUP BY ar.ArtistId ORDER BY TrackCount DESC LIMIT 1;",
|
| 85 |
+
"error": "Execution failed on sql 'SELECT a.Name, COUNT(t.TrackId) AS track_count FROM Track t JOIN Album a ON t.AlbumId = a.AlbumId GROUP BY a.Name ORDER BY track_count DESC LIMIT 1;': no such column: a.Name"
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"id": 13,
|
| 89 |
+
"question": "Which genres have more than 100 tracks? List the genre name and count.",
|
| 90 |
+
"status": "FAIL",
|
| 91 |
+
"generated_sql": "SELECT g.Name, COUNT(t.TrackId) AS track_count FROM Track t JOIN Genre g ON t.GenreId = g.GenreId GROUP BY g.Name HAVING COUNT(t.TrackId) > 100 ORDER BY track_count DESC NULLS LAST;",
|
| 92 |
+
"gold_sql": "SELECT g.Name, COUNT(t.TrackId) as TrackCount FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId HAVING TrackCount > 100;"
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"id": 14,
|
| 96 |
+
"question": "Calculate the average track length in seconds for each genre.",
|
| 97 |
+
"status": "FAIL",
|
| 98 |
+
"generated_sql": "SELECT g.Name, AVG(t.Milliseconds) AS average_length FROM Track t JOIN Genre g ON t.GenreId = g.GenreId GROUP BY g.Name ORDER BY average_length NULLS LAST;",
|
| 99 |
+
"gold_sql": "SELECT g.Name, AVG(t.Milliseconds) / 1000.0 as AvgSeconds FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId;"
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"id": 15,
|
| 103 |
+
"question": "Identify the artist who has earned the most revenue from customers in Canada.",
|
| 104 |
+
"status": "ERROR",
|
| 105 |
+
"generated_sql": "SELECT a.Name, SUM(i.Total) AS TotalRevenue FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId JOIN Album a ON c.SupportRepId = a.ArtistId WHERE c.Country = 'Canada' GROUP BY a.Name ORDER BY TotalRevenue DESC LIMIT 1;",
|
| 106 |
+
"gold_sql": "SELECT ar.Name, SUM(il.UnitPrice * il.Quantity) AS Revenue FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId JOIN InvoiceLine il ON t.TrackId = il.TrackId JOIN Invoice i ON il.InvoiceId = i.InvoiceId WHERE i.BillingCountry = 'Canada' GROUP BY ar.ArtistId ORDER BY Revenue DESC LIMIT 1;",
|
| 107 |
+
"error": "Execution failed on sql 'SELECT a.Name, SUM(i.Total) AS TotalRevenue FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId JOIN Album a ON c.SupportRepId = a.ArtistId WHERE c.Country = 'Canada' GROUP BY a.Name ORDER BY TotalRevenue DESC LIMIT 1;': no such column: a.Name"
|
| 108 |
+
}
|
| 109 |
+
]
|
hf_test_bench.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Test the Hugging Face inference
|
| 2 |
+
from src.nl2sql.hf_engine import generate_sql
|
| 3 |
+
from src.database.db_manager import get_db_connection, get_schema_context
|
| 4 |
+
import pandas as pd
|
| 5 |
+
|
| 6 |
+
def test_single_query():
|
| 7 |
+
print("Initializing Featherless AI SQL generation test...")
|
| 8 |
+
# Fetch the database schema context (ddl) from Chinook
|
| 9 |
+
ddl = get_schema_context
|
| 10 |
+
question = "Identify the artist who has earned the most revenue from customers in Canada."
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
generated_sql = generate_sql(question, ddl)
|
| 14 |
+
print(f"\nGenerated SQL:\n{generated_sql}\n")
|
| 15 |
+
|
| 16 |
+
# Connect to the database and execute the generated SQL
|
| 17 |
+
connection = get_db_connection()
|
| 18 |
+
df = pd.read_sql_query(generated_sql, connection)
|
| 19 |
+
connection.close()
|
| 20 |
+
|
| 21 |
+
print("\nDatabase Query Result:")
|
| 22 |
+
print(df)
|
| 23 |
+
print("\nTest completed successfully: API connected and SQL is valid.")
|
| 24 |
+
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print(f"\nTest failed: {e}")
|
| 27 |
+
|
| 28 |
+
if __name__ == "__main__":
|
| 29 |
+
test_single_query()
|
requirements.txt
CHANGED
|
Binary files a/requirements.txt and b/requirements.txt differ
|
|
|
src/database/__pycache__/db_manager.cpython-313.pyc
CHANGED
|
Binary files a/src/database/__pycache__/db_manager.cpython-313.pyc and b/src/database/__pycache__/db_manager.cpython-313.pyc differ
|
|
|
src/database/db_manager.py
CHANGED
|
@@ -1,48 +1,234 @@
|
|
| 1 |
-
#
|
| 2 |
|
| 3 |
-
import sqlite3
|
| 4 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
# Get the path to the database file
|
| 7 |
-
DB_PATH = os.path.join(os.path.dirname(__file__), 'Chinook_Sqlite.sqlite')
|
| 8 |
|
| 9 |
def get_db_connection():
|
| 10 |
-
"""
|
| 11 |
try:
|
| 12 |
connection = sqlite3.connect(DB_PATH)
|
|
|
|
| 13 |
return connection
|
| 14 |
-
except sqlite3.Error as
|
| 15 |
-
print(f"Error connecting to database: {
|
| 16 |
return None
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
if __name__ == "__main__":
|
| 20 |
connection = get_db_connection()
|
| 21 |
if connection:
|
| 22 |
print("Database connection successful!")
|
| 23 |
cursor = connection.cursor()
|
| 24 |
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
| 25 |
-
print("Tables in the database:", cursor.fetchall())
|
| 26 |
connection.close()
|
| 27 |
else:
|
| 28 |
print("Failed to connect to the database.")
|
| 29 |
-
|
| 30 |
-
# Extract Schema Information for LLM Prompts
|
| 31 |
-
def get_schema_context():
|
| 32 |
-
"""Extracts the database schema information to be used in LLM prompts."""
|
| 33 |
-
connection = get_db_connection()
|
| 34 |
-
if not connection:
|
| 35 |
-
return "Unable to connect to the database to retrieve schema information."
|
| 36 |
-
|
| 37 |
-
cursor = connection.cursor()
|
| 38 |
-
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
| 39 |
-
tables = [t[0] for t in cursor.fetchall() if not t[0].startswith('sqlite_')]
|
| 40 |
-
|
| 41 |
-
schema_text = ""
|
| 42 |
-
for table in tables:
|
| 43 |
-
cursor.execute(f"PRAGMA table_info({table});")
|
| 44 |
-
columns = [f"{c[1]} ({c[2]})" for c in cursor.fetchall()]
|
| 45 |
-
schema_text += f"Table {table}: {', '.join(columns)}\n"
|
| 46 |
-
connection.close()
|
| 47 |
-
return schema_text
|
| 48 |
-
|
|
|
|
| 1 |
+
#"""Database helpers for the NL2SQL project."""
|
| 2 |
|
|
|
|
| 3 |
import os
|
| 4 |
+
import re
|
| 5 |
+
import sqlite3
|
| 6 |
+
from typing import Dict, List
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
DB_PATH = os.path.join(os.path.dirname(__file__), "Chinook_Sqlite.sqlite")
|
| 10 |
+
STOPWORDS = {
|
| 11 |
+
"a",
|
| 12 |
+
"all",
|
| 13 |
+
"an",
|
| 14 |
+
"and",
|
| 15 |
+
"are",
|
| 16 |
+
"as",
|
| 17 |
+
"at",
|
| 18 |
+
"by",
|
| 19 |
+
"count",
|
| 20 |
+
"each",
|
| 21 |
+
"find",
|
| 22 |
+
"for",
|
| 23 |
+
"from",
|
| 24 |
+
"give",
|
| 25 |
+
"has",
|
| 26 |
+
"have",
|
| 27 |
+
"how",
|
| 28 |
+
"in",
|
| 29 |
+
"is",
|
| 30 |
+
"list",
|
| 31 |
+
"many",
|
| 32 |
+
"most",
|
| 33 |
+
"name",
|
| 34 |
+
"names",
|
| 35 |
+
"of",
|
| 36 |
+
"on",
|
| 37 |
+
"show",
|
| 38 |
+
"the",
|
| 39 |
+
"their",
|
| 40 |
+
"there",
|
| 41 |
+
"to",
|
| 42 |
+
"total",
|
| 43 |
+
"what",
|
| 44 |
+
"which",
|
| 45 |
+
"who",
|
| 46 |
+
"with",
|
| 47 |
+
}
|
| 48 |
|
|
|
|
|
|
|
| 49 |
|
| 50 |
def get_db_connection():
|
| 51 |
+
"""Establish a connection to the SQLite database."""
|
| 52 |
try:
|
| 53 |
connection = sqlite3.connect(DB_PATH)
|
| 54 |
+
connection.row_factory = sqlite3.Row
|
| 55 |
return connection
|
| 56 |
+
except sqlite3.Error as error:
|
| 57 |
+
print(f"Error connecting to database: {error}")
|
| 58 |
return None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _tokenize(text: str) -> set[str]:
|
| 62 |
+
tokens = re.findall(r"[A-Za-z0-9]+", text.lower())
|
| 63 |
+
return {token for token in tokens if token not in STOPWORDS}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _quote_identifier(identifier: str) -> str:
|
| 67 |
+
escaped_identifier = identifier.replace('"', '""')
|
| 68 |
+
return f'"{escaped_identifier}"'
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _load_schema_metadata(connection: sqlite3.Connection) -> Dict[str, Dict[str, object]]:
|
| 72 |
+
cursor = connection.cursor()
|
| 73 |
+
cursor.execute(
|
| 74 |
+
"""
|
| 75 |
+
SELECT name, sql
|
| 76 |
+
FROM sqlite_master
|
| 77 |
+
WHERE type = 'table' AND name NOT LIKE 'sqlite_%'
|
| 78 |
+
ORDER BY name
|
| 79 |
+
"""
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
metadata: Dict[str, Dict[str, object]] = {}
|
| 83 |
+
for row in cursor.fetchall():
|
| 84 |
+
table_name = row["name"]
|
| 85 |
+
quoted_table = _quote_identifier(table_name)
|
| 86 |
+
|
| 87 |
+
columns = cursor.execute(f"PRAGMA table_info({quoted_table})").fetchall()
|
| 88 |
+
foreign_keys = cursor.execute(f"PRAGMA foreign_key_list({quoted_table})").fetchall()
|
| 89 |
+
|
| 90 |
+
metadata[table_name] = {
|
| 91 |
+
"ddl": row["sql"] or "",
|
| 92 |
+
"columns": [
|
| 93 |
+
{
|
| 94 |
+
"name": column["name"],
|
| 95 |
+
"type": column["type"] or "TEXT",
|
| 96 |
+
"notnull": bool(column["notnull"]),
|
| 97 |
+
"pk": bool(column["pk"]),
|
| 98 |
+
}
|
| 99 |
+
for column in columns
|
| 100 |
+
],
|
| 101 |
+
"foreign_keys": [
|
| 102 |
+
{
|
| 103 |
+
"from": foreign_key["from"],
|
| 104 |
+
"to_table": foreign_key["table"],
|
| 105 |
+
"to_column": foreign_key["to"],
|
| 106 |
+
}
|
| 107 |
+
for foreign_key in foreign_keys
|
| 108 |
+
],
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
return metadata
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _build_table_summary(table_name: str, table_info: Dict[str, object]) -> str:
|
| 115 |
+
column_parts = []
|
| 116 |
+
for column in table_info["columns"]:
|
| 117 |
+
tags = []
|
| 118 |
+
if column["pk"]:
|
| 119 |
+
tags.append("PK")
|
| 120 |
+
if column["notnull"]:
|
| 121 |
+
tags.append("NOT NULL")
|
| 122 |
+
|
| 123 |
+
tag_suffix = f" [{' '.join(tags)}]" if tags else ""
|
| 124 |
+
column_parts.append(f"{column['name']} {column['type']}{tag_suffix}")
|
| 125 |
+
|
| 126 |
+
summary = f"Table {table_name}: {', '.join(column_parts)}"
|
| 127 |
+
if table_info["foreign_keys"]:
|
| 128 |
+
relationships = ", ".join(
|
| 129 |
+
f"{table_name}.{foreign_key['from']} -> "
|
| 130 |
+
f"{foreign_key['to_table']}.{foreign_key['to_column']}"
|
| 131 |
+
for foreign_key in table_info["foreign_keys"]
|
| 132 |
+
)
|
| 133 |
+
summary = f"{summary}\nRelationships: {relationships}"
|
| 134 |
+
|
| 135 |
+
return summary
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _rank_tables(
|
| 139 |
+
metadata: Dict[str, Dict[str, object]], question: str | None, max_tables: int
|
| 140 |
+
) -> List[str]:
|
| 141 |
+
table_names = list(metadata.keys())
|
| 142 |
+
if not question:
|
| 143 |
+
return table_names
|
| 144 |
+
|
| 145 |
+
question_tokens = _tokenize(question)
|
| 146 |
+
if not question_tokens:
|
| 147 |
+
return table_names
|
| 148 |
+
|
| 149 |
+
scored_tables = []
|
| 150 |
+
for table_name, table_info in metadata.items():
|
| 151 |
+
table_tokens = _tokenize(table_name)
|
| 152 |
+
column_tokens = set()
|
| 153 |
+
for column in table_info["columns"]:
|
| 154 |
+
column_tokens.update(_tokenize(column["name"]))
|
| 155 |
+
|
| 156 |
+
score = 0
|
| 157 |
+
score += 4 * len(question_tokens & table_tokens)
|
| 158 |
+
score += 2 * len(question_tokens & column_tokens)
|
| 159 |
+
|
| 160 |
+
singular_name = table_name[:-1].lower() if table_name.lower().endswith("s") else ""
|
| 161 |
+
if singular_name and singular_name in question.lower():
|
| 162 |
+
score += 2
|
| 163 |
+
if table_name.lower() in question.lower():
|
| 164 |
+
score += 3
|
| 165 |
+
|
| 166 |
+
scored_tables.append((score, table_name))
|
| 167 |
+
|
| 168 |
+
scored_tables.sort(key=lambda item: (-item[0], item[1]))
|
| 169 |
+
selected = [table_name for score, table_name in scored_tables if score > 0][:max_tables]
|
| 170 |
+
|
| 171 |
+
if not selected:
|
| 172 |
+
selected = [table_name for _, table_name in scored_tables[:max_tables]]
|
| 173 |
+
|
| 174 |
+
# Pull in directly related tables so the model sees valid join paths.
|
| 175 |
+
expanded = list(selected)
|
| 176 |
+
for table_name in selected:
|
| 177 |
+
for foreign_key in metadata[table_name]["foreign_keys"]:
|
| 178 |
+
related_table = foreign_key["to_table"]
|
| 179 |
+
if related_table in metadata and related_table not in expanded:
|
| 180 |
+
expanded.append(related_table)
|
| 181 |
+
|
| 182 |
+
for table_name, table_info in metadata.items():
|
| 183 |
+
for foreign_key in table_info["foreign_keys"]:
|
| 184 |
+
if foreign_key["to_table"] in selected and table_name not in expanded:
|
| 185 |
+
expanded.append(table_name)
|
| 186 |
+
|
| 187 |
+
return expanded[: max(max_tables, len(expanded))]
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def get_schema_context(question: str | None = None, max_tables: int = 7) -> str:
|
| 191 |
+
"""Extract schema information for prompt construction.
|
| 192 |
+
|
| 193 |
+
When a question is provided, the returned schema is narrowed to the most
|
| 194 |
+
relevant tables plus their immediate relationships. This keeps prompts
|
| 195 |
+
smaller while preserving valid join paths.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
connection = get_db_connection()
|
| 199 |
+
if not connection:
|
| 200 |
+
return "Unable to connect to the database to retrieve schema information."
|
| 201 |
+
|
| 202 |
+
try:
|
| 203 |
+
metadata = _load_schema_metadata(connection)
|
| 204 |
+
finally:
|
| 205 |
+
connection.close()
|
| 206 |
+
|
| 207 |
+
selected_tables = _rank_tables(metadata, question, max_tables=max_tables)
|
| 208 |
+
schema_sections = [_build_table_summary(table_name, metadata[table_name]) for table_name in selected_tables]
|
| 209 |
+
|
| 210 |
+
all_relationships = []
|
| 211 |
+
for table_name in selected_tables:
|
| 212 |
+
for foreign_key in metadata[table_name]["foreign_keys"]:
|
| 213 |
+
if foreign_key["to_table"] in selected_tables:
|
| 214 |
+
all_relationships.append(
|
| 215 |
+
f"{table_name}.{foreign_key['from']} = "
|
| 216 |
+
f"{foreign_key['to_table']}.{foreign_key['to_column']}"
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
if all_relationships:
|
| 220 |
+
schema_sections.append("Join paths:\n" + "\n".join(sorted(set(all_relationships))))
|
| 221 |
+
|
| 222 |
+
return "\n\n".join(schema_sections)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
if __name__ == "__main__":
|
| 226 |
connection = get_db_connection()
|
| 227 |
if connection:
|
| 228 |
print("Database connection successful!")
|
| 229 |
cursor = connection.cursor()
|
| 230 |
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
| 231 |
+
print("Tables in the database:", [row[0] for row in cursor.fetchall()])
|
| 232 |
connection.close()
|
| 233 |
else:
|
| 234 |
print("Failed to connect to the database.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/nl2sql/__pycache__/hf_engine.cpython-313.pyc
ADDED
|
Binary file (4.4 kB). View file
|
|
|
src/nl2sql/hf_engine.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#"""Hugging Face inference helpers for SQL generation."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from huggingface_hub import InferenceClient
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
load_dotenv()
|
| 11 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 12 |
+
if not hf_token:
|
| 13 |
+
raise ValueError("Token Not Found!")
|
| 14 |
+
|
| 15 |
+
client = InferenceClient(api_key=hf_token)
|
| 16 |
+
MODEL_ID = "defog/llama-3-sqlcoder-8b:featherless-ai"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _build_messages(question: str, schema_context: str):
|
| 20 |
+
system_content = (
|
| 21 |
+
"You are an expert SQLite assistant that converts natural language into one "
|
| 22 |
+
"executable SQLite query.\n"
|
| 23 |
+
"Rules:\n"
|
| 24 |
+
"1. Use only tables, columns, and join paths present in the provided schema.\n"
|
| 25 |
+
"2. Generate valid SQLite syntax only.\n"
|
| 26 |
+
"3. Prefer exact column names from the schema, never invent columns.\n"
|
| 27 |
+
"4. Use explicit JOIN conditions when multiple tables are required.\n"
|
| 28 |
+
"5. Use GROUP BY for aggregates by entity, HAVING for aggregate filters, "
|
| 29 |
+
"ORDER BY for ranking, and LIMIT for top-N requests.\n"
|
| 30 |
+
"6. Return SQL only. No markdown, explanations, comments, or chain-of-thought.\n"
|
| 31 |
+
"7. If a join is needed, use short aliases that remain readable.\n"
|
| 32 |
+
"8. Produce a single SELECT statement."
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
user_content = f"""Database schema:
|
| 36 |
+
{schema_context}
|
| 37 |
+
|
| 38 |
+
Question:
|
| 39 |
+
{question}
|
| 40 |
+
|
| 41 |
+
Write the SQLite query that answers the question. Return only the SQL query."""
|
| 42 |
+
|
| 43 |
+
return [
|
| 44 |
+
{"role": "system", "content": system_content},
|
| 45 |
+
{"role": "user", "content": user_content},
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _extract_sql(raw_response: str) -> str:
|
| 50 |
+
text = raw_response.strip()
|
| 51 |
+
fenced_match = re.search(r"```(?:sql)?\s*(.*?)```", text, flags=re.IGNORECASE | re.DOTALL)
|
| 52 |
+
if fenced_match:
|
| 53 |
+
text = fenced_match.group(1).strip()
|
| 54 |
+
|
| 55 |
+
statement_match = re.search(
|
| 56 |
+
r"(?is)\b(WITH|SELECT)\b.*?(;|$)",
|
| 57 |
+
text,
|
| 58 |
+
)
|
| 59 |
+
if statement_match:
|
| 60 |
+
text = statement_match.group(0).strip()
|
| 61 |
+
|
| 62 |
+
lines = [
|
| 63 |
+
line.strip()
|
| 64 |
+
for line in text.splitlines()
|
| 65 |
+
if line.strip() and not line.strip().startswith(("--", "#"))
|
| 66 |
+
]
|
| 67 |
+
sql = " ".join(lines).strip()
|
| 68 |
+
if sql and not sql.endswith(";"):
|
| 69 |
+
sql = f"{sql};"
|
| 70 |
+
return sql
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def generate_sql(question, ddl):
|
| 74 |
+
try:
|
| 75 |
+
completion = client.chat.completions.create(
|
| 76 |
+
model=MODEL_ID,
|
| 77 |
+
messages=_build_messages(question, ddl),
|
| 78 |
+
max_tokens=220,
|
| 79 |
+
temperature=0,
|
| 80 |
+
)
|
| 81 |
+
raw_response = completion.choices[0].message.content or ""
|
| 82 |
+
sql = _extract_sql(raw_response)
|
| 83 |
+
return sql or raw_response.strip()
|
| 84 |
+
except Exception as error:
|
| 85 |
+
return f"Error: {error}"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
my_ddl = "CREATE TABLE tracks (id INTEGER PRIMARY KEY, title TEXT, genre TEXT);"
|
| 90 |
+
my_question = "How many tracks are there in each genre?"
|
| 91 |
+
|
| 92 |
+
print("Generating SQL query via Featherless AI...")
|
| 93 |
+
try:
|
| 94 |
+
result = generate_sql(my_question, my_ddl)
|
| 95 |
+
print("-" * 20)
|
| 96 |
+
print(result)
|
| 97 |
+
except Exception as error:
|
| 98 |
+
print(f"An error occurred: {error}")
|
src/scripts/__pycache__/evaluate_hf.cpython-313.pyc
ADDED
|
Binary file (4.92 kB). View file
|
|
|
src/scripts/evaluate_hf.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#"""Evaluation script for Hugging Face SQL generation."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
from src.database.db_manager import get_db_connection, get_schema_context
|
| 9 |
+
from src.nl2sql.hf_engine import generate_sql
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
TEST_CASES_PATH = Path("src/scripts/test_cases.json")
|
| 13 |
+
RESULTS_PATH = Path("hf_evaluation_results.json")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _normalize_dataframe(dataframe: pd.DataFrame) -> pd.DataFrame:
|
| 17 |
+
normalized = dataframe.copy()
|
| 18 |
+
normalized.columns = [str(column).lower() for column in normalized.columns]
|
| 19 |
+
|
| 20 |
+
for column in normalized.columns:
|
| 21 |
+
normalized[column] = normalized[column].map(
|
| 22 |
+
lambda value: round(float(value), 6)
|
| 23 |
+
if isinstance(value, float)
|
| 24 |
+
else value
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
sort_columns = list(normalized.columns)
|
| 28 |
+
if sort_columns:
|
| 29 |
+
normalized = normalized.sort_values(by=sort_columns, kind="mergesort").reset_index(drop=True)
|
| 30 |
+
|
| 31 |
+
return normalized
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def compare_results(df_generated: pd.DataFrame, df_gold: pd.DataFrame) -> bool:
|
| 35 |
+
"""Compare generated and expected query results."""
|
| 36 |
+
if df_generated is None or df_gold is None:
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
normalized_generated = _normalize_dataframe(df_generated)
|
| 41 |
+
normalized_gold = _normalize_dataframe(df_gold)
|
| 42 |
+
return normalized_generated.equals(normalized_gold)
|
| 43 |
+
except Exception as error:
|
| 44 |
+
print(f"Error comparing results: {error}")
|
| 45 |
+
return False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def run_evaluation():
|
| 49 |
+
with TEST_CASES_PATH.open("r", encoding="utf-8") as handle:
|
| 50 |
+
test_cases = json.load(handle)
|
| 51 |
+
|
| 52 |
+
results = []
|
| 53 |
+
correct_count = 0
|
| 54 |
+
|
| 55 |
+
print(f"Running evaluation on {len(test_cases)} test cases...\n")
|
| 56 |
+
|
| 57 |
+
for case in test_cases:
|
| 58 |
+
question = case["question"]
|
| 59 |
+
print(f"Testing ID {case['id']}: {question[:50]}...")
|
| 60 |
+
|
| 61 |
+
schema_context = get_schema_context(question=question)
|
| 62 |
+
generated_sql = generate_sql(question, schema_context)
|
| 63 |
+
|
| 64 |
+
connection = get_db_connection()
|
| 65 |
+
if connection is None:
|
| 66 |
+
raise RuntimeError("Unable to connect to the SQLite database.")
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
df_generated = pd.read_sql_query(generated_sql, connection)
|
| 70 |
+
df_gold = pd.read_sql_query(case["gold_sql"], connection)
|
| 71 |
+
|
| 72 |
+
is_correct = compare_results(df_generated, df_gold)
|
| 73 |
+
if is_correct:
|
| 74 |
+
correct_count += 1
|
| 75 |
+
|
| 76 |
+
results.append(
|
| 77 |
+
{
|
| 78 |
+
"id": case["id"],
|
| 79 |
+
"question": question,
|
| 80 |
+
"status": "PASS" if is_correct else "FAIL",
|
| 81 |
+
"generated_sql": generated_sql,
|
| 82 |
+
"gold_sql": case["gold_sql"],
|
| 83 |
+
}
|
| 84 |
+
)
|
| 85 |
+
except Exception as error:
|
| 86 |
+
results.append(
|
| 87 |
+
{
|
| 88 |
+
"id": case["id"],
|
| 89 |
+
"question": question,
|
| 90 |
+
"status": "ERROR",
|
| 91 |
+
"generated_sql": generated_sql,
|
| 92 |
+
"gold_sql": case["gold_sql"],
|
| 93 |
+
"error": str(error),
|
| 94 |
+
}
|
| 95 |
+
)
|
| 96 |
+
finally:
|
| 97 |
+
connection.close()
|
| 98 |
+
|
| 99 |
+
accuracy = (correct_count / len(test_cases)) * 100 if test_cases else 0.0
|
| 100 |
+
print("\nEVALUATION COMPLETE")
|
| 101 |
+
print(f"Total Test Cases: {len(test_cases)}")
|
| 102 |
+
print(f"Correctly Generated SQL: {correct_count} / {len(test_cases)}")
|
| 103 |
+
print(f"Execution Accuracy: {accuracy:.2f}%")
|
| 104 |
+
|
| 105 |
+
with RESULTS_PATH.open("w", encoding="utf-8") as handle:
|
| 106 |
+
json.dump(results, handle, indent=4)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
run_evaluation()
|
src/scripts/test_cases.json
CHANGED
|
@@ -1,77 +1,77 @@
|
|
| 1 |
[
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
]
|
|
|
|
| 1 |
[
|
| 2 |
+
{
|
| 3 |
+
"id": 1,
|
| 4 |
+
"question": "List all the artists name in the database.",
|
| 5 |
+
"gold_sql": "SELECT Name FROM Artist;"
|
| 6 |
+
},
|
| 7 |
+
{
|
| 8 |
+
"id": 2,
|
| 9 |
+
"question": "How many genres are there?",
|
| 10 |
+
"gold_sql": "SELECT COUNT(*) FROM Genre;"
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"id": 3,
|
| 14 |
+
"question": "List the names of the first 5 tracks.",
|
| 15 |
+
"gold_sql": "SELECT Name FROM Track LIMIT 5;"
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"id": 4,
|
| 19 |
+
"question": "Count the number of customers located in the USA.",
|
| 20 |
+
"gold_sql": "SELECT COUNT(*) FROM Customer WHERE Country = 'USA';"
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"id": 5,
|
| 24 |
+
"question": "Find all invoices for the customer with ID 1.",
|
| 25 |
+
"gold_sql": "SELECT * FROM Invoice WHERE CustomerId = 1;"
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"id": 6,
|
| 29 |
+
"question": "List each album title along with the artist's name.",
|
| 30 |
+
"gold_sql": "SELECT Album.Title, Artist.Name FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId;"
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"id": 7,
|
| 34 |
+
"question": "How many tracks belong to the 'Rock' genre?",
|
| 35 |
+
"gold_sql": "SELECT COUNT(*) FROM Track JOIN Genre ON Track.GenreId = Genre.GenreId WHERE Genre.Name = 'Rock';"
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"id": 8,
|
| 39 |
+
"question": "Show the total revenue generated from each country.",
|
| 40 |
+
"gold_sql": "SELECT BillingCountry, SUM(Total) FROM Invoice GROUP BY BillingCountry;"
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"id": 9,
|
| 44 |
+
"question": "Find the total number of items sold for each media type.",
|
| 45 |
+
"gold_sql": "SELECT MediaType.Name, COUNT(InvoiceLine.TrackId) FROM InvoiceLine JOIN Track ON InvoiceLine.TrackId = Track.TrackId JOIN MediaType ON Track.MediaTypeId = MediaType.MediaTypeId GROUP BY MediaType.Name;"
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"id": 10,
|
| 49 |
+
"question": "List the first and last names of all employees who are Sales Support Agents.",
|
| 50 |
+
"gold_sql": "SELECT FirstName, LastName FROM Employee WHERE Title = 'Sales Support Agent';"
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"id": 11,
|
| 54 |
+
"question": "List the top 5 customers who have spent the most money in total.",
|
| 55 |
+
"gold_sql": "SELECT c.FirstName, c.LastName, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY TotalSpent DESC LIMIT 5;"
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"id": 12,
|
| 59 |
+
"question": "Which artist has the most tracks in the database? Give the name and count.",
|
| 60 |
+
"gold_sql": "SELECT ar.Name, COUNT(t.TrackId) as TrackCount FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId GROUP BY ar.ArtistId ORDER BY TrackCount DESC LIMIT 1;"
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"id": 13,
|
| 64 |
+
"question": "Which genres have more than 100 tracks? List the genre name and count.",
|
| 65 |
+
"gold_sql": "SELECT g.Name, COUNT(t.TrackId) as TrackCount FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId HAVING TrackCount > 100;"
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"id": 14,
|
| 69 |
+
"question": "Calculate the average track length in seconds for each genre.",
|
| 70 |
+
"gold_sql": "SELECT g.Name, AVG(t.Milliseconds) / 1000.0 as AvgSeconds FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId;"
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"id": 15,
|
| 74 |
+
"question": "Identify the artist who has earned the most revenue from customers in Canada.",
|
| 75 |
+
"gold_sql": "SELECT ar.Name, SUM(il.UnitPrice * il.Quantity) AS Revenue FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId JOIN InvoiceLine il ON t.TrackId = il.TrackId JOIN Invoice i ON il.InvoiceId = i.InvoiceId WHERE i.BillingCountry = 'Canada' GROUP BY ar.ArtistId ORDER BY Revenue DESC LIMIT 1;"
|
| 76 |
+
}
|
| 77 |
]
|