SimranShaikh commited on
Commit
5f45822
Β·
verified Β·
1 Parent(s): db372b0
Files changed (1) hide show
  1. src/streamlit_app.py +388 -550
src/streamlit_app.py CHANGED
@@ -1,659 +1,497 @@
1
- # Enterprise AI Assistant with RAG using IBM Granite Models
2
- # Complete implementation with Streamlit interface
3
 
4
  import streamlit as st
5
  import os
6
- import pandas as pd
7
- import numpy as np
8
- from typing import List, Dict, Any, Optional
9
  import tempfile
10
- import json
11
- from datetime import datetime
12
- import logging
13
  from pathlib import Path
 
14
 
15
- # Core libraries for RAG
16
- import chromadb
17
- from sentence_transformers import SentenceTransformer
18
- import torch
19
- from transformers import (
20
- AutoTokenizer,
21
- AutoModelForCausalLM,
22
- pipeline,
23
- BitsAndBytesConfig
24
  )
25
-
26
- # Document processing
27
- import PyPDF2
28
- from docx import Document
29
- import openpyxl
30
- from bs4 import BeautifulSoup
31
- import email
32
- from email.mime.text import MIMEText
33
- import chardet
34
-
35
- # Additional utilities
36
- import re
37
- from urllib.parse import urlparse
38
- import hashlib
39
- import pickle
40
-
41
- # Configure logging
42
- logging.basicConfig(level=logging.INFO)
43
  logger = logging.getLogger(__name__)
44
 
45
- # Configuration
46
- class Config:
 
 
 
 
 
 
47
  GRANITE_MODEL_NAME = "ibm-granite/granite-3.1-8b-instruct"
48
- GRANITE_GUARDIAN_MODEL = "ibm-granite/granite-guardian-3.2-5b"
49
  EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
50
- CHUNK_SIZE = 1000
51
- CHUNK_OVERLAP = 200
52
- MAX_CONTEXT_LENGTH = 4000
53
- TEMPERATURE = 0.7
54
- MAX_NEW_TOKENS = 512
55
- TOP_K = 5
56
- SUPPORTED_FORMATS = ['.pdf', '.docx', '.xlsx', '.txt', '.csv', '.html', '.json', '.md']
57
-
58
- class DocumentProcessor:
59
- """Handles document processing and text extraction"""
60
 
61
  @staticmethod
62
- def extract_text_from_pdf(file_path: str) -> str:
63
- """Extract text from PDF files"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  try:
65
- with open(file_path, 'rb') as file:
66
- reader = PyPDF2.PdfReader(file)
67
- text = ""
68
- for page in reader.pages:
69
- text += page.extract_text() + "\n"
70
- return text
71
  except Exception as e:
72
- logger.error(f"Error extracting text from PDF: {e}")
73
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- @staticmethod
76
- def extract_text_from_docx(file_path: str) -> str:
77
- """Extract text from DOCX files"""
 
 
 
78
  try:
79
- doc = Document(file_path)
80
- text = ""
81
- for paragraph in doc.paragraphs:
82
- text += paragraph.text + "\n"
83
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  except Exception as e:
85
- logger.error(f"Error extracting text from DOCX: {e}")
86
- return ""
87
 
88
- @staticmethod
89
- def extract_text_from_xlsx(file_path: str) -> str:
90
- """Extract text from Excel files"""
 
 
91
  try:
92
- workbook = openpyxl.load_workbook(file_path)
93
- text = ""
94
- for sheet_name in workbook.sheetnames:
95
- sheet = workbook[sheet_name]
96
- text += f"Sheet: {sheet_name}\n"
97
- for row in sheet.iter_rows(values_only=True):
98
- row_text = " | ".join([str(cell) if cell else "" for cell in row])
99
- if row_text.strip():
100
- text += row_text + "\n"
101
- text += "\n"
102
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  except Exception as e:
104
- logger.error(f"Error extracting text from XLSX: {e}")
105
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  @staticmethod
108
- def extract_text_from_csv(file_path: str) -> str:
109
- """Extract text from CSV files"""
110
  try:
111
- df = pd.read_csv(file_path)
112
- return df.to_string()
 
 
 
113
  except Exception as e:
114
- logger.error(f"Error extracting text from CSV: {e}")
115
  return ""
116
 
117
  @staticmethod
118
- def extract_text_from_html(file_path: str) -> str:
119
- """Extract text from HTML files"""
120
  try:
121
- with open(file_path, 'r', encoding='utf-8') as file:
122
- soup = BeautifulSoup(file.read(), 'html.parser')
123
- return soup.get_text()
124
  except Exception as e:
125
- logger.error(f"Error extracting text from HTML: {e}")
126
  return ""
127
 
128
  @staticmethod
129
- def extract_text_from_txt(file_path: str) -> str:
130
- """Extract text from TXT files"""
131
  try:
132
- # Detect encoding
133
- with open(file_path, 'rb') as file:
134
- raw_data = file.read()
135
- encoding = chardet.detect(raw_data)['encoding']
136
-
137
- with open(file_path, 'r', encoding=encoding) as file:
138
- return file.read()
139
  except Exception as e:
140
- logger.error(f"Error extracting text from TXT: {e}")
141
  return ""
142
 
143
  @staticmethod
144
- def extract_text_from_json(file_path: str) -> str:
145
- """Extract text from JSON files"""
146
  try:
147
- with open(file_path, 'r') as file:
148
- data = json.load(file)
149
- return json.dumps(data, indent=2)
150
  except Exception as e:
151
- logger.error(f"Error extracting text from JSON: {e}")
152
  return ""
153
-
154
- def process_document(self, file_path: str) -> str:
155
- """Process document based on file extension"""
156
- file_extension = Path(file_path).suffix.lower()
157
-
158
- extractors = {
159
- '.pdf': self.extract_text_from_pdf,
160
- '.docx': self.extract_text_from_docx,
161
- '.xlsx': self.extract_text_from_xlsx,
162
- '.csv': self.extract_text_from_csv,
163
- '.html': self.extract_text_from_html,
164
- '.txt': self.extract_text_from_txt,
165
- '.md': self.extract_text_from_txt,
166
- '.json': self.extract_text_from_json,
167
- }
168
-
169
- extractor = extractors.get(file_extension)
170
- if extractor:
171
- return extractor(file_path)
172
- else:
173
- logger.warning(f"Unsupported file format: {file_extension}")
174
- return ""
175
-
176
- class TextChunker:
177
- """Handles text chunking for RAG"""
178
-
179
- def __init__(self, chunk_size: int = Config.CHUNK_SIZE, chunk_overlap: int = Config.CHUNK_OVERLAP):
180
- self.chunk_size = chunk_size
181
- self.chunk_overlap = chunk_overlap
182
-
183
- def chunk_text(self, text: str, document_name: str = "") -> List[Dict[str, Any]]:
184
- """Split text into chunks with metadata"""
185
- chunks = []
186
- sentences = re.split(r'[.!?]+', text)
187
-
188
- current_chunk = ""
189
- current_length = 0
190
-
191
- for sentence in sentences:
192
- sentence = sentence.strip()
193
- if not sentence:
194
- continue
195
-
196
- sentence_length = len(sentence)
197
-
198
- if current_length + sentence_length > self.chunk_size and current_chunk:
199
- # Save current chunk
200
- chunks.append({
201
- 'text': current_chunk.strip(),
202
- 'metadata': {
203
- 'document_name': document_name,
204
- 'chunk_id': len(chunks),
205
- 'timestamp': datetime.now().isoformat()
206
- }
207
- })
208
-
209
- # Start new chunk with overlap
210
- overlap_text = current_chunk[-self.chunk_overlap:] if len(current_chunk) > self.chunk_overlap else current_chunk
211
- current_chunk = overlap_text + " " + sentence
212
- current_length = len(current_chunk)
213
- else:
214
- current_chunk += " " + sentence
215
- current_length += sentence_length
216
-
217
- # Add the last chunk
218
- if current_chunk.strip():
219
- chunks.append({
220
- 'text': current_chunk.strip(),
221
- 'metadata': {
222
- 'document_name': document_name,
223
- 'chunk_id': len(chunks),
224
- 'timestamp': datetime.now().isoformat()
225
- }
226
- })
227
-
228
- return chunks
229
 
230
- class VectorStore:
231
- """Handles vector storage and retrieval using ChromaDB"""
232
-
233
- def __init__(self, collection_name: str = "enterprise_documents"):
234
- self.client = chromadb.PersistentClient(path="./chroma_db")
235
- self.collection_name = collection_name
236
- self.embedding_model = SentenceTransformer(Config.EMBEDDING_MODEL_NAME)
237
-
238
- # Create or get collection
239
- try:
240
- self.collection = self.client.get_collection(collection_name)
241
- except:
242
- self.collection = self.client.create_collection(
243
- name=collection_name,
244
- metadata={"description": "Enterprise document embeddings"}
245
- )
246
-
247
- def add_documents(self, chunks: List[Dict[str, Any]]) -> None:
248
- """Add document chunks to vector store"""
249
- texts = [chunk['text'] for chunk in chunks]
250
- metadatas = [chunk['metadata'] for chunk in chunks]
251
-
252
- # Generate embeddings
253
- embeddings = self.embedding_model.encode(texts).tolist()
254
-
255
- # Generate IDs
256
- ids = [f"doc_{i}_{hashlib.md5(text.encode()).hexdigest()[:8]}"
257
- for i, text in enumerate(texts)]
258
-
259
- # Add to collection
260
- self.collection.add(
261
- documents=texts,
262
- embeddings=embeddings,
263
- metadatas=metadatas,
264
- ids=ids
265
- )
266
-
267
- logger.info(f"Added {len(chunks)} chunks to vector store")
268
-
269
- def similarity_search(self, query: str, k: int = Config.TOP_K) -> List[Dict[str, Any]]:
270
- """Search for similar documents"""
271
- query_embedding = self.embedding_model.encode([query]).tolist()
272
-
273
- results = self.collection.query(
274
- query_embeddings=query_embedding,
275
- n_results=k,
276
- include=['documents', 'metadatas', 'distances']
277
- )
278
-
279
- search_results = []
280
- for i in range(len(results['documents'][0])):
281
- search_results.append({
282
- 'text': results['documents'][0][i],
283
- 'metadata': results['metadatas'][0][i],
284
- 'distance': results['distances'][0][i] if results['distances'] else 0
285
- })
286
-
287
- return search_results
288
-
289
- def get_collection_stats(self) -> Dict[str, Any]:
290
- """Get statistics about the collection"""
291
- count = self.collection.count()
292
- return {
293
- 'total_documents': count,
294
- 'collection_name': self.collection_name
295
- }
296
-
297
- class GraniteModel:
298
- """Handles IBM Granite model loading and inference"""
299
 
300
  def __init__(self):
 
 
301
  self.model = None
302
  self.tokenizer = None
303
- self.guardian_pipeline = None
304
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
305
 
306
- @st.cache_resource
307
- def load_model(_self):
308
- """Load Granite model with caching"""
309
  try:
310
- # Configure for efficient loading
311
- if _self.device == "cuda":
312
- quantization_config = BitsAndBytesConfig(
313
- load_in_4bit=True,
314
- bnb_4bit_compute_dtype=torch.float16,
315
- bnb_4bit_use_double_quant=True,
316
- bnb_4bit_quant_type="nf4"
317
- )
318
- else:
319
- quantization_config = None
320
 
321
- # Load tokenizer
322
- _self.tokenizer = AutoTokenizer.from_pretrained(
323
- Config.GRANITE_MODEL_NAME,
324
- trust_remote_code=True
325
- )
326
 
327
- # Load model
328
- _self.model = AutoModelForCausalLM.from_pretrained(
329
- Config.GRANITE_MODEL_NAME,
330
- quantization_config=quantization_config,
331
- device_map="auto" if _self.device == "cuda" else None,
332
- torch_dtype=torch.float16 if _self.device == "cuda" else torch.float32,
333
- trust_remote_code=True
334
- )
335
-
336
- # Load Guardian model for safety
337
- try:
338
- _self.guardian_pipeline = pipeline(
339
- "text-classification",
340
- model=Config.GRANITE_GUARDIAN_MODEL,
341
- device=0 if _self.device == "cuda" else -1
342
- )
343
- logger.info("Granite Guardian model loaded successfully")
344
- except Exception as e:
345
- logger.warning(f"Could not load Guardian model: {e}")
346
- _self.guardian_pipeline = None
347
 
348
- logger.info(f"Granite model loaded successfully on {_self.device}")
349
  return True
350
 
351
  except Exception as e:
352
- logger.error(f"Error loading Granite model: {e}")
353
  return False
354
 
355
- def check_safety(self, text: str) -> bool:
356
- """Check if text is safe using Guardian model"""
357
- if not self.guardian_pipeline:
358
- return True # If no guardian model, assume safe
359
 
360
- try:
361
- result = self.guardian_pipeline(text)
362
- # Assuming Guardian returns safety classification
363
- return result[0]['label'].lower() == 'safe'
364
- except Exception as e:
365
- logger.warning(f"Error in safety check: {e}")
366
- return True # Default to safe if error
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
- def generate_response(self, prompt: str, context: str = "") -> str:
369
- """Generate response using Granite model"""
370
- if not self.model or not self.tokenizer:
371
- if not self.load_model():
372
- return "Error: Could not load the model. Please check your setup."
373
 
374
- # Safety check
375
- if not self.check_safety(prompt):
376
- return "I cannot provide a response to that query due to safety concerns."
377
 
378
- # Construct the full prompt
379
- system_prompt = """You are an Enterprise AI Assistant with access to company documents and policies.
380
- Provide helpful, accurate, and professional responses based on the provided context.
381
- If you cannot answer based on the context, say so clearly."""
 
 
 
 
 
 
382
 
383
- if context:
384
- full_prompt = f"{system_prompt}\n\nContext:\n{context}\n\nUser Question: {prompt}\n\nAssistant:"
385
- else:
386
- full_prompt = f"{system_prompt}\n\nUser Question: {prompt}\n\nAssistant:"
 
387
 
388
- try:
389
- # Tokenize input
390
- inputs = self.tokenizer.encode(full_prompt, return_tensors='pt')
391
-
392
- # Truncate if too long
393
- if inputs.shape[1] > Config.MAX_CONTEXT_LENGTH:
394
- inputs = inputs[:, -Config.MAX_CONTEXT_LENGTH:]
395
-
396
- inputs = inputs.to(self.device)
397
-
398
- # Generate response
399
- with torch.no_grad():
400
- outputs = self.model.generate(
401
- inputs,
402
- max_new_tokens=Config.MAX_NEW_TOKENS,
403
- temperature=Config.TEMPERATURE,
404
- do_sample=True,
405
- pad_token_id=self.tokenizer.eos_token_id,
406
- eos_token_id=self.tokenizer.eos_token_id,
407
- repetition_penalty=1.1
408
- )
409
-
410
- # Decode response
411
- response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
412
- return response.strip()
413
-
414
- except Exception as e:
415
- logger.error(f"Error generating response: {e}")
416
- return f"I apologize, but I encountered an error while generating a response: {str(e)}"
417
-
418
- class EnterpriseRAGAssistant:
419
- """Main RAG Assistant class"""
420
 
421
- def __init__(self):
422
- self.doc_processor = DocumentProcessor()
423
- self.text_chunker = TextChunker()
424
- self.vector_store = VectorStore()
425
- self.granite_model = GraniteModel()
426
 
427
- def process_and_store_documents(self, uploaded_files) -> Dict[str, Any]:
428
- """Process uploaded files and store in vector database"""
429
- results = {
430
- 'processed_files': [],
431
- 'errors': [],
432
- 'total_chunks': 0
433
- }
 
434
 
435
- for uploaded_file in uploaded_files:
 
436
  try:
437
- # Save uploaded file temporarily
438
- with tempfile.NamedTemporaryFile(delete=False, suffix=Path(uploaded_file.name).suffix) as tmp_file:
439
- tmp_file.write(uploaded_file.read())
440
- tmp_file_path = tmp_file.name
441
 
442
- # Extract text
443
- text = self.doc_processor.process_document(tmp_file_path)
444
-
445
- if text:
446
- # Chunk text
447
- chunks = self.text_chunker.chunk_text(text, uploaded_file.name)
448
-
449
- # Store in vector database
450
- self.vector_store.add_documents(chunks)
451
-
452
- results['processed_files'].append({
453
- 'name': uploaded_file.name,
454
- 'chunks': len(chunks),
455
- 'text_length': len(text)
456
- })
457
- results['total_chunks'] += len(chunks)
458
- else:
459
- results['errors'].append(f"Could not extract text from {uploaded_file.name}")
460
 
461
- # Clean up temporary file
462
- os.unlink(tmp_file_path)
463
 
464
  except Exception as e:
465
- results['errors'].append(f"Error processing {uploaded_file.name}: {str(e)}")
466
-
467
- return results
468
-
469
- def answer_query(self, query: str) -> Dict[str, Any]:
470
- """Answer user query using RAG"""
471
- # Retrieve relevant documents
472
- search_results = self.vector_store.similarity_search(query)
473
-
474
- # Prepare context
475
- context = "\n\n".join([result['text'] for result in search_results])
476
-
477
- # Generate response
478
- response = self.granite_model.generate_response(query, context)
479
 
480
  return {
481
  'response': response,
482
- 'sources': search_results,
483
- 'context_used': bool(context)
484
  }
485
 
486
  def main():
487
  """Main Streamlit application"""
488
  st.set_page_config(
489
- page_title="Enterprise AI Assistant with RAG",
490
  page_icon="🏒",
491
- layout="wide",
492
- initial_sidebar_state="expanded"
493
  )
494
 
495
- # Custom CSS
496
- st.markdown("""
497
- <style>
498
- .main-header {
499
- font-size: 2.5rem;
500
- color: #1f4e79;
501
- text-align: center;
502
- margin-bottom: 2rem;
503
- }
504
- .stButton > button {
505
- background-color: #0f62fe;
506
- color: white;
507
- font-weight: bold;
508
- }
509
- .success-box {
510
- padding: 1rem;
511
- border-radius: 0.5rem;
512
- background-color: #d4edda;
513
- border: 1px solid #c3e6cb;
514
- color: #155724;
515
- }
516
- .error-box {
517
- padding: 1rem;
518
- border-radius: 0.5rem;
519
- background-color: #f8d7da;
520
- border: 1px solid #f5c6cb;
521
- color: #721c24;
522
- }
523
- </style>
524
- """, unsafe_allow_html=True)
525
 
526
  # Initialize session state
527
  if 'rag_assistant' not in st.session_state:
528
- st.session_state.rag_assistant = EnterpriseRAGAssistant()
 
529
 
530
  if 'chat_history' not in st.session_state:
531
  st.session_state.chat_history = []
532
 
533
- # Header
534
- st.markdown('<h1 class="main-header">🏒 Enterprise AI Assistant with RAG</h1>', unsafe_allow_html=True)
535
- st.markdown("**Powered by IBM Granite Models | Intelligent Document Processing & Q&A**")
536
-
537
  # Sidebar
538
  with st.sidebar:
539
- st.header("πŸ“ Document Management")
540
 
541
- # File upload
542
  uploaded_files = st.file_uploader(
543
- "Upload Enterprise Documents",
544
- type=['pdf', 'docx', 'xlsx', 'txt', 'csv', 'html', 'json', 'md'],
545
- accept_multiple_files=True,
546
- help="Upload documents to build your knowledge base"
547
  )
548
 
549
- if uploaded_files and st.button("Process Documents", type="primary"):
550
  with st.spinner("Processing documents..."):
551
- results = st.session_state.rag_assistant.process_and_store_documents(uploaded_files)
552
 
553
- if results['processed_files']:
554
- st.markdown('<div class="success-box">', unsafe_allow_html=True)
555
- st.success(f"Successfully processed {len(results['processed_files'])} files!")
556
- st.write(f"Total chunks created: {results['total_chunks']}")
557
-
558
- for file_info in results['processed_files']:
559
- st.write(f"βœ“ {file_info['name']}: {file_info['chunks']} chunks")
560
- st.markdown('</div>', unsafe_allow_html=True)
561
 
562
  if results['errors']:
563
- st.markdown('<div class="error-box">', unsafe_allow_html=True)
564
- st.error("Some files had errors:")
565
  for error in results['errors']:
566
- st.write(f"βœ— {error}")
567
- st.markdown('</div>', unsafe_allow_html=True)
568
-
569
- # Database stats
570
- st.header("πŸ“Š Knowledge Base Stats")
571
- try:
572
- stats = st.session_state.rag_assistant.vector_store.get_collection_stats()
573
- st.metric("Total Documents", stats['total_documents'])
574
- except:
575
- st.metric("Total Documents", 0)
576
-
577
- # Model info
578
- st.header("πŸ€– Model Information")
579
- st.info(f"**Main Model**: {Config.GRANITE_MODEL_NAME}")
580
- st.info(f"**Safety Model**: {Config.GRANITE_GUARDIAN_MODEL}")
581
- st.info(f"**Embedding Model**: {Config.EMBEDDING_MODEL_NAME}")
 
582
 
583
- # Main content area
584
- col1, col2 = st.columns([2, 1])
585
 
586
- with col1:
587
- st.header("πŸ’¬ Chat with Your Documents")
588
-
589
- # Chat interface
590
- query = st.text_input(
591
- "Ask a question about your documents:",
592
- placeholder="e.g., What is our company's policy on remote work?",
593
- key="user_query"
594
- )
595
-
596
- if st.button("Send Query", type="primary") and query:
597
- with st.spinner("Generating response..."):
598
- result = st.session_state.rag_assistant.answer_query(query)
599
-
600
- # Add to chat history
601
- st.session_state.chat_history.append({
602
- 'query': query,
603
- 'response': result['response'],
604
- 'sources': result['sources'],
605
- 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
606
- })
607
-
608
- # Display chat history
609
- if st.session_state.chat_history:
610
- st.header("πŸ“œ Chat History")
611
-
612
- for i, chat in enumerate(reversed(st.session_state.chat_history)):
613
- with st.expander(f"Q: {chat['query'][:50]}... ({chat['timestamp']})", expanded=i==0):
614
- st.markdown("**Question:**")
615
- st.write(chat['query'])
616
-
617
- st.markdown("**Answer:**")
618
- st.write(chat['response'])
619
-
620
- if chat['sources']:
621
- st.markdown("**Sources:**")
622
- for j, source in enumerate(chat['sources'][:3]):
623
- st.markdown(f"**Source {j+1}** (from {source['metadata']['document_name']}):")
624
- st.text(source['text'][:200] + "...")
625
 
626
- with col2:
627
- st.header("πŸ” Search Results")
628
-
629
- if st.session_state.chat_history:
630
- latest_chat = st.session_state.chat_history[-1]
 
 
 
 
 
 
 
 
 
 
631
 
632
- st.subheader("Latest Query Sources")
633
- for i, source in enumerate(latest_chat['sources']):
634
- with st.expander(f"Source {i+1}: {source['metadata']['document_name']}"):
635
- st.write(f"**Relevance Score**: {1 - source['distance']:.3f}")
636
- st.write(f"**Document**: {source['metadata']['document_name']}")
637
- st.write(f"**Chunk ID**: {source['metadata']['chunk_id']}")
638
- st.text_area("Content", source['text'], height=150, disabled=True)
639
-
640
- # Quick actions
641
- st.header("⚑ Quick Actions")
642
- if st.button("Clear Chat History"):
643
- st.session_state.chat_history = []
644
- st.rerun()
645
-
646
- if st.button("Reset Knowledge Base"):
647
- if st.confirm("Are you sure you want to reset the knowledge base? This cannot be undone."):
648
- # This would require implementing a reset method
649
- st.warning("Knowledge base reset functionality would be implemented here")
650
 
651
  # Footer
652
  st.markdown("---")
653
- st.markdown(
654
- "Built with ❀️ using IBM Granite Models, Streamlit, and ChromaDB | "
655
- "Enterprise-grade AI Assistant for document processing and intelligent Q&A"
656
- )
657
 
658
  if __name__ == "__main__":
659
  main()
 
1
+ # HF Spaces Optimized Enterprise AI Assistant with RAG
2
+ # Handles permission issues and environment constraints
3
 
4
  import streamlit as st
5
  import os
6
+ import sys
 
 
7
  import tempfile
8
+ import shutil
 
 
9
  from pathlib import Path
10
+ import logging
11
 
12
+ # Configure logging for HF Spaces
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
16
+ handlers=[logging.StreamHandler(sys.stdout)]
 
 
 
 
17
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  logger = logging.getLogger(__name__)
19
 
20
+ # Environment detection and configuration
21
+ class HFSpacesConfig:
22
+ """Configuration optimized for Hugging Face Spaces"""
23
+
24
+ # Detect if running on HF Spaces
25
+ IS_HF_SPACES = os.getenv("SPACE_ID") is not None
26
+
27
+ # Model configurations (optimized for HF Spaces)
28
  GRANITE_MODEL_NAME = "ibm-granite/granite-3.1-8b-instruct"
 
29
  EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
30
+
31
+ # Reduced parameters for HF Spaces constraints
32
+ CHUNK_SIZE = 512
33
+ MAX_CONTEXT_LENGTH = 2048
34
+ MAX_NEW_TOKENS = 256
35
+ TOP_K = 3
 
 
 
 
36
 
37
  @staticmethod
38
+ def get_temp_dir():
39
+ """Get appropriate temporary directory"""
40
+ if HFSpacesConfig.IS_HF_SPACES:
41
+ # Use /tmp in HF Spaces
42
+ return "/tmp/chroma_db"
43
+ else:
44
+ return tempfile.mkdtemp(prefix="chroma_db_")
45
+
46
+ @staticmethod
47
+ def setup_environment():
48
+ """Setup environment variables for HF Spaces"""
49
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
50
+ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
51
+ os.environ["TRANSFORMERS_VERBOSITY"] = "error"
52
+
53
+ # Initialize environment
54
+ HFSpacesConfig.setup_environment()
55
+
56
+ # Import ML libraries after environment setup
57
+ try:
58
+ import chromadb
59
+ from sentence_transformers import SentenceTransformer
60
+ import torch
61
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
62
+ import pandas as pd
63
+ import numpy as np
64
+ DEPENDENCIES_AVAILABLE = True
65
+ except ImportError as e:
66
+ logger.error(f"Missing dependencies: {e}")
67
+ DEPENDENCIES_AVAILABLE = False
68
+
69
+ # Lightweight document processors
70
+ import PyPDF2
71
+ from docx import Document
72
+ import json
73
+ import csv
74
+
75
+ class SimpleVectorStore:
76
+ """Simplified vector store with better error handling"""
77
+
78
+ def __init__(self):
79
+ self.documents = []
80
+ self.embeddings = []
81
+ self.metadata = []
82
+ self.embedding_model = None
83
+ self.chroma_client = None
84
+ self.chroma_collection = None
85
+
86
+ self._initialize_storage()
87
+
88
+ def _initialize_storage(self):
89
+ """Initialize storage with fallback options"""
90
+ # Try to initialize embedding model
91
  try:
92
+ self.embedding_model = SentenceTransformer(HFSpacesConfig.EMBEDDING_MODEL_NAME)
93
+ logger.info("Embedding model loaded successfully")
 
 
 
 
94
  except Exception as e:
95
+ logger.error(f"Failed to load embedding model: {e}")
96
+ return False
97
+
98
+ # Try to initialize ChromaDB
99
+ try:
100
+ # Try in-memory first for HF Spaces
101
+ if HFSpacesConfig.IS_HF_SPACES:
102
+ self.chroma_client = chromadb.Client()
103
+ logger.info("Using in-memory ChromaDB for HF Spaces")
104
+ else:
105
+ # Try persistent storage locally
106
+ db_path = HFSpacesConfig.get_temp_dir()
107
+ Path(db_path).mkdir(parents=True, exist_ok=True)
108
+ self.chroma_client = chromadb.PersistentClient(path=db_path)
109
+ logger.info(f"Using persistent ChromaDB at: {db_path}")
110
+
111
+ # Create collection
112
+ self.chroma_collection = self.chroma_client.create_collection(
113
+ name="enterprise_docs",
114
+ get_or_create=True
115
+ )
116
+ return True
117
+
118
+ except Exception as e:
119
+ logger.warning(f"ChromaDB initialization failed: {e}, using simple storage")
120
+ return False
121
 
122
+ def add_documents(self, texts, metadatas):
123
+ """Add documents to the vector store"""
124
+ if not self.embedding_model:
125
+ logger.error("Embedding model not available")
126
+ return False
127
+
128
  try:
129
+ # Generate embeddings
130
+ embeddings = self.embedding_model.encode(texts)
131
+
132
+ if self.chroma_collection:
133
+ # Use ChromaDB
134
+ ids = [f"doc_{i}_{hash(text) % 10000}" for i, text in enumerate(texts)]
135
+ self.chroma_collection.add(
136
+ documents=texts,
137
+ embeddings=embeddings.tolist(),
138
+ metadatas=metadatas,
139
+ ids=ids
140
+ )
141
+ else:
142
+ # Use simple storage
143
+ self.documents.extend(texts)
144
+ self.embeddings.extend(embeddings)
145
+ self.metadata.extend(metadatas)
146
+
147
+ logger.info(f"Added {len(texts)} documents to vector store")
148
+ return True
149
+
150
  except Exception as e:
151
+ logger.error(f"Error adding documents: {e}")
152
+ return False
153
 
154
+ def search(self, query, k=3):
155
+ """Search for similar documents"""
156
+ if not self.embedding_model:
157
+ return []
158
+
159
  try:
160
+ query_embedding = self.embedding_model.encode([query])
161
+
162
+ if self.chroma_collection:
163
+ # Use ChromaDB search
164
+ results = self.chroma_collection.query(
165
+ query_embeddings=query_embedding.tolist(),
166
+ n_results=k
167
+ )
168
+
169
+ search_results = []
170
+ for i in range(len(results['documents'][0])):
171
+ search_results.append({
172
+ 'text': results['documents'][0][i],
173
+ 'metadata': results['metadatas'][0][i],
174
+ 'score': 1 - results['distances'][0][i] if results['distances'] else 0.5
175
+ })
176
+ return search_results
177
+
178
+ else:
179
+ # Use simple cosine similarity
180
+ if not self.embeddings:
181
+ return []
182
+
183
+ from sklearn.metrics.pairwise import cosine_similarity
184
+ similarities = cosine_similarity(query_embedding, self.embeddings)[0]
185
+
186
+ # Get top k results
187
+ top_indices = similarities.argsort()[-k:][::-1]
188
+
189
+ results = []
190
+ for idx in top_indices:
191
+ results.append({
192
+ 'text': self.documents[idx],
193
+ 'metadata': self.metadata[idx],
194
+ 'score': similarities[idx]
195
+ })
196
+ return results
197
+
198
  except Exception as e:
199
+ logger.error(f"Search error: {e}")
200
+ return []
201
+
202
+ def get_stats(self):
203
+ """Get storage statistics"""
204
+ if self.chroma_collection:
205
+ try:
206
+ count = self.chroma_collection.count()
207
+ return {'count': count, 'type': 'ChromaDB'}
208
+ except:
209
+ return {'count': 0, 'type': 'ChromaDB (Error)'}
210
+ else:
211
+ return {'count': len(self.documents), 'type': 'Simple Storage'}
212
+
213
+ class SimpleDocumentProcessor:
214
+ """Simplified document processor"""
215
 
216
  @staticmethod
217
+ def process_pdf(file):
218
+ """Process PDF file"""
219
  try:
220
+ reader = PyPDF2.PdfReader(file)
221
+ text = ""
222
+ for page in reader.pages:
223
+ text += page.extract_text() + "\n"
224
+ return text
225
  except Exception as e:
226
+ logger.error(f"PDF processing error: {e}")
227
  return ""
228
 
229
  @staticmethod
230
+ def process_docx(file):
231
+ """Process DOCX file"""
232
  try:
233
+ doc = Document(file)
234
+ text = "\n".join([para.text for para in doc.paragraphs])
235
+ return text
236
  except Exception as e:
237
+ logger.error(f"DOCX processing error: {e}")
238
  return ""
239
 
240
  @staticmethod
241
+ def process_txt(file):
242
+ """Process text file"""
243
  try:
244
+ return file.read().decode('utf-8')
 
 
 
 
 
 
245
  except Exception as e:
246
+ logger.error(f"TXT processing error: {e}")
247
  return ""
248
 
249
  @staticmethod
250
+ def process_csv(file):
251
+ """Process CSV file"""
252
  try:
253
+ content = file.read().decode('utf-8')
254
+ return content
 
255
  except Exception as e:
256
+ logger.error(f"CSV processing error: {e}")
257
  return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
+ class SimpleRAGAssistant:
260
+ """Simplified RAG Assistant for HF Spaces"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  def __init__(self):
263
+ self.vector_store = SimpleVectorStore()
264
+ self.doc_processor = SimpleDocumentProcessor()
265
  self.model = None
266
  self.tokenizer = None
267
+
268
+ def load_model(self):
269
+ """Load model with error handling"""
270
+ if self.model is not None:
271
+ return True
272
 
 
 
 
273
  try:
274
+ # Try to load a smaller model first
275
+ model_name = "microsoft/DialoGPT-medium" # Fallback model
 
 
 
 
 
 
 
 
276
 
277
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
278
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
279
 
280
+ if self.tokenizer.pad_token is None:
281
+ self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
+ logger.info(f"Loaded model: {model_name}")
284
  return True
285
 
286
  except Exception as e:
287
+ logger.error(f"Model loading failed: {e}")
288
  return False
289
 
290
+ def process_documents(self, files):
291
+ """Process uploaded documents"""
292
+ results = {'success': [], 'errors': []}
 
293
 
294
+ for file in files:
295
+ try:
296
+ file_extension = Path(file.name).suffix.lower()
297
+
298
+ if file_extension == '.pdf':
299
+ text = self.doc_processor.process_pdf(file)
300
+ elif file_extension == '.docx':
301
+ text = self.doc_processor.process_docx(file)
302
+ elif file_extension in ['.txt', '.md']:
303
+ text = self.doc_processor.process_txt(file)
304
+ elif file_extension == '.csv':
305
+ text = self.doc_processor.process_csv(file)
306
+ else:
307
+ results['errors'].append(f"Unsupported format: {file.name}")
308
+ continue
309
+
310
+ if text.strip():
311
+ # Simple chunking
312
+ chunks = self._chunk_text(text, file.name)
313
+ texts = [chunk['text'] for chunk in chunks]
314
+ metadatas = [chunk['metadata'] for chunk in chunks]
315
+
316
+ if self.vector_store.add_documents(texts, metadatas):
317
+ results['success'].append({
318
+ 'name': file.name,
319
+ 'chunks': len(chunks)
320
+ })
321
+ else:
322
+ results['errors'].append(f"Failed to store: {file.name}")
323
+ else:
324
+ results['errors'].append(f"No text extracted: {file.name}")
325
+
326
+ except Exception as e:
327
+ results['errors'].append(f"Error processing {file.name}: {str(e)}")
328
+
329
+ return results
330
 
331
+ def _chunk_text(self, text, filename):
332
+ """Simple text chunking"""
333
+ chunk_size = HFSpacesConfig.CHUNK_SIZE
334
+ chunks = []
 
335
 
336
+ sentences = text.split('.')
337
+ current_chunk = ""
 
338
 
339
+ for sentence in sentences:
340
+ if len(current_chunk + sentence) < chunk_size:
341
+ current_chunk += sentence + "."
342
+ else:
343
+ if current_chunk:
344
+ chunks.append({
345
+ 'text': current_chunk.strip(),
346
+ 'metadata': {'source': filename, 'chunk_id': len(chunks)}
347
+ })
348
+ current_chunk = sentence + "."
349
 
350
+ if current_chunk:
351
+ chunks.append({
352
+ 'text': current_chunk.strip(),
353
+ 'metadata': {'source': filename, 'chunk_id': len(chunks)}
354
+ })
355
 
356
+ return chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
+ def answer_query(self, query):
359
+ """Answer query using RAG"""
360
+ # Search for relevant documents
361
+ search_results = self.vector_store.search(query, k=HFSpacesConfig.TOP_K)
 
362
 
363
+ if not search_results:
364
+ return {
365
+ 'response': "I don't have enough information to answer your question. Please upload some documents first.",
366
+ 'sources': []
367
+ }
368
+
369
+ # Prepare context
370
+ context = "\n\n".join([result['text'][:200] + "..." for result in search_results])
371
 
372
+ # Generate response (simplified)
373
+ if self.model and self.tokenizer:
374
  try:
375
+ prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
376
+ inputs = self.tokenizer.encode(prompt, return_tensors='pt', max_length=512, truncation=True)
 
 
377
 
378
+ with torch.no_grad():
379
+ outputs = self.model.generate(
380
+ inputs,
381
+ max_length=inputs.shape[1] + 100,
382
+ num_return_sequences=1,
383
+ temperature=0.7,
384
+ pad_token_id=self.tokenizer.eos_token_id
385
+ )
 
 
 
 
 
 
 
 
 
 
386
 
387
+ response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
 
388
 
389
  except Exception as e:
390
+ logger.error(f"Generation error: {e}")
391
+ response = f"Based on the available documents: {context[:300]}..."
392
+ else:
393
+ # Fallback response
394
+ response = f"Based on the available documents, here's what I found: {context[:300]}..."
 
 
 
 
 
 
 
 
 
395
 
396
  return {
397
  'response': response,
398
+ 'sources': search_results
 
399
  }
400
 
401
  def main():
402
  """Main Streamlit application"""
403
  st.set_page_config(
404
+ page_title="Enterprise RAG Assistant (HF Spaces)",
405
  page_icon="🏒",
406
+ layout="wide"
 
407
  )
408
 
409
+ # Check dependencies
410
+ if not DEPENDENCIES_AVAILABLE:
411
+ st.error("Some dependencies are missing. Please check the requirements.txt file.")
412
+ st.stop()
413
+
414
+ st.title("🏒 Enterprise RAG Assistant")
415
+ st.caption("Optimized for Hugging Face Spaces")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
  # Initialize session state
418
  if 'rag_assistant' not in st.session_state:
419
+ with st.spinner("Initializing assistant..."):
420
+ st.session_state.rag_assistant = SimpleRAGAssistant()
421
 
422
  if 'chat_history' not in st.session_state:
423
  st.session_state.chat_history = []
424
 
 
 
 
 
425
  # Sidebar
426
  with st.sidebar:
427
+ st.header("πŸ“ Document Upload")
428
 
 
429
  uploaded_files = st.file_uploader(
430
+ "Upload documents",
431
+ type=['pdf', 'docx', 'txt', 'csv', 'md'],
432
+ accept_multiple_files=True
 
433
  )
434
 
435
+ if uploaded_files and st.button("Process Documents"):
436
  with st.spinner("Processing documents..."):
437
+ results = st.session_state.rag_assistant.process_documents(uploaded_files)
438
 
439
+ if results['success']:
440
+ st.success(f"βœ… Processed {len(results['success'])} files")
441
+ for file in results['success']:
442
+ st.write(f"- {file['name']}: {file['chunks']} chunks")
 
 
 
 
443
 
444
  if results['errors']:
445
+ st.error("❌ Some files had errors:")
 
446
  for error in results['errors']:
447
+ st.write(f"- {error}")
448
+
449
+ # Stats
450
+ st.header("πŸ“Š Statistics")
451
+ stats = st.session_state.rag_assistant.vector_store.get_stats()
452
+ st.metric("Documents", stats['count'])
453
+ st.info(f"Storage: {stats['type']}")
454
+
455
+ # Model loading
456
+ st.header("πŸ€– Model Status")
457
+ if st.button("Load Model"):
458
+ with st.spinner("Loading model..."):
459
+ success = st.session_state.rag_assistant.load_model()
460
+ if success:
461
+ st.success("βœ… Model loaded")
462
+ else:
463
+ st.error("❌ Model loading failed")
464
 
465
+ # Main chat interface
466
+ st.header("πŸ’¬ Chat")
467
 
468
+ query = st.text_input("Ask a question about your documents:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
+ if st.button("Send") and query:
471
+ with st.spinner("Generating response..."):
472
+ result = st.session_state.rag_assistant.answer_query(query)
473
+
474
+ st.session_state.chat_history.append({
475
+ 'query': query,
476
+ 'response': result['response'],
477
+ 'sources': result['sources']
478
+ })
479
+
480
+ # Display chat history
481
+ for i, chat in enumerate(reversed(st.session_state.chat_history)):
482
+ with st.expander(f"Q: {chat['query'][:50]}...", expanded=i==0):
483
+ st.write("**Question:**", chat['query'])
484
+ st.write("**Answer:**", chat['response'])
485
 
486
+ if chat['sources']:
487
+ st.write("**Sources:**")
488
+ for j, source in enumerate(chat['sources'][:2]):
489
+ st.write(f"{j+1}. {source['metadata']['source']} (Score: {source['score']:.2f})")
490
+ st.text(source['text'][:150] + "...")
 
 
 
 
 
 
 
 
 
 
 
 
 
491
 
492
  # Footer
493
  st.markdown("---")
494
+ st.markdown("πŸ€— Running on Hugging Face Spaces | Built with Streamlit")
 
 
 
495
 
496
  if __name__ == "__main__":
497
  main()