Jamal811 commited on
Commit
ef42c0d
·
verified ·
1 Parent(s): 457ed71

create .py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ from sentence_transformers import SentenceTransformer
4
+ import fitz # PyMuPDF for PDF handling
5
+ import faiss
6
+ import numpy as np
7
+
8
+ # Load models for embeddings and QA
9
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
10
+ qa_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
11
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
12
+
13
+ # Global variables to store documents and index
14
+ documents, passages, embeddings, file_names, indexes, index = {}, [], None, [], [], None
15
+
16
+ # Function to extract text from uploaded PDFs
17
+ def upload_and_extract_text(files):
18
+ global documents
19
+ documents = {}
20
+
21
+ for file in files:
22
+ with fitz.open(file.name) as pdf:
23
+ text = ""
24
+ for page in pdf:
25
+ text += page.get_text("text")
26
+ documents[file.name] = text
27
+ return "PDF content extracted and indexed successfully."
28
+
29
+ # Function to embed documents and create FAISS index
30
+ def embed_and_index_documents(chunk_size=300):
31
+ global passages, embeddings, file_names, indexes, index
32
+
33
+ passages, file_names, indexes = [], [], []
34
+
35
+ for file_name, text in documents.items():
36
+ chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
37
+ passages.extend(chunks)
38
+ file_names.extend([file_name] * len(chunks))
39
+ indexes.extend(range(len(chunks)))
40
+
41
+ # Create embeddings
42
+ embeddings = embedding_model.encode(passages, convert_to_tensor=False)
43
+ embedding_matrix = np.array(embeddings)
44
+
45
+ # Build FAISS index
46
+ index = faiss.IndexFlatL2(embedding_matrix.shape[1])
47
+ index.add(embedding_matrix)
48
+ return "Documents embedded and indexed successfully."
49
+
50
+ # Function to retrieve relevant passages
51
+ def retrieve_relevant_passages(question, top_k=3):
52
+ question_embedding = embedding_model.encode([question])
53
+ distances, retrieved_indices = index.search(np.array(question_embedding), top_k)
54
+ retrieved_passages = [passages[i] for i in retrieved_indices[0]]
55
+ return retrieved_passages
56
+
57
+ # Function to answer questions using retrieved passages
58
+ def answer_question(question, top_k=3):
59
+ retrieved_passages = retrieve_relevant_passages(question, top_k)
60
+ context = " ".join(retrieved_passages)
61
+ input_text = f"Answer the question based on this content: {context}. Question: {question}"
62
+
63
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
64
+ output_ids = qa_model.generate(input_ids, max_length=150)
65
+ answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
66
+
67
+ return answer
68
+
69
+ # Gradio interface functions
70
+ def handle_file_upload(files):
71
+ message = upload_and_extract_text(files)
72
+ indexing_message = embed_and_index_documents()
73
+ return f"{message}\n{indexing_message}"
74
+
75
+ def chat_with_pdfs(question):
76
+ answer = answer_question(question)
77
+ return answer
78
+
79
+ # Define Gradio UI
80
+ with gr.Blocks() as demo:
81
+ gr.Markdown("# PDF Chatbot using RAG (Retrieval-Augmented Generation)")
82
+
83
+ with gr.Tab("Upload PDF(s)"):
84
+ file_upload = gr.File(label="Upload PDF files", file_types=[".pdf"], file_count="multiple")
85
+ upload_button = gr.Button("Process PDFs")
86
+ upload_output = gr.Textbox(label="Status")
87
+ upload_button.click(fn=handle_file_upload, inputs=file_upload, outputs=upload_output)
88
+
89
+ with gr.Tab("Chat with PDFs"):
90
+ question_input = gr.Textbox(label="Ask a question about the uploaded PDFs")
91
+ answer_output = gr.Textbox(label="Answer")
92
+ question_input.submit(fn=chat_with_pdfs, inputs=question_input, outputs=answer_output)
93
+
94
+ # Launch the app
95
+ demo.launch()