SimranShaikh commited on
Commit
43812ec
Β·
verified Β·
1 Parent(s): 99810bc
Files changed (1) hide show
  1. src/streamlit_app.py +605 -374
src/streamlit_app.py CHANGED
@@ -1,428 +1,659 @@
1
- # app.py - Main Hugging Face Spaces Application
2
- import gradio as gr
3
- import PyPDF2
4
- import pdfplumber
5
- import fitz # PyMuPDF
6
- import pandas as pd
7
- import re
8
- import logging
9
  import os
 
 
 
10
  import tempfile
11
- from typing import Dict, List, Tuple, Optional
12
- from pathlib import Path
13
  import json
 
 
 
14
 
15
- # Set up logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
- class PDFProcessorError(Exception):
20
- """Custom exception for PDF processing errors"""
21
- pass
 
 
 
 
 
 
 
 
 
22
 
23
- def enhanced_pdf_processor(file_path: str) -> Dict:
24
- """
25
- Enhanced PDF processor for Hugging Face deployment
26
- """
27
- results = {
28
- 'text': '',
29
- 'tables': [],
30
- 'metadata': {},
31
- 'extraction_method': 'unknown',
32
- 'success': False,
33
- 'error': None,
34
- 'file_info': {},
35
- 'summary': ''
36
- }
37
 
38
- try:
39
- # Validate file
40
- if not os.path.exists(file_path):
41
- results['error'] = f"File does not exist: {file_path}"
42
- return results
43
-
44
- # Get file info
45
- results['file_info'] = get_file_info(file_path)
46
-
47
- # Try different extraction methods
48
- extraction_methods = [
49
- ('PyMuPDF', extract_with_pymupdf),
50
- ('pdfplumber', extract_with_pdfplumber),
51
- ('PyPDF2', extract_with_pypdf2)
52
- ]
53
-
54
- for method_name, method_func in extraction_methods:
55
- try:
56
- logger.info(f"Trying extraction method: {method_name}")
57
-
58
- if method_name == 'pdfplumber':
59
- text_result, tables = method_func(file_path)
60
- if text_result and len(text_result.strip()) > 10:
61
- results['text'] = text_result
62
- results['tables'] = tables
63
- results['extraction_method'] = method_name
64
- results['success'] = True
65
- break
66
-
67
- elif method_name == 'PyMuPDF':
68
- text_result, metadata = method_func(file_path)
69
- if text_result and len(text_result.strip()) > 10:
70
- results['text'] = text_result
71
- results['metadata'] = metadata
72
- results['extraction_method'] = method_name
73
- results['success'] = True
74
- break
75
-
76
- else: # PyPDF2
77
- text_result = method_func(file_path)
78
- if text_result and len(text_result.strip()) > 10:
79
- results['text'] = text_result
80
- results['extraction_method'] = method_name
81
- results['success'] = True
82
- break
83
-
84
- except Exception as e:
85
- logger.warning(f"{method_name} failed: {str(e)}")
86
- continue
87
-
88
- # Generate summary if successful
89
- if results['success']:
90
- results['summary'] = generate_document_summary(results['text'])
91
- else:
92
- results['error'] = "All extraction methods failed"
93
-
94
- except Exception as e:
95
- results['error'] = f"Processing error: {str(e)}"
96
- logger.error(f"PDF processing error: {e}")
97
 
98
- return results
99
-
100
- def extract_with_pypdf2(file_path: str) -> str:
101
- """Extract text using PyPDF2"""
102
- text = ""
103
- try:
104
- with open(file_path, 'rb') as file:
105
- reader = PyPDF2.PdfReader(file)
106
-
107
- if reader.is_encrypted:
108
- try:
109
- reader.decrypt("")
110
- except:
111
- raise PDFProcessorError("PDF is encrypted")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- for page_num, page in enumerate(reader.pages):
114
- try:
115
- page_text = page.extract_text()
116
- if page_text:
117
- text += f"\n--- Page {page_num + 1} ---\n{page_text}\n"
118
- except Exception as e:
119
- logger.warning(f"Failed to extract page {page_num + 1}: {e}")
120
-
121
- return clean_text(text)
122
-
123
- except Exception as e:
124
- raise PDFProcessorError(f"PyPDF2 extraction failed: {e}")
125
-
126
- def extract_with_pdfplumber(file_path: str) -> Tuple[str, List[Dict]]:
127
- """Extract text and tables using pdfplumber"""
128
- text = ""
129
- tables = []
130
 
131
- try:
132
- with pdfplumber.open(file_path) as pdf:
133
- for page_num, page in enumerate(pdf.pages):
134
- try:
135
- # Extract text
136
- page_text = page.extract_text()
137
- if page_text:
138
- text += f"\n--- Page {page_num + 1} ---\n{page_text}\n"
139
-
140
- # Extract tables
141
- page_tables = page.extract_tables()
142
- for table_num, table in enumerate(page_tables):
143
- if table and len(table) > 1:
144
- tables.append({
145
- 'page': page_num + 1,
146
- 'table_number': table_num + 1,
147
- 'data': table,
148
- 'text_representation': table_to_text(table)
149
- })
150
-
151
- except Exception as e:
152
- logger.warning(f"Failed to process page {page_num + 1}: {e}")
153
-
154
- return clean_text(text), tables
155
-
156
- except Exception as e:
157
- raise PDFProcessorError(f"pdfplumber extraction failed: {e}")
 
 
 
 
 
158
 
159
- def extract_with_pymupdf(file_path: str) -> Tuple[str, Dict]:
160
- """Extract text using PyMuPDF"""
161
- text = ""
162
- metadata = {}
 
 
163
 
164
- try:
165
- doc = fitz.open(file_path)
 
 
166
 
167
- # Extract metadata
168
- try:
169
- doc_metadata = doc.metadata or {}
170
- metadata = {
171
- 'page_count': doc.page_count,
172
- 'title': doc_metadata.get('title', ''),
173
- 'author': doc_metadata.get('author', ''),
174
- 'subject': doc_metadata.get('subject', ''),
175
- 'creator': doc_metadata.get('creator', ''),
176
- 'creation_date': doc_metadata.get('creationDate', '')
177
- }
178
- except Exception as e:
179
- metadata = {'page_count': doc.page_count}
180
 
181
- # Extract text
182
- for page_num in range(doc.page_count):
183
- try:
184
- page = doc[page_num]
185
- page_text = page.get_text()
186
- if page_text:
187
- text += f"\n--- Page {page_num + 1} ---\n{page_text}\n"
188
- except Exception as e:
189
- logger.warning(f"Failed to extract page {page_num + 1}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
- doc.close()
192
- return clean_text(text), metadata
 
 
 
 
 
 
 
 
193
 
194
- except Exception as e:
195
- raise PDFProcessorError(f"PyMuPDF extraction failed: {e}")
196
 
197
- def clean_text(text: str) -> str:
198
- """Clean extracted text"""
199
- if not text:
200
- return ""
201
 
202
- # Remove excessive whitespace
203
- text = re.sub(r'\n\s*\n', '\n\n', text)
204
- text = re.sub(r' +', ' ', text)
205
-
206
- # Remove problematic characters
207
- text = text.replace('\ufffd', '')
208
- text = text.replace('\x00', '')
209
- text = text.replace('\u200b', '')
 
 
 
 
 
210
 
211
- return text.strip()
212
-
213
- def table_to_text(table: List[List]) -> str:
214
- """Convert table to text"""
215
- if not table:
216
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
- text_lines = []
219
- for row in table:
220
- if row:
221
- clean_row = [str(cell).strip() if cell else "" for cell in row]
222
- if any(clean_row):
223
- text_lines.append(" | ".join(clean_row))
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
- return "\n".join(text_lines)
226
-
227
- def get_file_info(file_path: str) -> Dict:
228
- """Get file information"""
229
- try:
230
- path = Path(file_path)
231
- stat = path.stat()
232
  return {
233
- 'name': path.name,
234
- 'size': stat.st_size,
235
- 'size_mb': round(stat.st_size / (1024 * 1024), 2)
236
  }
237
- except Exception:
238
- return {}
239
 
240
- def generate_document_summary(text: str) -> str:
241
- """Generate a simple document summary"""
242
- if not text:
243
- return "No text extracted"
244
-
245
- # Basic statistics
246
- words = len(text.split())
247
- lines = len(text.split('\n'))
248
- chars = len(text)
249
 
250
- # Extract first few sentences for preview
251
- sentences = re.split(r'[.!?]+', text)
252
- preview = '. '.join(sentences[:3]).strip()
253
- if len(preview) > 300:
254
- preview = preview[:300] + "..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
- return f"""
257
- Document Statistics:
258
- - Characters: {chars:,}
259
- - Words: {words:,}
260
- - Lines: {lines:,}
261
-
262
- Preview:
263
- {preview}
264
- """
265
-
266
- def process_pdf_file(file) -> Tuple[str, str, str, str]:
267
- """
268
- Process uploaded PDF file for Gradio interface
269
- """
270
- if file is None:
271
- return "No file uploaded", "", "", ""
272
 
273
- try:
274
- # Create temporary file
275
- with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
276
- tmp_file.write(file.read())
277
- tmp_file_path = tmp_file.name
278
 
279
- # Process the PDF
280
- result = enhanced_pdf_processor(tmp_file_path)
 
281
 
282
- # Clean up
283
- os.unlink(tmp_file_path)
 
 
284
 
285
- if result['success']:
286
- # Format results for display
287
- status = f"βœ… Successfully processed using {result['extraction_method']}"
288
-
289
- # File info
290
- file_info = result.get('file_info', {})
291
- info = f"""
292
- File: {file_info.get('name', 'Unknown')}
293
- Size: {file_info.get('size_mb', 0)} MB
294
- Pages: {result.get('metadata', {}).get('page_count', 'Unknown')}
295
- """
296
-
297
- # Summary
298
- summary = result.get('summary', 'No summary available')
299
 
300
- # Full text (truncated for display)
301
- full_text = result['text']
302
- if len(full_text) > 5000:
303
- display_text = full_text[:5000] + f"\n\n... (Text truncated. Total length: {len(full_text)} characters)"
304
- else:
305
- display_text = full_text
306
 
307
- # Tables info
308
- if result['tables']:
309
- tables_info = f"\n\nTables found: {len(result['tables'])}"
310
- for i, table in enumerate(result['tables'][:3]): # Show first 3 tables
311
- tables_info += f"\n\nTable {i+1} (Page {table['page']}):\n"
312
- tables_info += table['text_representation'][:500]
313
- if len(table['text_representation']) > 500:
314
- tables_info += "..."
315
- display_text += tables_info
316
 
317
- return status, info, summary, display_text
 
 
 
 
 
 
 
 
 
 
318
 
319
- else:
320
- error_msg = result.get('error', 'Unknown error')
321
- return f"❌ Processing failed: {error_msg}", "", "", ""
322
 
323
- except Exception as e:
324
- return f"❌ Error: {str(e)}", "", "", ""
 
325
 
326
- def answer_question(text: str, question: str) -> str:
327
- """
328
- Simple keyword-based question answering
329
- """
330
- if not text or not question:
331
- return "Please provide both text and a question."
332
 
333
- # Convert to lowercase for searching
334
- text_lower = text.lower()
335
- question_lower = question.lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
- # Extract keywords from question
338
- keywords = [word for word in question_lower.split() if len(word) > 3]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
- # Find relevant sentences
341
- sentences = re.split(r'[.!?]+', text)
342
- relevant_sentences = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
- for sentence in sentences:
345
- sentence_lower = sentence.lower()
346
- score = sum(1 for keyword in keywords if keyword in sentence_lower)
347
- if score > 0:
348
- relevant_sentences.append((sentence.strip(), score))
349
 
350
- # Sort by relevance and take top 3
351
- relevant_sentences.sort(key=lambda x: x[1], reverse=True)
352
- top_sentences = [sent[0] for sent in relevant_sentences[:3]]
353
 
354
- if top_sentences:
355
- return f"Based on the document, here are the most relevant sections:\n\n" + "\n\n".join(top_sentences)
356
- else:
357
- return "I couldn't find information related to your question in the document."
358
-
359
- # Global variable to store extracted text
360
- extracted_text = ""
361
-
362
- def update_extracted_text(status, info, summary, full_text):
363
- """Update global extracted text variable"""
364
- global extracted_text
365
- extracted_text = full_text
366
- return status, info, summary, full_text
367
-
368
- def qa_interface(question):
369
- """Interface for question answering"""
370
- global extracted_text
371
- return answer_question(extracted_text, question)
372
-
373
- # Create Gradio interface
374
- with gr.Blocks(title="PDF Processor & Q&A System") as app:
375
- gr.Markdown("# πŸ“„ PDF Processor & Question Answering System")
376
- gr.Markdown("Upload a PDF file to extract text and ask questions about its content.")
377
 
378
- with gr.Tab("PDF Processing"):
379
- with gr.Row():
380
- with gr.Column():
381
- file_input = gr.File(label="Upload PDF", file_types=[".pdf"])
382
- process_btn = gr.Button("Process PDF", variant="primary")
 
 
 
 
 
 
 
 
 
 
383
 
384
- with gr.Column():
385
- status_output = gr.Textbox(label="Status", lines=2)
386
- info_output = gr.Textbox(label="File Information", lines=4)
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
- summary_output = gr.Textbox(label="Document Summary", lines=8)
389
- text_output = gr.Textbox(label="Extracted Text", lines=15, max_lines=20)
390
-
391
- with gr.Tab("Question & Answer"):
392
- gr.Markdown("Ask questions about the processed PDF content.")
393
- with gr.Row():
394
- question_input = gr.Textbox(label="Your Question", placeholder="What is this document about?")
395
- ask_btn = gr.Button("Ask Question", variant="primary")
396
 
397
- answer_output = gr.Textbox(label="Answer", lines=8)
 
 
 
 
398
 
399
- # Event handlers
400
- process_btn.click(
401
- fn=process_pdf_file,
402
- inputs=[file_input],
403
- outputs=[status_output, info_output, summary_output, text_output]
404
- ).then(
405
- fn=update_extracted_text,
406
- inputs=[status_output, info_output, summary_output, text_output],
407
- outputs=[status_output, info_output, summary_output, text_output]
408
- )
409
 
410
- ask_btn.click(
411
- fn=qa_interface,
412
- inputs=[question_input],
413
- outputs=[answer_output]
414
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
- # Example
417
- gr.Examples(
418
- examples=[
419
- ["What is the main topic of this document?"],
420
- ["What are the key findings?"],
421
- ["Who are the authors?"],
422
- ["What is the conclusion?"]
423
- ],
424
- inputs=[question_input]
425
  )
426
 
427
  if __name__ == "__main__":
428
- app.launch()
 
1
+ # Enterprise AI Assistant with RAG using IBM Granite Models
2
+ # Complete implementation with Streamlit interface
3
+
4
+ import streamlit as st
 
 
 
 
5
  import os
6
+ import pandas as pd
7
+ import numpy as np
8
+ from typing import List, Dict, Any, Optional
9
  import tempfile
 
 
10
  import json
11
+ from datetime import datetime
12
+ import logging
13
+ from pathlib import Path
14
 
15
+ # Core libraries for RAG
16
+ import chromadb
17
+ from sentence_transformers import SentenceTransformer
18
+ import torch
19
+ from transformers import (
20
+ AutoTokenizer,
21
+ AutoModelForCausalLM,
22
+ pipeline,
23
+ BitsAndBytesConfig
24
+ )
25
+
26
+ # Document processing
27
+ import PyPDF2
28
+ from docx import Document
29
+ import openpyxl
30
+ from bs4 import BeautifulSoup
31
+ import email
32
+ from email.mime.text import MIMEText
33
+ import chardet
34
+
35
+ # Additional utilities
36
+ import re
37
+ from urllib.parse import urlparse
38
+ import hashlib
39
+ import pickle
40
+
41
+ # Configure logging
42
  logging.basicConfig(level=logging.INFO)
43
  logger = logging.getLogger(__name__)
44
 
45
+ # Configuration
46
+ class Config:
47
+ GRANITE_MODEL_NAME = "ibm-granite/granite-3.1-8b-instruct"
48
+ GRANITE_GUARDIAN_MODEL = "ibm-granite/granite-guardian-3.2-5b"
49
+ EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
50
+ CHUNK_SIZE = 1000
51
+ CHUNK_OVERLAP = 200
52
+ MAX_CONTEXT_LENGTH = 4000
53
+ TEMPERATURE = 0.7
54
+ MAX_NEW_TOKENS = 512
55
+ TOP_K = 5
56
+ SUPPORTED_FORMATS = ['.pdf', '.docx', '.xlsx', '.txt', '.csv', '.html', '.json', '.md']
57
 
58
+ class DocumentProcessor:
59
+ """Handles document processing and text extraction"""
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ @staticmethod
62
+ def extract_text_from_pdf(file_path: str) -> str:
63
+ """Extract text from PDF files"""
64
+ try:
65
+ with open(file_path, 'rb') as file:
66
+ reader = PyPDF2.PdfReader(file)
67
+ text = ""
68
+ for page in reader.pages:
69
+ text += page.extract_text() + "\n"
70
+ return text
71
+ except Exception as e:
72
+ logger.error(f"Error extracting text from PDF: {e}")
73
+ return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ @staticmethod
76
+ def extract_text_from_docx(file_path: str) -> str:
77
+ """Extract text from DOCX files"""
78
+ try:
79
+ doc = Document(file_path)
80
+ text = ""
81
+ for paragraph in doc.paragraphs:
82
+ text += paragraph.text + "\n"
83
+ return text
84
+ except Exception as e:
85
+ logger.error(f"Error extracting text from DOCX: {e}")
86
+ return ""
87
+
88
+ @staticmethod
89
+ def extract_text_from_xlsx(file_path: str) -> str:
90
+ """Extract text from Excel files"""
91
+ try:
92
+ workbook = openpyxl.load_workbook(file_path)
93
+ text = ""
94
+ for sheet_name in workbook.sheetnames:
95
+ sheet = workbook[sheet_name]
96
+ text += f"Sheet: {sheet_name}\n"
97
+ for row in sheet.iter_rows(values_only=True):
98
+ row_text = " | ".join([str(cell) if cell else "" for cell in row])
99
+ if row_text.strip():
100
+ text += row_text + "\n"
101
+ text += "\n"
102
+ return text
103
+ except Exception as e:
104
+ logger.error(f"Error extracting text from XLSX: {e}")
105
+ return ""
106
+
107
+ @staticmethod
108
+ def extract_text_from_csv(file_path: str) -> str:
109
+ """Extract text from CSV files"""
110
+ try:
111
+ df = pd.read_csv(file_path)
112
+ return df.to_string()
113
+ except Exception as e:
114
+ logger.error(f"Error extracting text from CSV: {e}")
115
+ return ""
116
+
117
+ @staticmethod
118
+ def extract_text_from_html(file_path: str) -> str:
119
+ """Extract text from HTML files"""
120
+ try:
121
+ with open(file_path, 'r', encoding='utf-8') as file:
122
+ soup = BeautifulSoup(file.read(), 'html.parser')
123
+ return soup.get_text()
124
+ except Exception as e:
125
+ logger.error(f"Error extracting text from HTML: {e}")
126
+ return ""
127
+
128
+ @staticmethod
129
+ def extract_text_from_txt(file_path: str) -> str:
130
+ """Extract text from TXT files"""
131
+ try:
132
+ # Detect encoding
133
+ with open(file_path, 'rb') as file:
134
+ raw_data = file.read()
135
+ encoding = chardet.detect(raw_data)['encoding']
136
 
137
+ with open(file_path, 'r', encoding=encoding) as file:
138
+ return file.read()
139
+ except Exception as e:
140
+ logger.error(f"Error extracting text from TXT: {e}")
141
+ return ""
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ @staticmethod
144
+ def extract_text_from_json(file_path: str) -> str:
145
+ """Extract text from JSON files"""
146
+ try:
147
+ with open(file_path, 'r') as file:
148
+ data = json.load(file)
149
+ return json.dumps(data, indent=2)
150
+ except Exception as e:
151
+ logger.error(f"Error extracting text from JSON: {e}")
152
+ return ""
153
+
154
+ def process_document(self, file_path: str) -> str:
155
+ """Process document based on file extension"""
156
+ file_extension = Path(file_path).suffix.lower()
157
+
158
+ extractors = {
159
+ '.pdf': self.extract_text_from_pdf,
160
+ '.docx': self.extract_text_from_docx,
161
+ '.xlsx': self.extract_text_from_xlsx,
162
+ '.csv': self.extract_text_from_csv,
163
+ '.html': self.extract_text_from_html,
164
+ '.txt': self.extract_text_from_txt,
165
+ '.md': self.extract_text_from_txt,
166
+ '.json': self.extract_text_from_json,
167
+ }
168
+
169
+ extractor = extractors.get(file_extension)
170
+ if extractor:
171
+ return extractor(file_path)
172
+ else:
173
+ logger.warning(f"Unsupported file format: {file_extension}")
174
+ return ""
175
 
176
+ class TextChunker:
177
+ """Handles text chunking for RAG"""
178
+
179
+ def __init__(self, chunk_size: int = Config.CHUNK_SIZE, chunk_overlap: int = Config.CHUNK_OVERLAP):
180
+ self.chunk_size = chunk_size
181
+ self.chunk_overlap = chunk_overlap
182
 
183
+ def chunk_text(self, text: str, document_name: str = "") -> List[Dict[str, Any]]:
184
+ """Split text into chunks with metadata"""
185
+ chunks = []
186
+ sentences = re.split(r'[.!?]+', text)
187
 
188
+ current_chunk = ""
189
+ current_length = 0
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ for sentence in sentences:
192
+ sentence = sentence.strip()
193
+ if not sentence:
194
+ continue
195
+
196
+ sentence_length = len(sentence)
197
+
198
+ if current_length + sentence_length > self.chunk_size and current_chunk:
199
+ # Save current chunk
200
+ chunks.append({
201
+ 'text': current_chunk.strip(),
202
+ 'metadata': {
203
+ 'document_name': document_name,
204
+ 'chunk_id': len(chunks),
205
+ 'timestamp': datetime.now().isoformat()
206
+ }
207
+ })
208
+
209
+ # Start new chunk with overlap
210
+ overlap_text = current_chunk[-self.chunk_overlap:] if len(current_chunk) > self.chunk_overlap else current_chunk
211
+ current_chunk = overlap_text + " " + sentence
212
+ current_length = len(current_chunk)
213
+ else:
214
+ current_chunk += " " + sentence
215
+ current_length += sentence_length
216
 
217
+ # Add the last chunk
218
+ if current_chunk.strip():
219
+ chunks.append({
220
+ 'text': current_chunk.strip(),
221
+ 'metadata': {
222
+ 'document_name': document_name,
223
+ 'chunk_id': len(chunks),
224
+ 'timestamp': datetime.now().isoformat()
225
+ }
226
+ })
227
 
228
+ return chunks
 
229
 
230
+ class VectorStore:
231
+ """Handles vector storage and retrieval using ChromaDB"""
 
 
232
 
233
+ def __init__(self, collection_name: str = "enterprise_documents"):
234
+ self.client = chromadb.PersistentClient(path="./chroma_db")
235
+ self.collection_name = collection_name
236
+ self.embedding_model = SentenceTransformer(Config.EMBEDDING_MODEL_NAME)
237
+
238
+ # Create or get collection
239
+ try:
240
+ self.collection = self.client.get_collection(collection_name)
241
+ except:
242
+ self.collection = self.client.create_collection(
243
+ name=collection_name,
244
+ metadata={"description": "Enterprise document embeddings"}
245
+ )
246
 
247
+ def add_documents(self, chunks: List[Dict[str, Any]]) -> None:
248
+ """Add document chunks to vector store"""
249
+ texts = [chunk['text'] for chunk in chunks]
250
+ metadatas = [chunk['metadata'] for chunk in chunks]
251
+
252
+ # Generate embeddings
253
+ embeddings = self.embedding_model.encode(texts).tolist()
254
+
255
+ # Generate IDs
256
+ ids = [f"doc_{i}_{hashlib.md5(text.encode()).hexdigest()[:8]}"
257
+ for i, text in enumerate(texts)]
258
+
259
+ # Add to collection
260
+ self.collection.add(
261
+ documents=texts,
262
+ embeddings=embeddings,
263
+ metadatas=metadatas,
264
+ ids=ids
265
+ )
266
+
267
+ logger.info(f"Added {len(chunks)} chunks to vector store")
268
 
269
+ def similarity_search(self, query: str, k: int = Config.TOP_K) -> List[Dict[str, Any]]:
270
+ """Search for similar documents"""
271
+ query_embedding = self.embedding_model.encode([query]).tolist()
272
+
273
+ results = self.collection.query(
274
+ query_embeddings=query_embedding,
275
+ n_results=k,
276
+ include=['documents', 'metadatas', 'distances']
277
+ )
278
+
279
+ search_results = []
280
+ for i in range(len(results['documents'][0])):
281
+ search_results.append({
282
+ 'text': results['documents'][0][i],
283
+ 'metadata': results['metadatas'][0][i],
284
+ 'distance': results['distances'][0][i] if results['distances'] else 0
285
+ })
286
+
287
+ return search_results
288
 
289
+ def get_collection_stats(self) -> Dict[str, Any]:
290
+ """Get statistics about the collection"""
291
+ count = self.collection.count()
 
 
 
 
292
  return {
293
+ 'total_documents': count,
294
+ 'collection_name': self.collection_name
 
295
  }
 
 
296
 
297
+ class GraniteModel:
298
+ """Handles IBM Granite model loading and inference"""
 
 
 
 
 
 
 
299
 
300
+ def __init__(self):
301
+ self.model = None
302
+ self.tokenizer = None
303
+ self.guardian_pipeline = None
304
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
305
+
306
+ @st.cache_resource
307
+ def load_model(_self):
308
+ """Load Granite model with caching"""
309
+ try:
310
+ # Configure for efficient loading
311
+ if _self.device == "cuda":
312
+ quantization_config = BitsAndBytesConfig(
313
+ load_in_4bit=True,
314
+ bnb_4bit_compute_dtype=torch.float16,
315
+ bnb_4bit_use_double_quant=True,
316
+ bnb_4bit_quant_type="nf4"
317
+ )
318
+ else:
319
+ quantization_config = None
320
+
321
+ # Load tokenizer
322
+ _self.tokenizer = AutoTokenizer.from_pretrained(
323
+ Config.GRANITE_MODEL_NAME,
324
+ trust_remote_code=True
325
+ )
326
+
327
+ # Load model
328
+ _self.model = AutoModelForCausalLM.from_pretrained(
329
+ Config.GRANITE_MODEL_NAME,
330
+ quantization_config=quantization_config,
331
+ device_map="auto" if _self.device == "cuda" else None,
332
+ torch_dtype=torch.float16 if _self.device == "cuda" else torch.float32,
333
+ trust_remote_code=True
334
+ )
335
+
336
+ # Load Guardian model for safety
337
+ try:
338
+ _self.guardian_pipeline = pipeline(
339
+ "text-classification",
340
+ model=Config.GRANITE_GUARDIAN_MODEL,
341
+ device=0 if _self.device == "cuda" else -1
342
+ )
343
+ logger.info("Granite Guardian model loaded successfully")
344
+ except Exception as e:
345
+ logger.warning(f"Could not load Guardian model: {e}")
346
+ _self.guardian_pipeline = None
347
+
348
+ logger.info(f"Granite model loaded successfully on {_self.device}")
349
+ return True
350
+
351
+ except Exception as e:
352
+ logger.error(f"Error loading Granite model: {e}")
353
+ return False
354
 
355
+ def check_safety(self, text: str) -> bool:
356
+ """Check if text is safe using Guardian model"""
357
+ if not self.guardian_pipeline:
358
+ return True # If no guardian model, assume safe
359
+
360
+ try:
361
+ result = self.guardian_pipeline(text)
362
+ # Assuming Guardian returns safety classification
363
+ return result[0]['label'].lower() == 'safe'
364
+ except Exception as e:
365
+ logger.warning(f"Error in safety check: {e}")
366
+ return True # Default to safe if error
 
 
 
 
367
 
368
+ def generate_response(self, prompt: str, context: str = "") -> str:
369
+ """Generate response using Granite model"""
370
+ if not self.model or not self.tokenizer:
371
+ if not self.load_model():
372
+ return "Error: Could not load the model. Please check your setup."
373
 
374
+ # Safety check
375
+ if not self.check_safety(prompt):
376
+ return "I cannot provide a response to that query due to safety concerns."
377
 
378
+ # Construct the full prompt
379
+ system_prompt = """You are an Enterprise AI Assistant with access to company documents and policies.
380
+ Provide helpful, accurate, and professional responses based on the provided context.
381
+ If you cannot answer based on the context, say so clearly."""
382
 
383
+ if context:
384
+ full_prompt = f"{system_prompt}\n\nContext:\n{context}\n\nUser Question: {prompt}\n\nAssistant:"
385
+ else:
386
+ full_prompt = f"{system_prompt}\n\nUser Question: {prompt}\n\nAssistant:"
387
+
388
+ try:
389
+ # Tokenize input
390
+ inputs = self.tokenizer.encode(full_prompt, return_tensors='pt')
 
 
 
 
 
 
391
 
392
+ # Truncate if too long
393
+ if inputs.shape[1] > Config.MAX_CONTEXT_LENGTH:
394
+ inputs = inputs[:, -Config.MAX_CONTEXT_LENGTH:]
 
 
 
395
 
396
+ inputs = inputs.to(self.device)
 
 
 
 
 
 
 
 
397
 
398
+ # Generate response
399
+ with torch.no_grad():
400
+ outputs = self.model.generate(
401
+ inputs,
402
+ max_new_tokens=Config.MAX_NEW_TOKENS,
403
+ temperature=Config.TEMPERATURE,
404
+ do_sample=True,
405
+ pad_token_id=self.tokenizer.eos_token_id,
406
+ eos_token_id=self.tokenizer.eos_token_id,
407
+ repetition_penalty=1.1
408
+ )
409
 
410
+ # Decode response
411
+ response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
412
+ return response.strip()
413
 
414
+ except Exception as e:
415
+ logger.error(f"Error generating response: {e}")
416
+ return f"I apologize, but I encountered an error while generating a response: {str(e)}"
417
 
418
+ class EnterpriseRAGAssistant:
419
+ """Main RAG Assistant class"""
 
 
 
 
420
 
421
+ def __init__(self):
422
+ self.doc_processor = DocumentProcessor()
423
+ self.text_chunker = TextChunker()
424
+ self.vector_store = VectorStore()
425
+ self.granite_model = GraniteModel()
426
+
427
+ def process_and_store_documents(self, uploaded_files) -> Dict[str, Any]:
428
+ """Process uploaded files and store in vector database"""
429
+ results = {
430
+ 'processed_files': [],
431
+ 'errors': [],
432
+ 'total_chunks': 0
433
+ }
434
+
435
+ for uploaded_file in uploaded_files:
436
+ try:
437
+ # Save uploaded file temporarily
438
+ with tempfile.NamedTemporaryFile(delete=False, suffix=Path(uploaded_file.name).suffix) as tmp_file:
439
+ tmp_file.write(uploaded_file.read())
440
+ tmp_file_path = tmp_file.name
441
+
442
+ # Extract text
443
+ text = self.doc_processor.process_document(tmp_file_path)
444
+
445
+ if text:
446
+ # Chunk text
447
+ chunks = self.text_chunker.chunk_text(text, uploaded_file.name)
448
+
449
+ # Store in vector database
450
+ self.vector_store.add_documents(chunks)
451
+
452
+ results['processed_files'].append({
453
+ 'name': uploaded_file.name,
454
+ 'chunks': len(chunks),
455
+ 'text_length': len(text)
456
+ })
457
+ results['total_chunks'] += len(chunks)
458
+ else:
459
+ results['errors'].append(f"Could not extract text from {uploaded_file.name}")
460
+
461
+ # Clean up temporary file
462
+ os.unlink(tmp_file_path)
463
+
464
+ except Exception as e:
465
+ results['errors'].append(f"Error processing {uploaded_file.name}: {str(e)}")
466
+
467
+ return results
468
 
469
+ def answer_query(self, query: str) -> Dict[str, Any]:
470
+ """Answer user query using RAG"""
471
+ # Retrieve relevant documents
472
+ search_results = self.vector_store.similarity_search(query)
473
+
474
+ # Prepare context
475
+ context = "\n\n".join([result['text'] for result in search_results])
476
+
477
+ # Generate response
478
+ response = self.granite_model.generate_response(query, context)
479
+
480
+ return {
481
+ 'response': response,
482
+ 'sources': search_results,
483
+ 'context_used': bool(context)
484
+ }
485
+
486
+ def main():
487
+ """Main Streamlit application"""
488
+ st.set_page_config(
489
+ page_title="Enterprise AI Assistant with RAG",
490
+ page_icon="🏒",
491
+ layout="wide",
492
+ initial_sidebar_state="expanded"
493
+ )
494
 
495
+ # Custom CSS
496
+ st.markdown("""
497
+ <style>
498
+ .main-header {
499
+ font-size: 2.5rem;
500
+ color: #1f4e79;
501
+ text-align: center;
502
+ margin-bottom: 2rem;
503
+ }
504
+ .stButton > button {
505
+ background-color: #0f62fe;
506
+ color: white;
507
+ font-weight: bold;
508
+ }
509
+ .success-box {
510
+ padding: 1rem;
511
+ border-radius: 0.5rem;
512
+ background-color: #d4edda;
513
+ border: 1px solid #c3e6cb;
514
+ color: #155724;
515
+ }
516
+ .error-box {
517
+ padding: 1rem;
518
+ border-radius: 0.5rem;
519
+ background-color: #f8d7da;
520
+ border: 1px solid #f5c6cb;
521
+ color: #721c24;
522
+ }
523
+ </style>
524
+ """, unsafe_allow_html=True)
525
 
526
+ # Initialize session state
527
+ if 'rag_assistant' not in st.session_state:
528
+ st.session_state.rag_assistant = EnterpriseRAGAssistant()
 
 
529
 
530
+ if 'chat_history' not in st.session_state:
531
+ st.session_state.chat_history = []
 
532
 
533
+ # Header
534
+ st.markdown('<h1 class="main-header">🏒 Enterprise AI Assistant with RAG</h1>', unsafe_allow_html=True)
535
+ st.markdown("**Powered by IBM Granite Models | Intelligent Document Processing & Q&A**")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
 
537
+ # Sidebar
538
+ with st.sidebar:
539
+ st.header("πŸ“ Document Management")
540
+
541
+ # File upload
542
+ uploaded_files = st.file_uploader(
543
+ "Upload Enterprise Documents",
544
+ type=['pdf', 'docx', 'xlsx', 'txt', 'csv', 'html', 'json', 'md'],
545
+ accept_multiple_files=True,
546
+ help="Upload documents to build your knowledge base"
547
+ )
548
+
549
+ if uploaded_files and st.button("Process Documents", type="primary"):
550
+ with st.spinner("Processing documents..."):
551
+ results = st.session_state.rag_assistant.process_and_store_documents(uploaded_files)
552
 
553
+ if results['processed_files']:
554
+ st.markdown('<div class="success-box">', unsafe_allow_html=True)
555
+ st.success(f"Successfully processed {len(results['processed_files'])} files!")
556
+ st.write(f"Total chunks created: {results['total_chunks']}")
557
+
558
+ for file_info in results['processed_files']:
559
+ st.write(f"βœ“ {file_info['name']}: {file_info['chunks']} chunks")
560
+ st.markdown('</div>', unsafe_allow_html=True)
561
+
562
+ if results['errors']:
563
+ st.markdown('<div class="error-box">', unsafe_allow_html=True)
564
+ st.error("Some files had errors:")
565
+ for error in results['errors']:
566
+ st.write(f"βœ— {error}")
567
+ st.markdown('</div>', unsafe_allow_html=True)
568
 
569
+ # Database stats
570
+ st.header("πŸ“Š Knowledge Base Stats")
571
+ try:
572
+ stats = st.session_state.rag_assistant.vector_store.get_collection_stats()
573
+ st.metric("Total Documents", stats['total_documents'])
574
+ except:
575
+ st.metric("Total Documents", 0)
 
576
 
577
+ # Model info
578
+ st.header("πŸ€– Model Information")
579
+ st.info(f"**Main Model**: {Config.GRANITE_MODEL_NAME}")
580
+ st.info(f"**Safety Model**: {Config.GRANITE_GUARDIAN_MODEL}")
581
+ st.info(f"**Embedding Model**: {Config.EMBEDDING_MODEL_NAME}")
582
 
583
+ # Main content area
584
+ col1, col2 = st.columns([2, 1])
 
 
 
 
 
 
 
 
585
 
586
+ with col1:
587
+ st.header("πŸ’¬ Chat with Your Documents")
588
+
589
+ # Chat interface
590
+ query = st.text_input(
591
+ "Ask a question about your documents:",
592
+ placeholder="e.g., What is our company's policy on remote work?",
593
+ key="user_query"
594
+ )
595
+
596
+ if st.button("Send Query", type="primary") and query:
597
+ with st.spinner("Generating response..."):
598
+ result = st.session_state.rag_assistant.answer_query(query)
599
+
600
+ # Add to chat history
601
+ st.session_state.chat_history.append({
602
+ 'query': query,
603
+ 'response': result['response'],
604
+ 'sources': result['sources'],
605
+ 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
606
+ })
607
+
608
+ # Display chat history
609
+ if st.session_state.chat_history:
610
+ st.header("πŸ“œ Chat History")
611
+
612
+ for i, chat in enumerate(reversed(st.session_state.chat_history)):
613
+ with st.expander(f"Q: {chat['query'][:50]}... ({chat['timestamp']})", expanded=i==0):
614
+ st.markdown("**Question:**")
615
+ st.write(chat['query'])
616
+
617
+ st.markdown("**Answer:**")
618
+ st.write(chat['response'])
619
+
620
+ if chat['sources']:
621
+ st.markdown("**Sources:**")
622
+ for j, source in enumerate(chat['sources'][:3]):
623
+ st.markdown(f"**Source {j+1}** (from {source['metadata']['document_name']}):")
624
+ st.text(source['text'][:200] + "...")
625
+
626
+ with col2:
627
+ st.header("πŸ” Search Results")
628
+
629
+ if st.session_state.chat_history:
630
+ latest_chat = st.session_state.chat_history[-1]
631
+
632
+ st.subheader("Latest Query Sources")
633
+ for i, source in enumerate(latest_chat['sources']):
634
+ with st.expander(f"Source {i+1}: {source['metadata']['document_name']}"):
635
+ st.write(f"**Relevance Score**: {1 - source['distance']:.3f}")
636
+ st.write(f"**Document**: {source['metadata']['document_name']}")
637
+ st.write(f"**Chunk ID**: {source['metadata']['chunk_id']}")
638
+ st.text_area("Content", source['text'], height=150, disabled=True)
639
+
640
+ # Quick actions
641
+ st.header("⚑ Quick Actions")
642
+ if st.button("Clear Chat History"):
643
+ st.session_state.chat_history = []
644
+ st.rerun()
645
+
646
+ if st.button("Reset Knowledge Base"):
647
+ if st.confirm("Are you sure you want to reset the knowledge base? This cannot be undone."):
648
+ # This would require implementing a reset method
649
+ st.warning("Knowledge base reset functionality would be implemented here")
650
 
651
+ # Footer
652
+ st.markdown("---")
653
+ st.markdown(
654
+ "Built with ❀️ using IBM Granite Models, Streamlit, and ChromaDB | "
655
+ "Enterprise-grade AI Assistant for document processing and intelligent Q&A"
 
 
 
 
656
  )
657
 
658
  if __name__ == "__main__":
659
+ main()