riteshraut commited on
Commit
bd93e48
·
1 Parent(s): 5fd359f

new feature

Browse files
Files changed (5) hide show
  1. app.py +282 -152
  2. evaluate.py +205 -0
  3. query_expansion.py +524 -0
  4. rag_processor.py +382 -78
  5. templates/index.html +880 -610
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
- import time
3
  import uuid
4
- from flask import Flask, request, render_template, session, jsonify, Response
5
  from werkzeug.utils import secure_filename
6
  from rag_processor import create_rag_chain
7
  from typing import Sequence, Any, List
@@ -10,10 +9,7 @@ import re
10
  import io
11
  from gtts import gTTS
12
  from langchain_core.documents import Document
13
- from langchain_community.document_loaders import (
14
- TextLoader,
15
- Docx2txtLoader,
16
- )
17
  from langchain.text_splitter import RecursiveCharacterTextSplitter
18
  from langchain_huggingface import HuggingFaceEmbeddings
19
  from langchain_community.vectorstores import FAISS
@@ -23,124 +19,203 @@ from langchain_community.retrievers import BM25Retriever
23
  from langchain_community.chat_message_histories import ChatMessageHistory
24
  from langchain.storage import InMemoryStore
25
  from sentence_transformers.cross_encoder import CrossEncoder
 
26
  app = Flask(__name__)
27
  app.config['SECRET_KEY'] = os.urandom(24)
 
 
 
 
 
 
 
28
  class LocalReranker(BaseDocumentCompressor):
29
  model: Any
30
- top_n: int = 3
31
 
32
  class Config:
33
  arbitrary_types_allowed = True
34
 
35
- def compress_documents(
36
- self,
37
- documents: Sequence[Document],
38
- query: str,
39
- callbacks=None,
40
- ) -> Sequence[Document]:
41
  if not documents:
42
  return []
43
-
44
  pairs = [[query, doc.page_content] for doc in documents]
45
  scores = self.model.predict(pairs, show_progress_bar=False)
46
-
47
  doc_scores = list(zip(documents, scores))
48
- sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True)
49
-
50
  top_docs = []
51
- for doc, score in sorted_doc_scores[:self.top_n]:
52
  doc.metadata['rerank_score'] = float(score)
53
  top_docs.append(doc)
54
  return top_docs
55
-
56
- is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"))
57
- if is_hf_spaces:
58
- app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
59
- else:
60
- app.config['UPLOAD_FOLDER'] = 'uploads'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  try:
63
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
64
- print(f"Upload folder ready: {app.config['UPLOAD_FOLDER']}")
65
  except Exception as e:
66
- print(f"Failed to create upload folder {app.config['UPLOAD_FOLDER']}: {e}")
67
  app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
68
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
69
- print(f"Using fallback upload folder: {app.config['UPLOAD_FOLDER']}")
70
 
71
- rag_chains = {}
72
  message_histories = {}
73
 
74
- print("Loading embedding model...")
75
  try:
76
- EMBEDDING_MODEL = HuggingFaceEmbeddings(
77
- model_name="sentence-transformers/all-MiniLM-L6-v2",
78
- model_kwargs={'device': 'cpu'}
79
- )
80
- print("Embedding model loaded successfully.")
81
  except Exception as e:
82
- print(f"FATAL: Could not load embedding model. Error: {e}")
83
  raise e
84
 
85
- print("Loading local re-ranking model...")
86
  try:
87
- RERANKER_MODEL = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device='cpu')
88
- print("Re-ranking model loaded successfully.")
 
 
89
  except Exception as e:
90
- print(f"FATAL: Could not load reranker model. Error: {e}")
91
  raise e
92
 
 
93
  def load_pdf_with_fallback(filepath):
94
  try:
95
  docs = []
96
  with fitz.open(filepath) as pdf_doc:
97
- for page_num, page in enumerate(pdf_doc):
98
  text = page.get_text()
99
  if text.strip():
100
- docs.append(Document(
101
- page_content=text,
102
- metadata={
103
- "source": os.path.basename(filepath),
104
- "page": page_num + 1,
105
- }
106
- ))
107
  if docs:
108
- print(f"Successfully loaded PDF with PyMuPDF: {filepath}")
 
109
  return docs
110
  else:
111
- raise ValueError("No text content found in PDF.")
112
  except Exception as e:
113
- print(f"PyMuPDF failed for {filepath}: {e}")
114
  raise
115
 
116
- LOADER_MAPPING = {
117
- ".txt": TextLoader,
118
- ".pdf": load_pdf_with_fallback,
119
- ".docx": Docx2txtLoader,
120
- }
121
 
122
  def get_session_history(session_id: str) -> ChatMessageHistory:
123
  if session_id not in message_histories:
124
  message_histories[session_id] = ChatMessageHistory()
125
  return message_histories[session_id]
126
 
 
127
  @app.route('/health', methods=['GET'])
128
  def health_check():
129
- return jsonify({'status': 'healthy'}), 200
 
130
 
131
  @app.route('/', methods=['GET'])
132
  def index():
133
  return render_template('index.html')
134
 
 
135
  @app.route('/upload', methods=['POST'])
136
  def upload_files():
137
  files = request.files.getlist('file')
138
- if not files or all(f.filename == '' for f in files):
139
- return jsonify({'status': 'error', 'message': 'No selected files.'}), 400
140
 
141
- all_docs = []
142
- processed_files, failed_files = [], []
 
 
 
143
 
 
 
 
 
 
 
144
  for file in files:
145
  if file and file.filename:
146
  filename = secure_filename(file.filename)
@@ -149,147 +224,202 @@ def upload_files():
149
  file.save(filepath)
150
  file_ext = os.path.splitext(filename)[1].lower()
151
  if file_ext not in LOADER_MAPPING:
152
- raise ValueError("Unsupported file format.")
153
-
154
  loader_func = LOADER_MAPPING[file_ext]
155
- docs = loader_func(filepath) if file_ext == ".pdf" else loader_func(filepath).load()
156
-
157
  if not docs:
158
- raise ValueError("No content extracted.")
159
-
160
  all_docs.extend(docs)
161
  processed_files.append(filename)
162
- print(f"✓ Successfully processed: {filename}")
163
  except Exception as e:
164
- error_msg = str(e)
165
- print(f"✗ Error processing {filename}: {error_msg}")
166
- failed_files.append(f"{filename} ({error_msg})")
167
 
168
  if not all_docs:
169
- error_summary = "Failed to process all files."
170
- if failed_files:
171
- error_summary += " Reasons: " + ", ".join(failed_files)
172
- return jsonify({'status': 'error', 'message': error_summary}), 400
173
 
 
174
  try:
175
- print("Starting RAG pipeline setup...")
 
 
 
 
176
 
177
- vectorstore = FAISS.from_texts([""], EMBEDDING_MODEL)
178
  store = InMemoryStore()
179
-
180
- parent_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
181
- child_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=50)
182
-
183
- parent_docs = parent_splitter.split_documents(all_docs)
184
- doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
185
-
186
- child_docs = []
187
- for i, doc in enumerate(parent_docs):
188
- _id = doc_ids[i]
189
- sub_docs = child_splitter.split_documents([doc])
190
- for child in sub_docs:
191
- child.metadata["doc_id"] = _id
192
- child_docs.extend(sub_docs)
193
-
194
  store.mset(list(zip(doc_ids, parent_docs)))
195
- vectorstore.add_documents(child_docs)
196
- print(f"Stored {len(parent_docs)} parent docs and indexed {len(child_docs)} child docs.")
197
 
198
  bm25_retriever = BM25Retriever.from_documents(child_docs)
199
- bm25_retriever.k = 8
200
-
201
- faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 8})
202
-
203
- ensemble_retriever = EnsembleRetriever(
204
- retrievers=[bm25_retriever, faiss_retriever],
205
- weights=[0.4, 0.6]
206
- )
207
- print("Created Hybrid Retriever for child documents.")
208
 
209
- reranker = LocalReranker(model=RERANKER_MODEL, top_n=4)
210
-
211
  def get_parents(docs: List[Document]) -> List[Document]:
212
- parent_ids = {d.metadata["doc_id"] for d in docs}
213
- return store.mget(list(parent_ids))
214
 
215
- compression_retriever = ContextualCompressionRetriever(
216
- base_compressor=reranker, base_retriever=ensemble_retriever
217
- )
218
-
219
  final_retriever = compression_retriever | get_parents
220
-
221
- print("Final retriever chain created: (Hybrid -> Rerank) -> Parent Fetch")
222
 
223
  session_id = str(uuid.uuid4())
224
- rag_chain = create_rag_chain(final_retriever, get_session_history)
225
- rag_chains[session_id] = rag_chain
226
- session['session_id'] = session_id
227
-
228
- success_msg = f"Successfully processed: {', '.join(processed_files)}"
 
 
 
 
 
 
 
229
  if failed_files:
230
- success_msg += f"\nFailed to process: {', '.join(failed_files)}"
 
 
 
 
 
231
 
232
  return jsonify({
233
- 'status': 'success',
234
  'filename': success_msg,
235
- 'session_id': session_id
236
- })
237
-
 
238
  except Exception as e:
239
  import traceback
240
  traceback.print_exc()
241
- return jsonify({'status': 'error', 'message': f'Failed during RAG setup: {e}'}), 500
 
242
 
243
- @app.route('/chat', methods=['POST'])
244
  def chat():
245
- data = request.get_json()
246
- question = data.get('question')
247
- session_id = session.get('session_id') or data.get('session_id')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
- if not question or not session_id or session_id not in rag_chains:
250
- return jsonify({'status': 'error', 'message': 'Invalid session or no question provided.'}), 400
251
 
252
- try:
253
- rag_chain = rag_chains[session_id]
254
- config = {"configurable": {"session_id": session_id}}
255
- answer_string = rag_chain.invoke({"question": question}, config=config)
256
- return jsonify({'answer': answer_string})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  except Exception as e:
259
- print(f"Error during chat invocation: {e}")
260
- return jsonify({'status': 'error', 'message': 'An error occurred while getting the answer.'}), 500
 
 
 
 
 
 
 
 
 
261
 
262
  def clean_markdown_for_tts(text: str) -> str:
263
- text = re.sub(r'\*(\*?)(.*?)\1\*', r'\2', text)
264
- text = re.sub(r'\_(.*?)\_', r'\1', text)
265
- text = re.sub(r'`(.*?)`', r'\1', text)
266
- text = re.sub(r'^\s*#{1,6}\s+', '', text, flags=re.MULTILINE)
267
- text = re.sub(r'^\s*[\*\-]\s+', '', text, flags=re.MULTILINE)
268
- text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
269
- text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE)
270
- text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE)
271
- text = re.sub(r'\n+', ' ', text)
272
  return text.strip()
273
 
 
274
  @app.route('/tts', methods=['POST'])
275
  def text_to_speech():
276
  data = request.get_json()
277
  text = data.get('text')
278
-
279
  if not text:
280
- return jsonify({'status': 'error', 'message': 'No text provided.'}), 400
281
-
282
  try:
283
  clean_text = clean_markdown_for_tts(text)
 
 
 
284
  tts = gTTS(clean_text, lang='en')
285
  mp3_fp = io.BytesIO()
286
  tts.write_to_fp(mp3_fp)
287
  mp3_fp.seek(0)
288
  return Response(mp3_fp, mimetype='audio/mpeg')
289
  except Exception as e:
290
- print(f"Error in TTS generation: {e}")
291
- return jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500
 
 
292
 
293
  if __name__ == '__main__':
294
- port = int(os.environ.get("PORT", 7860))
295
- app.run(host="0.0.0.0", port=port, debug=False)
 
 
1
  import os
 
2
  import uuid
3
+ from flask import Flask, request, render_template, session, jsonify, Response, stream_with_context
4
  from werkzeug.utils import secure_filename
5
  from rag_processor import create_rag_chain
6
  from typing import Sequence, Any, List
 
9
  import io
10
  from gtts import gTTS
11
  from langchain_core.documents import Document
12
+ from langchain_community.document_loaders import TextLoader, Docx2txtLoader
 
 
 
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain_huggingface import HuggingFaceEmbeddings
15
  from langchain_community.vectorstores import FAISS
 
19
  from langchain_community.chat_message_histories import ChatMessageHistory
20
  from langchain.storage import InMemoryStore
21
  from sentence_transformers.cross_encoder import CrossEncoder
22
+
23
  app = Flask(__name__)
24
  app.config['SECRET_KEY'] = os.urandom(24)
25
+ TEMPERATURE_LABELS = {
26
+ '0.2': 'Precise',
27
+ '0.4': 'Confident',
28
+ '0.6': 'Balanced',
29
+ '0.8': 'Flexible',
30
+ '1.0': 'Creative',
31
+ }
32
  class LocalReranker(BaseDocumentCompressor):
33
  model: Any
34
+ top_n: int = 5
35
 
36
  class Config:
37
  arbitrary_types_allowed = True
38
 
39
+ def compress_documents(self, documents: Sequence[Document], query: str,
40
+ callbacks=None) -> Sequence[Document]:
 
 
 
 
41
  if not documents:
42
  return []
 
43
  pairs = [[query, doc.page_content] for doc in documents]
44
  scores = self.model.predict(pairs, show_progress_bar=False)
 
45
  doc_scores = list(zip(documents, scores))
46
+ sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1],
47
+ reverse=True)
48
  top_docs = []
49
+ for (doc, score) in sorted_doc_scores[:self.top_n]:
50
  doc.metadata['rerank_score'] = float(score)
51
  top_docs.append(doc)
52
  return top_docs
53
+
54
+
55
+ def create_optimized_parent_child_chunks(all_docs):
56
+ if not all_docs:
57
+ print ('CHUNKING: No input documents provided!')
58
+ return ([], [], [])
59
+
60
+ parent_splitter = RecursiveCharacterTextSplitter(chunk_size=900,
61
+ chunk_overlap=200, separators=['\n\n', '\n', '. ', '! ',
62
+ '? ', '; ', ', ', ' ', ''])
63
+ child_splitter = RecursiveCharacterTextSplitter(chunk_size=350,
64
+ chunk_overlap=80, separators=['\n', '. ', '! ', '? ', '; ',
65
+ ', ', ' ', ''])
66
+ parent_docs = parent_splitter.split_documents(all_docs)
67
+ doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
68
+ child_docs = []
69
+
70
+ for (i, parent_doc) in enumerate(parent_docs):
71
+ parent_id = doc_ids[i]
72
+ children = child_splitter.split_documents([parent_doc])
73
+ for (j, child) in enumerate(children):
74
+ child.metadata.update({'doc_id': parent_id,
75
+ 'chunk_index': j,
76
+ 'total_chunks': len(children),
77
+ 'is_first_chunk': j == 0,
78
+ 'is_last_chunk': j == len(children)
79
+ - 1})
80
+ if len(children) > 1:
81
+ if j == 0:
82
+ child.page_content = '[Beginning] ' + child.page_content
83
+ elif j == len(children) - 1:
84
+ child.page_content = '[Continues...] ' + child.page_content
85
+ child_docs.append(child)
86
+
87
+ print (f"CHUNKING: Created {len(parent_docs)} parent and {len(child_docs)} child chunks."
88
+ )
89
+ return (parent_docs, child_docs, doc_ids)
90
+ def get_context_aware_parents(docs: List[Document], store: InMemoryStore) -> List[Document]:
91
+ if not docs:
92
+ return []
93
+ (parent_scores, child_content_by_parent) = ({}, {})
94
+ for doc in docs:
95
+ parent_id = doc.metadata.get('doc_id')
96
+ if parent_id:
97
+ parent_scores[parent_id] = parent_scores.get(parent_id, 0) \
98
+ + 1
99
+ if parent_id not in child_content_by_parent:
100
+ child_content_by_parent[parent_id] = []
101
+ child_content_by_parent[parent_id].append(doc.page_content)
102
+
103
+ parent_ids = list(parent_scores.keys())
104
+ parents = store.mget(parent_ids)
105
+ enhanced_parents = []
106
+
107
+ for (i, parent) in enumerate(parents):
108
+ if parent is not None:
109
+ parent_id = parent_ids[i]
110
+ if parent_id in child_content_by_parent:
111
+ child_excerpts = '\n'.join(child_content_by_parent[parent_id][:3])
112
+ enhanced_content = f"{parent.page_content}\n\nRelevant excerpts:\n{child_excerpts}"
113
+ enhanced_parent =Document(page_content=enhanced_content,
114
+ metadata={**parent.metadata,
115
+ 'child_relevance_score': parent_scores[parent_id],
116
+ 'matching_children': len(child_content_by_parent[parent_id])})
117
+ enhanced_parents.append(enhanced_parent)
118
+ else:
119
+ print (f"PARENT_FETCH: Parent {parent_ids[i]} not found in store!")
120
+
121
+ enhanced_parents.sort(key=lambda p: p.metadata.get('child_relevance_score', 0), reverse=True)
122
+ return enhanced_parents
123
+
124
+
125
+ is_hf_spaces = bool(os.getenv('SPACE_ID') or os.getenv('SPACES_ZERO_GPU'
126
+ ))
127
+ app.config['UPLOAD_FOLDER'] = '/tmp/uploads' if is_hf_spaces else 'uploads'
128
 
129
  try:
130
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
131
+ print (f"Upload folder ready: {app.config['UPLOAD_FOLDER']}")
132
  except Exception as e:
133
+ print (f"Failed to create upload folder, falling back to /tmp: {e}")
134
  app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
135
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
 
136
 
137
+ session_data = {}
138
  message_histories = {}
139
 
140
+ print ('Loading embedding model...')
141
  try:
142
+ EMBEDDING_MODEL = \
143
+ HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2'
144
+ , model_kwargs={'device': 'cpu'},
145
+ encode_kwargs={'normalize_embeddings': True})
146
+ print ('Embedding model loaded.')
147
  except Exception as e:
148
+ print (f"FATAL: Could not load embedding model. Error: {e}")
149
  raise e
150
 
151
+ print ('Loading reranker model...')
152
  try:
153
+ RERANKER_MODEL = \
154
+ CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2',
155
+ device='cpu')
156
+ print ('Reranker model loaded.')
157
  except Exception as e:
158
+ print (f"FATAL: Could not load reranker model. Error: {e}")
159
  raise e
160
 
161
+
162
  def load_pdf_with_fallback(filepath):
163
  try:
164
  docs = []
165
  with fitz.open(filepath) as pdf_doc:
166
+ for (page_num, page) in enumerate(pdf_doc):
167
  text = page.get_text()
168
  if text.strip():
169
+ docs.append(Document(page_content=text,
170
+ metadata={'source': os.path.basename(filepath),
171
+ 'page': page_num + 1}))
 
 
 
 
172
  if docs:
173
+ print (f"Loaded PDF: {os.path.basename(filepath)} - {len(docs)} pages"
174
+ )
175
  return docs
176
  else:
177
+ raise ValueError('No text content found in PDF.')
178
  except Exception as e:
179
+ print (f"PyMuPDF failed for {filepath}: {e}")
180
  raise
181
 
182
+
183
+ LOADER_MAPPING = {'.txt': TextLoader, '.pdf': load_pdf_with_fallback,
184
+ '.docx': Docx2txtLoader}
185
+
 
186
 
187
  def get_session_history(session_id: str) -> ChatMessageHistory:
188
  if session_id not in message_histories:
189
  message_histories[session_id] = ChatMessageHistory()
190
  return message_histories[session_id]
191
 
192
+
193
  @app.route('/health', methods=['GET'])
194
  def health_check():
195
+ return (jsonify({'status': 'healthy'}), 200)
196
+
197
 
198
  @app.route('/', methods=['GET'])
199
  def index():
200
  return render_template('index.html')
201
 
202
+
203
  @app.route('/upload', methods=['POST'])
204
  def upload_files():
205
  files = request.files.getlist('file')
 
 
206
 
207
+ temperature_str = request.form.get('temperature', '0.2')
208
+ temperature = float(temperature_str)
209
+ model_name = request.form.get('model_name',
210
+ 'moonshotai/kimi-k2-instruct')
211
+ print (f"UPLOAD: Model: {model_name}, Temp: {temperature}")
212
 
213
+ if not files or all(f.filename == '' for f in files):
214
+ return (jsonify({'status': 'error',
215
+ 'message': 'No selected files.'}), 400)
216
+
217
+ (all_docs, processed_files, failed_files) = ([], [], [])
218
+ print (f"Processing {len(files)} file(s)...")
219
  for file in files:
220
  if file and file.filename:
221
  filename = secure_filename(file.filename)
 
224
  file.save(filepath)
225
  file_ext = os.path.splitext(filename)[1].lower()
226
  if file_ext not in LOADER_MAPPING:
227
+ raise ValueError('Unsupported file format.')
 
228
  loader_func = LOADER_MAPPING[file_ext]
229
+ docs = loader_func(filepath) if file_ext == '.pdf' \
230
+ else loader_func(filepath).load()
231
  if not docs:
232
+ raise ValueError('No content extracted.')
 
233
  all_docs.extend(docs)
234
  processed_files.append(filename)
 
235
  except Exception as e:
236
+ print (f"✗ Error processing {filename}: {e}")
237
+ failed_files.append(f"{filename} ({e})")
 
238
 
239
  if not all_docs:
240
+ return (jsonify({'status': 'error',
241
+ 'message': f"Failed to process all files. Reasons: {', '.join(failed_files)}"
242
+ }), 400)
 
243
 
244
+ print (f"UPLOAD: Processed {len(processed_files)} files.")
245
  try:
246
+ print ('Starting RAG pipeline setup...')
247
+ (parent_docs, child_docs, doc_ids) = \
248
+ create_optimized_parent_child_chunks(all_docs)
249
+ if not child_docs:
250
+ raise ValueError('No child documents created during chunking.')
251
 
252
+ vectorstore = FAISS.from_documents(child_docs, EMBEDDING_MODEL)
253
  store = InMemoryStore()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  store.mset(list(zip(doc_ids, parent_docs)))
255
+ print (f"Indexed {len(child_docs)} document chunks.")
 
256
 
257
  bm25_retriever = BM25Retriever.from_documents(child_docs)
258
+ bm25_retriever.k = 12
259
+ faiss_retriever = vectorstore.as_retriever(search_kwargs={'k': 12})
260
+ ensemble_retriever = \
261
+ EnsembleRetriever(retrievers=[bm25_retriever,
262
+ faiss_retriever], weights=[0.6, 0.4])
263
+ reranker = LocalReranker(model=RERANKER_MODEL, top_n=5)
 
 
 
264
 
 
 
265
  def get_parents(docs: List[Document]) -> List[Document]:
266
+ return get_context_aware_parents(docs, store)
 
267
 
268
+ compression_retriever = \
269
+ ContextualCompressionRetriever(base_compressor=reranker,
270
+ base_retriever=ensemble_retriever)
 
271
  final_retriever = compression_retriever | get_parents
 
 
272
 
273
  session_id = str(uuid.uuid4())
274
+ (rag_chain, api_key_manager) = \
275
+ create_rag_chain(retriever=final_retriever,
276
+ get_session_history_func=get_session_history,
277
+ model_name=model_name,
278
+ temperature=temperature)
279
+
280
+ session_data[session_id] = {'chain': rag_chain,
281
+ 'model_name': model_name,
282
+ 'temperature': temperature,
283
+ 'api_key_manager': api_key_manager}
284
+
285
+ success_msg = f"Processed: {', '.join(processed_files)}"
286
  if failed_files:
287
+ success_msg += f". Failed: {', '.join(failed_files)}"
288
+
289
+ mode_label = TEMPERATURE_LABELS.get(temperature_str,
290
+ temperature_str)
291
+
292
+ print (f"UPLOAD COMPLETE: Session {session_id} is ready.")
293
 
294
  return jsonify({
295
+ 'status': 'success',
296
  'filename': success_msg,
297
+ 'session_id': session_id,
298
+ 'model_name': model_name,
299
+ 'mode': mode_label,
300
+ })
301
  except Exception as e:
302
  import traceback
303
  traceback.print_exc()
304
+ return (jsonify({'status': 'error',
305
+ 'message': f'RAG setup failed: {e}'}), 500)
306
 
307
+ @app.route('/chat', methods=['POST', 'GET'])
308
  def chat():
309
+ if request.method == 'GET':
310
+ question = request.args.get('question')
311
+ session_id = request.args.get('session_id')
312
+ print(f"Received GET request for chat: session={session_id}, question={question[:50]}...")
313
+ elif request.method == 'POST':
314
+ data = request.get_json()
315
+ question = data.get('question')
316
+ session_id = data.get('session_id') or session.get('session_id')
317
+ print(f"Received POST request for chat: session={session_id}, question={question[:50]}...")
318
+ else:
319
+ return (jsonify({'status': 'error', 'message': 'Method not allowed'}), 405)
320
+
321
+ if not question:
322
+ error_msg = "Error: No question provided."
323
+ print(f"CHAT Validation Error: {error_msg}")
324
+ if request.method == 'GET':
325
+ def error_stream():
326
+ yield f'data: {{"error": "{error_msg}"}}\n\n'
327
+ return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=400)
328
+ return jsonify({'status': 'error','message': error_msg}), 400
329
+
330
+ if not session_id or session_id not in session_data:
331
+ error_msg = "Error: Invalid session. Please upload documents first."
332
+ print(f"CHAT Validation Error: Invalid session {session_id}.")
333
+ if request.method == 'GET':
334
+ def error_stream():
335
+ yield f'data: {{"error": "{error_msg}"}}\n\n'
336
+ return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=400)
337
+ return jsonify({'status': 'error', 'message': error_msg }), 400
338
+ try:
339
+ session_info = session_data[session_id]
340
+ rag_chain = session_info['chain']
341
+ model_name = session_info['model_name']
342
+ temperature_float = session_info['temperature']
343
+ temperature_str = str(temperature_float)
344
+ mode_label = TEMPERATURE_LABELS.get(temperature_str, temperature_str)
345
 
346
+ print (f"CHAT: Streaming response for session {session_id} (Model: {model_name}, Temp: {temperature_float})...")
 
347
 
348
+ def generate_chunks():
349
+ full_response = ''
350
+ try:
351
+ stream_iterator = rag_chain.stream({'question': question},
352
+ config={'configurable': {'session_id': session_id}})
353
+
354
+ for chunk in stream_iterator:
355
+ if isinstance(chunk, str):
356
+ full_response += chunk
357
+ token_escaped = chunk.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n')
358
+ model_name_escaped = model_name.replace('"', '\\"')
359
+ mode_label_escaped = mode_label.replace('"', '\\"')
360
+ yield f'data: {{"token": "{token_escaped}", "model_name": "{model_name_escaped}", "mode": "{mode_label_escaped}"}}\n\n'
361
+ else:
362
+ print(f"Received non-string chunk: {type(chunk)}")
363
+
364
+
365
+ print ('CHAT: Streaming finished successfully.')
366
+
367
+ except Exception as e:
368
+ print(f"CHAT Error during streaming generation: {e}")
369
+ import traceback
370
+ traceback.print_exc()
371
+ error_msg = f"Error during response generation: {str(e)}".replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n')
372
+ yield f'data: {{"error": "{error_msg}"}}\n\n'
373
+ return Response(stream_with_context(generate_chunks()), mimetype='text/event-stream')
374
 
375
  except Exception as e:
376
+ print(f"CHAT Setup Error: {e}")
377
+ import traceback
378
+ traceback.print_exc()
379
+ error_msg = f"Error setting up chat stream: {str(e)}"
380
+ if request.method == 'GET':
381
+ def error_stream():
382
+ clean_error_msg= error_msg.replace("\"", "\\\"").replace("n", "\\n")
383
+ yield f'data: {{"error": "{clean_error_msg}"}}\n\n'
384
+ return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=500)
385
+ return (jsonify({'status': 'error', 'message': error_msg}), 500)
386
+
387
 
388
  def clean_markdown_for_tts(text: str) -> str:
389
+ text = re.sub(r'\[.*?\]\(.*?\)', '', text)
390
+ text = re.sub(r'[`*_#]', '', text)
391
+ text = re.sub(r'^\s*[\-\*\+]\s+', '', text, flags=re.MULTILINE)
392
+ text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
393
+ text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE)
394
+ text = re.sub(r'\n+', ' ', text)
395
+ text = re.sub(r'\s{2,}', ' ', text)
 
 
396
  return text.strip()
397
 
398
+
399
  @app.route('/tts', methods=['POST'])
400
  def text_to_speech():
401
  data = request.get_json()
402
  text = data.get('text')
 
403
  if not text:
404
+ return (jsonify({'status': 'error',
405
+ 'message': 'No text provided.'}), 400)
406
  try:
407
  clean_text = clean_markdown_for_tts(text)
408
+ if not clean_text:
409
+ return (jsonify({'status': 'error', 'message': 'No speakable text found.'}), 400)
410
+
411
  tts = gTTS(clean_text, lang='en')
412
  mp3_fp = io.BytesIO()
413
  tts.write_to_fp(mp3_fp)
414
  mp3_fp.seek(0)
415
  return Response(mp3_fp, mimetype='audio/mpeg')
416
  except Exception as e:
417
+ print (f"TTS Error: {e}")
418
+ return (jsonify({'status': 'error',
419
+ 'message': 'Failed to generate audio.'}), 500)
420
+
421
 
422
  if __name__ == '__main__':
423
+ port = int(os.environ.get('PORT', 7860))
424
+ print (f"Starting Flask app on port {port}")
425
+ app.run(host='0.0.0.0', port=port, debug=False, threaded=True)
evaluate.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ import uuid
4
+ from dotenv import load_dotenv
5
+ from datasets import Dataset
6
+ import pandas as pd
7
+ from typing import Sequence, Any, List
8
+
9
+ # Ragas and LangChain components
10
+ from ragas import evaluate
11
+ from ragas.metrics import (
12
+ faithfulness,
13
+ answer_relevancy,
14
+ context_recall,
15
+ context_precision,
16
+ )
17
+ from ragas.testset import TestsetGenerator
18
+ # NOTE: The 'evolutions' import has been completely removed.
19
+
20
+ # Your specific RAG components from app.py
21
+ from langchain_groq import ChatGroq
22
+ from langchain_community.document_loaders import PyMuPDFLoader
23
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
24
+ from langchain_huggingface import HuggingFaceEmbeddings
25
+ from langchain_community.vectorstores import FAISS
26
+ from langchain.storage import InMemoryStore
27
+ from langchain_community.retrievers import BM25Retriever
28
+ from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriever
29
+ from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
30
+ from langchain_core.documents import Document
31
+ from sentence_transformers.cross_encoder import CrossEncoder
32
+ from rag_processor import create_rag_chain
33
+ from langchain_community.chat_message_histories import ChatMessageHistory
34
+ import fitz
35
+
36
+ # Load environment variables
37
+ load_dotenv()
38
+
39
+ # --- Re-implementing LocalReranker from app.py ---
40
+ class LocalReranker(BaseDocumentCompressor):
41
+ model: Any
42
+ top_n: int = 3
43
+ class Config:
44
+ arbitrary_types_allowed = True
45
+ def compress_documents(self, documents: Sequence[Document], query: str, callbacks=None) -> Sequence[Document]:
46
+ if not documents: return []
47
+ pairs = [[query, doc.page_content] for doc in documents]
48
+ scores = self.model.predict(pairs, show_progress_bar=False)
49
+ doc_scores = list(zip(documents, scores))
50
+ sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True)
51
+ top_docs = []
52
+ for doc, score in sorted_doc_scores[:self.top_n]:
53
+ doc.metadata['rerank_score'] = float(score)
54
+ top_docs.append(doc)
55
+ return top_docs
56
+
57
+ # --- Helper Functions ---
58
+ def load_pdf_with_fallback(filepath):
59
+ """Load PDF using PyMuPDF"""
60
+ try:
61
+ docs = []
62
+ with fitz.open(filepath) as pdf_doc:
63
+ for page_num, page in enumerate(pdf_doc):
64
+ text = page.get_text()
65
+ if text.strip():
66
+ docs.append(Document(
67
+ page_content=text,
68
+ metadata={"source": os.path.basename(filepath), "page": page_num + 1}
69
+ ))
70
+ if docs:
71
+ print(f"✓ Successfully loaded PDF: {filepath}")
72
+ return docs
73
+ else:
74
+ raise ValueError("No text content found in PDF.")
75
+ except Exception as e:
76
+ print(f"✗ PyMuPDF failed for {filepath}: {e}")
77
+ raise
78
+
79
+ async def main():
80
+ """Main execution function"""
81
+ print("\n" + "="*60 + "\nSTARTING RAGAS EVALUATION\n" + "="*60)
82
+
83
+ pdf_path = "uploads/Unit_-_1_Introduction.pdf"
84
+ if not os.path.exists(pdf_path):
85
+ print(f"✗ Error: PDF not found at {pdf_path}")
86
+ return
87
+
88
+ try:
89
+ # --- 1. Setup Models ---
90
+ print("\n--- 1. Initializing Models ---")
91
+ groq_api_key = os.getenv("GROQ_API_KEY")
92
+ if not groq_api_key or groq_api_key == "your_groq_api_key_here":
93
+ raise ValueError("GROQ_API_KEY not found or is a placeholder.")
94
+
95
+ generator_llm = ChatGroq(model="llama-3.1-8b-instant", api_key=groq_api_key)
96
+ critic_llm = ChatGroq(model="llama-3.1-70b-versatile", api_key=groq_api_key)
97
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
98
+ reranker_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device='cpu')
99
+ print("✓ Models initialized.")
100
+
101
+ # --- 2. Setup RAG Pipeline ---
102
+ print("\n--- 2. Setting up RAG Pipeline ---")
103
+ documents = load_pdf_with_fallback(pdf_path)
104
+
105
+ # Split documents
106
+ parent_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=400)
107
+ child_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=50)
108
+ parent_docs = parent_splitter.split_documents(documents)
109
+ doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
110
+
111
+ child_docs = []
112
+ for i, doc in enumerate(parent_docs):
113
+ _id = doc_ids[i]
114
+ sub_docs = child_splitter.split_documents([doc])
115
+ for child in sub_docs:
116
+ child.metadata["doc_id"] = _id
117
+ child_docs.extend(sub_docs)
118
+
119
+ store = InMemoryStore()
120
+ store.mset(list(zip(doc_ids, parent_docs)))
121
+ vectorstore = FAISS.from_documents(child_docs, embedding_model)
122
+
123
+ bm25_retriever = BM25Retriever.from_documents(child_docs, k=10)
124
+ faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
125
+ ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.4, 0.6])
126
+
127
+ reranker = LocalReranker(model=reranker_model, top_n=5)
128
+ compression_retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=ensemble_retriever)
129
+
130
+ def get_parents(docs: List[Document]) -> List[Document]:
131
+ parent_ids = {d.metadata["doc_id"] for d in docs}
132
+ return store.mget(list(parent_ids))
133
+
134
+ final_retriever = compression_retriever | get_parents
135
+
136
+ message_histories = {}
137
+ def get_session_history(session_id: str):
138
+ if session_id not in message_histories:
139
+ message_histories[session_id] = ChatMessageHistory()
140
+ return message_histories[session_id]
141
+
142
+ rag_chain = create_rag_chain(final_retriever, get_session_history)
143
+ print("✓ RAG chain created successfully.")
144
+
145
+ # --- 3. Generate Testset ---
146
+ print("\n--- 3. Generating Test Questions ---")
147
+ generator = TestsetGenerator.from_langchain(generator_llm, critic_llm, embedding_model)
148
+
149
+ # Generate a simple test set without complex distributions
150
+ testset = generator.generate_with_langchain_docs(documents, testset_size=5)
151
+ print("✓ Testset generated.")
152
+
153
+ # --- 4. Run RAG Chain on Testset ---
154
+ print("\n--- 4. Running RAG Chain to Generate Answers ---")
155
+ test_questions = [item['question'] for item in testset.to_pandas().to_dict('records')]
156
+ ground_truths = [item['ground_truth'] for item in testset.to_pandas().to_dict('records')]
157
+
158
+ answers = []
159
+ contexts = []
160
+
161
+ for i, question in enumerate(test_questions):
162
+ print(f" Processing question {i+1}/{len(test_questions)}...")
163
+ # Retrieve contexts
164
+ retrieved_docs = final_retriever.invoke(question)
165
+ contexts.append([doc.page_content for doc in retrieved_docs])
166
+ # Get answer from chain
167
+ config = {"configurable": {"session_id": str(uuid.uuid4())}}
168
+ answer = await rag_chain.ainvoke({"question": question}, config=config)
169
+ answers.append(answer)
170
+
171
+ # --- 5. Evaluate with Ragas ---
172
+ print("\n--- 5. Evaluating Results with Ragas ---")
173
+ eval_data = {
174
+ 'question': test_questions,
175
+ 'answer': answers,
176
+ 'contexts': contexts,
177
+ 'ground_truth': ground_truths
178
+ }
179
+ eval_dataset = Dataset.from_dict(eval_data)
180
+
181
+ result = evaluate(
182
+ eval_dataset,
183
+ metrics=[faithfulness, answer_relevancy, context_precision, context_recall],
184
+ llm=critic_llm,
185
+ embeddings=embedding_model
186
+ )
187
+
188
+ print("\n" + "="*60 + "\nEVALUATION RESULTS\n" + "="*60)
189
+ print(result)
190
+
191
+ # --- 6. Save Results ---
192
+ print("\n--- 6. Saving Results ---")
193
+ results_df = result.to_pandas()
194
+ results_df.to_csv("evaluation_results.csv", index=False)
195
+ print("✓ Evaluation results saved to evaluation_results.csv")
196
+
197
+ print("\n" + "="*60 + "\nEVALUATION COMPLETE!\n" + "="*60)
198
+
199
+ except Exception as e:
200
+ print(f"\n✗ An error occurred during the process: {e}")
201
+ import traceback
202
+ traceback.print_exc()
203
+
204
+ if __name__ == "__main__":
205
+ asyncio.run(main())
query_expansion.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ Query Expansion System for CogniChat RAG Application
4
+
5
+ This module implements advanced query expansion techniques to improve retrieval quality:
6
+ - QueryAnalyzer: Extracts intent, entities, and keywords
7
+ - QueryRephraser: Generates natural language variations
8
+ - MultiQueryExpander: Creates diverse query formulations
9
+ - MultiHopReasoner: Connects concepts across documents
10
+ - FallbackStrategies: Handles edge cases gracefully
11
+
12
+ Author: CogniChat Team
13
+ Date: October 19, 2025
14
+ """
15
+
16
+ import re
17
+ from typing import List, Dict, Any, Optional
18
+ from dataclasses import dataclass
19
+ from enum import Enum
20
+
21
+
22
+ class QueryStrategy(Enum):
23
+ """Query expansion strategies with different complexity levels."""
24
+ QUICK = "quick" # 2 queries - fast, minimal expansion
25
+ BALANCED = "balanced" # 3-4 queries - good balance
26
+ COMPREHENSIVE = "comprehensive" # 5-6 queries - maximum coverage
27
+
28
+
29
+ @dataclass
30
+ class QueryAnalysis:
31
+ """Results from query analysis."""
32
+ intent: str # question, definition, comparison, explanation, etc.
33
+ entities: List[str] # Named entities extracted
34
+ keywords: List[str] # Important keywords
35
+ complexity: str # simple, medium, complex
36
+ domain: Optional[str] = None # Technical domain if detected
37
+
38
+
39
+ @dataclass
40
+ class ExpandedQuery:
41
+ """Container for expanded query variations."""
42
+ original: str
43
+ variations: List[str]
44
+ strategy_used: QueryStrategy
45
+ analysis: QueryAnalysis
46
+
47
+
48
+ class QueryAnalyzer:
49
+ """
50
+ Analyzes queries to extract intent, entities, and key information.
51
+ Uses LLM-based analysis for intelligent query understanding.
52
+ """
53
+
54
+ def __init__(self, llm=None):
55
+ """
56
+ Initialize QueryAnalyzer.
57
+
58
+ Args:
59
+ llm: Optional LangChain LLM for advanced analysis
60
+ """
61
+ self.llm = llm
62
+ self.intent_patterns = {
63
+ 'definition': r'\b(what is|define|meaning of|definition)\b',
64
+ 'how_to': r'\b(how to|how do|how can|steps to)\b',
65
+ 'comparison': r'\b(compare|difference|versus|vs|better than)\b',
66
+ 'explanation': r'\b(why|explain|reason|cause)\b',
67
+ 'listing': r'\b(list|enumerate|what are|types of)\b',
68
+ 'example': r'\b(example|instance|sample|case)\b',
69
+ }
70
+
71
+ def analyze(self, query: str) -> QueryAnalysis:
72
+ """
73
+ Analyze query to extract intent, entities, and keywords.
74
+
75
+ Args:
76
+ query: User's original query
77
+
78
+ Returns:
79
+ QueryAnalysis object with extracted information
80
+ """
81
+ query_lower = query.lower()
82
+
83
+ # Detect intent
84
+ intent = self._detect_intent(query_lower)
85
+
86
+ # Extract entities (simplified - can be enhanced with NER)
87
+ entities = self._extract_entities(query)
88
+
89
+ # Extract keywords
90
+ keywords = self._extract_keywords(query)
91
+
92
+ # Assess complexity
93
+ complexity = self._assess_complexity(query, entities, keywords)
94
+
95
+ # Detect domain
96
+ domain = self._detect_domain(query_lower)
97
+
98
+ return QueryAnalysis(
99
+ intent=intent,
100
+ entities=entities,
101
+ keywords=keywords,
102
+ complexity=complexity,
103
+ domain=domain
104
+ )
105
+
106
+ def _detect_intent(self, query_lower: str) -> str:
107
+ """Detect query intent using pattern matching."""
108
+ for intent, pattern in self.intent_patterns.items():
109
+ if re.search(pattern, query_lower):
110
+ return intent
111
+ return 'general'
112
+
113
+ def _extract_entities(self, query: str) -> List[str]:
114
+ """Extract named entities (simplified version)."""
115
+ # Look for capitalized words (potential entities)
116
+ words = query.split()
117
+ entities = []
118
+
119
+ for word in words:
120
+ # Skip common words at sentence start
121
+ if word[0].isupper() and word.lower() not in ['what', 'how', 'why', 'when', 'where', 'which']:
122
+ entities.append(word)
123
+
124
+ # Look for quoted terms
125
+ quoted = re.findall(r'"([^"]+)"', query)
126
+ entities.extend(quoted)
127
+
128
+ return list(set(entities))
129
+
130
+ def _extract_keywords(self, query: str) -> List[str]:
131
+ """Extract important keywords from query."""
132
+ # Remove stop words (simplified list)
133
+ stop_words = {
134
+ 'a', 'an', 'the', 'is', 'are', 'was', 'were', 'be', 'been',
135
+ 'what', 'how', 'why', 'when', 'where', 'which', 'who',
136
+ 'do', 'does', 'did', 'can', 'could', 'should', 'would',
137
+ 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'
138
+ }
139
+
140
+ # Split and filter
141
+ words = re.findall(r'\b\w+\b', query.lower())
142
+ keywords = [w for w in words if w not in stop_words and len(w) > 2]
143
+
144
+ return keywords[:10] # Limit to top 10
145
+
146
+ def _assess_complexity(self, query: str, entities: List[str], keywords: List[str]) -> str:
147
+ """Assess query complexity."""
148
+ word_count = len(query.split())
149
+ entity_count = len(entities)
150
+ keyword_count = len(keywords)
151
+
152
+ # Simple scoring
153
+ score = word_count + (entity_count * 2) + (keyword_count * 1.5)
154
+
155
+ if score < 15:
156
+ return 'simple'
157
+ elif score < 30:
158
+ return 'medium'
159
+ else:
160
+ return 'complex'
161
+
162
+ def _detect_domain(self, query_lower: str) -> Optional[str]:
163
+ """Detect technical domain if present."""
164
+ domains = {
165
+ 'programming': ['code', 'function', 'class', 'variable', 'algorithm', 'debug'],
166
+ 'data_science': ['model', 'dataset', 'training', 'prediction', 'accuracy'],
167
+ 'machine_learning': ['neural', 'network', 'learning', 'ai', 'deep learning'],
168
+ 'web': ['html', 'css', 'javascript', 'api', 'frontend', 'backend'],
169
+ 'database': ['sql', 'query', 'database', 'table', 'index'],
170
+ 'security': ['encryption', 'authentication', 'vulnerability', 'attack'],
171
+ }
172
+
173
+ for domain, keywords in domains.items():
174
+ if any(kw in query_lower for kw in keywords):
175
+ return domain
176
+
177
+ return None
178
+
179
+
180
+ class QueryRephraser:
181
+ """
182
+ Generates natural language variations of queries using multiple strategies.
183
+ """
184
+
185
+ def __init__(self, llm=None):
186
+ """
187
+ Initialize QueryRephraser.
188
+
189
+ Args:
190
+ llm: LangChain LLM for generating variations
191
+ """
192
+ self.llm = llm
193
+
194
+ def generate_variations(
195
+ self,
196
+ query: str,
197
+ analysis: QueryAnalysis,
198
+ strategy: QueryStrategy = QueryStrategy.BALANCED
199
+ ) -> List[str]:
200
+ """
201
+ Generate query variations based on strategy.
202
+
203
+ Args:
204
+ query: Original query
205
+ analysis: Query analysis results
206
+ strategy: Expansion strategy to use
207
+
208
+ Returns:
209
+ List of query variations
210
+ """
211
+ variations = [query] # Always include original
212
+
213
+ if strategy == QueryStrategy.QUICK:
214
+ # Just add synonym variation
215
+ variations.append(self._synonym_variation(query, analysis))
216
+
217
+ elif strategy == QueryStrategy.BALANCED:
218
+ # Add synonym, expanded, and simplified versions
219
+ variations.append(self._synonym_variation(query, analysis))
220
+ variations.append(self._expanded_variation(query, analysis))
221
+ variations.append(self._simplified_variation(query, analysis))
222
+
223
+ elif strategy == QueryStrategy.COMPREHENSIVE:
224
+ # Add all variations
225
+ variations.append(self._synonym_variation(query, analysis))
226
+ variations.append(self._expanded_variation(query, analysis))
227
+ variations.append(self._simplified_variation(query, analysis))
228
+ variations.append(self._keyword_focused(query, analysis))
229
+ variations.append(self._context_variation(query, analysis))
230
+ # Add one more: alternate phrasing
231
+ if analysis.intent in ['how_to', 'explanation']:
232
+ variations.append(f"Guide to {' '.join(analysis.keywords[:3])}")
233
+
234
+ # Remove duplicates and None values
235
+ variations = [v for v in variations if v]
236
+ return list(dict.fromkeys(variations)) # Preserve order, remove dupes
237
+
238
+ def _synonym_variation(self, query: str, analysis: QueryAnalysis) -> str:
239
+ """Generate variation using synonyms."""
240
+ # Common synonym replacements
241
+ synonyms = {
242
+ 'error': 'issue',
243
+ 'problem': 'issue',
244
+ 'fix': 'resolve',
245
+ 'use': 'utilize',
246
+ 'create': 'generate',
247
+ 'make': 'create',
248
+ 'get': 'retrieve',
249
+ 'show': 'display',
250
+ 'find': 'locate',
251
+ 'explain': 'describe',
252
+ }
253
+
254
+ words = query.lower().split()
255
+ for i, word in enumerate(words):
256
+ if word in synonyms:
257
+ words[i] = synonyms[word]
258
+ break # Only replace one word to keep natural
259
+
260
+ return ' '.join(words).capitalize()
261
+
262
+ def _expanded_variation(self, query: str, analysis: QueryAnalysis) -> str:
263
+ """Generate expanded version with more detail."""
264
+ if analysis.intent == 'definition':
265
+ return f"Detailed explanation and definition of {' '.join(analysis.keywords)}"
266
+ elif analysis.intent == 'how_to':
267
+ return f"Step-by-step guide on {query.lower()}"
268
+ elif analysis.intent == 'comparison':
269
+ return f"Comprehensive comparison: {query}"
270
+ else:
271
+ # Add qualifying words
272
+ return f"Detailed information about {query.lower()}"
273
+
274
+ def _simplified_variation(self, query: str, analysis: QueryAnalysis) -> str:
275
+ """Generate simplified version focusing on core concepts."""
276
+ # Use just the keywords
277
+ if len(analysis.keywords) >= 2:
278
+ return ' '.join(analysis.keywords[:3])
279
+ return query
280
+
281
+ def _keyword_focused(self, query: str, analysis: QueryAnalysis) -> str:
282
+ """Create keyword-focused variation for BM25."""
283
+ keywords = analysis.keywords + analysis.entities
284
+ return ' '.join(keywords[:5])
285
+
286
+ def _context_variation(self, query: str, analysis: QueryAnalysis) -> str:
287
+ """Add contextual information if domain detected."""
288
+ if analysis.domain:
289
+ return f"{query} in {analysis.domain} context"
290
+ return query
291
+
292
+
293
+ class MultiQueryExpander:
294
+ """
295
+ Main query expansion orchestrator that combines analysis and rephrasing.
296
+ """
297
+
298
+ def __init__(self, llm=None):
299
+ """
300
+ Initialize MultiQueryExpander.
301
+
302
+ Args:
303
+ llm: LangChain LLM for advanced expansions
304
+ """
305
+ self.analyzer = QueryAnalyzer(llm)
306
+ self.rephraser = QueryRephraser(llm)
307
+
308
+ def expand(
309
+ self,
310
+ query: str,
311
+ strategy: QueryStrategy = QueryStrategy.BALANCED,
312
+ max_queries: int = 6
313
+ ) -> ExpandedQuery:
314
+ """
315
+ Expand query into multiple variations.
316
+
317
+ Args:
318
+ query: Original user query
319
+ strategy: Expansion strategy
320
+ max_queries: Maximum number of queries to generate
321
+
322
+ Returns:
323
+ ExpandedQuery object with all variations
324
+ """
325
+ # Analyze query
326
+ analysis = self.analyzer.analyze(query)
327
+
328
+ # Generate variations
329
+ variations = self.rephraser.generate_variations(query, analysis, strategy)
330
+
331
+ # Limit to max_queries
332
+ variations = variations[:max_queries]
333
+
334
+ return ExpandedQuery(
335
+ original=query,
336
+ variations=variations,
337
+ strategy_used=strategy,
338
+ analysis=analysis
339
+ )
340
+
341
+
342
+ class MultiHopReasoner:
343
+ """
344
+ Implements multi-hop reasoning to connect concepts across documents.
345
+ Useful for complex queries that require information from multiple sources.
346
+ """
347
+
348
+ def __init__(self, llm=None):
349
+ """
350
+ Initialize MultiHopReasoner.
351
+
352
+ Args:
353
+ llm: LangChain LLM for reasoning
354
+ """
355
+ self.llm = llm
356
+
357
+ def generate_sub_queries(self, query: str, analysis: QueryAnalysis) -> List[str]:
358
+ """
359
+ Break complex query into sub-queries for multi-hop reasoning.
360
+
361
+ Args:
362
+ query: Original complex query
363
+ analysis: Query analysis
364
+
365
+ Returns:
366
+ List of sub-queries
367
+ """
368
+ sub_queries = [query]
369
+
370
+ # For comparison queries, create separate queries for each entity
371
+ if analysis.intent == 'comparison' and len(analysis.entities) >= 2:
372
+ for entity in analysis.entities[:2]:
373
+ sub_queries.append(f"Information about {entity}")
374
+ elif analysis.intent == 'comparison' and len(analysis.keywords) >= 2:
375
+ # Fallback: use keywords if no entities found
376
+ for keyword in analysis.keywords[:2]:
377
+ sub_queries.append(f"Information about {keyword}")
378
+
379
+ # For how-to queries, break into steps
380
+ if analysis.intent == 'how_to' and len(analysis.keywords) >= 2:
381
+ main_topic = ' '.join(analysis.keywords[:2])
382
+ sub_queries.append(f"Prerequisites for {main_topic}")
383
+ sub_queries.append(f"Steps to {main_topic}")
384
+
385
+ # For complex questions, create focused sub-queries
386
+ if analysis.complexity == 'complex' and len(analysis.keywords) > 3:
387
+ # Create queries focusing on different keyword groups
388
+ mid = len(analysis.keywords) // 2
389
+ sub_queries.append(' '.join(analysis.keywords[:mid]))
390
+ sub_queries.append(' '.join(analysis.keywords[mid:]))
391
+
392
+ return sub_queries[:5] # Limit to 5 sub-queries
393
+
394
+
395
+ class FallbackStrategies:
396
+ """
397
+ Implements fallback strategies for queries that don't retrieve good results.
398
+ """
399
+
400
+ @staticmethod
401
+ def simplify_query(query: str) -> str:
402
+ """Simplify query by removing modifiers and focusing on core terms."""
403
+ # Remove question words
404
+ query = re.sub(r'\b(what|how|why|when|where|which|who|can|could|should|would)\b', '', query, flags=re.IGNORECASE)
405
+
406
+ # Remove common phrases
407
+ query = re.sub(r'\b(is|are|was|were|be|been|the|a|an)\b', '', query, flags=re.IGNORECASE)
408
+
409
+ # Clean up extra spaces
410
+ query = re.sub(r'\s+', ' ', query).strip()
411
+
412
+ return query
413
+
414
+ @staticmethod
415
+ def broaden_query(query: str, analysis: QueryAnalysis) -> str:
416
+ """Broaden query to increase recall."""
417
+ # Remove specific constraints
418
+ query = re.sub(r'\b(specific|exactly|precisely|only|just)\b', '', query, flags=re.IGNORECASE)
419
+
420
+ # Add general terms
421
+ if analysis.keywords:
422
+ return f"{analysis.keywords[0]} overview"
423
+
424
+ return query
425
+
426
+ @staticmethod
427
+ def focus_entities(analysis: QueryAnalysis) -> str:
428
+ """Create entity-focused query as fallback."""
429
+ if analysis.entities:
430
+ return ' '.join(analysis.entities)
431
+ elif analysis.keywords:
432
+ return ' '.join(analysis.keywords[:3])
433
+ return ""
434
+
435
+
436
+ # Convenience function for easy integration
437
+ def expand_query_simple(
438
+ query: str,
439
+ strategy: str = "balanced",
440
+ llm=None
441
+ ) -> List[str]:
442
+ """
443
+ Simple function to expand a query without dealing with classes.
444
+
445
+ Args:
446
+ query: User's query to expand
447
+ strategy: "quick", "balanced", or "comprehensive"
448
+ llm: Optional LangChain LLM
449
+
450
+ Returns:
451
+ List of expanded query variations
452
+
453
+ Example:
454
+ >>> queries = expand_query_simple("How do I debug Python code?", strategy="balanced")
455
+ >>> print(queries)
456
+ ['How do I debug Python code?', 'How do I resolve Python code?', ...]
457
+ """
458
+ expander = MultiQueryExpander(llm=llm)
459
+ strategy_enum = QueryStrategy(strategy)
460
+ expanded = expander.expand(query, strategy=strategy_enum)
461
+ return expanded.variations
462
+
463
+
464
+ # Example usage and testing
465
+ if __name__ == "__main__":
466
+ # Example 1: Simple query expansion
467
+ print("=" * 60)
468
+ print("Example 1: Simple Query Expansion")
469
+ print("=" * 60)
470
+
471
+ query = "What is machine learning?"
472
+ queries = expand_query_simple(query, strategy="balanced")
473
+
474
+ print(f"\nOriginal: {query}")
475
+ print(f"\nExpanded queries ({len(queries)}):")
476
+ for i, q in enumerate(queries, 1):
477
+ print(f" {i}. {q}")
478
+
479
+ # Example 2: Complex query with full analysis
480
+ print("\n" + "=" * 60)
481
+ print("Example 2: Complex Query with Analysis")
482
+ print("=" * 60)
483
+
484
+ expander = MultiQueryExpander()
485
+ query = "How do I compare the performance of different neural network architectures?"
486
+ result = expander.expand(query, strategy=QueryStrategy.COMPREHENSIVE)
487
+
488
+ print(f"\nOriginal: {result.original}")
489
+ print(f"\nAnalysis:")
490
+ print(f" Intent: {result.analysis.intent}")
491
+ print(f" Entities: {result.analysis.entities}")
492
+ print(f" Keywords: {result.analysis.keywords}")
493
+ print(f" Complexity: {result.analysis.complexity}")
494
+ print(f" Domain: {result.analysis.domain}")
495
+ print(f"\nExpanded queries ({len(result.variations)}):")
496
+ for i, q in enumerate(result.variations, 1):
497
+ print(f" {i}. {q}")
498
+
499
+ # Example 3: Multi-hop reasoning
500
+ print("\n" + "=" * 60)
501
+ print("Example 3: Multi-Hop Reasoning")
502
+ print("=" * 60)
503
+
504
+ reasoner = MultiHopReasoner()
505
+ analyzer = QueryAnalyzer()
506
+
507
+ query = "Compare Python and Java for web development"
508
+ analysis = analyzer.analyze(query)
509
+ sub_queries = reasoner.generate_sub_queries(query, analysis)
510
+
511
+ print(f"\nOriginal: {query}")
512
+ print(f"\nSub-queries for multi-hop reasoning:")
513
+ for i, sq in enumerate(sub_queries, 1):
514
+ print(f" {i}. {sq}")
515
+
516
+ # Example 4: Fallback strategies
517
+ print("\n" + "=" * 60)
518
+ print("Example 4: Fallback Strategies")
519
+ print("=" * 60)
520
+
521
+ query = "What is the specific difference between supervised and unsupervised learning?"
522
+ analysis = analyzer.analyze(query)
523
+
524
+
rag_processor.py CHANGED
@@ -6,109 +6,413 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
6
  from langchain_core.runnables import RunnableParallel, RunnablePassthrough
7
  from langchain_core.output_parsers import StrOutputParser
8
  from langchain_core.runnables.history import RunnableWithMessageHistory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def create_rag_chain(retriever, get_session_history_func):
11
- """
12
- Creates an advanced Retrieval-Augmented Generation (RAG) chain with hybrid search,
13
- query rewriting, answer refinement, and conversational memory.
14
 
15
- Args:
16
- retriever: A configured LangChain retriever object.
17
- get_session_history_func: A function to get the chat history for a session.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- Returns:
20
- A LangChain runnable object representing the RAG chain with memory.
 
 
 
 
 
 
21
 
22
- Raises:
23
- ValueError: If the GROQ_API_KEY is missing.
24
- """
25
- load_dotenv()
26
- api_key = os.getenv("GROQ_API_KEY")
27
- print("key loaded")
28
- if not api_key or api_key == "your_groq_api_key_here":
29
- error_msg = "GROQ_API_KEY not found or not configured properly.\n"
30
- print("Not found key")
31
- if os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"):
32
- error_msg += (
33
- "For Hugging Face Spaces: Set GROQ_API_KEY in your Space's Settings > Repository Secrets.\n"
34
- "Go to your Space settings and add GROQ_API_KEY as a secret variable."
35
  )
36
- else:
37
- error_msg += (
38
- "For local development: Set your GROQ API key in the .env file.\n"
39
- "Copy .env.example to .env and add your actual API key.\n"
40
- "Get your API key from: https://console.groq.com/keys"
41
- )
42
-
43
- raise ValueError(error_msg)
44
- llm = ChatGroq(model_name="llama-3.1-8b-instant", api_key=api_key, temperature=0.2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- print("\nSetting up query rewriting chain...")
47
- rewrite_template = """You are an expert at rewriting user questions for a vector database.
48
- You are here to help the user with their document.
49
- Based on the chat history, reformulate the follow-up question to be a standalone question.
50
- This new query should be optimized to find the most relevant documents in a knowledge base.
51
- Do NOT answer the question, only provide the rewritten, optimized question.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  Chat History:
54
  {chat_history}
55
 
56
  Follow-up Question: {question}
57
- Standalone Question:"""
 
58
  rewrite_prompt = ChatPromptTemplate.from_messages([
59
  ("system", rewrite_template),
60
  MessagesPlaceholder(variable_name="chat_history"),
61
- ("human", "Based on our conversation, reformulate this question to be a standalone query: {question}")
62
  ])
63
  query_rewriter = rewrite_prompt | llm | StrOutputParser()
64
- print("\nSetting up main RAG chain...")
65
- rag_template = """You are CogniChat, an expert document analysis assistant.
66
- IMPORTANT RULES:
67
- 1. ONLY use information from the provided context
68
- 2. If the answer isn't in the context, say "I cannot find this information in the uploaded documents"
69
- 3. Cite specific sections when possible (e.g., "According to Section 2.3...")
70
- 4. Be concise but comprehensive
71
- Context:
72
- {context}"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  rag_prompt = ChatPromptTemplate.from_messages([
74
  ("system", rag_template),
75
  MessagesPlaceholder(variable_name="chat_history"),
76
  ("human", "{question}"),
77
  ])
78
- setup_and_retrieval = RunnableParallel({
79
- "context": RunnablePassthrough.assign(
80
- rewritten_question=query_rewriter
81
- ) | (lambda x: x["rewritten_question"]) | retriever,
82
- "question": itemgetter("question"),
83
- "chat_history": itemgetter("chat_history")})
84
- conversational_rag_chain = (
85
- setup_and_retrieval
86
- | rag_prompt
87
- | llm
88
- | StrOutputParser()
89
- )
 
 
90
  chain_with_memory = RunnableWithMessageHistory(
91
  conversational_rag_chain,
92
  get_session_history_func,
93
  input_messages_key="question",
94
  history_messages_key="chat_history",
95
  )
96
- print("\nSetting up answer refinement chain...")
97
- refine_template = """You are an expert at editing and refining content.
98
- Your task is to take a given answer and improve its clarity, structure, and readability.
99
- Use formatting such as bold text, bullet points, or numbered lists where it enhances the explanation.
100
- Do not add any new information that wasn't in the original answer.
101
-
102
- Original Answer:
103
- {answer}
104
-
105
- Refined Answer:"""
106
- refine_prompt = ChatPromptTemplate.from_template(refine_template)
107
- refinement_chain = refine_prompt | llm | StrOutputParser()
108
-
109
- final_chain = (
110
- lambda input_dict: {"answer": chain_with_memory.invoke(input_dict, config=input_dict.get('config'))}
111
- ) | refinement_chain
112
 
113
- print("\nFinalizing the complete chain with memory...")
114
- return final_chain
 
 
 
6
  from langchain_core.runnables import RunnableParallel, RunnablePassthrough
7
  from langchain_core.output_parsers import StrOutputParser
8
  from langchain_core.runnables.history import RunnableWithMessageHistory
9
+ from langchain_core.documents import Document
10
+ from query_expansion import expand_query_simple
11
+ from typing import List, Optional
12
+ import time
13
+
14
+ class GroqAPIKeyManager:
15
+ def __init__(self, api_keys: List[str]):
16
+ self.api_keys = [key for key in api_keys if key and key != "your_groq_api_key_here"]
17
+ if not self.api_keys:
18
+ raise ValueError("No valid API keys provided!")
19
+
20
+ self.current_index = 0
21
+ self.failed_keys = set()
22
+ self.success_count = {key: 0 for key in self.api_keys}
23
+ self.failure_count = {key: 0 for key in self.api_keys}
24
+
25
+ print(f"API Key Manager: Loaded {len(self.api_keys)} API keys")
26
+
27
+ def get_current_key(self) -> str:
28
+ return self.api_keys[self.current_index]
29
+
30
+ def mark_success(self, api_key: str):
31
+ if api_key in self.success_count:
32
+ self.success_count[api_key] += 1
33
+ if api_key in self.failed_keys:
34
+ self.failed_keys.remove(api_key)
35
+ print(f"API Key #{self.api_keys.index(api_key) + 1} recovered!")
36
+
37
+ def mark_failure(self, api_key: str):
38
+ if api_key in self.failure_count:
39
+ self.failure_count[api_key] += 1
40
+ self.failed_keys.add(api_key)
41
+
42
+ def rotate_to_next_key(self) -> bool:
43
+ initial_index = self.current_index
44
+ attempts = 0
45
+
46
+ while attempts < len(self.api_keys):
47
+ self.current_index = (self.current_index + 1) % len(self.api_keys)
48
+ attempts += 1
49
+
50
+ current_key = self.api_keys[self.current_index]
51
+ if attempts >= len(self.api_keys):
52
+ print(f"All keys attempted, retrying with key #{self.current_index + 1}")
53
+ return True
54
+ if current_key not in self.failed_keys:
55
+ print(f"Switching to API Key #{self.current_index + 1}")
56
+ return True
57
+
58
+ return False
59
+
60
+ def get_statistics(self) -> str:
61
+ stats = []
62
+ for i, key in enumerate(self.api_keys):
63
+ success = self.success_count[key]
64
+ failure = self.failure_count[key]
65
+ status = "FAILED" if key in self.failed_keys else "ACTIVE"
66
+ masked_key = key[:8] + "..." + key[-4:] if len(key) > 12 else "***"
67
+ stats.append(f" Key #{i+1} ({masked_key}): {success} success, {failure} failures [{status}]")
68
+ return "\n".join(stats)
69
 
 
 
 
 
70
 
71
+ def load_api_keys_from_hf_secrets() -> List[str]:
72
+ api_keys = []
73
+ secret_names = ['GROQ_API_KEY_1', 'GROQ_API_KEY_2', 'GROQ_API_KEY_3', 'GROQ_API_KEY_4']
74
+
75
+ print("Loading API keys from Hugging Face Secrets...")
76
+
77
+ for secret_name in secret_names:
78
+ try:
79
+ api_key = os.getenv(secret_name)
80
+
81
+ if api_key and api_key.strip() and api_key != "your_groq_api_key_here":
82
+ api_keys.append(api_key.strip())
83
+ print(f" Loaded: {secret_name}")
84
+ else:
85
+ print(f" Not found or empty: {secret_name}")
86
+ except Exception as e:
87
+ print(f" Error loading {secret_name}: {str(e)}")
88
+ return api_keys
89
+
90
 
91
+ def create_llm_with_fallback(
92
+ api_key_manager: GroqAPIKeyManager,
93
+ model_name: str,
94
+ temperature: float,
95
+ max_retries: int = 3
96
+ ) -> ChatGroq:
97
+ for attempt in range(max_retries):
98
+ current_key = api_key_manager.get_current_key()
99
 
100
+ try:
101
+ llm = ChatGroq(
102
+ model_name=model_name,
103
+ api_key=current_key,
104
+ temperature=temperature
 
 
 
 
 
 
 
 
105
  )
106
+ test_result = llm.invoke("test")
107
+ api_key_manager.mark_success(current_key)
108
+ return llm
109
+
110
+ except Exception as e:
111
+ error_msg = str(e).lower()
112
+ api_key_manager.mark_failure(current_key)
113
+ if "rate" in error_msg or "limit" in error_msg:
114
+ print(f" Rate limit hit on API Key #{api_key_manager.current_index + 1}")
115
+ elif "auth" in error_msg or "api" in error_msg:
116
+ print(f" Authentication failed on API Key #{api_key_manager.current_index + 1}")
117
+ else:
118
+ print(f" Error with API Key #{api_key_manager.current_index + 1}: {str(e)[:50]}")
119
+
120
+ if attempt < max_retries - 1:
121
+ if api_key_manager.rotate_to_next_key():
122
+ print(f" Retrying with next API key (Attempt {attempt + 2}/{max_retries})...")
123
+ time.sleep(1)
124
+ else:
125
+ raise ValueError("All API keys failed!")
126
+ else:
127
+ raise ValueError(f"Failed to initialize LLM after {max_retries} attempts")
128
+
129
+ raise ValueError("Failed to create LLM with any available API key")
130
+
131
+
132
+ def create_multi_query_retriever(base_retriever, llm, strategy: str = "balanced"):
133
+ def multi_query_retrieve(query: str) -> List[Document]:
134
+ query_variations = expand_query_simple(query, strategy=strategy, llm=llm)
135
+ all_docs = []
136
+ seen_content = set()
137
+ for i, query_var in enumerate(query_variations):
138
+ try:
139
+ docs = base_retriever.invoke(query_var)
140
+ for doc in docs:
141
+ content_hash = hash(doc.page_content)
142
+ if content_hash not in seen_content:
143
+ seen_content.add(content_hash)
144
+ all_docs.append(doc)
145
+ except Exception as e:
146
+ print(f" Query Expansion Error (Query {i+1}): {str(e)[:50]}")
147
+ continue
148
+ print(f" Query Expansion: Retrieved {len(all_docs)} unique documents.")
149
+ return all_docs
150
+ return multi_query_retrieve
151
+
152
+
153
+ def get_system_prompt(temperature: float) -> str:
154
+ if temperature <= 0.4:
155
+ return """You are CogniChat, an expert document analysis assistant specializing in comprehensive and well-structured answers.
156
+
157
+ RESPONSE GUIDELINES:
158
+
159
+ **Structure & Formatting:**
160
+ - Start with a direct answer to the question
161
+ - Use **bold** for key terms, important concepts, and technical terminology
162
+ - Use bullet points (•) for lists, features, or multiple items
163
+ - Use numbered lists (1., 2., 3.) for steps, procedures, or sequential information
164
+ - Use ### Headers to organize different sections or topics
165
+ - Add blank lines between sections for readability
166
+
167
+ **Source Citation:**
168
+ - Always cite information using: [Source: filename, Page: X] and cite it at the end of the entire answer only
169
+ - Place citations at the end of your final answer only
170
+ - Do not cite sources within the body of your answer
171
+ - Multiple sources: [Source: doc1.pdf, Page: 3; doc2.pdf, Page: 7]
172
+
173
+ **Completeness:**
174
+ - Provide thorough, detailed answers using ALL relevant information from context
175
+ - Summarize and properly elaborate each point for increased clarity
176
+ - If the question has multiple parts, address each part clearly
177
+
178
+ **Accuracy:**
179
+ - ONLY use information from the provided context documents below
180
+ - If information is incomplete, state what IS available and what ISN'T
181
+ - If the answer isn't in the context, clearly state: "I cannot find this information in the uploaded documents"
182
+ - Never make assumptions or add information not in the context
183
+
184
+ ---
185
+
186
+ {context}
187
+
188
+ ---
189
+
190
+ Now answer the following question comprehensively using the context above:"""
191
+
192
+ elif temperature <= 0.8:
193
+ return """You are CogniChat, an intelligent document analysis assistant that combines accuracy with engaging communication.
194
+
195
+ RESPONSE GUIDELINES:
196
+
197
+ **Communication Style:**
198
+ - Present information in a clear, engaging manner
199
+ - Use **bold** for emphasis on important concepts
200
+ - Balance structure with natural flow
201
+ - Make complex topics accessible and interesting
202
+
203
+ **Content Approach:**
204
+ - Ground your response firmly in the provided context
205
+ - Add helpful explanations and connections between concepts
206
+ - Use analogies or examples when they help clarify ideas (but keep them brief)
207
+ - Organize information logically with headers (###) and lists where appropriate
208
+
209
+ **Source Attribution:**
210
+ - Cite sources at the end: [Source: filename, Page: X]
211
+ - Be transparent about what the documents do and don't contain
212
+
213
+ **Accuracy:**
214
+ - Base your answer on the context documents provided
215
+ - If information is partial, explain what's available
216
+ - Acknowledge gaps: "The documents don't cover this aspect"
217
+
218
+ ---
219
+
220
+ {context}
221
+
222
+ ---
223
+
224
+ Now answer the following question in an engaging yet accurate way:"""
225
+
226
+ else: # temperature > 0.8
227
+ # Creative BUT CLEAR prompt - REVISED VERSION
228
+ return """You are CogniChat, a creative document analyst who makes complex information clear, memorable, and engaging.
229
 
230
+ YOUR CORE MISSION: **CLARITY FIRST, CREATIVITY SECOND**
231
+
232
+ Make information easier to understand, not harder. Your creativity should illuminate, not obscure.
233
+
234
+ **CREATIVE CLARITY PRINCIPLES:**
235
+
236
+ 1. **Simplify, Don't Complicate**
237
+ - Break down complex concepts into simple, digestible parts
238
+ - Use everyday language alongside technical terms
239
+ - Explain jargon immediately in plain English
240
+ - Short sentences for complex ideas, varied length for rhythm
241
+
242
+ 2. **Smart Use of Examples & Analogies** (Use Sparingly!)
243
+ - Only use analogies when they genuinely make something clearer
244
+ - Keep analogies simple and relatable (everyday objects/experiences)
245
+ - Never use metaphors that require explanation themselves
246
+ - If you can explain it directly in simple terms, do that instead
247
+
248
+ 3. **Engaging Structure**
249
+ - Start with the core answer in one clear sentence
250
+ - Use **bold** to highlight key takeaways
251
+ - Break information into logical chunks with ### headers
252
+ - Use bullet points for clarity, not decoration
253
+ - Add brief transition phrases to connect ideas smoothly
254
+
255
+ 4. **Conversational Yet Precise**
256
+ - Write like you're explaining to a smart friend
257
+ - Use "you" and active voice to engage readers
258
+ - Ask occasional rhetorical questions only if they aid understanding
259
+ - Vary sentence length to maintain interest
260
+ - Use emojis sparingly (1-2 max) and only where they add clarity
261
+
262
+ 5. **Visual Clarity**
263
+ - Strategic use of formatting: **bold** for key terms, *italics* for emphasis
264
+ - White space between sections for easy scanning
265
+ - Progressive disclosure: simple concepts first, details after
266
+ - Numbered lists for sequences, bullets for related items
267
+
268
+ **WHAT TO AVOID:**
269
+ - Flowery or overly descriptive language
270
+ - Complex metaphors that need their own explanation
271
+ - Long narrative storytelling that buries the facts
272
+ - Multiple rhetorical questions in a row
273
+ - Overuse of emojis or exclamation points
274
+ - Making simple things sound complicated
275
+
276
+ **ACCURACY BOUNDARIES:**
277
+ - Creative explanation and presentation of facts
278
+ - Simple, helpful examples from common knowledge
279
+ - Reorganizing information for better understanding
280
+ - Never invent facts not in the documents
281
+ - Don't contradict source material
282
+ - If info is missing, say so clearly and briefly
283
+
284
+ **Source Attribution:**
285
+ - End with: [Source: filename, Page: X]
286
+ - Keep it simple and clear
287
+
288
+ ---
289
+
290
+ {context}
291
+
292
+ ---
293
+
294
+ Now, explain the answer clearly and engagingly. Remember: if your grandmother couldn't understand it, simplify more:"""
295
+
296
+
297
+
298
+ def create_rag_chain(
299
+ retriever,
300
+ get_session_history_func,
301
+ enable_query_expansion=True,
302
+ expansion_strategy="balanced",
303
+ model_name: str = "moonshotai/kimi-k2-instruct",
304
+ temperature: float = 0.2,
305
+ api_keys: Optional[List[str]] = None
306
+ ):
307
+ if api_keys is None:
308
+ api_keys = load_api_keys_from_hf_secrets()
309
+
310
+ if not api_keys:
311
+ raise ValueError(
312
+ "No valid API keys found! Please set GROQ_API_KEY or GROQ_API_KEY_1, "
313
+ "GROQ_API_KEY_2, GROQ_API_KEY_3, GROQ_API_KEY_4 in your .env file"
314
+ )
315
+
316
+ api_key_manager = GroqAPIKeyManager(api_keys)
317
+
318
+ print(f" RAG: Initializing LLM - Model: {model_name}, Temp: {temperature}")
319
+
320
+ if temperature <= 0.4:
321
+ creativity_mode = "FACTUAL & STRUCTURED"
322
+ elif temperature <= 0.8:
323
+ creativity_mode = "BALANCED & ENGAGING"
324
+ else:
325
+ creativity_mode = "CREATIVE & STORYTELLING"
326
+ print(f"Creativity Mode: {creativity_mode}")
327
+
328
+ llm = create_llm_with_fallback(api_key_manager, model_name, temperature)
329
+ print(f"LLM initialized with API Key #{api_key_manager.current_index + 1}")
330
+
331
+ if enable_query_expansion:
332
+ print(f"RAG: Query Expansion ENABLED (Strategy: {expansion_strategy})")
333
+ enhanced_retriever = create_multi_query_retriever(
334
+ base_retriever=retriever,
335
+ llm=llm,
336
+ strategy=expansion_strategy
337
+ )
338
+ else:
339
+ enhanced_retriever = retriever
340
+
341
+ rewrite_template = """You are an expert at optimizing search queries for document retrieval.
342
+
343
+ Given the conversation history and a follow-up question, create a comprehensive standalone question that:
344
+ 1. Incorporates all relevant context from the chat history
345
+ 2. Expands abbreviations and resolves all pronouns (it, they, this, that, etc.)
346
+ 3. Includes key technical terms and concepts that would help find relevant documents
347
+ 4. Maintains the original intent, specificity, and detail level
348
+ 5. If the question asks for comparison or multiple items, ensure all items are in the query
349
 
350
  Chat History:
351
  {chat_history}
352
 
353
  Follow-up Question: {question}
354
+
355
+ Optimized Standalone Question:"""
356
  rewrite_prompt = ChatPromptTemplate.from_messages([
357
  ("system", rewrite_template),
358
  MessagesPlaceholder(variable_name="chat_history"),
359
+ ("human", "{question}")
360
  ])
361
  query_rewriter = rewrite_prompt | llm | StrOutputParser()
362
+
363
+ def format_docs(docs):
364
+ if not docs:
365
+ return "No relevant documents found in the knowledge base."
366
+
367
+ formatted_parts = []
368
+ for i, doc in enumerate(docs, 1):
369
+ source = doc.metadata.get('source', 'Unknown Document')
370
+ page = doc.metadata.get('page', 'N/A')
371
+ rerank_score = doc.metadata.get('rerank_score')
372
+ content = doc.page_content.strip()
373
+
374
+ doc_header = f"{'='*60}\nDOCUMENT {i}\n{'='*60}"
375
+ metadata_line = f"Source: {source} | Page: {page}"
376
+ if rerank_score:
377
+ metadata_line += f" | Relevance: {rerank_score:.3f}"
378
+
379
+ formatted_parts.append(
380
+ f"{doc_header}\n"
381
+ f"{metadata_line}\n"
382
+ f"{'-'*60}\n"
383
+ f"{content}\n"
384
+ )
385
+ return f"RETRIEVED CONTEXT ({len(docs)} documents):\n\n" + "\n".join(formatted_parts)
386
+
387
+ rag_template = get_system_prompt(temperature)
388
+
389
  rag_prompt = ChatPromptTemplate.from_messages([
390
  ("system", rag_template),
391
  MessagesPlaceholder(variable_name="chat_history"),
392
  ("human", "{question}"),
393
  ])
394
+
395
+ rewriter_input = RunnableParallel({
396
+ "question": itemgetter("question"),
397
+ "chat_history": itemgetter("chat_history"),
398
+ })
399
+
400
+ retrieval_chain = rewriter_input | query_rewriter | enhanced_retriever | format_docs
401
+
402
+ conversational_rag_chain = RunnableParallel({
403
+ "context": retrieval_chain,
404
+ "question": itemgetter("question"),
405
+ "chat_history": itemgetter("chat_history"),
406
+ }) | rag_prompt | llm | StrOutputParser()
407
+
408
  chain_with_memory = RunnableWithMessageHistory(
409
  conversational_rag_chain,
410
  get_session_history_func,
411
  input_messages_key="question",
412
  history_messages_key="chat_history",
413
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
+ print("RAG: Chain created successfully.")
416
+ print("\n" + api_key_manager.get_statistics())
417
+
418
+ return chain_with_memory, api_key_manager
templates/index.html CHANGED
@@ -1,615 +1,885 @@
1
  <!DOCTYPE html>
2
- <html lang="en">
3
- <head>
4
- <meta charset="UTF-8">
5
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
- <title>CogniChat - Chat with your Documents</title>
7
- <script src="https://cdn.tailwindcss.com"></script>
8
- <link rel="preconnect" href="https://fonts.googleapis.com">
9
- <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
10
- <link href="https://fonts.googleapis.com/css2?family=Google+Sans:wght@400;500;700&family=Roboto:wght@400;500&display=swap" rel="stylesheet">
11
- <script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
12
- <style>
13
- :root {
14
- --background: #f0f4f9;
15
- --foreground: #1f1f1f;
16
- --primary: #1a73e8;
17
- --primary-hover: #1867cf;
18
- --card: #ffffff;
19
- --card-border: #dadce0;
20
- --input-bg: #e8f0fe;
21
- --user-bubble: #d9e7ff;
22
- --bot-bubble: #f1f3f4;
23
- }
24
-
25
- /* Dark mode styles */
26
- .dark {
27
- --background: #202124;
28
- --foreground: #e8eaed;
29
- --primary: #8ab4f8;
30
- --primary-hover: #99bdfa;
31
- --card: #303134;
32
- --card-border: #5f6368;
33
- --input-bg: #303134;
34
- --user-bubble: #3c4043;
35
- --bot-bubble: #3c4043;
36
- }
37
-
38
- body {
39
- font-family: 'Google Sans', 'Roboto', sans-serif;
40
- background-color: var(--background);
41
- color: var(--foreground);
42
- overflow: hidden;
43
- }
44
-
45
- #chat-window::-webkit-scrollbar { width: 8px; }
46
- #chat-window::-webkit-scrollbar-track { background: transparent; }
47
- #chat-window::-webkit-scrollbar-thumb { background-color: #bdc1c6; border-radius: 20px; }
48
- .dark #chat-window::-webkit-scrollbar-thumb { background-color: #5f6368; }
49
-
50
- .drop-zone--over {
51
- border-color: var(--primary);
52
- box-shadow: 0 0 15px rgba(26, 115, 232, 0.3);
53
- }
54
-
55
- /* Loading Spinner */
56
- .loader {
57
- width: 48px;
58
- height: 48px;
59
- border: 3px solid var(--card-border);
60
- border-radius: 50%;
61
- display: inline-block;
62
- position: relative;
63
- box-sizing: border-box;
64
- animation: rotation 1s linear infinite;
65
- }
66
- .loader::after {
67
- content: '';
68
- box-sizing: border-box;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  position: absolute;
70
- left: 50%;
71
- top: 50%;
72
- transform: translate(-50%, -50%);
73
- width: 56px;
74
- height: 56px;
75
- border-radius: 50%;
76
- border: 3px solid;
77
- border-color: var(--primary) transparent;
78
- }
79
-
80
- @keyframes rotation {
81
- 0% { transform: rotate(0deg); }
82
- 100% { transform: rotate(360deg); }
83
- }
84
-
85
- /* Typing Indicator Animation */
86
- .typing-indicator span {
87
- height: 10px;
88
- width: 10px;
89
- background-color: #9E9E9E;
90
- border-radius: 50%;
91
- display: inline-block;
92
- animation: bounce 1.4s infinite ease-in-out both;
93
- }
94
- .typing-indicator span:nth-child(1) { animation-delay: -0.32s; }
95
- .typing-indicator span:nth-child(2) { animation-delay: -0.16s; }
96
- @keyframes bounce {
97
- 0%, 80%, 100% { transform: scale(0); }
98
- 40% { transform: scale(1.0); }
99
- }
100
-
101
- /* Enhanced Markdown Styling for better readability and aesthetics */
102
- .markdown-content p {
103
- margin-bottom: 1rem;
104
- line-height: 1.75;
105
- }
106
- .markdown-content h1, .markdown-content h2, .markdown-content h3, .markdown-content h4 {
107
- font-family: 'Google Sans', sans-serif;
108
- font-weight: 700;
109
- margin-top: 1.75rem;
110
- margin-bottom: 1rem;
111
- line-height: 1.3;
112
- }
113
- .markdown-content h1 { font-size: 1.75em; border-bottom: 1px solid var(--card-border); padding-bottom: 0.5rem; }
114
- .markdown-content h2 { font-size: 1.5em; }
115
- .markdown-content h3 { font-size: 1.25em; }
116
- .markdown-content h4 { font-size: 1.1em; }
117
- .markdown-content ul, .markdown-content ol {
118
- padding-left: 1.75rem;
119
- margin-bottom: 1rem;
120
- }
121
- .markdown-content li {
122
- margin-bottom: 0.5rem;
123
- }
124
- .dark .markdown-content ul > li::marker { color: var(--primary); }
125
- .markdown-content ul > li::marker { color: var(--primary); }
126
- .markdown-content a {
127
- color: var(--primary);
128
- text-decoration: none;
129
- font-weight: 500;
130
- border-bottom: 1px solid transparent;
131
- transition: all 0.2s ease-in-out;
132
- }
133
- .markdown-content a:hover {
134
- border-bottom-color: var(--primary-hover);
135
- }
136
- .markdown-content blockquote {
137
- margin: 1.5rem 0;
138
- padding-left: 1.5rem;
139
- border-left: 4px solid var(--card-border);
140
- color: #6c757d;
141
- font-style: italic;
142
- }
143
- .dark .markdown-content blockquote {
144
- color: #adb5bd;
145
- }
146
- .markdown-content hr {
147
- border: none;
148
- border-top: 1px solid var(--card-border);
149
- margin: 2rem 0;
150
- }
151
- .markdown-content table {
152
- width: 100%;
153
- border-collapse: collapse;
154
- margin: 1.5rem 0;
155
- font-size: 0.9em;
156
- box-shadow: 0 1px 3px rgba(0,0,0,0.05);
157
- border-radius: 8px;
158
- overflow: hidden;
159
- }
160
- .markdown-content th, .markdown-content td {
161
  border: 1px solid var(--card-border);
162
- padding: 0.75rem 1rem;
163
- text-align: left;
164
- }
165
- .markdown-content th {
166
- background-color: var(--bot-bubble);
167
- font-weight: 500;
168
- }
169
- .markdown-content code {
170
- background-color: rgba(0,0,0,0.05);
171
- padding: 0.2rem 0.4rem;
172
- border-radius: 0.25rem;
173
- font-family: 'Roboto Mono', monospace;
174
- font-size: 0.9em;
175
- }
176
- .dark .markdown-content code {
177
- background-color: rgba(255,255,255,0.1);
178
- }
179
- .markdown-content pre {
180
- position: relative;
181
- background-color: #f8f9fa;
182
- border: 1px solid var(--card-border);
183
- border-radius: 0.5rem;
184
- margin-bottom: 1rem;
185
- }
186
- .dark .markdown-content pre {
187
- background-color: #2e2f32;
188
- }
189
- .markdown-content pre code {
190
- background: none;
191
- padding: 1rem;
192
- display: block;
193
- overflow-x: auto;
194
- }
195
- .markdown-content pre .copy-code-btn {
196
- position: absolute;
197
- top: 0.5rem;
198
- right: 0.5rem;
199
- background-color: #e8eaed;
200
- border: 1px solid #dadce0;
201
- color: #5f6368;
202
- padding: 0.3rem 0.6rem;
203
- border-radius: 0.25rem;
204
- cursor: pointer;
205
- opacity: 0;
206
- transition: opacity 0.2s;
207
  font-size: 0.8em;
208
- }
209
- .dark .markdown-content pre .copy-code-btn {
210
- background-color: #3c4043;
211
- border-color: #5f6368;
212
- color: #e8eaed;
213
- }
214
- .markdown-content pre:hover .copy-code-btn {
215
- opacity: 1;
216
- }
217
-
218
- /* Spinner for the TTS button */
219
- .tts-button-loader {
220
- width: 16px;
221
- height: 16px;
222
- border: 2px solid currentColor; /* Use button's text color */
223
- border-radius: 50%;
224
- display: inline-block;
225
- box-sizing: border-box;
226
- animation: rotation 0.8s linear infinite;
227
- border-bottom-color: transparent; /* Makes it a half circle spinner */
228
- }
229
- </style>
230
- </head>
231
- <body class="w-screen h-screen dark">
232
- <main id="main-content" class="h-full flex flex-col transition-opacity duration-500">
233
- <div id="chat-container" class="hidden flex-1 flex flex-col w-full mx-auto overflow-hidden">
234
- <header class="text-center p-4 border-b border-[var(--card-border)] flex-shrink-0">
235
- <h1 class="text-xl font-medium">Chat with your Docs</h1>
236
- <p id="chat-filename" class="text-xs text-gray-500 dark:text-gray-400 mt-1"></p>
237
- </header>
238
- <div id="chat-window" class="flex-1 overflow-y-auto p-4 md:p-6 lg:p-10">
239
- <div id="chat-content" class="max-w-4xl mx-auto space-y-8">
240
- </div>
241
- </div>
242
- <div class="p-4 flex-shrink-0 bg-[var(--background)] border-t border-[var(--card-border)]">
243
- <form id="chat-form" class="max-w-4xl mx-auto bg-[var(--card)] rounded-full p-2 flex items-center shadow-sm border border-transparent focus-within:border-[var(--primary)] transition-colors">
244
- <input type="text" id="chat-input" placeholder="Ask a question about your documents..." class="flex-grow bg-transparent focus:outline-none px-4 text-sm" autocomplete="off">
245
- <button type="submit" id="chat-submit-btn" class="bg-[var(--primary)] hover:bg-[var(--primary-hover)] text-white p-2 rounded-full transition-all duration-200 disabled:opacity-50 disabled:cursor-not-allowed disabled:bg-gray-500" title="Send">
246
- <svg class="w-5 h-5" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M3.49941 11.5556L11.555 3.5L12.4438 4.38889L6.27721 10.5556H21.9994V11.5556H6.27721L12.4438 17.7222L11.555 18.6111L3.49941 10.5556V11.5556Z" transform="rotate(180, 12.7497, 11.0556)" fill="currentColor"></path></svg>
247
- </button>
248
- </form>
249
- </div>
250
- </div>
251
-
252
- <div id="upload-container" class="flex-1 flex flex-col items-center justify-center p-8 transition-opacity duration-300">
253
- <div class="text-center">
254
- <h1 class="text-5xl font-medium mb-4">Upload docs to chat</h1>
255
- <div id="drop-zone" class="w-full max-w-lg text-center border-2 border-dashed border-[var(--card-border)] rounded-2xl p-10 transition-all duration-300 cursor-pointer bg-[var(--card)] hover:border-[var(--primary)]">
256
- <input id="file-upload" type="file" class="hidden" accept=".pdf,.txt,.docx,.jpg,.jpeg,.png" multiple title="input">
257
- <svg class="mx-auto h-12 w-12 text-gray-400" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" ><path stroke-linecap="round" stroke-linejoin="round" d="M12 16.5V9.75m0 0l3-3m-3 3l-3 3M6.75 19.5a4.5 4.5 0 01-1.41-8.775 5.25 5.25 0 0110.233-2.33 3 3 0 013.758 3.848A3.752 3.752 0 0118 19.5H6.75z"></path></svg>
258
- <p class="mt-4 text-sm font-medium">Drag & drop files or click to upload</p>
259
- <p id="file-name" class="mt-2 text-xs text-gray-500"></p>
260
- </div>
261
- </div>
262
- </div>
263
-
264
- <div id="loading-overlay" class="hidden fixed inset-0 bg-[var(--background)] bg-opacity-80 backdrop-blur-sm flex flex-col items-center justify-center z-50 text-center p-4">
265
- <div class="loader"></div>
266
- <p id="loading-text" class="mt-6 text-sm font-medium"></p>
267
- <p id="loading-subtext" class="mt-2 text-xs text-gray-500 dark:text-gray-400"></p>
268
- </div>
269
- </main>
270
-
271
- <script>
272
- document.addEventListener('DOMContentLoaded', () => {
273
- const uploadContainer = document.getElementById('upload-container');
274
- const chatContainer = document.getElementById('chat-container');
275
- const dropZone = document.getElementById('drop-zone');
276
- const fileUploadInput = document.getElementById('file-upload');
277
- const fileNameSpan = document.getElementById('file-name');
278
- const loadingOverlay = document.getElementById('loading-overlay');
279
- const loadingText = document.getElementById('loading-text');
280
- const loadingSubtext = document.getElementById('loading-subtext');
281
-
282
- const chatForm = document.getElementById('chat-form');
283
- const chatInput = document.getElementById('chat-input');
284
- const chatSubmitBtn = document.getElementById('chat-submit-btn');
285
- const chatWindow = document.getElementById('chat-window');
286
- const chatContent = document.getElementById('chat-content');
287
- const chatFilename = document.getElementById('chat-filename');
288
-
289
- let sessionId = null;
290
- const storedSessionId = sessionStorage.getItem('cognichat_session_id');
291
- if (storedSessionId) {
292
- sessionId = storedSessionId;
293
- console.debug('Restored session ID from storage:', sessionId);
294
- }
295
-
296
- // --- File Upload Logic ---
297
- dropZone.addEventListener('click', () => fileUploadInput.click());
298
-
299
- ['dragenter', 'dragover', 'dragleave', 'drop'].forEach(eventName => {
300
- dropZone.addEventListener(eventName, preventDefaults, false);
301
- document.body.addEventListener(eventName, preventDefaults, false);
302
- });
303
-
304
- ['dragenter', 'dragover'].forEach(eventName => {
305
- dropZone.addEventListener(eventName, () => dropZone.classList.add('drop-zone--over'));
306
- });
307
- ['dragleave', 'drop'].forEach(eventName => {
308
- dropZone.addEventListener(eventName, () => dropZone.classList.remove('drop-zone--over'));
309
- });
310
-
311
- dropZone.addEventListener('drop', (e) => {
312
- const files = e.dataTransfer.files;
313
- if (files.length > 0) handleFiles(files);
314
- });
315
-
316
- fileUploadInput.addEventListener('change', (e) => {
317
- if (e.target.files.length > 0) handleFiles(e.target.files);
318
- });
319
-
320
- function preventDefaults(e) { e.preventDefault(); e.stopPropagation(); }
321
-
322
- async function handleFiles(files) {
323
- const formData = new FormData();
324
- let fileNames = [];
325
- for (const file of files) {
326
- formData.append('file', file);
327
- fileNames.push(file.name);
328
- }
329
-
330
- fileNameSpan.textContent = `Selected: ${fileNames.join(', ')}`;
331
- await uploadAndProcessFiles(formData, fileNames);
332
- }
333
-
334
- async function uploadAndProcessFiles(formData, fileNames) {
335
- loadingOverlay.classList.remove('hidden');
336
- loadingText.textContent = `Processing ${fileNames.length} document(s)...`;
337
- loadingSubtext.textContent = "🤓Creating a knowledge base may take a minute or two. So please hold on tight";
338
-
339
- try {
340
- const response = await fetch('/upload', { method: 'POST', body: formData });
341
- const result = await response.json();
342
-
343
- if (!response.ok) throw new Error(result.message || 'Unknown error occurred.');
344
- if (result.session_id) {
345
- sessionId = result.session_id;
346
- sessionStorage.setItem('cognichat_session_id', sessionId);
347
- console.debug('Stored session ID from upload:', sessionId);
348
- } else {
349
- console.warn('Upload response missing session_id field.');
350
- }
351
-
352
- chatFilename.textContent = `Chatting with: ${result.filename}`;
353
- uploadContainer.classList.add('hidden');
354
- chatContainer.classList.remove('hidden');
355
- appendMessage("I've analyzed your documents. What would you like to know?", "bot");
356
-
357
- } catch (error) {
358
- console.error('Upload error:', error);
359
- alert(`Error: ${error.message}`);
360
- } finally {
361
- loadingOverlay.classList.add('hidden');
362
- loadingSubtext.textContent = '';
363
- fileNameSpan.textContent = '';
364
- fileUploadInput.value = '';
365
- }
366
- }
367
-
368
- // --- Chat Logic ---
369
- chatForm.addEventListener('submit', async (e) => {
370
- e.preventDefault();
371
- const question = chatInput.value.trim();
372
- if (!question) return;
373
-
374
- appendMessage(question, 'user');
375
- chatInput.value = '';
376
- chatInput.disabled = true;
377
- chatSubmitBtn.disabled = true;
378
-
379
- const typingIndicator = showTypingIndicator();
380
- let botMessageContainer = null;
381
- let contentDiv = null;
382
-
383
- try {
384
- const requestBody = { question: question };
385
- if (sessionId) {
386
- requestBody.session_id = sessionId;
387
- }
388
-
389
- const response = await fetch('/chat', {
390
- method: 'POST',
391
- headers: { 'Content-Type': 'application/json' },
392
- body: JSON.stringify(requestBody),
393
- });
394
-
395
- if (!response.ok) throw new Error(`Server error: ${response.statusText}`);
396
-
397
- // ============================ MODIFICATION START ==============================
398
- // Parse the JSON response instead of reading a stream
399
- const result = await response.json();
400
- const answer = result.answer; // Extract the 'answer' field
401
-
402
- if (!answer) {
403
- throw new Error("Received an empty or invalid response from the server.");
404
- }
405
-
406
- typingIndicator.remove();
407
- botMessageContainer = appendMessage('', 'bot');
408
- contentDiv = botMessageContainer.querySelector('.markdown-content');
409
-
410
- // Use the extracted answer for rendering
411
- contentDiv.innerHTML = marked.parse(answer);
412
- contentDiv.querySelectorAll('pre').forEach(addCopyButton);
413
- scrollToBottom(); // Scroll after content is added
414
-
415
- // Use the extracted answer for TTS
416
- addTextToSpeechControls(botMessageContainer, answer);
417
- // ============================ MODIFICATION END ==============================
418
-
419
- } catch (error) {
420
- console.error('Chat error:', error);
421
- if (typingIndicator) typingIndicator.remove();
422
- if (contentDiv) {
423
- contentDiv.innerHTML = `<p class="text-red-500">Error: ${error.message}</p>`;
424
- } else {
425
- appendMessage(`Error: ${error.message}`, 'bot');
426
- }
427
- } finally {
428
- chatInput.disabled = false;
429
- chatSubmitBtn.disabled = false;
430
- chatInput.focus();
431
- }
432
- });
433
-
434
- // --- UI Helper Functions ---
435
-
436
- function appendMessage(text, sender) {
437
- const messageWrapper = document.createElement('div');
438
- messageWrapper.className = `flex items-start gap-4`;
439
-
440
- const iconSVG = sender === 'user'
441
- ? `<div class="bg-blue-100 dark:bg-gray-700 p-2.5 rounded-full flex-shrink-0 mt-1"><svg class="w-5 h-5 text-blue-600 dark:text-blue-300" viewBox="0 0 24 24"><path fill="currentColor" d="M12 12c2.21 0 4-1.79 4-4s-1.79-4-4-4-4 1.79-4 4 1.79 4 4 4zm0 2c-2.67 0-8 1.34-8 4v2h16v-2c0-2.66-5.33-4-8-4z"></path></svg></div>`
442
- : `<div class="bg-gray-200 dark:bg-gray-700 rounded-full flex-shrink-0 mt-1 text-xl flex items-center justify-center w-10 h-10">✨</div>`;
443
-
444
- const messageBubble = document.createElement('div');
445
- messageBubble.className = `flex-1 pt-1`;
446
-
447
- const senderName = document.createElement('p');
448
- senderName.className = 'font-medium text-sm mb-1';
449
- senderName.textContent = sender === 'user' ? 'You' : 'CogniChat';
450
-
451
- const contentDiv = document.createElement('div');
452
- contentDiv.className = 'text-base markdown-content';
453
- // Only parse if text is not empty
454
- if (text) {
455
- contentDiv.innerHTML = marked.parse(text);
456
- }
457
-
458
- const controlsContainer = document.createElement('div');
459
- controlsContainer.className = 'tts-controls mt-2';
460
-
461
- messageBubble.appendChild(senderName);
462
- messageBubble.appendChild(contentDiv);
463
- messageBubble.appendChild(controlsContainer);
464
- messageWrapper.innerHTML = iconSVG;
465
- messageWrapper.appendChild(messageBubble);
466
-
467
- chatContent.appendChild(messageWrapper);
468
- scrollToBottom();
469
-
470
- return messageBubble;
471
- }
472
-
473
- function showTypingIndicator() {
474
- const indicatorWrapper = document.createElement('div');
475
- indicatorWrapper.className = `flex items-start gap-4`;
476
- indicatorWrapper.id = 'typing-indicator';
477
-
478
- const iconSVG = `<div class="bg-gray-200 dark:bg-gray-700 rounded-full flex-shrink-0 mt-1 text-xl flex items-center justify-center w-10 h-10">✨</div>`;
479
-
480
- const messageBubble = document.createElement('div');
481
- messageBubble.className = 'flex-1 pt-1';
482
-
483
- const senderName = document.createElement('p');
484
- senderName.className = 'font-medium text-sm mb-1';
485
- senderName.textContent = 'CogniChat is thinking...';
486
-
487
- const indicator = document.createElement('div');
488
- indicator.className = 'typing-indicator';
489
- indicator.innerHTML = '<span></span><span></span><span></span>';
490
-
491
- messageBubble.appendChild(senderName);
492
- messageBubble.appendChild(indicator);
493
- indicatorWrapper.innerHTML = iconSVG;
494
- indicatorWrapper.appendChild(messageBubble);
495
-
496
- chatContent.appendChild(indicatorWrapper);
497
- scrollToBottom();
498
-
499
- return indicatorWrapper;
500
- }
501
-
502
- function scrollToBottom() {
503
- chatWindow.scrollTo({
504
- top: chatWindow.scrollHeight,
505
- behavior: 'smooth'
506
- });
507
- }
508
-
509
- function addCopyButton(pre) {
510
- const button = document.createElement('button');
511
- button.className = 'copy-code-btn';
512
- button.textContent = 'Copy';
513
- pre.appendChild(button);
514
-
515
- button.addEventListener('click', () => {
516
- const code = pre.querySelector('code').innerText;
517
- navigator.clipboard.writeText(code).then(() => {
518
- button.textContent = 'Copied!';
519
- setTimeout(() => button.textContent = 'Copy', 2000);
520
- });
521
- });
522
- }
523
-
524
- // --- Text-to-Speech Logic ---
525
- let currentAudio = null;
526
- let currentPlayingButton = null;
527
-
528
- const playIconSVG = `<svg class="w-5 h-5" fill="currentColor" viewBox="0 0 24 24"><path d="M8 5v14l11-7z"/></svg>`;
529
- const pauseIconSVG = `<svg class="w-5 h-5" fill="currentColor" viewBox="0 0 24 24"><path d="M6 19h4V5H6v14zm8-14v14h4V5h-4z"/></svg>`;
530
-
531
-
532
- function addTextToSpeechControls(messageBubble, text) {
533
- const ttsControls = messageBubble.querySelector('.tts-controls');
534
- if (text.trim().length > 0) {
535
- const speakButton = document.createElement('button');
536
- speakButton.className = 'speak-btn px-4 py-2 bg-blue-700 text-white rounded-full text-sm font-medium hover:bg-blue-800 transition-colors flex items-center gap-2 disabled:opacity-50 disabled:cursor-not-allowed';
537
- speakButton.title = 'Listen to this message';
538
- speakButton.setAttribute('data-state', 'play');
539
- speakButton.innerHTML = `${playIconSVG} <span>Play</span>`;
540
- ttsControls.appendChild(speakButton);
541
- speakButton.addEventListener('click', () => handleTTS(text, speakButton));
542
- }
543
- }
544
-
545
- async function handleTTS(text, button) {
546
- if (button === currentPlayingButton) {
547
- if (currentAudio && !currentAudio.paused) {
548
- currentAudio.pause();
549
- button.setAttribute('data-state', 'paused');
550
- button.innerHTML = `${playIconSVG} <span>Play</span>`;
551
- } else if (currentAudio && currentAudio.paused) {
552
- currentAudio.play();
553
- button.setAttribute('data-state', 'playing');
554
- button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
555
- }
556
- return;
557
- }
558
-
559
- resetAllSpeakButtons();
560
-
561
- currentPlayingButton = button;
562
- button.setAttribute('data-state', 'loading');
563
- button.innerHTML = `<div class="tts-button-loader"></div> <span>Loading...</span>`;
564
- button.disabled = true;
565
-
566
- try {
567
- const response = await fetch('/tts', {
568
- method: 'POST',
569
- headers: { 'Content-Type': 'application/json' },
570
- body: JSON.stringify({ text: text })
571
- });
572
- if (!response.ok) throw new Error('Failed to generate audio.');
573
-
574
- const blob = await response.blob();
575
- const audioUrl = URL.createObjectURL(blob);
576
- currentAudio = new Audio(audioUrl);
577
- currentAudio.play();
578
-
579
- button.setAttribute('data-state', 'playing');
580
- button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
581
-
582
- currentAudio.onended = () => {
583
- button.setAttribute('data-state', 'play');
584
- button.innerHTML = `${playIconSVG} <span>Play</span>`;
585
- currentAudio = null;
586
- currentPlayingButton = null;
587
- };
588
-
589
- } catch (error) {
590
- console.error('TTS Error:', error);
591
- button.setAttribute('data-state', 'error');
592
- button.innerHTML = `${playIconSVG} <span>Error</span>`;
593
- alert('Failed to play audio. Please try again.');
594
- resetAllSpeakButtons();
595
- } finally {
596
- button.disabled = false;
597
- }
598
- }
599
-
600
- function resetAllSpeakButtons() {
601
- document.querySelectorAll('.speak-btn').forEach(btn => {
602
- btn.setAttribute('data-state', 'play');
603
- btn.innerHTML = `${playIconSVG} <span>Play</span>`;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
  btn.disabled = false;
605
- });
606
- if (currentAudio) {
607
- currentAudio.pause();
608
- currentAudio = null;
609
- }
610
- currentPlayingButton = null;
611
- }
612
- });
613
- </script>
614
- </body>
615
- </html>
 
 
 
 
 
 
 
 
 
 
1
  <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>CogniChat - Chat with your Documents</title>
7
+ <script src="https://cdn.tailwindcss.com"></script>
8
+ <link rel="preconnect" href="https://fonts.googleapis.com">
9
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
10
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Google+Sans:wght@400;500;700&family=Roboto:wght@400;500&display=swap" rel="stylesheet">
11
+ <script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
12
+ <style>
13
+ :root {
14
+ --background: #f0f4f9;
15
+ --foreground: #1f1f1f;
16
+ --primary: #1a73e8;
17
+ --primary-hover: #1867cf;
18
+ --card: #ffffff;
19
+ --card-border: #dadce0;
20
+ --input-bg: #e8f0fe;
21
+ --user-bubble: #d9e7ff;
22
+ --bot-bubble: #f1f3f4;
23
+ --select-bg: #ffffff;
24
+ --select-border: #dadce0;
25
+ --select-text: #1f1f1f;
26
+ }
27
+
28
+ .dark {
29
+ --background: #111827; /* Darker background */
30
+ --foreground: #e5e7eb;
31
+ --primary: #3b82f6; /* Adjusted primary blue */
32
+ --primary-hover: #60a5fa; /* Lighter hover blue */
33
+ --card: #1f2937; /* Dark card background */
34
+ --card-border: #4b5563; /* Greyer border */
35
+ --input-bg: #374151; /* Darker input background */
36
+ --user-bubble: #374151; /* Darker user bubble */
37
+ --bot-bubble: #374151; /* Darker bot bubble */
38
+ --select-bg: #374151;
39
+ --select-border: #6b7280;
40
+ --select-text: #f3f4f6;
41
+ --code-bg: #2d2d2d; /* Specific background for code blocks */
42
+ --code-text: #d4d4d4; /* Light grey text for code */
43
+ --copy-btn-bg: #4a4a4a;
44
+ --copy-btn-hover-bg: #5a5a5a;
45
+ --copy-btn-text: #e0e0e0;
46
+ }
47
+
48
+ body {
49
+ font-family: 'Inter', 'Google Sans', 'Roboto', sans-serif;
50
+ background-color: var(--background);
51
+ color: var(--foreground);
52
+ overflow: hidden; /* Prevent body scroll */
53
+ }
54
+
55
+ #chat-window {
56
+ scroll-behavior: smooth; /* Ensure smooth programatic scroll */
57
+ }
58
+ #chat-window::-webkit-scrollbar { width: 8px; }
59
+ #chat-window::-webkit-scrollbar-track { background: transparent; }
60
+ #chat-window::-webkit-scrollbar-thumb { background-color: #4b5563; border-radius: 20px; }
61
+ .dark #chat-window::-webkit-scrollbar-thumb { background-color: #5f6368; }
62
+
63
+ .drop-zone--over {
64
+ border-color: var(--primary);
65
+ box-shadow: 0 0 20px rgba(59, 130, 246, 0.4);
66
+ }
67
+
68
+ .loader {
69
+ width: 48px;
70
+ height: 48px;
71
+ border: 3px solid var(--card-border);
72
+ border-radius: 50%;
73
+ display: inline-block;
74
+ position: relative;
75
+ box-sizing: border-box;
76
+ animation: rotation 1s linear infinite;
77
+ }
78
+ .loader::after {
79
+ content: '';
80
+ box-sizing: border-box;
81
+ position: absolute;
82
+ left: 50%;
83
+ top: 50%;
84
+ transform: translate(-50%, -50%);
85
+ width: 56px;
86
+ height: 56px;
87
+ border-radius: 50%;
88
+ border: 3px solid;
89
+ border-color: var(--primary) transparent;
90
+ }
91
+
92
+ @keyframes rotation {
93
+ 0% { transform: rotate(0deg); }
94
+ 100% { transform: rotate(360deg); }
95
+ }
96
+
97
+ /* --- Updated Typing Indicator --- */
98
+ .typing-indicator {
99
+ display: inline-flex; /* Changed to inline-flex */
100
+ align-items: center;
101
+ padding: 8px 0; /* Add some vertical padding */
102
+ }
103
+ .typing-indicator span {
104
+ height: 8px; /* Slightly smaller dots */
105
+ width: 8px;
106
+ margin: 0 2px;
107
+ background-color: #9E9E9E;
108
+ border-radius: 50%;
109
+ opacity: 0; /* Start invisible */
110
+ animation: typing-pulse 1.4s infinite ease-in-out;
111
+ }
112
+ .typing-indicator span:nth-child(1) { animation-delay: 0s; }
113
+ .typing-indicator span:nth-child(2) { animation-delay: 0.2s; }
114
+ .typing-indicator span:nth-child(3) { animation-delay: 0.4s; }
115
+
116
+ @keyframes typing-pulse {
117
+ 0%, 100% { opacity: 0; transform: scale(0.7); }
118
+ 50% { opacity: 1; transform: scale(1); }
119
+ }
120
+ /* --- End Typing Indicator --- */
121
+
122
+ /* --- Updated Markdown Styling --- */
123
+ .markdown-content { /* Base styles for the content area */
124
+ line-height: 1.75;
125
+ }
126
+ .markdown-content p { margin-bottom: 1rem; }
127
+ .markdown-content h1, .markdown-content h2, .markdown-content h3,
128
+ .markdown-content h4, .markdown-content h5, .markdown-content h6 {
129
+ font-weight: 600;
130
+ margin-top: 1.5rem;
131
+ margin-bottom: 0.75rem;
132
+ line-height: 1.3;
133
+ }
134
+ .markdown-content h1 { font-size: 1.5em; border-bottom: 1px solid var(--card-border); padding-bottom: 0.3rem;}
135
+ .markdown-content h2 { font-size: 1.25em; }
136
+ .markdown-content h3 { font-size: 1.1em; }
137
+ .markdown-content ul, .markdown-content ol { padding-left: 1.75rem; margin-bottom: 1rem; }
138
+ .markdown-content li { margin-bottom: 0.5rem; }
139
+ .markdown-content a { color: var(--primary); text-decoration: none; font-weight: 500; }
140
+ .markdown-content a:hover { text-decoration: underline; }
141
+ .markdown-content strong, .markdown-content b { font-weight: 600; } /* Ensure bold works */
142
+ .markdown-content blockquote {
143
+ border-left: 4px solid var(--card-border);
144
+ padding-left: 1rem;
145
+ margin-left: 0;
146
+ margin-bottom: 1rem;
147
+ color: #a0aec0; /* Lighter text for quotes */
148
+ }
149
+ /* --- Code Block Styling --- */
150
+ .markdown-content pre {
151
+ position: relative;
152
+ background-color: var(--code-bg);
153
+ border: 1px solid var(--card-border);
154
+ border-radius: 0.5rem;
155
+ margin-bottom: 1rem;
156
+ font-size: 0.9em;
157
+ color: var(--code-text);
158
+ overflow: hidden; /* Hide horizontal overflow until hovered/focused */
159
+ }
160
+ .markdown-content pre code {
161
+ display: block;
162
+ padding: 1rem;
163
+ overflow-x: auto; /* Enable horizontal scroll on the code itself */
164
+ background: none !important; /* Override potential highlight.js background */
165
+ font-family: Consolas, Monaco, 'Andale Mono', 'Ubuntu Mono', monospace;
166
+ white-space: pre; /* Ensure whitespace is preserved */
167
+ }
168
+ /* --- Copy Button Styling --- */
169
+ .markdown-content pre .copy-code-btn {
170
  position: absolute;
171
+ top: 0.5rem;
172
+ right: 0.5rem;
173
+ background-color: var(--copy-btn-bg);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  border: 1px solid var(--card-border);
175
+ color: var(--copy-btn-text);
176
+ padding: 0.3rem 0.6rem;
177
+ border-radius: 0.25rem;
178
+ cursor: pointer;
179
+ opacity: 0; /* Initially hidden */
180
+ transition: opacity 0.2s, background-color 0.2s;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  font-size: 0.8em;
182
+ display: flex; /* For icon alignment */
183
+ align-items: center;
184
+ gap: 0.25rem;
185
+ }
186
+ .markdown-content pre .copy-code-btn:hover {
187
+ background-color: var(--copy-btn-hover-bg);
188
+ }
189
+ .markdown-content pre:hover .copy-code-btn {
190
+ opacity: 1; /* Show on hover */
191
+ }
192
+ /* --- Inline Code Styling --- */
193
+ .markdown-content code:not(pre code) {
194
+ background-color: rgba(110, 118, 129, 0.4);
195
+ padding: 0.2em 0.4em;
196
+ margin: 0 0.1em; /* Add slight horizontal margin */
197
+ font-size: 85%;
198
+ border-radius: 6px;
199
+ font-family: Consolas, Monaco, 'Andale Mono', 'Ubuntu Mono', monospace;
200
+ }
201
+ /* --- End Markdown Styling --- */
202
+
203
+ .tts-button-loader {
204
+ width: 16px;
205
+ height: 16px;
206
+ border: 2px solid currentColor;
207
+ border-radius: 50%;
208
+ display: inline-block;
209
+ box-sizing: border-box;
210
+ animation: rotation 0.8s linear infinite;
211
+ border-bottom-color: transparent;
212
+ }
213
+
214
+ /* --- Style for TTS controls container --- */
215
+ .tts-controls {
216
+ display: flex;
217
+ align-items: center;
218
+ gap: 0.5rem; /* Space between play and speed buttons */
219
+ margin-top: 0.5rem;
220
+ }
221
+
222
+ /* --- Style for Speed Cycle Button --- */
223
+ .speed-cycle-btn {
224
+ padding: 0.25rem 0.6rem; /* Smaller padding */
225
+ font-size: 0.75rem; /* Smaller text */
226
+ background-color: #4b5563; /* Grey background */
227
+ color: #e5e7eb; /* Light text */
228
+ border-radius: 9999px; /* Pill shape */
229
+ border: none;
230
+ cursor: pointer;
231
+ transition: background-color 0.2s;
232
+ white-space: nowrap; /* Prevent text wrapping */
233
+ margin-top: 0.5rem;
234
+ }
235
+ .speed-cycle-btn:hover {
236
+ background-color: #1f0bb8e6; /* Lighter grey on hover */
237
+ }
238
+ .speed-cycle-btn:disabled {
239
+ opacity: 0.5;
240
+ cursor: not-allowed;
241
+ }
242
+
243
+ /* --- Select dropdown styles (kept for consistency if needed elsewhere) --- */
244
+ .select-wrapper {
245
+ position: relative;
246
+ }
247
+ .select-wrapper select {
248
+ background-color: var(--select-bg);
249
+ border: 1px solid var(--select-border);
250
+ color: var(--select-text);
251
+ padding: 0.75rem 2.5rem 0.75rem 1rem;
252
+ border-radius: 0.75rem;
253
+ font-size: 0.875rem;
254
+ width: 100%;
255
+ appearance: none;
256
+ -webkit-appearance: none;
257
+ transition: all 0.2s ease-in-out;
258
+ cursor: pointer;
259
+ background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%239ca3af' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e");
260
+ background-position: right 0.75rem center;
261
+ background-repeat: no-repeat;
262
+ background-size: 1.25em 1.25em;
263
+ }
264
+ </style>
265
+ </head>
266
+ <body class="w-screen h-screen dark">
267
+ <main id="main-content" class="h-full flex flex-col transition-opacity duration-500">
268
+ <div id="chat-container" class="hidden flex-1 flex flex-col w-full mx-auto overflow-hidden">
269
+
270
+ <header class="p-4 border-b border-[var(--card-border)] flex-shrink-0 flex justify-between items-center w-full">
271
+ <div class="w-1/4"></div> <div class="w-1/2 text-center">
272
+ <h1 class="text-xl font-medium tracking-wide">CogniChat ✨</h1>
273
+ <p id="chat-filename" class="text-xs text-gray-400 mt-1 truncate"></p>
274
+ </div>
275
+ <div id="chat-session-info" class="w-1/4 text-right text-xs space-y-1 pr-4">
276
+ </div>
277
+ </header>
278
+
279
+ <div id="chat-window" class="flex-1 overflow-y-auto p-4 md:p-6 lg:p-10">
280
+ <div id="chat-content" class="max-w-4xl mx-auto space-y-8"></div>
281
+ </div>
282
+ <div class="p-4 flex-shrink-0 bg-opacity-50 backdrop-blur-md border-t border-[var(--card-border)]">
283
+ <form id="chat-form" class="max-w-4xl mx-auto bg-[var(--card)] rounded-full p-2 flex items-center shadow-lg border border-[var(--card-border)] focus-within:ring-2 focus-within:ring-[var(--primary)] transition-all">
284
+ <input type="text" id="chat-input" placeholder="Ask a question about your documents..." class="flex-grow bg-transparent focus:outline-none px-4 text-sm" autocomplete="off">
285
+ <button type="submit" id="chat-submit-btn" class="bg-[var(--primary)] hover:bg-[var(--primary-hover)] text-white p-2.5 rounded-full transition-all duration-200 disabled:opacity-50 disabled:cursor-not-allowed" title="Send">
286
+ <svg class="w-5 h-5" viewBox="0 0 20 20" fill="currentColor"><path fill-rule="evenodd" d="M10 18a8 8 0 100-16 8 8 0 000 16zm3.707-8.707l-3-3a1 1 0 00-1.414 1.414L10.586 9H7a1 1 0 100 2h3.586l-1.293 1.293a1 1 0 101.414 1.414l3-3a1 1 0 000-1.414z" clip-rule="evenodd"></path></svg>
287
+ </button>
288
+ </form>
289
+ </div>
290
+ </div>
291
+
292
+ <div id="upload-container" class="flex-1 flex flex-col items-center justify-center p-8 transition-opacity duration-300">
293
+ <div class="text-center max-w-xl w-full">
294
+ <h1 class="text-5xl font-bold mb-3 tracking-tight">CogniChat ✨</h1>
295
+ <p class="text-lg text-gray-400 mb-8">Upload your documents to start a conversation.</p>
296
+ <div class="mb-8 p-5 bg-[var(--card)] rounded-2xl border border-[var(--card-border)] shadow-lg">
297
+ <div class="flex flex-col sm:flex-row items-center gap-6">
298
+ <div class="w-full sm:w-1/2">
299
+ <div class="flex items-center gap-2 mb-2">
300
+ <svg class="w-5 h-5 text-gray-400" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"><path d="M7 3a1 1 0 000 2h6a1 1 0 100-2H7zM4 7a1 1 0 011-1h10a1 1 0 110 2H5a1 1 0 01-1-1zM2 11a2 2 0 012-2h12a2 2 0 012 2v4a2 2 0 01-2 2H4a2 2 0 01-2-2v-4z" /></svg>
301
+ <label for="model-select" class="block text-sm font-medium text-gray-300">Model</label>
302
+ </div>
303
+ <div class="select-wrapper">
304
+ <select id="model-select" name="model_name">
305
+ <option value="moonshotai/kimi-k2-instruct" selected>Kimi Instruct</option>
306
+ <option value="openai/gpt-oss-20b">GPT OSS 20b</option>
307
+ <option value="llama-3.3-70b-versatile">Llama 3.3 70b</option>
308
+ <option value="llama-3.1-8b-instant">Llama 3.1 8b Instant</option>
309
+ </select>
310
+ </div>
311
+ </div>
312
+ <div class="w-full sm:w-1/2">
313
+ <div class="flex items-center gap-2 mb-2">
314
+ <svg class="w-5 h-5 text-gray-400" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"><path fill-rule="evenodd" d="M5.5 16a3.5 3.5 0 100-7 3.5 3.5 0 000 7zM12 5.5a3.5 3.5 0 11-7 0 3.5 3.5 0 017 0zM14.5 16a3.5 3.5 0 100-7 3.5 3.5 0 000 7z" clip-rule="evenodd" /></svg>
315
+ <label for="temperature-select" class="block text-sm font-medium text-gray-300">Mode</label>
316
+ </div>
317
+ <div class="select-wrapper">
318
+ <select id="temperature-select" name="temperature">
319
+ <option value="0.2" selected>0.2 - Precise</option>
320
+ <option value="0.4">0.4 - Confident</option>
321
+ <option value="0.6">0.6 - Balanced</option>
322
+ <option value="0.8">0.8 - Flexible</option>
323
+ <option value="1.0">1.0 - Creative</option>
324
+ </select>
325
+ </div>
326
+ </div>
327
+ </div>
328
+ <p class="text-xs text-gray-500 mt-4 text-center">Higher creativity modes may reduce factual accuracy.</p>
329
+ </div>
330
+ <div id="drop-zone" class="w-full text-center border-2 border-dashed border-[var(--card-border)] rounded-2xl p-10 transition-all duration-300 cursor-pointer hover:bg-[var(--card)] hover:border-[var(--primary)]">
331
+ <div class="flex flex-col items-center justify-center pointer-events-none">
332
+ <svg class="mx-auto h-12 w-12 text-gray-500" fill="none" viewBox="0 0 24 24" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M12 16.5V9.75m0 0l3-3m-3 3l-3 3M6.75 19.5a4.5 4.5 0 01-1.41-8.775 5.25 5.25 0 0110.233-2.33 3 3 0 013.758 3.848A3.752 3.752 0 0118 19.5H6.75z"></path></svg>
333
+ <p class="mt-4 text-sm font-medium text-gray-400">Drag & drop files or <span class="text-[var(--primary)] font-semibold">click to upload</span></p>
334
+ <p class="text-xs text-gray-400 mt-1">Supports PDF, DOCX, TXT</p>
335
+ <p id="file-name" class="mt-2 text-xs text-gray-500"></p>
336
+ </div>
337
+ <input id="file-upload" type="file" class="hidden" accept=".pdf,.txt,.docx" multiple>
338
+ </div>
339
+ </div>
340
+ </div>
341
+
342
+ <div id="loading-overlay" class="hidden fixed inset-0 bg-[var(--background)] bg-opacity-80 backdrop-blur-sm flex flex-col items-center justify-center z-50">
343
+ <div class="loader"></div>
344
+ <p id="loading-text" class="mt-6 text-sm font-medium"></p>
345
+ <p id="loading-subtext" class="mt-2 text-xs text-gray-400"></p>
346
+ </div>
347
+ </main>
348
+
349
+ <script>
350
+ document.addEventListener('DOMContentLoaded', () => {
351
+ const uploadContainer = document.getElementById('upload-container');
352
+ const chatContainer = document.getElementById('chat-container');
353
+ const dropZone = document.getElementById('drop-zone');
354
+ const fileUploadInput = document.getElementById('file-upload');
355
+ const fileNameSpan = document.getElementById('file-name');
356
+ const loadingOverlay = document.getElementById('loading-overlay');
357
+ const loadingText = document.getElementById('loading-text');
358
+ const loadingSubtext = document.getElementById('loading-subtext');
359
+ const chatForm = document.getElementById('chat-form');
360
+ const chatInput = document.getElementById('chat-input');
361
+ const chatSubmitBtn = document.getElementById('chat-submit-btn');
362
+ const chatWindow = document.getElementById('chat-window');
363
+ const chatContent = document.getElementById('chat-content');
364
+ const modelSelect = document.getElementById('model-select');
365
+ const temperatureSelect = document.getElementById('temperature-select');
366
+ const chatFilename = document.getElementById('chat-filename');
367
+ const chatSessionInfo = document.getElementById('chat-session-info');
368
+
369
+ let sessionId = sessionStorage.getItem('cognichat_session_id');
370
+ let currentModelInfo = JSON.parse(sessionStorage.getItem('cognichat_model_info'));
371
+
372
+ marked.setOptions({
373
+ breaks: true,
374
+ gfm: true,
375
+ });
376
+
377
+ if (sessionId && currentModelInfo) {
378
+ console.log("Restoring session:", sessionId);
379
+ uploadContainer.classList.add('hidden');
380
+ chatContainer.classList.remove('hidden');
381
+ chatFilename.innerHTML = `Chatting with: <strong class="font-semibold">${sessionStorage.getItem('cognichat_filename') || 'documents'}</strong>`;
382
+ chatFilename.title = sessionStorage.getItem('cognichat_filename') || 'documents';
383
+ chatSessionInfo.innerHTML = `
384
+ <p>Model: ${currentModelInfo.simpleModelName}</p>
385
+ <p>Mode: ${currentModelInfo.mode}</p>
386
+ <button class="mt-1 text-xs text-blue-400 hover:text-blue-300 focus:outline-none" onclick="sessionStorage.clear(); location.reload();">New Chat</button>`;
387
+ }
388
+
389
+
390
+ // --- File Upload Logic ---
391
+ dropZone.addEventListener('click', () => fileUploadInput.click());
392
+
393
+ ['dragenter', 'dragover', 'dragleave', 'drop'].forEach(eventName => {
394
+ dropZone.addEventListener(eventName, e => {e.preventDefault(); e.stopPropagation();}, false);
395
+ document.body.addEventListener(eventName, e => {e.preventDefault(); e.stopPropagation();}, false);
396
+ });
397
+ ['dragenter', 'dragover'].forEach(eventName => dropZone.addEventListener(eventName, () => dropZone.classList.add('drop-zone--over')));
398
+ ['dragleave', 'drop'].forEach(eventName => dropZone.addEventListener(eventName, () => dropZone.classList.remove('drop-zone--over')));
399
+
400
+ dropZone.addEventListener('drop', (e) => {
401
+ if (e.dataTransfer.files.length > 0) handleFiles(e.dataTransfer.files);
402
+ });
403
+ fileUploadInput.addEventListener('change', (e) => {
404
+ if (e.target.files.length > 0) handleFiles(e.target.files);
405
+ });
406
+
407
+ async function handleFiles(files) {
408
+ const formData = new FormData();
409
+ let fileNames = Array.from(files).map(f => f.name);
410
+ for (const file of files) { formData.append('file', file); }
411
+
412
+ formData.append('model_name', modelSelect.value);
413
+ formData.append('temperature', temperatureSelect.value);
414
+
415
+ fileNameSpan.textContent = `Selected: ${fileNames.join(', ')}`;
416
+ await uploadAndProcessFiles(formData);
417
+ }
418
+
419
+ async function uploadAndProcessFiles(formData) {
420
+ loadingOverlay.classList.remove('hidden');
421
+ loadingText.textContent = `Processing document(s)...`;
422
+ loadingSubtext.textContent = "Creating a knowledge base... this might take a minute 🧠";
423
+ chatContent.innerHTML = ''; // Clear previous chat content on new upload
424
+
425
+ try {
426
+ const response = await fetch('/upload', { method: 'POST', body: formData });
427
+ const result = await response.json();
428
+ if (!response.ok) throw new Error(result.message || 'Unknown error occurred during upload.');
429
+
430
+ sessionId = result.session_id;
431
+ sessionStorage.setItem('cognichat_session_id', sessionId);
432
+
433
+ const modelOption = modelSelect.querySelector(`option[value="${result.model_name}"]`);
434
+ const simpleModelName = modelOption ? modelOption.textContent : result.model_name; // Adjust if needed
435
+
436
+ currentModelInfo = {
437
+ model: result.model_name,
438
+ mode: result.mode,
439
+ simpleModelName: simpleModelName // Use the derived simpler name
440
+ };
441
+ sessionStorage.setItem('cognichat_model_info', JSON.stringify(currentModelInfo)); // Store model info
442
+ sessionStorage.setItem('cognichat_filename', result.filename); // Store filename
443
+
444
+ chatFilename.innerHTML = `Chatting with: <strong class="font-semibold">${result.filename}</strong>`;
445
+ chatFilename.title = result.filename;
446
+
447
+ chatSessionInfo.innerHTML = `
448
+ <p>Model: ${currentModelInfo.simpleModelName}</p>
449
+ <p>Mode: ${currentModelInfo.mode}</p>
450
+ <button class="mt-1 text-xs text-blue-400 hover:text-blue-300 focus:outline-none" onclick="sessionStorage.clear(); location.reload();">New Chat</button>`;
451
+
452
+ uploadContainer.classList.add('hidden');
453
+ chatContainer.classList.remove('hidden');
454
+ appendMessage("Hello! 👋 I've analyzed your documents. What would you like to know?", "bot", currentModelInfo);
455
+
456
+ } catch (error) {
457
+ console.error('Upload error:', error);
458
+ alert(`Error processing files: ${error.message}`);
459
+ sessionStorage.clear(); // Clear session if upload fails
460
+ } finally {
461
+ loadingOverlay.classList.add('hidden');
462
+ fileNameSpan.textContent = '';
463
+ fileUploadInput.value = '';
464
+ }
465
+ }
466
+
467
+ // --- Chat Logic (Using Server-Sent Events - UPDATED FOR STREAMING & INDICATOR) ---
468
+ chatForm.addEventListener('submit', async (e) => {
469
+ e.preventDefault();
470
+ const question = chatInput.value.trim();
471
+ if (!question || !sessionId) {
472
+ console.warn("Submit ignored: No question or session ID.");
473
+ return;
474
+ }
475
+
476
+ appendMessage(question, 'user');
477
+ chatInput.value = '';
478
+ chatInput.disabled = true;
479
+ chatSubmitBtn.disabled = true;
480
+
481
+ let botMessageContainer;
482
+ let contentDiv;
483
+ let fullResponse = '';
484
+ let eventSource = null;
485
+ let inactivityTimeout = null;
486
+ let streamClosedCleanly = false; // Flag to check if stream ended normally vs error
487
+ let typingIndicatorElement = null; // Store indicator element
488
+
489
+ // Function to finalize chat (called on error, timeout, or successful completion)
490
+ function finalizeChat(isError = false) {
491
+ console.log(`Finalizing chat. Was error: ${isError}, Stream ended cleanly: ${streamClosedCleanly}`);
492
+ if (eventSource) {
493
+ eventSource.close();
494
+ eventSource = null;
495
+ console.log("SSE connection explicitly closed in finalizeChat.");
496
+ }
497
+ if (inactivityTimeout) {
498
+ clearTimeout(inactivityTimeout);
499
+ inactivityTimeout = null;
500
+ }
501
+ // Remove indicator if it's still there
502
+ if (typingIndicatorElement && typingIndicatorElement.parentNode) {
503
+ typingIndicatorElement.parentNode.removeChild(typingIndicatorElement);
504
+ typingIndicatorElement = null;
505
+ }
506
+
507
+
508
+ if (botMessageContainer && contentDiv) {
509
+ const hasErrorMsg = contentDiv.innerHTML.includes('⚠️');
510
+ // Ensure final render, apply copy buttons and TTS ONLY if response wasn't an error
511
+ if (!hasErrorMsg && fullResponse) {
512
+ // Re-parse the complete response to ensure correct final Markdown
513
+ contentDiv.innerHTML = marked.parse(fullResponse);
514
+ // Apply final touches like copy buttons and TTS
515
+ contentDiv.querySelectorAll('pre').forEach(addCopyButton);
516
+ addTextToSpeechControls(botMessageContainer, fullResponse);
517
+ // Optional: Final highlighting if using highlight.js
518
+ // contentDiv.querySelectorAll('pre code').forEach(block => hljs.highlightElement(block));
519
+ }
520
+ scrollToBottom(true); // Ensure scrolled to the end
521
+ }
522
+
523
+ // Always re-enable input fields
524
+ chatInput.disabled = false;
525
+ chatSubmitBtn.disabled = false;
526
+ chatInput.focus();
527
+ }
528
+
529
+ try {
530
+ // Create the bot message container *before* starting the stream
531
+ botMessageContainer = appendMessage('', 'bot', currentModelInfo); // Append empty bot message
532
+ contentDiv = botMessageContainer.querySelector('.markdown-content');
533
+
534
+ // Show typing indicator *inside* the contentDiv
535
+ typingIndicatorElement = showTypingIndicator();
536
+ if (contentDiv) {
537
+ contentDiv.appendChild(typingIndicatorElement);
538
+ scrollToBottom(true); // Scroll to show indicator
539
+ } else {
540
+ console.error("Could not find contentDiv to append typing indicator.");
541
+ }
542
+
543
+
544
+ // Establish SSE connection via GET request
545
+ const chatUrl = `/chat?question=${encodeURIComponent(question)}&session_id=${encodeURIComponent(sessionId)}`;
546
+ console.log("Connecting to SSE:", chatUrl);
547
+ eventSource = new EventSource(chatUrl);
548
+
549
+ eventSource.onopen = () => {
550
+ console.log("SSE Connection opened.");
551
+ // Remove indicator when connection opens and stream is about to start
552
+ if (typingIndicatorElement && typingIndicatorElement.parentNode) {
553
+ typingIndicatorElement.parentNode.removeChild(typingIndicatorElement);
554
+ typingIndicatorElement = null;
555
+ }
556
+ streamClosedCleanly = false; // Reset flag on new connection
557
+ };
558
+
559
+ eventSource.onmessage = (event) => {
560
+ // Remove indicator on first message just in case onopen didn't fire reliably
561
+ if (typingIndicatorElement && typingIndicatorElement.parentNode) {
562
+ typingIndicatorElement.parentNode.removeChild(typingIndicatorElement);
563
+ typingIndicatorElement = null;
564
+ }
565
+
566
+ // Reset inactivity timeout on each message
567
+ if (inactivityTimeout) clearTimeout(inactivityTimeout);
568
+ inactivityTimeout = setTimeout(() => {
569
+ console.log("Inactivity timeout triggered after message.");
570
+ streamClosedCleanly = true; // Assume normal end
571
+ finalizeChat(false);
572
+ }, 5000); // 5 seconds of inactivity
573
+
574
+ let data;
575
+ try {
576
+ data = JSON.parse(event.data);
577
+ } catch (parseError){
578
+ console.error("Failed to parse SSE data:", event.data, parseError);
579
+ contentDiv.innerHTML += `<p class="text-red-400 text-sm">Error receiving data chunk.</p>`;
580
+ return;
581
+ }
582
+
583
+ if (data.error) {
584
+ console.error('SSE Error from server:', data.error);
585
+ contentDiv.innerHTML = `<p class="text-red-500 font-semibold">⚠️ Server Error: ${data.error}</p>`;
586
+ streamClosedCleanly = false;
587
+ finalizeChat(true); // Pass true for error
588
+ return;
589
+ }
590
+
591
+ if (data.token !== undefined && data.token !== null) {
592
+ fullResponse += data.token;
593
+ // Update content by parsing the accumulated response
594
+ contentDiv.innerHTML = marked.parse(fullResponse);
595
+ scrollToBottom(); // Scroll smoothly as content arrives
596
+ }
597
+ };
598
+
599
+ eventSource.onerror = (error) => {
600
+ console.error('SSE connection error event:', error);
601
+ // Remove indicator on error
602
+ if (typingIndicatorElement && typingIndicatorElement.parentNode) {
603
+ typingIndicatorElement.parentNode.removeChild(typingIndicatorElement);
604
+ typingIndicatorElement = null;
605
+ }
606
+ // Don't show generic error if we received data and the stream likely just closed normally
607
+ if (!fullResponse && !streamClosedCleanly) { // Only show error if nothing received AND not already cleanly closed
608
+ const errorMsg = "⚠️ Connection error. Please try again.";
609
+ if (contentDiv) {
610
+ contentDiv.innerHTML = `<p class="text-red-500 font-semibold">${errorMsg}</p>`;
611
+ } else {
612
+ // Fallback if container wasn't created somehow
613
+ appendMessage(errorMsg, 'bot', currentModelInfo); // Pass model info here too
614
+ }
615
+ streamClosedCleanly = false;
616
+ } else if (!streamClosedCleanly) {
617
+ // If we received data, assume it's a normal closure misinterpreted as error
618
+ console.log("SSE connection closed (likely normal end detected by onerror).");
619
+ streamClosedCleanly = true; // Mark as clean closure NOW
620
+ } else {
621
+ console.log("SSE onerror event after stream already marked cleanly closed.")
622
+ }
623
+ finalizeChat(!streamClosedCleanly); // Finalize, indicate error if not clean
624
+ };
625
+
626
+ } catch (error) {
627
+ // For setup errors before SSE starts
628
+ console.error('Chat setup error:', error);
629
+ // Remove indicator on setup error
630
+ if (typingIndicatorElement && typingIndicatorElement.parentNode) {
631
+ typingIndicatorElement.parentNode.removeChild(typingIndicatorElement);
632
+ typingIndicatorElement = null;
633
+ }
634
+ if (botMessageContainer && contentDiv) {
635
+ contentDiv.innerHTML = `<p class="text-red-500 font-semibold">⚠️ Error starting chat: ${error.message}</p>`;
636
+ } else {
637
+ appendMessage(`Error starting chat: ${error.message}`, 'bot', currentModelInfo); // Pass model info
638
+ }
639
+ finalizeChat(true);
640
+ }
641
+ });
642
+
643
+
644
+ // --- UI Helper Functions ---
645
+
646
+ function appendMessage(text, sender, modelInfo = null) {
647
+ const messageWrapper = document.createElement('div');
648
+ const iconSVG = sender === 'user'
649
+ ? `<div class="bg-blue-200 dark:bg-gray-700 p-2.5 rounded-full flex-shrink-0 mt-1 self-start"><svg class="w-5 h-5 text-blue-700 dark:text-blue-300" viewBox="0 0 24 24"><path fill="currentColor" d="M12 12c2.21 0 4-1.79 4-4s-1.79-4-4-4-4 1.79-4 4 1.79 4 4 4zm0 2c-2.67 0-8 1.34-8 4v2h16v-2c0-2.66-5.33-4-8-4z"></path></svg></div>`
650
+ : `<div class="bg-gray-200 dark:bg-gray-700 rounded-full flex-shrink-0 mt-1 self-start text-xl flex items-center justify-center w-10 h-10">✨</div>`;
651
+
652
+ let senderHTML;
653
+ if (sender === 'user') {
654
+ senderHTML = '<p class="font-medium text-sm mb-1">You</p>';
655
+ } else {
656
+ let modelInfoHTML = '';
657
+ const displayInfo = modelInfo || currentModelInfo;
658
+ if (displayInfo && displayInfo.simpleModelName) {
659
+ modelInfoHTML = `
660
+ <span class="ml-2 text-xs font-normal text-gray-400">
661
+ (Model: ${displayInfo.simpleModelName} | Mode: ${displayInfo.mode})
662
+ </span>
663
+ `;
664
+ }
665
+ senderHTML = `<div class="font-medium text-sm mb-1 flex items-center">CogniChat ${modelInfoHTML}</div>`;
666
+ }
667
+
668
+ messageWrapper.className = `flex items-start gap-3`;
669
+ // Ensure markdown-content div exists even if text is empty for the indicator
670
+ messageWrapper.innerHTML = `
671
+ ${iconSVG}
672
+ <div class="flex-1 pt-1 min-w-0"> ${senderHTML}
673
+ <div class="text-base markdown-content prose dark:prose-invert max-w-none">${text ? marked.parse(text) : ''}</div>
674
+ <div class="tts-controls mt-2"></div>
675
+ </div>
676
+ `;
677
+ chatContent.appendChild(messageWrapper);
678
+ // Force scroll only when adding user message or initial bot message with content
679
+ if (sender === 'user' || text) {
680
+ scrollToBottom(true);
681
+ }
682
+ // Return the container that holds the sender name and content div
683
+ return messageWrapper.querySelector('.flex-1');
684
+ }
685
+
686
+ // --- UPDATED showTypingIndicator ---
687
+ function showTypingIndicator() {
688
+ const indicator = document.createElement('div');
689
+ indicator.className = 'typing-indicator'; // Use the main class
690
+ indicator.innerHTML = '<span></span><span></span><span></span>';
691
+ // Don't append here, just return the element
692
+ return indicator;
693
+ }
694
+ // --- End UPDATED showTypingIndicator ---
695
+
696
+
697
+ function scrollToBottom(force = false) {
698
+ const isNearBottom = chatWindow.scrollHeight - chatWindow.clientHeight <= chatWindow.scrollTop + 150; // Threshold
699
+
700
+ if (force || isNearBottom) {
701
+ requestAnimationFrame(() => { // Use rAF for smoother render loop
702
+ chatWindow.scrollTo({
703
+ top: chatWindow.scrollHeight,
704
+ behavior: 'smooth'
705
+ });
706
+ });
707
+ }
708
+ }
709
+
710
+ function addCopyButton(pre) {
711
+ if (pre.querySelector('.copy-code-btn')) return;
712
+
713
+ const button = document.createElement('button');
714
+ // Updated classes for better styling
715
+ button.className = 'copy-code-btn absolute top-2 right-2 p-1 rounded bg-[var(--copy-btn-bg)] text-[var(--copy-btn-text)] hover:bg-[var(--copy-btn-hover-bg)] transition-opacity duration-200 flex items-center gap-1 text-xs';
716
+ button.innerHTML = `<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"></path></svg> Copy`;
717
+ pre.style.position = 'relative'; // Ensure parent is relative for absolute positioning
718
+ pre.appendChild(button);
719
+
720
+ button.addEventListener('click', () => {
721
+ const code = pre.querySelector('code')?.innerText || '';
722
+ navigator.clipboard.writeText(code)
723
+ .then(() => {
724
+ button.textContent = 'Copied!';
725
+ setTimeout(() => button.innerHTML = `<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"></path></svg> Copy`, 1500);
726
+ })
727
+ .catch(err => {
728
+ console.error('Failed to copy code: ', err);
729
+ button.textContent = 'Error';
730
+ setTimeout(() => button.innerHTML = `<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"></path></svg> Copy`, 1500);
731
+ });
732
+ });
733
+ }
734
+
735
+ // --- TTS Functions (UPDATED FOR SPEED CYCLE) ---
736
+ let currentAudio = null;
737
+ let currentPlayingButton = null; // Stores the currently active *play/pause* button
738
+ const playIconSVG = `<svg class="w-4 h-4" fill="currentColor" viewBox="0 0 20 20"><path d="M6.3 2.841A1.5 1.5 0 004 4.11V15.89a1.5 1.5 0 002.3 1.269l9.344-5.89a1.5 1.5 0 000-2.538L6.3 2.84z"></path></svg>`;
739
+ const pauseIconSVG = `<svg class="w-4 h-4" fill="currentColor" viewBox="0 0 20 20"><path d="M5.75 4.75a.75.75 0 00-.75.75v9.5c0 .414.336.75.75.75h1.5a.75.75 0 00.75-.75v-9.5a.75.75 0 00-.75-.75h-1.5zm6.5 0a.75.75 0 00-.75.75v9.5c0 .414.336.75.75.75h1.5a.75.75 0 00.75-.75v-9.5a.75.75 0 00-.75-.75h-1.5z"></path></svg>`;
740
+ const availableSpeeds = [1.0, 1.5, 0.75]; // Normal, Fast, Slow
741
+
742
+ // --- UPDATED: addTextToSpeechControls ---
743
+ function addTextToSpeechControls(messageBubble, text) {
744
+ if (!text || !text.trim()) return;
745
+ const ttsControls = messageBubble.querySelector('.tts-controls');
746
+ if (!ttsControls || ttsControls.querySelector('.speak-btn')) return; // Avoid adding duplicates
747
+
748
+ // Play/Pause Button
749
+ const speakButton = document.createElement('button');
750
+ speakButton.className = 'speak-btn mt-2 px-3 py-1.5 bg-blue-700 text-white rounded-full text-xs font-medium hover:bg-blue-800 transition-colors flex items-center gap-1.5 disabled:opacity-50 disabled:cursor-not-allowed';
751
+ speakButton.title = 'Listen to this message';
752
+ speakButton.innerHTML = `${playIconSVG} <span>Listen</span>`;
753
+ speakButton.setAttribute('data-current-speed', '1.0'); // Store current speed
754
+ ttsControls.appendChild(speakButton);
755
+ speakButton.addEventListener('click', () => handleTTS(text, speakButton));
756
+
757
+ // Speed Cycle Button
758
+ const speedButton = document.createElement('button');
759
+ speedButton.className = 'speed-cycle-btn'; // Use new class for styling
760
+ speedButton.title = 'Cycle playback speed';
761
+ speedButton.textContent = 'Speed: 1x';
762
+ speedButton.setAttribute('data-speeds', JSON.stringify(availableSpeeds)); // Store speeds
763
+ ttsControls.appendChild(speedButton);
764
+ speedButton.addEventListener('click', () => cycleSpeed(speedButton, speakButton));
765
+ }
766
+
767
+ // --- NEW: cycleSpeed ---
768
+ function cycleSpeed(speedBtn, speakBtn) {
769
+ const speeds = JSON.parse(speedBtn.getAttribute('data-speeds'));
770
+ let currentSpeed = parseFloat(speakBtn.getAttribute('data-current-speed'));
771
+ let currentIndex = speeds.indexOf(currentSpeed);
772
+
773
+ // Find next speed index, looping back to 0
774
+ let nextIndex = (currentIndex + 1) % speeds.length;
775
+ let nextSpeed = speeds[nextIndex];
776
+
777
+ // Update speak button's data attribute and speed button's text
778
+ speakBtn.setAttribute('data-current-speed', nextSpeed.toString());
779
+ speedBtn.textContent = `Speed: ${nextSpeed}x`;
780
+
781
+ // If audio is currently playing and this is the active button, update playback rate
782
+ if (currentAudio && !currentAudio.paused && speakBtn === currentPlayingButton) {
783
+ currentAudio.playbackRate = nextSpeed;
784
+ }
785
+ }
786
+
787
+ // --- UPDATED: handleTTS ---
788
+ async function handleTTS(text, button) {
789
+ if (!text || !text.trim()) return;
790
+
791
+ // *** Get speed from the button's data attribute ***
792
+ const selectedSpeed = parseFloat(button.getAttribute('data-current-speed')) || 1.0;
793
+
794
+ if (button === currentPlayingButton) { // If clicking the currently active play/pause button
795
+ if (currentAudio && !currentAudio.paused) { // If playing, pause it
796
+ currentAudio.pause();
797
+ button.innerHTML = `${playIconSVG} <span>Listen</span>`;
798
+ } else if (currentAudio && currentAudio.paused) { // If paused, resume it
799
+ currentAudio.playbackRate = selectedSpeed; // Ensure speed is set on resume
800
+ currentAudio.play().catch(e => {console.error("Audio resume error:", e); resetAllSpeakButtons();});
801
+ button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
802
+ }
803
+ return;
804
+ }
805
+
806
+ // If clicking a new play button (or the first time)
807
+ resetAllSpeakButtons(); // Stop any other audio
808
+ currentPlayingButton = button; // Mark this button as active
809
+ button.innerHTML = `<div class="tts-button-loader mr-1"></div> <span>Loading...</span>`;
810
+ button.disabled = true;
811
+ // Disable the corresponding speed button while loading
812
+ const speedBtn = button.parentElement.querySelector('.speed-cycle-btn');
813
+ if(speedBtn) speedBtn.disabled = true;
814
+
815
+
816
+ try {
817
+ const response = await fetch('/tts', {
818
+ method: 'POST',
819
+ headers: { 'Content-Type': 'application/json' },
820
+ body: JSON.stringify({ text: text })
821
+ });
822
+ if (!response.ok) throw new Error(`TTS generation failed (${response.status})`);
823
+ const blob = await response.blob();
824
+ if (!blob || blob.size === 0) throw new Error("Received empty audio blob.");
825
+
826
+ const audioUrl = URL.createObjectURL(blob);
827
+ currentAudio = new Audio(audioUrl);
828
+
829
+ // *** Set the playback speed HERE ***
830
+ currentAudio.playbackRate = selectedSpeed;
831
+
832
+ await currentAudio.play();
833
+ button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
834
+ button.disabled = false;
835
+ // Re-enable the speed button
836
+ if(speedBtn) speedBtn.disabled = false;
837
+
838
+
839
+ currentAudio.onended = () => {
840
+ // Only reset if this button was the one playing
841
+ if (button === currentPlayingButton) resetAllSpeakButtons();
842
+ };
843
+ currentAudio.onerror = (e) => {
844
+ console.error('Audio object error:', e);
845
+ alert('Error playing audio.');
846
+ resetAllSpeakButtons();
847
+ };
848
+
849
+ } catch (error) {
850
+ console.error('TTS Handling Error:', error);
851
+ alert(`Failed to play audio: ${error.message}`);
852
+ resetAllSpeakButtons(); // Reset on error
853
+ }
854
+ }
855
+
856
+ // --- UPDATED: resetAllSpeakButtons ---
857
+ function resetAllSpeakButtons() {
858
+ document.querySelectorAll('.speak-btn').forEach(btn => {
859
+ btn.innerHTML = `${playIconSVG} <span>Listen</span>`;
860
+ btn.disabled = false;
861
+ btn.setAttribute('data-current-speed', '1.0'); // Reset speed attribute
862
+ });
863
+ document.querySelectorAll('.speed-cycle-btn').forEach(btn => {
864
+ btn.textContent = 'Speed: 1x'; // Reset speed button text
865
  btn.disabled = false;
866
+ });
867
+
868
+ if (currentAudio) {
869
+ currentAudio.pause();
870
+ currentAudio.onended = null; // Clean up listeners
871
+ currentAudio.onerror = null;
872
+ currentAudio = null;
873
+ }
874
+ currentPlayingButton = null;
875
+ }
876
+
877
+ // Remove the resetSpecificButton function if it exists, it's integrated now.
878
+ // --- End of TTS Functions ---
879
+
880
+ // ... (keep the optional highlight.js part if you have it)
881
+
882
+ });
883
+ </script>
884
+ </body>
885
+ </html>