simran40 commited on
Commit
c0cb811
·
verified ·
1 Parent(s): 630d618

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -268
app.py CHANGED
@@ -1,324 +1,171 @@
1
  import gradio as gr
2
- import fitz # PyMuPDF
3
  import re
4
  import faiss
 
5
  import numpy as np
6
- import time
7
  from sentence_transformers import SentenceTransformer
8
- from transformers import pipeline
9
-
10
- # --- Global State and Initialization ---
11
- # These variables will hold the processed document data
12
- qa_index = None
13
- qa_chunks = []
14
- summarizer_chunks = []
15
- is_initialized = False
16
-
17
- # =================================================
18
- # MODEL LOADING (ONCE)
19
- # WARNING: This step is the primary cause of slow startup.
20
- # =================================================
21
-
22
- try:
23
- # Embedding model for semantic retrieval
24
- print("Loading Sentence Transformer model...")
25
- embedding_model = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1")
26
-
27
- # Extractive QA model (accurate answers)
28
- print("Loading Extractive QA model...")
29
- qa_pipeline = pipeline(
30
- "question-answering",
31
- model="deepset/roberta-base-squad2",
32
- tokenizer="deepset/roberta-base-squad2"
33
- )
34
 
35
- # Summarization model (clean summary)
36
- print("Loading Summarization model...")
37
- summarizer = pipeline(
38
- "summarization",
39
- model="facebook/bart-large-cnn",
40
- tokenizer="facebook/bart-large-cnn"
41
- )
42
- is_initialized = True
43
- print("All models loaded successfully.")
44
-
45
- except Exception as e:
46
- print(f"ERROR: Failed to load required models. Please check dependencies (requirements.txt). Error: {e}")
47
- # Set initialized to False so functions return an error message
48
- is_initialized = False
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # =================================================
52
- # PDF PROCESSING UTILITIES
53
- # =================================================
 
54
 
55
  def extract_text_from_pdf(pdf_path):
56
- """Extracts raw text content from a PDF file using PyMuPDF."""
57
  doc = fitz.open(pdf_path)
58
  text = ""
59
  for page in doc:
60
- text += page.get_text() + "\n\n"
61
  return text
62
 
63
 
64
  def clean_text(text):
65
- """Performs common cleanup on raw PDF text."""
66
- # Remove excessive whitespace
67
- text = re.sub(r"\s+", " ", text)
68
- # Attempt to remove table of contents, headers, footers (often document-specific)
69
- text = re.sub(r"Table of Contents.*?Introduction", "", text, flags=re.I | re.DOTALL)
70
- text = re.sub(r"\bPage \d+ of \d+\b|\bPage \d+\b", "", text)
71
- return text.strip()
72
 
73
 
74
- def chunk_text(text, chunk_size=400, overlap=100):
75
- """Chunks text for QA retrieval (smaller chunks for better context focus)."""
76
  chunks = []
77
  start = 0
78
  while start < len(text):
79
  end = start + chunk_size
80
  chunks.append(text[start:end])
81
- start = end - overlap if end < len(text) else len(text)
82
  return chunks
83
 
84
 
85
- def chunk_text_for_summary(text, chunk_size=1024, overlap=150):
86
- """Chunks text for summarization (larger chunks to maintain context flow)."""
87
- chunks = []
88
- start = 0
89
- while start < len(text):
90
- end = start + chunk_size
91
- chunks.append(text[start:end])
92
- start = end - overlap if end < len(text) else len(text)
93
- return chunks
94
-
95
-
96
- # =================================================
97
- # FAISS AND CONTEXT RETRIEVAL
98
- # =================================================
99
 
100
  def build_faiss_index(chunks):
101
- """Builds a FAISS Index from text chunks."""
102
- print(f"Encoding {len(chunks)} chunks...")
103
- embeddings = embedding_model.encode(chunks, show_progress_bar=False)
104
  embeddings = np.array(embeddings).astype("float32")
105
-
106
- # Initialize FAISS Index (L2 distance for 'multi-qa-MiniLM-L6-cos-v1')
107
  index = faiss.IndexFlatL2(embeddings.shape[1])
108
  index.add(embeddings)
109
- print("FAISS Index built.")
110
  return index, chunks
111
 
112
 
113
- def retrieve_relevant_chunks(question, index, chunks, top_k=5):
114
- """Retrieves the most relevant chunks for a given question."""
115
- # Ensure FAISS index is ready
116
- if index is None:
117
- return []
118
-
119
- # Encode the query
120
- query_embedding = embedding_model.encode([question]).astype("float32")
121
-
122
- # Search the index
123
- distances, indices = index.search(query_embedding, top_k)
124
-
125
- results = []
126
- for i, idx in enumerate(indices[0]):
127
- # Higher score (smaller distance) is better in L2
128
- results.append((chunks[idx], distances[0][i]))
129
-
130
- # Sort by distance (smallest distance first)
131
- results.sort(key=lambda x: x[1])
132
- return [r[0] for r in results]
133
-
134
-
135
- # =================================================
136
- # HANDLERS FOR GRADIO INPUT
137
- # =================================================
138
-
139
- def process_pdf(pdf_file):
140
- """
141
- Initial PDF processing step: extracts text, cleans it, chunks it,
142
- and builds the FAISS index for retrieval. Updates global state.
143
- """
144
- global qa_index, qa_chunks, summarizer_chunks
145
-
146
- if not is_initialized:
147
- return "ERROR: AI models failed to load. Please check console for details."
148
-
149
- if pdf_file is None:
150
- # Clear state if no file is provided
151
- qa_index = None
152
- qa_chunks = []
153
- summarizer_chunks = []
154
- return "Please upload a PDF document."
155
-
156
- try:
157
- start_time = time.time()
158
- print("Starting PDF processing...")
159
-
160
- # 1. Extraction and Cleaning
161
- raw_text = extract_text_from_pdf(pdf_file.name)
162
- cleaned_text = clean_text(raw_text)
163
-
164
- # 2. Chunking for QA and Summary
165
- qa_chunks = chunk_text(cleaned_text)
166
- # Summarizer chunks might be larger to keep sequential context
167
- summarizer_chunks = chunk_text_for_summary(cleaned_text)
168
-
169
- # 3. Building FAISS Index for QA
170
- qa_index, qa_chunks = build_faiss_index(qa_chunks)
171
-
172
- end_time = time.time()
173
-
174
- return (f"Document successfully processed and indexed! "
175
- f"Total chunks: {len(qa_chunks)}. "
176
- f"Ready for Q&A and Summary. (Processing time: {end_time - start_time:.2f} seconds)")
177
-
178
- except Exception as e:
179
- return f"An error occurred during PDF processing: {e}"
180
-
181
-
182
- def get_answer(question):
183
- """Handles the Question Answering functionality."""
184
- if not is_initialized:
185
- return "ERROR: AI models failed to load. Cannot answer questions."
186
-
187
- if qa_index is None:
188
- return "Please upload and process a document first."
189
-
190
- if not question or question.strip() == "":
191
- return "Please enter a question to get an answer."
192
-
193
- try:
194
- start_time = time.time()
195
- # 1. Retrieval (RAG component)
196
- relevant_chunks = retrieve_relevant_chunks(question, qa_index, qa_chunks)
197
-
198
- # Combine the retrieved chunks into a single context
199
- context = " ".join(relevant_chunks)
200
-
201
- # 2. Generation (Extractive QA component)
202
- # Pass the question and the combined, relevant context to the QA model
203
- result = qa_pipeline(
204
- question=question,
205
- context=context,
206
- # Set minimum answer length to avoid single-word outputs
207
- max_answer_len=256,
208
  )
209
 
210
- answer = result["answer"]
211
- score = result["score"]
212
-
213
- # Set a confidence threshold for a valid answer
214
- if score < 0.4 or answer.strip() == "":
215
- return "Information not found in the most relevant sections of the document (confidence too low)."
216
-
217
- end_time = time.time()
218
- return (f"Answer: {answer}\n\n"
219
- f"Confidence Score: {score:.2f}\n"
220
- f"Time taken: {end_time - start_time:.2f} seconds")
221
-
222
- except Exception as e:
223
- return f"An error occurred during Q&A generation: {e}"
224
-
225
-
226
- def get_summary():
227
- """Handles the Summarization functionality."""
228
- if not is_initialized:
229
- return "ERROR: AI models failed to load. Cannot generate summary."
230
-
231
- if not summarizer_chunks:
232
- return "Please upload and process a document first."
233
-
234
- try:
235
- start_time = time.time()
236
- summaries = []
237
-
238
- # Summarize each chunk sequentially
239
- for i, chunk in enumerate(summarizer_chunks):
240
- print(f"Summarizing chunk {i+1}/{len(summarizer_chunks)}")
241
- summary_output = summarizer(
242
- chunk,
243
- max_length=150,
244
- min_length=50,
245
- do_sample=False,
246
- truncation=True # Crucial to handle inputs slightly over the model's max length
247
- )[0]["summary_text"]
248
- summaries.append(summary_output)
249
-
250
- # Join the sequential summaries and run a final merge summary
251
- merged_summary_text = " ".join(summaries)
252
-
253
- # If the merged summary is still too long, run a final summary pass
254
- if len(merged_summary_text) > 1024:
255
- print("Running final merge summary...")
256
- final_summary_output = summarizer(
257
- merged_summary_text,
258
- max_length=400,
259
- min_length=150,
260
- do_sample=False,
261
- truncation=True
262
- )[0]["summary_text"]
263
- else:
264
- final_summary_output = merged_summary_text
265
-
266
- end_time = time.time()
267
- return (f"--- Document Summary ---\n\n{final_summary_output}\n\n"
268
- f"Time taken: {end_time - start_time:.2f} seconds")
269
-
270
- except Exception as e:
271
- return f"An error occurred during summarization: {e}"
272
-
273
-
274
- # =================================================
275
- # GRADIO UI
276
- # =================================================
277
 
278
  with gr.Blocks() as demo:
279
 
280
  gr.Markdown("""
281
- # 📄 Open-Source RAG Document Analysis System (Python/Gradio)
282
-
283
- This system uses three best-in-class open-source models for **Retrieval-Augmented Generation (RAG)**:
284
- 1. **`multi-qa-MiniLM-L6-cos-v1`**: for fast, accurate context retrieval.
285
- 2. **`deepset/roberta-base-squad2`**: for highly accurate, extractive Question Answering.
286
- 3. **`facebook/bart-large-cnn`**: for multi-step, high-quality Summarization.
287
-
288
- ⚠️ **Warning**: Initial model loading is very slow. Please be patient after the app starts.
289
- """)
290
 
291
- with gr.Row():
292
- pdf_input = gr.File(label="📤 Upload PDF Document", file_types=[".pdf"])
293
- process_status = gr.Textbox(label="Processing Status", interactive=False, value="Upload a PDF to begin.")
294
-
295
- process_btn = gr.Button("1. Process & Index Document", variant="primary")
296
- process_btn.click(process_pdf, [pdf_input], process_status)
297
 
298
- gr.Markdown("---")
299
-
300
  with gr.Row():
301
  with gr.Column(scale=1):
 
 
 
 
 
302
  question_input = gr.Textbox(
303
- label="❓ Step 2: Ask a Question",
304
- placeholder="e.g. What were the Q4 revenue figures?",
305
  lines=2
306
  )
307
- qa_btn = gr.Button("🔍 Get Accurate Answer", variant="secondary")
308
 
309
- with gr.Column(scale=1):
310
- summary_btn = gr.Button("📝 Step 2: Generate Full Summary", variant="secondary")
311
 
312
- output_box = gr.Textbox(label="📌 Output / Result", lines=10, interactive=False)
 
 
 
 
313
 
314
- # Bind events
315
- qa_btn.click(get_answer, [question_input], output_box)
316
- summary_btn.click(get_summary, [], output_box)
 
 
317
 
318
  gr.Markdown("""
319
  ---
320
- *Disclaimer: Due to the size of the models, expect longer processing times for Q&A and Summarization than API-based solutions.*
 
321
  """)
322
 
323
- # To run the Gradio application
324
- demo.launch()
 
1
  import gradio as gr
2
+ import fitz
3
  import re
4
  import faiss
5
+ import torch
6
  import numpy as np
7
+
8
  from sentence_transformers import SentenceTransformer
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # ===============================
13
+ # MODEL LOADING
14
+ # ===============================
15
+
16
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
17
+
18
+ LLM_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
19
+ tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
20
+ llm = AutoModelForCausalLM.from_pretrained(
21
+ LLM_NAME,
22
+ torch_dtype=torch.float32
23
+ )
24
+
25
+ llm.eval()
26
 
27
+
28
+ # ===============================
29
+ # PDF PROCESSING
30
+ # ===============================
31
 
32
  def extract_text_from_pdf(pdf_path):
 
33
  doc = fitz.open(pdf_path)
34
  text = ""
35
  for page in doc:
36
+ text += page.get_text()
37
  return text
38
 
39
 
40
  def clean_text(text):
41
+ return re.sub(r"\s+", " ", text).strip()
 
 
 
 
 
 
42
 
43
 
44
+ def chunk_text(text, chunk_size=500, overlap=50):
 
45
  chunks = []
46
  start = 0
47
  while start < len(text):
48
  end = start + chunk_size
49
  chunks.append(text[start:end])
50
+ start = end - overlap
51
  return chunks
52
 
53
 
54
+ # ===============================
55
+ # VECTOR DB (FAISS)
56
+ # ===============================
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  def build_faiss_index(chunks):
59
+ embeddings = embedding_model.encode(chunks)
 
 
60
  embeddings = np.array(embeddings).astype("float32")
 
 
61
  index = faiss.IndexFlatL2(embeddings.shape[1])
62
  index.add(embeddings)
 
63
  return index, chunks
64
 
65
 
66
+ def retrieve_relevant_chunks(query, index, chunks, top_k=3):
67
+ query_embedding = embedding_model.encode([query]).astype("float32")
68
+ _, indices = index.search(query_embedding, top_k)
69
+ return [chunks[i] for i in indices[0]]
70
+
71
+
72
+ # ===============================
73
+ # LLM ANSWER
74
+ # ===============================
75
+
76
+ def generate_answer(question, context_chunks):
77
+ context = "\n\n".join(context_chunks)
78
+
79
+ prompt = f"""
80
+ Answer the question strictly using the given context.
81
+ If the answer is not found, say:
82
+ "Information not found in the document."
83
+
84
+ Context:
85
+ {context}
86
+
87
+ Question:
88
+ {question}
89
+
90
+ Answer:
91
+ """
92
+
93
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
94
+
95
+ with torch.no_grad():
96
+ output = llm.generate(
97
+ **inputs,
98
+ max_new_tokens=200,
99
+ temperature=0.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  )
101
 
102
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True)
103
+ return decoded.split("Answer:")[-1].strip()
104
+
105
+
106
+ # ===============================
107
+ # MAIN PIPELINE
108
+ # ===============================
109
+
110
+ def pdf_rag_chat(pdf_file, question):
111
+ if pdf_file is None or question.strip() == "":
112
+ return "Please upload a PDF and enter a question."
113
+
114
+ text = extract_text_from_pdf(pdf_file.name)
115
+ text = clean_text(text)
116
+
117
+ chunks = chunk_text(text)
118
+ index, chunks = build_faiss_index(chunks)
119
+ context = retrieve_relevant_chunks(question, index, chunks)
120
+
121
+ return generate_answer(question, context)
122
+
123
+
124
+ # ===============================
125
+ # GRADIO UI (GRADIO 6 SAFE)
126
+ # ===============================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  with gr.Blocks() as demo:
129
 
130
  gr.Markdown("""
131
+ # 📄 PDF RAG Chatbot (Open-Source AI)
 
 
 
 
 
 
 
 
132
 
133
+ Upload a **PDF** and ask questions based **only on its content**.
134
+ Built using **Retrieval Augmented Generation (RAG)** and
135
+ **open-source Hugging Face models**, running on **free CPU**.
136
+ """)
 
 
137
 
 
 
138
  with gr.Row():
139
  with gr.Column(scale=1):
140
+ pdf_input = gr.File(
141
+ label="📤 Upload PDF",
142
+ file_types=[".pdf"]
143
+ )
144
+
145
  question_input = gr.Textbox(
146
+ label="❓ Ask a question",
147
+ placeholder="e.g. What is the objective of the project?",
148
  lines=2
149
  )
 
150
 
151
+ submit_btn = gr.Button("🔍 Get Answer")
 
152
 
153
+ with gr.Column(scale=2):
154
+ answer_output = gr.Textbox(
155
+ label="📌 Answer",
156
+ lines=10
157
+ )
158
 
159
+ submit_btn.click(
160
+ fn=pdf_rag_chat,
161
+ inputs=[pdf_input, question_input],
162
+ outputs=answer_output
163
+ )
164
 
165
  gr.Markdown("""
166
  ---
167
+ **© Simranpreet Kaur**
168
+ **NIELIT Ropar | AIML Six Months Training | 2026**
169
  """)
170
 
171
+ demo.launch()