dvwn commited on
Commit
0731ede
·
1 Parent(s): 97dafec

Try and Error to Improve Accuracy

Browse files

Trying Codex to improve model accuracy

.env ADDED
@@ -0,0 +1 @@
 
 
1
+ HF_TOKEN='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'
freeze ADDED
File without changes
hf_evaluation_results.json ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "id": 1,
4
+ "question": "List all the artists name in the database.",
5
+ "status": "PASS",
6
+ "generated_sql": "SELECT a.Name FROM Artist a;",
7
+ "gold_sql": "SELECT Name FROM Artist;"
8
+ },
9
+ {
10
+ "id": 2,
11
+ "question": "How many genres are there?",
12
+ "status": "FAIL",
13
+ "generated_sql": "SELECT COUNT(DISTINCT g.GenreId) FROM Genre g;",
14
+ "gold_sql": "SELECT COUNT(*) FROM Genre;"
15
+ },
16
+ {
17
+ "id": 3,
18
+ "question": "List the names of the first 5 tracks.",
19
+ "status": "PASS",
20
+ "generated_sql": "SELECT t.Name FROM Track t ORDER BY t.TrackId LIMIT 5;",
21
+ "gold_sql": "SELECT Name FROM Track LIMIT 5;"
22
+ },
23
+ {
24
+ "id": 4,
25
+ "question": "Count the number of customers located in the USA.",
26
+ "status": "FAIL",
27
+ "generated_sql": "SELECT COUNT(DISTINCT c.CustomerId) FROM Customer c WHERE c.Country = 'USA';",
28
+ "gold_sql": "SELECT COUNT(*) FROM Customer WHERE Country = 'USA';"
29
+ },
30
+ {
31
+ "id": 5,
32
+ "question": "Find all invoices for the customer with ID 1.",
33
+ "status": "FAIL",
34
+ "generated_sql": "SELECT i.InvoiceId, i.InvoiceDate, i.Total FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId WHERE c.CustomerId = 1;",
35
+ "gold_sql": "SELECT * FROM Invoice WHERE CustomerId = 1;"
36
+ },
37
+ {
38
+ "id": 6,
39
+ "question": "List each album title along with the artist's name.",
40
+ "status": "PASS",
41
+ "generated_sql": "SELECT a.Title, ar.Name FROM Album a JOIN Artist ar ON a.ArtistId = ar.ArtistId;",
42
+ "gold_sql": "SELECT Album.Title, Artist.Name FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId;"
43
+ },
44
+ {
45
+ "id": 7,
46
+ "question": "How many tracks belong to the 'Rock' genre?",
47
+ "status": "FAIL",
48
+ "generated_sql": "SELECT COUNT(t.TrackId) FROM Track t JOIN Genre g ON t.GenreId = g.GenreId WHERE LOWER(g.Name) = 'rock';",
49
+ "gold_sql": "SELECT COUNT(*) FROM Track JOIN Genre ON Track.GenreId = Genre.GenreId WHERE Genre.Name = 'Rock';"
50
+ },
51
+ {
52
+ "id": 8,
53
+ "question": "Show the total revenue generated from each country.",
54
+ "status": "FAIL",
55
+ "generated_sql": "SELECT i.BillingCountry, SUM(i.Total) AS TotalRevenue FROM Invoice i GROUP BY i.BillingCountry ORDER BY TotalRevenue DESC NULLS LAST;",
56
+ "gold_sql": "SELECT BillingCountry, SUM(Total) FROM Invoice GROUP BY BillingCountry;"
57
+ },
58
+ {
59
+ "id": 9,
60
+ "question": "Find the total number of items sold for each media type.",
61
+ "status": "FAIL",
62
+ "generated_sql": "SELECT mt.Name AS MediaType, SUM(il.Quantity) AS TotalQuantity FROM InvoiceLine il JOIN Track t ON il.TrackId = t.TrackId JOIN MediaType mt ON t.MediaTypeId = mt.MediaTypeId GROUP BY mt.Name ORDER BY TotalQuantity DESC NULLS LAST;",
63
+ "gold_sql": "SELECT MediaType.Name, COUNT(InvoiceLine.TrackId) FROM InvoiceLine JOIN Track ON InvoiceLine.TrackId = Track.TrackId JOIN MediaType ON Track.MediaTypeId = MediaType.MediaTypeId GROUP BY MediaType.Name;"
64
+ },
65
+ {
66
+ "id": 10,
67
+ "question": "List the first and last names of all employees who are Sales Support Agents.",
68
+ "status": "PASS",
69
+ "generated_sql": "SELECT e.FirstName, e.LastName FROM Employee e WHERE e.Title = 'Sales Support Agent';",
70
+ "gold_sql": "SELECT FirstName, LastName FROM Employee WHERE Title = 'Sales Support Agent';"
71
+ },
72
+ {
73
+ "id": 11,
74
+ "question": "List the top 5 customers who have spent the most money in total.",
75
+ "status": "FAIL",
76
+ "generated_sql": "SELECT c.CustomerId, SUM(i.Total) AS total_spent FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.CustomerId ORDER BY total_spent DESC LIMIT 5;",
77
+ "gold_sql": "SELECT c.FirstName, c.LastName, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY TotalSpent DESC LIMIT 5;"
78
+ },
79
+ {
80
+ "id": 12,
81
+ "question": "Which artist has the most tracks in the database? Give the name and count.",
82
+ "status": "ERROR",
83
+ "generated_sql": "SELECT a.Name, COUNT(t.TrackId) AS track_count FROM Track t JOIN Album a ON t.AlbumId = a.AlbumId GROUP BY a.Name ORDER BY track_count DESC LIMIT 1;",
84
+ "gold_sql": "SELECT ar.Name, COUNT(t.TrackId) as TrackCount FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId GROUP BY ar.ArtistId ORDER BY TrackCount DESC LIMIT 1;",
85
+ "error": "Execution failed on sql 'SELECT a.Name, COUNT(t.TrackId) AS track_count FROM Track t JOIN Album a ON t.AlbumId = a.AlbumId GROUP BY a.Name ORDER BY track_count DESC LIMIT 1;': no such column: a.Name"
86
+ },
87
+ {
88
+ "id": 13,
89
+ "question": "Which genres have more than 100 tracks? List the genre name and count.",
90
+ "status": "FAIL",
91
+ "generated_sql": "SELECT g.Name, COUNT(t.TrackId) AS track_count FROM Track t JOIN Genre g ON t.GenreId = g.GenreId GROUP BY g.Name HAVING COUNT(t.TrackId) > 100 ORDER BY track_count DESC NULLS LAST;",
92
+ "gold_sql": "SELECT g.Name, COUNT(t.TrackId) as TrackCount FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId HAVING TrackCount > 100;"
93
+ },
94
+ {
95
+ "id": 14,
96
+ "question": "Calculate the average track length in seconds for each genre.",
97
+ "status": "FAIL",
98
+ "generated_sql": "SELECT g.Name, AVG(t.Milliseconds) AS average_length FROM Track t JOIN Genre g ON t.GenreId = g.GenreId GROUP BY g.Name ORDER BY average_length NULLS LAST;",
99
+ "gold_sql": "SELECT g.Name, AVG(t.Milliseconds) / 1000.0 as AvgSeconds FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId;"
100
+ },
101
+ {
102
+ "id": 15,
103
+ "question": "Identify the artist who has earned the most revenue from customers in Canada.",
104
+ "status": "ERROR",
105
+ "generated_sql": "SELECT a.Name, SUM(i.Total) AS TotalRevenue FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId JOIN Album a ON c.SupportRepId = a.ArtistId WHERE c.Country = 'Canada' GROUP BY a.Name ORDER BY TotalRevenue DESC LIMIT 1;",
106
+ "gold_sql": "SELECT ar.Name, SUM(il.UnitPrice * il.Quantity) AS Revenue FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId JOIN InvoiceLine il ON t.TrackId = il.TrackId JOIN Invoice i ON il.InvoiceId = i.InvoiceId WHERE i.BillingCountry = 'Canada' GROUP BY ar.ArtistId ORDER BY Revenue DESC LIMIT 1;",
107
+ "error": "Execution failed on sql 'SELECT a.Name, SUM(i.Total) AS TotalRevenue FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId JOIN Album a ON c.SupportRepId = a.ArtistId WHERE c.Country = 'Canada' GROUP BY a.Name ORDER BY TotalRevenue DESC LIMIT 1;': no such column: a.Name"
108
+ }
109
+ ]
hf_test_bench.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test the Hugging Face inference
2
+ from src.nl2sql.hf_engine import generate_sql
3
+ from src.database.db_manager import get_db_connection, get_schema_context
4
+ import pandas as pd
5
+
6
+ def test_single_query():
7
+ print("Initializing Featherless AI SQL generation test...")
8
+ # Fetch the database schema context (ddl) from Chinook
9
+ ddl = get_schema_context
10
+ question = "Identify the artist who has earned the most revenue from customers in Canada."
11
+
12
+ try:
13
+ generated_sql = generate_sql(question, ddl)
14
+ print(f"\nGenerated SQL:\n{generated_sql}\n")
15
+
16
+ # Connect to the database and execute the generated SQL
17
+ connection = get_db_connection()
18
+ df = pd.read_sql_query(generated_sql, connection)
19
+ connection.close()
20
+
21
+ print("\nDatabase Query Result:")
22
+ print(df)
23
+ print("\nTest completed successfully: API connected and SQL is valid.")
24
+
25
+ except Exception as e:
26
+ print(f"\nTest failed: {e}")
27
+
28
+ if __name__ == "__main__":
29
+ test_single_query()
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
 
src/database/__pycache__/db_manager.cpython-313.pyc CHANGED
Binary files a/src/database/__pycache__/db_manager.cpython-313.pyc and b/src/database/__pycache__/db_manager.cpython-313.pyc differ
 
src/database/db_manager.py CHANGED
@@ -1,48 +1,234 @@
1
- # This module provides a function to establish a connection to the SQLite database used in the NL2SQL project. It also includes a test block to verify the connection and list the tables in the database.
2
 
3
- import sqlite3
4
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # Get the path to the database file
7
- DB_PATH = os.path.join(os.path.dirname(__file__), 'Chinook_Sqlite.sqlite')
8
 
9
  def get_db_connection():
10
- """Establishes a connection to the SQLite database."""
11
  try:
12
  connection = sqlite3.connect(DB_PATH)
 
13
  return connection
14
- except sqlite3.Error as e:
15
- print(f"Error connecting to database: {e}")
16
  return None
17
-
18
- # Test the database connection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  if __name__ == "__main__":
20
  connection = get_db_connection()
21
  if connection:
22
  print("Database connection successful!")
23
  cursor = connection.cursor()
24
  cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
25
- print("Tables in the database:", cursor.fetchall())
26
  connection.close()
27
  else:
28
  print("Failed to connect to the database.")
29
-
30
- # Extract Schema Information for LLM Prompts
31
- def get_schema_context():
32
- """Extracts the database schema information to be used in LLM prompts."""
33
- connection = get_db_connection()
34
- if not connection:
35
- return "Unable to connect to the database to retrieve schema information."
36
-
37
- cursor = connection.cursor()
38
- cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
39
- tables = [t[0] for t in cursor.fetchall() if not t[0].startswith('sqlite_')]
40
-
41
- schema_text = ""
42
- for table in tables:
43
- cursor.execute(f"PRAGMA table_info({table});")
44
- columns = [f"{c[1]} ({c[2]})" for c in cursor.fetchall()]
45
- schema_text += f"Table {table}: {', '.join(columns)}\n"
46
- connection.close()
47
- return schema_text
48
-
 
1
+ #"""Database helpers for the NL2SQL project."""
2
 
 
3
  import os
4
+ import re
5
+ import sqlite3
6
+ from typing import Dict, List
7
+
8
+
9
+ DB_PATH = os.path.join(os.path.dirname(__file__), "Chinook_Sqlite.sqlite")
10
+ STOPWORDS = {
11
+ "a",
12
+ "all",
13
+ "an",
14
+ "and",
15
+ "are",
16
+ "as",
17
+ "at",
18
+ "by",
19
+ "count",
20
+ "each",
21
+ "find",
22
+ "for",
23
+ "from",
24
+ "give",
25
+ "has",
26
+ "have",
27
+ "how",
28
+ "in",
29
+ "is",
30
+ "list",
31
+ "many",
32
+ "most",
33
+ "name",
34
+ "names",
35
+ "of",
36
+ "on",
37
+ "show",
38
+ "the",
39
+ "their",
40
+ "there",
41
+ "to",
42
+ "total",
43
+ "what",
44
+ "which",
45
+ "who",
46
+ "with",
47
+ }
48
 
 
 
49
 
50
  def get_db_connection():
51
+ """Establish a connection to the SQLite database."""
52
  try:
53
  connection = sqlite3.connect(DB_PATH)
54
+ connection.row_factory = sqlite3.Row
55
  return connection
56
+ except sqlite3.Error as error:
57
+ print(f"Error connecting to database: {error}")
58
  return None
59
+
60
+
61
+ def _tokenize(text: str) -> set[str]:
62
+ tokens = re.findall(r"[A-Za-z0-9]+", text.lower())
63
+ return {token for token in tokens if token not in STOPWORDS}
64
+
65
+
66
+ def _quote_identifier(identifier: str) -> str:
67
+ escaped_identifier = identifier.replace('"', '""')
68
+ return f'"{escaped_identifier}"'
69
+
70
+
71
+ def _load_schema_metadata(connection: sqlite3.Connection) -> Dict[str, Dict[str, object]]:
72
+ cursor = connection.cursor()
73
+ cursor.execute(
74
+ """
75
+ SELECT name, sql
76
+ FROM sqlite_master
77
+ WHERE type = 'table' AND name NOT LIKE 'sqlite_%'
78
+ ORDER BY name
79
+ """
80
+ )
81
+
82
+ metadata: Dict[str, Dict[str, object]] = {}
83
+ for row in cursor.fetchall():
84
+ table_name = row["name"]
85
+ quoted_table = _quote_identifier(table_name)
86
+
87
+ columns = cursor.execute(f"PRAGMA table_info({quoted_table})").fetchall()
88
+ foreign_keys = cursor.execute(f"PRAGMA foreign_key_list({quoted_table})").fetchall()
89
+
90
+ metadata[table_name] = {
91
+ "ddl": row["sql"] or "",
92
+ "columns": [
93
+ {
94
+ "name": column["name"],
95
+ "type": column["type"] or "TEXT",
96
+ "notnull": bool(column["notnull"]),
97
+ "pk": bool(column["pk"]),
98
+ }
99
+ for column in columns
100
+ ],
101
+ "foreign_keys": [
102
+ {
103
+ "from": foreign_key["from"],
104
+ "to_table": foreign_key["table"],
105
+ "to_column": foreign_key["to"],
106
+ }
107
+ for foreign_key in foreign_keys
108
+ ],
109
+ }
110
+
111
+ return metadata
112
+
113
+
114
+ def _build_table_summary(table_name: str, table_info: Dict[str, object]) -> str:
115
+ column_parts = []
116
+ for column in table_info["columns"]:
117
+ tags = []
118
+ if column["pk"]:
119
+ tags.append("PK")
120
+ if column["notnull"]:
121
+ tags.append("NOT NULL")
122
+
123
+ tag_suffix = f" [{' '.join(tags)}]" if tags else ""
124
+ column_parts.append(f"{column['name']} {column['type']}{tag_suffix}")
125
+
126
+ summary = f"Table {table_name}: {', '.join(column_parts)}"
127
+ if table_info["foreign_keys"]:
128
+ relationships = ", ".join(
129
+ f"{table_name}.{foreign_key['from']} -> "
130
+ f"{foreign_key['to_table']}.{foreign_key['to_column']}"
131
+ for foreign_key in table_info["foreign_keys"]
132
+ )
133
+ summary = f"{summary}\nRelationships: {relationships}"
134
+
135
+ return summary
136
+
137
+
138
+ def _rank_tables(
139
+ metadata: Dict[str, Dict[str, object]], question: str | None, max_tables: int
140
+ ) -> List[str]:
141
+ table_names = list(metadata.keys())
142
+ if not question:
143
+ return table_names
144
+
145
+ question_tokens = _tokenize(question)
146
+ if not question_tokens:
147
+ return table_names
148
+
149
+ scored_tables = []
150
+ for table_name, table_info in metadata.items():
151
+ table_tokens = _tokenize(table_name)
152
+ column_tokens = set()
153
+ for column in table_info["columns"]:
154
+ column_tokens.update(_tokenize(column["name"]))
155
+
156
+ score = 0
157
+ score += 4 * len(question_tokens & table_tokens)
158
+ score += 2 * len(question_tokens & column_tokens)
159
+
160
+ singular_name = table_name[:-1].lower() if table_name.lower().endswith("s") else ""
161
+ if singular_name and singular_name in question.lower():
162
+ score += 2
163
+ if table_name.lower() in question.lower():
164
+ score += 3
165
+
166
+ scored_tables.append((score, table_name))
167
+
168
+ scored_tables.sort(key=lambda item: (-item[0], item[1]))
169
+ selected = [table_name for score, table_name in scored_tables if score > 0][:max_tables]
170
+
171
+ if not selected:
172
+ selected = [table_name for _, table_name in scored_tables[:max_tables]]
173
+
174
+ # Pull in directly related tables so the model sees valid join paths.
175
+ expanded = list(selected)
176
+ for table_name in selected:
177
+ for foreign_key in metadata[table_name]["foreign_keys"]:
178
+ related_table = foreign_key["to_table"]
179
+ if related_table in metadata and related_table not in expanded:
180
+ expanded.append(related_table)
181
+
182
+ for table_name, table_info in metadata.items():
183
+ for foreign_key in table_info["foreign_keys"]:
184
+ if foreign_key["to_table"] in selected and table_name not in expanded:
185
+ expanded.append(table_name)
186
+
187
+ return expanded[: max(max_tables, len(expanded))]
188
+
189
+
190
+ def get_schema_context(question: str | None = None, max_tables: int = 7) -> str:
191
+ """Extract schema information for prompt construction.
192
+
193
+ When a question is provided, the returned schema is narrowed to the most
194
+ relevant tables plus their immediate relationships. This keeps prompts
195
+ smaller while preserving valid join paths.
196
+ """
197
+
198
+ connection = get_db_connection()
199
+ if not connection:
200
+ return "Unable to connect to the database to retrieve schema information."
201
+
202
+ try:
203
+ metadata = _load_schema_metadata(connection)
204
+ finally:
205
+ connection.close()
206
+
207
+ selected_tables = _rank_tables(metadata, question, max_tables=max_tables)
208
+ schema_sections = [_build_table_summary(table_name, metadata[table_name]) for table_name in selected_tables]
209
+
210
+ all_relationships = []
211
+ for table_name in selected_tables:
212
+ for foreign_key in metadata[table_name]["foreign_keys"]:
213
+ if foreign_key["to_table"] in selected_tables:
214
+ all_relationships.append(
215
+ f"{table_name}.{foreign_key['from']} = "
216
+ f"{foreign_key['to_table']}.{foreign_key['to_column']}"
217
+ )
218
+
219
+ if all_relationships:
220
+ schema_sections.append("Join paths:\n" + "\n".join(sorted(set(all_relationships))))
221
+
222
+ return "\n\n".join(schema_sections)
223
+
224
+
225
  if __name__ == "__main__":
226
  connection = get_db_connection()
227
  if connection:
228
  print("Database connection successful!")
229
  cursor = connection.cursor()
230
  cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
231
+ print("Tables in the database:", [row[0] for row in cursor.fetchall()])
232
  connection.close()
233
  else:
234
  print("Failed to connect to the database.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/nl2sql/__pycache__/hf_engine.cpython-313.pyc ADDED
Binary file (4.4 kB). View file
 
src/nl2sql/hf_engine.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #"""Hugging Face inference helpers for SQL generation."""
2
+
3
+ import os
4
+ import re
5
+
6
+ from dotenv import load_dotenv
7
+ from huggingface_hub import InferenceClient
8
+
9
+
10
+ load_dotenv()
11
+ hf_token = os.getenv("HF_TOKEN")
12
+ if not hf_token:
13
+ raise ValueError("Token Not Found!")
14
+
15
+ client = InferenceClient(api_key=hf_token)
16
+ MODEL_ID = "defog/llama-3-sqlcoder-8b:featherless-ai"
17
+
18
+
19
+ def _build_messages(question: str, schema_context: str):
20
+ system_content = (
21
+ "You are an expert SQLite assistant that converts natural language into one "
22
+ "executable SQLite query.\n"
23
+ "Rules:\n"
24
+ "1. Use only tables, columns, and join paths present in the provided schema.\n"
25
+ "2. Generate valid SQLite syntax only.\n"
26
+ "3. Prefer exact column names from the schema, never invent columns.\n"
27
+ "4. Use explicit JOIN conditions when multiple tables are required.\n"
28
+ "5. Use GROUP BY for aggregates by entity, HAVING for aggregate filters, "
29
+ "ORDER BY for ranking, and LIMIT for top-N requests.\n"
30
+ "6. Return SQL only. No markdown, explanations, comments, or chain-of-thought.\n"
31
+ "7. If a join is needed, use short aliases that remain readable.\n"
32
+ "8. Produce a single SELECT statement."
33
+ )
34
+
35
+ user_content = f"""Database schema:
36
+ {schema_context}
37
+
38
+ Question:
39
+ {question}
40
+
41
+ Write the SQLite query that answers the question. Return only the SQL query."""
42
+
43
+ return [
44
+ {"role": "system", "content": system_content},
45
+ {"role": "user", "content": user_content},
46
+ ]
47
+
48
+
49
+ def _extract_sql(raw_response: str) -> str:
50
+ text = raw_response.strip()
51
+ fenced_match = re.search(r"```(?:sql)?\s*(.*?)```", text, flags=re.IGNORECASE | re.DOTALL)
52
+ if fenced_match:
53
+ text = fenced_match.group(1).strip()
54
+
55
+ statement_match = re.search(
56
+ r"(?is)\b(WITH|SELECT)\b.*?(;|$)",
57
+ text,
58
+ )
59
+ if statement_match:
60
+ text = statement_match.group(0).strip()
61
+
62
+ lines = [
63
+ line.strip()
64
+ for line in text.splitlines()
65
+ if line.strip() and not line.strip().startswith(("--", "#"))
66
+ ]
67
+ sql = " ".join(lines).strip()
68
+ if sql and not sql.endswith(";"):
69
+ sql = f"{sql};"
70
+ return sql
71
+
72
+
73
+ def generate_sql(question, ddl):
74
+ try:
75
+ completion = client.chat.completions.create(
76
+ model=MODEL_ID,
77
+ messages=_build_messages(question, ddl),
78
+ max_tokens=220,
79
+ temperature=0,
80
+ )
81
+ raw_response = completion.choices[0].message.content or ""
82
+ sql = _extract_sql(raw_response)
83
+ return sql or raw_response.strip()
84
+ except Exception as error:
85
+ return f"Error: {error}"
86
+
87
+
88
+ if __name__ == "__main__":
89
+ my_ddl = "CREATE TABLE tracks (id INTEGER PRIMARY KEY, title TEXT, genre TEXT);"
90
+ my_question = "How many tracks are there in each genre?"
91
+
92
+ print("Generating SQL query via Featherless AI...")
93
+ try:
94
+ result = generate_sql(my_question, my_ddl)
95
+ print("-" * 20)
96
+ print(result)
97
+ except Exception as error:
98
+ print(f"An error occurred: {error}")
src/scripts/__pycache__/evaluate_hf.cpython-313.pyc ADDED
Binary file (4.92 kB). View file
 
src/scripts/evaluate_hf.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #"""Evaluation script for Hugging Face SQL generation."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import pandas as pd
7
+
8
+ from src.database.db_manager import get_db_connection, get_schema_context
9
+ from src.nl2sql.hf_engine import generate_sql
10
+
11
+
12
+ TEST_CASES_PATH = Path("src/scripts/test_cases.json")
13
+ RESULTS_PATH = Path("hf_evaluation_results.json")
14
+
15
+
16
+ def _normalize_dataframe(dataframe: pd.DataFrame) -> pd.DataFrame:
17
+ normalized = dataframe.copy()
18
+ normalized.columns = [str(column).lower() for column in normalized.columns]
19
+
20
+ for column in normalized.columns:
21
+ normalized[column] = normalized[column].map(
22
+ lambda value: round(float(value), 6)
23
+ if isinstance(value, float)
24
+ else value
25
+ )
26
+
27
+ sort_columns = list(normalized.columns)
28
+ if sort_columns:
29
+ normalized = normalized.sort_values(by=sort_columns, kind="mergesort").reset_index(drop=True)
30
+
31
+ return normalized
32
+
33
+
34
+ def compare_results(df_generated: pd.DataFrame, df_gold: pd.DataFrame) -> bool:
35
+ """Compare generated and expected query results."""
36
+ if df_generated is None or df_gold is None:
37
+ return False
38
+
39
+ try:
40
+ normalized_generated = _normalize_dataframe(df_generated)
41
+ normalized_gold = _normalize_dataframe(df_gold)
42
+ return normalized_generated.equals(normalized_gold)
43
+ except Exception as error:
44
+ print(f"Error comparing results: {error}")
45
+ return False
46
+
47
+
48
+ def run_evaluation():
49
+ with TEST_CASES_PATH.open("r", encoding="utf-8") as handle:
50
+ test_cases = json.load(handle)
51
+
52
+ results = []
53
+ correct_count = 0
54
+
55
+ print(f"Running evaluation on {len(test_cases)} test cases...\n")
56
+
57
+ for case in test_cases:
58
+ question = case["question"]
59
+ print(f"Testing ID {case['id']}: {question[:50]}...")
60
+
61
+ schema_context = get_schema_context(question=question)
62
+ generated_sql = generate_sql(question, schema_context)
63
+
64
+ connection = get_db_connection()
65
+ if connection is None:
66
+ raise RuntimeError("Unable to connect to the SQLite database.")
67
+
68
+ try:
69
+ df_generated = pd.read_sql_query(generated_sql, connection)
70
+ df_gold = pd.read_sql_query(case["gold_sql"], connection)
71
+
72
+ is_correct = compare_results(df_generated, df_gold)
73
+ if is_correct:
74
+ correct_count += 1
75
+
76
+ results.append(
77
+ {
78
+ "id": case["id"],
79
+ "question": question,
80
+ "status": "PASS" if is_correct else "FAIL",
81
+ "generated_sql": generated_sql,
82
+ "gold_sql": case["gold_sql"],
83
+ }
84
+ )
85
+ except Exception as error:
86
+ results.append(
87
+ {
88
+ "id": case["id"],
89
+ "question": question,
90
+ "status": "ERROR",
91
+ "generated_sql": generated_sql,
92
+ "gold_sql": case["gold_sql"],
93
+ "error": str(error),
94
+ }
95
+ )
96
+ finally:
97
+ connection.close()
98
+
99
+ accuracy = (correct_count / len(test_cases)) * 100 if test_cases else 0.0
100
+ print("\nEVALUATION COMPLETE")
101
+ print(f"Total Test Cases: {len(test_cases)}")
102
+ print(f"Correctly Generated SQL: {correct_count} / {len(test_cases)}")
103
+ print(f"Execution Accuracy: {accuracy:.2f}%")
104
+
105
+ with RESULTS_PATH.open("w", encoding="utf-8") as handle:
106
+ json.dump(results, handle, indent=4)
107
+
108
+
109
+ if __name__ == "__main__":
110
+ run_evaluation()
src/scripts/test_cases.json CHANGED
@@ -1,77 +1,77 @@
1
  [
2
- {
3
- "id": 1,
4
- "question": "How many tracks are there in each genre? List the genre name and the count.",
5
- "gold_sql": "SELECT t.Genre, COUNT(t.TrackId) AS TrackCount FROM Track t GROUP BY t.Genre;"
6
- },
7
- {
8
- "id": 2,
9
- "question": "Provide a list of all albums and the name of the artist who created them.",
10
- "gold_sql": "SELECT a.Title, ar.Name FROM Album a JOIN Artist ar ON a.ArtistId = ar.ArtistId;"
11
- },
12
- {
13
- "id": 3,
14
- "question": "What is the total revenue generated from each country?",
15
- "gold_sql": "SELECT BillingCountry, SUM(Total) FROM Invoice GROUP BY BillingCountry;"
16
- },
17
- {
18
- "id": 4,
19
- "question": "Show the full names of all employees who are Sales Support Agents.",
20
- "gold_sql": "SELECT e.FirstName, e.LastName FROM Employee e WHERE e.Title = 'Sales Support Agent';"
21
- },
22
- {
23
- "id": 5,
24
- "question": "List the top 5 customers who have spent the most money.",
25
- "gold_sql": "SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId, c.FirstName, c.LastName ORDER BY TotalSpent DESC LIMIT 5;"
26
- },
27
- {
28
- "id": 6,
29
- "question": "List all Rock songs and the artists who performed them.",
30
- "gold_sql": "SELECT t.Name, ar.Name FROM Track t JOIN Genre g ON t.GenreId = g.GenreId JOIN Album a ON t.AlbumId = a.AlbumId JOIN Artist ar ON a.ArtistId = ar.ArtistId WHERE g.Name = 'Rock';"
31
- },
32
- {
33
- "id": 7,
34
- "question": "Find the total number of tracks sold for each media type.",
35
- "gold_sql": "SELECT m.Name, COUNT(il.TrackId) FROM MediaType m JOIN Track t ON m.MediaTypeId = t.MediaTypeId JOIN InvoiceLine il ON t.TrackId = il.TrackId GROUP BY m.Name;"
36
- },
37
- {
38
- "id": 8,
39
- "question": "Show the names of all tracks that appear on the 'TV Shows' playlist.",
40
- "gold_sql": "SELECT t.Name FROM Track t JOIN PlaylistTrack pt ON t.TrackId = pt.TrackId JOIN Playlist p ON pt.PlaylistId = p.PlaylistId WHERE p.Name = 'TV Shows';"
41
- },
42
- {
43
- "id": 9,
44
- "question": "Which artist has the most tracks? Give the name and count.",
45
- "gold_sql": "SELECT ar.Name, COUNT(t.TrackId) FROM Artist ar JOIN Album a ON ar.ArtistId = a.ArtistId JOIN Track t ON a.AlbumId = t.AlbumId GROUP BY ar.Name ORDER BY COUNT(t.TrackId) DESC LIMIT 1;"
46
- },
47
- {
48
- "id": 10,
49
- "question": "Which genres have more than 100 tracks?",
50
- "gold_sql": "SELECT g.Name, COUNT(t.TrackId) as TrackCount FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.Name HAVING TrackCount > 100;"
51
- },
52
- {
53
- "id": 11,
54
- "question": "Who is the best-selling artist by total revenue? Provide the artist's name and total revenue.",
55
- "gold_sql": "SELECT ar.Name, SUM(i.Total) as TotalRevenue FROM Artist ar JOIN Album a ON ar.ArtistId = a.ArtistId JOIN Track t ON a.AlbumId = t.AlbumId JOIN InvoiceLine il ON t.TrackId = il.TrackId JOIN Invoice i ON il.InvoiceId = i.InvoiceId GROUP BY ar.Name ORDER BY TotalRevenue DESC LIMIT 1;"
56
- },
57
- {
58
- "id": 12,
59
- "question": "Find the average length of tracks in seconds for each album. List the album title and average length.",
60
- "gold_sql": "SELECT a.Title, AVG(t.Milliseconds) as AverageLength FROM Album a JOIN Track t ON a.AlbumId = t.AlbumId GROUP BY a.Title;"
61
- },
62
- {
63
- "id": 13,
64
- "question": "List customers helped by the employee Jane Peacock. Provide the customer's full name and the employee's full name.",
65
- "gold_sql": "SELECT c.FirstName AS CustomerFirstName, c.LastName AS CustomerLastName, e.FirstName AS EmployeeFirstName, e.LastName AS EmployeeLastName FROM Customer c JOIN Employee e ON c.SupportRepId = e.EmployeeId WHERE e.FirstName = 'Jane' AND e.LastName = 'Peacock';"
66
- },
67
- {
68
- "id": 14,
69
- "question": "Which city had the highest number of invoices in 2013?",
70
- "gold_sql": "SELECT BillingCity, COUNT(InvoiceId) FROM Invoice WHERE InvoiceDate LIKE '2013%' GROUP BY BillingCity ORDER BY 2 DESC LIMIT 1;"
71
- },
72
- {
73
- "id": 15,
74
- "question": "List albums with a total price greater than 20 dollars.",
75
- "gold_sql": "SELECT al.Title, SUM(t.UnitPrice) FROM Album al JOIN Track t ON al.AlbumId = t.AlbumId GROUP BY al.Title HAVING SUM(t.UnitPrice) > 20;"
76
- }
77
  ]
 
1
  [
2
+ {
3
+ "id": 1,
4
+ "question": "List all the artists name in the database.",
5
+ "gold_sql": "SELECT Name FROM Artist;"
6
+ },
7
+ {
8
+ "id": 2,
9
+ "question": "How many genres are there?",
10
+ "gold_sql": "SELECT COUNT(*) FROM Genre;"
11
+ },
12
+ {
13
+ "id": 3,
14
+ "question": "List the names of the first 5 tracks.",
15
+ "gold_sql": "SELECT Name FROM Track LIMIT 5;"
16
+ },
17
+ {
18
+ "id": 4,
19
+ "question": "Count the number of customers located in the USA.",
20
+ "gold_sql": "SELECT COUNT(*) FROM Customer WHERE Country = 'USA';"
21
+ },
22
+ {
23
+ "id": 5,
24
+ "question": "Find all invoices for the customer with ID 1.",
25
+ "gold_sql": "SELECT * FROM Invoice WHERE CustomerId = 1;"
26
+ },
27
+ {
28
+ "id": 6,
29
+ "question": "List each album title along with the artist's name.",
30
+ "gold_sql": "SELECT Album.Title, Artist.Name FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId;"
31
+ },
32
+ {
33
+ "id": 7,
34
+ "question": "How many tracks belong to the 'Rock' genre?",
35
+ "gold_sql": "SELECT COUNT(*) FROM Track JOIN Genre ON Track.GenreId = Genre.GenreId WHERE Genre.Name = 'Rock';"
36
+ },
37
+ {
38
+ "id": 8,
39
+ "question": "Show the total revenue generated from each country.",
40
+ "gold_sql": "SELECT BillingCountry, SUM(Total) FROM Invoice GROUP BY BillingCountry;"
41
+ },
42
+ {
43
+ "id": 9,
44
+ "question": "Find the total number of items sold for each media type.",
45
+ "gold_sql": "SELECT MediaType.Name, COUNT(InvoiceLine.TrackId) FROM InvoiceLine JOIN Track ON InvoiceLine.TrackId = Track.TrackId JOIN MediaType ON Track.MediaTypeId = MediaType.MediaTypeId GROUP BY MediaType.Name;"
46
+ },
47
+ {
48
+ "id": 10,
49
+ "question": "List the first and last names of all employees who are Sales Support Agents.",
50
+ "gold_sql": "SELECT FirstName, LastName FROM Employee WHERE Title = 'Sales Support Agent';"
51
+ },
52
+ {
53
+ "id": 11,
54
+ "question": "List the top 5 customers who have spent the most money in total.",
55
+ "gold_sql": "SELECT c.FirstName, c.LastName, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY TotalSpent DESC LIMIT 5;"
56
+ },
57
+ {
58
+ "id": 12,
59
+ "question": "Which artist has the most tracks in the database? Give the name and count.",
60
+ "gold_sql": "SELECT ar.Name, COUNT(t.TrackId) as TrackCount FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId GROUP BY ar.ArtistId ORDER BY TrackCount DESC LIMIT 1;"
61
+ },
62
+ {
63
+ "id": 13,
64
+ "question": "Which genres have more than 100 tracks? List the genre name and count.",
65
+ "gold_sql": "SELECT g.Name, COUNT(t.TrackId) as TrackCount FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId HAVING TrackCount > 100;"
66
+ },
67
+ {
68
+ "id": 14,
69
+ "question": "Calculate the average track length in seconds for each genre.",
70
+ "gold_sql": "SELECT g.Name, AVG(t.Milliseconds) / 1000.0 as AvgSeconds FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId;"
71
+ },
72
+ {
73
+ "id": 15,
74
+ "question": "Identify the artist who has earned the most revenue from customers in Canada.",
75
+ "gold_sql": "SELECT ar.Name, SUM(il.UnitPrice * il.Quantity) AS Revenue FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId JOIN InvoiceLine il ON t.TrackId = il.TrackId JOIN Invoice i ON il.InvoiceId = i.InvoiceId WHERE i.BillingCountry = 'Canada' GROUP BY ar.ArtistId ORDER BY Revenue DESC LIMIT 1;"
76
+ }
77
  ]