Spaces:
Build error
Build error
| import streamlit as st | |
| import sqlite3 | |
| import uuid | |
| from langchain_google_genai import GoogleGenerativeAI | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_community.chat_message_histories import SQLChatMessageHistory | |
| from langchain_core.runnables.history import RunnableWithMessageHistory | |
| # Load API key from secrets | |
| GOOGLE_API_KEY = st.secrets.get("GOOGLE_API_KEY") | |
| # Set up the Gemini 1.5 Pro model | |
| llm = GoogleGenerativeAI(api_key=GOOGLE_API_KEY, model="gemini-1.5-pro") | |
| # Initialize SQLite database for chat history | |
| conn = sqlite3.connect("chat_history.db", check_same_thread=False) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS chat ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| session_id TEXT, | |
| role TEXT, | |
| content TEXT | |
| ) | |
| """) | |
| conn.commit() | |
| # Function to save messages | |
| def save_message(session_id, role, content): | |
| cursor.execute("INSERT INTO chat (session_id, role, content) VALUES (?, ?, ?)", (session_id, role, content)) | |
| conn.commit() | |
| # Function to load chat history | |
| def load_chat_history(session_id): | |
| cursor.execute("SELECT role, content FROM chat WHERE session_id = ?", (session_id,)) | |
| return cursor.fetchall() | |
| # Chat history instance | |
| def chat_history(session_id): | |
| return SQLChatMessageHistory( | |
| session_id=session_id, | |
| connection="sqlite:///chat_history.db" # ✅ FIXED: Use a connection string | |
| ) | |
| # Generate unique session ID for each user | |
| if "session_id" not in st.session_state: | |
| st.session_state.session_id = str(uuid.uuid4()) | |
| session_id = st.session_state.session_id | |
| chat_history_instance = chat_history(session_id) | |
| # Define Chat Prompt Template | |
| chat_prompt = ChatPromptTemplate( | |
| messages=[ | |
| ('system', """You are an AI assistant specialized in Data Science tutoring. | |
| You will only answer questions related to Data Science. | |
| If asked anything outside this topic, politely decline and request a Data Science-related question. | |
| """), | |
| MessagesPlaceholder(variable_name="history", optional=True), | |
| ('human', '{prompt}') | |
| ] | |
| ) | |
| # Define output parser | |
| out_parser = StrOutputParser() | |
| # Create a chain | |
| chain = chat_prompt | llm | out_parser | |
| # Define Runnable with message history | |
| chat = RunnableWithMessageHistory( | |
| chain, | |
| lambda session: SQLChatMessageHistory(session, "sqlite:///chat_history.db"), | |
| input_messages_key="prompt", | |
| history_messages_key="history" | |
| ) | |
| # Streamlit UI | |
| st.title("Conversational AI Data Science Tutor") | |
| st.write("Ask me anything about Data Science!") | |
| # Load chat history | |
| st.session_state.setdefault("messages", load_chat_history(session_id)) | |
| # Display chat history | |
| for role, content in st.session_state.messages: | |
| with st.chat_message(role): | |
| st.markdown(content) | |
| # User input | |
| user_input = st.text_input("You:", "", key="user_input") | |
| if user_input: | |
| save_message(session_id, "user", user_input) | |
| st.session_state.messages.append(("user", user_input)) | |
| # Invoke the AI model | |
| config = {'configurable': {'session_id': session_id}} | |
| response = chat.invoke({'prompt': user_input}, config) | |
| save_message(session_id, "assistant", response) | |
| st.session_state.messages.append(("assistant", response)) | |
| with st.chat_message("assistant"): | |
| st.markdown(response) | |