SimranShaikh commited on
Commit
4088bc5
Β·
verified Β·
1 Parent(s): c6df280
Files changed (1) hide show
  1. src/streamlit_app.py +218 -106
src/streamlit_app.py CHANGED
@@ -7,7 +7,6 @@ os.environ['TRANSFORMERS_CACHE'] = tempfile.gettempdir()
7
  os.environ['HF_HOME'] = tempfile.gettempdir()
8
  os.environ['SENTENCE_TRANSFORMERS_HOME'] = tempfile.gettempdir()
9
 
10
-
11
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
12
  import torch
13
  import PyPDF2
@@ -19,7 +18,7 @@ from chromadb.config import Settings
19
  import tempfile
20
  import uuid
21
 
22
- # Page config
23
  st.set_page_config(
24
  page_title="FinanceGPT - Enterprise AI Assistant",
25
  page_icon="πŸ’°",
@@ -55,69 +54,127 @@ st.markdown("""
55
  @st.cache_resource
56
  def load_models():
57
  """Load and cache all models"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # Initialize embedding model
60
- embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
61
-
62
- # Initialize Granite model (using a smaller model for demo)
63
- model_name = "microsoft/DialoGPT-medium" # Fallback for demo
64
- tokenizer = AutoTokenizer.from_pretrained(model_name)
65
- model = AutoModelForCausalLM.from_pretrained(model_name)
66
-
67
- # Initialize vector database (in-memory for HF Spaces)
68
- client = chromadb.Client()
69
- collection = client.create_collection(
70
- name="documents",
71
- metadata={"hnsw:space": "cosine"}
72
- )
73
 
74
- return embedding_model, tokenizer, model, collection
75
 
76
  @st.cache_data
77
  def process_document(uploaded_file):
78
- """Process uploaded document"""
79
 
80
- # Create temporary file
81
- with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp_file:
82
- tmp_file.write(uploaded_file.getvalue())
83
- tmp_path = tmp_file.name
 
 
 
 
 
 
 
 
84
 
85
  try:
86
  file_extension = uploaded_file.name.split('.')[-1].lower()
 
87
 
88
  if file_extension == 'pdf':
89
- with open(tmp_path, 'rb') as file:
90
- reader = PyPDF2.PdfReader(file)
91
- text = ""
92
- for page in reader.pages:
93
- text += page.extract_text() + "\n"
 
 
94
 
95
  elif file_extension == 'docx':
96
- doc = docx.Document(tmp_path)
97
- text = ""
98
- for paragraph in doc.paragraphs:
99
- text += paragraph.text + "\n"
 
 
100
 
101
  elif file_extension == 'txt':
102
- with open(tmp_path, 'r', encoding='utf-8') as file:
103
- text = file.read()
 
 
 
 
 
 
 
104
 
105
  elif file_extension in ['xlsx', 'xls']:
106
- df = pd.read_excel(tmp_path)
107
- text = df.to_string()
 
 
 
108
 
109
  else:
110
- text = "Unsupported file format"
 
 
 
111
 
112
  return text, uploaded_file.name
113
 
114
  finally:
115
- # Clean up
116
- if os.path.exists(tmp_path):
117
- os.remove(tmp_path)
 
 
 
118
 
119
  def chunk_text(text, chunk_size=1000, overlap=200):
120
  """Split text into chunks"""
 
 
 
121
  chunks = []
122
  start = 0
123
 
@@ -131,7 +188,9 @@ def chunk_text(text, chunk_size=1000, overlap=200):
131
  end = start + last_period + 1
132
  chunk = text[start:end]
133
 
134
- chunks.append(chunk.strip())
 
 
135
  start = end - overlap
136
 
137
  if start >= len(text):
@@ -142,33 +201,45 @@ def chunk_text(text, chunk_size=1000, overlap=200):
142
  def search_documents(query, collection, embedding_model, n_results=3):
143
  """Search for relevant document chunks"""
144
  try:
 
 
 
145
  query_embedding = embedding_model.encode([query]).tolist()
146
 
147
  results = collection.query(
148
  query_embeddings=query_embedding,
149
- n_results=n_results,
150
  include=['documents', 'metadatas', 'distances']
151
  )
152
 
153
  search_results = []
154
- for i in range(len(results['documents'][0])):
155
- search_results.append({
156
- 'content': results['documents'][0][i],
157
- 'metadata': results['metadatas'][0][i],
158
- 'score': 1 - results['distances'][0][i]
159
- })
 
160
 
161
  return search_results
162
- except:
 
163
  return []
164
 
165
  def generate_response(query, context_chunks):
166
  """Generate response using available model"""
167
 
 
 
 
168
  # Build context
169
  context = ""
 
 
170
  for i, chunk in enumerate(context_chunks):
171
- context += f"[Document {i+1}: {chunk['metadata']['filename']}]\n"
 
 
172
  context += f"{chunk['content'][:500]}...\n\n"
173
 
174
  # For demo purposes, create a structured response
@@ -178,9 +249,9 @@ def generate_response(query, context_chunks):
178
  {context[:800]}...
179
 
180
  πŸ’‘ **Analysis:**
181
- The documents contain relevant information that addresses your question. The most relevant sections have been identified and analyzed.
182
 
183
- πŸ“š **Sources:** {len(context_chunks)} document sections were used to generate this response.
184
  """
185
 
186
  return response
@@ -196,60 +267,92 @@ def main():
196
  </div>
197
  """, unsafe_allow_html=True)
198
 
199
- # Load models
200
  with st.spinner("πŸ”„ Loading AI models..."):
201
- embedding_model, tokenizer, model, collection = load_models()
 
 
 
 
202
 
203
  # Sidebar for document upload
204
  with st.sidebar:
205
  st.header("πŸ“ Document Management")
206
  st.markdown("Upload your financial documents to get started!")
207
 
 
 
 
208
  uploaded_files = st.file_uploader(
209
  "Choose files",
210
  accept_multiple_files=True,
211
  type=['pdf', 'docx', 'txt', 'xlsx'],
212
- help="Supported formats: PDF, DOCX, TXT, XLSX"
213
  )
214
 
215
  if uploaded_files:
216
- st.success(f"βœ… {len(uploaded_files)} files uploaded!")
 
 
 
 
 
 
 
217
 
218
- if st.button("πŸ”„ Process Documents", type="primary"):
219
- progress_bar = st.progress(0)
220
- status_text = st.empty()
221
 
222
- for i, file in enumerate(uploaded_files):
223
- status_text.text(f"Processing {file.name}...")
 
 
224
 
225
- try:
226
- # Process document
227
- text, filename = process_document(file)
228
-
229
- # Create chunks
230
- chunks = chunk_text(text)
231
 
232
- # Generate embeddings and store
233
- for j, chunk in enumerate(chunks):
234
- chunk_id = f"{filename}_{j}"
235
- embedding = embedding_model.encode([chunk]).tolist()
236
 
237
- collection.add(
238
- embeddings=embedding,
239
- documents=[chunk],
240
- metadatas=[{'filename': filename, 'chunk_id': j}],
241
- ids=[chunk_id]
242
- )
243
-
244
- st.success(f"βœ… {filename}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
- except Exception as e:
247
- st.error(f"❌ Error processing {file.name}: {str(e)}")
248
 
249
- progress_bar.progress((i + 1) / len(uploaded_files))
250
-
251
- status_text.text("βœ… All documents processed!")
252
- st.balloons()
 
 
 
253
 
254
  # Main interface
255
  col1, col2 = st.columns([2, 1])
@@ -289,24 +392,30 @@ def main():
289
  return
290
 
291
  with st.spinner("πŸ€– Analyzing documents and generating response..."):
292
- # Search for relevant context
293
- search_results = search_documents(query, collection, embedding_model)
294
-
295
- if search_results:
296
- # Generate response
297
- response = generate_response(query, search_results)
298
-
299
- # Display response
300
- st.markdown("### πŸ€– AI Response")
301
- st.markdown(f'<div class="chat-message">{response}</div>', unsafe_allow_html=True)
302
 
303
- # Show sources
304
- st.markdown("### πŸ“š Sources")
305
- for i, result in enumerate(search_results):
306
- with st.expander(f"πŸ“„ Source {i+1}: {result['metadata']['filename']} (Relevance: {result['score']:.1%})"):
307
- st.markdown(f'<div class="source-box">{result["content"][:500]}...</div>', unsafe_allow_html=True)
308
- else:
309
- st.error("❌ No relevant information found in the uploaded documents.")
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
  with col2:
312
  st.header("πŸ“Š Project Info")
@@ -337,9 +446,11 @@ def main():
337
  """)
338
 
339
  # Stats
340
- if 'collection' in locals():
341
  doc_count = collection.count()
342
- st.metric("πŸ“„ Documents Processed", doc_count)
 
 
343
 
344
  # Demo link
345
  st.markdown("""
@@ -348,9 +459,10 @@ def main():
348
  This is a fully functional prototype!
349
 
350
  **Try it:**
351
- 1. Upload financial documents
352
- 2. Ask intelligent questions
353
- 3. Get instant answers with sources
 
354
  """)
355
 
356
  if __name__ == "__main__":
 
7
  os.environ['HF_HOME'] = tempfile.gettempdir()
8
  os.environ['SENTENCE_TRANSFORMERS_HOME'] = tempfile.gettempdir()
9
 
 
10
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
11
  import torch
12
  import PyPDF2
 
18
  import tempfile
19
  import uuid
20
 
21
+ # Page config - ADD FILE SIZE LIMIT
22
  st.set_page_config(
23
  page_title="FinanceGPT - Enterprise AI Assistant",
24
  page_icon="πŸ’°",
 
54
  @st.cache_resource
55
  def load_models():
56
  """Load and cache all models"""
57
+ try:
58
+ # Initialize embedding model
59
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
60
+
61
+ # Initialize Granite model (using a smaller model for demo)
62
+ model_name = "microsoft/DialoGPT-medium" # Fallback for demo
63
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
64
+ if tokenizer.pad_token is None:
65
+ tokenizer.pad_token = tokenizer.eos_token
66
+ model = AutoModelForCausalLM.from_pretrained(model_name)
67
+
68
+ # Initialize vector database (in-memory for HF Spaces)
69
+ client = chromadb.Client()
70
+
71
+ # Check if collection exists, if not create it
72
+ try:
73
+ collection = client.get_collection("documents")
74
+ except:
75
+ collection = client.create_collection(
76
+ name="documents",
77
+ metadata={"hnsw:space": "cosine"}
78
+ )
79
+
80
+ return embedding_model, tokenizer, model, collection
81
+ except Exception as e:
82
+ st.error(f"Error loading models: {str(e)}")
83
+ return None, None, None, None
84
+
85
+ # FIX: Add file validation function
86
+ def validate_file(uploaded_file):
87
+ """Validate uploaded file"""
88
+ # Check file size (limit to 50MB to avoid 403 errors)
89
+ max_size = 50 * 1024 * 1024 # 50MB in bytes
90
+ if uploaded_file.size > max_size:
91
+ return False, f"File {uploaded_file.name} is too large. Maximum size is 50MB."
92
 
93
+ # Check file type
94
+ allowed_extensions = ['pdf', 'docx', 'txt', 'xlsx', 'xls']
95
+ file_extension = uploaded_file.name.split('.')[-1].lower()
96
+ if file_extension not in allowed_extensions:
97
+ return False, f"File type .{file_extension} is not supported."
 
 
 
 
 
 
 
 
 
98
 
99
+ return True, "Valid file"
100
 
101
  @st.cache_data
102
  def process_document(uploaded_file):
103
+ """Process uploaded document with better error handling"""
104
 
105
+ # Validate file first
106
+ is_valid, message = validate_file(uploaded_file)
107
+ if not is_valid:
108
+ raise ValueError(message)
109
+
110
+ # Create temporary file with better error handling
111
+ try:
112
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp_file:
113
+ tmp_file.write(uploaded_file.getvalue())
114
+ tmp_path = tmp_file.name
115
+ except Exception as e:
116
+ raise ValueError(f"Failed to create temporary file: {str(e)}")
117
 
118
  try:
119
  file_extension = uploaded_file.name.split('.')[-1].lower()
120
+ text = ""
121
 
122
  if file_extension == 'pdf':
123
+ try:
124
+ with open(tmp_path, 'rb') as file:
125
+ reader = PyPDF2.PdfReader(file)
126
+ for page in reader.pages:
127
+ text += page.extract_text() + "\n"
128
+ except Exception as e:
129
+ raise ValueError(f"Error reading PDF: {str(e)}")
130
 
131
  elif file_extension == 'docx':
132
+ try:
133
+ doc = docx.Document(tmp_path)
134
+ for paragraph in doc.paragraphs:
135
+ text += paragraph.text + "\n"
136
+ except Exception as e:
137
+ raise ValueError(f"Error reading DOCX: {str(e)}")
138
 
139
  elif file_extension == 'txt':
140
+ try:
141
+ with open(tmp_path, 'r', encoding='utf-8') as file:
142
+ text = file.read()
143
+ except UnicodeDecodeError:
144
+ # Try with different encoding
145
+ with open(tmp_path, 'r', encoding='latin-1') as file:
146
+ text = file.read()
147
+ except Exception as e:
148
+ raise ValueError(f"Error reading TXT: {str(e)}")
149
 
150
  elif file_extension in ['xlsx', 'xls']:
151
+ try:
152
+ df = pd.read_excel(tmp_path)
153
+ text = df.to_string()
154
+ except Exception as e:
155
+ raise ValueError(f"Error reading Excel: {str(e)}")
156
 
157
  else:
158
+ raise ValueError("Unsupported file format")
159
+
160
+ if not text.strip():
161
+ raise ValueError("No text content found in the file")
162
 
163
  return text, uploaded_file.name
164
 
165
  finally:
166
+ # Clean up temporary file
167
+ try:
168
+ if os.path.exists(tmp_path):
169
+ os.remove(tmp_path)
170
+ except:
171
+ pass
172
 
173
  def chunk_text(text, chunk_size=1000, overlap=200):
174
  """Split text into chunks"""
175
+ if not text or not text.strip():
176
+ return []
177
+
178
  chunks = []
179
  start = 0
180
 
 
188
  end = start + last_period + 1
189
  chunk = text[start:end]
190
 
191
+ if chunk.strip(): # Only add non-empty chunks
192
+ chunks.append(chunk.strip())
193
+
194
  start = end - overlap
195
 
196
  if start >= len(text):
 
201
  def search_documents(query, collection, embedding_model, n_results=3):
202
  """Search for relevant document chunks"""
203
  try:
204
+ if collection.count() == 0:
205
+ return []
206
+
207
  query_embedding = embedding_model.encode([query]).tolist()
208
 
209
  results = collection.query(
210
  query_embeddings=query_embedding,
211
+ n_results=min(n_results, collection.count()),
212
  include=['documents', 'metadatas', 'distances']
213
  )
214
 
215
  search_results = []
216
+ if results['documents'] and results['documents'][0]:
217
+ for i in range(len(results['documents'][0])):
218
+ search_results.append({
219
+ 'content': results['documents'][0][i],
220
+ 'metadata': results['metadatas'][0][i],
221
+ 'score': 1 - results['distances'][0][i] if results['distances'][0][i] else 1.0
222
+ })
223
 
224
  return search_results
225
+ except Exception as e:
226
+ st.error(f"Search error: {str(e)}")
227
  return []
228
 
229
  def generate_response(query, context_chunks):
230
  """Generate response using available model"""
231
 
232
+ if not context_chunks:
233
+ return "No relevant information found in the uploaded documents."
234
+
235
  # Build context
236
  context = ""
237
+ source_files = set()
238
+
239
  for i, chunk in enumerate(context_chunks):
240
+ filename = chunk['metadata'].get('filename', 'Unknown')
241
+ source_files.add(filename)
242
+ context += f"[Document {i+1}: {filename}]\n"
243
  context += f"{chunk['content'][:500]}...\n\n"
244
 
245
  # For demo purposes, create a structured response
 
249
  {context[:800]}...
250
 
251
  πŸ’‘ **Analysis:**
252
+ The documents contain relevant information that addresses your question. The analysis is based on {len(context_chunks)} relevant sections from your uploaded documents.
253
 
254
+ πŸ“š **Sources:** {len(source_files)} document(s) - {', '.join(source_files)}
255
  """
256
 
257
  return response
 
267
  </div>
268
  """, unsafe_allow_html=True)
269
 
270
+ # Load models with error handling
271
  with st.spinner("πŸ”„ Loading AI models..."):
272
+ models = load_models()
273
+ if models[0] is None:
274
+ st.error("Failed to load AI models. Please refresh the page.")
275
+ return
276
+ embedding_model, tokenizer, model, collection = models
277
 
278
  # Sidebar for document upload
279
  with st.sidebar:
280
  st.header("πŸ“ Document Management")
281
  st.markdown("Upload your financial documents to get started!")
282
 
283
+ # ADD FILE SIZE WARNING
284
+ st.info("πŸ“‹ **File Requirements:**\n- Max size: 50MB per file\n- Formats: PDF, DOCX, TXT, XLSX")
285
+
286
  uploaded_files = st.file_uploader(
287
  "Choose files",
288
  accept_multiple_files=True,
289
  type=['pdf', 'docx', 'txt', 'xlsx'],
290
+ help="Supported formats: PDF, DOCX, TXT, XLSX (Max 50MB each)"
291
  )
292
 
293
  if uploaded_files:
294
+ # Validate files before processing
295
+ valid_files = []
296
+ for file in uploaded_files:
297
+ is_valid, message = validate_file(file)
298
+ if is_valid:
299
+ valid_files.append(file)
300
+ else:
301
+ st.error(f"❌ {message}")
302
 
303
+ if valid_files:
304
+ st.success(f"βœ… {len(valid_files)} valid files ready for processing!")
 
305
 
306
+ if st.button("πŸ”„ Process Documents", type="primary"):
307
+ progress_bar = st.progress(0)
308
+ status_text = st.empty()
309
+ processed_count = 0
310
 
311
+ for i, file in enumerate(valid_files):
312
+ status_text.text(f"Processing {file.name}...")
 
 
 
 
313
 
314
+ try:
315
+ # Process document
316
+ text, filename = process_document(file)
 
317
 
318
+ # Create chunks
319
+ chunks = chunk_text(text)
320
+
321
+ if not chunks:
322
+ st.warning(f"⚠️ No content extracted from {filename}")
323
+ continue
324
+
325
+ # Generate embeddings and store
326
+ for j, chunk in enumerate(chunks):
327
+ try:
328
+ chunk_id = f"{filename}_{j}_{uuid.uuid4().hex[:8]}"
329
+ embedding = embedding_model.encode([chunk]).tolist()
330
+
331
+ collection.add(
332
+ embeddings=embedding,
333
+ documents=[chunk],
334
+ metadatas=[{'filename': filename, 'chunk_id': j}],
335
+ ids=[chunk_id]
336
+ )
337
+ except Exception as e:
338
+ st.warning(f"⚠️ Error adding chunk {j} from {filename}: {str(e)}")
339
+ continue
340
+
341
+ st.success(f"βœ… {filename} ({len(chunks)} chunks)")
342
+ processed_count += 1
343
+
344
+ except Exception as e:
345
+ st.error(f"❌ Error processing {file.name}: {str(e)}")
346
 
347
+ progress_bar.progress((i + 1) / len(valid_files))
 
348
 
349
+ if processed_count > 0:
350
+ status_text.text(f"βœ… {processed_count} documents processed successfully!")
351
+ st.balloons()
352
+ else:
353
+ status_text.text("❌ No documents were processed successfully.")
354
+ else:
355
+ st.error("❌ No valid files to process!")
356
 
357
  # Main interface
358
  col1, col2 = st.columns([2, 1])
 
392
  return
393
 
394
  with st.spinner("πŸ€– Analyzing documents and generating response..."):
395
+ try:
396
+ # Search for relevant context
397
+ search_results = search_documents(query, collection, embedding_model)
 
 
 
 
 
 
 
398
 
399
+ if search_results:
400
+ # Generate response
401
+ response = generate_response(query, search_results)
402
+
403
+ # Display response
404
+ st.markdown("### πŸ€– AI Response")
405
+ st.markdown(f'<div class="chat-message">{response}</div>', unsafe_allow_html=True)
406
+
407
+ # Show sources
408
+ st.markdown("### πŸ“š Sources")
409
+ for i, result in enumerate(search_results):
410
+ score_percent = f"{result['score']:.1%}" if result['score'] else "N/A"
411
+ filename = result['metadata'].get('filename', 'Unknown')
412
+ with st.expander(f"πŸ“„ Source {i+1}: {filename} (Relevance: {score_percent})"):
413
+ st.markdown(f'<div class="source-box">{result["content"][:500]}...</div>', unsafe_allow_html=True)
414
+ else:
415
+ st.error("❌ No relevant information found in the uploaded documents.")
416
+
417
+ except Exception as e:
418
+ st.error(f"❌ Error processing your question: {str(e)}")
419
 
420
  with col2:
421
  st.header("πŸ“Š Project Info")
 
446
  """)
447
 
448
  # Stats
449
+ try:
450
  doc_count = collection.count()
451
+ st.metric("πŸ“„ Document Chunks", doc_count)
452
+ except:
453
+ st.metric("πŸ“„ Document Chunks", 0)
454
 
455
  # Demo link
456
  st.markdown("""
 
459
  This is a fully functional prototype!
460
 
461
  **Try it:**
462
+ 1. Upload financial documents (max 50MB each)
463
+ 2. Process the documents
464
+ 3. Ask intelligent questions
465
+ 4. Get instant answers with sources
466
  """)
467
 
468
  if __name__ == "__main__":