SimranShaikh commited on
Commit
b8bcf74
Β·
verified Β·
1 Parent(s): cd9e823
Files changed (1) hide show
  1. src/streamlit_app.py +109 -49
src/streamlit_app.py CHANGED
@@ -7,7 +7,6 @@ os.environ['TRANSFORMERS_CACHE'] = tempfile.gettempdir()
7
  os.environ['HF_HOME'] = tempfile.gettempdir()
8
  os.environ['SENTENCE_TRANSFORMERS_HOME'] = tempfile.gettempdir()
9
 
10
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
11
  import torch
12
  import PyPDF2
13
  import docx
@@ -126,15 +125,14 @@ ANALYSIS_TYPES = {
126
 
127
  @st.cache_resource
128
  def load_models():
129
- """Load and cache all models"""
130
  try:
 
 
131
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
132
- model_name = "microsoft/DialoGPT-medium"
133
- tokenizer = AutoTokenizer.from_pretrained(model_name)
134
- if tokenizer.pad_token is None:
135
- tokenizer.pad_token = tokenizer.eos_token
136
- model = AutoModelForCausalLM.from_pretrained(model_name)
137
 
 
 
138
  client = chromadb.Client()
139
  try:
140
  collection = client.get_collection("documents")
@@ -144,10 +142,13 @@ def load_models():
144
  metadata={"hnsw:space": "cosine"}
145
  )
146
 
147
- return embedding_model, tokenizer, model, collection
 
 
148
  except Exception as e:
149
- st.error(f"Error loading models: {str(e)}")
150
- return None, None, None, None
 
151
 
152
  def validate_file(uploaded_file):
153
  """Validate uploaded file"""
@@ -168,7 +169,7 @@ def analyze_document_structure(text, filename):
168
  'filename': filename,
169
  'word_count': len(text.split()),
170
  'char_count': len(text),
171
- 'estimated_pages': len(text) // 2000, # Rough estimate
172
  'has_financial_data': bool(re.search(r'\$|€|Β£|β‚Ή|\d+\.\d+%|\d+,\d+', text)),
173
  'has_tables': bool(re.search(r'\|\s*\w+\s*\|', text)),
174
  'sections': [],
@@ -177,28 +178,33 @@ def analyze_document_structure(text, filename):
177
  }
178
 
179
  # Detect document type
180
- if any(term in text.lower() for term in ['financial statement', 'balance sheet', 'income statement']):
 
181
  analysis['document_type'] = 'Financial Statement'
182
- elif any(term in text.lower() for term in ['annual report', '10-k', '10-q']):
183
  analysis['document_type'] = 'Annual Report'
184
- elif any(term in text.lower() for term in ['investment', 'portfolio', 'fund']):
185
  analysis['document_type'] = 'Investment Document'
186
- elif any(term in text.lower() for term in ['contract', 'agreement', 'terms']):
187
  analysis['document_type'] = 'Legal Document'
 
 
 
 
188
 
189
  # Extract sections (headers)
190
  headers = re.findall(r'^[A-Z][A-Za-z\s]{10,50}$', text, re.MULTILINE)
191
  analysis['sections'] = headers[:10] # Top 10 sections
192
 
193
  # Extract key financial terms
194
- financial_terms = re.findall(r'\b(?:revenue|profit|loss|assets|liabilities|equity|cash|debt|investment|ROI|EBITDA|margin)\b', text, re.IGNORECASE)
195
  analysis['key_terms'] = list(set(financial_terms))[:15]
196
 
197
  return analysis
198
 
199
  @st.cache_data
200
  def process_document(uploaded_file):
201
- """Process uploaded document with enhanced analysis"""
202
  is_valid, message = validate_file(uploaded_file)
203
  if not is_valid:
204
  raise ValueError(message)
@@ -218,8 +224,14 @@ def process_document(uploaded_file):
218
  try:
219
  with open(tmp_path, 'rb') as file:
220
  reader = PyPDF2.PdfReader(file)
 
 
221
  for page in reader.pages:
222
- text += page.extract_text() + "\n"
 
 
 
 
223
  except Exception as e:
224
  raise ValueError(f"Error reading PDF: {str(e)}")
225
 
@@ -227,29 +239,43 @@ def process_document(uploaded_file):
227
  try:
228
  doc = docx.Document(tmp_path)
229
  for paragraph in doc.paragraphs:
230
- text += paragraph.text + "\n"
 
 
 
231
  except Exception as e:
232
  raise ValueError(f"Error reading DOCX: {str(e)}")
233
 
234
  elif file_extension == 'txt':
235
  try:
 
236
  with open(tmp_path, 'r', encoding='utf-8') as file:
237
  text = file.read()
238
  except UnicodeDecodeError:
239
- with open(tmp_path, 'r', encoding='latin-1') as file:
240
- text = file.read()
 
 
 
 
241
  except Exception as e:
242
- raise ValueError(f"Error reading TXT: {str(e)}")
243
 
244
  elif file_extension in ['xlsx', 'xls']:
245
  try:
246
- df = pd.read_excel(tmp_path)
247
- text = df.to_string()
 
 
248
  except Exception as e:
249
- raise ValueError(f"Error reading Excel: {str(e)}")
250
 
251
- if not text.strip():
252
- raise ValueError("No text content found in the file")
 
 
 
 
253
 
254
  # Analyze document structure
255
  analysis = analyze_document_structure(text, uploaded_file.name)
@@ -275,12 +301,30 @@ def generate_analysis_by_type(text, analysis_type, analysis_info):
275
  for keyword in keywords:
276
  if keyword in text_lower:
277
  # Find context around keywords
278
- pattern = rf'.{0,200}\b{keyword}\b.{0,200}'
279
  matches = re.findall(pattern, text, re.IGNORECASE | re.DOTALL)
280
- relevant_sections.extend(matches[:3]) # Max 3 matches per keyword
281
 
282
  if not relevant_sections:
283
- return f"No specific information found for {analysis_type} in this document."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
  # Create structured analysis
286
  analysis_result = f"""
@@ -293,31 +337,43 @@ def generate_analysis_by_type(text, analysis_type, analysis_info):
293
 
294
  for i, section in enumerate(relevant_sections[:5], 1):
295
  cleaned_section = re.sub(r'\s+', ' ', section.strip())
296
- analysis_result += f"\n{i}. {cleaned_section[:300]}...\n"
 
 
297
 
298
- analysis_result += f"\n**Summary**: Based on the document analysis, {len(relevant_sections)} relevant sections were identified related to {analysis_type.lower()}."
299
 
300
  return analysis_result
301
 
302
  def chunk_text(text, chunk_size=1000, overlap=200):
303
- """Split text into chunks"""
304
  if not text or not text.strip():
305
  return []
306
 
 
 
 
307
  chunks = []
308
  start = 0
309
 
310
  while start < len(text):
311
  end = start + chunk_size
312
- chunk = text[start:end]
313
 
314
- if end < len(text):
 
 
 
 
 
315
  last_period = chunk.rfind('.')
316
- if last_period > chunk_size * 0.7:
317
- end = start + last_period + 1
 
 
 
318
  chunk = text[start:end]
319
 
320
- if chunk.strip():
321
  chunks.append(chunk.strip())
322
 
323
  start = end - overlap
@@ -328,13 +384,15 @@ def chunk_text(text, chunk_size=1000, overlap=200):
328
  return chunks
329
 
330
  def search_documents(query, collection, embedding_model, n_results=3):
331
- """Search for relevant document chunks"""
332
  try:
333
  if collection.count() == 0:
334
  return []
335
-
 
336
  query_embedding = embedding_model.encode([query]).tolist()
337
 
 
338
  results = collection.query(
339
  query_embeddings=query_embedding,
340
  n_results=min(n_results, collection.count()),
@@ -346,7 +404,7 @@ def search_documents(query, collection, embedding_model, n_results=3):
346
  for i in range(len(results['documents'][0])):
347
  search_results.append({
348
  'content': results['documents'][0][i],
349
- 'metadata': results['metadatas'][0][i],
350
  'score': 1 - results['distances'][0][i] if results['distances'][0][i] else 1.0
351
  })
352
 
@@ -361,7 +419,7 @@ def main():
361
 
362
  st.markdown("""
363
  <div style="text-align: center; font-size: 1.2rem; color: #666; margin-bottom: 2rem;">
364
- πŸš€ Powered by IBM Granite Models | πŸ“Š Advanced Document Intelligence | πŸ”’ Secure & Compliant
365
  </div>
366
  """, unsafe_allow_html=True)
367
 
@@ -369,9 +427,10 @@ def main():
369
  with st.spinner("πŸ”„ Loading AI models..."):
370
  models = load_models()
371
  if models[0] is None:
372
- st.error("Failed to load AI models. Please refresh the page.")
373
- return
374
- embedding_model, tokenizer, model, collection = models
 
375
 
376
  # Sidebar for document management
377
  with st.sidebar:
@@ -425,13 +484,14 @@ def main():
425
  chunk_id = f"{filename}_{j}_{uuid.uuid4().hex[:8]}"
426
  embedding = embedding_model.encode([chunk]).tolist()
427
 
428
- collection.add(
429
  embeddings=embedding,
430
  documents=[chunk],
431
  metadatas=[{'filename': filename, 'chunk_id': j}],
432
  ids=[chunk_id]
433
  )
434
  except Exception as e:
 
435
  continue
436
 
437
  st.success(f"βœ… {filename}")
@@ -593,7 +653,7 @@ def main():
593
  **Query**: {query}
594
 
595
  **Key Findings**:
596
- {context[:1000]}...
597
 
598
  **Summary**: Based on analysis of {len(search_results)} relevant sections from {len(source_files)} document(s), the information above directly addresses your question.
599
 
@@ -650,10 +710,10 @@ def main():
650
  st.header("🎯 Project Info")
651
 
652
  st.markdown("""
653
- ### **Built For IBM Hackathon**
654
 
655
  **πŸ”§ Technology Stack:**
656
- - 🧠 IBM Granite Models
657
  - πŸ” RAG (Retrieval-Augmented Generation)
658
  - πŸ“Š Streamlit UI
659
  - πŸ—„οΈ ChromaDB Vector Database
 
7
  os.environ['HF_HOME'] = tempfile.gettempdir()
8
  os.environ['SENTENCE_TRANSFORMERS_HOME'] = tempfile.gettempdir()
9
 
 
10
  import torch
11
  import PyPDF2
12
  import docx
 
125
 
126
  @st.cache_resource
127
  def load_models():
128
+ """Load and cache models with better error handling"""
129
  try:
130
+ # Load embedding model first (most reliable)
131
+ st.info("Loading embedding model...")
132
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
 
 
 
 
 
133
 
134
+ # Initialize ChromaDB
135
+ st.info("Initializing vector database...")
136
  client = chromadb.Client()
137
  try:
138
  collection = client.get_collection("documents")
 
142
  metadata={"hnsw:space": "cosine"}
143
  )
144
 
145
+ st.success("βœ… Models loaded successfully!")
146
+ return embedding_model, collection
147
+
148
  except Exception as e:
149
+ st.error(f"❌ Error loading models: {str(e)}")
150
+ st.error("Please check your internet connection and try refreshing the page.")
151
+ return None, None
152
 
153
  def validate_file(uploaded_file):
154
  """Validate uploaded file"""
 
169
  'filename': filename,
170
  'word_count': len(text.split()),
171
  'char_count': len(text),
172
+ 'estimated_pages': max(1, len(text) // 2000), # Minimum 1 page
173
  'has_financial_data': bool(re.search(r'\$|€|Β£|β‚Ή|\d+\.\d+%|\d+,\d+', text)),
174
  'has_tables': bool(re.search(r'\|\s*\w+\s*\|', text)),
175
  'sections': [],
 
178
  }
179
 
180
  # Detect document type
181
+ text_lower = text.lower()
182
+ if any(term in text_lower for term in ['financial statement', 'balance sheet', 'income statement']):
183
  analysis['document_type'] = 'Financial Statement'
184
+ elif any(term in text_lower for term in ['annual report', '10-k', '10-q']):
185
  analysis['document_type'] = 'Annual Report'
186
+ elif any(term in text_lower for term in ['investment', 'portfolio', 'fund']):
187
  analysis['document_type'] = 'Investment Document'
188
+ elif any(term in text_lower for term in ['contract', 'agreement', 'terms']):
189
  analysis['document_type'] = 'Legal Document'
190
+ elif any(term in text_lower for term in ['budget', 'forecast', 'projection']):
191
+ analysis['document_type'] = 'Financial Planning'
192
+ else:
193
+ analysis['document_type'] = 'Business Document'
194
 
195
  # Extract sections (headers)
196
  headers = re.findall(r'^[A-Z][A-Za-z\s]{10,50}$', text, re.MULTILINE)
197
  analysis['sections'] = headers[:10] # Top 10 sections
198
 
199
  # Extract key financial terms
200
+ financial_terms = re.findall(r'\b(?:revenue|profit|loss|assets|liabilities|equity|cash|debt|investment|ROI|EBITDA|margin|expenses|income|growth|risk|return)\b', text, re.IGNORECASE)
201
  analysis['key_terms'] = list(set(financial_terms))[:15]
202
 
203
  return analysis
204
 
205
  @st.cache_data
206
  def process_document(uploaded_file):
207
+ """Process uploaded document with enhanced error handling"""
208
  is_valid, message = validate_file(uploaded_file)
209
  if not is_valid:
210
  raise ValueError(message)
 
224
  try:
225
  with open(tmp_path, 'rb') as file:
226
  reader = PyPDF2.PdfReader(file)
227
+ if len(reader.pages) == 0:
228
+ raise ValueError("PDF file appears to be empty")
229
  for page in reader.pages:
230
+ page_text = page.extract_text()
231
+ if page_text:
232
+ text += page_text + "\n"
233
+ if not text.strip():
234
+ raise ValueError("Could not extract text from PDF")
235
  except Exception as e:
236
  raise ValueError(f"Error reading PDF: {str(e)}")
237
 
 
239
  try:
240
  doc = docx.Document(tmp_path)
241
  for paragraph in doc.paragraphs:
242
+ if paragraph.text.strip():
243
+ text += paragraph.text + "\n"
244
+ if not text.strip():
245
+ raise ValueError("DOCX file appears to be empty")
246
  except Exception as e:
247
  raise ValueError(f"Error reading DOCX: {str(e)}")
248
 
249
  elif file_extension == 'txt':
250
  try:
251
+ # Try UTF-8 first
252
  with open(tmp_path, 'r', encoding='utf-8') as file:
253
  text = file.read()
254
  except UnicodeDecodeError:
255
+ try:
256
+ # Fallback to latin-1
257
+ with open(tmp_path, 'r', encoding='latin-1') as file:
258
+ text = file.read()
259
+ except Exception as e:
260
+ raise ValueError(f"Error reading TXT file: {str(e)}")
261
  except Exception as e:
262
+ raise ValueError(f"Error reading TXT file: {str(e)}")
263
 
264
  elif file_extension in ['xlsx', 'xls']:
265
  try:
266
+ df = pd.read_excel(tmp_path, sheet_name=0) # Read first sheet
267
+ if df.empty:
268
+ raise ValueError("Excel file appears to be empty")
269
+ text = df.to_string(index=False)
270
  except Exception as e:
271
+ raise ValueError(f"Error reading Excel file: {str(e)}")
272
 
273
+ if not text or not text.strip():
274
+ raise ValueError("No readable text content found in the file")
275
+
276
+ # Clean up text
277
+ text = re.sub(r'\n\s*\n', '\n\n', text) # Remove excessive newlines
278
+ text = text.strip()
279
 
280
  # Analyze document structure
281
  analysis = analyze_document_structure(text, uploaded_file.name)
 
301
  for keyword in keywords:
302
  if keyword in text_lower:
303
  # Find context around keywords
304
+ pattern = rf'.{{0,200}}\b{keyword}\b.{{0,200}}'
305
  matches = re.findall(pattern, text, re.IGNORECASE | re.DOTALL)
306
+ relevant_sections.extend(matches[:2]) # Max 2 matches per keyword
307
 
308
  if not relevant_sections:
309
+ # If no keyword matches, provide general analysis
310
+ words = text.split()
311
+ if len(words) > 500:
312
+ sample_text = ' '.join(words[:500]) + "..."
313
+ else:
314
+ sample_text = text
315
+
316
+ return f"""
317
+ ## {analysis_type}
318
+
319
+ **Analysis Focus**: {description}
320
+
321
+ **Document Analysis**:
322
+ Based on the document content, here are the key insights related to {analysis_type.lower()}:
323
+
324
+ {sample_text}
325
+
326
+ **Summary**: The document has been analyzed for {analysis_type.lower()} content. While specific keywords weren't found, the above content provides relevant context for your analysis needs.
327
+ """
328
 
329
  # Create structured analysis
330
  analysis_result = f"""
 
337
 
338
  for i, section in enumerate(relevant_sections[:5], 1):
339
  cleaned_section = re.sub(r'\s+', ' ', section.strip())
340
+ if len(cleaned_section) > 300:
341
+ cleaned_section = cleaned_section[:300] + "..."
342
+ analysis_result += f"\n**Finding {i}**: {cleaned_section}\n"
343
 
344
+ analysis_result += f"\n**Summary**: Based on the document analysis, {len(relevant_sections)} relevant sections were identified related to {analysis_type.lower()}. These findings provide insights into the document's content from the perspective of {description.lower()}."
345
 
346
  return analysis_result
347
 
348
  def chunk_text(text, chunk_size=1000, overlap=200):
349
+ """Split text into chunks with better handling"""
350
  if not text or not text.strip():
351
  return []
352
 
353
+ # Clean text first
354
+ text = re.sub(r'\s+', ' ', text.strip())
355
+
356
  chunks = []
357
  start = 0
358
 
359
  while start < len(text):
360
  end = start + chunk_size
 
361
 
362
+ if end >= len(text):
363
+ # Last chunk
364
+ chunk = text[start:]
365
+ else:
366
+ chunk = text[start:end]
367
+ # Try to break at sentence boundary
368
  last_period = chunk.rfind('.')
369
+ last_newline = chunk.rfind('\n')
370
+ break_point = max(last_period, last_newline)
371
+
372
+ if break_point > chunk_size * 0.5: # If we found a good break point
373
+ end = start + break_point + 1
374
  chunk = text[start:end]
375
 
376
+ if chunk.strip() and len(chunk.strip()) > 50: # Only add substantial chunks
377
  chunks.append(chunk.strip())
378
 
379
  start = end - overlap
 
384
  return chunks
385
 
386
  def search_documents(query, collection, embedding_model, n_results=3):
387
+ """Search for relevant document chunks with better error handling"""
388
  try:
389
  if collection.count() == 0:
390
  return []
391
+
392
+ # Generate query embedding
393
  query_embedding = embedding_model.encode([query]).tolist()
394
 
395
+ # Search the collection
396
  results = collection.query(
397
  query_embeddings=query_embedding,
398
  n_results=min(n_results, collection.count()),
 
404
  for i in range(len(results['documents'][0])):
405
  search_results.append({
406
  'content': results['documents'][0][i],
407
+ 'metadata': results['metadatas'][0][i] if results['metadatas'][0] else {},
408
  'score': 1 - results['distances'][0][i] if results['distances'][0][i] else 1.0
409
  })
410
 
 
419
 
420
  st.markdown("""
421
  <div style="text-align: center; font-size: 1.2rem; color: #666; margin-bottom: 2rem;">
422
+ πŸš€ Powered by Advanced AI | πŸ“Š Document Intelligence | πŸ”’ Secure & Compliant
423
  </div>
424
  """, unsafe_allow_html=True)
425
 
 
427
  with st.spinner("πŸ”„ Loading AI models..."):
428
  models = load_models()
429
  if models[0] is None:
430
+ st.error("❌ Failed to load AI models. Please refresh the page and check your internet connection.")
431
+ st.stop()
432
+
433
+ embedding_model, collection = models
434
 
435
  # Sidebar for document management
436
  with st.sidebar:
 
484
  chunk_id = f"{filename}_{j}_{uuid.uuid4().hex[:8]}"
485
  embedding = embedding_model.encode([chunk]).tolist()
486
 
487
+ collection.upsert(
488
  embeddings=embedding,
489
  documents=[chunk],
490
  metadatas=[{'filename': filename, 'chunk_id': j}],
491
  ids=[chunk_id]
492
  )
493
  except Exception as e:
494
+ st.warning(f"Warning: Could not process chunk {j} of {filename}")
495
  continue
496
 
497
  st.success(f"βœ… {filename}")
 
653
  **Query**: {query}
654
 
655
  **Key Findings**:
656
+ {context[:1500]}...
657
 
658
  **Summary**: Based on analysis of {len(search_results)} relevant sections from {len(source_files)} document(s), the information above directly addresses your question.
659
 
 
710
  st.header("🎯 Project Info")
711
 
712
  st.markdown("""
713
+ ### **Enterprise AI Assistant**
714
 
715
  **πŸ”§ Technology Stack:**
716
+ - 🧠 Advanced AI Models
717
  - πŸ” RAG (Retrieval-Augmented Generation)
718
  - πŸ“Š Streamlit UI
719
  - πŸ—„οΈ ChromaDB Vector Database