Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from langchain_community.vectorstores import Pinecone as PineconeVS | |
| from pinecone import Pinecone | |
| import time | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain.chains import RetrievalQAWithSourcesChain | |
| from langchain.chains.conversation.memory import ConversationBufferMemory | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| # Constants | |
| INDEX_NAME = 'gitlab' | |
| TEXT_FIELD = "content" | |
| MODEL_NAME = 'gemini-1.5-flash' | |
| EMBEDDING_MODEL = "multilingual-e5-large" | |
| # Initialize Pinecone and return index | |
| def init_pinecone(): | |
| pc = Pinecone() # Initialize Pinecone client | |
| index = pc.Index(INDEX_NAME) | |
| time.sleep(1) # Allow time for index to initialize | |
| index.describe_index_stats() # View index stats | |
| return pc, index # Return both client and index | |
| # Embed function to create embeddings for input data | |
| def embed(pc, data): # Pass pc as a parameter to the embed function | |
| embeddings = pc.inference.embed( | |
| model=EMBEDDING_MODEL, | |
| inputs=data, | |
| parameters={"input_type": "passage", "truncate": "END"} | |
| ) | |
| return [e['values'] for e in embeddings] | |
| # Function to create the vector store | |
| def create_vectorstore(index, embed_query_fn): | |
| return PineconeVS(index, embed_query_fn, TEXT_FIELD) | |
| # Initialize the LLM model | |
| def init_llm(): | |
| return ChatGoogleGenerativeAI( | |
| model=MODEL_NAME, | |
| temperature=0.0, | |
| ) | |
| # Initialize conversation memory | |
| def init_memory(): | |
| return ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key="answer") | |
| # Create the QA chain with memory for to-and-fro Q&A | |
| def create_qa_chain(llm, vectorstore, memory): | |
| return RetrievalQAWithSourcesChain.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=vectorstore.as_retriever(), | |
| memory=memory | |
| ) | |
| # Initialize the chatbot with all necessary components | |
| def init_chatbot(): | |
| pc, index = init_pinecone() # Initialize both Pinecone client and index | |
| embed_query_fn = lambda query: embed(pc, [query])[0] # Pass pc to embed function | |
| vectorstore = create_vectorstore(index, embed_query_fn) | |
| llm = init_llm() | |
| memory = init_memory() | |
| qa_with_memory = create_qa_chain(llm, vectorstore, memory) | |
| return qa_with_memory | |
| # Streamlit UI for displaying conversation | |
| def display_conversation(): | |
| st.title("GenAI Chatbot") | |
| st.write("Explore GitLab's Handbook and Direction! What would you like to know?") | |
| # Function to create links for sources | |
| def format_sources_with_links(sources: str): | |
| # Format sources with numbered links | |
| urls = sources.split(',') | |
| links = [] | |
| for i, url in enumerate(urls, start=1): | |
| links.append(f"[[{i}]]({url})") # Format as [1](url), [2](url), etc. | |
| return links | |
| # Handle the user input and chatbot response | |
| def handle_chat(user_input, qa_with_memory): | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| with st.chat_message("user"): | |
| st.markdown(user_input) | |
| with st.chat_message("assistant"): | |
| # Get response from the model | |
| try: | |
| response = qa_with_memory.invoke(user_input) | |
| response_text = response["answer"]+ "\n\n**Sources:** " + " ".join(format_sources_with_links(response["sources"])) | |
| except Exception as e: | |
| response_text = "Error:- "+ str(e) | |
| # Display the assistant's response | |
| st.markdown(response_text) | |
| st.session_state.messages.append({"role": "assistant", "content": response_text}) | |
| # Main function to run the app | |
| def main(): | |
| # Initialize chatbot | |
| qa_with_memory = init_chatbot() | |
| # Display the UI | |
| display_conversation() | |
| # Initialize conversation history if not already present | |
| if "messages" not in st.session_state: | |
| st.session_state["messages"] = [] | |
| # Show the conversation history | |
| for message in st.session_state["messages"]: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # Get user input and process it | |
| if prompt := st.chat_input("Ask something about GitLab:"): | |
| if prompt.strip() == "": | |
| st.warning("Please enter a valid input.") | |
| return | |
| handle_chat(prompt, qa_with_memory) | |
| if __name__ == "__main__": | |
| main() | |