Kshitijk20 commited on
Commit
99bbd9b
·
1 Parent(s): da72b93

code push

Browse files
.gitignore ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ .venv/
7
+ env/
8
+ venv/
9
+ ENV/
10
+ build/
11
+ dist/
12
+ *.egg-info/
13
+
14
+ # VS Code
15
+ .vscode/
16
+
17
+ # Environment variables
18
+ .env
19
+
20
+ # OS files
21
+ .DS_Store
22
+ Thumbs.db
23
+ sql_agent.txt
24
+
25
+ # Logs
26
+ *.log
27
+
28
+ # Docker
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.13
Dockerfile ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim-bookworm
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies for aws
6
+ # RUN apt-get update && apt-get install -y --no-install-recommends \
7
+ # libpq-dev \
8
+ # gcc \
9
+ # g++ \
10
+ # nginx \
11
+ # && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy requirements file and install dependencies
14
+ COPY requirements.txt .
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # Copy the entire application
18
+ COPY . /app
19
+
20
+ # Nginx configuration
21
+ # RUN echo " \
22
+ # server { \
23
+ # listen 80; \
24
+ # server_name localhost; \
25
+ # location / { \
26
+ # proxy_pass http://localhost:8501; \
27
+ # proxy_set_header Host \$host; \
28
+ # proxy_set_header X-Real-IP \$remote_addr; \
29
+ # proxy_set_header X-Forwarded-For \$proxy_add_x_forwarded_for; \
30
+ # proxy_set_header X-Forwarded-Proto \$scheme; \
31
+ # } \
32
+ # }" > /etc/nginx/conf.d/default.conf
33
+
34
+ # Expose the ports
35
+ # EXPOSE 8000 8501 80
36
+ # EXPOSE 8000
37
+
38
+ # Start Nginx and then the backend and frontend
39
+ # CMD service nginx start && uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload & streamlit run app/frontend/Talk2SQL.py --server.address=0.0.0.0 --server.port=8501
40
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
app/__init__.py ADDED
File without changes
app/api/__init__.py ADDED
File without changes
app/api/v1/__init__.py ADDED
File without changes
app/api/v1/auth.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from fastapi import APIRouter, HTTPException, status, Depends
3
+ from pydantic import BaseModel
4
+ import sqlite3
5
+ from passlib.context import CryptContext
6
+ from fastapi.security import OAuth2PasswordRequestForm
7
+
8
+ router = APIRouter()
9
+
10
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
11
+ DB_PATH = "users.db" # Adjust path if needed
12
+
13
+ def get_db():
14
+ try:
15
+ conn = sqlite3.connect(DB_PATH)
16
+ conn.row_factory = sqlite3.Row
17
+ return conn
18
+ except Exception as e:
19
+ raise HTTPException(status_code=500, detail=f"Database connection error: {str(e)}")
20
+
21
+ class UserCreate(BaseModel):
22
+ username: str
23
+ password: str
24
+
25
+ class UserOut(BaseModel):
26
+ id: int | None = None
27
+ username: str
28
+
29
+ def get_user_by_username(conn, username: str):
30
+ try:
31
+ cur = conn.cursor()
32
+ cur.execute("SELECT * FROM users WHERE username = ?", (username,))
33
+ return cur.fetchone()
34
+ except Exception as e:
35
+ return None
36
+
37
+ def create_user(conn, username: str, password: str):
38
+ try:
39
+ hashed_password = pwd_context.hash(password)
40
+ cur = conn.cursor()
41
+ cur.execute("INSERT INTO users (username, password) VALUES (?, ?)", (username, hashed_password))
42
+ conn.commit()
43
+ return cur.lastrowid
44
+ except sqlite3.IntegrityError:
45
+ return None
46
+ except Exception as e:
47
+ return None
48
+
49
+ @router.post("/signup", response_model=UserOut)
50
+ def signup(user: UserCreate):
51
+ try:
52
+ conn = get_db()
53
+ if get_user_by_username(conn, user.username):
54
+ raise HTTPException(status_code=400, detail="Username already exists")
55
+ user_id = create_user(conn, user.username, user.password)
56
+ if not user_id:
57
+ raise HTTPException(status_code=400, detail="Could not create user")
58
+ return {"id": user_id, "username": user.username}
59
+ except HTTPException:
60
+ raise
61
+ except Exception as e:
62
+ raise HTTPException(status_code=500, detail=f"Signup error: {str(e)}")
63
+
64
+ @router.post("/login")
65
+ def login(form_data: OAuth2PasswordRequestForm = Depends()):
66
+ try:
67
+ conn = get_db()
68
+ user = get_user_by_username(conn, form_data.username)
69
+ if not user:
70
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
71
+ try:
72
+ valid = pwd_context.verify(form_data.password, user["password"])
73
+ except Exception:
74
+ raise HTTPException(status_code=500, detail="Password verification error")
75
+ if not valid:
76
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
77
+ # Defensive: handle missing id column
78
+ user_id = user["id"] if "id" in user.keys() else None
79
+ return {"id": user_id, "username": user["username"]}
80
+ except HTTPException:
81
+ raise
82
+ except Exception as e:
83
+ raise HTTPException(status_code=500, detail=f"Login error: {str(e)}")
app/api/v1/endpoints/__init__.py ADDED
File without changes
app/api/v1/endpoints/database_connection.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from fastapi import APIRouter, HTTPException
2
+ # from app.models import DatabaseConnectionRequest
3
+ # # from app.services.sql_agent import setup_database_connection
4
+ # from sqlalchemy.exc import OperationalError, DatabaseError
5
+ # from urllib.parse import urlparse
6
+ # from app.services.sql_agent_instance import sql_agent
7
+
8
+ # router = APIRouter()
9
+
10
+ # @router.post("/setup-connection")
11
+ # async def setup_connection(request: DatabaseConnectionRequest):
12
+ # try:
13
+ # # Basic validation of connection string format
14
+ # parsed = urlparse(request.connection_string)
15
+ # if not all([parsed.scheme, parsed.netloc]):
16
+ # raise HTTPException(
17
+ # status_code=400,
18
+ # detail="Invalid connection string format. Expected format: dialect+driver://username:password@host:port/database"
19
+ # )
20
+
21
+ # sql_agent.setup_database_connection(request.connection_string)
22
+ # return {"message": "Database connection established successfully!"}
23
+ # except OperationalError as e:
24
+ # raise HTTPException(
25
+ # status_code=503,
26
+ # detail=f"Failed to connect to database: Connection refused or invalid credentials. Details: {str(e)}"
27
+ # )
28
+ # except DatabaseError as e:
29
+ # raise HTTPException(
30
+ # status_code=500,
31
+ # detail=f"Database error occurred: {str(e)}"
32
+ # )
33
+ # except ValueError as e:
34
+ # raise HTTPException(
35
+ # status_code=400,
36
+ # detail=f"Invalid configuration: {str(e)}"
37
+ # )
38
+ # except Exception as e:
39
+ # raise HTTPException(
40
+ # status_code=500,
41
+ # detail=f"Unexpected error occurred while setting up database connection: {str(e)}"
42
+ # )
43
+
44
+ # app/api/v1/endpoints/database_connection.py
45
+ from fastapi import APIRouter, HTTPException
46
+ from pydantic import BaseModel
47
+ from app.services.sql_agent_instance import sql_agent
48
+ from sqlalchemy.exc import OperationalError, DatabaseError
49
+ from urllib.parse import urlparse
50
+
51
+ router = APIRouter()
52
+
53
+ class DatabaseConnectionRequest(BaseModel):
54
+ connection_string: str
55
+
56
+ @router.post("/setup-connection")
57
+ async def setup_connection(request: DatabaseConnectionRequest):
58
+ try:
59
+ # Basic validation of connection string format
60
+ parsed = urlparse(request.connection_string)
61
+ if not all([parsed.scheme, parsed.netloc]):
62
+ raise HTTPException(
63
+ status_code=400,
64
+ detail="Invalid connection string format. Expected format: dialect+driver://username:password@host:port/database"
65
+ )
66
+
67
+ sql_agent.setup_database_connection(request.connection_string)
68
+ return {"message": "Database connection established successfully!"}
69
+ except OperationalError as e:
70
+ raise HTTPException(
71
+ status_code=503,
72
+ detail=f"Failed to connect to database: Connection refused or invalid credentials. Details: {str(e)}"
73
+ )
74
+ except DatabaseError as e:
75
+ raise HTTPException(
76
+ status_code=500,
77
+ detail=f"Database error occurred: {str(e)}"
78
+ )
79
+ except ValueError as e:
80
+ raise HTTPException(
81
+ status_code=400,
82
+ detail=f"Invalid configuration: {str(e)}"
83
+ )
84
+ except Exception as e:
85
+ raise HTTPException(
86
+ status_code=500,
87
+ detail=f"Unexpected error occurred while setting up database connection: {str(e)}"
88
+ )
app/api/v1/endpoints/sql_query.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from fastapi import APIRouter, HTTPException
2
+ # from app.models import SQLQueryRequest, SQLQueryResponse
3
+ # from app.services.sql_agent import execute_query
4
+
5
+ # router = APIRouter()
6
+
7
+ # @router.post("/query", response_model=SQLQueryResponse)
8
+ # async def query_database(request: SQLQueryRequest):
9
+ # try:
10
+ # result = execute_query(request.query)
11
+ # return SQLQueryResponse(result=result)
12
+ # except ValueError as e:
13
+ # raise HTTPException(status_code=400, detail=str(e))
14
+ # except Exception as e:
15
+ # raise HTTPException(status_code=500, detail=str(e))
16
+
17
+ # app/api/v1/endpoints/sql_query.py
18
+ from fastapi import APIRouter, HTTPException
19
+ from pydantic import BaseModel
20
+ from app.services.sql_agent_instance import sql_agent
21
+
22
+ router = APIRouter()
23
+
24
+ class SQLQueryRequest(BaseModel):
25
+ query: str
26
+
27
+ class SQLQueryResponse(BaseModel):
28
+ result: str
29
+
30
+ @router.post("/query", response_model=SQLQueryResponse)
31
+ async def query_database(request: SQLQueryRequest):
32
+ try:
33
+ result = sql_agent.execute_query(request.query)
34
+ return SQLQueryResponse(result=result)
35
+ except ValueError as e:
36
+ raise HTTPException(status_code=400, detail=str(e))
37
+ except Exception as e:
38
+ raise HTTPException(status_code=500, detail=str(e))
app/core/config.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseSettings, validator
2
+ from typing import Optional, List
3
+ import os
4
+ from pathlib import Path
5
+ from functools import lru_cache
6
+ import logging
7
+ from dotenv import load_dotenv
8
+
9
+ # Load .env file if it exists
10
+ load_dotenv()
11
+
12
+ class Settings(BaseSettings):
13
+ # Project settings
14
+ PROJECT_NAME: str = "Talk2SQL"
15
+ VERSION: str = "1.0.0"
16
+ API_V1_STR: str = "/api/v1"
17
+
18
+ # Database settings
19
+ USE_DB: bool = os.getenv("USE_DB", "false").lower() == "true"
20
+ DATABASE_URL: Optional[str] = os.getenv("DATABASE_URL")
21
+ DATABASE_HOST: Optional[str] = os.getenv("DATABASE_HOST")
22
+ DATABASE_PORT: Optional[str] = os.getenv("DATABASE_PORT")
23
+ DATABASE_USER: Optional[str] = os.getenv("DATABASE_USER")
24
+ DATABASE_PASSWORD: Optional[str] = os.getenv("DATABASE_PASSWORD")
25
+ DATABASE_NAME: Optional[str] = os.getenv("DATABASE_NAME")
26
+
27
+ # Session management
28
+ SESSION_EXPIRE_MINUTES: int = int(os.getenv("SESSION_EXPIRE_MINUTES", "60"))
29
+ SESSION_SECRET_KEY: str = os.getenv("SESSION_SECRET_KEY")
30
+
31
+ # CORS settings
32
+ BACKEND_CORS_ORIGINS: List[str] = os.getenv("BACKEND_CORS_ORIGINS", "http://localhost:3000").split(",")
33
+
34
+ # Logging configuration
35
+ LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
36
+ LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
37
+
38
+ @validator("SESSION_SECRET_KEY", pre=True)
39
+ def validate_session_secret_key(cls, v: Optional[str]) -> str:
40
+ if not v:
41
+ raise ValueError("SESSION_SECRET_KEY must be set in production environment")
42
+ return v
43
+
44
+ @validator("DATABASE_HOST", "DATABASE_USER", "DATABASE_PASSWORD", "DATABASE_NAME", pre=True)
45
+ def validate_database_settings(cls, v: Optional[str], field: str) -> Optional[str]:
46
+ if not cls.USE_DB:
47
+ return None
48
+ if not v:
49
+ # Default values for development when database is enabled
50
+ defaults = {
51
+ "DATABASE_HOST": "localhost",
52
+ "DATABASE_PORT": "5432",
53
+ "DATABASE_USER": "postgres",
54
+ "DATABASE_PASSWORD": "postgres",
55
+ "DATABASE_NAME": "postgres"
56
+ }
57
+ return defaults.get(field)
58
+ return v
59
+
60
+ class Config:
61
+ case_sensitive = True
62
+ env_file = ".env"
63
+
64
+ def get_database_url(self) -> Optional[str]:
65
+ """Generate database URL if not explicitly set and database is enabled"""
66
+ if not self.USE_DB:
67
+ return None
68
+ if self.DATABASE_URL:
69
+ return self.DATABASE_URL
70
+ if not all([self.DATABASE_HOST, self.DATABASE_PORT, self.DATABASE_USER,
71
+ self.DATABASE_PASSWORD, self.DATABASE_NAME]):
72
+ return None
73
+ return f"postgresql://{self.DATABASE_USER}:{self.DATABASE_PASSWORD}@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}"
74
+
75
+ # Configure logging
76
+ log_level = os.getenv("LOG_LEVEL", "INFO")
77
+ logging.basicConfig(
78
+ level=getattr(logging, log_level.upper()),
79
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
80
+ )
81
+ logger = logging.getLogger(__name__)
82
+
83
+ @lru_cache()
84
+ def get_settings() -> Settings:
85
+ """Return cached settings instance"""
86
+ return Settings()
87
+
88
+ settings = get_settings()
app/frontend/Talk2SQL.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import sqlite3
3
+ import requests
4
+ import hashlib
5
+ import pandas as pd
6
+ # Initialize SQLite database
7
+ def init_db():
8
+ conn = sqlite3.connect('users.db')
9
+ c = conn.cursor()
10
+ c.execute('''CREATE TABLE IF NOT EXISTS users
11
+ (username TEXT PRIMARY KEY, password TEXT)''')
12
+ conn.commit()
13
+ conn.close()
14
+
15
+ # Hash password
16
+ def hash_password(password: str) -> str:
17
+ return hashlib.sha256(password.encode()).hexdigest()
18
+
19
+ # User authentication
20
+ def authenticate_user(username: str, password: str) -> bool:
21
+ conn = sqlite3.connect('users.db')
22
+ c = conn.cursor()
23
+ c.execute('SELECT password FROM users WHERE username=?', (username,))
24
+ result = c.fetchone()
25
+ conn.close()
26
+ return result and result[0] == hash_password(password)
27
+
28
+ # User registration
29
+ def register_user(username: str, password: str) -> bool:
30
+ try:
31
+ conn = sqlite3.connect('users.db')
32
+ c = conn.cursor()
33
+ c.execute('INSERT INTO users VALUES (?, ?)', (username, hash_password(password)))
34
+ conn.commit()
35
+ conn.close()
36
+ return True
37
+ except sqlite3.IntegrityError:
38
+ return False
39
+
40
+ # Initialize session state
41
+ def init_session_state():
42
+ if 'logged_in' not in st.session_state:
43
+ st.session_state.logged_in = False
44
+ if 'current_page' not in st.session_state:
45
+ st.session_state.current_page = 'login'
46
+ if 'username' not in st.session_state:
47
+ st.session_state.username = None
48
+ if 'db_connected' not in st.session_state:
49
+ st.session_state.db_connected = False
50
+
51
+ # Login/Signup page
52
+ def login_page():
53
+ st.set_page_config(page_title="Talk2SQL👨🏼‍💻🛢", layout="wide")
54
+ st.header('Talk2SQL👨🏼‍💻🛢')
55
+ st.title('Login / Sign Up')
56
+
57
+ tab1, tab2 = st.tabs(['Login', 'Sign Up'])
58
+
59
+ with tab1:
60
+ with st.form('login_form'):
61
+ username = st.text_input('Username')
62
+ password = st.text_input('Password', type='password')
63
+ submit = st.form_submit_button('Login')
64
+
65
+ if submit:
66
+ if authenticate_user(username, password):
67
+ st.session_state.logged_in = True
68
+ st.session_state.username = username
69
+ st.session_state.current_page = 'db_connection'
70
+ st.rerun()
71
+ else:
72
+ st.error('Invalid username or password')
73
+
74
+ with tab2:
75
+ with st.form('signup_form'):
76
+ new_username = st.text_input('Username')
77
+ new_password = st.text_input('Password', type='password')
78
+ confirm_password = st.text_input('Confirm Password', type='password')
79
+ submit = st.form_submit_button('Sign Up')
80
+
81
+ if submit:
82
+ if new_password != confirm_password:
83
+ st.error('Passwords do not match')
84
+ elif register_user(new_username, new_password):
85
+ st.success('Registration successful! Please login.')
86
+ else:
87
+ st.error('Username already exists')
88
+
89
+ # Database connection page
90
+ def db_connection_page():
91
+ st.set_page_config(page_title="Talk2SQL👨🏼‍💻🛢", layout="wide")
92
+ st.header('Talk2SQL👨🏼‍💻🛢')
93
+ st.title('Database Connection')
94
+
95
+ # Sidebar content
96
+ with st.sidebar:
97
+ st.header("Sample Data")
98
+
99
+ # Sample connection string
100
+ st.subheader("Sample Connection String")
101
+ st.sidebar.subheader("Sample Connection String")
102
+ st.sidebar.code("mysql+pymysql://admin:9522359448@mydatabase.cf8u2cy0a4h6.us-east-1.rds.amazonaws.com:3306/mydb")
103
+
104
+ st.sidebar.subheader("Sample Table")
105
+ sample_data = pd.DataFrame({
106
+ "id": [1, 2, 3, 4],
107
+ "first_name": ["John", "Jane", "Tom", "Jerry"],
108
+ "last_name": ["Doe", "Doe", "Smith", "Jones"],
109
+ "email": ["johnD@abc.com", "JaneD@abc.com", "toms@abc.com", "Jerry@abc.com"],
110
+ "hire_date": ["2020-01-01", "2020-05-01", "2020-03-01", "2020-02-01"],
111
+ "salary": [50000, 60000, 70000, 80000]
112
+ })
113
+ st.sidebar.dataframe(sample_data)
114
+
115
+ # Sample questions
116
+ st.subheader("Sample Questions")
117
+ questions = [
118
+ "What is the email of John?",
119
+ "What is the lastname of Tom?",
120
+ "Hiredate of the Jerry?"
121
+ ]
122
+ for q in questions:
123
+ st.markdown(f"- {q}")
124
+
125
+ # Logout button
126
+ st.divider()
127
+ if st.button("Logout", type="primary"):
128
+ logout()
129
+
130
+ # Main content
131
+ db_options = ["MySQL", "PostgreSQL"]
132
+ db_type = st.selectbox("Select Database Type", db_options)
133
+ placeholder_text = ""
134
+ if db_type == "PostgreSQL":
135
+ placeholder_text = "postgresql://user:password@host:port/database"
136
+ elif db_type == "MySQL":
137
+ placeholder_text = "mysql+pymysql://user:password@host:port/database"
138
+
139
+ with st.form('connection_form'):
140
+ connection_string = st.text_input('Connection String', placeholder=placeholder_text, disabled=not db_type)
141
+ submit = st.form_submit_button('Connect')
142
+
143
+ if submit and connection_string:
144
+ try:
145
+ response = requests.post(
146
+ 'http://localhost:8000/api/v1/setup-connection',
147
+ json={'connection_string': connection_string}
148
+ )
149
+ if response.status_code == 200:
150
+ st.success('Database connected successfully!')
151
+ st.session_state.db_connected = True
152
+ st.session_state.current_page = 'chat'
153
+ st.rerun()
154
+ else:
155
+ st.error(f'Connection failed: {response.text}')
156
+ except requests.RequestException as e:
157
+ st.error(f'Error connecting to backend: {str(e)}')
158
+
159
+ # Chat interface page
160
+ def chat_page():
161
+ st.set_page_config(page_title="Talk2SQL👨🏼‍💻🛢", layout="wide")
162
+ st.title('Chat Interface')
163
+
164
+ if 'chat_history' not in st.session_state:
165
+ st.session_state.chat_history = []
166
+
167
+ for message in st.session_state.chat_history:
168
+ with st.chat_message(message["role"]):
169
+ st.write(message["content"])
170
+
171
+ query = st.chat_input("Enter your query")
172
+
173
+ if query:
174
+ st.session_state.chat_history.append({"role": "user", "content": query})
175
+
176
+ try:
177
+ response = requests.post(
178
+ 'http://localhost:8000/api/v1/query',
179
+ json={'query': query}
180
+ )
181
+
182
+ if response.status_code == 200:
183
+ result = response.json().get("result", "No result")
184
+ st.session_state.chat_history.append({"role": "assistant", "content": result})
185
+ st.rerun()
186
+ else:
187
+ st.error(f'Query failed: {response.text}')
188
+ except requests.RequestException as e:
189
+ st.error(f'Error connecting to backend: {str(e)}')
190
+
191
+ if st.button("End Chat"):
192
+ st.session_state.current_page = 'db_connection'
193
+ st.rerun()
194
+
195
+ # Main app
196
+ def main():
197
+ init_db()
198
+ init_session_state()
199
+
200
+ if not st.session_state.logged_in:
201
+ login_page()
202
+ elif st.session_state.current_page == 'db_connection':
203
+ db_connection_page()
204
+ elif st.session_state.current_page == 'chat':
205
+ if not st.session_state.db_connected:
206
+ st.error('Database not connected. Redirecting to Database Connection page')
207
+ st.session_state.current_page = 'db_connection'
208
+ st.rerun()
209
+ chat_page()
210
+
211
+ def logout():
212
+ st.session_state.logged_in = False
213
+ st.session_state.username = None
214
+ st.session_state.current_page = 'login'
215
+ st.session_state.db_connected = False
216
+ st.rerun()
217
+
218
+ if __name__ == '__main__':
219
+ main()
app/frontend/users.db ADDED
Binary file (12.3 kB). View file
 
app/logging_config.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import logging.handlers
3
+ import sys
4
+ from pathlib import Path
5
+ from app.core.config import settings
6
+
7
+ # Create logs directory if it doesn't exist
8
+ logs_dir = Path("logs")
9
+ logs_dir.mkdir(exist_ok=True)
10
+
11
+ # Define formatters
12
+ DETAILED_FORMATTER = logging.Formatter(
13
+ "%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s",
14
+ datefmt="%Y-%m-%d %H:%M:%S"
15
+ )
16
+
17
+ CONSOLE_FORMATTER = logging.Formatter(
18
+ "%(asctime)s - %(levelname)s - %(message)s",
19
+ datefmt="%H:%M:%S"
20
+ )
21
+
22
+ def setup_logging():
23
+ """Configure logging for the application"""
24
+ # Get root logger
25
+ root_logger = logging.getLogger()
26
+
27
+ # Clear any existing handlers
28
+ root_logger.handlers.clear()
29
+
30
+ # Set log level from settings
31
+ log_level = getattr(logging, settings.LOG_LEVEL.upper(), logging.INFO)
32
+ root_logger.setLevel(log_level)
33
+
34
+ # Console Handler
35
+ console_handler = logging.StreamHandler(sys.stdout)
36
+ console_handler.setFormatter(CONSOLE_FORMATTER)
37
+ console_handler.setLevel(log_level)
38
+ root_logger.addHandler(console_handler)
39
+
40
+ # File Handler with rotation
41
+ file_handler = logging.handlers.RotatingFileHandler(
42
+ filename=logs_dir / "talk2sql.log",
43
+ maxBytes=10 * 1024 * 1024, # 10MB
44
+ backupCount=5,
45
+ encoding="utf-8"
46
+ )
47
+ file_handler.setFormatter(DETAILED_FORMATTER)
48
+ file_handler.setLevel(log_level)
49
+ root_logger.addHandler(file_handler)
50
+
51
+ # Create separate error log file for ERROR and CRITICAL
52
+ error_handler = logging.handlers.RotatingFileHandler(
53
+ filename=logs_dir / "error.log",
54
+ maxBytes=10 * 1024 * 1024, # 10MB
55
+ backupCount=5,
56
+ encoding="utf-8"
57
+ )
58
+ error_handler.setFormatter(DETAILED_FORMATTER)
59
+ error_handler.setLevel(logging.ERROR)
60
+ root_logger.addHandler(error_handler)
61
+
62
+ # Capture unhandled exceptions
63
+ def handle_exception(exc_type, exc_value, exc_traceback):
64
+ if issubclass(exc_type, KeyboardInterrupt):
65
+ # Call the default handler for KeyboardInterrupt
66
+ sys.__excepthook__(exc_type, exc_value, exc_traceback)
67
+ return
68
+ root_logger.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))
69
+
70
+ sys.excepthook = handle_exception
71
+
72
+ # Log initial configuration
73
+ root_logger.info(f"Logging configured with level: {settings.LOG_LEVEL}")
74
+ return root_logger
75
+
76
+ def get_logger(name: str) -> logging.Logger:
77
+ """Get a logger instance for a specific module"""
78
+ return logging.getLogger(name)
app/main.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from app.api.v1.endpoints import sql_query, database_connection
4
+ from app.api.v1 import auth
5
+
6
+ app = FastAPI()
7
+
8
+ # Configure CORS middleware
9
+ app.add_middleware(
10
+ CORSMiddleware,
11
+ allow_credentials=True,
12
+ allow_origins=["*"], # Allow all origin
13
+ allow_methods=["*"], # Allow all HTTP methods
14
+ allow_headers=["*"], # Allow al
15
+ )
16
+
17
+ app.include_router(database_connection.router, prefix="/api/v1")
18
+ app.include_router(sql_query.router, prefix="/api/v1")
19
+ app.include_router(auth.router, prefix="/api/v1/auth")
app/models/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, validator
2
+ from urllib.parse import urlparse
3
+ import re
4
+
5
+ class DatabaseConnectionRequest(BaseModel):
6
+ connection_string: str # e.g., "mysql+pymysql://user:password@host:port/database"
7
+
8
+ @validator('connection_string')
9
+ def validate_connection_string(cls, v):
10
+ if not v:
11
+ raise ValueError("Connection string cannot be empty")
12
+
13
+ # Check if string follows basic URL format
14
+ try:
15
+ # Basic format check
16
+ if not re.match(r'^[a-zA-Z]+(\+[a-zA-Z]+)?://[^/]+/.+$', v):
17
+ raise ValueError("Invalid connection string format - must follow pattern: dialect+driver://username:password@host:port/database")
18
+
19
+ # Parse URL to validate components
20
+ parsed = urlparse(v)
21
+
22
+ # Validate scheme (database type)
23
+ if not parsed.scheme:
24
+ raise ValueError("Database type must be specified")
25
+
26
+ # Validate that we have a hostname
27
+ if not parsed.hostname:
28
+ raise ValueError("Host must be specified")
29
+
30
+ # Validate that we have a database name
31
+ if not parsed.path or parsed.path == '/':
32
+ raise ValueError("Database name must be specified")
33
+
34
+ # Validate port if present
35
+ if parsed.port and (parsed.port < 1 or parsed.port > 65535):
36
+ raise ValueError("Port number must be between 1 and 65535")
37
+
38
+ return v
39
+ except Exception as e:
40
+ raise ValueError(f"Invalid connection string: {str(e)}")
41
+ # Alternatively, you can break it down into individual fields:
42
+ # db_type: str # e.g., "mysql", "postgres"
43
+ # host: str
44
+ # port: int
45
+ # database: str
46
+ # username: str
47
+ # password: str
48
+ class SQLQueryRequest(BaseModel):
49
+ query: str
50
+
51
+ class SQLQueryResponse(BaseModel):
52
+ result: str
app/requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ langchain
4
+ langgraph
5
+ langchain-groq
6
+ pydantic
7
+ sqlalchemy
8
+ pymysql
9
+ langchain-community
10
+ langchain-core
11
+ streamlit
12
+ pandas
13
+ IPython
14
+ ipykernel
15
+ passlib
16
+ python-multipart
17
+ bcrypt==4.3.0
18
+ psycopg2-binary
app/services/__init__.py ADDED
File without changes
app/services/sql_agent.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from langchain_community.utilities import SQLDatabase
2
+ # from langchain_groq import ChatGroq
3
+ # from langgraph.graph import StateGraph, END, START
4
+ # from langchain_core.messages import AIMessage, ToolMessage, AnyMessage, HumanMessage
5
+ # from langgraph.graph.message import AnyMessage, add_messages
6
+ # from langchain_core.tools import tool
7
+ # from typing import Annotated, Literal, TypedDict, Any
8
+ # from pydantic import BaseModel, Field
9
+ # from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
10
+ # from langgraph.prebuilt import ToolNode
11
+ # from langchain_core.prompts import ChatPromptTemplate
12
+ # from langchain_community.agent_toolkits import SQLDatabaseToolkit
13
+ # from dotenv import load_dotenv
14
+ # import os
15
+ # from IPython.display import display
16
+ # import PIL
17
+ # from langgraph.errors import GraphRecursionError
18
+ # import os
19
+ # import io
20
+ # from typing import Annotated, Any, TypedDict
21
+
22
+ # from IPython.display import Image, display
23
+ # from langchain_core.runnables.graph import MermaidDrawMethod
24
+ # from typing import Optional
25
+
26
+ # class SQLAgent:
27
+ # def __init__(self, model="llama3-70b-8192"):
28
+ # load_dotenv()
29
+ # # Initialize instance variables
30
+ # self.db = None
31
+ # self.toolkit = None
32
+ # self.tools = None
33
+ # self.list_tables_tool = None
34
+ # self.sql_db_query = None
35
+ # self.get_schema_tool = None
36
+ # self.app = None
37
+
38
+ # # Setting up LLM
39
+ # self.llm = ChatGroq(model=model)
40
+
41
+ # # Register the tool method
42
+ # self.query_to_database = self._create_query_tool()
43
+
44
+ # def _create_query_tool(self):
45
+ # """Create the query tool bound to this instance"""
46
+ # print("creating _create_query_tool")
47
+ # @tool
48
+ # def query_to_database(query: str) -> str:
49
+ # """
50
+ # Execute a SQL query against the database and return the result.
51
+ # If the query is invalid or returns no result, an error message will be returned.
52
+ # In case of an error, the user is advised to rewrite the query and try again.
53
+ # """
54
+ # if self.db is None:
55
+ # return "Error: Database connection not established. Please set up the connection first."
56
+ # result = self.db.run_no_throw(query)
57
+ # if not result:
58
+ # return "Error: Query failed. Please rewrite your query and try again."
59
+ # return result
60
+
61
+ # return query_to_database
62
+
63
+ # def setup_database_connection(self, connection_string: str):
64
+ # """Set up database connection and initialize tools"""
65
+ # try:
66
+ # # Initialize database connection
67
+ # self.db = SQLDatabase.from_uri(connection_string)
68
+ # print("Database connection successful!")
69
+
70
+ # try:
71
+ # # Initialize toolkit and tools
72
+ # self.toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm)
73
+ # self.tools = self.toolkit.get_tools()
74
+ # for tool in self.tools:
75
+ # print(f"Initialized tool: {tool.name}")
76
+
77
+ # # Create instances of the tools
78
+ # self.list_tables_tool = next((tool for tool in self.tools if tool.name == "sql_db_list_tables"), None)
79
+ # self.sql_db_query = next((tool for tool in self.tools if tool.name == "sql_db_query"), None)
80
+ # self.get_schema_tool = next((tool for tool in self.tools if tool.name == "sql_db_schema"), None)
81
+
82
+ # if not all([self.list_tables_tool, self.sql_db_query, self.get_schema_tool]):
83
+ # raise ValueError("Failed to initialize one or more required database tools")
84
+
85
+ # # Initialize workflow and compile it into an app
86
+ # self.initialize_workflow()
87
+
88
+ # return self.db
89
+
90
+ # except Exception as e:
91
+ # print(f"Error initializing tools and workflow: {str(e)}")
92
+ # raise ValueError(f"Failed to initialize database tools: {str(e)}")
93
+
94
+ # except ImportError as e:
95
+ # print(f"Database driver import error: {str(e)}")
96
+ # raise ValueError(f"Missing database driver or invalid database type: {str(e)}")
97
+ # except ValueError as e:
98
+ # print(f"Invalid connection string or configuration: {str(e)}")
99
+ # raise
100
+ # except Exception as e:
101
+ # print(f"Unexpected error during database connection: {str(e)}")
102
+ # raise ValueError(f"Failed to establish database connection: {str(e)}")
103
+
104
+ # def initialize_workflow(self):
105
+ # """Initialize the workflow graph"""
106
+
107
+ # print("Intializing Workflow....")
108
+ # # Binding tools with LLM
109
+ # llm_to_get_schema = self.llm.bind_tools([self.get_schema_tool]) if self.get_schema_tool else None
110
+ # llm_with_tools = self.llm.bind_tools([self.query_to_database])
111
+
112
+ # class State(TypedDict):
113
+ # messages: Annotated[list[AnyMessage], add_messages]
114
+
115
+ # class SubmitFinalAnswer(BaseModel):
116
+ # final_answer: str = Field(..., description="The final answer to the user")
117
+
118
+ # llm_with_final_answer = self.llm.bind_tools([SubmitFinalAnswer])
119
+
120
+ # def handle_tool_error(state: State):
121
+ # error = state.get("error")
122
+ # tool_calls = state["messages"][-1].tool_calls
123
+ # return {"messages": [ToolMessage(content=f"Error: {repr(error)}\n please fix your mistakes.", tool_call_id=tc["id"],) for tc in tool_calls]}
124
+
125
+ # def create_node_from_tool_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
126
+ # return ToolNode(tools).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error")
127
+
128
+ # list_tables = create_node_from_tool_with_fallback([self.list_tables_tool]) if self.list_tables_tool else None
129
+ # get_schema = create_node_from_tool_with_fallback([self.get_schema_tool]) if self.get_schema_tool else None
130
+ # query_database = create_node_from_tool_with_fallback([self.query_to_database])
131
+
132
+ # query_check_system = """You are a SQL expert. Carefully review the SQL query for common mistakes, including:
133
+
134
+ # Issues with NULL handling (e.g., NOT IN with NULLs)
135
+ # Improper use of UNION instead of UNION ALL
136
+ # Incorrect use of BETWEEN for exclusive ranges
137
+ # Data type mismatches or incorrect casting
138
+ # Quoting identifiers improperly
139
+ # Incorrect number of arguments in functions
140
+ # Errors in JOIN conditions
141
+
142
+ # If you find any mistakes, rewrite the query to fix them. If it's correct, reproduce it as is."""
143
+ # query_check_prompt = ChatPromptTemplate.from_messages([("system", query_check_system), ("placeholder", "{messages}")])
144
+ # check_generated_query = query_check_prompt | llm_with_tools
145
+
146
+ # def check_the_given_query(state: State):
147
+ # return {"messages": [check_generated_query.invoke({"messages": [state["messages"][-1]]})]}
148
+
149
+ # query_gen_system_prompt = """You are a SQL expert with a strong attention to detail.Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
150
+
151
+ # 1. DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.
152
+
153
+ # When generating the query:
154
+
155
+ # 2. Output the SQL query that answers the input question without a tool call.
156
+
157
+ # 3. Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
158
+
159
+ # 4. You can order the results by a relevant column to return the most interesting examples in the database.
160
+
161
+ # 5. Never query for all the columns from a specific table, only ask for the relevant columns given the question.
162
+
163
+ # 6. If you get an error while executing a query, rewrite the query and try again.
164
+
165
+ # 7. If you get an empty result set, you should try to rewrite the query to get a non-empty result set.
166
+
167
+ # 8. NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.
168
+
169
+ # 9. If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.
170
+
171
+ # 10. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. Do not return any sql query except answer."""
172
+ # query_gen_prompt = ChatPromptTemplate.from_messages([("system", query_gen_system_prompt), ("placeholder", "{messages}")])
173
+ # query_generator = query_gen_prompt | llm_with_final_answer
174
+
175
+ # def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
176
+ # return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
177
+
178
+ # def generation_query(state: State):
179
+ # message = query_generator.invoke(state)
180
+ # tool_messages = []
181
+ # if message.tool_calls:
182
+ # for tc in message.tool_calls:
183
+ # if tc["name"] != "SubmitFinalAnswer":
184
+ # tool_messages.append(
185
+ # ToolMessage(
186
+ # content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
187
+ # tool_call_id=tc["id"],
188
+ # )
189
+ # )
190
+ # else:
191
+ # tool_messages = []
192
+ # return {"messages": [message] + tool_messages}
193
+
194
+ # def should_continue(state: State):
195
+ # messages = state["messages"]
196
+ # last_message = messages[-1]
197
+ # if getattr(last_message, "tool_calls", None):
198
+ # # Check if the tool call is SubmitFinalAnswer
199
+ # if len(last_message.tool_calls) > 0 and last_message.tool_calls[0]["name"] == "SubmitFinalAnswer":
200
+ # return END
201
+ # else:
202
+ # # Wrong tool called, route to error handling (not implemented here)
203
+ # return "query_gen" # Or a dedicated error node
204
+ # elif last_message.content.startswith("Error:"):
205
+ # return "query_gen"
206
+ # else:
207
+ # return "correct_query"
208
+
209
+ # def llm_get_schema(state: State):
210
+ # response = llm_to_get_schema.invoke(state["messages"])
211
+ # return {"messages": [response]}
212
+
213
+ # # Create workflow
214
+ # workflow = StateGraph(State)
215
+ # workflow.add_node("first_tool_call", first_tool_call)
216
+ # workflow.add_node("list_tables_tool", list_tables)
217
+ # workflow.add_node("get_schema_tool", get_schema)
218
+ # workflow.add_node("model_get_schema", llm_get_schema)
219
+ # workflow.add_node("query_gen", generation_query)
220
+ # workflow.add_node("correct_query", check_the_given_query)
221
+ # workflow.add_node("execute_query", query_database)
222
+
223
+ # workflow.add_edge(START, "first_tool_call")
224
+ # workflow.add_edge("first_tool_call", "list_tables_tool")
225
+ # workflow.add_edge("list_tables_tool", "model_get_schema")
226
+ # workflow.add_edge("model_get_schema", "get_schema_tool")
227
+ # workflow.add_edge("get_schema_tool", "query_gen")
228
+ # workflow.add_conditional_edges("query_gen", should_continue, {END: END, "correct_query": "correct_query", "query_gen": "query_gen"})
229
+ # workflow.add_edge("correct_query", "execute_query")
230
+ # workflow.add_edge("execute_query", "query_gen")
231
+
232
+ # # Compile the workflow into an executable app
233
+ # self.app = workflow.compile()
234
+
235
+
236
+ # # # Generate the graph image as bytes
237
+ # # image_bytes = self.app.get_graph().draw_mermaid_png()
238
+
239
+ # # # Convert bytes to an Image object
240
+ # # image = Image.open(io.BytesIO(image_bytes))
241
+
242
+ # # # Save the image to a file
243
+ # # image.save("workflow_graph.png")
244
+ # # print(f"Workflow graph saved")
245
+
246
+ # def is_query_relevant(self, query: str) -> bool:
247
+ # """Check if the query is relevant to the database using the LLM."""
248
+
249
+ # # Retrieve the schema of the relevant tables
250
+ # if self.list_tables_tool:
251
+ # relevant_tables = self.list_tables_tool.invoke("")
252
+ # # print(relevant_tables)
253
+ # table_list= relevant_tables.split(", ")
254
+ # print(table_list)
255
+ # # print(agent.get_schema_tool.invoke(table_list[0]))
256
+ # schema = ""
257
+ # for table in table_list:
258
+ # schema+= self.get_schema_tool.invoke(table)
259
+
260
+ # print(schema)
261
+
262
+ # # if self.get_schema_tool:
263
+ # # schema_response = self.get_schema_tool.invoke({})
264
+ # # table_schema = schema_response.content # Assuming this returns the schema as a string
265
+
266
+ # relevance_check_prompt = (
267
+ # """You are an expert SQL agent which takes user query in Natural language and find out it have releavnce with the given schema or not. Please determine if the following query is related to a database.Here is the schema of the tables present in database:\n{schema}\n\n. If the query related to given schema respond with 'yes'. Here is the query: {query}. Answer with only 'yes' or 'no'."""
268
+ # ).format(schema=relevant_tables, query=query)
269
+
270
+ # response = self.llm.invoke([{"role": "user", "content": relevance_check_prompt}])
271
+
272
+ # # Assuming the LLM returns a simple 'yes' or 'no'
273
+ # return response.content == "yes"
274
+
275
+
276
+ # def execute_query(self, query: str):
277
+ # """Execute a query through the workflow"""
278
+ # if self.db is None:
279
+ # raise ValueError("Database connection not established. Please set up the connection first.")
280
+ # if self.app is None:
281
+ # raise ValueError("Workflow not initialized. Please set up the connection first.")
282
+ # # First, handle simple queries like "list tables" directly
283
+ # query_lower = query.lower()
284
+ # if any(phrase in query_lower for phrase in ["list all the tables", "show tables", "name of tables",
285
+ # "which tables are present", "how many tables", "list all tables"]):
286
+ # if self.list_tables_tool:
287
+ # tables = self.list_tables_tool.invoke("")
288
+ # return f"The tables in the database are: {tables}"
289
+ # else:
290
+ # return "Error: Unable to list tables. The list_tables_tool is not initialized."
291
+
292
+ # # Check if the query is relevant to the database
293
+ # if not self.is_query_relevant(query):
294
+ # print("Not relevent to database.")
295
+ # # If not relevant, let the LLM answer the question directly
296
+ # non_relevant_prompt = (
297
+ # """You are an expert SQL agent created by Kshitij Kumrawat. You can only assist with questions related to databases so repond the user with the following example resonse and Do not answer any questions that are not related to databases.:
298
+ # Please ask a question that pertains to database operations, such as querying tables, retrieving data, or understanding the database schema. """
299
+ # )
300
+
301
+ # # Invoke the LLM with the non-relevant prompt
302
+ # response = self.llm.invoke([{"role": "user", "content": non_relevant_prompt}])
303
+ # # print(response.content)
304
+ # return response.content
305
+
306
+ # # If relevant, proceed with the SQL workflow
307
+ # response = self.app.invoke({"messages": [HumanMessage(content=query, role="user")]})
308
+
309
+ # # More robust final answer extraction
310
+ # if (
311
+ # response
312
+ # and response["messages"]
313
+ # and response["messages"][-1].tool_calls
314
+ # and len(response["messages"][-1].tool_calls) > 0
315
+ # and "args" in response["messages"][-1].tool_calls[0]
316
+ # and "final_answer" in response["messages"][-1].tool_calls[0]["args"]
317
+ # ):
318
+ # return response["messages"][-1].tool_calls[0]["args"]["final_answer"]
319
+ # else:
320
+ # return "Error: Could not extract final answer."
321
+
322
+
323
+
324
+ from langchain_community.utilities import SQLDatabase
325
+ from langchain_groq import ChatGroq
326
+ from langgraph.graph import StateGraph, END, START
327
+ from langchain_core.messages import AIMessage, ToolMessage, AnyMessage, HumanMessage
328
+ from langgraph.graph.message import AnyMessage, add_messages
329
+ from langchain_core.tools import tool
330
+ from typing import Annotated, Literal, TypedDict, Any
331
+ from pydantic import BaseModel, Field
332
+ from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
333
+ from langgraph.prebuilt import ToolNode
334
+ from langchain_core.prompts import ChatPromptTemplate
335
+ from langchain_community.agent_toolkits import SQLDatabaseToolkit
336
+ from dotenv import load_dotenv
337
+ import os
338
+ from IPython.display import display
339
+ import PIL
340
+ from langgraph.errors import GraphRecursionError
341
+ import os
342
+ import io
343
+ from typing import Annotated, Any, TypedDict
344
+ from langgraph.graph import StateGraph, END, MessagesState
345
+
346
+ from IPython.display import Image, display
347
+ from langchain_core.runnables.graph import MermaidDrawMethod
348
+ from typing import Optional, Dict
349
+
350
+ from langchain_community.utilities import SQLDatabase
351
+ from langchain_community.agent_toolkits import SQLDatabaseToolkit
352
+ from langchain_groq import ChatGroq
353
+ from langchain_core.messages import HumanMessage, AIMessage
354
+ from langchain_core.prompts import ChatPromptTemplate
355
+ from langchain_core.pydantic_v1 import BaseModel, Field
356
+ from langgraph.graph import StateGraph, END, MessagesState
357
+ from typing import TypedDict, Annotated, List, Literal, Dict, Any
358
+ from dotenv import load_dotenv
359
+ load_dotenv()
360
+ import os
361
+ os.environ["GROQ_API_KEY"]=os.getenv("GROQ_API_KEY")
362
+
363
+ class SQLAgent:
364
+ def __init__(self, model="llama3-70b-8192"):
365
+
366
+ # Initialize instance variables
367
+ self.db = None
368
+ self.toolkit = None
369
+ self.tools = None
370
+ self.list_tables_tool = None
371
+ self.sql_db_query = None
372
+ self.get_schema_tool = None
373
+ self.app = None
374
+
375
+ # Setting up LLM
376
+ self.llm = ChatGroq(model=model,api_key = os.getenv("GROQ_API_KEY"))
377
+
378
+ # Register the tool method
379
+ self.query_to_database = self._create_query_tool()
380
+
381
+ def _create_query_tool(self):
382
+ """Create the query tool bound to this instance"""
383
+ print("creating _create_query_tool")
384
+ @tool
385
+ def query_to_database(query: str) -> str:
386
+ """
387
+ Execute a SQL query against the database and return the result.
388
+ If the query is invalid or returns no result, an error message will be returned.
389
+ In case of an error, the user is advised to rewrite the query and try again.
390
+ """
391
+ if self.db is None:
392
+ return "Error: Database connection not established. Please set up the connection first."
393
+ result = self.db.run_no_throw(query)
394
+ if not result:
395
+ return "Error: Query failed. Please rewrite your query and try again."
396
+ return result
397
+
398
+ return query_to_database
399
+
400
+ def setup_database_connection(self, connection_string: str):
401
+ """Set up database connection and initialize tools"""
402
+ try:
403
+ # Initialize database connection
404
+ self.db = SQLDatabase.from_uri(connection_string)
405
+ print("Database connection successful!")
406
+
407
+ try:
408
+ # Initialize toolkit and tools
409
+ self.toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm)
410
+ self.tools = self.toolkit.get_tools()
411
+ for tool in self.tools:
412
+ print(f"Initialized tool: {tool.name}")
413
+
414
+ # Create instances of the tools
415
+ self.list_tables_tool = next((tool for tool in self.tools if tool.name == "sql_db_list_tables"), None)
416
+ self.query_tool = next((tool for tool in self.tools if tool.name == "sql_db_query"), None)
417
+ self.get_schema_tool = next((tool for tool in self.tools if tool.name == "sql_db_schema"), None)
418
+ self.query_checker_tool = next((tool for tool in self.tools if tool.name == "sql_db_query_checker"), None)
419
+ if not all([self.list_tables_tool, self.query_tool, self.get_schema_tool, self.query_checker_tool]):
420
+ raise ValueError("Failed to initialize one or more required database tools")
421
+
422
+ # Initialize workflow and compile it into an app
423
+ self.initialize_workflow()
424
+
425
+ return self.db
426
+
427
+ except Exception as e:
428
+ print(f"Error initializing tools and workflow: {str(e)}")
429
+ raise ValueError(f"Failed to initialize database tools: {str(e)}")
430
+
431
+ except ImportError as e:
432
+ print(f"Database driver import error: {str(e)}")
433
+ raise ValueError(f"Missing database driver or invalid database type: {str(e)}")
434
+ except ValueError as e:
435
+ print(f"Invalid connection string or configuration: {str(e)}")
436
+ raise
437
+ except Exception as e:
438
+ print(f"Unexpected error during database connection: {str(e)}")
439
+ raise ValueError(f"Failed to establish database connection: {str(e)}")
440
+
441
+ def initialize_workflow(self):
442
+ """Initialize the workflow graph"""
443
+
444
+ print("Intializing Workflow....")
445
+
446
+ class SQLAgentState(MessagesState):
447
+ """State for the agent"""
448
+ next_tool : str = ""
449
+ tables_list: str = ""
450
+ schema_of_table: str = ""
451
+ query_gen : str= ""
452
+ check_query: str = ""
453
+ execute_query : str = ""
454
+ task_complete: bool = False
455
+ response_to_user: str= ""
456
+ current_task: str = ""
457
+ query: str = "" ## query of the human stored in it
458
+
459
+ class DBQuery(BaseModel):
460
+ query: str = Field(..., description="The SQL query to execute")
461
+
462
+ def creating_sql_agent_chain():
463
+ """Creating a sql agent chain"""
464
+ print("Creating a sql agent chain")
465
+ sql_agent_prompt = ChatPromptTemplate.from_messages([
466
+ ("system", """You are a supervisor SQL agent managing tools to get the answer to the user's query.
467
+
468
+ Based on the current state, decide which tool should be called next:
469
+ 1. list_table_tools - List all tables from the database
470
+ 2. get_schema - Get the schema of required tables
471
+ 3. generate_query - Generate a SQL query
472
+ 4. check_query - Check if the query is correct
473
+ 5. execute_query - Execute the query
474
+ 6. response - Create response for the user
475
+
476
+ Current state:
477
+ - Tables listed: {tables_list}
478
+ - Schema retrieved: {schema_of_table}
479
+ - Query generated: {query_gen}
480
+ - Query checked: {check_query}
481
+ - Query executed: {execute_query}
482
+ - Response created: {response_to_user}
483
+
484
+ If no tables are listed, respond with 'list_table_tools'.
485
+ If tables are listed but no schema, respond with 'get_schema'.
486
+ If schema exists but no query generated, respond with 'generate_query'.
487
+ If query generated but not checked, respond with 'check_query'.
488
+ If query checked but not executed, respond with 'execute_query'.
489
+ If query executed but no response, respond with 'response'.
490
+ If everything is complete, respond with 'DONE'.
491
+
492
+ Respond with ONLY the tool name or 'DONE'.
493
+ """),
494
+ ("human", "{task}")
495
+ ])
496
+ return sql_agent_prompt | self.llm
497
+
498
+ def sql_agent(state: SQLAgentState) -> Dict:
499
+ """Agent decides which tool to call next"""
500
+ messages = state["messages"]
501
+ task = messages[-1].content if messages else "No task"
502
+
503
+ # Store the original query in state if not already stored
504
+ if not state.get("query"):
505
+ state["query"] = task
506
+
507
+ # Check what's been completed (convert to boolean properly)
508
+ tables_list = bool(state.get("tables_list", "").strip())
509
+ schema_of_table = bool(state.get("schema_of_table", "").strip())
510
+ query_gen = bool(state.get("query_gen", "").strip())
511
+ check_query = bool(state.get("check_query", "").strip())
512
+ execute_query = bool(state.get("execute_query", "").strip())
513
+ response_to_user = bool(state.get("response_to_user", "").strip())
514
+
515
+ print(f"State check - Tables: {tables_list}, Schema: {schema_of_table}, Query: {query_gen}, Check: {check_query}, Execute: {execute_query}, Response: {response_to_user}")
516
+
517
+ chain = creating_sql_agent_chain()
518
+ decision = chain.invoke({
519
+ "task": task,
520
+ "tables_list": tables_list,
521
+ "schema_of_table": schema_of_table,
522
+ "query_gen": query_gen,
523
+ "check_query": check_query,
524
+ "execute_query": execute_query,
525
+ "response_to_user": response_to_user
526
+ })
527
+ decision_text = decision.content.strip().lower()
528
+ print(f"Agent decision: {decision_text}")
529
+
530
+ if "done" in decision_text:
531
+ next_tool = "end"
532
+ agent_msg = "✅ SQL Agent: All tasks complete!"
533
+ elif "list_table_tools" in decision_text:
534
+ next_tool = "list_table_tools"
535
+ agent_msg = "📋 SQL Agent: Listing all tables in database."
536
+ elif "get_schema" in decision_text:
537
+ next_tool = "get_schema"
538
+ agent_msg = "📋 SQL Agent: Getting schema of tables."
539
+ elif "generate_query" in decision_text:
540
+ next_tool = "generate_query"
541
+ agent_msg = "📋 SQL Agent: Generating SQL query."
542
+ elif "check_query" in decision_text:
543
+ next_tool = "check_query"
544
+ agent_msg = "📋 SQL Agent: Checking SQL query."
545
+ elif "execute_query" in decision_text:
546
+ next_tool = "execute_query"
547
+ agent_msg = "📋 SQL Agent: Executing query."
548
+ elif "response" in decision_text:
549
+ next_tool = "response"
550
+ agent_msg = "📋 SQL Agent: Creating response."
551
+ else:
552
+ next_tool = "end"
553
+ agent_msg = "✅ SQL Agent: Task complete."
554
+
555
+ return {
556
+ "messages": [AIMessage(content=agent_msg)],
557
+ "next_tool": next_tool,
558
+ "current_task": task
559
+ }
560
+
561
+ def list_table_tools(state: SQLAgentState) -> Dict:
562
+ """List all the tables"""
563
+ tables_list = self.list_tables_tool.invoke("")
564
+ print(f"Tables found: {tables_list}")
565
+ return {
566
+ "messages": [AIMessage(content=f"Tables found: {tables_list}")],
567
+ "tables_list": tables_list,
568
+ "next_tool": "sql_agent"
569
+ }
570
+
571
+ def get_schema(state: SQLAgentState) -> Dict:
572
+ """Get the schema of required tables"""
573
+ print("📘 Getting schema...")
574
+ tables_list = state.get("tables_list", "")
575
+ if not tables_list:
576
+ tables_list = self.list_tables_tool.invoke("")
577
+
578
+ tables = [table.strip() for table in tables_list.split(",")]
579
+ full_schema = ""
580
+
581
+ for table in tables:
582
+ try:
583
+ schema = self.get_schema_tool.invoke(table)
584
+ full_schema += f"\nTable: {table}\n{schema}\n"
585
+ except Exception as e:
586
+ print(f"Error getting schema for {table}: {e}")
587
+
588
+ print(f"📘 Schema collected for tables: {tables}")
589
+ return {
590
+ "messages": [AIMessage(content=f"Schema retrieved: {full_schema}")],
591
+ "schema_of_table": full_schema,
592
+ "tables_list": tables_list,
593
+ "next_tool": "sql_agent"
594
+ }
595
+
596
+ def generate_query(state: SQLAgentState) -> Dict:
597
+ """Generate a SQL Query according to the user query"""
598
+ schema = state.get("schema_of_table", "")
599
+ human_query = state.get("query", "")
600
+ tables = state.get("tables_list", "")
601
+
602
+ print(f"Generating query for: {human_query}")
603
+
604
+ generate_query_system_prompt = """You are a SQL expert that generates precise SQL queries based on user questions.
605
+
606
+ You will be provided with:
607
+ - User's question
608
+ - Available tables
609
+ - Complete schema information
610
+
611
+ Generate a SQL query that:
612
+ - Uses correct column names from schema
613
+ - Properly joins tables if needed
614
+ - Includes appropriate WHERE clauses
615
+ - Uses proper aggregation functions when needed
616
+
617
+ Respond ONLY with the SQL query. Do not explain."""
618
+
619
+ combined_input = f"""
620
+ User Question: {human_query}
621
+ Tables: {tables}
622
+ Schema: {schema}
623
+ """
624
+
625
+ generate_query_prompt = ChatPromptTemplate.from_messages([
626
+ ("system", generate_query_system_prompt),
627
+ ("human", "{input}")
628
+ ])
629
+
630
+ try:
631
+ formatted_prompt = generate_query_prompt.invoke({"input": combined_input})
632
+ generate_query_llm = self.llm.with_structured_output(DBQuery)
633
+ result = generate_query_llm.invoke(formatted_prompt)
634
+
635
+ print(f"✅ Query generated: {result.query}")
636
+ return {
637
+ "messages": [AIMessage(content=f"Query generated: {result.query}")],
638
+ "query_gen": result.query,
639
+ "next_tool": "sql_agent"
640
+ }
641
+ except Exception as e:
642
+ print(f"❌ Failed to generate query: {e}")
643
+ return {
644
+ "messages": [AIMessage(content="⚠️ Failed to generate SQL query.")],
645
+ "query_gen": "",
646
+ "next_tool": "sql_agent"
647
+ }
648
+
649
+ def check_query(state: SQLAgentState) -> Dict:
650
+ """Check if the query is correct"""
651
+ query = state.get("query_gen", "")
652
+ print(f"Checking query: {query}")
653
+
654
+ if not query:
655
+ return {
656
+ "messages": [AIMessage(content="No query to check")],
657
+ "check_query": "",
658
+ "next_tool": "sql_agent"
659
+ }
660
+
661
+ try:
662
+ checked_query = self.query_checker_tool.invoke(query)
663
+ print(f"Query checked: {checked_query}")
664
+ return {
665
+ "messages": [AIMessage(content=f"Query checked: {checked_query}")],
666
+ "check_query": checked_query if checked_query else query,
667
+ "next_tool": "sql_agent"
668
+ }
669
+ except Exception as e:
670
+ print(f"Error checking query: {e}")
671
+ return {
672
+ "messages": [AIMessage(content="Query check failed, using original query")],
673
+ "check_query": query,
674
+ "next_tool": "sql_agent"
675
+ }
676
+ def execute_query_(state: SQLAgentState) -> Dict:
677
+ """Execute the SQL query"""
678
+ query = state.get("check_query", "") or state.get("query_gen", "")
679
+ print(f"Executing query: {query}")
680
+
681
+ if not query:
682
+ return {
683
+ "messages": [AIMessage(content="No query to execute")],
684
+ "execute_query": "",
685
+ "next_tool": "sql_agent"
686
+ }
687
+
688
+ try:
689
+ results = self.query_tool.invoke(query)
690
+ print(f"Query results: {results}")
691
+ return {
692
+ "messages": [AIMessage(content=f"Query executed successfully: {results}")],
693
+ "execute_query": results,
694
+ "next_tool": "sql_agent"
695
+ }
696
+ except Exception as e:
697
+ print(f"Error executing query: {e}")
698
+ return {
699
+ "messages": [AIMessage(content=f"Query execution failed: {e}")],
700
+ "execute_query": "",
701
+ "next_tool": "sql_agent"
702
+ }
703
+ def create_response(state: SQLAgentState) -> Dict:
704
+ """Create a final response for the user"""
705
+ print("Creating final response...")
706
+
707
+ query = state.get("check_query", "") or state.get("query_gen", "")
708
+ result = state.get("execute_query", "")
709
+ human_query = state.get("query", "")
710
+
711
+ response_prompt = f"""Create a clear, concise response for the user based on:
712
+
713
+ User Question: {human_query}
714
+ SQL Query: {query}
715
+ Query Result: {result}
716
+
717
+ Provide a natural language answer that directly addresses the user's question. Make sure to provide only answer to human question, no any internal process results and explaination, just answer related to the human query."""
718
+
719
+ try:
720
+ response = self.llm.invoke([HumanMessage(content=response_prompt)])
721
+ print(f"Response created: {response.content}")
722
+
723
+ return {
724
+ "messages": [response],
725
+ "response_to_user": response.content,
726
+ "next_tool": "sql_agent",
727
+ "task_complete": True
728
+ }
729
+ except Exception as e:
730
+ print(f"Error creating response: {e}")
731
+ return {
732
+ "messages": [AIMessage(content="Failed to create response")],
733
+ "response_to_user": "",
734
+ "next_tool": "sql_agent",
735
+ "task_complete": True
736
+ }
737
+ def router(state: SQLAgentState):
738
+ """Route to the next node"""
739
+ print("🔁 Entering router...")
740
+ next_tool = state.get("next_tool", "")
741
+ print(f"➡️ Next tool: {next_tool}")
742
+
743
+ if next_tool == "end" or state.get("task_complete", False):
744
+ return END
745
+
746
+ valid_tools = [
747
+ "sql_agent", "list_table_tools", "get_schema", "generate_query",
748
+ "check_query", "execute_query", "response"
749
+ ]
750
+
751
+ return next_tool if next_tool in valid_tools else "sql_agent"
752
+
753
+ # Create workflow
754
+ workflow = StateGraph(SQLAgentState)
755
+
756
+ # Add nodes
757
+ workflow.add_node("sql_agent", sql_agent)
758
+ workflow.add_node("list_table_tools", list_table_tools)
759
+ workflow.add_node("get_schema", get_schema)
760
+ workflow.add_node("generate_query", generate_query)
761
+ workflow.add_node("check_query", check_query)
762
+ workflow.add_node("execute_query", execute_query_)
763
+ workflow.add_node("response", create_response)
764
+
765
+ # Set entry point
766
+ workflow.set_entry_point("sql_agent")
767
+
768
+ # Add routing
769
+ for node in ["sql_agent", "list_table_tools", "get_schema", "generate_query", "check_query", "execute_query", "response"]:
770
+ workflow.add_conditional_edges(
771
+ node,
772
+ router,
773
+ {
774
+ "sql_agent": "sql_agent",
775
+ "list_table_tools": "list_table_tools",
776
+ "get_schema": "get_schema",
777
+ "generate_query": "generate_query",
778
+ "check_query": "check_query",
779
+ "execute_query": "execute_query",
780
+ "response": "response",
781
+ END: END
782
+ }
783
+ )
784
+
785
+ # Compile the graph
786
+ self.app = workflow.compile()
787
+
788
+
789
+
790
+ def is_query_relevant(self, query: str) -> bool:
791
+ """Check if the query is relevant to the database using the LLM."""
792
+
793
+ # Retrieve the schema of the relevant tables
794
+ if self.list_tables_tool:
795
+ relevant_tables = self.list_tables_tool.invoke("")
796
+ # print(relevant_tables)
797
+ table_list= relevant_tables.split(", ")
798
+ print(table_list)
799
+ # print(agent.get_schema_tool.invoke(table_list[0]))
800
+ schema = ""
801
+ for table in table_list:
802
+ schema+= self.get_schema_tool.invoke(table)
803
+
804
+ print(schema)
805
+
806
+ # if self.get_schema_tool:
807
+ # schema_response = self.get_schema_tool.invoke({})
808
+ # table_schema = schema_response.content # Assuming this returns the schema as a string
809
+
810
+ relevance_check_prompt = (
811
+ """You are an expert SQL agent which takes user query in Natural language and find out it have releavnce with the given schema or not. Please determine if the following query is related to a database.Here is the schema of the tables present in database:\n{schema}\n\n. If the query related to given schema respond with 'yes'. Here is the query: {query}. Answer with only 'yes' or 'no'."""
812
+ ).format(schema=relevant_tables, query=query)
813
+
814
+ response = self.llm.invoke([{"role": "user", "content": relevance_check_prompt}])
815
+
816
+ # Assuming the LLM returns a simple 'yes' or 'no'
817
+ return response.content == "yes"
818
+
819
+ ## called from the fastapi endpoint
820
+ def execute_query(self, query: str):
821
+ """Execute a query through the workflow"""
822
+ if self.db is None:
823
+ raise ValueError("Database connection not established. Please set up the connection first.")
824
+ if self.app is None:
825
+ raise ValueError("Workflow not initialized. Please set up the connection first.")
826
+ # First, handle simple queries like "list tables" directly
827
+ query_lower = query.lower()
828
+ if any(phrase in query_lower for phrase in ["list all the tables", "show tables", "name of tables",
829
+ "which tables are present", "how many tables", "list all tables"]):
830
+ if self.list_tables_tool:
831
+ tables = self.list_tables_tool.invoke("")
832
+ return f"The tables in the database are: {tables}"
833
+ else:
834
+ return "Error: Unable to list tables. The list_tables_tool is not initialized."
835
+
836
+ # Check if the query is relevant to the database
837
+ if not self.is_query_relevant(query):
838
+ print("Not relevent to database.")
839
+ # If not relevant, let the LLM answer the question directly
840
+ non_relevant_prompt = (
841
+ """You are an expert SQL agent created by Kshitij Kumrawat. You can only assist with questions related to databases so repond the user with the following example resonse and Do not answer any questions that are not related to databases.:
842
+ Please ask a question that pertains to database operations, such as querying tables, retrieving data, or understanding the database schema. """
843
+ )
844
+
845
+ # Invoke the LLM with the non-relevant prompt
846
+ response = self.llm.invoke([{"role": "user", "content": non_relevant_prompt}])
847
+ # print(response.content)
848
+ return response.content
849
+
850
+ # If relevant, proceed with the SQL workflow
851
+ # response = self.app.invoke({"messages": [HumanMessage(content=query, role="user")]})
852
+ response = self.app.invoke({
853
+ "messages": [HumanMessage(content=query)],
854
+ "query": query
855
+ })
856
+
857
+ return response["messages"][-1].content
858
+
859
+ # # More robust final answer extraction
860
+ # if (
861
+ # response
862
+ # and response["messages"]
863
+ # and response["messages"][-1].tool_calls
864
+ # and len(response["messages"][-1].tool_calls) > 0
865
+ # and "args" in response["messages"][-1].tool_calls[0]
866
+ # and "final_answer" in response["messages"][-1].tool_calls[0]["args"]
867
+ # ):
868
+ # return response["messages"][-1].tool_calls[0]["args"]["final_answer"]
869
+ # else:
870
+ # return "Error: Could not extract final answer."
871
+
app/services/sql_agent_instance.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQLAgent singleton instance module.
3
+ This creates and maintains a single instance of the SQLAgent class
4
+ that can be imported and used throughout the application.
5
+ """
6
+ from app.services.sql_agent import SQLAgent
7
+
8
+ # Create a singleton instance
9
+ sql_agent = SQLAgent()
docker-compose.yml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.8"
2
+ services:
3
+ app:
4
+ image: ${{ secrets.ECR_REGISTRY }}/${{ secrets.ECR_REPOSITORY }}:latest
5
+ ports:
6
+ - "80:80" # Map host port 80 to container port 80
7
+ - "8000:8000"
8
+ - "8501:8501" # Expose Streamlit port
9
+ environment:
10
+ - PYTHONUNBUFFERED=1
11
+ restart: unless-stopped
12
+ command: |
13
+ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload & streamlit run app/frontend/Talk2SQL.py --server.address=0.0.0.0 --server.port=8501
employee.db ADDED
Binary file (28.7 kB). View file
 
pyproject.toml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "backend"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.13"
7
+ dependencies = [
8
+ "bcrypt>=4.3.0",
9
+ "fastapi>=0.116.1",
10
+ "ipykernel>=6.29.5",
11
+ "ipython>=9.4.0",
12
+ "langchain>=0.3.26",
13
+ "langchain-community>=0.3.27",
14
+ "langchain-core>=0.3.68",
15
+ "langchain-groq>=0.3.6",
16
+ "langgraph>=0.5.3",
17
+ "pandas>=2.3.1",
18
+ "passlib>=1.7.4",
19
+ "psycopg2-binary>=2.9.10",
20
+ "pydantic>=2.11.7",
21
+ "pymysql>=1.1.1",
22
+ "python-multipart>=0.0.20",
23
+ "sqlalchemy>=2.0.41",
24
+ "streamlit>=1.46.1",
25
+ "uvicorn>=0.35.0",
26
+ ]
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ langchain
4
+ langgraph
5
+ langchain-groq
6
+ pydantic
7
+ sqlalchemy
8
+ pymysql
9
+ langchain-community
10
+ langchain-core
11
+ streamlit
12
+ pandas
13
+ IPython
14
+ ipykernel
15
+ passlib
16
+ python-multipart
17
+ bcrypt==4.3.0
18
+ psycopg2-binary
setup.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='talk2sql',
5
+ version='0.1.0',
6
+ packages=find_packages(),
7
+ install_requires=[
8
+ 'fastapi',
9
+ 'uvicorn',
10
+ 'streamlit',
11
+ 'pydantic',
12
+ 'SQLAlchemy',
13
+ 'pymysql',
14
+ 'python-dotenv',
15
+ 'langchain',
16
+ 'langchain_community',
17
+ 'langchain_groq',
18
+ 'langgraph',
19
+ 'beautifulsoup4',
20
+ 'lxml'
21
+ ],
22
+ )
sql_agent_version2.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
sql_agent_with_langgraph.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
users.db ADDED
Binary file (12.3 kB). View file
 
uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
workflow_graph.png ADDED