File size: 3,406 Bytes
4c9a5ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
from dotenv import load_dotenv
load_dotenv() # Load All the environment variables
import streamlit as st
import os
import sqlite3
import google.generativeai as genai
import pandas as pd
from custom_table import get_dataframe
from utils import remove_triple_quotes
# configure genai key
genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
# Function to load Google Gemini Model and provide query as response
def get_gemini_response(question, prompt):
model=genai.GenerativeModel('gemini-pro')
response=model.generate_content([prompt[0], question])
return response.text
# function to retrieve query from the database
def read_sql_query(sql, db):
conn = sqlite3.connect(db)
cur = conn.cursor()
cur.execute(sql)
rows = cur.fetchall()
# Convert results to pandas DataFrame
sql_output_df = pd.DataFrame(rows, columns=[description[0] for description in cur.description])
conn.commit()
conn.close()
for row in rows:
print(row)
return rows, sql_output_df
## Define your prompt
prompt = [
""" You are an expert in converting English questions to SQL query!
The SQL database has the name test_db and the SQL table name is tbl_customer
and has the following columns -
CustomerKey,GeographyKey,CustomerAlternateKey,Title,FirstName,MiddleName,
LastName,NameStyle,BirthDate,MaritalStatus,Suffix,Gender,EmailAddress,
YearlyIncome,TotalChildren,NumberChildrenAtHome,EnglishEducation,
SpanishEducation,FrenchEducation,EnglishOccupation,SpanishOccupation,
FrenchOccupation,HouseOwnerFlag,NumberCarsOwned,AddressLine1,AddressLine2,
Phone,DateFirstPurchase,CommuteDistance.
\n\nFor example, \nExample 1 - How many entries of records are present?,
the SQL command will be something like this SELECT COUNT(*) FROM tbl_customer ;
\nExample 2 - Tell me all the customers who has CustomerKey greater than 11000?,
the SQL command will be something like this SELECT * FROM tbl_customer where CustomerKey > 11000";
also the sql code should NOT have ``` in beginning or end and sql word in output
also there should NOT be sql keyword at the beginning of the output """
]
st.set_page_config(page_title="I can retrieve any SQL query")
st.header("Gemini App to Retrieve SQL data")
question = st.text_input("Input: ", key=input)
submit=st.button("Ask the Question :")
# # If Submit is clicked...
# if submit:
# generated_sql_query = get_gemini_response(question, prompt)
# print(generated_sql_query)
# response = read_sql_query(generated_sql_query, "student.db")
# st.subheader("The Response is :")
# for row in response:
# print(row)
# st.header(row)
# Display DataFrame in Streamlit
df = get_dataframe()
st.write('##### This is the sample table...')
st.write(df.head(5))
# If Submit is clicked...
if submit:
generated_sql_query = get_gemini_response(question, prompt)
generated_sql_query = remove_triple_quotes(generated_sql_query)
print("The generated sql_quey is :")
print(generated_sql_query)
response, sql_output_df = read_sql_query(generated_sql_query, "test_db.db")
st.write('##### used SQL query is:')
st.write(generated_sql_query)
st.write("#### The Response is :")
# Display the DataFrame in Streamlit
st.write(sql_output_df)
# for row in response:
# print(row)
# st.write(row)
|