Spaces:
Build error
Build error
| import gradio as gr | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from sentence_transformers import SentenceTransformer | |
| import fitz # PyMuPDF for PDF handling | |
| import faiss | |
| import numpy as np | |
| # Load models for embeddings and QA | |
| embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| qa_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large") | |
| tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large") | |
| # Global variables to store documents and index | |
| documents, passages, embeddings, file_names, indexes, index = {}, [], None, [], [], None | |
| # Function to extract text from uploaded PDFs | |
| def upload_and_extract_text(files): | |
| global documents | |
| documents = {} | |
| for file in files: | |
| with fitz.open(file.name) as pdf: | |
| text = "" | |
| for page in pdf: | |
| text += page.get_text("text") | |
| documents[file.name] = text | |
| return "PDF content extracted and indexed successfully." | |
| # Function to embed documents and create FAISS index | |
| def embed_and_index_documents(chunk_size=300): | |
| global passages, embeddings, file_names, indexes, index | |
| passages, file_names, indexes = [], [], [] | |
| for file_name, text in documents.items(): | |
| chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] | |
| passages.extend(chunks) | |
| file_names.extend([file_name] * len(chunks)) | |
| indexes.extend(range(len(chunks))) | |
| # Create embeddings | |
| embeddings = embedding_model.encode(passages, convert_to_tensor=False) | |
| embedding_matrix = np.array(embeddings) | |
| # Build FAISS index | |
| index = faiss.IndexFlatL2(embedding_matrix.shape[1]) | |
| index.add(embedding_matrix) | |
| return "Documents embedded and indexed successfully." | |
| # Function to retrieve relevant passages | |
| def retrieve_relevant_passages(question, top_k=3): | |
| question_embedding = embedding_model.encode([question]) | |
| distances, retrieved_indices = index.search(np.array(question_embedding), top_k) | |
| retrieved_passages = [passages[i] for i in retrieved_indices[0]] | |
| return retrieved_passages | |
| # Function to answer questions using retrieved passages | |
| def answer_question(question, top_k=3): | |
| retrieved_passages = retrieve_relevant_passages(question, top_k) | |
| context = " ".join(retrieved_passages) | |
| input_text = f"Answer the question based on this content: {context}. Question: {question}" | |
| input_ids = tokenizer.encode(input_text, return_tensors="pt") | |
| output_ids = qa_model.generate(input_ids, max_length=150) | |
| answer = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| return answer | |
| # Gradio interface functions | |
| def handle_file_upload(files): | |
| message = upload_and_extract_text(files) | |
| indexing_message = embed_and_index_documents() | |
| return f"{message}\n{indexing_message}" | |
| def chat_with_pdfs(question): | |
| answer = answer_question(question) | |
| return answer | |
| # Define Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# PDF Chatbot using RAG (Retrieval-Augmented Generation)") | |
| with gr.Tab("Upload PDF(s)"): | |
| file_upload = gr.File(label="Upload PDF files", file_types=[".pdf"], file_count="multiple") | |
| upload_button = gr.Button("Process PDFs") | |
| upload_output = gr.Textbox(label="Status") | |
| upload_button.click(fn=handle_file_upload, inputs=file_upload, outputs=upload_output) | |
| with gr.Tab("Chat with PDFs"): | |
| question_input = gr.Textbox(label="Ask a question about the uploaded PDFs") | |
| answer_output = gr.Textbox(label="Answer") | |
| question_input.submit(fn=chat_with_pdfs, inputs=question_input, outputs=answer_output) | |
| # Launch the app | |
| demo.launch() |