|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
import streamlit as st |
|
|
import os |
|
|
import sqlite3 |
|
|
import google.generativeai as genai |
|
|
import pandas as pd |
|
|
from custom_table import get_dataframe |
|
|
from utils import remove_triple_quotes |
|
|
|
|
|
|
|
|
|
|
|
genai.configure(api_key=os.getenv('GOOGLE_API_KEY')) |
|
|
|
|
|
|
|
|
|
|
|
def get_gemini_response(question, prompt): |
|
|
model=genai.GenerativeModel('gemini-pro') |
|
|
response=model.generate_content([prompt[0], question]) |
|
|
return response.text |
|
|
|
|
|
|
|
|
|
|
|
def read_sql_query(sql, db): |
|
|
conn = sqlite3.connect(db) |
|
|
cur = conn.cursor() |
|
|
cur.execute(sql) |
|
|
rows = cur.fetchall() |
|
|
|
|
|
sql_output_df = pd.DataFrame(rows, columns=[description[0] for description in cur.description]) |
|
|
conn.commit() |
|
|
conn.close() |
|
|
for row in rows: |
|
|
print(row) |
|
|
return rows, sql_output_df |
|
|
|
|
|
|
|
|
|
|
|
prompt = [ |
|
|
""" You are an expert in converting English questions to SQL query! |
|
|
The SQL database has the name test_db and the SQL table name is tbl_customer |
|
|
and has the following columns - |
|
|
CustomerKey,GeographyKey,CustomerAlternateKey,Title,FirstName,MiddleName, |
|
|
LastName,NameStyle,BirthDate,MaritalStatus,Suffix,Gender,EmailAddress, |
|
|
YearlyIncome,TotalChildren,NumberChildrenAtHome,EnglishEducation, |
|
|
SpanishEducation,FrenchEducation,EnglishOccupation,SpanishOccupation, |
|
|
FrenchOccupation,HouseOwnerFlag,NumberCarsOwned,AddressLine1,AddressLine2, |
|
|
Phone,DateFirstPurchase,CommuteDistance. |
|
|
|
|
|
\n\nFor example, \nExample 1 - How many entries of records are present?, |
|
|
the SQL command will be something like this SELECT COUNT(*) FROM tbl_customer ; |
|
|
|
|
|
\nExample 2 - Tell me all the customers who has CustomerKey greater than 11000?, |
|
|
the SQL command will be something like this SELECT * FROM tbl_customer where CustomerKey > 11000"; |
|
|
also the sql code should NOT have ``` in beginning or end and sql word in output |
|
|
also there should NOT be sql keyword at the beginning of the output """ |
|
|
] |
|
|
|
|
|
st.set_page_config(page_title="I can retrieve any SQL query") |
|
|
st.header("Gemini App to Retrieve SQL data") |
|
|
|
|
|
question = st.text_input("Input: ", key=input) |
|
|
submit=st.button("Ask the Question :") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
df = get_dataframe() |
|
|
st.write('##### This is the sample table...') |
|
|
st.write(df.head(5)) |
|
|
|
|
|
|
|
|
|
|
|
if submit: |
|
|
generated_sql_query = get_gemini_response(question, prompt) |
|
|
generated_sql_query = remove_triple_quotes(generated_sql_query) |
|
|
print("The generated sql_quey is :") |
|
|
print(generated_sql_query) |
|
|
response, sql_output_df = read_sql_query(generated_sql_query, "test_db.db") |
|
|
st.write('##### used SQL query is:') |
|
|
st.write(generated_sql_query) |
|
|
st.write("#### The Response is :") |
|
|
|
|
|
st.write(sql_output_df) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|