Spaces:
Running
Running
Update main entry point for nl2sql application > app/main.py
Browse files1. Run from the main entry point, main.py
2. Remove local test block for db_manager, hf_engine, and sql_agent
app/main.py
CHANGED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Path: app/main.py
|
| 2 |
+
# Main entry point for the NL2SQL application
|
| 3 |
+
import os
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from src.nl2sql.sql_agent import nl2sql_agent
|
| 6 |
+
from src.scripts.evaluate_hf import run_evaluation
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
|
| 10 |
+
# User prompt question manually and see the agent's response
|
| 11 |
+
def interactive_mode():
|
| 12 |
+
"""Allows user to manually type questions and get agent's response."""
|
| 13 |
+
print("\n========= Interactive NL2SQL Mode =========")
|
| 14 |
+
print("Type 'exit' or 'q' to return to the main menu.\n")
|
| 15 |
+
|
| 16 |
+
while True:
|
| 17 |
+
question = input("\nEnter your question: ")
|
| 18 |
+
if question.lower() in ['exit', 'q']:
|
| 19 |
+
break
|
| 20 |
+
if not question.strip():
|
| 21 |
+
continue
|
| 22 |
+
|
| 23 |
+
print("\nProcessing your question...")
|
| 24 |
+
response = nl2sql_agent(question)
|
| 25 |
+
|
| 26 |
+
print("\n========= Agent Response =========")
|
| 27 |
+
print(f"Status: {response.get('status')}")
|
| 28 |
+
print(f"Generated SQL:\n{response.get('query')}")
|
| 29 |
+
|
| 30 |
+
if response.get('status') == 'success':
|
| 31 |
+
print(f"\nresults (First 5 rows):\n{response.get('results')[:5]}")
|
| 32 |
+
else:
|
| 33 |
+
print(f"\nError Details:\n{response.get('error')}")
|
| 34 |
+
print("==================================\n")
|
| 35 |
+
|
| 36 |
+
def main():
|
| 37 |
+
"""Main application entry point and interactive menu"""
|
| 38 |
+
while True:
|
| 39 |
+
print("\n" + "="*30)
|
| 40 |
+
print(" NL2SQL Application Main Menu")
|
| 41 |
+
print("\n" + "="*30)
|
| 42 |
+
print("1. Run Interactive Agent NL2SQL Mode (Ask a single question)")
|
| 43 |
+
print("2. Run Batch Evaluation of NL2SQL Agent (Evaluate on 15 test cases)")
|
| 44 |
+
print("3. Exit")
|
| 45 |
+
print("\n" + "="*30)
|
| 46 |
+
|
| 47 |
+
choice = input("Select an option (1-3): ")
|
| 48 |
+
|
| 49 |
+
if choice == '1':
|
| 50 |
+
interactive_mode()
|
| 51 |
+
elif choice == '2':
|
| 52 |
+
run_evaluation()
|
| 53 |
+
elif choice == '3':
|
| 54 |
+
print("Exiting application. Goodbye!")
|
| 55 |
+
break
|
| 56 |
+
else:
|
| 57 |
+
print("Invalid choice. Please select a valid option (1, 2, or 3).")
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
main()
|
src/database/__pycache__/db_manager.cpython-313.pyc
CHANGED
|
Binary files a/src/database/__pycache__/db_manager.cpython-313.pyc and b/src/database/__pycache__/db_manager.cpython-313.pyc differ
|
|
|
src/database/db_manager.py
CHANGED
|
@@ -3,14 +3,11 @@
|
|
| 3 |
# Include a RAG-based dynamic schema retrieval.
|
| 4 |
import os
|
| 5 |
import sqlite3
|
| 6 |
-
from dotenv import load_dotenv
|
| 7 |
from langchain_community.utilities import SQLDatabase
|
| 8 |
from langchain_chroma import Chroma
|
| 9 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 10 |
from langchain_core.documents import Document
|
| 11 |
|
| 12 |
-
load_dotenv()
|
| 13 |
-
|
| 14 |
# The path to SQLite database
|
| 15 |
DB_PATH = os.getenv("SQLITE_DB_PATH", "src/database/Chinook_Sqlite.sqlite")
|
| 16 |
DB_URI = f"sqlite:///{DB_PATH}"
|
|
@@ -78,10 +75,4 @@ def get_schema_context(question=None):
|
|
| 78 |
db = SQLDatabase.from_uri(DB_URI, include_tables = tables_to_include, sample_rows_in_table_info=5)
|
| 79 |
return db.get_table_info()
|
| 80 |
except Exception as e:
|
| 81 |
-
return f"Error retrieving database schema: {e}"
|
| 82 |
-
|
| 83 |
-
if __name__ == "__main__":
|
| 84 |
-
# Example usage: Get schema context for a specific question
|
| 85 |
-
question = "What tables contain information about music tracks and their genres?"
|
| 86 |
-
schema_context = get_schema_context(question)
|
| 87 |
-
print(schema_context)
|
|
|
|
| 3 |
# Include a RAG-based dynamic schema retrieval.
|
| 4 |
import os
|
| 5 |
import sqlite3
|
|
|
|
| 6 |
from langchain_community.utilities import SQLDatabase
|
| 7 |
from langchain_chroma import Chroma
|
| 8 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 9 |
from langchain_core.documents import Document
|
| 10 |
|
|
|
|
|
|
|
| 11 |
# The path to SQLite database
|
| 12 |
DB_PATH = os.getenv("SQLITE_DB_PATH", "src/database/Chinook_Sqlite.sqlite")
|
| 13 |
DB_URI = f"sqlite:///{DB_PATH}"
|
|
|
|
| 75 |
db = SQLDatabase.from_uri(DB_URI, include_tables = tables_to_include, sample_rows_in_table_info=5)
|
| 76 |
return db.get_table_info()
|
| 77 |
except Exception as e:
|
| 78 |
+
return f"Error retrieving database schema: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/nl2sql/__pycache__/hf_engine.cpython-313.pyc
CHANGED
|
Binary files a/src/nl2sql/__pycache__/hf_engine.cpython-313.pyc and b/src/nl2sql/__pycache__/hf_engine.cpython-313.pyc differ
|
|
|
src/nl2sql/hf_engine.py
CHANGED
|
@@ -47,17 +47,4 @@ def get_llm(model_id: str = DEFAULT_MODEL_ID):
|
|
| 47 |
client = InferenceClient(api_key=hf_token)
|
| 48 |
llm = HFChatWrapper(client=client, model_id=model_id)
|
| 49 |
|
| 50 |
-
return llm
|
| 51 |
-
|
| 52 |
-
# Local Test block
|
| 53 |
-
if __name__ == "__main__":
|
| 54 |
-
from dotenv import load_dotenv
|
| 55 |
-
load_dotenv()
|
| 56 |
-
|
| 57 |
-
try:
|
| 58 |
-
test_llm = get_llm()
|
| 59 |
-
print("Model loaded successfully! Running a quick ping...")
|
| 60 |
-
response = test_llm.invoke("Write a single SQL statement to count all rows in a table named 'Employee'.")
|
| 61 |
-
print(f"\nResponse:\n{response}")
|
| 62 |
-
except Exception as e:
|
| 63 |
-
print(f"Error during LLM initialization or invocation: {e}")
|
|
|
|
| 47 |
client = InferenceClient(api_key=hf_token)
|
| 48 |
llm = HFChatWrapper(client=client, model_id=model_id)
|
| 49 |
|
| 50 |
+
return llm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/nl2sql/sql_agent.py
CHANGED
|
@@ -90,18 +90,4 @@ def nl2sql_agent(user_question: str) -> dict:
|
|
| 90 |
"status": f"error executing SQL: {e}"
|
| 91 |
}
|
| 92 |
finally:
|
| 93 |
-
connection.close()
|
| 94 |
-
|
| 95 |
-
# Local Test block
|
| 96 |
-
if __name__ == "__main__":
|
| 97 |
-
from dotenv import load_dotenv
|
| 98 |
-
load_dotenv()
|
| 99 |
-
|
| 100 |
-
test_question = "How many employees are there in the Employee table?"
|
| 101 |
-
|
| 102 |
-
print("------Starting NL2SQL Agent Test------")
|
| 103 |
-
result = nl2sql_agent(test_question)
|
| 104 |
-
|
| 105 |
-
print("------Final Output------")
|
| 106 |
-
for key, value in result.items():
|
| 107 |
-
print(f"{key.capitalize()}: {value}")
|
|
|
|
| 90 |
"status": f"error executing SQL: {e}"
|
| 91 |
}
|
| 92 |
finally:
|
| 93 |
+
connection.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|