Spaces:
Sleeping
Sleeping
Update evaluation script and test cases json
Browse files1. Update evaluation scripts to utilise agent sql instead of directly use engine LLM
2. Tweak questions and sql in test cases json
src/scripts/__pycache__/evaluate_hf.cpython-313.pyc
CHANGED
|
Binary files a/src/scripts/__pycache__/evaluate_hf.cpython-313.pyc and b/src/scripts/__pycache__/evaluate_hf.cpython-313.pyc differ
|
|
|
src/scripts/evaluate_hf.py
CHANGED
|
@@ -1,19 +1,17 @@
|
|
| 1 |
-
#
|
|
|
|
| 2 |
|
| 3 |
import json
|
| 4 |
from pathlib import Path
|
| 5 |
-
|
| 6 |
import pandas as pd
|
| 7 |
-
|
| 8 |
-
from src.
|
| 9 |
-
from src.nl2sql.hf_engine import generate_sql
|
| 10 |
-
|
| 11 |
|
| 12 |
TEST_CASES_PATH = Path("src/scripts/test_cases.json")
|
| 13 |
RESULTS_PATH = Path("hf_evaluation_results.json")
|
| 14 |
|
| 15 |
-
|
| 16 |
def _normalize_dataframe(dataframe: pd.DataFrame) -> pd.DataFrame:
|
|
|
|
| 17 |
normalized = dataframe.copy()
|
| 18 |
normalized.columns = [str(column).lower() for column in normalized.columns]
|
| 19 |
|
|
@@ -30,7 +28,7 @@ def _normalize_dataframe(dataframe: pd.DataFrame) -> pd.DataFrame:
|
|
| 30 |
|
| 31 |
return normalized
|
| 32 |
|
| 33 |
-
|
| 34 |
def compare_results(df_generated: pd.DataFrame, df_gold: pd.DataFrame) -> bool:
|
| 35 |
"""Compare generated and expected query results."""
|
| 36 |
if df_generated is None or df_gold is None:
|
|
@@ -44,8 +42,11 @@ def compare_results(df_generated: pd.DataFrame, df_gold: pd.DataFrame) -> bool:
|
|
| 44 |
print(f"Error comparing results: {error}")
|
| 45 |
return False
|
| 46 |
|
| 47 |
-
|
| 48 |
def run_evaluation():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
with TEST_CASES_PATH.open("r", encoding="utf-8") as handle:
|
| 50 |
test_cases = json.load(handle)
|
| 51 |
|
|
@@ -58,8 +59,9 @@ def run_evaluation():
|
|
| 58 |
question = case["question"]
|
| 59 |
print(f"Testing ID {case['id']}: {question[:50]}...")
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
|
|
|
| 63 |
|
| 64 |
connection = get_db_connection()
|
| 65 |
if connection is None:
|
|
@@ -103,8 +105,4 @@ def run_evaluation():
|
|
| 103 |
print(f"Execution Accuracy: {accuracy:.2f}%")
|
| 104 |
|
| 105 |
with RESULTS_PATH.open("w", encoding="utf-8") as handle:
|
| 106 |
-
json.dump(results, handle, indent=4)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
if __name__ == "__main__":
|
| 110 |
-
run_evaluation()
|
|
|
|
| 1 |
+
# src/scripts/evaluate_hf.py
|
| 2 |
+
# Evaluation script for Hugging Face SQL generation.
|
| 3 |
|
| 4 |
import json
|
| 5 |
from pathlib import Path
|
|
|
|
| 6 |
import pandas as pd
|
| 7 |
+
from src.database.db_manager import get_db_connection
|
| 8 |
+
from src.nl2sql.sql_agent import nl2sql_agent
|
|
|
|
|
|
|
| 9 |
|
| 10 |
TEST_CASES_PATH = Path("src/scripts/test_cases.json")
|
| 11 |
RESULTS_PATH = Path("hf_evaluation_results.json")
|
| 12 |
|
|
|
|
| 13 |
def _normalize_dataframe(dataframe: pd.DataFrame) -> pd.DataFrame:
|
| 14 |
+
# Normalize dataframe to ensure accurate comparison
|
| 15 |
normalized = dataframe.copy()
|
| 16 |
normalized.columns = [str(column).lower() for column in normalized.columns]
|
| 17 |
|
|
|
|
| 28 |
|
| 29 |
return normalized
|
| 30 |
|
| 31 |
+
# Compare generated SQL results with expected results
|
| 32 |
def compare_results(df_generated: pd.DataFrame, df_gold: pd.DataFrame) -> bool:
|
| 33 |
"""Compare generated and expected query results."""
|
| 34 |
if df_generated is None or df_gold is None:
|
|
|
|
| 42 |
print(f"Error comparing results: {error}")
|
| 43 |
return False
|
| 44 |
|
|
|
|
| 45 |
def run_evaluation():
|
| 46 |
+
if not TEST_CASES_PATH.exists():
|
| 47 |
+
print(f"Error: Could not find test cases at {TEST_CASES_PATH}")
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
with TEST_CASES_PATH.open("r", encoding="utf-8") as handle:
|
| 51 |
test_cases = json.load(handle)
|
| 52 |
|
|
|
|
| 59 |
question = case["question"]
|
| 60 |
print(f"Testing ID {case['id']}: {question[:50]}...")
|
| 61 |
|
| 62 |
+
# Implement agent to handle RAG retrieval and SQL generation
|
| 63 |
+
agent_response = nl2sql_agent(user_question=question)
|
| 64 |
+
generated_sql = agent_response.get("query", "")
|
| 65 |
|
| 66 |
connection = get_db_connection()
|
| 67 |
if connection is None:
|
|
|
|
| 105 |
print(f"Execution Accuracy: {accuracy:.2f}%")
|
| 106 |
|
| 107 |
with RESULTS_PATH.open("w", encoding="utf-8") as handle:
|
| 108 |
+
json.dump(results, handle, indent=4)
|
|
|
|
|
|
|
|
|
|
|
|
src/scripts/test_cases.json
CHANGED
|
@@ -42,7 +42,7 @@
|
|
| 42 |
{
|
| 43 |
"id": 9,
|
| 44 |
"question": "Find the total number of items sold for each media type.",
|
| 45 |
-
"gold_sql": "SELECT MediaType.Name,
|
| 46 |
},
|
| 47 |
{
|
| 48 |
"id": 10,
|
|
|
|
| 42 |
{
|
| 43 |
"id": 9,
|
| 44 |
"question": "Find the total number of items sold for each media type.",
|
| 45 |
+
"gold_sql": "SELECT MediaType.Name, SUM(InvoiceLine.Quantity) FROM InvoiceLine JOIN Track ON InvoiceLine.TrackId = Track.TrackId JOIN MediaType ON Track.MediaTypeId = MediaType.MediaTypeId GROUP BY MediaType.Name;"
|
| 46 |
},
|
| 47 |
{
|
| 48 |
"id": 10,
|