dvwn commited on
Commit
a218c12
·
1 Parent(s): 71c851f

Update main entry point for nl2sql application > app/main.py

Browse files

1. Run from the main entry point, main.py
2. Remove local test block for db_manager, hf_engine, and sql_agent

app/main.py CHANGED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Path: app/main.py
2
+ # Main entry point for the NL2SQL application
3
+ import os
4
+ from dotenv import load_dotenv
5
+ from src.nl2sql.sql_agent import nl2sql_agent
6
+ from src.scripts.evaluate_hf import run_evaluation
7
+
8
+ load_dotenv()
9
+
10
+ # User prompt question manually and see the agent's response
11
+ def interactive_mode():
12
+ """Allows user to manually type questions and get agent's response."""
13
+ print("\n========= Interactive NL2SQL Mode =========")
14
+ print("Type 'exit' or 'q' to return to the main menu.\n")
15
+
16
+ while True:
17
+ question = input("\nEnter your question: ")
18
+ if question.lower() in ['exit', 'q']:
19
+ break
20
+ if not question.strip():
21
+ continue
22
+
23
+ print("\nProcessing your question...")
24
+ response = nl2sql_agent(question)
25
+
26
+ print("\n========= Agent Response =========")
27
+ print(f"Status: {response.get('status')}")
28
+ print(f"Generated SQL:\n{response.get('query')}")
29
+
30
+ if response.get('status') == 'success':
31
+ print(f"\nresults (First 5 rows):\n{response.get('results')[:5]}")
32
+ else:
33
+ print(f"\nError Details:\n{response.get('error')}")
34
+ print("==================================\n")
35
+
36
+ def main():
37
+ """Main application entry point and interactive menu"""
38
+ while True:
39
+ print("\n" + "="*30)
40
+ print(" NL2SQL Application Main Menu")
41
+ print("\n" + "="*30)
42
+ print("1. Run Interactive Agent NL2SQL Mode (Ask a single question)")
43
+ print("2. Run Batch Evaluation of NL2SQL Agent (Evaluate on 15 test cases)")
44
+ print("3. Exit")
45
+ print("\n" + "="*30)
46
+
47
+ choice = input("Select an option (1-3): ")
48
+
49
+ if choice == '1':
50
+ interactive_mode()
51
+ elif choice == '2':
52
+ run_evaluation()
53
+ elif choice == '3':
54
+ print("Exiting application. Goodbye!")
55
+ break
56
+ else:
57
+ print("Invalid choice. Please select a valid option (1, 2, or 3).")
58
+
59
+ if __name__ == "__main__":
60
+ main()
src/database/__pycache__/db_manager.cpython-313.pyc CHANGED
Binary files a/src/database/__pycache__/db_manager.cpython-313.pyc and b/src/database/__pycache__/db_manager.cpython-313.pyc differ
 
src/database/db_manager.py CHANGED
@@ -3,14 +3,11 @@
3
  # Include a RAG-based dynamic schema retrieval.
4
  import os
5
  import sqlite3
6
- from dotenv import load_dotenv
7
  from langchain_community.utilities import SQLDatabase
8
  from langchain_chroma import Chroma
9
  from langchain_huggingface import HuggingFaceEmbeddings
10
  from langchain_core.documents import Document
11
 
12
- load_dotenv()
13
-
14
  # The path to SQLite database
15
  DB_PATH = os.getenv("SQLITE_DB_PATH", "src/database/Chinook_Sqlite.sqlite")
16
  DB_URI = f"sqlite:///{DB_PATH}"
@@ -78,10 +75,4 @@ def get_schema_context(question=None):
78
  db = SQLDatabase.from_uri(DB_URI, include_tables = tables_to_include, sample_rows_in_table_info=5)
79
  return db.get_table_info()
80
  except Exception as e:
81
- return f"Error retrieving database schema: {e}"
82
-
83
- if __name__ == "__main__":
84
- # Example usage: Get schema context for a specific question
85
- question = "What tables contain information about music tracks and their genres?"
86
- schema_context = get_schema_context(question)
87
- print(schema_context)
 
3
  # Include a RAG-based dynamic schema retrieval.
4
  import os
5
  import sqlite3
 
6
  from langchain_community.utilities import SQLDatabase
7
  from langchain_chroma import Chroma
8
  from langchain_huggingface import HuggingFaceEmbeddings
9
  from langchain_core.documents import Document
10
 
 
 
11
  # The path to SQLite database
12
  DB_PATH = os.getenv("SQLITE_DB_PATH", "src/database/Chinook_Sqlite.sqlite")
13
  DB_URI = f"sqlite:///{DB_PATH}"
 
75
  db = SQLDatabase.from_uri(DB_URI, include_tables = tables_to_include, sample_rows_in_table_info=5)
76
  return db.get_table_info()
77
  except Exception as e:
78
+ return f"Error retrieving database schema: {e}"
 
 
 
 
 
 
src/nl2sql/__pycache__/hf_engine.cpython-313.pyc CHANGED
Binary files a/src/nl2sql/__pycache__/hf_engine.cpython-313.pyc and b/src/nl2sql/__pycache__/hf_engine.cpython-313.pyc differ
 
src/nl2sql/hf_engine.py CHANGED
@@ -47,17 +47,4 @@ def get_llm(model_id: str = DEFAULT_MODEL_ID):
47
  client = InferenceClient(api_key=hf_token)
48
  llm = HFChatWrapper(client=client, model_id=model_id)
49
 
50
- return llm
51
-
52
- # Local Test block
53
- if __name__ == "__main__":
54
- from dotenv import load_dotenv
55
- load_dotenv()
56
-
57
- try:
58
- test_llm = get_llm()
59
- print("Model loaded successfully! Running a quick ping...")
60
- response = test_llm.invoke("Write a single SQL statement to count all rows in a table named 'Employee'.")
61
- print(f"\nResponse:\n{response}")
62
- except Exception as e:
63
- print(f"Error during LLM initialization or invocation: {e}")
 
47
  client = InferenceClient(api_key=hf_token)
48
  llm = HFChatWrapper(client=client, model_id=model_id)
49
 
50
+ return llm
 
 
 
 
 
 
 
 
 
 
 
 
 
src/nl2sql/sql_agent.py CHANGED
@@ -90,18 +90,4 @@ def nl2sql_agent(user_question: str) -> dict:
90
  "status": f"error executing SQL: {e}"
91
  }
92
  finally:
93
- connection.close()
94
-
95
- # Local Test block
96
- if __name__ == "__main__":
97
- from dotenv import load_dotenv
98
- load_dotenv()
99
-
100
- test_question = "How many employees are there in the Employee table?"
101
-
102
- print("------Starting NL2SQL Agent Test------")
103
- result = nl2sql_agent(test_question)
104
-
105
- print("------Final Output------")
106
- for key, value in result.items():
107
- print(f"{key.capitalize()}: {value}")
 
90
  "status": f"error executing SQL: {e}"
91
  }
92
  finally:
93
+ connection.close()