dvwn commited on
Commit
71c851f
·
1 Parent(s): dc6fe88

Update evaluation script and test cases json

Browse files

1. Update evaluation scripts to utilise agent sql instead of directly use engine LLM
2. Tweak questions and sql in test cases json

src/scripts/__pycache__/evaluate_hf.cpython-313.pyc CHANGED
Binary files a/src/scripts/__pycache__/evaluate_hf.cpython-313.pyc and b/src/scripts/__pycache__/evaluate_hf.cpython-313.pyc differ
 
src/scripts/evaluate_hf.py CHANGED
@@ -1,19 +1,17 @@
1
- #"""Evaluation script for Hugging Face SQL generation."""
 
2
 
3
  import json
4
  from pathlib import Path
5
-
6
  import pandas as pd
7
-
8
- from src.database.db_manager import get_db_connection, get_schema_context
9
- from src.nl2sql.hf_engine import generate_sql
10
-
11
 
12
  TEST_CASES_PATH = Path("src/scripts/test_cases.json")
13
  RESULTS_PATH = Path("hf_evaluation_results.json")
14
 
15
-
16
  def _normalize_dataframe(dataframe: pd.DataFrame) -> pd.DataFrame:
 
17
  normalized = dataframe.copy()
18
  normalized.columns = [str(column).lower() for column in normalized.columns]
19
 
@@ -30,7 +28,7 @@ def _normalize_dataframe(dataframe: pd.DataFrame) -> pd.DataFrame:
30
 
31
  return normalized
32
 
33
-
34
  def compare_results(df_generated: pd.DataFrame, df_gold: pd.DataFrame) -> bool:
35
  """Compare generated and expected query results."""
36
  if df_generated is None or df_gold is None:
@@ -44,8 +42,11 @@ def compare_results(df_generated: pd.DataFrame, df_gold: pd.DataFrame) -> bool:
44
  print(f"Error comparing results: {error}")
45
  return False
46
 
47
-
48
  def run_evaluation():
 
 
 
 
49
  with TEST_CASES_PATH.open("r", encoding="utf-8") as handle:
50
  test_cases = json.load(handle)
51
 
@@ -58,8 +59,9 @@ def run_evaluation():
58
  question = case["question"]
59
  print(f"Testing ID {case['id']}: {question[:50]}...")
60
 
61
- schema_context = get_schema_context(question=question)
62
- generated_sql = generate_sql(question, schema_context)
 
63
 
64
  connection = get_db_connection()
65
  if connection is None:
@@ -103,8 +105,4 @@ def run_evaluation():
103
  print(f"Execution Accuracy: {accuracy:.2f}%")
104
 
105
  with RESULTS_PATH.open("w", encoding="utf-8") as handle:
106
- json.dump(results, handle, indent=4)
107
-
108
-
109
- if __name__ == "__main__":
110
- run_evaluation()
 
1
+ # src/scripts/evaluate_hf.py
2
+ # Evaluation script for Hugging Face SQL generation.
3
 
4
  import json
5
  from pathlib import Path
 
6
  import pandas as pd
7
+ from src.database.db_manager import get_db_connection
8
+ from src.nl2sql.sql_agent import nl2sql_agent
 
 
9
 
10
  TEST_CASES_PATH = Path("src/scripts/test_cases.json")
11
  RESULTS_PATH = Path("hf_evaluation_results.json")
12
 
 
13
  def _normalize_dataframe(dataframe: pd.DataFrame) -> pd.DataFrame:
14
+ # Normalize dataframe to ensure accurate comparison
15
  normalized = dataframe.copy()
16
  normalized.columns = [str(column).lower() for column in normalized.columns]
17
 
 
28
 
29
  return normalized
30
 
31
+ # Compare generated SQL results with expected results
32
  def compare_results(df_generated: pd.DataFrame, df_gold: pd.DataFrame) -> bool:
33
  """Compare generated and expected query results."""
34
  if df_generated is None or df_gold is None:
 
42
  print(f"Error comparing results: {error}")
43
  return False
44
 
 
45
  def run_evaluation():
46
+ if not TEST_CASES_PATH.exists():
47
+ print(f"Error: Could not find test cases at {TEST_CASES_PATH}")
48
+ return
49
+
50
  with TEST_CASES_PATH.open("r", encoding="utf-8") as handle:
51
  test_cases = json.load(handle)
52
 
 
59
  question = case["question"]
60
  print(f"Testing ID {case['id']}: {question[:50]}...")
61
 
62
+ # Implement agent to handle RAG retrieval and SQL generation
63
+ agent_response = nl2sql_agent(user_question=question)
64
+ generated_sql = agent_response.get("query", "")
65
 
66
  connection = get_db_connection()
67
  if connection is None:
 
105
  print(f"Execution Accuracy: {accuracy:.2f}%")
106
 
107
  with RESULTS_PATH.open("w", encoding="utf-8") as handle:
108
+ json.dump(results, handle, indent=4)
 
 
 
 
src/scripts/test_cases.json CHANGED
@@ -42,7 +42,7 @@
42
  {
43
  "id": 9,
44
  "question": "Find the total number of items sold for each media type.",
45
- "gold_sql": "SELECT MediaType.Name, COUNT(InvoiceLine.TrackId) FROM InvoiceLine JOIN Track ON InvoiceLine.TrackId = Track.TrackId JOIN MediaType ON Track.MediaTypeId = MediaType.MediaTypeId GROUP BY MediaType.Name;"
46
  },
47
  {
48
  "id": 10,
 
42
  {
43
  "id": 9,
44
  "question": "Find the total number of items sold for each media type.",
45
+ "gold_sql": "SELECT MediaType.Name, SUM(InvoiceLine.Quantity) FROM InvoiceLine JOIN Track ON InvoiceLine.TrackId = Track.TrackId JOIN MediaType ON Track.MediaTypeId = MediaType.MediaTypeId GROUP BY MediaType.Name;"
46
  },
47
  {
48
  "id": 10,