riteshraut commited on
Commit
46af083
·
1 Parent(s): 1dcbf42

feat/used the parentdocument retriever

Browse files
Files changed (3) hide show
  1. app.py +128 -434
  2. rag_processor.py +14 -13
  3. templates/index.html +142 -40
app.py CHANGED
@@ -3,28 +3,20 @@
3
  import os
4
  import time
5
  import uuid
6
- from flask import Flask, request, render_template, session, jsonify, Response, stream_with_context
7
  from werkzeug.utils import secure_filename
8
  from rag_processor import create_rag_chain
9
- from typing import Sequence, Any
10
- from langchain_experimental.text_splitter import SemanticChunker
11
-
12
- from gtts import gTTS
13
- import io
14
  import re
 
 
15
 
16
-
17
  from langchain_community.document_loaders import (
18
  TextLoader,
19
- PyPDFLoader,
20
  Docx2txtLoader,
21
  )
22
-
23
- # Additional imports for robust PDF handling
24
- from langchain_core.documents import Document
25
- import fitz
26
-
27
- # Text Splitter, Embeddings, Retrievers
28
  from langchain.text_splitter import RecursiveCharacterTextSplitter
29
  from langchain_huggingface import HuggingFaceEmbeddings
30
  from langchain_community.vectorstores import FAISS
@@ -32,8 +24,8 @@ from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriev
32
  from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
33
  from langchain_community.retrievers import BM25Retriever
34
  from langchain_community.chat_message_histories import ChatMessageHistory
 
35
  from sentence_transformers.cross_encoder import CrossEncoder
36
- import numpy as np
37
 
38
 
39
  app = Flask(__name__)
@@ -42,7 +34,7 @@ app.config['SECRET_KEY'] = os.urandom(24)
42
 
43
  class LocalReranker(BaseDocumentCompressor):
44
  model: Any
45
- top_n: int = 5
46
 
47
  class Config:
48
  arbitrary_types_allowed = True
@@ -62,8 +54,11 @@ class LocalReranker(BaseDocumentCompressor):
62
  doc_scores = list(zip(documents, scores))
63
  sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True)
64
 
65
- return [doc for doc, score in sorted_doc_scores[:self.top_n]]
66
-
 
 
 
67
 
68
  is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"))
69
  if is_hf_spaces:
@@ -71,451 +66,225 @@ if is_hf_spaces:
71
  else:
72
  app.config['UPLOAD_FOLDER'] = 'uploads'
73
 
74
-
75
  try:
76
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
77
  print(f"✓ Upload folder ready: {app.config['UPLOAD_FOLDER']}")
78
  except Exception as e:
79
  print(f"✗ Failed to create upload folder {app.config['UPLOAD_FOLDER']}: {e}")
80
- # Fallback to /tmp if the configured path fails
81
  app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
82
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
83
  print(f"✓ Using fallback upload folder: {app.config['UPLOAD_FOLDER']}")
84
 
85
-
86
-
87
-
88
-
89
  rag_chains = {}
90
  message_histories = {}
91
 
92
  print("Loading embedding model...")
93
-
94
-
95
- cache_base = os.path.expanduser("~/.cache") if os.path.expanduser("~") != "~" else "/tmp/hf_cache"
96
- os.environ.setdefault('HF_HOME', f'{cache_base}/huggingface')
97
- os.environ.setdefault('HF_HUB_CACHE', f'{cache_base}/huggingface/hub')
98
- os.environ.setdefault('TRANSFORMERS_CACHE', f'{cache_base}/transformers')
99
- os.environ.setdefault('SENTENCE_TRANSFORMERS_HOME', f'{cache_base}/sentence_transformers')
100
-
101
-
102
- cache_dirs = [
103
- os.environ['HF_HOME'],
104
- os.environ['HF_HUB_CACHE'],
105
- os.environ['TRANSFORMERS_CACHE'],
106
- os.environ['SENTENCE_TRANSFORMERS_HOME']
107
- ]
108
-
109
- for cache_dir in cache_dirs:
110
- try:
111
- os.makedirs(cache_dir, mode=0o777, exist_ok=True)
112
- # Set permissions explicitly after creation
113
- os.chmod(cache_dir, 0o777)
114
- print(f"Cache directory ready: {cache_dir}")
115
- except PermissionError as pe:
116
- print(f"Permission error for {cache_dir}: {pe}")
117
- # Try to use a fallback directory in /tmp
118
- fallback_dir = cache_dir.replace('/app/.cache', '/tmp/hf_cache')
119
- try:
120
- os.makedirs(fallback_dir, mode=0o777, exist_ok=True)
121
- print(f"Using fallback cache directory: {fallback_dir}")
122
- # Update environment variable to point to fallback
123
- env_var_map = {
124
- '/app/.cache/huggingface': 'HF_HOME',
125
- '/app/.cache/huggingface/hub': 'HF_HUB_CACHE',
126
- '/app/.cache/transformers': 'TRANSFORMERS_CACHE',
127
- '/app/.cache/sentence_transformers': 'SENTENCE_TRANSFORMERS_HOME'
128
- }
129
- if cache_dir in env_var_map:
130
- os.environ[env_var_map[cache_dir]] = fallback_dir
131
- except Exception as fe:
132
- print(f"Failed to create fallback directory {fallback_dir}: {fe}")
133
- except Exception as e:
134
- print(f"Warning: Could not create {cache_dir}: {e}")
135
-
136
-
137
-
138
- # Try loading embedding model with error handling and fallbacks
139
  try:
140
- print("Attempting to load embedding model...")
141
  EMBEDDING_MODEL = HuggingFaceEmbeddings(
142
- model_name="sentence-transformers/all-miniLM-L6-v2",
143
- model_kwargs={'device': 'cpu'}, # Ensure CPU usage
144
- cache_folder=os.environ.get('SENTENCE_TRANSFORMERS_HOME')
145
  )
146
- print("Embedding model loaded successfully.")
147
  except Exception as e:
148
- print(f"Error loading embedding model with full name: {e}")
149
- print("Trying shortened model name...")
150
- try:
151
- EMBEDDING_MODEL = HuggingFaceEmbeddings(
152
- model_name="all-miniLM-L6-v2",
153
- model_kwargs={'device': 'cpu'},
154
- cache_folder=os.environ.get('SENTENCE_TRANSFORMERS_HOME')
155
- )
156
- print("Embedding model with shortened name loaded successfully.")
157
- except Exception as e2:
158
- print(f"Failed to load embedding model with shortened name: {e2}")
159
- print("Trying without cache folder specification...")
160
- try:
161
- EMBEDDING_MODEL = HuggingFaceEmbeddings(
162
- model_name="sentence-transformers/all-miniLM-L6-v2",
163
- model_kwargs={'device': 'cpu'}
164
- )
165
- print("Embedding model loaded without cache folder specification.")
166
- except Exception as e3:
167
- print(f"Final attempt failed: {e3}")
168
- # Use a simpler fallback model or raise the error
169
- raise Exception(f"Could not load any embedding model. Last error: {e3}")
170
-
171
-
172
 
173
  print("Loading local re-ranking model...")
174
- RERANKER_MODEL = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device='cpu')
175
- print("Re-ranking model loaded successfully.")
176
-
 
 
 
177
 
178
  def load_pdf_with_fallback(filepath):
179
- """
180
- Tries to load PDF using multiple methods for better reliability.
181
- Returns a list of Document objects.
182
- """
183
- # Method 1: Try PyPDFLoader (primary method)
184
  try:
185
- loader = PyPDFLoader(filepath)
186
- docs = loader.load()
187
- if docs and any(doc.page_content.strip() for doc in docs):
188
- print(f"✓ Successfully loaded PDF using PyPDFLoader: {filepath}")
189
- return docs
190
- except Exception as e:
191
- print(f"PyPDFLoader failed for {filepath}: {e}")
192
-
193
- # Method 2: Try PyMuPDF (fitz) as fallback
194
- try:
195
- import fitz
196
  docs = []
197
- pdf_doc = fitz.open(filepath)
198
-
199
- for page_num in range(pdf_doc.page_count):
200
- page = pdf_doc[page_num]
201
- text = page.get_text()
202
-
203
- if text.strip(): # Only add pages with content
204
- doc = Document(
205
- page_content=text,
206
- metadata={
207
- "source": filepath,
208
- "page": page_num + 1,
209
- "total_pages": pdf_doc.page_count
210
- }
211
- )
212
- docs.append(doc)
213
-
214
- pdf_doc.close()
215
-
216
- if docs:
217
- print(f"✓ Successfully loaded PDF using PyMuPDF: {filepath}")
218
- return docs
219
-
220
- except Exception as e:
221
- print(f"PyMuPDF failed for {filepath}: {e}")
222
-
223
- # Method 3: Try to extract text using pdfplumber (if available)
224
- try:
225
- import pdfplumber
226
- docs = []
227
-
228
- with pdfplumber.open(filepath) as pdf:
229
- for page_num, page in enumerate(pdf.pages):
230
- text = page.extract_text()
231
- if text and text.strip():
232
- doc = Document(
233
  page_content=text,
234
  metadata={
235
- "source": filepath,
236
  "page": page_num + 1,
237
- "total_pages": len(pdf.pages)
238
  }
239
- )
240
- docs.append(doc)
241
-
242
  if docs:
243
- print(f"✓ Successfully loaded PDF using pdfplumber: {filepath}")
244
  return docs
245
-
246
- except ImportError:
247
- print("pdfplumber not available, skipping method 3")
248
  except Exception as e:
249
- print(f"pdfplumber failed for {filepath}: {e}")
250
-
251
- # If all methods fail, raise the original error
252
- raise Exception(f"Could not load PDF {filepath} using any available method. The file may be corrupted, password-protected, or in an unsupported format.")
253
 
254
- # A dictionary to map file extensions to their corresponding loader classes
255
  LOADER_MAPPING = {
256
  ".txt": TextLoader,
257
- ".pdf": load_pdf_with_fallback, # Use our custom PDF loader
258
  ".docx": Docx2txtLoader,
259
  }
260
 
261
  def get_session_history(session_id: str) -> ChatMessageHistory:
262
- """
263
- Retrieves the chat history for a given session ID. If it doesn't exist,
264
- a new history object is created.
265
- """
266
  if session_id not in message_histories:
267
  message_histories[session_id] = ChatMessageHistory()
268
  return message_histories[session_id]
269
 
270
  @app.route('/health', methods=['GET'])
271
  def health_check():
272
- """Health check endpoint for monitoring."""
273
- return jsonify({'status': 'healthy', 'message': 'CogniChat is running'}), 200
274
 
275
  @app.route('/', methods=['GET'])
276
  def index():
277
- """Renders the main page."""
278
  return render_template('index.html')
279
 
280
  @app.route('/upload', methods=['POST'])
281
  def upload_files():
282
- """Handles multiple file uploads, processing, and RAG chain creation."""
283
  files = request.files.getlist('file')
284
-
285
  if not files or all(f.filename == '' for f in files):
286
  return jsonify({'status': 'error', 'message': 'No selected files.'}), 400
287
 
288
  all_docs = []
289
- all_filenames = []
290
-
291
- try:
292
- print(f"Processing {len(files)} files...")
293
- processed_files = []
294
- failed_files = []
295
-
296
- for file in files:
297
- if file and file.filename:
298
- filename = secure_filename(file.filename)
299
- all_filenames.append(filename)
300
- filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
301
 
302
- try:
303
- file.save(filepath)
304
- print(f"✓ Successfully saved file: {filepath}")
305
- except Exception as save_error:
306
- print(f"✗ Failed to save file {filename}: {save_error}")
307
- failed_files.append(f"{filename} (failed to save: {save_error})")
308
- continue
309
-
310
- file_extension = os.path.splitext(filename)[1].lower()
311
- if file_extension not in LOADER_MAPPING:
312
- print(f"Skipping unsupported file type: {filename}")
313
- failed_files.append(f"{filename} (unsupported format)")
314
- continue
315
 
316
- # Add file-specific error handling
317
- try:
318
- loader_or_function = LOADER_MAPPING[file_extension]
319
-
320
- # Handle custom PDF loader function vs regular loader classes
321
- if file_extension == ".pdf":
322
- file_docs = loader_or_function(filepath) # Direct function call
323
- else:
324
- loader = loader_or_function(filepath) # Create loader instance
325
- file_docs = loader.load()
326
-
327
- if file_docs: # Only add if documents were successfully loaded
328
- all_docs.extend(file_docs)
329
- processed_files.append(filename)
330
- print(f" Successfully processed: {filename} ({len(file_docs)} pages/sections)")
331
- else:
332
- failed_files.append(f"{filename} (no content extracted)")
333
- print(f"⚠ Warning: No content extracted from {filename}")
334
-
335
- except Exception as file_error:
336
- error_msg = str(file_error)
337
- print(f"✗ Error processing {filename}: {error_msg}")
338
-
339
- # Provide specific error messages for common PDF issues
340
- if "endstream" in error_msg.lower():
341
- failed_files.append(f"{filename} (corrupted PDF - missing endstream marker)")
342
- elif "pdf" in error_msg.lower() or "pypdf" in error_msg.lower():
343
- failed_files.append(f"{filename} (PDF parsing error)")
344
- elif "permission" in error_msg.lower():
345
- failed_files.append(f"{filename} (password protected or permission denied)")
346
- elif "encoding" in error_msg.lower():
347
- failed_files.append(f"{filename} (text encoding issue)")
348
- else:
349
- failed_files.append(f"{filename} ({error_msg[:50]}...)" if len(error_msg) > 50 else f"{filename} ({error_msg})")
350
-
351
- continue
352
-
353
- # Check if any documents were successfully processed
354
- if not all_docs:
355
- if failed_files:
356
- error_msg = f"Failed to process all uploaded files:\n" + "\n".join(failed_files[:5])
357
- if len(failed_files) > 5:
358
- error_msg += f"\n...and {len(failed_files) - 5} more files"
359
- return jsonify({'status': 'error', 'message': error_msg}), 400
360
- else:
361
- return jsonify({'status': 'error', 'message': 'No processable files were uploaded.'}), 400
362
-
363
- # --- Process all documents together ---
364
- print(f"Successfully processed {len(processed_files)} files, creating knowledge base...")
365
- text_splitter = SemanticChunker(EMBEDDING_MODEL, breakpoint_threshold_type="percentile")
366
- splits = text_splitter.split_documents(all_docs)
367
-
368
- print("Creating vector store for all documents...")
369
- vectorstore = FAISS.from_documents(documents=splits, embedding=EMBEDDING_MODEL)
370
-
371
- bm25_retriever = BM25Retriever.from_documents(splits)
372
- bm25_retriever.k = 7
373
- faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 7})
374
  ensemble_retriever = EnsembleRetriever(
375
  retrievers=[bm25_retriever, faiss_retriever],
376
- weights=[0.5, 0.5]
377
  )
378
- reranker = LocalReranker(model=RERANKER_MODEL, top_n=5)
 
 
 
 
 
 
379
 
380
  compression_retriever = ContextualCompressionRetriever(
381
- base_compressor=reranker,
382
- base_retriever=ensemble_retriever
383
- )
384
-
385
 
386
- session_id = str(uuid.uuid4())
387
- print(f"Creating RAG chain for session {session_id}...")
388
 
389
- try:
390
- rag_chain = create_rag_chain(compression_retriever, get_session_history)
391
- rag_chains[session_id] = rag_chain
392
- print(f"✓ RAG chain created successfully for session {session_id} with {len(processed_files)} documents.")
393
- except Exception as rag_error:
394
- print(f"✗ Failed to create RAG chain: {rag_error}")
395
- import traceback
396
- traceback.print_exc()
397
- raise rag_error
398
 
 
 
 
399
  session['session_id'] = session_id
400
- print(f"✓ Session stored in Flask session: {session_id}")
401
-
402
- # Prepare success message with file processing summary
403
  success_msg = f"Successfully processed: {', '.join(processed_files)}"
404
  if failed_files:
405
- success_msg += f"\n\nFailed to process {len(failed_files)} file(s):\n" + "\n".join(failed_files[:3])
406
- if len(failed_files) > 3:
407
- success_msg += f"\n...and {len(failed_files) - 3} more"
408
-
409
  return jsonify({
410
  'status': 'success',
411
  'filename': success_msg,
412
- 'processed_count': len(processed_files),
413
- 'failed_count': len(failed_files),
414
- 'session_id': session_id # Include session_id in response for client to store
415
  })
416
 
417
  except Exception as e:
418
- print(f"Error creating RAG chain: {e}")
419
- return jsonify({'status': 'error', 'message': f'Failed to process files: {str(e)}'}), 500
 
420
 
421
  @app.route('/chat', methods=['POST'])
422
  def chat():
423
- """Handles chat messages and streams the response with memory."""
424
  data = request.get_json()
425
  question = data.get('question')
426
-
427
- # Try to get session_id from multiple sources (Flask session or request body)
428
  session_id = session.get('session_id') or data.get('session_id')
429
 
430
- print(f"\n=== Chat Request Debug ===")
431
- print(f"Raw request data: {data}")
432
- print(f"Question: '{question}' (type: {type(question)})")
433
- print(f"Session ID from Flask session: {session.get('session_id')}")
434
- print(f"Session ID from request body: {data.get('session_id')}")
435
- print(f"Final session ID used: {session_id}")
436
- print(f"Available RAG chains: {list(rag_chains.keys())}")
437
- print(f"Total RAG chains: {len(rag_chains)}")
438
-
439
- # Check each condition individually with detailed logging
440
- if not question:
441
- print("✗ FAILURE: No question provided")
442
- return jsonify({'status': 'error', 'message': 'No question provided.'}), 400
443
- else:
444
- print("✓ Question is valid")
445
-
446
- if not session_id:
447
- print("✗ FAILURE: No session ID in Flask session")
448
- return jsonify({'status': 'error', 'message': 'No session found. Please upload documents first.'}), 400
449
- else:
450
- print(f"✓ Session ID found: {session_id}")
451
-
452
- if session_id not in rag_chains:
453
- print(f"✗ FAILURE: Session ID {session_id} not found in RAG chains")
454
- print(f"Available sessions: {list(rag_chains.keys())}")
455
- return jsonify({'status': 'error', 'message': 'Session not found. Please upload documents again.'}), 400
456
- else:
457
- print(f"✓ RAG chain found for session: {session_id}")
458
 
459
  try:
460
  rag_chain = rag_chains[session_id]
461
  config = {"configurable": {"session_id": session_id}}
462
 
463
- def generate():
464
- """A generator function to stream the response."""
465
- # Create the input dictionary with question and config
466
- input_data = {
467
- "question": question,
468
- "config": config
469
- }
470
-
471
- try:
472
- # For streaming, we need to handle the chain differently
473
- # The final chain doesn't support streaming, so we'll invoke it directly
474
- response = rag_chain.invoke(input_data)
475
- yield response
476
- except Exception as stream_error:
477
- print(f"Streaming error: {stream_error}")
478
- # Fallback to direct invocation
479
- try:
480
- response = rag_chain.invoke(input_data)
481
- yield response
482
- except Exception as invoke_error:
483
- print(f"Invoke error: {invoke_error}")
484
- yield "I apologize, but I'm having trouble processing your question. Please try again or upload your documents again."
485
 
486
- return Response(stream_with_context(generate()), mimetype='text/plain')
 
 
487
 
488
  except Exception as e:
489
  print(f"Error during chat invocation: {e}")
490
- return Response("An error occurred while getting the answer.", status=500, mimetype='text/plain')
491
-
492
 
493
  def clean_markdown_for_tts(text: str) -> str:
494
- """Removes markdown formatting for cleaner text-to-speech output."""
495
- # Remove bold (**text**) and italics (*text* or _text_)
496
  text = re.sub(r'\*(\*?)(.*?)\1\*', r'\2', text)
497
  text = re.sub(r'\_(.*?)\_', r'\1', text)
498
- # Remove inline code (`code`)
499
  text = re.sub(r'`(.*?)`', r'\1', text)
500
- # Remove headings (e.g., #, ##, ###)
501
  text = re.sub(r'^\s*#{1,6}\s+', '', text, flags=re.MULTILINE)
502
- # Remove list item markers (*, -, 1.)
503
  text = re.sub(r'^\s*[\*\-]\s+', '', text, flags=re.MULTILINE)
504
  text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
505
- # Remove blockquotes (>)
506
  text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE)
507
-
508
- # --- NEW ADDITION HERE ---
509
- # Remove horizontal rules (---, ***, ===)
510
  text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE)
511
-
512
- # Replace multiple newlines with a single space
513
  text = re.sub(r'\n+', ' ', text)
514
  return text.strip()
515
 
516
  @app.route('/tts', methods=['POST'])
517
  def text_to_speech():
518
- """Generates audio from text and returns it as an MP3 stream."""
519
  data = request.get_json()
520
  text = data.get('text')
521
 
@@ -523,9 +292,7 @@ def text_to_speech():
523
  return jsonify({'status': 'error', 'message': 'No text provided.'}), 400
524
 
525
  try:
526
- # --- FIX IS HERE: Clean the text before sending to gTTS ---
527
  clean_text = clean_markdown_for_tts(text)
528
-
529
  tts = gTTS(clean_text, lang='en')
530
  mp3_fp = io.BytesIO()
531
  tts.write_to_fp(mp3_fp)
@@ -535,79 +302,6 @@ def text_to_speech():
535
  print(f"Error in TTS generation: {e}")
536
  return jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500
537
 
538
-
539
-
540
- @app.route('/debug', methods=['GET'])
541
- def debug_info():
542
- """Debug endpoint to check configuration."""
543
- api_key = os.getenv("GROQ_API_KEY")
544
- is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"))
545
-
546
- return jsonify({
547
- 'environment': 'Hugging Face Spaces' if is_hf_spaces else 'Local/Other',
548
- 'upload_folder': app.config['UPLOAD_FOLDER'],
549
- 'upload_folder_writable': os.access(app.config['UPLOAD_FOLDER'], os.W_OK),
550
- 'groq_api_key_set': bool(api_key and api_key != "your_groq_api_key_here"),
551
- 'groq_api_key_placeholder': api_key == "your_groq_api_key_here" if api_key else False,
552
- 'sessions_count': len(rag_chains),
553
- 'session_ids': list(rag_chains.keys()),
554
- 'flask_session_id': session.get('session_id'),
555
- 'flask_session_keys': list(session.keys()) if session else [],
556
- 'embedding_model_loaded': 'EMBEDDING_MODEL' in globals(),
557
- 'space_id': os.getenv("SPACE_ID"),
558
- 'hf_spaces_detected': is_hf_spaces
559
- })
560
-
561
- @app.route('/test-session', methods=['GET', 'POST'])
562
- def test_session():
563
- """Test endpoint to debug Flask session issues."""
564
- if request.method == 'POST':
565
- # Set a test session value
566
- session['test_key'] = 'test_value'
567
- session['timestamp'] = str(time.time())
568
- return jsonify({
569
- 'action': 'session_set',
570
- 'test_key': session.get('test_key'),
571
- 'timestamp': session.get('timestamp'),
572
- 'session_keys': list(session.keys())
573
- })
574
- else:
575
- # Read session values
576
- return jsonify({
577
- 'action': 'session_read',
578
- 'test_key': session.get('test_key'),
579
- 'timestamp': session.get('timestamp'),
580
- 'session_id': session.get('session_id'),
581
- 'session_keys': list(session.keys()),
582
- 'has_session_data': bool(session)
583
- })
584
-
585
  if __name__ == '__main__':
586
- print(f"\n=== Application Startup Complete at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
587
-
588
- # Check environment
589
- is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"))
590
- api_key = os.getenv('GROQ_API_KEY')
591
-
592
- print(f"Environment: {'Hugging Face Spaces' if is_hf_spaces else 'Local Development'}")
593
- print(f"Upload folder: {app.config['UPLOAD_FOLDER']}")
594
-
595
- if is_hf_spaces:
596
- print(f"Space ID: {os.getenv('SPACE_ID', 'Not set')}")
597
- print(f"GROQ API Key: {'Set via HF Secrets' if api_key and api_key != 'your_groq_api_key_here' else 'NOT SET - Add to Space Secrets'}")
598
- else:
599
- print(f"GROQ API Key: {'Set' if api_key and api_key != 'your_groq_api_key_here' else 'NOT SET'}")
600
-
601
- print(f"Embedding model: {'Loaded' if 'EMBEDDING_MODEL' in globals() else 'NOT LOADED'}")
602
-
603
- if not api_key or api_key == 'your_groq_api_key_here':
604
- print("\n⚠️ WARNING: GROQ API KEY NOT CONFIGURED!")
605
- if is_hf_spaces:
606
- print(" → Go to your Space Settings > Repository Secrets")
607
- print(" → Add GROQ_API_KEY as a secret")
608
- else:
609
- print(" → Update .env file with your GROQ API key")
610
-
611
- # Use port 7860 for Hugging Face Spaces, fallback to 5001 for local development
612
  port = int(os.environ.get("PORT", 7860))
613
- app.run(host="0.0.0.0", port=port, debug=False)
 
3
  import os
4
  import time
5
  import uuid
6
+ from flask import Flask, request, render_template, session, jsonify, Response
7
  from werkzeug.utils import secure_filename
8
  from rag_processor import create_rag_chain
9
+ from typing import Sequence, Any, List
10
+ import fitz
 
 
 
11
  import re
12
+ import io
13
+ from gtts import gTTS
14
 
15
+ from langchain_core.documents import Document
16
  from langchain_community.document_loaders import (
17
  TextLoader,
 
18
  Docx2txtLoader,
19
  )
 
 
 
 
 
 
20
  from langchain.text_splitter import RecursiveCharacterTextSplitter
21
  from langchain_huggingface import HuggingFaceEmbeddings
22
  from langchain_community.vectorstores import FAISS
 
24
  from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
25
  from langchain_community.retrievers import BM25Retriever
26
  from langchain_community.chat_message_histories import ChatMessageHistory
27
+ from langchain.storage import InMemoryStore
28
  from sentence_transformers.cross_encoder import CrossEncoder
 
29
 
30
 
31
  app = Flask(__name__)
 
34
 
35
  class LocalReranker(BaseDocumentCompressor):
36
  model: Any
37
+ top_n: int = 3
38
 
39
  class Config:
40
  arbitrary_types_allowed = True
 
54
  doc_scores = list(zip(documents, scores))
55
  sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True)
56
 
57
+ top_docs = []
58
+ for doc, score in sorted_doc_scores[:self.top_n]:
59
+ doc.metadata['rerank_score'] = float(score)
60
+ top_docs.append(doc)
61
+ return top_docs
62
 
63
  is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"))
64
  if is_hf_spaces:
 
66
  else:
67
  app.config['UPLOAD_FOLDER'] = 'uploads'
68
 
 
69
  try:
70
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
71
  print(f"✓ Upload folder ready: {app.config['UPLOAD_FOLDER']}")
72
  except Exception as e:
73
  print(f"✗ Failed to create upload folder {app.config['UPLOAD_FOLDER']}: {e}")
 
74
  app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
75
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
76
  print(f"✓ Using fallback upload folder: {app.config['UPLOAD_FOLDER']}")
77
 
 
 
 
 
78
  rag_chains = {}
79
  message_histories = {}
80
 
81
  print("Loading embedding model...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  try:
 
83
  EMBEDDING_MODEL = HuggingFaceEmbeddings(
84
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
85
+ model_kwargs={'device': 'cpu'}
 
86
  )
87
+ print("Embedding model loaded successfully.")
88
  except Exception as e:
89
+ print(f" FATAL: Could not load embedding model. Error: {e}")
90
+ raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  print("Loading local re-ranking model...")
93
+ try:
94
+ RERANKER_MODEL = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device='cpu')
95
+ print("✓ Re-ranking model loaded successfully.")
96
+ except Exception as e:
97
+ print(f"✗ FATAL: Could not load reranker model. Error: {e}")
98
+ raise e
99
 
100
  def load_pdf_with_fallback(filepath):
 
 
 
 
 
101
  try:
 
 
 
 
 
 
 
 
 
 
 
102
  docs = []
103
+ with fitz.open(filepath) as pdf_doc:
104
+ for page_num, page in enumerate(pdf_doc):
105
+ text = page.get_text()
106
+ if text.strip():
107
+ docs.append(Document(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  page_content=text,
109
  metadata={
110
+ "source": os.path.basename(filepath),
111
  "page": page_num + 1,
 
112
  }
113
+ ))
 
 
114
  if docs:
115
+ print(f"✓ Successfully loaded PDF with PyMuPDF: {filepath}")
116
  return docs
117
+ else:
118
+ raise ValueError("No text content found in PDF.")
 
119
  except Exception as e:
120
+ print(f" PyMuPDF failed for {filepath}: {e}")
121
+ raise
 
 
122
 
 
123
  LOADER_MAPPING = {
124
  ".txt": TextLoader,
125
+ ".pdf": load_pdf_with_fallback,
126
  ".docx": Docx2txtLoader,
127
  }
128
 
129
  def get_session_history(session_id: str) -> ChatMessageHistory:
 
 
 
 
130
  if session_id not in message_histories:
131
  message_histories[session_id] = ChatMessageHistory()
132
  return message_histories[session_id]
133
 
134
  @app.route('/health', methods=['GET'])
135
  def health_check():
136
+ return jsonify({'status': 'healthy'}), 200
 
137
 
138
  @app.route('/', methods=['GET'])
139
  def index():
 
140
  return render_template('index.html')
141
 
142
  @app.route('/upload', methods=['POST'])
143
  def upload_files():
 
144
  files = request.files.getlist('file')
 
145
  if not files or all(f.filename == '' for f in files):
146
  return jsonify({'status': 'error', 'message': 'No selected files.'}), 400
147
 
148
  all_docs = []
149
+ processed_files, failed_files = [], []
150
+
151
+ for file in files:
152
+ if file and file.filename:
153
+ filename = secure_filename(file.filename)
154
+ filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
155
+ try:
156
+ file.save(filepath)
157
+ file_ext = os.path.splitext(filename)[1].lower()
158
+ if file_ext not in LOADER_MAPPING:
159
+ raise ValueError("Unsupported file format.")
 
160
 
161
+ loader_func = LOADER_MAPPING[file_ext]
162
+ docs = loader_func(filepath) if file_ext == ".pdf" else loader_func(filepath).load()
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ if not docs:
165
+ raise ValueError("No content extracted.")
166
+
167
+ all_docs.extend(docs)
168
+ processed_files.append(filename)
169
+ print(f"✓ Successfully processed: {filename}")
170
+ except Exception as e:
171
+ error_msg = str(e)
172
+ print(f"✗ Error processing {filename}: {error_msg}")
173
+ failed_files.append(f"{filename} ({error_msg})")
174
+
175
+ if not all_docs:
176
+ error_summary = "Failed to process all files."
177
+ if failed_files:
178
+ error_summary += " Reasons: " + ", ".join(failed_files)
179
+ return jsonify({'status': 'error', 'message': error_summary}), 400
180
+
181
+ try:
182
+ print("Starting RAG pipeline setup...")
183
+
184
+ vectorstore = FAISS.from_texts([""], EMBEDDING_MODEL)
185
+ store = InMemoryStore()
186
+
187
+ parent_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
188
+ child_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=50)
189
+
190
+ parent_docs = parent_splitter.split_documents(all_docs)
191
+ doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
192
+
193
+ child_docs = []
194
+ for i, doc in enumerate(parent_docs):
195
+ _id = doc_ids[i]
196
+ sub_docs = child_splitter.split_documents([doc])
197
+ for child in sub_docs:
198
+ child.metadata["doc_id"] = _id
199
+ child_docs.extend(sub_docs)
200
+
201
+ store.mset(list(zip(doc_ids, parent_docs)))
202
+ vectorstore.add_documents(child_docs)
203
+ print(f"✓ Stored {len(parent_docs)} parent docs and indexed {len(child_docs)} child docs.")
204
+
205
+ bm25_retriever = BM25Retriever.from_documents(child_docs)
206
+ bm25_retriever.k = 8
207
+
208
+ faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 8})
209
+
 
 
 
 
 
 
 
 
 
 
 
 
210
  ensemble_retriever = EnsembleRetriever(
211
  retrievers=[bm25_retriever, faiss_retriever],
212
+ weights=[0.4, 0.6]
213
  )
214
+ print("✓ Created Hybrid Retriever for child documents.")
215
+
216
+ reranker = LocalReranker(model=RERANKER_MODEL, top_n=4)
217
+
218
+ def get_parents(docs: List[Document]) -> List[Document]:
219
+ parent_ids = {d.metadata["doc_id"] for d in docs}
220
+ return store.mget(list(parent_ids))
221
 
222
  compression_retriever = ContextualCompressionRetriever(
223
+ base_compressor=reranker, base_retriever=ensemble_retriever
224
+ )
 
 
225
 
226
+ final_retriever = compression_retriever | get_parents
 
227
 
228
+ print("✓ Final retriever chain created: (Hybrid -> Rerank) -> Parent Fetch")
 
 
 
 
 
 
 
 
229
 
230
+ session_id = str(uuid.uuid4())
231
+ rag_chain = create_rag_chain(final_retriever, get_session_history)
232
+ rag_chains[session_id] = rag_chain
233
  session['session_id'] = session_id
234
+
 
 
235
  success_msg = f"Successfully processed: {', '.join(processed_files)}"
236
  if failed_files:
237
+ success_msg += f"\nFailed to process: {', '.join(failed_files)}"
238
+
 
 
239
  return jsonify({
240
  'status': 'success',
241
  'filename': success_msg,
242
+ 'session_id': session_id
 
 
243
  })
244
 
245
  except Exception as e:
246
+ import traceback
247
+ traceback.print_exc()
248
+ return jsonify({'status': 'error', 'message': f'Failed during RAG setup: {e}'}), 500
249
 
250
  @app.route('/chat', methods=['POST'])
251
  def chat():
 
252
  data = request.get_json()
253
  question = data.get('question')
 
 
254
  session_id = session.get('session_id') or data.get('session_id')
255
 
256
+ if not question or not session_id or session_id not in rag_chains:
257
+ return jsonify({'status': 'error', 'message': 'Invalid session or no question provided.'}), 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  try:
260
  rag_chain = rag_chains[session_id]
261
  config = {"configurable": {"session_id": session_id}}
262
 
263
+ # Invoke the chain, which will return a string
264
+ answer_string = rag_chain.invoke({"question": question}, config=config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
+ # --- THIS IS THE FIX ---
267
+ # Directly use the returned string in the JSON response.
268
+ return jsonify({'answer': answer_string})
269
 
270
  except Exception as e:
271
  print(f"Error during chat invocation: {e}")
272
+ return jsonify({'status': 'error', 'message': 'An error occurred while getting the answer.'}), 500
 
273
 
274
  def clean_markdown_for_tts(text: str) -> str:
 
 
275
  text = re.sub(r'\*(\*?)(.*?)\1\*', r'\2', text)
276
  text = re.sub(r'\_(.*?)\_', r'\1', text)
 
277
  text = re.sub(r'`(.*?)`', r'\1', text)
 
278
  text = re.sub(r'^\s*#{1,6}\s+', '', text, flags=re.MULTILINE)
 
279
  text = re.sub(r'^\s*[\*\-]\s+', '', text, flags=re.MULTILINE)
280
  text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
 
281
  text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE)
 
 
 
282
  text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE)
 
 
283
  text = re.sub(r'\n+', ' ', text)
284
  return text.strip()
285
 
286
  @app.route('/tts', methods=['POST'])
287
  def text_to_speech():
 
288
  data = request.get_json()
289
  text = data.get('text')
290
 
 
292
  return jsonify({'status': 'error', 'message': 'No text provided.'}), 400
293
 
294
  try:
 
295
  clean_text = clean_markdown_for_tts(text)
 
296
  tts = gTTS(clean_text, lang='en')
297
  mp3_fp = io.BytesIO()
298
  tts.write_to_fp(mp3_fp)
 
302
  print(f"Error in TTS generation: {e}")
303
  return jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500
304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  if __name__ == '__main__':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  port = int(os.environ.get("PORT", 7860))
307
+ app.run(host="0.0.0.0", port=port, debug=False)
rag_processor.py CHANGED
@@ -35,9 +35,11 @@ def create_rag_chain(retriever, get_session_history_func):
35
 
36
  # Get API key from environment (works for both HF Spaces secrets and local .env)
37
  api_key = os.getenv("GROQ_API_KEY")
 
38
 
39
  if not api_key or api_key == "your_groq_api_key_here":
40
  error_msg = "GROQ_API_KEY not found or not configured properly.\n"
 
41
 
42
  # Detect if running in Hugging Face Spaces
43
  if os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"):
@@ -56,7 +58,7 @@ def create_rag_chain(retriever, get_session_history_func):
56
 
57
  # --- 1. Initialize the LLM ---
58
  # Updated model_name to a standard, high-performance Groq model
59
- llm = ChatGroq(model_name="llama-3.1-8b-instant", api_key=api_key, temperature=1)
60
 
61
  # --- 2. Create Query Rewriting Chain 🧠 ---
62
  print("\nSetting up query rewriting chain...")
@@ -80,14 +82,12 @@ Standalone Question:"""
80
 
81
  # --- 3. Create Main RAG Chain with Memory ---
82
  print("\nSetting up main RAG chain...")
83
- rag_template = """You are an expert assistant named `Cognichat`, developed by Ritesh and Alish.
84
- Your primary job is to provide accurate, comprehensive answers based ONLY on the provided context document.
85
- Act as an expert historian analyzing a text. Synthesize information from all available context snippets to form a complete answer.
86
- Your answer must be grounded in the text. Quote key phrases where necessary but do not simply copy-paste.
87
- Structure your answers clearly using markdown, such as bold text and bullet points, for readability.
88
-
89
- **Crucially, do not state that the information is missing from the context.** Assume the provided text contains the necessary information and do your best to construct the answer from it. If the context is truly insufficient, answer the question as best you can with the information available without mentioning the context's limitations.
90
-
91
  Context:
92
  {context}"""
93
  rag_prompt = ChatPromptTemplate.from_messages([
@@ -100,10 +100,11 @@ Context:
100
  # Parallel process to fetch context and correctly pass through question and history.
101
  # We use itemgetter to select the specific keys from the input dictionary.
102
  setup_and_retrieval = RunnableParallel({
103
- "context": query_rewriter | retriever,
104
- "question": itemgetter("question"),
105
- "chat_history": itemgetter("chat_history"),
106
- })
 
107
  # =====================================================================
108
 
109
  # The initial RAG chain
 
35
 
36
  # Get API key from environment (works for both HF Spaces secrets and local .env)
37
  api_key = os.getenv("GROQ_API_KEY")
38
+ print("key loaded")
39
 
40
  if not api_key or api_key == "your_groq_api_key_here":
41
  error_msg = "GROQ_API_KEY not found or not configured properly.\n"
42
+ print("Not found key")
43
 
44
  # Detect if running in Hugging Face Spaces
45
  if os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"):
 
58
 
59
  # --- 1. Initialize the LLM ---
60
  # Updated model_name to a standard, high-performance Groq model
61
+ llm = ChatGroq(model_name="llama-3.1-8b-instant", api_key=api_key, temperature=0.1)
62
 
63
  # --- 2. Create Query Rewriting Chain 🧠 ---
64
  print("\nSetting up query rewriting chain...")
 
82
 
83
  # --- 3. Create Main RAG Chain with Memory ---
84
  print("\nSetting up main RAG chain...")
85
+ rag_template = """You are CogniChat, an expert document analysis assistant.
86
+ IMPORTANT RULES:
87
+ 1. ONLY use information from the provided context
88
+ 2. If the answer isn't in the context, say "I cannot find this information in the uploaded documents"
89
+ 3. Cite specific sections when possible (e.g., "According to Section 2.3...")
90
+ 4. Be concise but comprehensive
 
 
91
  Context:
92
  {context}"""
93
  rag_prompt = ChatPromptTemplate.from_messages([
 
100
  # Parallel process to fetch context and correctly pass through question and history.
101
  # We use itemgetter to select the specific keys from the input dictionary.
102
  setup_and_retrieval = RunnableParallel({
103
+ "context": RunnablePassthrough.assign(
104
+ rewritten_question=query_rewriter
105
+ ) | (lambda x: x["rewritten_question"]) | retriever,
106
+ "question": itemgetter("question"),
107
+ "chat_history": itemgetter("chat_history")})
108
  # =====================================================================
109
 
110
  # The initial RAG chain
templates/index.html CHANGED
@@ -98,17 +98,122 @@
98
  40% { transform: scale(1.0); }
99
  }
100
 
101
- /* Markdown Styling */
102
- .markdown-content p { margin-bottom: 0.75rem; line-height: 1.75; }
103
- .markdown-content ul, .markdown-content ol { margin-left: 1.5rem; margin-bottom: 0.75rem; }
104
- .markdown-content code { background-color: rgba(0,0,0,0.05); padding: 0.2rem 0.4rem; border-radius: 0.25rem; font-family: 'Roboto Mono', monospace; font-size: 0.9em; }
105
- .dark .markdown-content code { background-color: rgba(255,255,255,0.1); }
106
- .markdown-content pre { position: relative; background-color: #f8f9fa; border: 1px solid var(--card-border); border-radius: 0.5rem; margin-bottom: 1rem; }
107
- .dark .markdown-content pre { background-color: #2e2f32; }
108
- .markdown-content pre code { background: none; padding: 1rem; display: block; overflow-x: auto; }
109
- .markdown-content pre .copy-code-btn { position: absolute; top: 0.5rem; right: 0.5rem; background-color: #e8eaed; border: 1px solid #dadce0; color: #5f6368; padding: 0.3rem 0.6rem; border-radius: 0.25rem; cursor: pointer; opacity: 0; transition: opacity 0.2s; font-size: 0.8em;}
110
- .dark .markdown-content pre .copy-code-btn { background-color: #3c4043; border-color: #5f6368; color: #e8eaed; }
111
- .markdown-content pre:hover .copy-code-btn { opacity: 1; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  /* Spinner for the TTS button */
114
  .tts-button-loader {
@@ -181,7 +286,6 @@
181
  const chatContent = document.getElementById('chat-content');
182
  const chatFilename = document.getElementById('chat-filename');
183
 
184
- // Persist the active Flask session between requests
185
  let sessionId = null;
186
  const storedSessionId = sessionStorage.getItem('cognichat_session_id');
187
  if (storedSessionId) {
@@ -257,7 +361,7 @@
257
  loadingOverlay.classList.add('hidden');
258
  loadingSubtext.textContent = '';
259
  fileNameSpan.textContent = '';
260
- fileUploadInput.value = ''; // Reset file input
261
  }
262
  }
263
 
@@ -289,26 +393,28 @@
289
  });
290
 
291
  if (!response.ok) throw new Error(`Server error: ${response.statusText}`);
 
 
 
 
 
 
 
 
 
292
 
293
  typingIndicator.remove();
294
  botMessageContainer = appendMessage('', 'bot');
295
  contentDiv = botMessageContainer.querySelector('.markdown-content');
296
 
297
- const reader = response.body.getReader();
298
- const decoder = new TextDecoder();
299
- let fullResponse = '';
300
-
301
- while (true) {
302
- const { value, done } = await reader.read();
303
- if (done) break;
304
-
305
- fullResponse += decoder.decode(value, { stream: true });
306
- contentDiv.innerHTML = marked.parse(fullResponse);
307
- scrollToBottom();
308
- }
309
  contentDiv.querySelectorAll('pre').forEach(addCopyButton);
 
310
 
311
- addTextToSpeechControls(botMessageContainer, fullResponse);
 
 
312
 
313
  } catch (error) {
314
  console.error('Chat error:', error);
@@ -344,7 +450,10 @@
344
 
345
  const contentDiv = document.createElement('div');
346
  contentDiv.className = 'text-base markdown-content';
347
- contentDiv.innerHTML = marked.parse(text);
 
 
 
348
 
349
  const controlsContainer = document.createElement('div');
350
  controlsContainer.className = 'tts-controls mt-2';
@@ -412,7 +521,7 @@
412
  });
413
  }
414
 
415
- // ============================ MODIFICATIONS START ==============================
416
  let currentAudio = null;
417
  let currentPlayingButton = null;
418
 
@@ -424,7 +533,6 @@
424
  const ttsControls = messageBubble.querySelector('.tts-controls');
425
  if (text.trim().length > 0) {
426
  const speakButton = document.createElement('button');
427
- // STYLING CHANGE HERE: Replaced theme variables with specific dark blue colors.
428
  speakButton.className = 'speak-btn px-4 py-2 bg-blue-700 text-white rounded-full text-sm font-medium hover:bg-blue-800 transition-colors flex items-center gap-2 disabled:opacity-50 disabled:cursor-not-allowed';
429
  speakButton.title = 'Listen to this message';
430
  speakButton.setAttribute('data-state', 'play');
@@ -435,24 +543,19 @@
435
  }
436
 
437
  async function handleTTS(text, button) {
438
- // BUG FIX: Reworked the logic to correctly handle pause/resume.
439
-
440
- // Case 1: The clicked button is already active (playing or paused).
441
  if (button === currentPlayingButton) {
442
- if (currentAudio && !currentAudio.paused) { // If it's playing, pause it.
443
  currentAudio.pause();
444
  button.setAttribute('data-state', 'paused');
445
  button.innerHTML = `${playIconSVG} <span>Play</span>`;
446
- } else if (currentAudio && currentAudio.paused) { // If it's paused, resume it.
447
  currentAudio.play();
448
  button.setAttribute('data-state', 'playing');
449
  button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
450
  }
451
- return; // Stop the function here.
452
  }
453
 
454
- // Case 2: A new button is clicked (or no audio is active).
455
- // Stop any other audio that might be playing.
456
  resetAllSpeakButtons();
457
 
458
  currentPlayingButton = button;
@@ -488,7 +591,7 @@
488
  button.setAttribute('data-state', 'error');
489
  button.innerHTML = `${playIconSVG} <span>Error</span>`;
490
  alert('Failed to play audio. Please try again.');
491
- resetAllSpeakButtons(); // Reset state on error
492
  } finally {
493
  button.disabled = false;
494
  }
@@ -506,8 +609,7 @@
506
  }
507
  currentPlayingButton = null;
508
  }
509
- // ============================ MODIFICATIONS END ==============================
510
  });
511
  </script>
512
  </body>
513
- </html><!DOCTYPE html>
 
98
  40% { transform: scale(1.0); }
99
  }
100
 
101
+ /* Enhanced Markdown Styling for better readability and aesthetics */
102
+ .markdown-content p {
103
+ margin-bottom: 1rem;
104
+ line-height: 1.75;
105
+ }
106
+ .markdown-content h1, .markdown-content h2, .markdown-content h3, .markdown-content h4 {
107
+ font-family: 'Google Sans', sans-serif;
108
+ font-weight: 700;
109
+ margin-top: 1.75rem;
110
+ margin-bottom: 1rem;
111
+ line-height: 1.3;
112
+ }
113
+ .markdown-content h1 { font-size: 1.75em; border-bottom: 1px solid var(--card-border); padding-bottom: 0.5rem; }
114
+ .markdown-content h2 { font-size: 1.5em; }
115
+ .markdown-content h3 { font-size: 1.25em; }
116
+ .markdown-content h4 { font-size: 1.1em; }
117
+ .markdown-content ul, .markdown-content ol {
118
+ padding-left: 1.75rem;
119
+ margin-bottom: 1rem;
120
+ }
121
+ .markdown-content li {
122
+ margin-bottom: 0.5rem;
123
+ }
124
+ .dark .markdown-content ul > li::marker { color: var(--primary); }
125
+ .markdown-content ul > li::marker { color: var(--primary); }
126
+ .markdown-content a {
127
+ color: var(--primary);
128
+ text-decoration: none;
129
+ font-weight: 500;
130
+ border-bottom: 1px solid transparent;
131
+ transition: all 0.2s ease-in-out;
132
+ }
133
+ .markdown-content a:hover {
134
+ border-bottom-color: var(--primary-hover);
135
+ }
136
+ .markdown-content blockquote {
137
+ margin: 1.5rem 0;
138
+ padding-left: 1.5rem;
139
+ border-left: 4px solid var(--card-border);
140
+ color: #6c757d;
141
+ font-style: italic;
142
+ }
143
+ .dark .markdown-content blockquote {
144
+ color: #adb5bd;
145
+ }
146
+ .markdown-content hr {
147
+ border: none;
148
+ border-top: 1px solid var(--card-border);
149
+ margin: 2rem 0;
150
+ }
151
+ .markdown-content table {
152
+ width: 100%;
153
+ border-collapse: collapse;
154
+ margin: 1.5rem 0;
155
+ font-size: 0.9em;
156
+ box-shadow: 0 1px 3px rgba(0,0,0,0.05);
157
+ border-radius: 8px;
158
+ overflow: hidden;
159
+ }
160
+ .markdown-content th, .markdown-content td {
161
+ border: 1px solid var(--card-border);
162
+ padding: 0.75rem 1rem;
163
+ text-align: left;
164
+ }
165
+ .markdown-content th {
166
+ background-color: var(--bot-bubble);
167
+ font-weight: 500;
168
+ }
169
+ .markdown-content code {
170
+ background-color: rgba(0,0,0,0.05);
171
+ padding: 0.2rem 0.4rem;
172
+ border-radius: 0.25rem;
173
+ font-family: 'Roboto Mono', monospace;
174
+ font-size: 0.9em;
175
+ }
176
+ .dark .markdown-content code {
177
+ background-color: rgba(255,255,255,0.1);
178
+ }
179
+ .markdown-content pre {
180
+ position: relative;
181
+ background-color: #f8f9fa;
182
+ border: 1px solid var(--card-border);
183
+ border-radius: 0.5rem;
184
+ margin-bottom: 1rem;
185
+ }
186
+ .dark .markdown-content pre {
187
+ background-color: #2e2f32;
188
+ }
189
+ .markdown-content pre code {
190
+ background: none;
191
+ padding: 1rem;
192
+ display: block;
193
+ overflow-x: auto;
194
+ }
195
+ .markdown-content pre .copy-code-btn {
196
+ position: absolute;
197
+ top: 0.5rem;
198
+ right: 0.5rem;
199
+ background-color: #e8eaed;
200
+ border: 1px solid #dadce0;
201
+ color: #5f6368;
202
+ padding: 0.3rem 0.6rem;
203
+ border-radius: 0.25rem;
204
+ cursor: pointer;
205
+ opacity: 0;
206
+ transition: opacity 0.2s;
207
+ font-size: 0.8em;
208
+ }
209
+ .dark .markdown-content pre .copy-code-btn {
210
+ background-color: #3c4043;
211
+ border-color: #5f6368;
212
+ color: #e8eaed;
213
+ }
214
+ .markdown-content pre:hover .copy-code-btn {
215
+ opacity: 1;
216
+ }
217
 
218
  /* Spinner for the TTS button */
219
  .tts-button-loader {
 
286
  const chatContent = document.getElementById('chat-content');
287
  const chatFilename = document.getElementById('chat-filename');
288
 
 
289
  let sessionId = null;
290
  const storedSessionId = sessionStorage.getItem('cognichat_session_id');
291
  if (storedSessionId) {
 
361
  loadingOverlay.classList.add('hidden');
362
  loadingSubtext.textContent = '';
363
  fileNameSpan.textContent = '';
364
+ fileUploadInput.value = '';
365
  }
366
  }
367
 
 
393
  });
394
 
395
  if (!response.ok) throw new Error(`Server error: ${response.statusText}`);
396
+
397
+ // ============================ MODIFICATION START ==============================
398
+ // Parse the JSON response instead of reading a stream
399
+ const result = await response.json();
400
+ const answer = result.answer; // Extract the 'answer' field
401
+
402
+ if (!answer) {
403
+ throw new Error("Received an empty or invalid response from the server.");
404
+ }
405
 
406
  typingIndicator.remove();
407
  botMessageContainer = appendMessage('', 'bot');
408
  contentDiv = botMessageContainer.querySelector('.markdown-content');
409
 
410
+ // Use the extracted answer for rendering
411
+ contentDiv.innerHTML = marked.parse(answer);
 
 
 
 
 
 
 
 
 
 
412
  contentDiv.querySelectorAll('pre').forEach(addCopyButton);
413
+ scrollToBottom(); // Scroll after content is added
414
 
415
+ // Use the extracted answer for TTS
416
+ addTextToSpeechControls(botMessageContainer, answer);
417
+ // ============================ MODIFICATION END ==============================
418
 
419
  } catch (error) {
420
  console.error('Chat error:', error);
 
450
 
451
  const contentDiv = document.createElement('div');
452
  contentDiv.className = 'text-base markdown-content';
453
+ // Only parse if text is not empty
454
+ if (text) {
455
+ contentDiv.innerHTML = marked.parse(text);
456
+ }
457
 
458
  const controlsContainer = document.createElement('div');
459
  controlsContainer.className = 'tts-controls mt-2';
 
521
  });
522
  }
523
 
524
+ // --- Text-to-Speech Logic ---
525
  let currentAudio = null;
526
  let currentPlayingButton = null;
527
 
 
533
  const ttsControls = messageBubble.querySelector('.tts-controls');
534
  if (text.trim().length > 0) {
535
  const speakButton = document.createElement('button');
 
536
  speakButton.className = 'speak-btn px-4 py-2 bg-blue-700 text-white rounded-full text-sm font-medium hover:bg-blue-800 transition-colors flex items-center gap-2 disabled:opacity-50 disabled:cursor-not-allowed';
537
  speakButton.title = 'Listen to this message';
538
  speakButton.setAttribute('data-state', 'play');
 
543
  }
544
 
545
  async function handleTTS(text, button) {
 
 
 
546
  if (button === currentPlayingButton) {
547
+ if (currentAudio && !currentAudio.paused) {
548
  currentAudio.pause();
549
  button.setAttribute('data-state', 'paused');
550
  button.innerHTML = `${playIconSVG} <span>Play</span>`;
551
+ } else if (currentAudio && currentAudio.paused) {
552
  currentAudio.play();
553
  button.setAttribute('data-state', 'playing');
554
  button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
555
  }
556
+ return;
557
  }
558
 
 
 
559
  resetAllSpeakButtons();
560
 
561
  currentPlayingButton = button;
 
591
  button.setAttribute('data-state', 'error');
592
  button.innerHTML = `${playIconSVG} <span>Error</span>`;
593
  alert('Failed to play audio. Please try again.');
594
+ resetAllSpeakButtons();
595
  } finally {
596
  button.disabled = false;
597
  }
 
609
  }
610
  currentPlayingButton = null;
611
  }
 
612
  });
613
  </script>
614
  </body>
615
+ </html>