File size: 4,266 Bytes
4630812
c07889e
4630812
c07889e
 
 
 
4630812
c07889e
 
4630812
 
 
c07889e
 
 
4630812
c07889e
 
 
 
4630812
c07889e
4630812
c07889e
4630812
c07889e
 
 
4630812
c07889e
 
4630812
c07889e
4630812
 
c07889e
4630812
c07889e
 
4630812
c07889e
 
 
 
 
 
4630812
c07889e
 
4630812
c07889e
 
 
 
 
 
 
 
4630812
c07889e
 
 
 
4630812
c07889e
 
 
4630812
c07889e
 
 
 
 
 
 
 
 
4630812
c07889e
 
 
 
4630812
c07889e
 
 
 
4630812
c07889e
 
4630812
c07889e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4630812
c07889e
 
4630812
c07889e
 
4630812
c07889e
4630812
c07889e
4630812
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from dotenv import load_dotenv
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
import streamlit as st
import os

load_dotenv()

def init_database(user: str, password: str, host: str, port: str, database: str) -> SQLDatabase:
    db_uri = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}"
    return SQLDatabase.from_uri(db_uri)

def get_sql_chain(db):
    template = """
    You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
    Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.

    <SCHEMA>{schema}</SCHEMA>

    Conversation History: {chat_history}

    Write only the SQL query and nothing else. Do not wrap the SQL query in any other text, not even backticks.

    Your turn:

    Question: {question}
    SQL Query:
    """

    prompt = ChatPromptTemplate.from_template(template)

    llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)

    def get_schema(_):
        return db.get_table_info()

    return (
            RunnablePassthrough.assign(schema=get_schema)
            | prompt
            | llm
            | StrOutputParser()
    )

def get_response(user_query: str, db: SQLDatabase, chat_history: list):
    sql_chain = get_sql_chain(db)

    template = """
    You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
    Based on the table schema below, question, sql query, and sql response, write a natural language response.the output should be in the format given below
    **Course Title**
    **Course Description**
    **Course Curriculum**
    **Course URL**
    <SCHEMA>{schema}</SCHEMA>

    Conversation History: {chat_history}
    SQL Query: <SQL>{query}</SQL>
    User question: {question}
    SQL Response: {response}"""

    prompt = ChatPromptTemplate.from_template(template)

    llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)

    chain = (
            RunnablePassthrough.assign(query=sql_chain).assign(
                schema=lambda _: db.get_table_info(),
                response=lambda vars: db.run(vars["query"]),
            )
            | prompt
            | llm
            | StrOutputParser()
    )

    return chain.invoke({
        "question": user_query,
        "chat_history": chat_history,
    })

if "chat_history" not in st.session_state:
    st.session_state.chat_history = [
        AIMessage(content="Hello! I'm a Course Recommendation System,Search about Free Analyical Vidya Courses"),
    ]

st.set_page_config(page_title="Analytical Vidya Course Recommendation", page_icon=":speech_balloon:")
st.title("Analytical Vidya Course Recommendation")

user = os.getenv("DB_USER", "root")
password = os.getenv("DB_PASSWORD", "Hayat123")
host = os.getenv("DB_HOST", "localhost")
port = os.getenv("DB_PORT", "3306")
database = os.getenv("DB_NAME", "analytics")

if "db" not in st.session_state:
    with st.spinner("Connecting to the database..."):
        st.session_state.db = init_database(user, password, host, port, database)
        st.success("Connected")

for message in st.session_state.chat_history:
    if isinstance(message, AIMessage):
        with st.chat_message("AI"):
            st.markdown(message.content)
    elif isinstance(message, HumanMessage):
        with st.chat_message("Human"):
            st.markdown(message.content)

user_query = st.chat_input("Type a message...")
if user_query is not None and user_query.strip() != "":
    st.session_state.chat_history.append(HumanMessage(content=user_query))

    with st.chat_message("Human"):
        st.markdown(user_query)

    with st.chat_message("AI"):
        response = get_response(user_query, st.session_state.db, st.session_state.chat_history)

        st.markdown(response)

    st.session_state.chat_history.append(AIMessage(content=response))