Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| # from htmlTemplates import css, bot_template, user_template | |
| from dotenv import load_dotenv | |
| # from PyPDF2 import PdfReader | |
| import os | |
| import mysql.connector | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain_community.embeddings import HuggingFaceInstructEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.llms import HuggingFaceHub | |
| from langchain_openai import ChatOpenAI | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.chains import ConversationalRetrievalChain | |
| def get_pdf_text(slug): | |
| load_dotenv() | |
| text = "" | |
| try: | |
| conn = mysql.connector.connect( | |
| user=os.getenv("SQL_USER"), | |
| password=os.getenv("SQL_PWD"), | |
| host=os.getenv("SQL_HOST"), | |
| database="Birdseye_DB", | |
| ) | |
| cursor = conn.cursor() | |
| # Execute a query | |
| cursor.execute("SELECT ocr_text FROM birdseye_temp WHERE slug = %s", (slug,)) | |
| # Fetch the results | |
| rows = cursor.fetchall() | |
| for row in rows: | |
| if row[0]: | |
| text += row[0] | |
| except mysql.connector.Error as err: | |
| st.error(f"Error: {err}") | |
| finally: | |
| if conn.is_connected(): | |
| cursor.close() | |
| conn.close() | |
| return text | |
| def get_text_chunks(text): | |
| """ | |
| Splits the given text into chunks based on specified character settings. | |
| Parameters: | |
| - text (str): The text to be split into chunks. | |
| Returns: | |
| - list: A list of text chunks. | |
| """ | |
| text_splitter = CharacterTextSplitter( | |
| separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len | |
| ) | |
| chunks = text_splitter.split_text(text) | |
| return chunks | |
| def get_vectorstore(text_chunks): | |
| """ | |
| Generates a vector store from a list of text chunks using specified embeddings. | |
| Parameters: | |
| - text_chunks (list of str): Text segments to convert into vector embeddings. | |
| Returns: | |
| - FAISS: A FAISS vector store containing the embeddings of the text chunks. | |
| """ | |
| embeddings = OpenAIEmbeddings() | |
| vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings) | |
| return vectorstore | |
| def get_conversation_chain(vectorstore): | |
| """ | |
| Initializes a conversational retrieval chain that uses a large language model | |
| for generating responses based on the provided vector store. | |
| Parameters: | |
| - vectorstore (FAISS): A vector store to be used for retrieving relevant content. | |
| Returns: | |
| - ConversationalRetrievalChain: An initialized conversational chain object. | |
| """ | |
| try: | |
| llm = ChatOpenAI(model_name="gpt-4o", temperature=0.5, top_p=0.5) | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", return_messages=True | |
| ) | |
| conversation_chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, retriever=vectorstore.as_retriever(), memory=memory | |
| ) | |
| return conversation_chain | |
| except Exception as e: | |
| raise # Re-raise exception to handle it or log it properly elsewhere | |
| def handle_userinput(user_question): | |
| response = st.session_state.conversation( | |
| { | |
| "question": f"Based on the memory and the provided document, answer the following user question: {user_question}. If the question is unrelated to memory or the document, just mention that you cannot provide an answer." | |
| } | |
| ) | |
| st.session_state.chat_history = response["chat_history"] | |
| for i, message in reversed(list(enumerate(st.session_state.chat_history))): | |
| if i % 2 == 0: | |
| st.write( | |
| user_template.replace("{{MSG}}", message.content), | |
| unsafe_allow_html=True, | |
| ) | |
| else: | |
| st.write( | |
| bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True | |
| ) | |
| def get_user_chat_count(user_id): | |
| """ | |
| Retrieves the chat count for the user from the MySQL database. | |
| """ | |
| try: | |
| conn = mysql.connector.connect( | |
| user=os.getenv("SQL_USER"), | |
| password=os.getenv("SQL_PWD"), | |
| host=os.getenv("SQL_HOST"), | |
| database="Birdseye_DB", | |
| ) | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT count FROM birdseye_chat WHERE user_id = %s", (user_id,)) | |
| result = cursor.fetchone() | |
| if result: | |
| return result[0] | |
| else: | |
| # Insert a new row for the user if not found | |
| cursor.execute( | |
| "INSERT INTO birdseye_chat (user_id, count) VALUES (%s, %s)", | |
| (user_id, 0), | |
| ) | |
| conn.commit() | |
| return 0 | |
| except mysql.connector.Error as err: | |
| st.error(f"Error: {err}") | |
| return None | |
| finally: | |
| if conn.is_connected(): | |
| cursor.close() | |
| conn.close() | |
| def increment_user_chat_count(user_id): | |
| """ | |
| Increments the chat count for the user in the MySQL database. | |
| """ | |
| try: | |
| conn = mysql.connector.connect( | |
| user=os.getenv("SQL_USER"), | |
| password=os.getenv("SQL_PWD"), | |
| host=os.getenv("SQL_HOST"), | |
| database="Birdseye_DB", | |
| ) | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "UPDATE birdseye_chat SET count = count + 1 WHERE user_id = %s ", (user_id,) | |
| ) | |
| conn.commit() | |
| except mysql.connector.Error as err: | |
| st.error(f"Error: {err}") | |
| finally: | |
| if conn.is_connected(): | |
| cursor.close() | |
| conn.close() | |
| def is_user_in_unlimited_chat_group(user_id): | |
| """ | |
| Checks if the user belongs to the 'Unlimited Chat' group. | |
| """ | |
| try: | |
| conn = mysql.connector.connect( | |
| user=os.getenv("SQL_USER"), | |
| password=os.getenv("SQL_PWD"), | |
| host=os.getenv("SQL_HOST"), | |
| database="Birdseye_DB", | |
| ) | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """ | |
| SELECT 1 | |
| FROM auth_user_groups | |
| JOIN auth_group ON auth_user_groups.group_id = auth_group.id | |
| WHERE auth_user_groups.user_id = %s AND auth_group.name = 'Unlimited Chat' | |
| """, | |
| (user_id,), | |
| ) | |
| return cursor.fetchone() is not None | |
| except mysql.connector.Error as err: | |
| st.error(f"Error: {err}") | |
| return False | |
| finally: | |
| if conn.is_connected(): | |
| cursor.close() | |
| conn.close() | |
| def chat(slug, user_id): | |
| """ | |
| Manages the chat interface in the Streamlit application, handling the conversation | |
| flow and displaying the chat history. | |
| Restricts chat based on user group and chat count. | |
| """ | |
| st.write( | |
| "**Please note:** Due to processing limitations, the chat may not fully comprehend the whole document." | |
| ) | |
| text_chunks = get_text_chunks(get_pdf_text(slug)) | |
| vectorstore = get_vectorstore(text_chunks) | |
| st.session_state.conversation = get_conversation_chain(vectorstore) | |
| # Check if the user can chat | |
| if not is_user_in_unlimited_chat_group(user_id): | |
| user_chat_count = get_user_chat_count(user_id) | |
| if user_chat_count is None or user_chat_count >= 20: | |
| st.write("You have reached your chat limit.") | |
| return | |
| if len(st.session_state.messages) == 1: | |
| message = st.session_state.messages[0] | |
| with st.chat_message(message["role"]): | |
| st.write(message["content"]) | |
| else: | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.write(message["content"]) | |
| # User-provided prompt | |
| if prompt := st.chat_input(): | |
| # increment_user_chat_count(user_id) | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| st.session_state.prompts = prompt | |
| with st.chat_message("user"): | |
| st.write(prompt) | |
| if st.session_state.messages[-1]["role"] != "ai": | |
| with st.spinner("Generating response..."): | |
| response = st.session_state.conversation.invoke( | |
| {"question": st.session_state.prompts} | |
| ) | |
| with st.chat_message("ai"): | |
| message_content = response["chat_history"][-1].content | |
| st.session_state.messages.append({"role": "ai", "content": message_content}) | |
| st.write(message_content) | |
| if not is_user_in_unlimited_chat_group(user_id): | |
| increment_user_chat_count(user_id) # Increment count after response | |
| def init(): | |
| """ | |
| Initializes the session state variables used in the Streamlit application and | |
| loads environment variables. | |
| """ | |
| if "pdf" not in st.session_state: | |
| st.session_state["pdf"] = False | |
| if "conversation" not in st.session_state: | |
| st.session_state.conversation = None | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = None | |
| if "messages" not in st.session_state.keys(): | |
| st.session_state.messages = [ | |
| { | |
| "role": "ai", | |
| "content": "What do you want to learn about the document? Ask me a question!", | |
| } | |
| ] | |
| def main(): | |
| init() | |
| query_params = st.query_params | |
| slug = query_params.get("slug") | |
| user_id = query_params.get("user_id") | |
| load_dotenv() | |
| st.title("Chat with GPT :books:") | |
| if slug and user_id: | |
| chat(slug, user_id) | |
| else: | |
| st.error("Please return to Birdseye and select a document.") | |
| if __name__ == "__main__": | |
| main() | |