Invicto69 commited on
Commit
fb30fd7
·
verified ·
1 Parent(s): 355116c

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (5) hide show
  1. app.py +124 -0
  2. readme.md +79 -0
  3. requirements.txt +6 -3
  4. utils.py +154 -0
  5. var.py +62 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Generator
2
+ from utils import validate_uri, extract_code_blocks, get_info_sqlalchemy
3
+ from langchain_community.utilities import SQLDatabase
4
+ from var import system_prompt, markdown_info, query_output
5
+ import streamlit as st
6
+ from openai import OpenAI
7
+
8
+ st.set_page_config(layout="wide")
9
+
10
+ # Initialize chat history and selected model
11
+ if "messages" not in st.session_state:
12
+ st.session_state.messages = []
13
+ st.session_state.sql_result = []
14
+
15
+ if "selected_model" not in st.session_state:
16
+ st.session_state.selected_model = None
17
+
18
+ st.markdown("# SQL Chat")
19
+
20
+ st.sidebar.title("Settings")
21
+ base_url = st.sidebar.text_input("Base URL", help="OpenAI compatible API")
22
+ api_key = st.sidebar.text_input("API Key")
23
+ model = st.sidebar.text_input("Model ID")
24
+
25
+ if st.session_state.selected_model != model:
26
+ st.session_state.messages = []
27
+ st.session_state.sql_result = []
28
+ st.session_state.selected_model = model
29
+
30
+ uri = st.sidebar.text_input("Enter SQL Database URI")
31
+
32
+ if not validate_uri(uri):
33
+ st.sidebar.error("Enter valid URI")
34
+ else:
35
+ st.sidebar.success("URI is valid")
36
+ db_info = get_info_sqlalchemy(uri)
37
+ markdown_info = markdown_info.format(**db_info)
38
+ with st.expander("SQL Database Info"):
39
+ st.markdown(markdown_info)
40
+ system_prompt = system_prompt.format(markdown_info = markdown_info)
41
+
42
+ if base_url and api_key and model and uri:
43
+ client = OpenAI(
44
+ base_url=base_url,
45
+ api_key=api_key,
46
+ )
47
+
48
+ db = SQLDatabase.from_uri(uri)
49
+
50
+ avatar = {"user": '👨‍💻', "assistant": '🤖', "executor": '🛢'}
51
+
52
+ # Display chat messages from history on app rerun
53
+ for i, message in enumerate(st.session_state.messages):
54
+ with st.chat_message(message["role"], avatar=avatar[message["role"]]):
55
+ st.markdown(message["content"])
56
+ if (i+1)%2 == 0:
57
+ with st.chat_message("SQL Executor", avatar=avatar["executor"]):
58
+ st.markdown(st.session_state.sql_result[i//2])
59
+
60
+
61
+ def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
62
+ """Yield chat response content from the Groq API response."""
63
+ for chunk in chat_completion:
64
+ if chunk.choices[0].delta.content:
65
+ yield chunk.choices[0].delta.content
66
+
67
+
68
+ if prompt := st.chat_input("Enter your prompt here..."):
69
+ st.session_state.messages.append({"role": "user", "content": prompt})
70
+
71
+ with st.chat_message("user", avatar=avatar["user"]):
72
+ st.markdown(prompt)
73
+
74
+ # Fetch response from Groq API
75
+ try:
76
+ chat_completion = client.chat.completions.create(
77
+ model=model,
78
+ messages=[{
79
+ "role": "system",
80
+ "content": system_prompt
81
+ },
82
+ ]+
83
+ [
84
+ {
85
+ "role": m["role"],
86
+ "content": m["content"]
87
+ }
88
+ for m in st.session_state.messages[-8:]
89
+ ],
90
+ max_tokens=3000,
91
+ stream=True
92
+ )
93
+
94
+ # Use the generator function with st.write_stream
95
+ with st.chat_message("SQL Assistant", avatar=avatar["assistant"]):
96
+ chat_responses_generator = generate_chat_responses(chat_completion)
97
+ llm_response = st.write_stream(chat_responses_generator)
98
+
99
+ with st.chat_message("SQL Executor", avatar=avatar["executor"]):
100
+ query = extract_code_blocks(llm_response)
101
+ result = db.run(query[0])
102
+ query_response = st.write(query_output.format(result=result))
103
+
104
+ except Exception as e:
105
+ st.error(e, icon="🚨")
106
+
107
+ if len(str(result)) > 1000:
108
+ query_output_truncated = query_output.format(result=result)[:500]+query_output.format(result=result)[-500:]
109
+ else:
110
+ query_output_truncated = query_output.format(result=result)
111
+
112
+ st.session_state.sql_result.append(query_output_truncated)
113
+
114
+ # Append the llm response to session_state.messages
115
+ if isinstance(llm_response, str):
116
+ st.session_state.messages.append(
117
+ {"role": "assistant", "content": llm_response})
118
+ else:
119
+ # Handle the case where llm_response is not a string
120
+ combined_response = "\n".join(str(item) for item in llm_response)
121
+ st.session_state.messages.append(
122
+ {"role": "assistant", "content": combined_response})
123
+
124
+ st.sidebar.button("Clear Chat History", on_click=lambda: st.session_state.messages.clear() and st.session_state.sql_result.clear())
readme.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SQLchat
2
+
3
+ This project is a **SQL Chatbot** built with **LangChain** and **Streamlit**, designed to generate SQL queries and execute queries
4
+ based on database table schemas and structure. The chatbot can interact with users to understand their requirements
5
+ and translate them into SQL queries, leveraging relational database information provided via URI and schema definitions.
6
+
7
+ ## Features
8
+
9
+ - **SQL Query Generator**: Automatically generates SQL queries based on user inputs and database structure.
10
+ - **SQL Query Execution**: Automatically executes SQL queries generated by chatbot.
11
+ - **Interactive Chat Interface**: Built with Streamlit for a user-friendly conversational experience.
12
+ - **Database Schema Integration**: Parses table schemas from a database URI to provide accurate SQL generation capabilities.
13
+ - **Customizable LLM Configuration**: Supports various large language models (LLMs) for generating responses.
14
+
15
+ ## Installation
16
+
17
+ 1. Clone the repository:
18
+
19
+ ```bash
20
+ git clone https://github.com/arthiondaena/SQLchat.git
21
+ cd SQLchat
22
+ ```
23
+
24
+ 2. Set up a virtual environment:
25
+
26
+ ```bash
27
+ python -m venv venv
28
+ source venv/bin/activate # On Windows: venv\Scripts\activate
29
+ ```
30
+
31
+ 3. Install dependencies:
32
+
33
+ ```bash
34
+ pip install -r requirements.txt
35
+ ```
36
+
37
+ ## Usage
38
+
39
+ Run the application using Streamlit:
40
+
41
+ ```bash
42
+ streamlit run app.py
43
+ ```
44
+
45
+ This will launch the chatbot interface in your default web browser. The chatbot can then process user inputs and generate SQL queries based on the database schema.
46
+
47
+ ## Setup
48
+
49
+ 1. **Configure Database Connection**:
50
+ - Set up the `URI` configuration in the streamlit app to connect to your relational database.
51
+ - Ensure the database has the necessary permissions to allow schema queries.
52
+
53
+ 2. **Table Schemas**:
54
+ - The chatbot extracts table structures and schemas from the database for generating SQL queries. Make sure the database contains valid schema definitions.
55
+
56
+ 3. **API Key Configuration**:
57
+ - Provide your Groq API key for LLM integration within the script.
58
+
59
+ 4. **System Prompt Customization**:
60
+ - Adjust the instructions as per your specific SQL generation use case.
61
+ - The chatbot can remember upto last 4 conversations.
62
+
63
+ ## Features in Detail
64
+
65
+ 1. **SQL Query Generation**:
66
+ - The chatbot uses relational database schemas to intelligently generate SQL queries.
67
+ - Supports basic and complex queries tailored to the provided database structure.
68
+
69
+ 2. **Database Schema Utilization**:
70
+ - Extracts table information (columns, types, relationships) from the connected database.
71
+ - Leverages this knowledge to produce highly precise SQL queries.
72
+
73
+ 3. **Customizable Model Prompts**:
74
+ - Custom system prompts and instructions can be added to suit diverse database use cases.
75
+
76
+ ## Example Workflow
77
+ 1. Connect the chatbot to your database by specifying the database URI.
78
+ 2. Provide the chatbot with your SQL query requirement in plain language (e.g., "Fetch the top 10 customers by revenue").
79
+ 3. The chatbot generates and returns an accurate SQL query based on the schema.
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
1
+ groq
2
+ langchain
3
+ langchain[groq]
4
+ streamlit
5
+ langchain_community
6
+ psycopg2
utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from langchain_community.utilities import SQLDatabase
3
+ from langchain_community.tools.sql_database.tool import ListSQLDatabaseTool, InfoSQLDatabaseTool
4
+ from sqlalchemy import (
5
+ create_engine,
6
+ MetaData,
7
+ inspect,
8
+ Table,
9
+ select,
10
+ distinct
11
+ )
12
+ from sqlalchemy.schema import CreateTable
13
+ from sqlalchemy.exc import ProgrammingError
14
+ from sqlalchemy.engine import Engine
15
+ import re
16
+
17
+ def get_all_groq_model(api_key:str=None) -> list:
18
+ """Uses Groq API to fetch all the available models."""
19
+ if api_key is None:
20
+ raise ValueError("API key is required")
21
+ url = "https://api.groq.com/openai/v1/models"
22
+
23
+ headers = {
24
+ "Authorization": f"Bearer {api_key}",
25
+ "Content-Type": "application/json"
26
+ }
27
+
28
+ response = requests.get(url, headers=headers)
29
+
30
+ data = response.json()['data']
31
+ model_ids = [model['id'] for model in data]
32
+
33
+ return model_ids
34
+
35
+ def validate_api_key(api_key:str) -> bool:
36
+ """Validates the Groq API key using the get_all_groq_model function."""
37
+ if len(api_key) == 0:
38
+ return False
39
+ try:
40
+ get_all_groq_model(api_key=api_key)
41
+ return True
42
+ except Exception as e:
43
+ return False
44
+
45
+ def validate_uri(uri:str) -> bool:
46
+ """Validates the SQL Database URI using the SQLDatabase.from_uri function."""
47
+ try:
48
+ SQLDatabase.from_uri(uri)
49
+ return True
50
+ except Exception as e:
51
+ return False
52
+
53
+ def get_info(uri:str) -> dict[str, str] | None:
54
+ """Gets the dialect name, accessible tables and table schemas using the SQLDatabase toolkit"""
55
+ db = SQLDatabase.from_uri(uri)
56
+ dialect = db.dialect
57
+ # List all the tables accessible to the user.
58
+ access_tables = ListSQLDatabaseTool(db=db).invoke("")
59
+ # List the table schemas of all the accessible tables.
60
+ tables_schemas = InfoSQLDatabaseTool(db=db).invoke(access_tables)
61
+ return {'sql_dialect': dialect, 'tables': access_tables, 'tables_schema': tables_schemas}
62
+
63
+ def get_sample_rows(engine:Engine, table:Table, row_count: int = 3) -> str:
64
+ """Gets the sample rows of a table using the SQLAlchemy engine"""
65
+ # build the select command
66
+ command = select(table).limit(row_count)
67
+
68
+ # save the columns in string format
69
+ columns_str = "\t".join([col.name for col in table.columns])
70
+
71
+ try:
72
+ # get the sample rows
73
+ with engine.connect() as connection:
74
+ sample_rows_result = connection.execute(command) # type: ignore
75
+ # shorten values in the sample rows
76
+ sample_rows = list(
77
+ map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
78
+ )
79
+
80
+ # save the sample rows in string format
81
+ sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
82
+
83
+ # in some dialects when there are no rows in the table a
84
+ # 'ProgrammingError' is returned
85
+ except ProgrammingError:
86
+ sample_rows_str = ""
87
+
88
+ return (
89
+ f"{row_count} rows from {table.name} table:\n"
90
+ f"{columns_str}\n"
91
+ f"{sample_rows_str}"
92
+ )
93
+
94
+ def get_unique_values(engine:Engine, table:Table) -> str:
95
+ """Gets the unique values of each column in a table using the SQLAlchemy engine"""
96
+ unique_values = {}
97
+ for column in table.c:
98
+ command = select(distinct(column))
99
+
100
+ try:
101
+ # get the sample rows
102
+ with engine.connect() as connection:
103
+ result = connection.execute(command) # type: ignore
104
+ # shorten values in the sample rows
105
+ unique_values[column.name] = [str(u) for u in result]
106
+
107
+ # save the sample rows in string format
108
+ # sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
109
+ # in some dialects when there are no rows in the table a
110
+ # 'ProgrammingError' is returned
111
+ except ProgrammingError:
112
+ sample_rows_str = ""
113
+
114
+ output_str = f"Unique values of each column in {table.name}: \n"
115
+ for column, values in unique_values.items():
116
+ output_str += f"{column} has {len(values)} unique values: {' '.join(values[:20])}"
117
+ if len(values) > 20:
118
+ output_str += ", ...."
119
+ output_str += "\n"
120
+
121
+ return output_str
122
+
123
+ def get_info_sqlalchemy(uri:str) -> dict[str, str] | None:
124
+ """Gets the dialect name, accessible tables and table schemas using the SQLAlchemy engine"""
125
+ engine = create_engine(uri)
126
+ # Get dialect name using inspector
127
+ inspector = inspect(engine)
128
+ dialect = inspector.dialect.name
129
+ # Metadata for tables and columns
130
+ m = MetaData()
131
+ m.reflect(engine)
132
+
133
+ tables = {}
134
+ for table in m.tables.values():
135
+ tables[table.name] = str(CreateTable(table).compile(engine)).rstrip()
136
+ tables[table.name] += "\n\n/*"
137
+ tables[table.name] += "\n" + get_sample_rows(engine, table)+"\n"
138
+ tables[table.name] += "\n" + get_unique_values(engine, table)+"\n"
139
+ tables[table.name] += "*/"
140
+
141
+ return {'sql_dialect': dialect, 'tables': ", ".join(tables.keys()), 'tables_schema': "\n\n".join(tables.values())}
142
+
143
+ def extract_code_blocks(text):
144
+ pattern = r"```(?:\w+)?\n(.*?)\n```"
145
+ matches = re.findall(pattern, text, re.DOTALL)
146
+ return matches
147
+
148
+ if __name__ == "__main__":
149
+ from dotenv import load_dotenv
150
+ import os
151
+ load_dotenv()
152
+
153
+ uri = os.getenv("POSTGRES_URI")
154
+ print(get_info_sqlalchemy(uri))
var.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ groq_models = ['llama-3.3-70b-versatile', 'gemma2-9b-it', 'llama-3.2-3b-preview', 'deepseek-r1-distill-llama-70b', 'qwen-2.5-coder-32b',
2
+ 'mixtral-8x7b-32768', 'llama-3.1-8b-instant', 'llama-3.2-1b-preview', 'allam-2-7b', 'qwen-qwq-32b', 'llama3-70b-8192',
3
+ 'mistral-saba-24b', 'deepseek-r1-distill-qwen-32b', 'qwen-2.5-32b', 'llama-3.3-70b-specdec', 'llama3-8b-8192', 'llama-guard-3-8b']
4
+
5
+ db_info = {'sql_dialect': '', 'tables': '', 'tables_schema': ''}
6
+
7
+ markdown_info = """
8
+ **SQL Dialect**: {sql_dialect}\n
9
+ **Tables**: {tables}\n
10
+ **Tables Schema**:
11
+ ```sql
12
+ {tables_schema}
13
+ ```
14
+ """
15
+
16
+ system_prompt = """
17
+ You are an AI assistant specialized in generating optimized SQL queries based on user instructions. \
18
+ You have access to the database schema provided in a structured Markdown format. Use this schema to ensure \
19
+ correctness, efficiency, and security in your SQL queries.\
20
+
21
+ ## SQL Database Info
22
+ {markdown_info}
23
+
24
+ ---
25
+
26
+ ## Query Generation Guidelines
27
+ 1. **Ensure Query Validity**: Use only the tables and columns defined in the schema.
28
+ 2. **Optimize Performance**: Prefer indexed columns for filtering, avoid `SELECT *` where specific columns suffice.
29
+ 3. **Security Best Practices**: Always use parameterized queries or placeholders instead of direct user inputs.
30
+ 4. **Context Awareness**: Understand the intent behind the query and generate the most relevant SQL statement.
31
+ 5. **Formatting**: Return queries in a clean, well-structured format with appropriate indentation.
32
+ 6. **Commenting**: Include comments in complex queries to explain logic when needed.
33
+ 7. **Result**: Don't return the result of the query, return only the SQL query.
34
+ 8. **Optimal**: Try to generate query which is optimal and not brute force.
35
+ 9. **Single query**: Generate a best single SQL query for the user input.'
36
+ 10. **Comment**: Include comments in the query to explain the logic behind it.
37
+
38
+ ---
39
+
40
+ ## Expected Output Format
41
+
42
+ The SQL query should be returned as a formatted code block:
43
+
44
+ ```sql
45
+ -- Get all completed orders with user details
46
+ -- Comment explaining the logic.
47
+ SELECT orders.id, users.name, users.email, orders.amount, orders.created_at
48
+ FROM orders
49
+ JOIN users ON orders.user_id = users.id
50
+ WHERE orders.status = 'completed'
51
+ ORDER BY orders.created_at DESC;
52
+ ```
53
+
54
+ If the user's request is ambiguous, ask clarifying questions before generating the query.
55
+ """
56
+
57
+ query_output = """
58
+ **The result of query execution:**
59
+ ```sql
60
+ {result}
61
+ ```
62
+ """