Spaces:
Sleeping
Sleeping
riteshraut commited on
Commit ·
46af083
1
Parent(s): 1dcbf42
feat/used the parentdocument retriever
Browse files- app.py +128 -434
- rag_processor.py +14 -13
- templates/index.html +142 -40
app.py
CHANGED
|
@@ -3,28 +3,20 @@
|
|
| 3 |
import os
|
| 4 |
import time
|
| 5 |
import uuid
|
| 6 |
-
from flask import Flask, request, render_template, session, jsonify, Response
|
| 7 |
from werkzeug.utils import secure_filename
|
| 8 |
from rag_processor import create_rag_chain
|
| 9 |
-
from typing import Sequence, Any
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
from gtts import gTTS
|
| 13 |
-
import io
|
| 14 |
import re
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
from langchain_community.document_loaders import (
|
| 18 |
TextLoader,
|
| 19 |
-
PyPDFLoader,
|
| 20 |
Docx2txtLoader,
|
| 21 |
)
|
| 22 |
-
|
| 23 |
-
# Additional imports for robust PDF handling
|
| 24 |
-
from langchain_core.documents import Document
|
| 25 |
-
import fitz
|
| 26 |
-
|
| 27 |
-
# Text Splitter, Embeddings, Retrievers
|
| 28 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 29 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 30 |
from langchain_community.vectorstores import FAISS
|
|
@@ -32,8 +24,8 @@ from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriev
|
|
| 32 |
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
| 33 |
from langchain_community.retrievers import BM25Retriever
|
| 34 |
from langchain_community.chat_message_histories import ChatMessageHistory
|
|
|
|
| 35 |
from sentence_transformers.cross_encoder import CrossEncoder
|
| 36 |
-
import numpy as np
|
| 37 |
|
| 38 |
|
| 39 |
app = Flask(__name__)
|
|
@@ -42,7 +34,7 @@ app.config['SECRET_KEY'] = os.urandom(24)
|
|
| 42 |
|
| 43 |
class LocalReranker(BaseDocumentCompressor):
|
| 44 |
model: Any
|
| 45 |
-
top_n: int =
|
| 46 |
|
| 47 |
class Config:
|
| 48 |
arbitrary_types_allowed = True
|
|
@@ -62,8 +54,11 @@ class LocalReranker(BaseDocumentCompressor):
|
|
| 62 |
doc_scores = list(zip(documents, scores))
|
| 63 |
sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True)
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"))
|
| 69 |
if is_hf_spaces:
|
|
@@ -71,451 +66,225 @@ if is_hf_spaces:
|
|
| 71 |
else:
|
| 72 |
app.config['UPLOAD_FOLDER'] = 'uploads'
|
| 73 |
|
| 74 |
-
|
| 75 |
try:
|
| 76 |
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
| 77 |
print(f"✓ Upload folder ready: {app.config['UPLOAD_FOLDER']}")
|
| 78 |
except Exception as e:
|
| 79 |
print(f"✗ Failed to create upload folder {app.config['UPLOAD_FOLDER']}: {e}")
|
| 80 |
-
# Fallback to /tmp if the configured path fails
|
| 81 |
app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
|
| 82 |
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
| 83 |
print(f"✓ Using fallback upload folder: {app.config['UPLOAD_FOLDER']}")
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
rag_chains = {}
|
| 90 |
message_histories = {}
|
| 91 |
|
| 92 |
print("Loading embedding model...")
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
cache_base = os.path.expanduser("~/.cache") if os.path.expanduser("~") != "~" else "/tmp/hf_cache"
|
| 96 |
-
os.environ.setdefault('HF_HOME', f'{cache_base}/huggingface')
|
| 97 |
-
os.environ.setdefault('HF_HUB_CACHE', f'{cache_base}/huggingface/hub')
|
| 98 |
-
os.environ.setdefault('TRANSFORMERS_CACHE', f'{cache_base}/transformers')
|
| 99 |
-
os.environ.setdefault('SENTENCE_TRANSFORMERS_HOME', f'{cache_base}/sentence_transformers')
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
cache_dirs = [
|
| 103 |
-
os.environ['HF_HOME'],
|
| 104 |
-
os.environ['HF_HUB_CACHE'],
|
| 105 |
-
os.environ['TRANSFORMERS_CACHE'],
|
| 106 |
-
os.environ['SENTENCE_TRANSFORMERS_HOME']
|
| 107 |
-
]
|
| 108 |
-
|
| 109 |
-
for cache_dir in cache_dirs:
|
| 110 |
-
try:
|
| 111 |
-
os.makedirs(cache_dir, mode=0o777, exist_ok=True)
|
| 112 |
-
# Set permissions explicitly after creation
|
| 113 |
-
os.chmod(cache_dir, 0o777)
|
| 114 |
-
print(f"Cache directory ready: {cache_dir}")
|
| 115 |
-
except PermissionError as pe:
|
| 116 |
-
print(f"Permission error for {cache_dir}: {pe}")
|
| 117 |
-
# Try to use a fallback directory in /tmp
|
| 118 |
-
fallback_dir = cache_dir.replace('/app/.cache', '/tmp/hf_cache')
|
| 119 |
-
try:
|
| 120 |
-
os.makedirs(fallback_dir, mode=0o777, exist_ok=True)
|
| 121 |
-
print(f"Using fallback cache directory: {fallback_dir}")
|
| 122 |
-
# Update environment variable to point to fallback
|
| 123 |
-
env_var_map = {
|
| 124 |
-
'/app/.cache/huggingface': 'HF_HOME',
|
| 125 |
-
'/app/.cache/huggingface/hub': 'HF_HUB_CACHE',
|
| 126 |
-
'/app/.cache/transformers': 'TRANSFORMERS_CACHE',
|
| 127 |
-
'/app/.cache/sentence_transformers': 'SENTENCE_TRANSFORMERS_HOME'
|
| 128 |
-
}
|
| 129 |
-
if cache_dir in env_var_map:
|
| 130 |
-
os.environ[env_var_map[cache_dir]] = fallback_dir
|
| 131 |
-
except Exception as fe:
|
| 132 |
-
print(f"Failed to create fallback directory {fallback_dir}: {fe}")
|
| 133 |
-
except Exception as e:
|
| 134 |
-
print(f"Warning: Could not create {cache_dir}: {e}")
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
# Try loading embedding model with error handling and fallbacks
|
| 139 |
try:
|
| 140 |
-
print("Attempting to load embedding model...")
|
| 141 |
EMBEDDING_MODEL = HuggingFaceEmbeddings(
|
| 142 |
-
model_name="sentence-transformers/all-
|
| 143 |
-
model_kwargs={'device': 'cpu'}
|
| 144 |
-
cache_folder=os.environ.get('SENTENCE_TRANSFORMERS_HOME')
|
| 145 |
)
|
| 146 |
-
print("Embedding model loaded successfully.")
|
| 147 |
except Exception as e:
|
| 148 |
-
print(f"
|
| 149 |
-
|
| 150 |
-
try:
|
| 151 |
-
EMBEDDING_MODEL = HuggingFaceEmbeddings(
|
| 152 |
-
model_name="all-miniLM-L6-v2",
|
| 153 |
-
model_kwargs={'device': 'cpu'},
|
| 154 |
-
cache_folder=os.environ.get('SENTENCE_TRANSFORMERS_HOME')
|
| 155 |
-
)
|
| 156 |
-
print("Embedding model with shortened name loaded successfully.")
|
| 157 |
-
except Exception as e2:
|
| 158 |
-
print(f"Failed to load embedding model with shortened name: {e2}")
|
| 159 |
-
print("Trying without cache folder specification...")
|
| 160 |
-
try:
|
| 161 |
-
EMBEDDING_MODEL = HuggingFaceEmbeddings(
|
| 162 |
-
model_name="sentence-transformers/all-miniLM-L6-v2",
|
| 163 |
-
model_kwargs={'device': 'cpu'}
|
| 164 |
-
)
|
| 165 |
-
print("Embedding model loaded without cache folder specification.")
|
| 166 |
-
except Exception as e3:
|
| 167 |
-
print(f"Final attempt failed: {e3}")
|
| 168 |
-
# Use a simpler fallback model or raise the error
|
| 169 |
-
raise Exception(f"Could not load any embedding model. Last error: {e3}")
|
| 170 |
-
|
| 171 |
-
|
| 172 |
|
| 173 |
print("Loading local re-ranking model...")
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
def load_pdf_with_fallback(filepath):
|
| 179 |
-
"""
|
| 180 |
-
Tries to load PDF using multiple methods for better reliability.
|
| 181 |
-
Returns a list of Document objects.
|
| 182 |
-
"""
|
| 183 |
-
# Method 1: Try PyPDFLoader (primary method)
|
| 184 |
try:
|
| 185 |
-
loader = PyPDFLoader(filepath)
|
| 186 |
-
docs = loader.load()
|
| 187 |
-
if docs and any(doc.page_content.strip() for doc in docs):
|
| 188 |
-
print(f"✓ Successfully loaded PDF using PyPDFLoader: {filepath}")
|
| 189 |
-
return docs
|
| 190 |
-
except Exception as e:
|
| 191 |
-
print(f"PyPDFLoader failed for {filepath}: {e}")
|
| 192 |
-
|
| 193 |
-
# Method 2: Try PyMuPDF (fitz) as fallback
|
| 194 |
-
try:
|
| 195 |
-
import fitz
|
| 196 |
docs = []
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
if text.strip(): # Only add pages with content
|
| 204 |
-
doc = Document(
|
| 205 |
-
page_content=text,
|
| 206 |
-
metadata={
|
| 207 |
-
"source": filepath,
|
| 208 |
-
"page": page_num + 1,
|
| 209 |
-
"total_pages": pdf_doc.page_count
|
| 210 |
-
}
|
| 211 |
-
)
|
| 212 |
-
docs.append(doc)
|
| 213 |
-
|
| 214 |
-
pdf_doc.close()
|
| 215 |
-
|
| 216 |
-
if docs:
|
| 217 |
-
print(f"✓ Successfully loaded PDF using PyMuPDF: {filepath}")
|
| 218 |
-
return docs
|
| 219 |
-
|
| 220 |
-
except Exception as e:
|
| 221 |
-
print(f"PyMuPDF failed for {filepath}: {e}")
|
| 222 |
-
|
| 223 |
-
# Method 3: Try to extract text using pdfplumber (if available)
|
| 224 |
-
try:
|
| 225 |
-
import pdfplumber
|
| 226 |
-
docs = []
|
| 227 |
-
|
| 228 |
-
with pdfplumber.open(filepath) as pdf:
|
| 229 |
-
for page_num, page in enumerate(pdf.pages):
|
| 230 |
-
text = page.extract_text()
|
| 231 |
-
if text and text.strip():
|
| 232 |
-
doc = Document(
|
| 233 |
page_content=text,
|
| 234 |
metadata={
|
| 235 |
-
"source": filepath,
|
| 236 |
"page": page_num + 1,
|
| 237 |
-
"total_pages": len(pdf.pages)
|
| 238 |
}
|
| 239 |
-
)
|
| 240 |
-
docs.append(doc)
|
| 241 |
-
|
| 242 |
if docs:
|
| 243 |
-
print(f"✓ Successfully loaded PDF
|
| 244 |
return docs
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
print("pdfplumber not available, skipping method 3")
|
| 248 |
except Exception as e:
|
| 249 |
-
print(f"
|
| 250 |
-
|
| 251 |
-
# If all methods fail, raise the original error
|
| 252 |
-
raise Exception(f"Could not load PDF {filepath} using any available method. The file may be corrupted, password-protected, or in an unsupported format.")
|
| 253 |
|
| 254 |
-
# A dictionary to map file extensions to their corresponding loader classes
|
| 255 |
LOADER_MAPPING = {
|
| 256 |
".txt": TextLoader,
|
| 257 |
-
".pdf": load_pdf_with_fallback,
|
| 258 |
".docx": Docx2txtLoader,
|
| 259 |
}
|
| 260 |
|
| 261 |
def get_session_history(session_id: str) -> ChatMessageHistory:
|
| 262 |
-
"""
|
| 263 |
-
Retrieves the chat history for a given session ID. If it doesn't exist,
|
| 264 |
-
a new history object is created.
|
| 265 |
-
"""
|
| 266 |
if session_id not in message_histories:
|
| 267 |
message_histories[session_id] = ChatMessageHistory()
|
| 268 |
return message_histories[session_id]
|
| 269 |
|
| 270 |
@app.route('/health', methods=['GET'])
|
| 271 |
def health_check():
|
| 272 |
-
|
| 273 |
-
return jsonify({'status': 'healthy', 'message': 'CogniChat is running'}), 200
|
| 274 |
|
| 275 |
@app.route('/', methods=['GET'])
|
| 276 |
def index():
|
| 277 |
-
"""Renders the main page."""
|
| 278 |
return render_template('index.html')
|
| 279 |
|
| 280 |
@app.route('/upload', methods=['POST'])
|
| 281 |
def upload_files():
|
| 282 |
-
"""Handles multiple file uploads, processing, and RAG chain creation."""
|
| 283 |
files = request.files.getlist('file')
|
| 284 |
-
|
| 285 |
if not files or all(f.filename == '' for f in files):
|
| 286 |
return jsonify({'status': 'error', 'message': 'No selected files.'}), 400
|
| 287 |
|
| 288 |
all_docs = []
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
| 301 |
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
print(f"✓ Successfully saved file: {filepath}")
|
| 305 |
-
except Exception as save_error:
|
| 306 |
-
print(f"✗ Failed to save file {filename}: {save_error}")
|
| 307 |
-
failed_files.append(f"{filename} (failed to save: {save_error})")
|
| 308 |
-
continue
|
| 309 |
-
|
| 310 |
-
file_extension = os.path.splitext(filename)[1].lower()
|
| 311 |
-
if file_extension not in LOADER_MAPPING:
|
| 312 |
-
print(f"Skipping unsupported file type: {filename}")
|
| 313 |
-
failed_files.append(f"{filename} (unsupported format)")
|
| 314 |
-
continue
|
| 315 |
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
# --- Process all documents together ---
|
| 364 |
-
print(f"Successfully processed {len(processed_files)} files, creating knowledge base...")
|
| 365 |
-
text_splitter = SemanticChunker(EMBEDDING_MODEL, breakpoint_threshold_type="percentile")
|
| 366 |
-
splits = text_splitter.split_documents(all_docs)
|
| 367 |
-
|
| 368 |
-
print("Creating vector store for all documents...")
|
| 369 |
-
vectorstore = FAISS.from_documents(documents=splits, embedding=EMBEDDING_MODEL)
|
| 370 |
-
|
| 371 |
-
bm25_retriever = BM25Retriever.from_documents(splits)
|
| 372 |
-
bm25_retriever.k = 7
|
| 373 |
-
faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 7})
|
| 374 |
ensemble_retriever = EnsembleRetriever(
|
| 375 |
retrievers=[bm25_retriever, faiss_retriever],
|
| 376 |
-
weights=[0.
|
| 377 |
)
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
|
| 380 |
compression_retriever = ContextualCompressionRetriever(
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
|
| 386 |
-
|
| 387 |
-
print(f"Creating RAG chain for session {session_id}...")
|
| 388 |
|
| 389 |
-
|
| 390 |
-
rag_chain = create_rag_chain(compression_retriever, get_session_history)
|
| 391 |
-
rag_chains[session_id] = rag_chain
|
| 392 |
-
print(f"✓ RAG chain created successfully for session {session_id} with {len(processed_files)} documents.")
|
| 393 |
-
except Exception as rag_error:
|
| 394 |
-
print(f"✗ Failed to create RAG chain: {rag_error}")
|
| 395 |
-
import traceback
|
| 396 |
-
traceback.print_exc()
|
| 397 |
-
raise rag_error
|
| 398 |
|
|
|
|
|
|
|
|
|
|
| 399 |
session['session_id'] = session_id
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
# Prepare success message with file processing summary
|
| 403 |
success_msg = f"Successfully processed: {', '.join(processed_files)}"
|
| 404 |
if failed_files:
|
| 405 |
-
success_msg += f"\
|
| 406 |
-
|
| 407 |
-
success_msg += f"\n...and {len(failed_files) - 3} more"
|
| 408 |
-
|
| 409 |
return jsonify({
|
| 410 |
'status': 'success',
|
| 411 |
'filename': success_msg,
|
| 412 |
-
'
|
| 413 |
-
'failed_count': len(failed_files),
|
| 414 |
-
'session_id': session_id # Include session_id in response for client to store
|
| 415 |
})
|
| 416 |
|
| 417 |
except Exception as e:
|
| 418 |
-
|
| 419 |
-
|
|
|
|
| 420 |
|
| 421 |
@app.route('/chat', methods=['POST'])
|
| 422 |
def chat():
|
| 423 |
-
"""Handles chat messages and streams the response with memory."""
|
| 424 |
data = request.get_json()
|
| 425 |
question = data.get('question')
|
| 426 |
-
|
| 427 |
-
# Try to get session_id from multiple sources (Flask session or request body)
|
| 428 |
session_id = session.get('session_id') or data.get('session_id')
|
| 429 |
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
print(f"Question: '{question}' (type: {type(question)})")
|
| 433 |
-
print(f"Session ID from Flask session: {session.get('session_id')}")
|
| 434 |
-
print(f"Session ID from request body: {data.get('session_id')}")
|
| 435 |
-
print(f"Final session ID used: {session_id}")
|
| 436 |
-
print(f"Available RAG chains: {list(rag_chains.keys())}")
|
| 437 |
-
print(f"Total RAG chains: {len(rag_chains)}")
|
| 438 |
-
|
| 439 |
-
# Check each condition individually with detailed logging
|
| 440 |
-
if not question:
|
| 441 |
-
print("✗ FAILURE: No question provided")
|
| 442 |
-
return jsonify({'status': 'error', 'message': 'No question provided.'}), 400
|
| 443 |
-
else:
|
| 444 |
-
print("✓ Question is valid")
|
| 445 |
-
|
| 446 |
-
if not session_id:
|
| 447 |
-
print("✗ FAILURE: No session ID in Flask session")
|
| 448 |
-
return jsonify({'status': 'error', 'message': 'No session found. Please upload documents first.'}), 400
|
| 449 |
-
else:
|
| 450 |
-
print(f"✓ Session ID found: {session_id}")
|
| 451 |
-
|
| 452 |
-
if session_id not in rag_chains:
|
| 453 |
-
print(f"✗ FAILURE: Session ID {session_id} not found in RAG chains")
|
| 454 |
-
print(f"Available sessions: {list(rag_chains.keys())}")
|
| 455 |
-
return jsonify({'status': 'error', 'message': 'Session not found. Please upload documents again.'}), 400
|
| 456 |
-
else:
|
| 457 |
-
print(f"✓ RAG chain found for session: {session_id}")
|
| 458 |
|
| 459 |
try:
|
| 460 |
rag_chain = rag_chains[session_id]
|
| 461 |
config = {"configurable": {"session_id": session_id}}
|
| 462 |
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
# Create the input dictionary with question and config
|
| 466 |
-
input_data = {
|
| 467 |
-
"question": question,
|
| 468 |
-
"config": config
|
| 469 |
-
}
|
| 470 |
-
|
| 471 |
-
try:
|
| 472 |
-
# For streaming, we need to handle the chain differently
|
| 473 |
-
# The final chain doesn't support streaming, so we'll invoke it directly
|
| 474 |
-
response = rag_chain.invoke(input_data)
|
| 475 |
-
yield response
|
| 476 |
-
except Exception as stream_error:
|
| 477 |
-
print(f"Streaming error: {stream_error}")
|
| 478 |
-
# Fallback to direct invocation
|
| 479 |
-
try:
|
| 480 |
-
response = rag_chain.invoke(input_data)
|
| 481 |
-
yield response
|
| 482 |
-
except Exception as invoke_error:
|
| 483 |
-
print(f"Invoke error: {invoke_error}")
|
| 484 |
-
yield "I apologize, but I'm having trouble processing your question. Please try again or upload your documents again."
|
| 485 |
|
| 486 |
-
|
|
|
|
|
|
|
| 487 |
|
| 488 |
except Exception as e:
|
| 489 |
print(f"Error during chat invocation: {e}")
|
| 490 |
-
return
|
| 491 |
-
|
| 492 |
|
| 493 |
def clean_markdown_for_tts(text: str) -> str:
|
| 494 |
-
"""Removes markdown formatting for cleaner text-to-speech output."""
|
| 495 |
-
# Remove bold (**text**) and italics (*text* or _text_)
|
| 496 |
text = re.sub(r'\*(\*?)(.*?)\1\*', r'\2', text)
|
| 497 |
text = re.sub(r'\_(.*?)\_', r'\1', text)
|
| 498 |
-
# Remove inline code (`code`)
|
| 499 |
text = re.sub(r'`(.*?)`', r'\1', text)
|
| 500 |
-
# Remove headings (e.g., #, ##, ###)
|
| 501 |
text = re.sub(r'^\s*#{1,6}\s+', '', text, flags=re.MULTILINE)
|
| 502 |
-
# Remove list item markers (*, -, 1.)
|
| 503 |
text = re.sub(r'^\s*[\*\-]\s+', '', text, flags=re.MULTILINE)
|
| 504 |
text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
|
| 505 |
-
# Remove blockquotes (>)
|
| 506 |
text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE)
|
| 507 |
-
|
| 508 |
-
# --- NEW ADDITION HERE ---
|
| 509 |
-
# Remove horizontal rules (---, ***, ===)
|
| 510 |
text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE)
|
| 511 |
-
|
| 512 |
-
# Replace multiple newlines with a single space
|
| 513 |
text = re.sub(r'\n+', ' ', text)
|
| 514 |
return text.strip()
|
| 515 |
|
| 516 |
@app.route('/tts', methods=['POST'])
|
| 517 |
def text_to_speech():
|
| 518 |
-
"""Generates audio from text and returns it as an MP3 stream."""
|
| 519 |
data = request.get_json()
|
| 520 |
text = data.get('text')
|
| 521 |
|
|
@@ -523,9 +292,7 @@ def text_to_speech():
|
|
| 523 |
return jsonify({'status': 'error', 'message': 'No text provided.'}), 400
|
| 524 |
|
| 525 |
try:
|
| 526 |
-
# --- FIX IS HERE: Clean the text before sending to gTTS ---
|
| 527 |
clean_text = clean_markdown_for_tts(text)
|
| 528 |
-
|
| 529 |
tts = gTTS(clean_text, lang='en')
|
| 530 |
mp3_fp = io.BytesIO()
|
| 531 |
tts.write_to_fp(mp3_fp)
|
|
@@ -535,79 +302,6 @@ def text_to_speech():
|
|
| 535 |
print(f"Error in TTS generation: {e}")
|
| 536 |
return jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500
|
| 537 |
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
@app.route('/debug', methods=['GET'])
|
| 541 |
-
def debug_info():
|
| 542 |
-
"""Debug endpoint to check configuration."""
|
| 543 |
-
api_key = os.getenv("GROQ_API_KEY")
|
| 544 |
-
is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"))
|
| 545 |
-
|
| 546 |
-
return jsonify({
|
| 547 |
-
'environment': 'Hugging Face Spaces' if is_hf_spaces else 'Local/Other',
|
| 548 |
-
'upload_folder': app.config['UPLOAD_FOLDER'],
|
| 549 |
-
'upload_folder_writable': os.access(app.config['UPLOAD_FOLDER'], os.W_OK),
|
| 550 |
-
'groq_api_key_set': bool(api_key and api_key != "your_groq_api_key_here"),
|
| 551 |
-
'groq_api_key_placeholder': api_key == "your_groq_api_key_here" if api_key else False,
|
| 552 |
-
'sessions_count': len(rag_chains),
|
| 553 |
-
'session_ids': list(rag_chains.keys()),
|
| 554 |
-
'flask_session_id': session.get('session_id'),
|
| 555 |
-
'flask_session_keys': list(session.keys()) if session else [],
|
| 556 |
-
'embedding_model_loaded': 'EMBEDDING_MODEL' in globals(),
|
| 557 |
-
'space_id': os.getenv("SPACE_ID"),
|
| 558 |
-
'hf_spaces_detected': is_hf_spaces
|
| 559 |
-
})
|
| 560 |
-
|
| 561 |
-
@app.route('/test-session', methods=['GET', 'POST'])
|
| 562 |
-
def test_session():
|
| 563 |
-
"""Test endpoint to debug Flask session issues."""
|
| 564 |
-
if request.method == 'POST':
|
| 565 |
-
# Set a test session value
|
| 566 |
-
session['test_key'] = 'test_value'
|
| 567 |
-
session['timestamp'] = str(time.time())
|
| 568 |
-
return jsonify({
|
| 569 |
-
'action': 'session_set',
|
| 570 |
-
'test_key': session.get('test_key'),
|
| 571 |
-
'timestamp': session.get('timestamp'),
|
| 572 |
-
'session_keys': list(session.keys())
|
| 573 |
-
})
|
| 574 |
-
else:
|
| 575 |
-
# Read session values
|
| 576 |
-
return jsonify({
|
| 577 |
-
'action': 'session_read',
|
| 578 |
-
'test_key': session.get('test_key'),
|
| 579 |
-
'timestamp': session.get('timestamp'),
|
| 580 |
-
'session_id': session.get('session_id'),
|
| 581 |
-
'session_keys': list(session.keys()),
|
| 582 |
-
'has_session_data': bool(session)
|
| 583 |
-
})
|
| 584 |
-
|
| 585 |
if __name__ == '__main__':
|
| 586 |
-
print(f"\n=== Application Startup Complete at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
|
| 587 |
-
|
| 588 |
-
# Check environment
|
| 589 |
-
is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"))
|
| 590 |
-
api_key = os.getenv('GROQ_API_KEY')
|
| 591 |
-
|
| 592 |
-
print(f"Environment: {'Hugging Face Spaces' if is_hf_spaces else 'Local Development'}")
|
| 593 |
-
print(f"Upload folder: {app.config['UPLOAD_FOLDER']}")
|
| 594 |
-
|
| 595 |
-
if is_hf_spaces:
|
| 596 |
-
print(f"Space ID: {os.getenv('SPACE_ID', 'Not set')}")
|
| 597 |
-
print(f"GROQ API Key: {'Set via HF Secrets' if api_key and api_key != 'your_groq_api_key_here' else 'NOT SET - Add to Space Secrets'}")
|
| 598 |
-
else:
|
| 599 |
-
print(f"GROQ API Key: {'Set' if api_key and api_key != 'your_groq_api_key_here' else 'NOT SET'}")
|
| 600 |
-
|
| 601 |
-
print(f"Embedding model: {'Loaded' if 'EMBEDDING_MODEL' in globals() else 'NOT LOADED'}")
|
| 602 |
-
|
| 603 |
-
if not api_key or api_key == 'your_groq_api_key_here':
|
| 604 |
-
print("\n⚠️ WARNING: GROQ API KEY NOT CONFIGURED!")
|
| 605 |
-
if is_hf_spaces:
|
| 606 |
-
print(" → Go to your Space Settings > Repository Secrets")
|
| 607 |
-
print(" → Add GROQ_API_KEY as a secret")
|
| 608 |
-
else:
|
| 609 |
-
print(" → Update .env file with your GROQ API key")
|
| 610 |
-
|
| 611 |
-
# Use port 7860 for Hugging Face Spaces, fallback to 5001 for local development
|
| 612 |
port = int(os.environ.get("PORT", 7860))
|
| 613 |
-
app.run(host="0.0.0.0", port=port, debug=False)
|
|
|
|
| 3 |
import os
|
| 4 |
import time
|
| 5 |
import uuid
|
| 6 |
+
from flask import Flask, request, render_template, session, jsonify, Response
|
| 7 |
from werkzeug.utils import secure_filename
|
| 8 |
from rag_processor import create_rag_chain
|
| 9 |
+
from typing import Sequence, Any, List
|
| 10 |
+
import fitz
|
|
|
|
|
|
|
|
|
|
| 11 |
import re
|
| 12 |
+
import io
|
| 13 |
+
from gtts import gTTS
|
| 14 |
|
| 15 |
+
from langchain_core.documents import Document
|
| 16 |
from langchain_community.document_loaders import (
|
| 17 |
TextLoader,
|
|
|
|
| 18 |
Docx2txtLoader,
|
| 19 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 21 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 22 |
from langchain_community.vectorstores import FAISS
|
|
|
|
| 24 |
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
| 25 |
from langchain_community.retrievers import BM25Retriever
|
| 26 |
from langchain_community.chat_message_histories import ChatMessageHistory
|
| 27 |
+
from langchain.storage import InMemoryStore
|
| 28 |
from sentence_transformers.cross_encoder import CrossEncoder
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
app = Flask(__name__)
|
|
|
|
| 34 |
|
| 35 |
class LocalReranker(BaseDocumentCompressor):
|
| 36 |
model: Any
|
| 37 |
+
top_n: int = 3
|
| 38 |
|
| 39 |
class Config:
|
| 40 |
arbitrary_types_allowed = True
|
|
|
|
| 54 |
doc_scores = list(zip(documents, scores))
|
| 55 |
sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True)
|
| 56 |
|
| 57 |
+
top_docs = []
|
| 58 |
+
for doc, score in sorted_doc_scores[:self.top_n]:
|
| 59 |
+
doc.metadata['rerank_score'] = float(score)
|
| 60 |
+
top_docs.append(doc)
|
| 61 |
+
return top_docs
|
| 62 |
|
| 63 |
is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"))
|
| 64 |
if is_hf_spaces:
|
|
|
|
| 66 |
else:
|
| 67 |
app.config['UPLOAD_FOLDER'] = 'uploads'
|
| 68 |
|
|
|
|
| 69 |
try:
|
| 70 |
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
| 71 |
print(f"✓ Upload folder ready: {app.config['UPLOAD_FOLDER']}")
|
| 72 |
except Exception as e:
|
| 73 |
print(f"✗ Failed to create upload folder {app.config['UPLOAD_FOLDER']}: {e}")
|
|
|
|
| 74 |
app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
|
| 75 |
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
| 76 |
print(f"✓ Using fallback upload folder: {app.config['UPLOAD_FOLDER']}")
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
rag_chains = {}
|
| 79 |
message_histories = {}
|
| 80 |
|
| 81 |
print("Loading embedding model...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
try:
|
|
|
|
| 83 |
EMBEDDING_MODEL = HuggingFaceEmbeddings(
|
| 84 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 85 |
+
model_kwargs={'device': 'cpu'}
|
|
|
|
| 86 |
)
|
| 87 |
+
print("✓ Embedding model loaded successfully.")
|
| 88 |
except Exception as e:
|
| 89 |
+
print(f"✗ FATAL: Could not load embedding model. Error: {e}")
|
| 90 |
+
raise e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
print("Loading local re-ranking model...")
|
| 93 |
+
try:
|
| 94 |
+
RERANKER_MODEL = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device='cpu')
|
| 95 |
+
print("✓ Re-ranking model loaded successfully.")
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"✗ FATAL: Could not load reranker model. Error: {e}")
|
| 98 |
+
raise e
|
| 99 |
|
| 100 |
def load_pdf_with_fallback(filepath):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
docs = []
|
| 103 |
+
with fitz.open(filepath) as pdf_doc:
|
| 104 |
+
for page_num, page in enumerate(pdf_doc):
|
| 105 |
+
text = page.get_text()
|
| 106 |
+
if text.strip():
|
| 107 |
+
docs.append(Document(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
page_content=text,
|
| 109 |
metadata={
|
| 110 |
+
"source": os.path.basename(filepath),
|
| 111 |
"page": page_num + 1,
|
|
|
|
| 112 |
}
|
| 113 |
+
))
|
|
|
|
|
|
|
| 114 |
if docs:
|
| 115 |
+
print(f"✓ Successfully loaded PDF with PyMuPDF: {filepath}")
|
| 116 |
return docs
|
| 117 |
+
else:
|
| 118 |
+
raise ValueError("No text content found in PDF.")
|
|
|
|
| 119 |
except Exception as e:
|
| 120 |
+
print(f"✗ PyMuPDF failed for {filepath}: {e}")
|
| 121 |
+
raise
|
|
|
|
|
|
|
| 122 |
|
|
|
|
| 123 |
LOADER_MAPPING = {
|
| 124 |
".txt": TextLoader,
|
| 125 |
+
".pdf": load_pdf_with_fallback,
|
| 126 |
".docx": Docx2txtLoader,
|
| 127 |
}
|
| 128 |
|
| 129 |
def get_session_history(session_id: str) -> ChatMessageHistory:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
if session_id not in message_histories:
|
| 131 |
message_histories[session_id] = ChatMessageHistory()
|
| 132 |
return message_histories[session_id]
|
| 133 |
|
| 134 |
@app.route('/health', methods=['GET'])
|
| 135 |
def health_check():
|
| 136 |
+
return jsonify({'status': 'healthy'}), 200
|
|
|
|
| 137 |
|
| 138 |
@app.route('/', methods=['GET'])
|
| 139 |
def index():
|
|
|
|
| 140 |
return render_template('index.html')
|
| 141 |
|
| 142 |
@app.route('/upload', methods=['POST'])
|
| 143 |
def upload_files():
|
|
|
|
| 144 |
files = request.files.getlist('file')
|
|
|
|
| 145 |
if not files or all(f.filename == '' for f in files):
|
| 146 |
return jsonify({'status': 'error', 'message': 'No selected files.'}), 400
|
| 147 |
|
| 148 |
all_docs = []
|
| 149 |
+
processed_files, failed_files = [], []
|
| 150 |
+
|
| 151 |
+
for file in files:
|
| 152 |
+
if file and file.filename:
|
| 153 |
+
filename = secure_filename(file.filename)
|
| 154 |
+
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
| 155 |
+
try:
|
| 156 |
+
file.save(filepath)
|
| 157 |
+
file_ext = os.path.splitext(filename)[1].lower()
|
| 158 |
+
if file_ext not in LOADER_MAPPING:
|
| 159 |
+
raise ValueError("Unsupported file format.")
|
|
|
|
| 160 |
|
| 161 |
+
loader_func = LOADER_MAPPING[file_ext]
|
| 162 |
+
docs = loader_func(filepath) if file_ext == ".pdf" else loader_func(filepath).load()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
+
if not docs:
|
| 165 |
+
raise ValueError("No content extracted.")
|
| 166 |
+
|
| 167 |
+
all_docs.extend(docs)
|
| 168 |
+
processed_files.append(filename)
|
| 169 |
+
print(f"✓ Successfully processed: {filename}")
|
| 170 |
+
except Exception as e:
|
| 171 |
+
error_msg = str(e)
|
| 172 |
+
print(f"✗ Error processing {filename}: {error_msg}")
|
| 173 |
+
failed_files.append(f"{filename} ({error_msg})")
|
| 174 |
+
|
| 175 |
+
if not all_docs:
|
| 176 |
+
error_summary = "Failed to process all files."
|
| 177 |
+
if failed_files:
|
| 178 |
+
error_summary += " Reasons: " + ", ".join(failed_files)
|
| 179 |
+
return jsonify({'status': 'error', 'message': error_summary}), 400
|
| 180 |
+
|
| 181 |
+
try:
|
| 182 |
+
print("Starting RAG pipeline setup...")
|
| 183 |
+
|
| 184 |
+
vectorstore = FAISS.from_texts([""], EMBEDDING_MODEL)
|
| 185 |
+
store = InMemoryStore()
|
| 186 |
+
|
| 187 |
+
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
|
| 188 |
+
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=50)
|
| 189 |
+
|
| 190 |
+
parent_docs = parent_splitter.split_documents(all_docs)
|
| 191 |
+
doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
|
| 192 |
+
|
| 193 |
+
child_docs = []
|
| 194 |
+
for i, doc in enumerate(parent_docs):
|
| 195 |
+
_id = doc_ids[i]
|
| 196 |
+
sub_docs = child_splitter.split_documents([doc])
|
| 197 |
+
for child in sub_docs:
|
| 198 |
+
child.metadata["doc_id"] = _id
|
| 199 |
+
child_docs.extend(sub_docs)
|
| 200 |
+
|
| 201 |
+
store.mset(list(zip(doc_ids, parent_docs)))
|
| 202 |
+
vectorstore.add_documents(child_docs)
|
| 203 |
+
print(f"✓ Stored {len(parent_docs)} parent docs and indexed {len(child_docs)} child docs.")
|
| 204 |
+
|
| 205 |
+
bm25_retriever = BM25Retriever.from_documents(child_docs)
|
| 206 |
+
bm25_retriever.k = 8
|
| 207 |
+
|
| 208 |
+
faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 8})
|
| 209 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
ensemble_retriever = EnsembleRetriever(
|
| 211 |
retrievers=[bm25_retriever, faiss_retriever],
|
| 212 |
+
weights=[0.4, 0.6]
|
| 213 |
)
|
| 214 |
+
print("✓ Created Hybrid Retriever for child documents.")
|
| 215 |
+
|
| 216 |
+
reranker = LocalReranker(model=RERANKER_MODEL, top_n=4)
|
| 217 |
+
|
| 218 |
+
def get_parents(docs: List[Document]) -> List[Document]:
|
| 219 |
+
parent_ids = {d.metadata["doc_id"] for d in docs}
|
| 220 |
+
return store.mget(list(parent_ids))
|
| 221 |
|
| 222 |
compression_retriever = ContextualCompressionRetriever(
|
| 223 |
+
base_compressor=reranker, base_retriever=ensemble_retriever
|
| 224 |
+
)
|
|
|
|
|
|
|
| 225 |
|
| 226 |
+
final_retriever = compression_retriever | get_parents
|
|
|
|
| 227 |
|
| 228 |
+
print("✓ Final retriever chain created: (Hybrid -> Rerank) -> Parent Fetch")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
+
session_id = str(uuid.uuid4())
|
| 231 |
+
rag_chain = create_rag_chain(final_retriever, get_session_history)
|
| 232 |
+
rag_chains[session_id] = rag_chain
|
| 233 |
session['session_id'] = session_id
|
| 234 |
+
|
|
|
|
|
|
|
| 235 |
success_msg = f"Successfully processed: {', '.join(processed_files)}"
|
| 236 |
if failed_files:
|
| 237 |
+
success_msg += f"\nFailed to process: {', '.join(failed_files)}"
|
| 238 |
+
|
|
|
|
|
|
|
| 239 |
return jsonify({
|
| 240 |
'status': 'success',
|
| 241 |
'filename': success_msg,
|
| 242 |
+
'session_id': session_id
|
|
|
|
|
|
|
| 243 |
})
|
| 244 |
|
| 245 |
except Exception as e:
|
| 246 |
+
import traceback
|
| 247 |
+
traceback.print_exc()
|
| 248 |
+
return jsonify({'status': 'error', 'message': f'Failed during RAG setup: {e}'}), 500
|
| 249 |
|
| 250 |
@app.route('/chat', methods=['POST'])
|
| 251 |
def chat():
|
|
|
|
| 252 |
data = request.get_json()
|
| 253 |
question = data.get('question')
|
|
|
|
|
|
|
| 254 |
session_id = session.get('session_id') or data.get('session_id')
|
| 255 |
|
| 256 |
+
if not question or not session_id or session_id not in rag_chains:
|
| 257 |
+
return jsonify({'status': 'error', 'message': 'Invalid session or no question provided.'}), 400
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
try:
|
| 260 |
rag_chain = rag_chains[session_id]
|
| 261 |
config = {"configurable": {"session_id": session_id}}
|
| 262 |
|
| 263 |
+
# Invoke the chain, which will return a string
|
| 264 |
+
answer_string = rag_chain.invoke({"question": question}, config=config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
+
# --- THIS IS THE FIX ---
|
| 267 |
+
# Directly use the returned string in the JSON response.
|
| 268 |
+
return jsonify({'answer': answer_string})
|
| 269 |
|
| 270 |
except Exception as e:
|
| 271 |
print(f"Error during chat invocation: {e}")
|
| 272 |
+
return jsonify({'status': 'error', 'message': 'An error occurred while getting the answer.'}), 500
|
|
|
|
| 273 |
|
| 274 |
def clean_markdown_for_tts(text: str) -> str:
|
|
|
|
|
|
|
| 275 |
text = re.sub(r'\*(\*?)(.*?)\1\*', r'\2', text)
|
| 276 |
text = re.sub(r'\_(.*?)\_', r'\1', text)
|
|
|
|
| 277 |
text = re.sub(r'`(.*?)`', r'\1', text)
|
|
|
|
| 278 |
text = re.sub(r'^\s*#{1,6}\s+', '', text, flags=re.MULTILINE)
|
|
|
|
| 279 |
text = re.sub(r'^\s*[\*\-]\s+', '', text, flags=re.MULTILINE)
|
| 280 |
text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
|
|
|
|
| 281 |
text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE)
|
|
|
|
|
|
|
|
|
|
| 282 |
text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE)
|
|
|
|
|
|
|
| 283 |
text = re.sub(r'\n+', ' ', text)
|
| 284 |
return text.strip()
|
| 285 |
|
| 286 |
@app.route('/tts', methods=['POST'])
|
| 287 |
def text_to_speech():
|
|
|
|
| 288 |
data = request.get_json()
|
| 289 |
text = data.get('text')
|
| 290 |
|
|
|
|
| 292 |
return jsonify({'status': 'error', 'message': 'No text provided.'}), 400
|
| 293 |
|
| 294 |
try:
|
|
|
|
| 295 |
clean_text = clean_markdown_for_tts(text)
|
|
|
|
| 296 |
tts = gTTS(clean_text, lang='en')
|
| 297 |
mp3_fp = io.BytesIO()
|
| 298 |
tts.write_to_fp(mp3_fp)
|
|
|
|
| 302 |
print(f"Error in TTS generation: {e}")
|
| 303 |
return jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500
|
| 304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
port = int(os.environ.get("PORT", 7860))
|
| 307 |
+
app.run(host="0.0.0.0", port=port, debug=False)
|
rag_processor.py
CHANGED
|
@@ -35,9 +35,11 @@ def create_rag_chain(retriever, get_session_history_func):
|
|
| 35 |
|
| 36 |
# Get API key from environment (works for both HF Spaces secrets and local .env)
|
| 37 |
api_key = os.getenv("GROQ_API_KEY")
|
|
|
|
| 38 |
|
| 39 |
if not api_key or api_key == "your_groq_api_key_here":
|
| 40 |
error_msg = "GROQ_API_KEY not found or not configured properly.\n"
|
|
|
|
| 41 |
|
| 42 |
# Detect if running in Hugging Face Spaces
|
| 43 |
if os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"):
|
|
@@ -56,7 +58,7 @@ def create_rag_chain(retriever, get_session_history_func):
|
|
| 56 |
|
| 57 |
# --- 1. Initialize the LLM ---
|
| 58 |
# Updated model_name to a standard, high-performance Groq model
|
| 59 |
-
llm = ChatGroq(model_name="llama-3.1-8b-instant", api_key=api_key, temperature=1)
|
| 60 |
|
| 61 |
# --- 2. Create Query Rewriting Chain 🧠 ---
|
| 62 |
print("\nSetting up query rewriting chain...")
|
|
@@ -80,14 +82,12 @@ Standalone Question:"""
|
|
| 80 |
|
| 81 |
# --- 3. Create Main RAG Chain with Memory ---
|
| 82 |
print("\nSetting up main RAG chain...")
|
| 83 |
-
rag_template = """You are an expert
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
**Crucially, do not state that the information is missing from the context.** Assume the provided text contains the necessary information and do your best to construct the answer from it. If the context is truly insufficient, answer the question as best you can with the information available without mentioning the context's limitations.
|
| 90 |
-
|
| 91 |
Context:
|
| 92 |
{context}"""
|
| 93 |
rag_prompt = ChatPromptTemplate.from_messages([
|
|
@@ -100,10 +100,11 @@ Context:
|
|
| 100 |
# Parallel process to fetch context and correctly pass through question and history.
|
| 101 |
# We use itemgetter to select the specific keys from the input dictionary.
|
| 102 |
setup_and_retrieval = RunnableParallel({
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
| 107 |
# =====================================================================
|
| 108 |
|
| 109 |
# The initial RAG chain
|
|
|
|
| 35 |
|
| 36 |
# Get API key from environment (works for both HF Spaces secrets and local .env)
|
| 37 |
api_key = os.getenv("GROQ_API_KEY")
|
| 38 |
+
print("key loaded")
|
| 39 |
|
| 40 |
if not api_key or api_key == "your_groq_api_key_here":
|
| 41 |
error_msg = "GROQ_API_KEY not found or not configured properly.\n"
|
| 42 |
+
print("Not found key")
|
| 43 |
|
| 44 |
# Detect if running in Hugging Face Spaces
|
| 45 |
if os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"):
|
|
|
|
| 58 |
|
| 59 |
# --- 1. Initialize the LLM ---
|
| 60 |
# Updated model_name to a standard, high-performance Groq model
|
| 61 |
+
llm = ChatGroq(model_name="llama-3.1-8b-instant", api_key=api_key, temperature=0.1)
|
| 62 |
|
| 63 |
# --- 2. Create Query Rewriting Chain 🧠 ---
|
| 64 |
print("\nSetting up query rewriting chain...")
|
|
|
|
| 82 |
|
| 83 |
# --- 3. Create Main RAG Chain with Memory ---
|
| 84 |
print("\nSetting up main RAG chain...")
|
| 85 |
+
rag_template = """You are CogniChat, an expert document analysis assistant.
|
| 86 |
+
IMPORTANT RULES:
|
| 87 |
+
1. ONLY use information from the provided context
|
| 88 |
+
2. If the answer isn't in the context, say "I cannot find this information in the uploaded documents"
|
| 89 |
+
3. Cite specific sections when possible (e.g., "According to Section 2.3...")
|
| 90 |
+
4. Be concise but comprehensive
|
|
|
|
|
|
|
| 91 |
Context:
|
| 92 |
{context}"""
|
| 93 |
rag_prompt = ChatPromptTemplate.from_messages([
|
|
|
|
| 100 |
# Parallel process to fetch context and correctly pass through question and history.
|
| 101 |
# We use itemgetter to select the specific keys from the input dictionary.
|
| 102 |
setup_and_retrieval = RunnableParallel({
|
| 103 |
+
"context": RunnablePassthrough.assign(
|
| 104 |
+
rewritten_question=query_rewriter
|
| 105 |
+
) | (lambda x: x["rewritten_question"]) | retriever,
|
| 106 |
+
"question": itemgetter("question"),
|
| 107 |
+
"chat_history": itemgetter("chat_history")})
|
| 108 |
# =====================================================================
|
| 109 |
|
| 110 |
# The initial RAG chain
|
templates/index.html
CHANGED
|
@@ -98,17 +98,122 @@
|
|
| 98 |
40% { transform: scale(1.0); }
|
| 99 |
}
|
| 100 |
|
| 101 |
-
/* Markdown Styling */
|
| 102 |
-
.markdown-content p {
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
.markdown-content
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
/* Spinner for the TTS button */
|
| 114 |
.tts-button-loader {
|
|
@@ -181,7 +286,6 @@
|
|
| 181 |
const chatContent = document.getElementById('chat-content');
|
| 182 |
const chatFilename = document.getElementById('chat-filename');
|
| 183 |
|
| 184 |
-
// Persist the active Flask session between requests
|
| 185 |
let sessionId = null;
|
| 186 |
const storedSessionId = sessionStorage.getItem('cognichat_session_id');
|
| 187 |
if (storedSessionId) {
|
|
@@ -257,7 +361,7 @@
|
|
| 257 |
loadingOverlay.classList.add('hidden');
|
| 258 |
loadingSubtext.textContent = '';
|
| 259 |
fileNameSpan.textContent = '';
|
| 260 |
-
fileUploadInput.value = '';
|
| 261 |
}
|
| 262 |
}
|
| 263 |
|
|
@@ -289,26 +393,28 @@
|
|
| 289 |
});
|
| 290 |
|
| 291 |
if (!response.ok) throw new Error(`Server error: ${response.statusText}`);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
typingIndicator.remove();
|
| 294 |
botMessageContainer = appendMessage('', 'bot');
|
| 295 |
contentDiv = botMessageContainer.querySelector('.markdown-content');
|
| 296 |
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
let fullResponse = '';
|
| 300 |
-
|
| 301 |
-
while (true) {
|
| 302 |
-
const { value, done } = await reader.read();
|
| 303 |
-
if (done) break;
|
| 304 |
-
|
| 305 |
-
fullResponse += decoder.decode(value, { stream: true });
|
| 306 |
-
contentDiv.innerHTML = marked.parse(fullResponse);
|
| 307 |
-
scrollToBottom();
|
| 308 |
-
}
|
| 309 |
contentDiv.querySelectorAll('pre').forEach(addCopyButton);
|
|
|
|
| 310 |
|
| 311 |
-
|
|
|
|
|
|
|
| 312 |
|
| 313 |
} catch (error) {
|
| 314 |
console.error('Chat error:', error);
|
|
@@ -344,7 +450,10 @@
|
|
| 344 |
|
| 345 |
const contentDiv = document.createElement('div');
|
| 346 |
contentDiv.className = 'text-base markdown-content';
|
| 347 |
-
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
const controlsContainer = document.createElement('div');
|
| 350 |
controlsContainer.className = 'tts-controls mt-2';
|
|
@@ -412,7 +521,7 @@
|
|
| 412 |
});
|
| 413 |
}
|
| 414 |
|
| 415 |
-
//
|
| 416 |
let currentAudio = null;
|
| 417 |
let currentPlayingButton = null;
|
| 418 |
|
|
@@ -424,7 +533,6 @@
|
|
| 424 |
const ttsControls = messageBubble.querySelector('.tts-controls');
|
| 425 |
if (text.trim().length > 0) {
|
| 426 |
const speakButton = document.createElement('button');
|
| 427 |
-
// STYLING CHANGE HERE: Replaced theme variables with specific dark blue colors.
|
| 428 |
speakButton.className = 'speak-btn px-4 py-2 bg-blue-700 text-white rounded-full text-sm font-medium hover:bg-blue-800 transition-colors flex items-center gap-2 disabled:opacity-50 disabled:cursor-not-allowed';
|
| 429 |
speakButton.title = 'Listen to this message';
|
| 430 |
speakButton.setAttribute('data-state', 'play');
|
|
@@ -435,24 +543,19 @@
|
|
| 435 |
}
|
| 436 |
|
| 437 |
async function handleTTS(text, button) {
|
| 438 |
-
// BUG FIX: Reworked the logic to correctly handle pause/resume.
|
| 439 |
-
|
| 440 |
-
// Case 1: The clicked button is already active (playing or paused).
|
| 441 |
if (button === currentPlayingButton) {
|
| 442 |
-
if (currentAudio && !currentAudio.paused) {
|
| 443 |
currentAudio.pause();
|
| 444 |
button.setAttribute('data-state', 'paused');
|
| 445 |
button.innerHTML = `${playIconSVG} <span>Play</span>`;
|
| 446 |
-
} else if (currentAudio && currentAudio.paused) {
|
| 447 |
currentAudio.play();
|
| 448 |
button.setAttribute('data-state', 'playing');
|
| 449 |
button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
|
| 450 |
}
|
| 451 |
-
return;
|
| 452 |
}
|
| 453 |
|
| 454 |
-
// Case 2: A new button is clicked (or no audio is active).
|
| 455 |
-
// Stop any other audio that might be playing.
|
| 456 |
resetAllSpeakButtons();
|
| 457 |
|
| 458 |
currentPlayingButton = button;
|
|
@@ -488,7 +591,7 @@
|
|
| 488 |
button.setAttribute('data-state', 'error');
|
| 489 |
button.innerHTML = `${playIconSVG} <span>Error</span>`;
|
| 490 |
alert('Failed to play audio. Please try again.');
|
| 491 |
-
resetAllSpeakButtons();
|
| 492 |
} finally {
|
| 493 |
button.disabled = false;
|
| 494 |
}
|
|
@@ -506,8 +609,7 @@
|
|
| 506 |
}
|
| 507 |
currentPlayingButton = null;
|
| 508 |
}
|
| 509 |
-
// ============================ MODIFICATIONS END ==============================
|
| 510 |
});
|
| 511 |
</script>
|
| 512 |
</body>
|
| 513 |
-
</html
|
|
|
|
| 98 |
40% { transform: scale(1.0); }
|
| 99 |
}
|
| 100 |
|
| 101 |
+
/* Enhanced Markdown Styling for better readability and aesthetics */
|
| 102 |
+
.markdown-content p {
|
| 103 |
+
margin-bottom: 1rem;
|
| 104 |
+
line-height: 1.75;
|
| 105 |
+
}
|
| 106 |
+
.markdown-content h1, .markdown-content h2, .markdown-content h3, .markdown-content h4 {
|
| 107 |
+
font-family: 'Google Sans', sans-serif;
|
| 108 |
+
font-weight: 700;
|
| 109 |
+
margin-top: 1.75rem;
|
| 110 |
+
margin-bottom: 1rem;
|
| 111 |
+
line-height: 1.3;
|
| 112 |
+
}
|
| 113 |
+
.markdown-content h1 { font-size: 1.75em; border-bottom: 1px solid var(--card-border); padding-bottom: 0.5rem; }
|
| 114 |
+
.markdown-content h2 { font-size: 1.5em; }
|
| 115 |
+
.markdown-content h3 { font-size: 1.25em; }
|
| 116 |
+
.markdown-content h4 { font-size: 1.1em; }
|
| 117 |
+
.markdown-content ul, .markdown-content ol {
|
| 118 |
+
padding-left: 1.75rem;
|
| 119 |
+
margin-bottom: 1rem;
|
| 120 |
+
}
|
| 121 |
+
.markdown-content li {
|
| 122 |
+
margin-bottom: 0.5rem;
|
| 123 |
+
}
|
| 124 |
+
.dark .markdown-content ul > li::marker { color: var(--primary); }
|
| 125 |
+
.markdown-content ul > li::marker { color: var(--primary); }
|
| 126 |
+
.markdown-content a {
|
| 127 |
+
color: var(--primary);
|
| 128 |
+
text-decoration: none;
|
| 129 |
+
font-weight: 500;
|
| 130 |
+
border-bottom: 1px solid transparent;
|
| 131 |
+
transition: all 0.2s ease-in-out;
|
| 132 |
+
}
|
| 133 |
+
.markdown-content a:hover {
|
| 134 |
+
border-bottom-color: var(--primary-hover);
|
| 135 |
+
}
|
| 136 |
+
.markdown-content blockquote {
|
| 137 |
+
margin: 1.5rem 0;
|
| 138 |
+
padding-left: 1.5rem;
|
| 139 |
+
border-left: 4px solid var(--card-border);
|
| 140 |
+
color: #6c757d;
|
| 141 |
+
font-style: italic;
|
| 142 |
+
}
|
| 143 |
+
.dark .markdown-content blockquote {
|
| 144 |
+
color: #adb5bd;
|
| 145 |
+
}
|
| 146 |
+
.markdown-content hr {
|
| 147 |
+
border: none;
|
| 148 |
+
border-top: 1px solid var(--card-border);
|
| 149 |
+
margin: 2rem 0;
|
| 150 |
+
}
|
| 151 |
+
.markdown-content table {
|
| 152 |
+
width: 100%;
|
| 153 |
+
border-collapse: collapse;
|
| 154 |
+
margin: 1.5rem 0;
|
| 155 |
+
font-size: 0.9em;
|
| 156 |
+
box-shadow: 0 1px 3px rgba(0,0,0,0.05);
|
| 157 |
+
border-radius: 8px;
|
| 158 |
+
overflow: hidden;
|
| 159 |
+
}
|
| 160 |
+
.markdown-content th, .markdown-content td {
|
| 161 |
+
border: 1px solid var(--card-border);
|
| 162 |
+
padding: 0.75rem 1rem;
|
| 163 |
+
text-align: left;
|
| 164 |
+
}
|
| 165 |
+
.markdown-content th {
|
| 166 |
+
background-color: var(--bot-bubble);
|
| 167 |
+
font-weight: 500;
|
| 168 |
+
}
|
| 169 |
+
.markdown-content code {
|
| 170 |
+
background-color: rgba(0,0,0,0.05);
|
| 171 |
+
padding: 0.2rem 0.4rem;
|
| 172 |
+
border-radius: 0.25rem;
|
| 173 |
+
font-family: 'Roboto Mono', monospace;
|
| 174 |
+
font-size: 0.9em;
|
| 175 |
+
}
|
| 176 |
+
.dark .markdown-content code {
|
| 177 |
+
background-color: rgba(255,255,255,0.1);
|
| 178 |
+
}
|
| 179 |
+
.markdown-content pre {
|
| 180 |
+
position: relative;
|
| 181 |
+
background-color: #f8f9fa;
|
| 182 |
+
border: 1px solid var(--card-border);
|
| 183 |
+
border-radius: 0.5rem;
|
| 184 |
+
margin-bottom: 1rem;
|
| 185 |
+
}
|
| 186 |
+
.dark .markdown-content pre {
|
| 187 |
+
background-color: #2e2f32;
|
| 188 |
+
}
|
| 189 |
+
.markdown-content pre code {
|
| 190 |
+
background: none;
|
| 191 |
+
padding: 1rem;
|
| 192 |
+
display: block;
|
| 193 |
+
overflow-x: auto;
|
| 194 |
+
}
|
| 195 |
+
.markdown-content pre .copy-code-btn {
|
| 196 |
+
position: absolute;
|
| 197 |
+
top: 0.5rem;
|
| 198 |
+
right: 0.5rem;
|
| 199 |
+
background-color: #e8eaed;
|
| 200 |
+
border: 1px solid #dadce0;
|
| 201 |
+
color: #5f6368;
|
| 202 |
+
padding: 0.3rem 0.6rem;
|
| 203 |
+
border-radius: 0.25rem;
|
| 204 |
+
cursor: pointer;
|
| 205 |
+
opacity: 0;
|
| 206 |
+
transition: opacity 0.2s;
|
| 207 |
+
font-size: 0.8em;
|
| 208 |
+
}
|
| 209 |
+
.dark .markdown-content pre .copy-code-btn {
|
| 210 |
+
background-color: #3c4043;
|
| 211 |
+
border-color: #5f6368;
|
| 212 |
+
color: #e8eaed;
|
| 213 |
+
}
|
| 214 |
+
.markdown-content pre:hover .copy-code-btn {
|
| 215 |
+
opacity: 1;
|
| 216 |
+
}
|
| 217 |
|
| 218 |
/* Spinner for the TTS button */
|
| 219 |
.tts-button-loader {
|
|
|
|
| 286 |
const chatContent = document.getElementById('chat-content');
|
| 287 |
const chatFilename = document.getElementById('chat-filename');
|
| 288 |
|
|
|
|
| 289 |
let sessionId = null;
|
| 290 |
const storedSessionId = sessionStorage.getItem('cognichat_session_id');
|
| 291 |
if (storedSessionId) {
|
|
|
|
| 361 |
loadingOverlay.classList.add('hidden');
|
| 362 |
loadingSubtext.textContent = '';
|
| 363 |
fileNameSpan.textContent = '';
|
| 364 |
+
fileUploadInput.value = '';
|
| 365 |
}
|
| 366 |
}
|
| 367 |
|
|
|
|
| 393 |
});
|
| 394 |
|
| 395 |
if (!response.ok) throw new Error(`Server error: ${response.statusText}`);
|
| 396 |
+
|
| 397 |
+
// ============================ MODIFICATION START ==============================
|
| 398 |
+
// Parse the JSON response instead of reading a stream
|
| 399 |
+
const result = await response.json();
|
| 400 |
+
const answer = result.answer; // Extract the 'answer' field
|
| 401 |
+
|
| 402 |
+
if (!answer) {
|
| 403 |
+
throw new Error("Received an empty or invalid response from the server.");
|
| 404 |
+
}
|
| 405 |
|
| 406 |
typingIndicator.remove();
|
| 407 |
botMessageContainer = appendMessage('', 'bot');
|
| 408 |
contentDiv = botMessageContainer.querySelector('.markdown-content');
|
| 409 |
|
| 410 |
+
// Use the extracted answer for rendering
|
| 411 |
+
contentDiv.innerHTML = marked.parse(answer);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
contentDiv.querySelectorAll('pre').forEach(addCopyButton);
|
| 413 |
+
scrollToBottom(); // Scroll after content is added
|
| 414 |
|
| 415 |
+
// Use the extracted answer for TTS
|
| 416 |
+
addTextToSpeechControls(botMessageContainer, answer);
|
| 417 |
+
// ============================ MODIFICATION END ==============================
|
| 418 |
|
| 419 |
} catch (error) {
|
| 420 |
console.error('Chat error:', error);
|
|
|
|
| 450 |
|
| 451 |
const contentDiv = document.createElement('div');
|
| 452 |
contentDiv.className = 'text-base markdown-content';
|
| 453 |
+
// Only parse if text is not empty
|
| 454 |
+
if (text) {
|
| 455 |
+
contentDiv.innerHTML = marked.parse(text);
|
| 456 |
+
}
|
| 457 |
|
| 458 |
const controlsContainer = document.createElement('div');
|
| 459 |
controlsContainer.className = 'tts-controls mt-2';
|
|
|
|
| 521 |
});
|
| 522 |
}
|
| 523 |
|
| 524 |
+
// --- Text-to-Speech Logic ---
|
| 525 |
let currentAudio = null;
|
| 526 |
let currentPlayingButton = null;
|
| 527 |
|
|
|
|
| 533 |
const ttsControls = messageBubble.querySelector('.tts-controls');
|
| 534 |
if (text.trim().length > 0) {
|
| 535 |
const speakButton = document.createElement('button');
|
|
|
|
| 536 |
speakButton.className = 'speak-btn px-4 py-2 bg-blue-700 text-white rounded-full text-sm font-medium hover:bg-blue-800 transition-colors flex items-center gap-2 disabled:opacity-50 disabled:cursor-not-allowed';
|
| 537 |
speakButton.title = 'Listen to this message';
|
| 538 |
speakButton.setAttribute('data-state', 'play');
|
|
|
|
| 543 |
}
|
| 544 |
|
| 545 |
async function handleTTS(text, button) {
|
|
|
|
|
|
|
|
|
|
| 546 |
if (button === currentPlayingButton) {
|
| 547 |
+
if (currentAudio && !currentAudio.paused) {
|
| 548 |
currentAudio.pause();
|
| 549 |
button.setAttribute('data-state', 'paused');
|
| 550 |
button.innerHTML = `${playIconSVG} <span>Play</span>`;
|
| 551 |
+
} else if (currentAudio && currentAudio.paused) {
|
| 552 |
currentAudio.play();
|
| 553 |
button.setAttribute('data-state', 'playing');
|
| 554 |
button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
|
| 555 |
}
|
| 556 |
+
return;
|
| 557 |
}
|
| 558 |
|
|
|
|
|
|
|
| 559 |
resetAllSpeakButtons();
|
| 560 |
|
| 561 |
currentPlayingButton = button;
|
|
|
|
| 591 |
button.setAttribute('data-state', 'error');
|
| 592 |
button.innerHTML = `${playIconSVG} <span>Error</span>`;
|
| 593 |
alert('Failed to play audio. Please try again.');
|
| 594 |
+
resetAllSpeakButtons();
|
| 595 |
} finally {
|
| 596 |
button.disabled = false;
|
| 597 |
}
|
|
|
|
| 609 |
}
|
| 610 |
currentPlayingButton = null;
|
| 611 |
}
|
|
|
|
| 612 |
});
|
| 613 |
</script>
|
| 614 |
</body>
|
| 615 |
+
</html>
|