from dotenv import load_dotenv
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
import streamlit as st
import os
load_dotenv()
def init_database(user: str, password: str, host: str, port: str, database: str) -> SQLDatabase:
db_uri = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}"
return SQLDatabase.from_uri(db_uri)
def get_sql_chain(db):
template = """
You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.
{schema}
Conversation History: {chat_history}
Write only the SQL query and nothing else. Do not wrap the SQL query in any other text, not even backticks.
Your turn:
Question: {question}
SQL Query:
"""
prompt = ChatPromptTemplate.from_template(template)
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
def get_schema(_):
return db.get_table_info()
return (
RunnablePassthrough.assign(schema=get_schema)
| prompt
| llm
| StrOutputParser()
)
def get_response(user_query: str, db: SQLDatabase, chat_history: list):
sql_chain = get_sql_chain(db)
template = """
You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
Based on the table schema below, question, sql query, and sql response, write a natural language response.the output should be in the format given below
**Course Title**
**Course Description**
**Course Curriculum**
**Course URL**
{schema}
Conversation History: {chat_history}
SQL Query: {query}
User question: {question}
SQL Response: {response}"""
prompt = ChatPromptTemplate.from_template(template)
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
chain = (
RunnablePassthrough.assign(query=sql_chain).assign(
schema=lambda _: db.get_table_info(),
response=lambda vars: db.run(vars["query"]),
)
| prompt
| llm
| StrOutputParser()
)
return chain.invoke({
"question": user_query,
"chat_history": chat_history,
})
if "chat_history" not in st.session_state:
st.session_state.chat_history = [
AIMessage(content="Hello! I'm a Course Recommendation System,Search about Free Analyical Vidya Courses"),
]
st.set_page_config(page_title="Analytical Vidya Course Recommendation", page_icon=":speech_balloon:")
st.title("Analytical Vidya Course Recommendation")
user = os.getenv("DB_USER", "root")
password = os.getenv("DB_PASSWORD", "Hayat123")
host = os.getenv("DB_HOST", "localhost")
port = os.getenv("DB_PORT", "3306")
database = os.getenv("DB_NAME", "analytics")
if "db" not in st.session_state:
with st.spinner("Connecting to the database..."):
st.session_state.db = init_database(user, password, host, port, database)
st.success("Connected")
for message in st.session_state.chat_history:
if isinstance(message, AIMessage):
with st.chat_message("AI"):
st.markdown(message.content)
elif isinstance(message, HumanMessage):
with st.chat_message("Human"):
st.markdown(message.content)
user_query = st.chat_input("Type a message...")
if user_query is not None and user_query.strip() != "":
st.session_state.chat_history.append(HumanMessage(content=user_query))
with st.chat_message("Human"):
st.markdown(user_query)
with st.chat_message("AI"):
response = get_response(user_query, st.session_state.db, st.session_state.chat_history)
st.markdown(response)
st.session_state.chat_history.append(AIMessage(content=response))