trohith89 commited on
Commit
996808d
·
verified ·
1 Parent(s): 7219fc6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import sqlite3
3
+ import uuid
4
+ import time
5
+ from langchain_google_genai import GoogleGenerativeAI
6
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
7
+ from langchain_core.output_parsers import StrOutputParser
8
+ from langchain_community.chat_message_histories import SQLChatMessageHistory
9
+ from langchain_core.runnables.history import RunnableWithMessageHistory
10
+
11
+ # Load API key
12
+ GOOGLE_API_KEY = st.secrets.get("GOOGLE_API_KEY")
13
+
14
+ # Set up the Gemini 1.5 Pro model
15
+ llm = GoogleGenerativeAI(api_key=GOOGLE_API_KEY, model="gemini-1.5-pro")
16
+
17
+ # Initialize SQLite database
18
+ conn = sqlite3.connect("chat_history.db", check_same_thread=False)
19
+ cursor = conn.cursor()
20
+ cursor.execute("""
21
+ CREATE TABLE IF NOT EXISTS chat (
22
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
23
+ session_id TEXT,
24
+ role TEXT,
25
+ content TEXT
26
+ )
27
+ """)
28
+ conn.commit()
29
+
30
+ # Function to save messages
31
+ def save_message(session_id, role, content):
32
+ cursor.execute("INSERT INTO chat (session_id, role, content) VALUES (?, ?, ?)", (session_id, role, content))
33
+ conn.commit()
34
+
35
+ # Function to load chat history
36
+ def load_chat_history(session_id):
37
+ cursor.execute("SELECT role, content FROM chat WHERE session_id = ?", (session_id,))
38
+ return cursor.fetchall()
39
+
40
+ # Chat history instance
41
+ def chat_history(session_id):
42
+ return SQLChatMessageHistory(
43
+ session_id=session_id,
44
+ connection="sqlite:///chat_history.db"
45
+ )
46
+
47
+ # Generate unique session ID
48
+ if "session_id" not in st.session_state:
49
+ st.session_state.session_id = str(uuid.uuid4())
50
+
51
+ # Custom CSS for UI enhancements
52
+ st.markdown("""
53
+ <style>
54
+ body {
55
+ background-color: #E3F2FD;
56
+ }
57
+ .title-text {
58
+ text-align: center;
59
+ font-size: 28px;
60
+ font-weight: bold;
61
+ color: #1976D2;
62
+ padding: 15px;
63
+ }
64
+ .stTextInput {
65
+ position: fixed;
66
+ bottom: 10px;
67
+ width: 80%;
68
+ left: 10%;
69
+ z-index: 999;
70
+ }
71
+ .chat-container {
72
+ background-color: white;
73
+ padding: 20px;
74
+ border-radius: 10px;
75
+ box-shadow: 2px 2px 10px rgba(0,0,0,0.1);
76
+ }
77
+ </style>
78
+ """, unsafe_allow_html=True)
79
+
80
+ # Display title
81
+ st.markdown("""<h1 class='title-text'>💬 AI Data Science Tutor</h1>""", unsafe_allow_html=True)
82
+
83
+ # New Chat Button
84
+ if st.button("🆕 New Chat"):
85
+ st.session_state.session_id = str(uuid.uuid4()) # Generate new session
86
+ st.session_state.messages = [] # Clear chat history
87
+ st.rerun() # Refresh the app
88
+
89
+ # Get session ID
90
+ session_id = st.session_state.session_id
91
+ chat_history_instance = chat_history(session_id)
92
+
93
+ # Define Chat Prompt Template
94
+ chat_prompt = ChatPromptTemplate(
95
+ messages=[
96
+ ('system', """You are an AI assistant specialized in Data Science tutoring.
97
+ You will only answer questions related to Data Science.
98
+ If asked anything outside this topic, politely decline and request a Data Science-related question.
99
+ """),
100
+ MessagesPlaceholder(variable_name="history", optional=True),
101
+ ('human', '{prompt}')
102
+ ]
103
+ )
104
+
105
+ # Define output parser
106
+ out_parser = StrOutputParser()
107
+
108
+ # Create a chain
109
+ chain = chat_prompt | llm | out_parser
110
+
111
+ # Define Runnable with message history
112
+ chat = RunnableWithMessageHistory(
113
+ chain,
114
+ lambda session: SQLChatMessageHistory(session, "sqlite:///chat_history.db"),
115
+ input_messages_key="prompt",
116
+ history_messages_key="history"
117
+ )
118
+
119
+ # Chat History Container
120
+ st.markdown("### Chat History")
121
+ chat_container = st.container()
122
+
123
+ # Load chat history and display it
124
+ if "messages" not in st.session_state:
125
+ st.session_state.messages = load_chat_history(session_id)
126
+
127
+ with chat_container:
128
+ for role, content in st.session_state.messages:
129
+ with st.chat_message(role):
130
+ st.markdown(content)
131
+
132
+ # User input at the bottom
133
+ user_input = st.text_input("Type your message here:", key="user_message")
134
+
135
+ # If user submits a message
136
+ if user_input:
137
+ save_message(session_id, "user", user_input)
138
+ st.session_state.messages.append(("user", user_input))
139
+
140
+ # Invoke AI model
141
+ config = {'configurable': {'session_id': session_id}}
142
+ response = chat.invoke({'prompt': user_input}, config)
143
+
144
+ save_message(session_id, "assistant", response)
145
+ st.session_state.messages.append(("assistant", response))
146
+
147
+ # Display AI response
148
+ with chat_container:
149
+ with st.chat_message("assistant"):
150
+ st.markdown(response)
151
+
152
+ # Clear the input field
153
+ st.session_state.pop("user_message")
154
+ st.session_state["user_message"] = ""
155
+ st.rerun() # Refresh the app