dvwn commited on
Commit
a4607e0
·
1 Parent(s): f3f9320

Debugging

Browse files

- Remove duplicate & unnecessary files

app/main.py DELETED
@@ -1,38 +0,0 @@
1
- # Path: app/main.py
2
- # Main entry point for the NL2SQL application
3
- import os
4
- from dotenv import load_dotenv
5
- from src.scripts.interactive_mode import run_interactiveMode
6
- from src.scripts.evaluation_mode import run_evaluation
7
-
8
- load_dotenv()
9
- # Load HuggingFace API token from environment variable
10
- hf_token = os.getenv("HF_TOKEN")
11
- if not hf_token:
12
- raise ValueError("HuggingFace API token not found!")
13
-
14
- def main():
15
- """Main application entry point and interactive menu"""
16
- while True:
17
- print("\n" + "="*30)
18
- print(" NL2SQL Application Main Menu")
19
- print("\n" + "="*30)
20
- print("1. Run Interactive Agent NL2SQL Mode (Ask a single question)")
21
- print("2. Run Batch Evaluation of NL2SQL Agent (Evaluate on 15 test cases)")
22
- print("3. Exit")
23
- print("\n" + "="*30)
24
-
25
- choice = input("Select an option (1-3): ")
26
-
27
- if choice == '1':
28
- run_interactiveMode()
29
- elif choice == '2':
30
- run_evaluation()
31
- elif choice == '3':
32
- print("Exiting application. Goodbye!")
33
- break
34
- else:
35
- print("Invalid choice. Please select a valid option (1, 2, or 3).")
36
-
37
- if __name__ == "__main__":
38
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/app/main.py CHANGED
@@ -2,8 +2,8 @@
2
  # Main entry point for the NL2SQL application
3
  import os
4
  from dotenv import load_dotenv
5
- from backend.src.scripts.interactive_mode import run_interactiveMode
6
- from backend.src.scripts.evaluation_mode import run_evaluation
7
 
8
  load_dotenv()
9
  # Load HuggingFace API token from environment variable
 
2
  # Main entry point for the NL2SQL application
3
  import os
4
  from dotenv import load_dotenv
5
+ from src.scripts.interactive_mode import run_interactiveMode
6
+ from src.scripts.evaluation_mode import run_evaluation
7
 
8
  load_dotenv()
9
  # Load HuggingFace API token from environment variable
backend/hf_evaluation_results.json CHANGED
@@ -2,108 +2,136 @@
2
  {
3
  "id": 1,
4
  "question": "List all the artists name in the database.",
5
- "status": "PASS",
6
- "generated_sql": "SELECT a.Name FROM Artist a;",
 
 
7
  "gold_sql": "SELECT Name FROM Artist;"
8
  },
9
  {
10
  "id": 2,
11
  "question": "How many genres are there?",
12
- "status": "FAIL",
13
- "generated_sql": "SELECT COUNT(DISTINCT g.GenreId) FROM Genre g;",
 
 
14
  "gold_sql": "SELECT COUNT(*) FROM Genre;"
15
  },
16
  {
17
  "id": 3,
18
  "question": "List the names of the first 5 tracks.",
19
- "status": "PASS",
20
- "generated_sql": "SELECT t.Name FROM Track t ORDER BY t.TrackId LIMIT 5;",
 
 
21
  "gold_sql": "SELECT Name FROM Track LIMIT 5;"
22
  },
23
  {
24
  "id": 4,
25
  "question": "Count the number of customers located in the USA.",
26
- "status": "FAIL",
27
- "generated_sql": "SELECT COUNT(DISTINCT c.CustomerId) FROM Customer c WHERE c.Country = 'USA';",
 
 
28
  "gold_sql": "SELECT COUNT(*) FROM Customer WHERE Country = 'USA';"
29
  },
30
  {
31
  "id": 5,
32
  "question": "Find all invoices for the customer with ID 1.",
33
- "status": "FAIL",
34
- "generated_sql": "SELECT i.InvoiceId, i.InvoiceDate, i.Total FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId WHERE c.CustomerId = 1;",
 
 
35
  "gold_sql": "SELECT * FROM Invoice WHERE CustomerId = 1;"
36
  },
37
  {
38
  "id": 6,
39
  "question": "List each album title along with the artist's name.",
40
- "status": "PASS",
41
- "generated_sql": "SELECT a.Title, ar.Name FROM Album a JOIN Artist ar ON a.ArtistId = ar.ArtistId;",
 
 
42
  "gold_sql": "SELECT Album.Title, Artist.Name FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId;"
43
  },
44
  {
45
  "id": 7,
46
  "question": "How many tracks belong to the 'Rock' genre?",
47
- "status": "FAIL",
48
- "generated_sql": "SELECT COUNT(t.TrackId) FROM Track t JOIN Genre g ON t.GenreId = g.GenreId WHERE LOWER(g.Name) = 'rock';",
 
 
49
  "gold_sql": "SELECT COUNT(*) FROM Track JOIN Genre ON Track.GenreId = Genre.GenreId WHERE Genre.Name = 'Rock';"
50
  },
51
  {
52
  "id": 8,
53
  "question": "Show the total revenue generated from each country.",
54
- "status": "FAIL",
55
- "generated_sql": "SELECT i.BillingCountry, SUM(i.Total) AS TotalRevenue FROM Invoice i GROUP BY i.BillingCountry ORDER BY TotalRevenue DESC NULLS LAST;",
 
 
56
  "gold_sql": "SELECT BillingCountry, SUM(Total) FROM Invoice GROUP BY BillingCountry;"
57
  },
58
  {
59
  "id": 9,
60
  "question": "Find the total number of items sold for each media type.",
61
- "status": "FAIL",
62
- "generated_sql": "SELECT mt.Name AS MediaType, SUM(il.Quantity) AS TotalQuantity FROM InvoiceLine il JOIN Track t ON il.TrackId = t.TrackId JOIN MediaType mt ON t.MediaTypeId = mt.MediaTypeId GROUP BY mt.Name ORDER BY TotalQuantity DESC NULLS LAST;",
63
- "gold_sql": "SELECT MediaType.Name, COUNT(InvoiceLine.TrackId) FROM InvoiceLine JOIN Track ON InvoiceLine.TrackId = Track.TrackId JOIN MediaType ON Track.MediaTypeId = MediaType.MediaTypeId GROUP BY MediaType.Name;"
 
 
64
  },
65
  {
66
  "id": 10,
67
  "question": "List the first and last names of all employees who are Sales Support Agents.",
68
- "status": "PASS",
69
- "generated_sql": "SELECT e.FirstName, e.LastName FROM Employee e WHERE e.Title = 'Sales Support Agent';",
 
 
70
  "gold_sql": "SELECT FirstName, LastName FROM Employee WHERE Title = 'Sales Support Agent';"
71
  },
72
  {
73
  "id": 11,
74
  "question": "List the top 5 customers who have spent the most money in total.",
75
- "status": "FAIL",
76
- "generated_sql": "SELECT c.CustomerId, SUM(i.Total) AS total_spent FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.CustomerId ORDER BY total_spent DESC LIMIT 5;",
 
 
77
  "gold_sql": "SELECT c.FirstName, c.LastName, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY TotalSpent DESC LIMIT 5;"
78
  },
79
  {
80
  "id": 12,
81
  "question": "Which artist has the most tracks in the database? Give the name and count.",
82
- "status": "ERROR",
83
- "generated_sql": "SELECT a.Name, COUNT(t.TrackId) AS track_count FROM Track t JOIN Album a ON t.AlbumId = a.AlbumId GROUP BY a.Name ORDER BY track_count DESC LIMIT 1;",
84
- "gold_sql": "SELECT ar.Name, COUNT(t.TrackId) as TrackCount FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId GROUP BY ar.ArtistId ORDER BY TrackCount DESC LIMIT 1;",
85
- "error": "Execution failed on sql 'SELECT a.Name, COUNT(t.TrackId) AS track_count FROM Track t JOIN Album a ON t.AlbumId = a.AlbumId GROUP BY a.Name ORDER BY track_count DESC LIMIT 1;': no such column: a.Name"
 
86
  },
87
  {
88
  "id": 13,
89
  "question": "Which genres have more than 100 tracks? List the genre name and count.",
90
- "status": "FAIL",
91
- "generated_sql": "SELECT g.Name, COUNT(t.TrackId) AS track_count FROM Track t JOIN Genre g ON t.GenreId = g.GenreId GROUP BY g.Name HAVING COUNT(t.TrackId) > 100 ORDER BY track_count DESC NULLS LAST;",
 
 
92
  "gold_sql": "SELECT g.Name, COUNT(t.TrackId) as TrackCount FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId HAVING TrackCount > 100;"
93
  },
94
  {
95
  "id": 14,
96
  "question": "Calculate the average track length in seconds for each genre.",
97
- "status": "FAIL",
98
- "generated_sql": "SELECT g.Name, AVG(t.Milliseconds) AS average_length FROM Track t JOIN Genre g ON t.GenreId = g.GenreId GROUP BY g.Name ORDER BY average_length NULLS LAST;",
 
 
99
  "gold_sql": "SELECT g.Name, AVG(t.Milliseconds) / 1000.0 as AvgSeconds FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId;"
100
  },
101
  {
102
  "id": 15,
103
  "question": "Identify the artist who has earned the most revenue from customers in Canada.",
104
- "status": "ERROR",
105
- "generated_sql": "SELECT a.Name, SUM(i.Total) AS TotalRevenue FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId JOIN Album a ON c.SupportRepId = a.ArtistId WHERE c.Country = 'Canada' GROUP BY a.Name ORDER BY TotalRevenue DESC LIMIT 1;",
106
- "gold_sql": "SELECT ar.Name, SUM(il.UnitPrice * il.Quantity) AS Revenue FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId JOIN InvoiceLine il ON t.TrackId = il.TrackId JOIN Invoice i ON il.InvoiceId = i.InvoiceId WHERE i.BillingCountry = 'Canada' GROUP BY ar.ArtistId ORDER BY Revenue DESC LIMIT 1;",
107
- "error": "Execution failed on sql 'SELECT a.Name, SUM(i.Total) AS TotalRevenue FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId JOIN Album a ON c.SupportRepId = a.ArtistId WHERE c.Country = 'Canada' GROUP BY a.Name ORDER BY TotalRevenue DESC LIMIT 1;': no such column: a.Name"
 
108
  }
109
  ]
 
2
  {
3
  "id": 1,
4
  "question": "List all the artists name in the database.",
5
+ "taxonomy": "Selection",
6
+ "ex_pass": true,
7
+ "esm_pass": true,
8
+ "generated_sql": "SELECT Name FROM Artist",
9
  "gold_sql": "SELECT Name FROM Artist;"
10
  },
11
  {
12
  "id": 2,
13
  "question": "How many genres are there?",
14
+ "taxonomy": "Aggregation",
15
+ "ex_pass": true,
16
+ "esm_pass": true,
17
+ "generated_sql": "SELECT COUNT(*) FROM Genre",
18
  "gold_sql": "SELECT COUNT(*) FROM Genre;"
19
  },
20
  {
21
  "id": 3,
22
  "question": "List the names of the first 5 tracks.",
23
+ "taxonomy": "Selection, Limit",
24
+ "ex_pass": true,
25
+ "esm_pass": true,
26
+ "generated_sql": "SELECT Name FROM Track LIMIT 5",
27
  "gold_sql": "SELECT Name FROM Track LIMIT 5;"
28
  },
29
  {
30
  "id": 4,
31
  "question": "Count the number of customers located in the USA.",
32
+ "taxonomy": "Aggregation, Filtering",
33
+ "ex_pass": true,
34
+ "esm_pass": true,
35
+ "generated_sql": "SELECT COUNT(*) FROM Customer WHERE Country = 'USA'",
36
  "gold_sql": "SELECT COUNT(*) FROM Customer WHERE Country = 'USA';"
37
  },
38
  {
39
  "id": 5,
40
  "question": "Find all invoices for the customer with ID 1.",
41
+ "taxonomy": "Selection, Filtering",
42
+ "ex_pass": true,
43
+ "esm_pass": true,
44
+ "generated_sql": "SELECT * FROM Invoice WHERE CustomerId = 1",
45
  "gold_sql": "SELECT * FROM Invoice WHERE CustomerId = 1;"
46
  },
47
  {
48
  "id": 6,
49
  "question": "List each album title along with the artist's name.",
50
+ "taxonomy": "Simple Join",
51
+ "ex_pass": true,
52
+ "esm_pass": true,
53
+ "generated_sql": "SELECT Album.Title, Artist.Name FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId",
54
  "gold_sql": "SELECT Album.Title, Artist.Name FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId;"
55
  },
56
  {
57
  "id": 7,
58
  "question": "How many tracks belong to the 'Rock' genre?",
59
+ "taxonomy": "Simple Join, Filtering, Aggregation",
60
+ "ex_pass": true,
61
+ "esm_pass": false,
62
+ "generated_sql": "SELECT COUNT(*) FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock')",
63
  "gold_sql": "SELECT COUNT(*) FROM Track JOIN Genre ON Track.GenreId = Genre.GenreId WHERE Genre.Name = 'Rock';"
64
  },
65
  {
66
  "id": 8,
67
  "question": "Show the total revenue generated from each country.",
68
+ "taxonomy": "Aggregation, Grouping",
69
+ "ex_pass": true,
70
+ "esm_pass": false,
71
+ "generated_sql": "SELECT BillingCountry, SUM(Total) AS TotalRevenue FROM Invoice GROUP BY BillingCountry",
72
  "gold_sql": "SELECT BillingCountry, SUM(Total) FROM Invoice GROUP BY BillingCountry;"
73
  },
74
  {
75
  "id": 9,
76
  "question": "Find the total number of items sold for each media type.",
77
+ "taxonomy": "Multi-Join, Aggregation, Grouping",
78
+ "ex_pass": true,
79
+ "esm_pass": false,
80
+ "generated_sql": "SELECT M.Name, COUNT(T.TrackId) AS TotalItemsSold\nFROM MediaType M\nJOIN Track T ON M.MediaTypeId = T.MediaTypeId\nJOIN InvoiceLine IL ON T.TrackId = IL.TrackId\nGROUP BY M.Name",
81
+ "gold_sql": "SELECT MediaType.Name, SUM(InvoiceLine.Quantity) FROM InvoiceLine JOIN Track ON InvoiceLine.TrackId = Track.TrackId JOIN MediaType ON Track.MediaTypeId = MediaType.MediaTypeId GROUP BY MediaType.Name;"
82
  },
83
  {
84
  "id": 10,
85
  "question": "List the first and last names of all employees who are Sales Support Agents.",
86
+ "taxonomy": "Selection, Filtering",
87
+ "ex_pass": true,
88
+ "esm_pass": true,
89
+ "generated_sql": "SELECT FirstName, LastName FROM Employee WHERE Title = 'Sales Support Agent'",
90
  "gold_sql": "SELECT FirstName, LastName FROM Employee WHERE Title = 'Sales Support Agent';"
91
  },
92
  {
93
  "id": 11,
94
  "question": "List the top 5 customers who have spent the most money in total.",
95
+ "taxonomy": "Simple Join, Aggregation, Grouping, Ordering, Limit",
96
+ "ex_pass": false,
97
+ "esm_pass": false,
98
+ "generated_sql": "SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSpent\nFROM Customer c\nJOIN Invoice i ON c.CustomerId = i.CustomerId\nGROUP BY c.CustomerId, c.FirstName, c.LastName\nORDER BY TotalSpent DESC\nLIMIT 5",
99
  "gold_sql": "SELECT c.FirstName, c.LastName, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY TotalSpent DESC LIMIT 5;"
100
  },
101
  {
102
  "id": 12,
103
  "question": "Which artist has the most tracks in the database? Give the name and count.",
104
+ "taxonomy": "Multi-Join, Aggregation, Grouping, Ordering, Limit",
105
+ "ex_pass": true,
106
+ "esm_pass": false,
107
+ "generated_sql": "SELECT T2.Name, COUNT(T3.TrackId) AS TrackCount FROM Album AS T1 JOIN Artist AS T2 ON T1.ArtistId = T2.ArtistId JOIN Track AS T3 ON T1.AlbumId = T3.AlbumId GROUP BY T2.Name ORDER BY TrackCount DESC LIMIT 1",
108
+ "gold_sql": "SELECT ar.Name, COUNT(t.TrackId) as TrackCount FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId GROUP BY ar.ArtistId ORDER BY TrackCount DESC LIMIT 1;"
109
  },
110
  {
111
  "id": 13,
112
  "question": "Which genres have more than 100 tracks? List the genre name and count.",
113
+ "taxonomy": "Simple Join, Aggregation, Grouping, Having",
114
+ "ex_pass": true,
115
+ "esm_pass": false,
116
+ "generated_sql": "SELECT G.Name, COUNT(T.TrackId) AS TrackCount FROM Genre G JOIN Track T ON G.GenreId = T.GenreId GROUP BY G.Name HAVING COUNT(T.TrackId) > 100",
117
  "gold_sql": "SELECT g.Name, COUNT(t.TrackId) as TrackCount FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId HAVING TrackCount > 100;"
118
  },
119
  {
120
  "id": 14,
121
  "question": "Calculate the average track length in seconds for each genre.",
122
+ "taxonomy": "Simple Join, Aggregation, Arithmetic, Grouping",
123
+ "ex_pass": false,
124
+ "esm_pass": false,
125
+ "generated_sql": "SELECT GenreId, AVG(Milliseconds / 1000.0) AS AverageLengthInSeconds FROM Track GROUP BY GenreId",
126
  "gold_sql": "SELECT g.Name, AVG(t.Milliseconds) / 1000.0 as AvgSeconds FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId;"
127
  },
128
  {
129
  "id": 15,
130
  "question": "Identify the artist who has earned the most revenue from customers in Canada.",
131
+ "taxonomy": "Multi-Join, Aggregation, Grouping, Ordering, Limit",
132
+ "ex_pass": false,
133
+ "esm_pass": false,
134
+ "generated_sql": "SELECT T1.Name FROM Artist AS T1 JOIN Album AS T2 ON T1.ArtistId = T2.ArtistId JOIN Track AS T5 ON T2.AlbumId = T5.AlbumId JOIN InvoiceLine AS T3 ON T5.TrackId = T3.TrackId JOIN Invoice AS T4 ON T3.InvoiceId = T4.InvoiceId WHERE T4.BillingCountry = 'Canada' GROUP BY T1.ArtistId ORDER BY SUM(T3.UnitPrice * T3.Quantity) DESC LIMIT 1",
135
+ "gold_sql": "SELECT ar.Name, SUM(il.UnitPrice * il.Quantity) AS Revenue FROM Artist ar JOIN Album al ON ar.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId JOIN InvoiceLine il ON t.TrackId = il.TrackId JOIN Invoice i ON il.InvoiceId = i.InvoiceId WHERE i.BillingCountry = 'Canada' GROUP BY ar.ArtistId ORDER BY Revenue DESC LIMIT 1;"
136
  }
137
  ]
backend/requirements.txt CHANGED
Binary files a/backend/requirements.txt and b/backend/requirements.txt differ
 
backend/src/nl2sql/sql_agent.py CHANGED
@@ -1,8 +1,8 @@
1
  # Path: src/nl2sql/sql_agent.py
2
  # SQL Agent for handling NL2SQL conversion with Auto-Correct functionality
3
- from backend.src.database.db_manager import get_db_connection, get_schema_context
4
  from langchain_core.prompts import PromptTemplate
5
- from backend.src.nl2sql.hf_engine import get_llm
6
 
7
  # Craft the Prompt Template to instruct LLM on its persona
8
  SQL_PROMPT_TEMPLATE = """You are an expert SQLite developer.
 
1
  # Path: src/nl2sql/sql_agent.py
2
  # SQL Agent for handling NL2SQL conversion with Auto-Correct functionality
3
+ from src.database.db_manager import get_db_connection, get_schema_context
4
  from langchain_core.prompts import PromptTemplate
5
+ from src.nl2sql.hf_engine import get_llm
6
 
7
  # Craft the Prompt Template to instruct LLM on its persona
8
  SQL_PROMPT_TEMPLATE = """You are an expert SQLite developer.
src/database/__pycache__/db_manager.cpython-313.pyc DELETED
Binary file (3.88 kB)
 
src/nl2sql/__pycache__/hf_engine.cpython-313.pyc DELETED
Binary file (3.71 kB)