Spaces:
Runtime error
Runtime error
| from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.document_loaders import ArxivLoader | |
| from faiss import IndexFlatL2 | |
| from langchain_community.docstore.in_memory import InMemoryDocstore | |
| from langchain.document_transformers import LongContextReorder | |
| from langchain_core.runnables import RunnableLambda | |
| from langchain_core.runnables.passthrough import RunnableAssign | |
| from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| import gradio as gr | |
| from functools import partial | |
| from operator import itemgetter | |
| from langchain_community.document_loaders import PyPDFLoader | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, chunk_overlap=100, | |
| separators=["\n\n", "\n", ".", ";", ",", " ", ""], | |
| ) | |
| print("Loading Documents") | |
| # docs = [ | |
| # ArxivLoader(query="1706.03762").load(), ## Attention Is All You Need Paper | |
| # ArxivLoader(query="1810.04805").load(), ## BERT Paper | |
| # ArxivLoader(query="2005.11401").load(), ## RAG Paper | |
| # # ArxivLoader(query="2205.00445").load(), ## MRKL Paper | |
| # ArxivLoader(query="2310.06825").load(), ## Mistral Paper | |
| # ArxivLoader(query="2306.05685").load(), ## LLM-as-a-Judge | |
| # ## Some longer papers | |
| # # ArxivLoader(query="2210.03629").load(), ## ReAct Paper | |
| # # ArxivLoader(query="2112.10752").load(), ## Latent Stable Diffusion Paper | |
| # # ArxivLoader(query="2103.00020").load(), ## CLIP Paper | |
| # ] | |
| # Open the PDF | |
| # with pdfplumber.open("Data_Quality_Matters_A_Case_Study_on_Data_Label_Correctness_for_Security_Bug_Report_Prediction.pdf") as pdf: | |
| # # Extract the text | |
| # text = pdf.extract_text() | |
| # print(text) | |
| # | |
| # reader = PyPDFLoader("Data_Quality_Matters_A_Case_Study_on_Data_Label_Correctness_for_Security_Bug_Report_Prediction.pdf") | |
| reader = PyPDFLoader("Inconsistent_Defect_Labels_Essence_Causes_and_Influence.pdf") | |
| docs = [reader.load_and_split()] | |
| for doc in docs: | |
| content = doc[0].page_content | |
| if "References" in content: | |
| doc[0].page_content = content[:content.index("References")] | |
| ## Split the documents and also filter out stubs (overly short chunks) | |
| print("Chunking Documents") | |
| docs_chunks = [text_splitter.split_documents(doc) for doc in docs] | |
| docs_chunks = [[c for c in dchunks if len(c.page_content) > 200] for dchunks in docs_chunks] | |
| ## Make some custom Chunks to give big-picture details | |
| # doc_string = '''Inconsistent_Defect_Labels_Essence_Causes_and_Influence | |
| # Data_Quality_Matters_A_Case_Study_on_Data_Label_Correctness_for_Security_Bug_Report_Prediction | |
| # ''' | |
| doc_string = ''' | |
| Inconsistent_Defect_Labels_Essence_Causes_and_Influence | |
| ''' | |
| doc_metadata = [] | |
| for chunks in docs_chunks: | |
| metadata = getattr(chunks[0], 'metadata', {}) | |
| # doc_string += "\n - " + metadata.get('Title') | |
| doc_metadata += [str(metadata)] | |
| extra_chunks = [doc_string] + doc_metadata | |
| embedder = NVIDIAEmbeddings(model="nvolveqa_40k", model_type=None) | |
| ## Construct series of document vector stores | |
| print("Constructing Vector Stores") | |
| vecstores = [FAISS.from_texts(extra_chunks, embedder)] | |
| vecstores += [FAISS.from_documents(doc_chunks, embedder) for doc_chunks in docs_chunks] | |
| embed_dims = len(embedder.embed_query("test")) | |
| def default_FAISS(): | |
| '''Useful utility for making an empty FAISS vectorstore''' | |
| return FAISS( | |
| embedding_function=embedder, | |
| index=IndexFlatL2(embed_dims), | |
| docstore=InMemoryDocstore(), | |
| index_to_docstore_id={}, | |
| normalize_L2=False | |
| ) | |
| def aggregate_vstores(vectorstores): | |
| ## Initialize an empty FAISS Index and merge others into it | |
| ## We'll use default_faiss for simplicity, though it's tied to your embedder by reference | |
| agg_vstore = default_FAISS() | |
| for vstore in vectorstores: | |
| agg_vstore.merge_from(vstore) | |
| return agg_vstore | |
| if 'docstore' not in globals(): | |
| ## Unintuitive optimization; merge_from seems to optimize constituent vector stores away | |
| docstore = aggregate_vstores(vecstores) | |
| print(f"Constructed aggregate docstore with {len(docstore.docstore._dict)} chunks") | |
| ######################################################################## | |
| ## Utility Runnables/Methods | |
| def RPrint(preface=""): | |
| """Simple passthrough "prints, then returns" chain""" | |
| def print_and_return(x, preface): | |
| print(f"{preface}{x}") | |
| return x | |
| return RunnableLambda(partial(print_and_return, preface=preface)) | |
| def docs2str(docs, title="Document"): | |
| """Useful utility for making chunks into context string. Optional, but useful""" | |
| out_str = "" | |
| for doc in docs: | |
| doc_name = getattr(doc, 'metadata', {}).get('Title', title) | |
| if doc_name: | |
| out_str += f"[Quote from {doc_name}] " | |
| out_str += getattr(doc, 'page_content', str(doc)) + "\n" | |
| return out_str | |
| ## Optional; Reorders longer documents to center of output text | |
| long_reorder = RunnableLambda(LongContextReorder().transform_documents) | |
| ######################################################################## | |
| llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser() | |
| convstore = default_FAISS() | |
| def save_memory_and_get_output(d, vstore): | |
| """Accepts 'input'/'output' dictionary and saves to convstore""" | |
| vstore.add_texts([ | |
| f"User previously responded with {d.get('input')}", | |
| f"Agent previously responded with {d.get('output')}" | |
| ]) | |
| return d.get('output') | |
| initial_msg = ( | |
| "Hello! I am a research paper assistant here to help Ejaz!" | |
| f" These are the papers you are studying this week! {doc_string}\n\nPlease ask questions Ejaz !" | |
| ) | |
| chat_prompt = ChatPromptTemplate.from_messages([("system", | |
| "You are a document chatbot. Help the user as they ask questions about documents." | |
| " User messaged just asked: {input}\n\n" | |
| " From this, we have retrieved the following potentially-useful info: " | |
| " Conversation History Retrieval:\n{history}\n\n" | |
| " Document Retrieval:\n{context}\n\n" | |
| " (Answer only from retrieval. Only cite sources that are used. Make your response conversational.)" | |
| ), ('user', '{input}')]) | |
| retrieval_chain = ( | |
| {'input' : (lambda x: x)} | |
| | RunnableAssign({'history' : itemgetter('input') | convstore.as_retriever() | long_reorder | docs2str}) | |
| | RunnableAssign({'context' : itemgetter('input') | docstore.as_retriever() | long_reorder | docs2str}) | |
| | RPrint() | |
| ) | |
| stream_chain = chat_prompt | llm | |
| def chat_gen(message, history=[], return_buffer=True): | |
| buffer = "" | |
| ## First perform the retrieval based on the input message | |
| retrieval = retrieval_chain.invoke(message) | |
| line_buffer = "" | |
| ## Then, stream the results of the stream_chain | |
| for token in stream_chain.stream(retrieval): | |
| buffer += token | |
| ## If you're using standard print, keep line from getting too long | |
| if not return_buffer: | |
| line_buffer += token | |
| if "\n" in line_buffer: | |
| line_buffer = "" | |
| if ((len(line_buffer)>84 and token and token[0] == " ") or len(line_buffer)>100): | |
| line_buffer = "" | |
| yield "\n" | |
| token = " " + token.lstrip() | |
| yield buffer if return_buffer else token | |
| ## Lastly, save the chat exchange to the conversation memory buffer | |
| save_memory_and_get_output({'input': message, 'output': buffer}, convstore) | |
| chatbot = gr.Chatbot(value = [[None, initial_msg]]) | |
| demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue() | |
| try: | |
| demo.launch(debug=True, share=True, show_api=False) | |
| demo.close() | |
| except Exception as e: | |
| demo.close() | |
| print(e) | |
| raise e |