| from dotenv import load_dotenv | |
| from langchain_core.messages import AIMessage, HumanMessage | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_community.utilities import SQLDatabase | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_openai import ChatOpenAI | |
| from langchain_groq import ChatGroq | |
| import streamlit as st | |
| import os | |
| load_dotenv() | |
| def init_database(user: str, password: str, host: str, port: str, database: str) -> SQLDatabase: | |
| db_uri = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}" | |
| return SQLDatabase.from_uri(db_uri) | |
| def get_sql_chain(db): | |
| template = """ | |
| You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database. | |
| Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account. | |
| <SCHEMA>{schema}</SCHEMA> | |
| Conversation History: {chat_history} | |
| Write only the SQL query and nothing else. Do not wrap the SQL query in any other text, not even backticks. | |
| Your turn: | |
| Question: {question} | |
| SQL Query: | |
| """ | |
| prompt = ChatPromptTemplate.from_template(template) | |
| llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0) | |
| def get_schema(_): | |
| return db.get_table_info() | |
| return ( | |
| RunnablePassthrough.assign(schema=get_schema) | |
| | prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| def get_response(user_query: str, db: SQLDatabase, chat_history: list): | |
| sql_chain = get_sql_chain(db) | |
| template = """ | |
| You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database. | |
| Based on the table schema below, question, sql query, and sql response, write a natural language response.the output should be in the format given below | |
| **Course Title** | |
| **Course Description** | |
| **Course Curriculum** | |
| **Course URL** | |
| <SCHEMA>{schema}</SCHEMA> | |
| Conversation History: {chat_history} | |
| SQL Query: <SQL>{query}</SQL> | |
| User question: {question} | |
| SQL Response: {response}""" | |
| prompt = ChatPromptTemplate.from_template(template) | |
| llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0) | |
| chain = ( | |
| RunnablePassthrough.assign(query=sql_chain).assign( | |
| schema=lambda _: db.get_table_info(), | |
| response=lambda vars: db.run(vars["query"]), | |
| ) | |
| | prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| return chain.invoke({ | |
| "question": user_query, | |
| "chat_history": chat_history, | |
| }) | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [ | |
| AIMessage(content="Hello! I'm a Course Recommendation System,Search about Free Analyical Vidya Courses"), | |
| ] | |
| st.set_page_config(page_title="Analytical Vidya Course Recommendation", page_icon=":speech_balloon:") | |
| st.title("Analytical Vidya Course Recommendation") | |
| user = os.getenv("DB_USER", "root") | |
| password = os.getenv("DB_PASSWORD", "Hayat123") | |
| host = os.getenv("DB_HOST", "localhost") | |
| port = os.getenv("DB_PORT", "3306") | |
| database = os.getenv("DB_NAME", "analytics") | |
| if "db" not in st.session_state: | |
| with st.spinner("Connecting to the database..."): | |
| st.session_state.db = init_database(user, password, host, port, database) | |
| st.success("Connected") | |
| for message in st.session_state.chat_history: | |
| if isinstance(message, AIMessage): | |
| with st.chat_message("AI"): | |
| st.markdown(message.content) | |
| elif isinstance(message, HumanMessage): | |
| with st.chat_message("Human"): | |
| st.markdown(message.content) | |
| user_query = st.chat_input("Type a message...") | |
| if user_query is not None and user_query.strip() != "": | |
| st.session_state.chat_history.append(HumanMessage(content=user_query)) | |
| with st.chat_message("Human"): | |
| st.markdown(user_query) | |
| with st.chat_message("AI"): | |
| response = get_response(user_query, st.session_state.db, st.session_state.chat_history) | |
| st.markdown(response) | |
| st.session_state.chat_history.append(AIMessage(content=response)) | |