Spaces:
Build error
Build error
| import os | |
| import gradio as gr | |
| from pypdf import PdfReader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| from langchain.llms import HuggingFacePipeline | |
| from langchain.chains import RetrievalQA | |
| from langchain.memory import ConversationBufferMemory | |
| # Function to extract text from PDFs | |
| def extract_text_from_pdf(pdf_file): | |
| try: | |
| reader = PdfReader(pdf_file) | |
| text = "" | |
| for page in reader.pages: | |
| extracted = page.extract_text() | |
| if extracted: | |
| text += extracted + "\n" | |
| return text | |
| except Exception as e: | |
| return f"Error reading PDF: {e}" | |
| # Function to process PDFs and create vector store | |
| def process_pdfs(pdf_files): | |
| documents = [] | |
| for pdf_file in pdf_files: | |
| text = extract_text_from_pdf(pdf_file) | |
| if text and not text.startswith("Error"): | |
| documents.append(text) | |
| # Chunk documents | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=150, | |
| length_function=len | |
| ) | |
| chunks = [] | |
| for doc in documents: | |
| splits = text_splitter.split_text(doc) | |
| chunks.extend(splits) | |
| # Create embeddings and vector store | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| vector_store = FAISS.from_texts(chunks, embeddings) | |
| return vector_store | |
| # Initialize LLM | |
| def initialize_llm(): | |
| model_name = "google/flan-t5-base" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| pipe = pipeline( | |
| "text2text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_length=512, | |
| temperature=0.7, | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| return llm | |
| # Create RAG chain | |
| def create_rag_chain(vector_store, llm): | |
| prompt_template = """Use the following pieces of context to answer the question. If you don't know the answer, say so. Do not make up information. | |
| {context} | |
| Question: {question} | |
| Answer: """ | |
| prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| input_key="question", | |
| output_key="answer", | |
| max_len=4 | |
| ) | |
| chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=vector_store.as_retriever(search_kwargs={"k": 5}), | |
| return_source_documents=True, | |
| chain_type_kwargs={"prompt": prompt, "memory": memory} | |
| ) | |
| return chain | |
| # Gradio interface function | |
| def rag_interface(pdf_files, question): | |
| if not pdf_files: | |
| return "Please upload at least one PDF file.", "" | |
| # Process PDFs and create vector store | |
| vector_store = process_pdfs(pdf_files) | |
| # Initialize LLM and RAG chain | |
| llm = initialize_llm() | |
| rag_chain = create_rag_chain(vector_store, llm) | |
| # Get answer | |
| result = rag_chain({"query": question}) | |
| answer = result["result"] | |
| chat_history = rag_chain.combine_documents_chain.memory.chat_memory.messages | |
| # Format chat history | |
| history_text = "" | |
| for i in range(0, len(chat_history), 2): | |
| if i + 1 < len(chat_history): | |
| history_text += f"Q: {chat_history[i].content}\nA: {chat_history[i+1].content}\n\n" | |
| return answer, history_text | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# RAG Question Answering System") | |
| pdf_input = gr.File(label="Upload PDFs", file_count="multiple", file_types=[".pdf"]) | |
| question_input = gr.Textbox(label="Ask a question") | |
| answer_output = gr.Textbox(label="Answer") | |
| history_output = gr.Textbox(label="Chat History") | |
| submit_button = gr.Button("Submit") | |
| submit_button.click( | |
| fn=rag_interface, | |
| inputs=[pdf_input, question_input], | |
| outputs=[answer_output, history_output] | |
| ) | |
| demo.launch(share=True) |