Spaces:
Build error
Build error
alejandro
commited on
Commit
·
bda4e9b
1
Parent(s):
411b037
feat: update prompt template && add groq LLM
Browse files- requirements.txt +2 -0
- src/app.py +13 -5
requirements.txt
CHANGED
|
@@ -4,3 +4,5 @@ langchain-community==0.0.21
|
|
| 4 |
langchain-core==0.1.24
|
| 5 |
langchain-openai==0.0.6
|
| 6 |
mysql-connector-python==8.3.0
|
|
|
|
|
|
|
|
|
| 4 |
langchain-core==0.1.24
|
| 5 |
langchain-openai==0.0.6
|
| 6 |
mysql-connector-python==8.3.0
|
| 7 |
+
groq==0.4.2
|
| 8 |
+
langchain-groq==0.0.1
|
src/app.py
CHANGED
|
@@ -3,6 +3,7 @@ from langchain_community.utilities import SQLDatabase
|
|
| 3 |
from langchain_core.output_parsers import StrOutputParser
|
| 4 |
from langchain_core.runnables import RunnablePassthrough
|
| 5 |
from langchain_openai import ChatOpenAI
|
|
|
|
| 6 |
from langchain_core.messages import HumanMessage, AIMessage
|
| 7 |
from langchain_core.prompts import ChatPromptTemplate
|
| 8 |
from dotenv import load_dotenv
|
|
@@ -16,6 +17,11 @@ def get_sql_chain(db):
|
|
| 16 |
Based on the table schema below, write a SQL query that would answer the user's question.
|
| 17 |
{schema}
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
Question: {question}
|
| 20 |
SQL Query:
|
| 21 |
"""
|
|
@@ -23,6 +29,7 @@ def get_sql_chain(db):
|
|
| 23 |
prompt = ChatPromptTemplate.from_template(template)
|
| 24 |
|
| 25 |
llm = ChatOpenAI()
|
|
|
|
| 26 |
|
| 27 |
def get_schema(_):
|
| 28 |
return db.get_table_info()
|
|
@@ -49,6 +56,7 @@ def get_response(user_query, chat_history, db):
|
|
| 49 |
|
| 50 |
prompt = ChatPromptTemplate.from_template(template)
|
| 51 |
|
|
|
|
| 52 |
llm = ChatOpenAI()
|
| 53 |
|
| 54 |
def get_schema(_):
|
|
@@ -85,11 +93,11 @@ with st.sidebar:
|
|
| 85 |
st.title("Chat with a MySQL Database")
|
| 86 |
st.write("This is a simple chat application allows you to chat with a MySQL database.")
|
| 87 |
|
| 88 |
-
st.text_input("Host", key="name")
|
| 89 |
-
st.text_input("Port", key="port")
|
| 90 |
-
st.text_input("Username", key="username")
|
| 91 |
-
st.text_input("Password", key="password")
|
| 92 |
-
st.text_input("Database", key="database")
|
| 93 |
|
| 94 |
if st.button("Connect"):
|
| 95 |
with st.spinner("Connecting to the database..."):
|
|
|
|
| 3 |
from langchain_core.output_parsers import StrOutputParser
|
| 4 |
from langchain_core.runnables import RunnablePassthrough
|
| 5 |
from langchain_openai import ChatOpenAI
|
| 6 |
+
from langchain_groq import ChatGroq
|
| 7 |
from langchain_core.messages import HumanMessage, AIMessage
|
| 8 |
from langchain_core.prompts import ChatPromptTemplate
|
| 9 |
from dotenv import load_dotenv
|
|
|
|
| 17 |
Based on the table schema below, write a SQL query that would answer the user's question.
|
| 18 |
{schema}
|
| 19 |
|
| 20 |
+
Write only the SQL query and nothing else. For example:
|
| 21 |
+
Question: which 3 artists have the most tracks?
|
| 22 |
+
SQL Query: SELECT ArtistId, COUNT(*) as track_count FROM Track GROUP BY ArtistId ORDER BY track_count DESC LIMIT 3;
|
| 23 |
+
Question: Name 10 artists
|
| 24 |
+
SQL Query: SELECT Name FROM Artist LIMIT 10;
|
| 25 |
Question: {question}
|
| 26 |
SQL Query:
|
| 27 |
"""
|
|
|
|
| 29 |
prompt = ChatPromptTemplate.from_template(template)
|
| 30 |
|
| 31 |
llm = ChatOpenAI()
|
| 32 |
+
# llm = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768")
|
| 33 |
|
| 34 |
def get_schema(_):
|
| 35 |
return db.get_table_info()
|
|
|
|
| 56 |
|
| 57 |
prompt = ChatPromptTemplate.from_template(template)
|
| 58 |
|
| 59 |
+
# llm = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768")
|
| 60 |
llm = ChatOpenAI()
|
| 61 |
|
| 62 |
def get_schema(_):
|
|
|
|
| 93 |
st.title("Chat with a MySQL Database")
|
| 94 |
st.write("This is a simple chat application allows you to chat with a MySQL database.")
|
| 95 |
|
| 96 |
+
st.text_input("Host", key="name", value="localhost")
|
| 97 |
+
st.text_input("Port", key="port", value="3306")
|
| 98 |
+
st.text_input("Username", key="username", value="root")
|
| 99 |
+
st.text_input("Password", key="password", type="password", value="admin")
|
| 100 |
+
st.text_input("Database", key="database", value="Chinook")
|
| 101 |
|
| 102 |
if st.button("Connect"):
|
| 103 |
with st.spinner("Connecting to the database..."):
|