import os from typing import List from langchain_chroma import Chroma from langchain.chains import ConversationalRetrievalChain from langchain_groq import ChatGroq from langchain_community.document_loaders import PyPDFLoader from langchain.memory import ChatMessageHistory, ConversationBufferMemory from langchain_google_genai import GoogleGenerativeAIEmbeddings from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_core.prompts import PromptTemplate from langchain import hub import chainlit as cl from io import BytesIO ##################################### Load the embeddings and model ##################################### groq_api_key = os.getenv("GROQ_API_KEY") embeddings_api_key = os.getenv('GOOGLE_API_KEY') embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001") llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0) ##################################### on_chat_start event handler ####################################### @cl.on_chat_start async def on_chat_start(): files = None while files is None: files = await cl.AskFileMessage( content="Please upload a text file to begin", accept=["application/pdf"], max_size_mb=20, timeout=300 ).send() file = files[0] msg = cl.Message(content=f"Processing `{file.name}` ...") await msg.send() ##################################### Load the text from the file #################################### pdf_loader = PyPDFLoader(file.path).load() ##################################### Split the text into chunks ##################################### text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) chunks = text_splitter.split_documents(pdf_loader) ##################################### Chroma DB setup ################################################ docsearch = await cl.make_async(Chroma.from_documents)( chunks, embedding_model ) message_history = ChatMessageHistory() memory = ConversationBufferMemory( memory_key="chat_history", output_key="answer", chat_memory=message_history, return_messages=True ) ##################################### Chain setup ################################################### # Define your custom prompt template custom_prompt_template = """ Based on the provided context please answer . if you don't know the answer. just say i don't know. {context} Question: {question} """ custom_prompt = PromptTemplate( template=custom_prompt_template, input_variables=["context", "question"],) chain = ConversationalRetrievalChain.from_llm( llm, chain_type="stuff", retriever=docsearch.as_retriever(), memory=memory, return_source_documents=True, combine_docs_chain_kwargs={"prompt": custom_prompt} ) msg.content = f"Processing `{file.name}` ... Done!✅ You can ask questions now!" await msg.update() cl.user_session.set("chain", chain) ##################################### On message event handler ########################################### @cl.on_message async def main(message: cl.Message): chain = cl.user_session.get("chain") cb = cl.AsyncLangchainCallbackHandler() res = await chain.acall(message.content, callbacks=[cb]) answer = res['answer'] source_documents = res["source_documents"] # type: List[Document] text_elements = [] # type: List[cl.Text] if source_documents: for source_idx, source_doc in enumerate(source_documents): source_name = f"source_{source_idx}" # Create the text element referenced in the message text_elements.append( cl.Text(content=source_doc.page_content, name=source_name, display="side") ) source_names = [text_el.name for text_el in text_elements] if source_names: answer += f"\nSources: {', '.join(source_names)}" else: answer += "\nNo sources found" await cl.Message(content=answer, elements=text_elements).send()