Girinath11 commited on
Commit
e56b39d
Β·
verified Β·
1 Parent(s): e789621

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -42
app.py CHANGED
@@ -6,13 +6,17 @@ import docx
6
  from sentence_transformers import SentenceTransformer
7
  import faiss
8
  import numpy as np
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
  import torch
11
  from datetime import datetime
 
 
12
 
13
  # Load models
14
  print("Loading models...")
15
  embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
 
16
  tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
17
  llm_model = AutoModelForCausalLM.from_pretrained(
18
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
@@ -20,51 +24,119 @@ llm_model = AutoModelForCausalLM.from_pretrained(
20
  device_map="auto"
21
  )
22
 
23
- # Store documents
 
 
 
 
 
 
 
 
 
24
  documents = []
25
  images = []
 
26
  embeddings_index = None
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def extract_pdf_text(pdf_path):
 
29
  chunks = []
30
  with open(pdf_path, 'rb') as f:
31
  pdf = PyPDF2.PdfReader(f)
32
  for i, page in enumerate(pdf.pages):
33
  text = page.extract_text()
34
  if text.strip():
35
- chunks.append({'text': text, 'page': i+1, 'source': Path(pdf_path).name})
 
 
 
 
36
  return chunks
37
 
38
  def extract_docx_text(docx_path):
 
39
  doc = docx.Document(docx_path)
40
  text = '\n'.join([p.text for p in doc.paragraphs if p.text.strip()])
41
  return [{'text': text, 'source': Path(docx_path).name}]
42
 
43
  def extract_txt_text(txt_path):
 
44
  with open(txt_path, 'r', encoding='utf-8') as f:
45
  text = f.read()
46
  return [{'text': text, 'source': Path(txt_path).name}]
47
 
48
  def chunk_text(text, size=400):
 
49
  words = text.split()
50
  chunks = []
51
  for i in range(0, len(words), size):
52
  chunks.append(' '.join(words[i:i+size]))
53
  return chunks
54
 
55
- def process_files(files):
56
- global documents, images, embeddings_index
 
57
 
58
  if not files:
59
  return "Please upload files first"
60
 
61
  documents = []
62
  images = []
 
 
 
63
 
64
- for file in files:
 
65
  ext = Path(file.name).suffix.lower()
66
 
 
67
  if ext == '.pdf':
 
68
  chunks = extract_pdf_text(file.name)
69
  for chunk in chunks:
70
  for small_chunk in chunk_text(chunk['text']):
@@ -73,34 +145,59 @@ def process_files(files):
73
  'source': chunk['source'],
74
  'page': chunk.get('page', '')
75
  })
 
 
 
 
 
 
 
 
76
 
77
  elif ext == '.docx':
78
  chunks = extract_docx_text(file.name)
79
  for chunk in chunks:
80
  for small_chunk in chunk_text(chunk['text']):
81
- documents.append({'text': small_chunk, 'source': chunk['source']})
 
 
 
82
 
83
  elif ext == '.txt':
84
  chunks = extract_txt_text(file.name)
85
  for chunk in chunks:
86
  for small_chunk in chunk_text(chunk['text']):
87
- documents.append({'text': small_chunk, 'source': chunk['source']})
 
 
 
88
 
89
- elif ext in ['.jpg', '.jpeg', '.png']:
90
- images.append(file.name)
 
 
 
 
 
 
 
91
 
92
- # Create embeddings
 
93
  if documents:
94
  texts = [doc['text'] for doc in documents]
95
- embeddings = embedding_model.encode(texts)
96
 
97
  index = faiss.IndexFlatL2(embeddings.shape[1])
98
  index.add(embeddings.astype('float32'))
99
  embeddings_index = index
100
 
101
- return f"Processed {len(documents)} text chunks and {len(images)} images"
 
 
102
 
103
  def search_documents(query, k=3):
 
104
  if not documents or embeddings_index is None:
105
  return []
106
 
@@ -114,24 +211,59 @@ def search_documents(query, k=3):
114
 
115
  return results
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def generate_answer(question, context_docs):
 
118
  context = '\n\n'.join([doc['text'] for doc in context_docs])
119
 
120
- prompt = f"""Answer the question based on this context:
121
 
122
  {context}
123
 
124
  Question: {question}
125
  Answer:"""
126
 
127
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1500)
128
 
129
  with torch.no_grad():
130
  outputs = llm_model.generate(
131
  inputs.input_ids,
132
- max_new_tokens=250,
133
  temperature=0.7,
134
- top_p=0.9,
135
  pad_token_id=tokenizer.eos_token_id
136
  )
137
 
@@ -140,24 +272,27 @@ Answer:"""
140
 
141
  return answer
142
 
143
- def answer_query(question):
 
144
  if not question:
145
- return "Please enter a question", None
146
 
147
  if not documents:
148
- return "Please upload documents first", None
149
 
150
- # Search relevant docs
151
- relevant_docs = search_documents(question)
 
152
 
153
  if not relevant_docs:
154
- return "No relevant info found", None
155
 
156
  # Generate answer
 
157
  answer = generate_answer(question, relevant_docs)
158
 
159
  # Format response
160
- response = f"**Answer:**\n{answer}\n\n**Sources:**\n"
161
  for i, doc in enumerate(relevant_docs, 1):
162
  source = doc['source']
163
  page = doc.get('page', '')
@@ -166,45 +301,85 @@ def answer_query(question):
166
  else:
167
  response += f"{i}. {source}\n"
168
 
169
- # Return images if available
170
- imgs = images[:2] if images else None
 
171
 
172
- return response, imgs
 
 
 
 
 
 
 
 
173
 
174
  # UI
175
- with gr.Blocks(title="DocVision AI") as app:
176
- gr.Markdown("# DocVision AI - Document Q&A System")
177
- gr.Markdown("Upload documents and ask questions to get AI-powered answers")
 
 
178
 
179
  with gr.Row():
180
  with gr.Column():
181
  file_input = gr.File(
182
- label="Upload Files (PDF, DOCX, TXT, Images)",
183
  file_count="multiple",
184
- file_types=[".pdf", ".docx", ".txt", ".jpg", ".png"]
185
  )
186
- process_btn = gr.Button("Process Documents", variant="primary")
187
- status = gr.Textbox(label="Status")
188
 
189
  with gr.Column():
190
- question = gr.Textbox(label="Ask a Question", lines=2)
191
- ask_btn = gr.Button("Get Answer", variant="primary")
 
 
 
 
 
 
192
 
193
- answer = gr.Markdown(label="Answer")
194
- gallery = gr.Gallery(label="Related Images", columns=2)
 
 
 
 
195
 
 
196
  gr.Examples(
197
  examples=[
198
  ["What is this document about?"],
199
  ["Summarize the main points"],
200
- ["What are the key findings?"]
 
201
  ],
202
  inputs=question
203
  )
204
 
205
- process_btn.click(process_files, inputs=[file_input], outputs=[status])
206
- ask_btn.click(answer_query, inputs=[question], outputs=[answer, gallery])
207
- question.submit(answer_query, inputs=[question], outputs=[answer, gallery])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  if __name__ == "__main__":
210
  app.launch()
 
6
  from sentence_transformers import SentenceTransformer
7
  import faiss
8
  import numpy as np
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BlipProcessor, BlipForConditionalGeneration
10
  import torch
11
  from datetime import datetime
12
+ import fitz # PyMuPDF for better PDF image extraction
13
+ import io
14
 
15
  # Load models
16
  print("Loading models...")
17
  embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
18
+
19
+ print("Loading LLM...")
20
  tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
21
  llm_model = AutoModelForCausalLM.from_pretrained(
22
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
 
24
  device_map="auto"
25
  )
26
 
27
+ print("Loading image caption model...")
28
+ caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
29
+ caption_model = BlipForConditionalGeneration.from_pretrained(
30
+ "Salesforce/blip-image-captioning-base",
31
+ torch_dtype=torch.float16
32
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
33
+
34
+ print("βœ… All models loaded!")
35
+
36
+ # Store documents and images
37
  documents = []
38
  images = []
39
+ image_captions = []
40
  embeddings_index = None
41
 
42
+ def generate_image_caption(image_path):
43
+ """Generate caption for image"""
44
+ try:
45
+ img = Image.open(image_path).convert('RGB')
46
+ inputs = caption_processor(img, return_tensors="pt").to(caption_model.device)
47
+ out = caption_model.generate(**inputs, max_length=50)
48
+ caption = caption_processor.decode(out[0], skip_special_tokens=True)
49
+ return caption
50
+ except Exception as e:
51
+ print(f"Caption error: {e}")
52
+ return "Image from document"
53
+
54
+ def extract_images_from_pdf(pdf_path):
55
+ """Extract images from PDF using PyMuPDF"""
56
+ extracted_images = []
57
+ try:
58
+ doc = fitz.open(pdf_path)
59
+ for page_num in range(len(doc)):
60
+ page = doc[page_num]
61
+ image_list = page.get_images(full=True)
62
+
63
+ for img_index, img in enumerate(image_list):
64
+ xref = img[0]
65
+ base_image = doc.extract_image(xref)
66
+ image_bytes = base_image["image"]
67
+
68
+ # Save image
69
+ img_path = f"/tmp/pdf_img_p{page_num+1}_{img_index}.png"
70
+ with open(img_path, "wb") as f:
71
+ f.write(image_bytes)
72
+
73
+ extracted_images.append({
74
+ 'path': img_path,
75
+ 'page': page_num + 1,
76
+ 'source': Path(pdf_path).name
77
+ })
78
+
79
+ doc.close()
80
+ except Exception as e:
81
+ print(f"PDF image extraction error: {e}")
82
+
83
+ return extracted_images
84
+
85
  def extract_pdf_text(pdf_path):
86
+ """Extract text from PDF"""
87
  chunks = []
88
  with open(pdf_path, 'rb') as f:
89
  pdf = PyPDF2.PdfReader(f)
90
  for i, page in enumerate(pdf.pages):
91
  text = page.extract_text()
92
  if text.strip():
93
+ chunks.append({
94
+ 'text': text,
95
+ 'page': i+1,
96
+ 'source': Path(pdf_path).name
97
+ })
98
  return chunks
99
 
100
  def extract_docx_text(docx_path):
101
+ """Extract text from DOCX"""
102
  doc = docx.Document(docx_path)
103
  text = '\n'.join([p.text for p in doc.paragraphs if p.text.strip()])
104
  return [{'text': text, 'source': Path(docx_path).name}]
105
 
106
  def extract_txt_text(txt_path):
107
+ """Extract text from TXT"""
108
  with open(txt_path, 'r', encoding='utf-8') as f:
109
  text = f.read()
110
  return [{'text': text, 'source': Path(txt_path).name}]
111
 
112
  def chunk_text(text, size=400):
113
+ """Split text into chunks"""
114
  words = text.split()
115
  chunks = []
116
  for i in range(0, len(words), size):
117
  chunks.append(' '.join(words[i:i+size]))
118
  return chunks
119
 
120
+ def process_files(files, progress=gr.Progress()):
121
+ """Process files with progress tracking"""
122
+ global documents, images, image_captions, embeddings_index
123
 
124
  if not files:
125
  return "Please upload files first"
126
 
127
  documents = []
128
  images = []
129
+ image_captions = []
130
+
131
+ total = len(files)
132
 
133
+ for idx, file in enumerate(files):
134
+ progress((idx + 1) / total, desc=f"Processing {Path(file.name).name}...")
135
  ext = Path(file.name).suffix.lower()
136
 
137
+ # Extract text
138
  if ext == '.pdf':
139
+ # Extract text
140
  chunks = extract_pdf_text(file.name)
141
  for chunk in chunks:
142
  for small_chunk in chunk_text(chunk['text']):
 
145
  'source': chunk['source'],
146
  'page': chunk.get('page', '')
147
  })
148
+
149
+ # Extract images from PDF
150
+ pdf_images = extract_images_from_pdf(file.name)
151
+ for img in pdf_images:
152
+ images.append(img)
153
+ # Generate caption
154
+ caption = generate_image_caption(img['path'])
155
+ image_captions.append(caption)
156
 
157
  elif ext == '.docx':
158
  chunks = extract_docx_text(file.name)
159
  for chunk in chunks:
160
  for small_chunk in chunk_text(chunk['text']):
161
+ documents.append({
162
+ 'text': small_chunk,
163
+ 'source': chunk['source']
164
+ })
165
 
166
  elif ext == '.txt':
167
  chunks = extract_txt_text(file.name)
168
  for chunk in chunks:
169
  for small_chunk in chunk_text(chunk['text']):
170
+ documents.append({
171
+ 'text': small_chunk,
172
+ 'source': chunk['source']
173
+ })
174
 
175
+ elif ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']:
176
+ images.append({
177
+ 'path': file.name,
178
+ 'source': Path(file.name).name,
179
+ 'page': ''
180
+ })
181
+ # Generate caption
182
+ caption = generate_image_caption(file.name)
183
+ image_captions.append(caption)
184
 
185
+ # Create embeddings for text
186
+ progress(0.9, desc="Creating embeddings...")
187
  if documents:
188
  texts = [doc['text'] for doc in documents]
189
+ embeddings = embedding_model.encode(texts, show_progress_bar=False)
190
 
191
  index = faiss.IndexFlatL2(embeddings.shape[1])
192
  index.add(embeddings.astype('float32'))
193
  embeddings_index = index
194
 
195
+ progress(1.0, desc="Done!")
196
+
197
+ return f"βœ… Processed {len(documents)} text chunks and {len(images)} images"
198
 
199
  def search_documents(query, k=3):
200
+ """Search relevant documents"""
201
  if not documents or embeddings_index is None:
202
  return []
203
 
 
211
 
212
  return results
213
 
214
+ def find_relevant_images(query, top_k=2):
215
+ """Find images relevant to query using captions"""
216
+ if not images or not image_captions:
217
+ return [], []
218
+
219
+ # Encode query and captions
220
+ query_embedding = embedding_model.encode([query])
221
+ caption_embeddings = embedding_model.encode(image_captions)
222
+
223
+ # Calculate similarity
224
+ similarities = np.dot(caption_embeddings, query_embedding.T).flatten()
225
+
226
+ # Get top k images
227
+ top_indices = np.argsort(similarities)[::-1][:top_k]
228
+
229
+ relevant_images = []
230
+ explanations = []
231
+
232
+ for idx in top_indices:
233
+ if idx < len(images):
234
+ img_info = images[idx]
235
+ caption = image_captions[idx]
236
+
237
+ relevant_images.append(img_info['path'])
238
+
239
+ # Create explanation
240
+ explanation = f"**Image from {img_info['source']}"
241
+ if img_info.get('page'):
242
+ explanation += f" (Page {img_info['page']})"
243
+ explanation += f"**\n{caption}"
244
+ explanations.append(explanation)
245
+
246
+ return relevant_images, explanations
247
+
248
  def generate_answer(question, context_docs):
249
+ """Generate answer using LLM"""
250
  context = '\n\n'.join([doc['text'] for doc in context_docs])
251
 
252
+ prompt = f"""Answer based on context:
253
 
254
  {context}
255
 
256
  Question: {question}
257
  Answer:"""
258
 
259
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1200)
260
 
261
  with torch.no_grad():
262
  outputs = llm_model.generate(
263
  inputs.input_ids,
264
+ max_new_tokens=200,
265
  temperature=0.7,
266
+ do_sample=True,
267
  pad_token_id=tokenizer.eos_token_id
268
  )
269
 
 
272
 
273
  return answer
274
 
275
+ def answer_query(question, progress=gr.Progress()):
276
+ """Answer query with relevant images"""
277
  if not question:
278
+ return "Please enter a question", None, ""
279
 
280
  if not documents:
281
+ return "Please upload documents first", None, ""
282
 
283
+ # Search documents
284
+ progress(0.3, desc="Searching documents...")
285
+ relevant_docs = search_documents(question, k=3)
286
 
287
  if not relevant_docs:
288
+ return "No relevant info found", None, ""
289
 
290
  # Generate answer
291
+ progress(0.6, desc="Generating answer...")
292
  answer = generate_answer(question, relevant_docs)
293
 
294
  # Format response
295
+ response = f"## πŸ’‘ Answer:\n{answer}\n\n## πŸ“š Sources:\n"
296
  for i, doc in enumerate(relevant_docs, 1):
297
  source = doc['source']
298
  page = doc.get('page', '')
 
301
  else:
302
  response += f"{i}. {source}\n"
303
 
304
+ # Find relevant images
305
+ progress(0.9, desc="Finding relevant images...")
306
+ relevant_imgs, img_explanations = find_relevant_images(question, top_k=2)
307
 
308
+ # Add image explanations to response
309
+ if img_explanations:
310
+ response += f"\n## πŸ–ΌοΈ Related Images:\n"
311
+ for exp in img_explanations:
312
+ response += f"{exp}\n\n"
313
+
314
+ progress(1.0, desc="Done!")
315
+
316
+ return response, relevant_imgs if relevant_imgs else None, ""
317
 
318
  # UI
319
+ with gr.Blocks(title="DocVision AI", theme=gr.themes.Soft()) as app:
320
+ gr.Markdown("""
321
+ # πŸ“š DocVision AI - Smart Document Q&A
322
+ Upload documents and ask questions to get AI-powered answers with relevant images
323
+ """)
324
 
325
  with gr.Row():
326
  with gr.Column():
327
  file_input = gr.File(
328
+ label="πŸ“ Upload Files (PDF, DOCX, TXT, Images)",
329
  file_count="multiple",
330
+ file_types=[".pdf", ".docx", ".txt", ".jpg", ".png", ".jpeg", ".gif"]
331
  )
332
+ process_btn = gr.Button("⚑ Process Documents", variant="primary", size="lg")
333
+ status = gr.Textbox(label="Status", lines=2)
334
 
335
  with gr.Column():
336
+ question = gr.Textbox(
337
+ label="❓ Ask a Question",
338
+ placeholder="What would you like to know?",
339
+ lines=3
340
+ )
341
+ ask_btn = gr.Button("πŸ” Get Answer", variant="primary", size="lg")
342
+
343
+ answer = gr.Markdown(label="πŸ“ Answer & Sources")
344
 
345
+ with gr.Row():
346
+ gallery = gr.Gallery(
347
+ label="πŸ–ΌοΈ Relevant Images with Explanations",
348
+ columns=2,
349
+ height=400
350
+ )
351
 
352
+ gr.Markdown("### πŸ“Œ Example Questions:")
353
  gr.Examples(
354
  examples=[
355
  ["What is this document about?"],
356
  ["Summarize the main points"],
357
+ ["What are the key findings?"],
358
+ ["Show me information about diagrams or charts"]
359
  ],
360
  inputs=question
361
  )
362
 
363
+ debug_output = gr.Textbox(label="Debug Info", visible=False)
364
+
365
+ # Event handlers
366
+ process_btn.click(
367
+ process_files,
368
+ inputs=[file_input],
369
+ outputs=[status]
370
+ )
371
+
372
+ ask_btn.click(
373
+ answer_query,
374
+ inputs=[question],
375
+ outputs=[answer, gallery, debug_output]
376
+ )
377
+
378
+ question.submit(
379
+ answer_query,
380
+ inputs=[question],
381
+ outputs=[answer, gallery, debug_output]
382
+ )
383
 
384
  if __name__ == "__main__":
385
  app.launch()