Spaces:
Sleeping
Sleeping
riteshraut commited on
Commit ·
bd93e48
1
Parent(s): 5fd359f
new feature
Browse files- app.py +282 -152
- evaluate.py +205 -0
- query_expansion.py +524 -0
- rag_processor.py +382 -78
- templates/index.html +880 -610
app.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
-
import time
|
| 3 |
import uuid
|
| 4 |
-
from flask import Flask, request, render_template, session, jsonify, Response
|
| 5 |
from werkzeug.utils import secure_filename
|
| 6 |
from rag_processor import create_rag_chain
|
| 7 |
from typing import Sequence, Any, List
|
|
@@ -10,10 +9,7 @@ import re
|
|
| 10 |
import io
|
| 11 |
from gtts import gTTS
|
| 12 |
from langchain_core.documents import Document
|
| 13 |
-
from langchain_community.document_loaders import
|
| 14 |
-
TextLoader,
|
| 15 |
-
Docx2txtLoader,
|
| 16 |
-
)
|
| 17 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 18 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 19 |
from langchain_community.vectorstores import FAISS
|
|
@@ -23,124 +19,203 @@ from langchain_community.retrievers import BM25Retriever
|
|
| 23 |
from langchain_community.chat_message_histories import ChatMessageHistory
|
| 24 |
from langchain.storage import InMemoryStore
|
| 25 |
from sentence_transformers.cross_encoder import CrossEncoder
|
|
|
|
| 26 |
app = Flask(__name__)
|
| 27 |
app.config['SECRET_KEY'] = os.urandom(24)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
class LocalReranker(BaseDocumentCompressor):
|
| 29 |
model: Any
|
| 30 |
-
top_n: int =
|
| 31 |
|
| 32 |
class Config:
|
| 33 |
arbitrary_types_allowed = True
|
| 34 |
|
| 35 |
-
def compress_documents(
|
| 36 |
-
|
| 37 |
-
documents: Sequence[Document],
|
| 38 |
-
query: str,
|
| 39 |
-
callbacks=None,
|
| 40 |
-
) -> Sequence[Document]:
|
| 41 |
if not documents:
|
| 42 |
return []
|
| 43 |
-
|
| 44 |
pairs = [[query, doc.page_content] for doc in documents]
|
| 45 |
scores = self.model.predict(pairs, show_progress_bar=False)
|
| 46 |
-
|
| 47 |
doc_scores = list(zip(documents, scores))
|
| 48 |
-
sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1],
|
| 49 |
-
|
| 50 |
top_docs = []
|
| 51 |
-
for doc, score in sorted_doc_scores[:self.top_n]:
|
| 52 |
doc.metadata['rerank_score'] = float(score)
|
| 53 |
top_docs.append(doc)
|
| 54 |
return top_docs
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
try:
|
| 63 |
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
| 64 |
-
print(f"Upload folder ready: {app.config['UPLOAD_FOLDER']}")
|
| 65 |
except Exception as e:
|
| 66 |
-
print(f"Failed to create upload folder
|
| 67 |
app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
|
| 68 |
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
| 69 |
-
print(f"Using fallback upload folder: {app.config['UPLOAD_FOLDER']}")
|
| 70 |
|
| 71 |
-
|
| 72 |
message_histories = {}
|
| 73 |
|
| 74 |
-
print(
|
| 75 |
try:
|
| 76 |
-
EMBEDDING_MODEL =
|
| 77 |
-
model_name=
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
print(
|
| 81 |
except Exception as e:
|
| 82 |
-
print(f"FATAL: Could not load embedding model. Error: {e}")
|
| 83 |
raise e
|
| 84 |
|
| 85 |
-
print(
|
| 86 |
try:
|
| 87 |
-
RERANKER_MODEL =
|
| 88 |
-
|
|
|
|
|
|
|
| 89 |
except Exception as e:
|
| 90 |
-
print(f"FATAL: Could not load reranker model. Error: {e}")
|
| 91 |
raise e
|
| 92 |
|
|
|
|
| 93 |
def load_pdf_with_fallback(filepath):
|
| 94 |
try:
|
| 95 |
docs = []
|
| 96 |
with fitz.open(filepath) as pdf_doc:
|
| 97 |
-
for page_num, page in enumerate(pdf_doc):
|
| 98 |
text = page.get_text()
|
| 99 |
if text.strip():
|
| 100 |
-
docs.append(Document(
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
"source": os.path.basename(filepath),
|
| 104 |
-
"page": page_num + 1,
|
| 105 |
-
}
|
| 106 |
-
))
|
| 107 |
if docs:
|
| 108 |
-
print(f"
|
|
|
|
| 109 |
return docs
|
| 110 |
else:
|
| 111 |
-
raise ValueError(
|
| 112 |
except Exception as e:
|
| 113 |
-
print(f"PyMuPDF failed for {filepath}: {e}")
|
| 114 |
raise
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
}
|
| 121 |
|
| 122 |
def get_session_history(session_id: str) -> ChatMessageHistory:
|
| 123 |
if session_id not in message_histories:
|
| 124 |
message_histories[session_id] = ChatMessageHistory()
|
| 125 |
return message_histories[session_id]
|
| 126 |
|
|
|
|
| 127 |
@app.route('/health', methods=['GET'])
|
| 128 |
def health_check():
|
| 129 |
-
return jsonify({'status': 'healthy'}), 200
|
|
|
|
| 130 |
|
| 131 |
@app.route('/', methods=['GET'])
|
| 132 |
def index():
|
| 133 |
return render_template('index.html')
|
| 134 |
|
|
|
|
| 135 |
@app.route('/upload', methods=['POST'])
|
| 136 |
def upload_files():
|
| 137 |
files = request.files.getlist('file')
|
| 138 |
-
if not files or all(f.filename == '' for f in files):
|
| 139 |
-
return jsonify({'status': 'error', 'message': 'No selected files.'}), 400
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
for file in files:
|
| 145 |
if file and file.filename:
|
| 146 |
filename = secure_filename(file.filename)
|
|
@@ -149,147 +224,202 @@ def upload_files():
|
|
| 149 |
file.save(filepath)
|
| 150 |
file_ext = os.path.splitext(filename)[1].lower()
|
| 151 |
if file_ext not in LOADER_MAPPING:
|
| 152 |
-
raise ValueError(
|
| 153 |
-
|
| 154 |
loader_func = LOADER_MAPPING[file_ext]
|
| 155 |
-
docs = loader_func(filepath) if file_ext ==
|
| 156 |
-
|
| 157 |
if not docs:
|
| 158 |
-
raise ValueError(
|
| 159 |
-
|
| 160 |
all_docs.extend(docs)
|
| 161 |
processed_files.append(filename)
|
| 162 |
-
print(f"✓ Successfully processed: {filename}")
|
| 163 |
except Exception as e:
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
failed_files.append(f"{filename} ({error_msg})")
|
| 167 |
|
| 168 |
if not all_docs:
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
return jsonify({'status': 'error', 'message': error_summary}), 400
|
| 173 |
|
|
|
|
| 174 |
try:
|
| 175 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
-
vectorstore = FAISS.
|
| 178 |
store = InMemoryStore()
|
| 179 |
-
|
| 180 |
-
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
|
| 181 |
-
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=50)
|
| 182 |
-
|
| 183 |
-
parent_docs = parent_splitter.split_documents(all_docs)
|
| 184 |
-
doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
|
| 185 |
-
|
| 186 |
-
child_docs = []
|
| 187 |
-
for i, doc in enumerate(parent_docs):
|
| 188 |
-
_id = doc_ids[i]
|
| 189 |
-
sub_docs = child_splitter.split_documents([doc])
|
| 190 |
-
for child in sub_docs:
|
| 191 |
-
child.metadata["doc_id"] = _id
|
| 192 |
-
child_docs.extend(sub_docs)
|
| 193 |
-
|
| 194 |
store.mset(list(zip(doc_ids, parent_docs)))
|
| 195 |
-
|
| 196 |
-
print(f"Stored {len(parent_docs)} parent docs and indexed {len(child_docs)} child docs.")
|
| 197 |
|
| 198 |
bm25_retriever = BM25Retriever.from_documents(child_docs)
|
| 199 |
-
bm25_retriever.k =
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
weights=[0.4, 0.6]
|
| 206 |
-
)
|
| 207 |
-
print("Created Hybrid Retriever for child documents.")
|
| 208 |
|
| 209 |
-
reranker = LocalReranker(model=RERANKER_MODEL, top_n=4)
|
| 210 |
-
|
| 211 |
def get_parents(docs: List[Document]) -> List[Document]:
|
| 212 |
-
|
| 213 |
-
return store.mget(list(parent_ids))
|
| 214 |
|
| 215 |
-
compression_retriever =
|
| 216 |
-
base_compressor=reranker,
|
| 217 |
-
|
| 218 |
-
|
| 219 |
final_retriever = compression_retriever | get_parents
|
| 220 |
-
|
| 221 |
-
print("Final retriever chain created: (Hybrid -> Rerank) -> Parent Fetch")
|
| 222 |
|
| 223 |
session_id = str(uuid.uuid4())
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
if failed_files:
|
| 230 |
-
success_msg += f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
return jsonify({
|
| 233 |
-
'status': 'success',
|
| 234 |
'filename': success_msg,
|
| 235 |
-
'session_id': session_id
|
| 236 |
-
|
| 237 |
-
|
|
|
|
| 238 |
except Exception as e:
|
| 239 |
import traceback
|
| 240 |
traceback.print_exc()
|
| 241 |
-
return jsonify({'status': 'error',
|
|
|
|
| 242 |
|
| 243 |
-
@app.route('/chat', methods=['POST'])
|
| 244 |
def chat():
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
-
|
| 250 |
-
return jsonify({'status': 'error', 'message': 'Invalid session or no question provided.'}), 400
|
| 251 |
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
except Exception as e:
|
| 259 |
-
print(f"
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
def clean_markdown_for_tts(text: str) -> str:
|
| 263 |
-
text = re.sub(r'\*
|
| 264 |
-
text = re.sub(r'
|
| 265 |
-
text = re.sub(r'
|
| 266 |
-
text = re.sub(r'^\s*
|
| 267 |
-
text = re.sub(r'^\s*
|
| 268 |
-
text = re.sub(r'
|
| 269 |
-
text = re.sub(r'
|
| 270 |
-
text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE)
|
| 271 |
-
text = re.sub(r'\n+', ' ', text)
|
| 272 |
return text.strip()
|
| 273 |
|
|
|
|
| 274 |
@app.route('/tts', methods=['POST'])
|
| 275 |
def text_to_speech():
|
| 276 |
data = request.get_json()
|
| 277 |
text = data.get('text')
|
| 278 |
-
|
| 279 |
if not text:
|
| 280 |
-
return jsonify({'status': 'error',
|
| 281 |
-
|
| 282 |
try:
|
| 283 |
clean_text = clean_markdown_for_tts(text)
|
|
|
|
|
|
|
|
|
|
| 284 |
tts = gTTS(clean_text, lang='en')
|
| 285 |
mp3_fp = io.BytesIO()
|
| 286 |
tts.write_to_fp(mp3_fp)
|
| 287 |
mp3_fp.seek(0)
|
| 288 |
return Response(mp3_fp, mimetype='audio/mpeg')
|
| 289 |
except Exception as e:
|
| 290 |
-
print(f"
|
| 291 |
-
return jsonify({'status': 'error',
|
|
|
|
|
|
|
| 292 |
|
| 293 |
if __name__ == '__main__':
|
| 294 |
-
port = int(os.environ.get(
|
| 295 |
-
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import uuid
|
| 3 |
+
from flask import Flask, request, render_template, session, jsonify, Response, stream_with_context
|
| 4 |
from werkzeug.utils import secure_filename
|
| 5 |
from rag_processor import create_rag_chain
|
| 6 |
from typing import Sequence, Any, List
|
|
|
|
| 9 |
import io
|
| 10 |
from gtts import gTTS
|
| 11 |
from langchain_core.documents import Document
|
| 12 |
+
from langchain_community.document_loaders import TextLoader, Docx2txtLoader
|
|
|
|
|
|
|
|
|
|
| 13 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 14 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 15 |
from langchain_community.vectorstores import FAISS
|
|
|
|
| 19 |
from langchain_community.chat_message_histories import ChatMessageHistory
|
| 20 |
from langchain.storage import InMemoryStore
|
| 21 |
from sentence_transformers.cross_encoder import CrossEncoder
|
| 22 |
+
|
| 23 |
app = Flask(__name__)
|
| 24 |
app.config['SECRET_KEY'] = os.urandom(24)
|
| 25 |
+
TEMPERATURE_LABELS = {
|
| 26 |
+
'0.2': 'Precise',
|
| 27 |
+
'0.4': 'Confident',
|
| 28 |
+
'0.6': 'Balanced',
|
| 29 |
+
'0.8': 'Flexible',
|
| 30 |
+
'1.0': 'Creative',
|
| 31 |
+
}
|
| 32 |
class LocalReranker(BaseDocumentCompressor):
|
| 33 |
model: Any
|
| 34 |
+
top_n: int = 5
|
| 35 |
|
| 36 |
class Config:
|
| 37 |
arbitrary_types_allowed = True
|
| 38 |
|
| 39 |
+
def compress_documents(self, documents: Sequence[Document], query: str,
|
| 40 |
+
callbacks=None) -> Sequence[Document]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
if not documents:
|
| 42 |
return []
|
|
|
|
| 43 |
pairs = [[query, doc.page_content] for doc in documents]
|
| 44 |
scores = self.model.predict(pairs, show_progress_bar=False)
|
|
|
|
| 45 |
doc_scores = list(zip(documents, scores))
|
| 46 |
+
sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1],
|
| 47 |
+
reverse=True)
|
| 48 |
top_docs = []
|
| 49 |
+
for (doc, score) in sorted_doc_scores[:self.top_n]:
|
| 50 |
doc.metadata['rerank_score'] = float(score)
|
| 51 |
top_docs.append(doc)
|
| 52 |
return top_docs
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def create_optimized_parent_child_chunks(all_docs):
|
| 56 |
+
if not all_docs:
|
| 57 |
+
print ('CHUNKING: No input documents provided!')
|
| 58 |
+
return ([], [], [])
|
| 59 |
+
|
| 60 |
+
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=900,
|
| 61 |
+
chunk_overlap=200, separators=['\n\n', '\n', '. ', '! ',
|
| 62 |
+
'? ', '; ', ', ', ' ', ''])
|
| 63 |
+
child_splitter = RecursiveCharacterTextSplitter(chunk_size=350,
|
| 64 |
+
chunk_overlap=80, separators=['\n', '. ', '! ', '? ', '; ',
|
| 65 |
+
', ', ' ', ''])
|
| 66 |
+
parent_docs = parent_splitter.split_documents(all_docs)
|
| 67 |
+
doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
|
| 68 |
+
child_docs = []
|
| 69 |
+
|
| 70 |
+
for (i, parent_doc) in enumerate(parent_docs):
|
| 71 |
+
parent_id = doc_ids[i]
|
| 72 |
+
children = child_splitter.split_documents([parent_doc])
|
| 73 |
+
for (j, child) in enumerate(children):
|
| 74 |
+
child.metadata.update({'doc_id': parent_id,
|
| 75 |
+
'chunk_index': j,
|
| 76 |
+
'total_chunks': len(children),
|
| 77 |
+
'is_first_chunk': j == 0,
|
| 78 |
+
'is_last_chunk': j == len(children)
|
| 79 |
+
- 1})
|
| 80 |
+
if len(children) > 1:
|
| 81 |
+
if j == 0:
|
| 82 |
+
child.page_content = '[Beginning] ' + child.page_content
|
| 83 |
+
elif j == len(children) - 1:
|
| 84 |
+
child.page_content = '[Continues...] ' + child.page_content
|
| 85 |
+
child_docs.append(child)
|
| 86 |
+
|
| 87 |
+
print (f"CHUNKING: Created {len(parent_docs)} parent and {len(child_docs)} child chunks."
|
| 88 |
+
)
|
| 89 |
+
return (parent_docs, child_docs, doc_ids)
|
| 90 |
+
def get_context_aware_parents(docs: List[Document], store: InMemoryStore) -> List[Document]:
|
| 91 |
+
if not docs:
|
| 92 |
+
return []
|
| 93 |
+
(parent_scores, child_content_by_parent) = ({}, {})
|
| 94 |
+
for doc in docs:
|
| 95 |
+
parent_id = doc.metadata.get('doc_id')
|
| 96 |
+
if parent_id:
|
| 97 |
+
parent_scores[parent_id] = parent_scores.get(parent_id, 0) \
|
| 98 |
+
+ 1
|
| 99 |
+
if parent_id not in child_content_by_parent:
|
| 100 |
+
child_content_by_parent[parent_id] = []
|
| 101 |
+
child_content_by_parent[parent_id].append(doc.page_content)
|
| 102 |
+
|
| 103 |
+
parent_ids = list(parent_scores.keys())
|
| 104 |
+
parents = store.mget(parent_ids)
|
| 105 |
+
enhanced_parents = []
|
| 106 |
+
|
| 107 |
+
for (i, parent) in enumerate(parents):
|
| 108 |
+
if parent is not None:
|
| 109 |
+
parent_id = parent_ids[i]
|
| 110 |
+
if parent_id in child_content_by_parent:
|
| 111 |
+
child_excerpts = '\n'.join(child_content_by_parent[parent_id][:3])
|
| 112 |
+
enhanced_content = f"{parent.page_content}\n\nRelevant excerpts:\n{child_excerpts}"
|
| 113 |
+
enhanced_parent =Document(page_content=enhanced_content,
|
| 114 |
+
metadata={**parent.metadata,
|
| 115 |
+
'child_relevance_score': parent_scores[parent_id],
|
| 116 |
+
'matching_children': len(child_content_by_parent[parent_id])})
|
| 117 |
+
enhanced_parents.append(enhanced_parent)
|
| 118 |
+
else:
|
| 119 |
+
print (f"PARENT_FETCH: Parent {parent_ids[i]} not found in store!")
|
| 120 |
+
|
| 121 |
+
enhanced_parents.sort(key=lambda p: p.metadata.get('child_relevance_score', 0), reverse=True)
|
| 122 |
+
return enhanced_parents
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
is_hf_spaces = bool(os.getenv('SPACE_ID') or os.getenv('SPACES_ZERO_GPU'
|
| 126 |
+
))
|
| 127 |
+
app.config['UPLOAD_FOLDER'] = '/tmp/uploads' if is_hf_spaces else 'uploads'
|
| 128 |
|
| 129 |
try:
|
| 130 |
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
| 131 |
+
print (f"Upload folder ready: {app.config['UPLOAD_FOLDER']}")
|
| 132 |
except Exception as e:
|
| 133 |
+
print (f"Failed to create upload folder, falling back to /tmp: {e}")
|
| 134 |
app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
|
| 135 |
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
|
|
|
| 136 |
|
| 137 |
+
session_data = {}
|
| 138 |
message_histories = {}
|
| 139 |
|
| 140 |
+
print ('Loading embedding model...')
|
| 141 |
try:
|
| 142 |
+
EMBEDDING_MODEL = \
|
| 143 |
+
HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2'
|
| 144 |
+
, model_kwargs={'device': 'cpu'},
|
| 145 |
+
encode_kwargs={'normalize_embeddings': True})
|
| 146 |
+
print ('Embedding model loaded.')
|
| 147 |
except Exception as e:
|
| 148 |
+
print (f"FATAL: Could not load embedding model. Error: {e}")
|
| 149 |
raise e
|
| 150 |
|
| 151 |
+
print ('Loading reranker model...')
|
| 152 |
try:
|
| 153 |
+
RERANKER_MODEL = \
|
| 154 |
+
CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2',
|
| 155 |
+
device='cpu')
|
| 156 |
+
print ('Reranker model loaded.')
|
| 157 |
except Exception as e:
|
| 158 |
+
print (f"FATAL: Could not load reranker model. Error: {e}")
|
| 159 |
raise e
|
| 160 |
|
| 161 |
+
|
| 162 |
def load_pdf_with_fallback(filepath):
|
| 163 |
try:
|
| 164 |
docs = []
|
| 165 |
with fitz.open(filepath) as pdf_doc:
|
| 166 |
+
for (page_num, page) in enumerate(pdf_doc):
|
| 167 |
text = page.get_text()
|
| 168 |
if text.strip():
|
| 169 |
+
docs.append(Document(page_content=text,
|
| 170 |
+
metadata={'source': os.path.basename(filepath),
|
| 171 |
+
'page': page_num + 1}))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
if docs:
|
| 173 |
+
print (f"Loaded PDF: {os.path.basename(filepath)} - {len(docs)} pages"
|
| 174 |
+
)
|
| 175 |
return docs
|
| 176 |
else:
|
| 177 |
+
raise ValueError('No text content found in PDF.')
|
| 178 |
except Exception as e:
|
| 179 |
+
print (f"PyMuPDF failed for {filepath}: {e}")
|
| 180 |
raise
|
| 181 |
|
| 182 |
+
|
| 183 |
+
LOADER_MAPPING = {'.txt': TextLoader, '.pdf': load_pdf_with_fallback,
|
| 184 |
+
'.docx': Docx2txtLoader}
|
| 185 |
+
|
|
|
|
| 186 |
|
| 187 |
def get_session_history(session_id: str) -> ChatMessageHistory:
|
| 188 |
if session_id not in message_histories:
|
| 189 |
message_histories[session_id] = ChatMessageHistory()
|
| 190 |
return message_histories[session_id]
|
| 191 |
|
| 192 |
+
|
| 193 |
@app.route('/health', methods=['GET'])
|
| 194 |
def health_check():
|
| 195 |
+
return (jsonify({'status': 'healthy'}), 200)
|
| 196 |
+
|
| 197 |
|
| 198 |
@app.route('/', methods=['GET'])
|
| 199 |
def index():
|
| 200 |
return render_template('index.html')
|
| 201 |
|
| 202 |
+
|
| 203 |
@app.route('/upload', methods=['POST'])
|
| 204 |
def upload_files():
|
| 205 |
files = request.files.getlist('file')
|
|
|
|
|
|
|
| 206 |
|
| 207 |
+
temperature_str = request.form.get('temperature', '0.2')
|
| 208 |
+
temperature = float(temperature_str)
|
| 209 |
+
model_name = request.form.get('model_name',
|
| 210 |
+
'moonshotai/kimi-k2-instruct')
|
| 211 |
+
print (f"UPLOAD: Model: {model_name}, Temp: {temperature}")
|
| 212 |
|
| 213 |
+
if not files or all(f.filename == '' for f in files):
|
| 214 |
+
return (jsonify({'status': 'error',
|
| 215 |
+
'message': 'No selected files.'}), 400)
|
| 216 |
+
|
| 217 |
+
(all_docs, processed_files, failed_files) = ([], [], [])
|
| 218 |
+
print (f"Processing {len(files)} file(s)...")
|
| 219 |
for file in files:
|
| 220 |
if file and file.filename:
|
| 221 |
filename = secure_filename(file.filename)
|
|
|
|
| 224 |
file.save(filepath)
|
| 225 |
file_ext = os.path.splitext(filename)[1].lower()
|
| 226 |
if file_ext not in LOADER_MAPPING:
|
| 227 |
+
raise ValueError('Unsupported file format.')
|
|
|
|
| 228 |
loader_func = LOADER_MAPPING[file_ext]
|
| 229 |
+
docs = loader_func(filepath) if file_ext == '.pdf' \
|
| 230 |
+
else loader_func(filepath).load()
|
| 231 |
if not docs:
|
| 232 |
+
raise ValueError('No content extracted.')
|
|
|
|
| 233 |
all_docs.extend(docs)
|
| 234 |
processed_files.append(filename)
|
|
|
|
| 235 |
except Exception as e:
|
| 236 |
+
print (f"✗ Error processing {filename}: {e}")
|
| 237 |
+
failed_files.append(f"{filename} ({e})")
|
|
|
|
| 238 |
|
| 239 |
if not all_docs:
|
| 240 |
+
return (jsonify({'status': 'error',
|
| 241 |
+
'message': f"Failed to process all files. Reasons: {', '.join(failed_files)}"
|
| 242 |
+
}), 400)
|
|
|
|
| 243 |
|
| 244 |
+
print (f"UPLOAD: Processed {len(processed_files)} files.")
|
| 245 |
try:
|
| 246 |
+
print ('Starting RAG pipeline setup...')
|
| 247 |
+
(parent_docs, child_docs, doc_ids) = \
|
| 248 |
+
create_optimized_parent_child_chunks(all_docs)
|
| 249 |
+
if not child_docs:
|
| 250 |
+
raise ValueError('No child documents created during chunking.')
|
| 251 |
|
| 252 |
+
vectorstore = FAISS.from_documents(child_docs, EMBEDDING_MODEL)
|
| 253 |
store = InMemoryStore()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
store.mset(list(zip(doc_ids, parent_docs)))
|
| 255 |
+
print (f"Indexed {len(child_docs)} document chunks.")
|
|
|
|
| 256 |
|
| 257 |
bm25_retriever = BM25Retriever.from_documents(child_docs)
|
| 258 |
+
bm25_retriever.k = 12
|
| 259 |
+
faiss_retriever = vectorstore.as_retriever(search_kwargs={'k': 12})
|
| 260 |
+
ensemble_retriever = \
|
| 261 |
+
EnsembleRetriever(retrievers=[bm25_retriever,
|
| 262 |
+
faiss_retriever], weights=[0.6, 0.4])
|
| 263 |
+
reranker = LocalReranker(model=RERANKER_MODEL, top_n=5)
|
|
|
|
|
|
|
|
|
|
| 264 |
|
|
|
|
|
|
|
| 265 |
def get_parents(docs: List[Document]) -> List[Document]:
|
| 266 |
+
return get_context_aware_parents(docs, store)
|
|
|
|
| 267 |
|
| 268 |
+
compression_retriever = \
|
| 269 |
+
ContextualCompressionRetriever(base_compressor=reranker,
|
| 270 |
+
base_retriever=ensemble_retriever)
|
|
|
|
| 271 |
final_retriever = compression_retriever | get_parents
|
|
|
|
|
|
|
| 272 |
|
| 273 |
session_id = str(uuid.uuid4())
|
| 274 |
+
(rag_chain, api_key_manager) = \
|
| 275 |
+
create_rag_chain(retriever=final_retriever,
|
| 276 |
+
get_session_history_func=get_session_history,
|
| 277 |
+
model_name=model_name,
|
| 278 |
+
temperature=temperature)
|
| 279 |
+
|
| 280 |
+
session_data[session_id] = {'chain': rag_chain,
|
| 281 |
+
'model_name': model_name,
|
| 282 |
+
'temperature': temperature,
|
| 283 |
+
'api_key_manager': api_key_manager}
|
| 284 |
+
|
| 285 |
+
success_msg = f"Processed: {', '.join(processed_files)}"
|
| 286 |
if failed_files:
|
| 287 |
+
success_msg += f". Failed: {', '.join(failed_files)}"
|
| 288 |
+
|
| 289 |
+
mode_label = TEMPERATURE_LABELS.get(temperature_str,
|
| 290 |
+
temperature_str)
|
| 291 |
+
|
| 292 |
+
print (f"UPLOAD COMPLETE: Session {session_id} is ready.")
|
| 293 |
|
| 294 |
return jsonify({
|
| 295 |
+
'status': 'success',
|
| 296 |
'filename': success_msg,
|
| 297 |
+
'session_id': session_id,
|
| 298 |
+
'model_name': model_name,
|
| 299 |
+
'mode': mode_label,
|
| 300 |
+
})
|
| 301 |
except Exception as e:
|
| 302 |
import traceback
|
| 303 |
traceback.print_exc()
|
| 304 |
+
return (jsonify({'status': 'error',
|
| 305 |
+
'message': f'RAG setup failed: {e}'}), 500)
|
| 306 |
|
| 307 |
+
@app.route('/chat', methods=['POST', 'GET'])
|
| 308 |
def chat():
|
| 309 |
+
if request.method == 'GET':
|
| 310 |
+
question = request.args.get('question')
|
| 311 |
+
session_id = request.args.get('session_id')
|
| 312 |
+
print(f"Received GET request for chat: session={session_id}, question={question[:50]}...")
|
| 313 |
+
elif request.method == 'POST':
|
| 314 |
+
data = request.get_json()
|
| 315 |
+
question = data.get('question')
|
| 316 |
+
session_id = data.get('session_id') or session.get('session_id')
|
| 317 |
+
print(f"Received POST request for chat: session={session_id}, question={question[:50]}...")
|
| 318 |
+
else:
|
| 319 |
+
return (jsonify({'status': 'error', 'message': 'Method not allowed'}), 405)
|
| 320 |
+
|
| 321 |
+
if not question:
|
| 322 |
+
error_msg = "Error: No question provided."
|
| 323 |
+
print(f"CHAT Validation Error: {error_msg}")
|
| 324 |
+
if request.method == 'GET':
|
| 325 |
+
def error_stream():
|
| 326 |
+
yield f'data: {{"error": "{error_msg}"}}\n\n'
|
| 327 |
+
return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=400)
|
| 328 |
+
return jsonify({'status': 'error','message': error_msg}), 400
|
| 329 |
+
|
| 330 |
+
if not session_id or session_id not in session_data:
|
| 331 |
+
error_msg = "Error: Invalid session. Please upload documents first."
|
| 332 |
+
print(f"CHAT Validation Error: Invalid session {session_id}.")
|
| 333 |
+
if request.method == 'GET':
|
| 334 |
+
def error_stream():
|
| 335 |
+
yield f'data: {{"error": "{error_msg}"}}\n\n'
|
| 336 |
+
return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=400)
|
| 337 |
+
return jsonify({'status': 'error', 'message': error_msg }), 400
|
| 338 |
+
try:
|
| 339 |
+
session_info = session_data[session_id]
|
| 340 |
+
rag_chain = session_info['chain']
|
| 341 |
+
model_name = session_info['model_name']
|
| 342 |
+
temperature_float = session_info['temperature']
|
| 343 |
+
temperature_str = str(temperature_float)
|
| 344 |
+
mode_label = TEMPERATURE_LABELS.get(temperature_str, temperature_str)
|
| 345 |
|
| 346 |
+
print (f"CHAT: Streaming response for session {session_id} (Model: {model_name}, Temp: {temperature_float})...")
|
|
|
|
| 347 |
|
| 348 |
+
def generate_chunks():
|
| 349 |
+
full_response = ''
|
| 350 |
+
try:
|
| 351 |
+
stream_iterator = rag_chain.stream({'question': question},
|
| 352 |
+
config={'configurable': {'session_id': session_id}})
|
| 353 |
+
|
| 354 |
+
for chunk in stream_iterator:
|
| 355 |
+
if isinstance(chunk, str):
|
| 356 |
+
full_response += chunk
|
| 357 |
+
token_escaped = chunk.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n')
|
| 358 |
+
model_name_escaped = model_name.replace('"', '\\"')
|
| 359 |
+
mode_label_escaped = mode_label.replace('"', '\\"')
|
| 360 |
+
yield f'data: {{"token": "{token_escaped}", "model_name": "{model_name_escaped}", "mode": "{mode_label_escaped}"}}\n\n'
|
| 361 |
+
else:
|
| 362 |
+
print(f"Received non-string chunk: {type(chunk)}")
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
print ('CHAT: Streaming finished successfully.')
|
| 366 |
+
|
| 367 |
+
except Exception as e:
|
| 368 |
+
print(f"CHAT Error during streaming generation: {e}")
|
| 369 |
+
import traceback
|
| 370 |
+
traceback.print_exc()
|
| 371 |
+
error_msg = f"Error during response generation: {str(e)}".replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n')
|
| 372 |
+
yield f'data: {{"error": "{error_msg}"}}\n\n'
|
| 373 |
+
return Response(stream_with_context(generate_chunks()), mimetype='text/event-stream')
|
| 374 |
|
| 375 |
except Exception as e:
|
| 376 |
+
print(f"CHAT Setup Error: {e}")
|
| 377 |
+
import traceback
|
| 378 |
+
traceback.print_exc()
|
| 379 |
+
error_msg = f"Error setting up chat stream: {str(e)}"
|
| 380 |
+
if request.method == 'GET':
|
| 381 |
+
def error_stream():
|
| 382 |
+
clean_error_msg= error_msg.replace("\"", "\\\"").replace("n", "\\n")
|
| 383 |
+
yield f'data: {{"error": "{clean_error_msg}"}}\n\n'
|
| 384 |
+
return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=500)
|
| 385 |
+
return (jsonify({'status': 'error', 'message': error_msg}), 500)
|
| 386 |
+
|
| 387 |
|
| 388 |
def clean_markdown_for_tts(text: str) -> str:
|
| 389 |
+
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
|
| 390 |
+
text = re.sub(r'[`*_#]', '', text)
|
| 391 |
+
text = re.sub(r'^\s*[\-\*\+]\s+', '', text, flags=re.MULTILINE)
|
| 392 |
+
text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
|
| 393 |
+
text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE)
|
| 394 |
+
text = re.sub(r'\n+', ' ', text)
|
| 395 |
+
text = re.sub(r'\s{2,}', ' ', text)
|
|
|
|
|
|
|
| 396 |
return text.strip()
|
| 397 |
|
| 398 |
+
|
| 399 |
@app.route('/tts', methods=['POST'])
|
| 400 |
def text_to_speech():
|
| 401 |
data = request.get_json()
|
| 402 |
text = data.get('text')
|
|
|
|
| 403 |
if not text:
|
| 404 |
+
return (jsonify({'status': 'error',
|
| 405 |
+
'message': 'No text provided.'}), 400)
|
| 406 |
try:
|
| 407 |
clean_text = clean_markdown_for_tts(text)
|
| 408 |
+
if not clean_text:
|
| 409 |
+
return (jsonify({'status': 'error', 'message': 'No speakable text found.'}), 400)
|
| 410 |
+
|
| 411 |
tts = gTTS(clean_text, lang='en')
|
| 412 |
mp3_fp = io.BytesIO()
|
| 413 |
tts.write_to_fp(mp3_fp)
|
| 414 |
mp3_fp.seek(0)
|
| 415 |
return Response(mp3_fp, mimetype='audio/mpeg')
|
| 416 |
except Exception as e:
|
| 417 |
+
print (f"TTS Error: {e}")
|
| 418 |
+
return (jsonify({'status': 'error',
|
| 419 |
+
'message': 'Failed to generate audio.'}), 500)
|
| 420 |
+
|
| 421 |
|
| 422 |
if __name__ == '__main__':
|
| 423 |
+
port = int(os.environ.get('PORT', 7860))
|
| 424 |
+
print (f"Starting Flask app on port {port}")
|
| 425 |
+
app.run(host='0.0.0.0', port=port, debug=False, threaded=True)
|
evaluate.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import asyncio
|
| 3 |
+
import uuid
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from datasets import Dataset
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from typing import Sequence, Any, List
|
| 8 |
+
|
| 9 |
+
# Ragas and LangChain components
|
| 10 |
+
from ragas import evaluate
|
| 11 |
+
from ragas.metrics import (
|
| 12 |
+
faithfulness,
|
| 13 |
+
answer_relevancy,
|
| 14 |
+
context_recall,
|
| 15 |
+
context_precision,
|
| 16 |
+
)
|
| 17 |
+
from ragas.testset import TestsetGenerator
|
| 18 |
+
# NOTE: The 'evolutions' import has been completely removed.
|
| 19 |
+
|
| 20 |
+
# Your specific RAG components from app.py
|
| 21 |
+
from langchain_groq import ChatGroq
|
| 22 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
| 23 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 24 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 25 |
+
from langchain_community.vectorstores import FAISS
|
| 26 |
+
from langchain.storage import InMemoryStore
|
| 27 |
+
from langchain_community.retrievers import BM25Retriever
|
| 28 |
+
from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriever
|
| 29 |
+
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
| 30 |
+
from langchain_core.documents import Document
|
| 31 |
+
from sentence_transformers.cross_encoder import CrossEncoder
|
| 32 |
+
from rag_processor import create_rag_chain
|
| 33 |
+
from langchain_community.chat_message_histories import ChatMessageHistory
|
| 34 |
+
import fitz
|
| 35 |
+
|
| 36 |
+
# Load environment variables
|
| 37 |
+
load_dotenv()
|
| 38 |
+
|
| 39 |
+
# --- Re-implementing LocalReranker from app.py ---
|
| 40 |
+
class LocalReranker(BaseDocumentCompressor):
|
| 41 |
+
model: Any
|
| 42 |
+
top_n: int = 3
|
| 43 |
+
class Config:
|
| 44 |
+
arbitrary_types_allowed = True
|
| 45 |
+
def compress_documents(self, documents: Sequence[Document], query: str, callbacks=None) -> Sequence[Document]:
|
| 46 |
+
if not documents: return []
|
| 47 |
+
pairs = [[query, doc.page_content] for doc in documents]
|
| 48 |
+
scores = self.model.predict(pairs, show_progress_bar=False)
|
| 49 |
+
doc_scores = list(zip(documents, scores))
|
| 50 |
+
sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True)
|
| 51 |
+
top_docs = []
|
| 52 |
+
for doc, score in sorted_doc_scores[:self.top_n]:
|
| 53 |
+
doc.metadata['rerank_score'] = float(score)
|
| 54 |
+
top_docs.append(doc)
|
| 55 |
+
return top_docs
|
| 56 |
+
|
| 57 |
+
# --- Helper Functions ---
|
| 58 |
+
def load_pdf_with_fallback(filepath):
|
| 59 |
+
"""Load PDF using PyMuPDF"""
|
| 60 |
+
try:
|
| 61 |
+
docs = []
|
| 62 |
+
with fitz.open(filepath) as pdf_doc:
|
| 63 |
+
for page_num, page in enumerate(pdf_doc):
|
| 64 |
+
text = page.get_text()
|
| 65 |
+
if text.strip():
|
| 66 |
+
docs.append(Document(
|
| 67 |
+
page_content=text,
|
| 68 |
+
metadata={"source": os.path.basename(filepath), "page": page_num + 1}
|
| 69 |
+
))
|
| 70 |
+
if docs:
|
| 71 |
+
print(f"✓ Successfully loaded PDF: {filepath}")
|
| 72 |
+
return docs
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError("No text content found in PDF.")
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f"✗ PyMuPDF failed for {filepath}: {e}")
|
| 77 |
+
raise
|
| 78 |
+
|
| 79 |
+
async def main():
|
| 80 |
+
"""Main execution function"""
|
| 81 |
+
print("\n" + "="*60 + "\nSTARTING RAGAS EVALUATION\n" + "="*60)
|
| 82 |
+
|
| 83 |
+
pdf_path = "uploads/Unit_-_1_Introduction.pdf"
|
| 84 |
+
if not os.path.exists(pdf_path):
|
| 85 |
+
print(f"✗ Error: PDF not found at {pdf_path}")
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
# --- 1. Setup Models ---
|
| 90 |
+
print("\n--- 1. Initializing Models ---")
|
| 91 |
+
groq_api_key = os.getenv("GROQ_API_KEY")
|
| 92 |
+
if not groq_api_key or groq_api_key == "your_groq_api_key_here":
|
| 93 |
+
raise ValueError("GROQ_API_KEY not found or is a placeholder.")
|
| 94 |
+
|
| 95 |
+
generator_llm = ChatGroq(model="llama-3.1-8b-instant", api_key=groq_api_key)
|
| 96 |
+
critic_llm = ChatGroq(model="llama-3.1-70b-versatile", api_key=groq_api_key)
|
| 97 |
+
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 98 |
+
reranker_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device='cpu')
|
| 99 |
+
print("✓ Models initialized.")
|
| 100 |
+
|
| 101 |
+
# --- 2. Setup RAG Pipeline ---
|
| 102 |
+
print("\n--- 2. Setting up RAG Pipeline ---")
|
| 103 |
+
documents = load_pdf_with_fallback(pdf_path)
|
| 104 |
+
|
| 105 |
+
# Split documents
|
| 106 |
+
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=400)
|
| 107 |
+
child_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=50)
|
| 108 |
+
parent_docs = parent_splitter.split_documents(documents)
|
| 109 |
+
doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
|
| 110 |
+
|
| 111 |
+
child_docs = []
|
| 112 |
+
for i, doc in enumerate(parent_docs):
|
| 113 |
+
_id = doc_ids[i]
|
| 114 |
+
sub_docs = child_splitter.split_documents([doc])
|
| 115 |
+
for child in sub_docs:
|
| 116 |
+
child.metadata["doc_id"] = _id
|
| 117 |
+
child_docs.extend(sub_docs)
|
| 118 |
+
|
| 119 |
+
store = InMemoryStore()
|
| 120 |
+
store.mset(list(zip(doc_ids, parent_docs)))
|
| 121 |
+
vectorstore = FAISS.from_documents(child_docs, embedding_model)
|
| 122 |
+
|
| 123 |
+
bm25_retriever = BM25Retriever.from_documents(child_docs, k=10)
|
| 124 |
+
faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
|
| 125 |
+
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.4, 0.6])
|
| 126 |
+
|
| 127 |
+
reranker = LocalReranker(model=reranker_model, top_n=5)
|
| 128 |
+
compression_retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=ensemble_retriever)
|
| 129 |
+
|
| 130 |
+
def get_parents(docs: List[Document]) -> List[Document]:
|
| 131 |
+
parent_ids = {d.metadata["doc_id"] for d in docs}
|
| 132 |
+
return store.mget(list(parent_ids))
|
| 133 |
+
|
| 134 |
+
final_retriever = compression_retriever | get_parents
|
| 135 |
+
|
| 136 |
+
message_histories = {}
|
| 137 |
+
def get_session_history(session_id: str):
|
| 138 |
+
if session_id not in message_histories:
|
| 139 |
+
message_histories[session_id] = ChatMessageHistory()
|
| 140 |
+
return message_histories[session_id]
|
| 141 |
+
|
| 142 |
+
rag_chain = create_rag_chain(final_retriever, get_session_history)
|
| 143 |
+
print("✓ RAG chain created successfully.")
|
| 144 |
+
|
| 145 |
+
# --- 3. Generate Testset ---
|
| 146 |
+
print("\n--- 3. Generating Test Questions ---")
|
| 147 |
+
generator = TestsetGenerator.from_langchain(generator_llm, critic_llm, embedding_model)
|
| 148 |
+
|
| 149 |
+
# Generate a simple test set without complex distributions
|
| 150 |
+
testset = generator.generate_with_langchain_docs(documents, testset_size=5)
|
| 151 |
+
print("✓ Testset generated.")
|
| 152 |
+
|
| 153 |
+
# --- 4. Run RAG Chain on Testset ---
|
| 154 |
+
print("\n--- 4. Running RAG Chain to Generate Answers ---")
|
| 155 |
+
test_questions = [item['question'] for item in testset.to_pandas().to_dict('records')]
|
| 156 |
+
ground_truths = [item['ground_truth'] for item in testset.to_pandas().to_dict('records')]
|
| 157 |
+
|
| 158 |
+
answers = []
|
| 159 |
+
contexts = []
|
| 160 |
+
|
| 161 |
+
for i, question in enumerate(test_questions):
|
| 162 |
+
print(f" Processing question {i+1}/{len(test_questions)}...")
|
| 163 |
+
# Retrieve contexts
|
| 164 |
+
retrieved_docs = final_retriever.invoke(question)
|
| 165 |
+
contexts.append([doc.page_content for doc in retrieved_docs])
|
| 166 |
+
# Get answer from chain
|
| 167 |
+
config = {"configurable": {"session_id": str(uuid.uuid4())}}
|
| 168 |
+
answer = await rag_chain.ainvoke({"question": question}, config=config)
|
| 169 |
+
answers.append(answer)
|
| 170 |
+
|
| 171 |
+
# --- 5. Evaluate with Ragas ---
|
| 172 |
+
print("\n--- 5. Evaluating Results with Ragas ---")
|
| 173 |
+
eval_data = {
|
| 174 |
+
'question': test_questions,
|
| 175 |
+
'answer': answers,
|
| 176 |
+
'contexts': contexts,
|
| 177 |
+
'ground_truth': ground_truths
|
| 178 |
+
}
|
| 179 |
+
eval_dataset = Dataset.from_dict(eval_data)
|
| 180 |
+
|
| 181 |
+
result = evaluate(
|
| 182 |
+
eval_dataset,
|
| 183 |
+
metrics=[faithfulness, answer_relevancy, context_precision, context_recall],
|
| 184 |
+
llm=critic_llm,
|
| 185 |
+
embeddings=embedding_model
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
print("\n" + "="*60 + "\nEVALUATION RESULTS\n" + "="*60)
|
| 189 |
+
print(result)
|
| 190 |
+
|
| 191 |
+
# --- 6. Save Results ---
|
| 192 |
+
print("\n--- 6. Saving Results ---")
|
| 193 |
+
results_df = result.to_pandas()
|
| 194 |
+
results_df.to_csv("evaluation_results.csv", index=False)
|
| 195 |
+
print("✓ Evaluation results saved to evaluation_results.csv")
|
| 196 |
+
|
| 197 |
+
print("\n" + "="*60 + "\nEVALUATION COMPLETE!\n" + "="*60)
|
| 198 |
+
|
| 199 |
+
except Exception as e:
|
| 200 |
+
print(f"\n✗ An error occurred during the process: {e}")
|
| 201 |
+
import traceback
|
| 202 |
+
traceback.print_exc()
|
| 203 |
+
|
| 204 |
+
if __name__ == "__main__":
|
| 205 |
+
asyncio.run(main())
|
query_expansion.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
"""
|
| 3 |
+
Query Expansion System for CogniChat RAG Application
|
| 4 |
+
|
| 5 |
+
This module implements advanced query expansion techniques to improve retrieval quality:
|
| 6 |
+
- QueryAnalyzer: Extracts intent, entities, and keywords
|
| 7 |
+
- QueryRephraser: Generates natural language variations
|
| 8 |
+
- MultiQueryExpander: Creates diverse query formulations
|
| 9 |
+
- MultiHopReasoner: Connects concepts across documents
|
| 10 |
+
- FallbackStrategies: Handles edge cases gracefully
|
| 11 |
+
|
| 12 |
+
Author: CogniChat Team
|
| 13 |
+
Date: October 19, 2025
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import re
|
| 17 |
+
from typing import List, Dict, Any, Optional
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from enum import Enum
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class QueryStrategy(Enum):
|
| 23 |
+
"""Query expansion strategies with different complexity levels."""
|
| 24 |
+
QUICK = "quick" # 2 queries - fast, minimal expansion
|
| 25 |
+
BALANCED = "balanced" # 3-4 queries - good balance
|
| 26 |
+
COMPREHENSIVE = "comprehensive" # 5-6 queries - maximum coverage
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class QueryAnalysis:
|
| 31 |
+
"""Results from query analysis."""
|
| 32 |
+
intent: str # question, definition, comparison, explanation, etc.
|
| 33 |
+
entities: List[str] # Named entities extracted
|
| 34 |
+
keywords: List[str] # Important keywords
|
| 35 |
+
complexity: str # simple, medium, complex
|
| 36 |
+
domain: Optional[str] = None # Technical domain if detected
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class ExpandedQuery:
|
| 41 |
+
"""Container for expanded query variations."""
|
| 42 |
+
original: str
|
| 43 |
+
variations: List[str]
|
| 44 |
+
strategy_used: QueryStrategy
|
| 45 |
+
analysis: QueryAnalysis
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class QueryAnalyzer:
|
| 49 |
+
"""
|
| 50 |
+
Analyzes queries to extract intent, entities, and key information.
|
| 51 |
+
Uses LLM-based analysis for intelligent query understanding.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, llm=None):
|
| 55 |
+
"""
|
| 56 |
+
Initialize QueryAnalyzer.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
llm: Optional LangChain LLM for advanced analysis
|
| 60 |
+
"""
|
| 61 |
+
self.llm = llm
|
| 62 |
+
self.intent_patterns = {
|
| 63 |
+
'definition': r'\b(what is|define|meaning of|definition)\b',
|
| 64 |
+
'how_to': r'\b(how to|how do|how can|steps to)\b',
|
| 65 |
+
'comparison': r'\b(compare|difference|versus|vs|better than)\b',
|
| 66 |
+
'explanation': r'\b(why|explain|reason|cause)\b',
|
| 67 |
+
'listing': r'\b(list|enumerate|what are|types of)\b',
|
| 68 |
+
'example': r'\b(example|instance|sample|case)\b',
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
def analyze(self, query: str) -> QueryAnalysis:
|
| 72 |
+
"""
|
| 73 |
+
Analyze query to extract intent, entities, and keywords.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
query: User's original query
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
QueryAnalysis object with extracted information
|
| 80 |
+
"""
|
| 81 |
+
query_lower = query.lower()
|
| 82 |
+
|
| 83 |
+
# Detect intent
|
| 84 |
+
intent = self._detect_intent(query_lower)
|
| 85 |
+
|
| 86 |
+
# Extract entities (simplified - can be enhanced with NER)
|
| 87 |
+
entities = self._extract_entities(query)
|
| 88 |
+
|
| 89 |
+
# Extract keywords
|
| 90 |
+
keywords = self._extract_keywords(query)
|
| 91 |
+
|
| 92 |
+
# Assess complexity
|
| 93 |
+
complexity = self._assess_complexity(query, entities, keywords)
|
| 94 |
+
|
| 95 |
+
# Detect domain
|
| 96 |
+
domain = self._detect_domain(query_lower)
|
| 97 |
+
|
| 98 |
+
return QueryAnalysis(
|
| 99 |
+
intent=intent,
|
| 100 |
+
entities=entities,
|
| 101 |
+
keywords=keywords,
|
| 102 |
+
complexity=complexity,
|
| 103 |
+
domain=domain
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def _detect_intent(self, query_lower: str) -> str:
|
| 107 |
+
"""Detect query intent using pattern matching."""
|
| 108 |
+
for intent, pattern in self.intent_patterns.items():
|
| 109 |
+
if re.search(pattern, query_lower):
|
| 110 |
+
return intent
|
| 111 |
+
return 'general'
|
| 112 |
+
|
| 113 |
+
def _extract_entities(self, query: str) -> List[str]:
|
| 114 |
+
"""Extract named entities (simplified version)."""
|
| 115 |
+
# Look for capitalized words (potential entities)
|
| 116 |
+
words = query.split()
|
| 117 |
+
entities = []
|
| 118 |
+
|
| 119 |
+
for word in words:
|
| 120 |
+
# Skip common words at sentence start
|
| 121 |
+
if word[0].isupper() and word.lower() not in ['what', 'how', 'why', 'when', 'where', 'which']:
|
| 122 |
+
entities.append(word)
|
| 123 |
+
|
| 124 |
+
# Look for quoted terms
|
| 125 |
+
quoted = re.findall(r'"([^"]+)"', query)
|
| 126 |
+
entities.extend(quoted)
|
| 127 |
+
|
| 128 |
+
return list(set(entities))
|
| 129 |
+
|
| 130 |
+
def _extract_keywords(self, query: str) -> List[str]:
|
| 131 |
+
"""Extract important keywords from query."""
|
| 132 |
+
# Remove stop words (simplified list)
|
| 133 |
+
stop_words = {
|
| 134 |
+
'a', 'an', 'the', 'is', 'are', 'was', 'were', 'be', 'been',
|
| 135 |
+
'what', 'how', 'why', 'when', 'where', 'which', 'who',
|
| 136 |
+
'do', 'does', 'did', 'can', 'could', 'should', 'would',
|
| 137 |
+
'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
# Split and filter
|
| 141 |
+
words = re.findall(r'\b\w+\b', query.lower())
|
| 142 |
+
keywords = [w for w in words if w not in stop_words and len(w) > 2]
|
| 143 |
+
|
| 144 |
+
return keywords[:10] # Limit to top 10
|
| 145 |
+
|
| 146 |
+
def _assess_complexity(self, query: str, entities: List[str], keywords: List[str]) -> str:
|
| 147 |
+
"""Assess query complexity."""
|
| 148 |
+
word_count = len(query.split())
|
| 149 |
+
entity_count = len(entities)
|
| 150 |
+
keyword_count = len(keywords)
|
| 151 |
+
|
| 152 |
+
# Simple scoring
|
| 153 |
+
score = word_count + (entity_count * 2) + (keyword_count * 1.5)
|
| 154 |
+
|
| 155 |
+
if score < 15:
|
| 156 |
+
return 'simple'
|
| 157 |
+
elif score < 30:
|
| 158 |
+
return 'medium'
|
| 159 |
+
else:
|
| 160 |
+
return 'complex'
|
| 161 |
+
|
| 162 |
+
def _detect_domain(self, query_lower: str) -> Optional[str]:
|
| 163 |
+
"""Detect technical domain if present."""
|
| 164 |
+
domains = {
|
| 165 |
+
'programming': ['code', 'function', 'class', 'variable', 'algorithm', 'debug'],
|
| 166 |
+
'data_science': ['model', 'dataset', 'training', 'prediction', 'accuracy'],
|
| 167 |
+
'machine_learning': ['neural', 'network', 'learning', 'ai', 'deep learning'],
|
| 168 |
+
'web': ['html', 'css', 'javascript', 'api', 'frontend', 'backend'],
|
| 169 |
+
'database': ['sql', 'query', 'database', 'table', 'index'],
|
| 170 |
+
'security': ['encryption', 'authentication', 'vulnerability', 'attack'],
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
for domain, keywords in domains.items():
|
| 174 |
+
if any(kw in query_lower for kw in keywords):
|
| 175 |
+
return domain
|
| 176 |
+
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class QueryRephraser:
|
| 181 |
+
"""
|
| 182 |
+
Generates natural language variations of queries using multiple strategies.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
def __init__(self, llm=None):
|
| 186 |
+
"""
|
| 187 |
+
Initialize QueryRephraser.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
llm: LangChain LLM for generating variations
|
| 191 |
+
"""
|
| 192 |
+
self.llm = llm
|
| 193 |
+
|
| 194 |
+
def generate_variations(
|
| 195 |
+
self,
|
| 196 |
+
query: str,
|
| 197 |
+
analysis: QueryAnalysis,
|
| 198 |
+
strategy: QueryStrategy = QueryStrategy.BALANCED
|
| 199 |
+
) -> List[str]:
|
| 200 |
+
"""
|
| 201 |
+
Generate query variations based on strategy.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
query: Original query
|
| 205 |
+
analysis: Query analysis results
|
| 206 |
+
strategy: Expansion strategy to use
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
List of query variations
|
| 210 |
+
"""
|
| 211 |
+
variations = [query] # Always include original
|
| 212 |
+
|
| 213 |
+
if strategy == QueryStrategy.QUICK:
|
| 214 |
+
# Just add synonym variation
|
| 215 |
+
variations.append(self._synonym_variation(query, analysis))
|
| 216 |
+
|
| 217 |
+
elif strategy == QueryStrategy.BALANCED:
|
| 218 |
+
# Add synonym, expanded, and simplified versions
|
| 219 |
+
variations.append(self._synonym_variation(query, analysis))
|
| 220 |
+
variations.append(self._expanded_variation(query, analysis))
|
| 221 |
+
variations.append(self._simplified_variation(query, analysis))
|
| 222 |
+
|
| 223 |
+
elif strategy == QueryStrategy.COMPREHENSIVE:
|
| 224 |
+
# Add all variations
|
| 225 |
+
variations.append(self._synonym_variation(query, analysis))
|
| 226 |
+
variations.append(self._expanded_variation(query, analysis))
|
| 227 |
+
variations.append(self._simplified_variation(query, analysis))
|
| 228 |
+
variations.append(self._keyword_focused(query, analysis))
|
| 229 |
+
variations.append(self._context_variation(query, analysis))
|
| 230 |
+
# Add one more: alternate phrasing
|
| 231 |
+
if analysis.intent in ['how_to', 'explanation']:
|
| 232 |
+
variations.append(f"Guide to {' '.join(analysis.keywords[:3])}")
|
| 233 |
+
|
| 234 |
+
# Remove duplicates and None values
|
| 235 |
+
variations = [v for v in variations if v]
|
| 236 |
+
return list(dict.fromkeys(variations)) # Preserve order, remove dupes
|
| 237 |
+
|
| 238 |
+
def _synonym_variation(self, query: str, analysis: QueryAnalysis) -> str:
|
| 239 |
+
"""Generate variation using synonyms."""
|
| 240 |
+
# Common synonym replacements
|
| 241 |
+
synonyms = {
|
| 242 |
+
'error': 'issue',
|
| 243 |
+
'problem': 'issue',
|
| 244 |
+
'fix': 'resolve',
|
| 245 |
+
'use': 'utilize',
|
| 246 |
+
'create': 'generate',
|
| 247 |
+
'make': 'create',
|
| 248 |
+
'get': 'retrieve',
|
| 249 |
+
'show': 'display',
|
| 250 |
+
'find': 'locate',
|
| 251 |
+
'explain': 'describe',
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
words = query.lower().split()
|
| 255 |
+
for i, word in enumerate(words):
|
| 256 |
+
if word in synonyms:
|
| 257 |
+
words[i] = synonyms[word]
|
| 258 |
+
break # Only replace one word to keep natural
|
| 259 |
+
|
| 260 |
+
return ' '.join(words).capitalize()
|
| 261 |
+
|
| 262 |
+
def _expanded_variation(self, query: str, analysis: QueryAnalysis) -> str:
|
| 263 |
+
"""Generate expanded version with more detail."""
|
| 264 |
+
if analysis.intent == 'definition':
|
| 265 |
+
return f"Detailed explanation and definition of {' '.join(analysis.keywords)}"
|
| 266 |
+
elif analysis.intent == 'how_to':
|
| 267 |
+
return f"Step-by-step guide on {query.lower()}"
|
| 268 |
+
elif analysis.intent == 'comparison':
|
| 269 |
+
return f"Comprehensive comparison: {query}"
|
| 270 |
+
else:
|
| 271 |
+
# Add qualifying words
|
| 272 |
+
return f"Detailed information about {query.lower()}"
|
| 273 |
+
|
| 274 |
+
def _simplified_variation(self, query: str, analysis: QueryAnalysis) -> str:
|
| 275 |
+
"""Generate simplified version focusing on core concepts."""
|
| 276 |
+
# Use just the keywords
|
| 277 |
+
if len(analysis.keywords) >= 2:
|
| 278 |
+
return ' '.join(analysis.keywords[:3])
|
| 279 |
+
return query
|
| 280 |
+
|
| 281 |
+
def _keyword_focused(self, query: str, analysis: QueryAnalysis) -> str:
|
| 282 |
+
"""Create keyword-focused variation for BM25."""
|
| 283 |
+
keywords = analysis.keywords + analysis.entities
|
| 284 |
+
return ' '.join(keywords[:5])
|
| 285 |
+
|
| 286 |
+
def _context_variation(self, query: str, analysis: QueryAnalysis) -> str:
|
| 287 |
+
"""Add contextual information if domain detected."""
|
| 288 |
+
if analysis.domain:
|
| 289 |
+
return f"{query} in {analysis.domain} context"
|
| 290 |
+
return query
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class MultiQueryExpander:
|
| 294 |
+
"""
|
| 295 |
+
Main query expansion orchestrator that combines analysis and rephrasing.
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
def __init__(self, llm=None):
|
| 299 |
+
"""
|
| 300 |
+
Initialize MultiQueryExpander.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
llm: LangChain LLM for advanced expansions
|
| 304 |
+
"""
|
| 305 |
+
self.analyzer = QueryAnalyzer(llm)
|
| 306 |
+
self.rephraser = QueryRephraser(llm)
|
| 307 |
+
|
| 308 |
+
def expand(
|
| 309 |
+
self,
|
| 310 |
+
query: str,
|
| 311 |
+
strategy: QueryStrategy = QueryStrategy.BALANCED,
|
| 312 |
+
max_queries: int = 6
|
| 313 |
+
) -> ExpandedQuery:
|
| 314 |
+
"""
|
| 315 |
+
Expand query into multiple variations.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
query: Original user query
|
| 319 |
+
strategy: Expansion strategy
|
| 320 |
+
max_queries: Maximum number of queries to generate
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
ExpandedQuery object with all variations
|
| 324 |
+
"""
|
| 325 |
+
# Analyze query
|
| 326 |
+
analysis = self.analyzer.analyze(query)
|
| 327 |
+
|
| 328 |
+
# Generate variations
|
| 329 |
+
variations = self.rephraser.generate_variations(query, analysis, strategy)
|
| 330 |
+
|
| 331 |
+
# Limit to max_queries
|
| 332 |
+
variations = variations[:max_queries]
|
| 333 |
+
|
| 334 |
+
return ExpandedQuery(
|
| 335 |
+
original=query,
|
| 336 |
+
variations=variations,
|
| 337 |
+
strategy_used=strategy,
|
| 338 |
+
analysis=analysis
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class MultiHopReasoner:
|
| 343 |
+
"""
|
| 344 |
+
Implements multi-hop reasoning to connect concepts across documents.
|
| 345 |
+
Useful for complex queries that require information from multiple sources.
|
| 346 |
+
"""
|
| 347 |
+
|
| 348 |
+
def __init__(self, llm=None):
|
| 349 |
+
"""
|
| 350 |
+
Initialize MultiHopReasoner.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
llm: LangChain LLM for reasoning
|
| 354 |
+
"""
|
| 355 |
+
self.llm = llm
|
| 356 |
+
|
| 357 |
+
def generate_sub_queries(self, query: str, analysis: QueryAnalysis) -> List[str]:
|
| 358 |
+
"""
|
| 359 |
+
Break complex query into sub-queries for multi-hop reasoning.
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
query: Original complex query
|
| 363 |
+
analysis: Query analysis
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
List of sub-queries
|
| 367 |
+
"""
|
| 368 |
+
sub_queries = [query]
|
| 369 |
+
|
| 370 |
+
# For comparison queries, create separate queries for each entity
|
| 371 |
+
if analysis.intent == 'comparison' and len(analysis.entities) >= 2:
|
| 372 |
+
for entity in analysis.entities[:2]:
|
| 373 |
+
sub_queries.append(f"Information about {entity}")
|
| 374 |
+
elif analysis.intent == 'comparison' and len(analysis.keywords) >= 2:
|
| 375 |
+
# Fallback: use keywords if no entities found
|
| 376 |
+
for keyword in analysis.keywords[:2]:
|
| 377 |
+
sub_queries.append(f"Information about {keyword}")
|
| 378 |
+
|
| 379 |
+
# For how-to queries, break into steps
|
| 380 |
+
if analysis.intent == 'how_to' and len(analysis.keywords) >= 2:
|
| 381 |
+
main_topic = ' '.join(analysis.keywords[:2])
|
| 382 |
+
sub_queries.append(f"Prerequisites for {main_topic}")
|
| 383 |
+
sub_queries.append(f"Steps to {main_topic}")
|
| 384 |
+
|
| 385 |
+
# For complex questions, create focused sub-queries
|
| 386 |
+
if analysis.complexity == 'complex' and len(analysis.keywords) > 3:
|
| 387 |
+
# Create queries focusing on different keyword groups
|
| 388 |
+
mid = len(analysis.keywords) // 2
|
| 389 |
+
sub_queries.append(' '.join(analysis.keywords[:mid]))
|
| 390 |
+
sub_queries.append(' '.join(analysis.keywords[mid:]))
|
| 391 |
+
|
| 392 |
+
return sub_queries[:5] # Limit to 5 sub-queries
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class FallbackStrategies:
|
| 396 |
+
"""
|
| 397 |
+
Implements fallback strategies for queries that don't retrieve good results.
|
| 398 |
+
"""
|
| 399 |
+
|
| 400 |
+
@staticmethod
|
| 401 |
+
def simplify_query(query: str) -> str:
|
| 402 |
+
"""Simplify query by removing modifiers and focusing on core terms."""
|
| 403 |
+
# Remove question words
|
| 404 |
+
query = re.sub(r'\b(what|how|why|when|where|which|who|can|could|should|would)\b', '', query, flags=re.IGNORECASE)
|
| 405 |
+
|
| 406 |
+
# Remove common phrases
|
| 407 |
+
query = re.sub(r'\b(is|are|was|were|be|been|the|a|an)\b', '', query, flags=re.IGNORECASE)
|
| 408 |
+
|
| 409 |
+
# Clean up extra spaces
|
| 410 |
+
query = re.sub(r'\s+', ' ', query).strip()
|
| 411 |
+
|
| 412 |
+
return query
|
| 413 |
+
|
| 414 |
+
@staticmethod
|
| 415 |
+
def broaden_query(query: str, analysis: QueryAnalysis) -> str:
|
| 416 |
+
"""Broaden query to increase recall."""
|
| 417 |
+
# Remove specific constraints
|
| 418 |
+
query = re.sub(r'\b(specific|exactly|precisely|only|just)\b', '', query, flags=re.IGNORECASE)
|
| 419 |
+
|
| 420 |
+
# Add general terms
|
| 421 |
+
if analysis.keywords:
|
| 422 |
+
return f"{analysis.keywords[0]} overview"
|
| 423 |
+
|
| 424 |
+
return query
|
| 425 |
+
|
| 426 |
+
@staticmethod
|
| 427 |
+
def focus_entities(analysis: QueryAnalysis) -> str:
|
| 428 |
+
"""Create entity-focused query as fallback."""
|
| 429 |
+
if analysis.entities:
|
| 430 |
+
return ' '.join(analysis.entities)
|
| 431 |
+
elif analysis.keywords:
|
| 432 |
+
return ' '.join(analysis.keywords[:3])
|
| 433 |
+
return ""
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
# Convenience function for easy integration
|
| 437 |
+
def expand_query_simple(
|
| 438 |
+
query: str,
|
| 439 |
+
strategy: str = "balanced",
|
| 440 |
+
llm=None
|
| 441 |
+
) -> List[str]:
|
| 442 |
+
"""
|
| 443 |
+
Simple function to expand a query without dealing with classes.
|
| 444 |
+
|
| 445 |
+
Args:
|
| 446 |
+
query: User's query to expand
|
| 447 |
+
strategy: "quick", "balanced", or "comprehensive"
|
| 448 |
+
llm: Optional LangChain LLM
|
| 449 |
+
|
| 450 |
+
Returns:
|
| 451 |
+
List of expanded query variations
|
| 452 |
+
|
| 453 |
+
Example:
|
| 454 |
+
>>> queries = expand_query_simple("How do I debug Python code?", strategy="balanced")
|
| 455 |
+
>>> print(queries)
|
| 456 |
+
['How do I debug Python code?', 'How do I resolve Python code?', ...]
|
| 457 |
+
"""
|
| 458 |
+
expander = MultiQueryExpander(llm=llm)
|
| 459 |
+
strategy_enum = QueryStrategy(strategy)
|
| 460 |
+
expanded = expander.expand(query, strategy=strategy_enum)
|
| 461 |
+
return expanded.variations
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
# Example usage and testing
|
| 465 |
+
if __name__ == "__main__":
|
| 466 |
+
# Example 1: Simple query expansion
|
| 467 |
+
print("=" * 60)
|
| 468 |
+
print("Example 1: Simple Query Expansion")
|
| 469 |
+
print("=" * 60)
|
| 470 |
+
|
| 471 |
+
query = "What is machine learning?"
|
| 472 |
+
queries = expand_query_simple(query, strategy="balanced")
|
| 473 |
+
|
| 474 |
+
print(f"\nOriginal: {query}")
|
| 475 |
+
print(f"\nExpanded queries ({len(queries)}):")
|
| 476 |
+
for i, q in enumerate(queries, 1):
|
| 477 |
+
print(f" {i}. {q}")
|
| 478 |
+
|
| 479 |
+
# Example 2: Complex query with full analysis
|
| 480 |
+
print("\n" + "=" * 60)
|
| 481 |
+
print("Example 2: Complex Query with Analysis")
|
| 482 |
+
print("=" * 60)
|
| 483 |
+
|
| 484 |
+
expander = MultiQueryExpander()
|
| 485 |
+
query = "How do I compare the performance of different neural network architectures?"
|
| 486 |
+
result = expander.expand(query, strategy=QueryStrategy.COMPREHENSIVE)
|
| 487 |
+
|
| 488 |
+
print(f"\nOriginal: {result.original}")
|
| 489 |
+
print(f"\nAnalysis:")
|
| 490 |
+
print(f" Intent: {result.analysis.intent}")
|
| 491 |
+
print(f" Entities: {result.analysis.entities}")
|
| 492 |
+
print(f" Keywords: {result.analysis.keywords}")
|
| 493 |
+
print(f" Complexity: {result.analysis.complexity}")
|
| 494 |
+
print(f" Domain: {result.analysis.domain}")
|
| 495 |
+
print(f"\nExpanded queries ({len(result.variations)}):")
|
| 496 |
+
for i, q in enumerate(result.variations, 1):
|
| 497 |
+
print(f" {i}. {q}")
|
| 498 |
+
|
| 499 |
+
# Example 3: Multi-hop reasoning
|
| 500 |
+
print("\n" + "=" * 60)
|
| 501 |
+
print("Example 3: Multi-Hop Reasoning")
|
| 502 |
+
print("=" * 60)
|
| 503 |
+
|
| 504 |
+
reasoner = MultiHopReasoner()
|
| 505 |
+
analyzer = QueryAnalyzer()
|
| 506 |
+
|
| 507 |
+
query = "Compare Python and Java for web development"
|
| 508 |
+
analysis = analyzer.analyze(query)
|
| 509 |
+
sub_queries = reasoner.generate_sub_queries(query, analysis)
|
| 510 |
+
|
| 511 |
+
print(f"\nOriginal: {query}")
|
| 512 |
+
print(f"\nSub-queries for multi-hop reasoning:")
|
| 513 |
+
for i, sq in enumerate(sub_queries, 1):
|
| 514 |
+
print(f" {i}. {sq}")
|
| 515 |
+
|
| 516 |
+
# Example 4: Fallback strategies
|
| 517 |
+
print("\n" + "=" * 60)
|
| 518 |
+
print("Example 4: Fallback Strategies")
|
| 519 |
+
print("=" * 60)
|
| 520 |
+
|
| 521 |
+
query = "What is the specific difference between supervised and unsupervised learning?"
|
| 522 |
+
analysis = analyzer.analyze(query)
|
| 523 |
+
|
| 524 |
+
|
rag_processor.py
CHANGED
|
@@ -6,109 +6,413 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
| 6 |
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
| 7 |
from langchain_core.output_parsers import StrOutputParser
|
| 8 |
from langchain_core.runnables.history import RunnableWithMessageHistory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
def create_rag_chain(retriever, get_session_history_func):
|
| 11 |
-
"""
|
| 12 |
-
Creates an advanced Retrieval-Augmented Generation (RAG) chain with hybrid search,
|
| 13 |
-
query rewriting, answer refinement, and conversational memory.
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
print("key loaded")
|
| 28 |
-
if not api_key or api_key == "your_groq_api_key_here":
|
| 29 |
-
error_msg = "GROQ_API_KEY not found or not configured properly.\n"
|
| 30 |
-
print("Not found key")
|
| 31 |
-
if os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"):
|
| 32 |
-
error_msg += (
|
| 33 |
-
"For Hugging Face Spaces: Set GROQ_API_KEY in your Space's Settings > Repository Secrets.\n"
|
| 34 |
-
"Go to your Space settings and add GROQ_API_KEY as a secret variable."
|
| 35 |
)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
Chat History:
|
| 54 |
{chat_history}
|
| 55 |
|
| 56 |
Follow-up Question: {question}
|
| 57 |
-
|
|
|
|
| 58 |
rewrite_prompt = ChatPromptTemplate.from_messages([
|
| 59 |
("system", rewrite_template),
|
| 60 |
MessagesPlaceholder(variable_name="chat_history"),
|
| 61 |
-
("human", "
|
| 62 |
])
|
| 63 |
query_rewriter = rewrite_prompt | llm | StrOutputParser()
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
rag_prompt = ChatPromptTemplate.from_messages([
|
| 74 |
("system", rag_template),
|
| 75 |
MessagesPlaceholder(variable_name="chat_history"),
|
| 76 |
("human", "{question}"),
|
| 77 |
])
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
| 90 |
chain_with_memory = RunnableWithMessageHistory(
|
| 91 |
conversational_rag_chain,
|
| 92 |
get_session_history_func,
|
| 93 |
input_messages_key="question",
|
| 94 |
history_messages_key="chat_history",
|
| 95 |
)
|
| 96 |
-
print("\nSetting up answer refinement chain...")
|
| 97 |
-
refine_template = """You are an expert at editing and refining content.
|
| 98 |
-
Your task is to take a given answer and improve its clarity, structure, and readability.
|
| 99 |
-
Use formatting such as bold text, bullet points, or numbered lists where it enhances the explanation.
|
| 100 |
-
Do not add any new information that wasn't in the original answer.
|
| 101 |
-
|
| 102 |
-
Original Answer:
|
| 103 |
-
{answer}
|
| 104 |
-
|
| 105 |
-
Refined Answer:"""
|
| 106 |
-
refine_prompt = ChatPromptTemplate.from_template(refine_template)
|
| 107 |
-
refinement_chain = refine_prompt | llm | StrOutputParser()
|
| 108 |
-
|
| 109 |
-
final_chain = (
|
| 110 |
-
lambda input_dict: {"answer": chain_with_memory.invoke(input_dict, config=input_dict.get('config'))}
|
| 111 |
-
) | refinement_chain
|
| 112 |
|
| 113 |
-
print("
|
| 114 |
-
|
|
|
|
|
|
|
|
|
| 6 |
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
| 7 |
from langchain_core.output_parsers import StrOutputParser
|
| 8 |
from langchain_core.runnables.history import RunnableWithMessageHistory
|
| 9 |
+
from langchain_core.documents import Document
|
| 10 |
+
from query_expansion import expand_query_simple
|
| 11 |
+
from typing import List, Optional
|
| 12 |
+
import time
|
| 13 |
+
|
| 14 |
+
class GroqAPIKeyManager:
|
| 15 |
+
def __init__(self, api_keys: List[str]):
|
| 16 |
+
self.api_keys = [key for key in api_keys if key and key != "your_groq_api_key_here"]
|
| 17 |
+
if not self.api_keys:
|
| 18 |
+
raise ValueError("No valid API keys provided!")
|
| 19 |
+
|
| 20 |
+
self.current_index = 0
|
| 21 |
+
self.failed_keys = set()
|
| 22 |
+
self.success_count = {key: 0 for key in self.api_keys}
|
| 23 |
+
self.failure_count = {key: 0 for key in self.api_keys}
|
| 24 |
+
|
| 25 |
+
print(f"API Key Manager: Loaded {len(self.api_keys)} API keys")
|
| 26 |
+
|
| 27 |
+
def get_current_key(self) -> str:
|
| 28 |
+
return self.api_keys[self.current_index]
|
| 29 |
+
|
| 30 |
+
def mark_success(self, api_key: str):
|
| 31 |
+
if api_key in self.success_count:
|
| 32 |
+
self.success_count[api_key] += 1
|
| 33 |
+
if api_key in self.failed_keys:
|
| 34 |
+
self.failed_keys.remove(api_key)
|
| 35 |
+
print(f"API Key #{self.api_keys.index(api_key) + 1} recovered!")
|
| 36 |
+
|
| 37 |
+
def mark_failure(self, api_key: str):
|
| 38 |
+
if api_key in self.failure_count:
|
| 39 |
+
self.failure_count[api_key] += 1
|
| 40 |
+
self.failed_keys.add(api_key)
|
| 41 |
+
|
| 42 |
+
def rotate_to_next_key(self) -> bool:
|
| 43 |
+
initial_index = self.current_index
|
| 44 |
+
attempts = 0
|
| 45 |
+
|
| 46 |
+
while attempts < len(self.api_keys):
|
| 47 |
+
self.current_index = (self.current_index + 1) % len(self.api_keys)
|
| 48 |
+
attempts += 1
|
| 49 |
+
|
| 50 |
+
current_key = self.api_keys[self.current_index]
|
| 51 |
+
if attempts >= len(self.api_keys):
|
| 52 |
+
print(f"All keys attempted, retrying with key #{self.current_index + 1}")
|
| 53 |
+
return True
|
| 54 |
+
if current_key not in self.failed_keys:
|
| 55 |
+
print(f"Switching to API Key #{self.current_index + 1}")
|
| 56 |
+
return True
|
| 57 |
+
|
| 58 |
+
return False
|
| 59 |
+
|
| 60 |
+
def get_statistics(self) -> str:
|
| 61 |
+
stats = []
|
| 62 |
+
for i, key in enumerate(self.api_keys):
|
| 63 |
+
success = self.success_count[key]
|
| 64 |
+
failure = self.failure_count[key]
|
| 65 |
+
status = "FAILED" if key in self.failed_keys else "ACTIVE"
|
| 66 |
+
masked_key = key[:8] + "..." + key[-4:] if len(key) > 12 else "***"
|
| 67 |
+
stats.append(f" Key #{i+1} ({masked_key}): {success} success, {failure} failures [{status}]")
|
| 68 |
+
return "\n".join(stats)
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
def load_api_keys_from_hf_secrets() -> List[str]:
|
| 72 |
+
api_keys = []
|
| 73 |
+
secret_names = ['GROQ_API_KEY_1', 'GROQ_API_KEY_2', 'GROQ_API_KEY_3', 'GROQ_API_KEY_4']
|
| 74 |
+
|
| 75 |
+
print("Loading API keys from Hugging Face Secrets...")
|
| 76 |
+
|
| 77 |
+
for secret_name in secret_names:
|
| 78 |
+
try:
|
| 79 |
+
api_key = os.getenv(secret_name)
|
| 80 |
+
|
| 81 |
+
if api_key and api_key.strip() and api_key != "your_groq_api_key_here":
|
| 82 |
+
api_keys.append(api_key.strip())
|
| 83 |
+
print(f" Loaded: {secret_name}")
|
| 84 |
+
else:
|
| 85 |
+
print(f" Not found or empty: {secret_name}")
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print(f" Error loading {secret_name}: {str(e)}")
|
| 88 |
+
return api_keys
|
| 89 |
+
|
| 90 |
|
| 91 |
+
def create_llm_with_fallback(
|
| 92 |
+
api_key_manager: GroqAPIKeyManager,
|
| 93 |
+
model_name: str,
|
| 94 |
+
temperature: float,
|
| 95 |
+
max_retries: int = 3
|
| 96 |
+
) -> ChatGroq:
|
| 97 |
+
for attempt in range(max_retries):
|
| 98 |
+
current_key = api_key_manager.get_current_key()
|
| 99 |
|
| 100 |
+
try:
|
| 101 |
+
llm = ChatGroq(
|
| 102 |
+
model_name=model_name,
|
| 103 |
+
api_key=current_key,
|
| 104 |
+
temperature=temperature
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
)
|
| 106 |
+
test_result = llm.invoke("test")
|
| 107 |
+
api_key_manager.mark_success(current_key)
|
| 108 |
+
return llm
|
| 109 |
+
|
| 110 |
+
except Exception as e:
|
| 111 |
+
error_msg = str(e).lower()
|
| 112 |
+
api_key_manager.mark_failure(current_key)
|
| 113 |
+
if "rate" in error_msg or "limit" in error_msg:
|
| 114 |
+
print(f" Rate limit hit on API Key #{api_key_manager.current_index + 1}")
|
| 115 |
+
elif "auth" in error_msg or "api" in error_msg:
|
| 116 |
+
print(f" Authentication failed on API Key #{api_key_manager.current_index + 1}")
|
| 117 |
+
else:
|
| 118 |
+
print(f" Error with API Key #{api_key_manager.current_index + 1}: {str(e)[:50]}")
|
| 119 |
+
|
| 120 |
+
if attempt < max_retries - 1:
|
| 121 |
+
if api_key_manager.rotate_to_next_key():
|
| 122 |
+
print(f" Retrying with next API key (Attempt {attempt + 2}/{max_retries})...")
|
| 123 |
+
time.sleep(1)
|
| 124 |
+
else:
|
| 125 |
+
raise ValueError("All API keys failed!")
|
| 126 |
+
else:
|
| 127 |
+
raise ValueError(f"Failed to initialize LLM after {max_retries} attempts")
|
| 128 |
+
|
| 129 |
+
raise ValueError("Failed to create LLM with any available API key")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def create_multi_query_retriever(base_retriever, llm, strategy: str = "balanced"):
|
| 133 |
+
def multi_query_retrieve(query: str) -> List[Document]:
|
| 134 |
+
query_variations = expand_query_simple(query, strategy=strategy, llm=llm)
|
| 135 |
+
all_docs = []
|
| 136 |
+
seen_content = set()
|
| 137 |
+
for i, query_var in enumerate(query_variations):
|
| 138 |
+
try:
|
| 139 |
+
docs = base_retriever.invoke(query_var)
|
| 140 |
+
for doc in docs:
|
| 141 |
+
content_hash = hash(doc.page_content)
|
| 142 |
+
if content_hash not in seen_content:
|
| 143 |
+
seen_content.add(content_hash)
|
| 144 |
+
all_docs.append(doc)
|
| 145 |
+
except Exception as e:
|
| 146 |
+
print(f" Query Expansion Error (Query {i+1}): {str(e)[:50]}")
|
| 147 |
+
continue
|
| 148 |
+
print(f" Query Expansion: Retrieved {len(all_docs)} unique documents.")
|
| 149 |
+
return all_docs
|
| 150 |
+
return multi_query_retrieve
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def get_system_prompt(temperature: float) -> str:
|
| 154 |
+
if temperature <= 0.4:
|
| 155 |
+
return """You are CogniChat, an expert document analysis assistant specializing in comprehensive and well-structured answers.
|
| 156 |
+
|
| 157 |
+
RESPONSE GUIDELINES:
|
| 158 |
+
|
| 159 |
+
**Structure & Formatting:**
|
| 160 |
+
- Start with a direct answer to the question
|
| 161 |
+
- Use **bold** for key terms, important concepts, and technical terminology
|
| 162 |
+
- Use bullet points (•) for lists, features, or multiple items
|
| 163 |
+
- Use numbered lists (1., 2., 3.) for steps, procedures, or sequential information
|
| 164 |
+
- Use ### Headers to organize different sections or topics
|
| 165 |
+
- Add blank lines between sections for readability
|
| 166 |
+
|
| 167 |
+
**Source Citation:**
|
| 168 |
+
- Always cite information using: [Source: filename, Page: X] and cite it at the end of the entire answer only
|
| 169 |
+
- Place citations at the end of your final answer only
|
| 170 |
+
- Do not cite sources within the body of your answer
|
| 171 |
+
- Multiple sources: [Source: doc1.pdf, Page: 3; doc2.pdf, Page: 7]
|
| 172 |
+
|
| 173 |
+
**Completeness:**
|
| 174 |
+
- Provide thorough, detailed answers using ALL relevant information from context
|
| 175 |
+
- Summarize and properly elaborate each point for increased clarity
|
| 176 |
+
- If the question has multiple parts, address each part clearly
|
| 177 |
+
|
| 178 |
+
**Accuracy:**
|
| 179 |
+
- ONLY use information from the provided context documents below
|
| 180 |
+
- If information is incomplete, state what IS available and what ISN'T
|
| 181 |
+
- If the answer isn't in the context, clearly state: "I cannot find this information in the uploaded documents"
|
| 182 |
+
- Never make assumptions or add information not in the context
|
| 183 |
+
|
| 184 |
+
---
|
| 185 |
+
|
| 186 |
+
{context}
|
| 187 |
+
|
| 188 |
+
---
|
| 189 |
+
|
| 190 |
+
Now answer the following question comprehensively using the context above:"""
|
| 191 |
+
|
| 192 |
+
elif temperature <= 0.8:
|
| 193 |
+
return """You are CogniChat, an intelligent document analysis assistant that combines accuracy with engaging communication.
|
| 194 |
+
|
| 195 |
+
RESPONSE GUIDELINES:
|
| 196 |
+
|
| 197 |
+
**Communication Style:**
|
| 198 |
+
- Present information in a clear, engaging manner
|
| 199 |
+
- Use **bold** for emphasis on important concepts
|
| 200 |
+
- Balance structure with natural flow
|
| 201 |
+
- Make complex topics accessible and interesting
|
| 202 |
+
|
| 203 |
+
**Content Approach:**
|
| 204 |
+
- Ground your response firmly in the provided context
|
| 205 |
+
- Add helpful explanations and connections between concepts
|
| 206 |
+
- Use analogies or examples when they help clarify ideas (but keep them brief)
|
| 207 |
+
- Organize information logically with headers (###) and lists where appropriate
|
| 208 |
+
|
| 209 |
+
**Source Attribution:**
|
| 210 |
+
- Cite sources at the end: [Source: filename, Page: X]
|
| 211 |
+
- Be transparent about what the documents do and don't contain
|
| 212 |
+
|
| 213 |
+
**Accuracy:**
|
| 214 |
+
- Base your answer on the context documents provided
|
| 215 |
+
- If information is partial, explain what's available
|
| 216 |
+
- Acknowledge gaps: "The documents don't cover this aspect"
|
| 217 |
+
|
| 218 |
+
---
|
| 219 |
+
|
| 220 |
+
{context}
|
| 221 |
+
|
| 222 |
+
---
|
| 223 |
+
|
| 224 |
+
Now answer the following question in an engaging yet accurate way:"""
|
| 225 |
+
|
| 226 |
+
else: # temperature > 0.8
|
| 227 |
+
# Creative BUT CLEAR prompt - REVISED VERSION
|
| 228 |
+
return """You are CogniChat, a creative document analyst who makes complex information clear, memorable, and engaging.
|
| 229 |
|
| 230 |
+
YOUR CORE MISSION: **CLARITY FIRST, CREATIVITY SECOND**
|
| 231 |
+
|
| 232 |
+
Make information easier to understand, not harder. Your creativity should illuminate, not obscure.
|
| 233 |
+
|
| 234 |
+
**CREATIVE CLARITY PRINCIPLES:**
|
| 235 |
+
|
| 236 |
+
1. **Simplify, Don't Complicate**
|
| 237 |
+
- Break down complex concepts into simple, digestible parts
|
| 238 |
+
- Use everyday language alongside technical terms
|
| 239 |
+
- Explain jargon immediately in plain English
|
| 240 |
+
- Short sentences for complex ideas, varied length for rhythm
|
| 241 |
+
|
| 242 |
+
2. **Smart Use of Examples & Analogies** (Use Sparingly!)
|
| 243 |
+
- Only use analogies when they genuinely make something clearer
|
| 244 |
+
- Keep analogies simple and relatable (everyday objects/experiences)
|
| 245 |
+
- Never use metaphors that require explanation themselves
|
| 246 |
+
- If you can explain it directly in simple terms, do that instead
|
| 247 |
+
|
| 248 |
+
3. **Engaging Structure**
|
| 249 |
+
- Start with the core answer in one clear sentence
|
| 250 |
+
- Use **bold** to highlight key takeaways
|
| 251 |
+
- Break information into logical chunks with ### headers
|
| 252 |
+
- Use bullet points for clarity, not decoration
|
| 253 |
+
- Add brief transition phrases to connect ideas smoothly
|
| 254 |
+
|
| 255 |
+
4. **Conversational Yet Precise**
|
| 256 |
+
- Write like you're explaining to a smart friend
|
| 257 |
+
- Use "you" and active voice to engage readers
|
| 258 |
+
- Ask occasional rhetorical questions only if they aid understanding
|
| 259 |
+
- Vary sentence length to maintain interest
|
| 260 |
+
- Use emojis sparingly (1-2 max) and only where they add clarity
|
| 261 |
+
|
| 262 |
+
5. **Visual Clarity**
|
| 263 |
+
- Strategic use of formatting: **bold** for key terms, *italics* for emphasis
|
| 264 |
+
- White space between sections for easy scanning
|
| 265 |
+
- Progressive disclosure: simple concepts first, details after
|
| 266 |
+
- Numbered lists for sequences, bullets for related items
|
| 267 |
+
|
| 268 |
+
**WHAT TO AVOID:**
|
| 269 |
+
- Flowery or overly descriptive language
|
| 270 |
+
- Complex metaphors that need their own explanation
|
| 271 |
+
- Long narrative storytelling that buries the facts
|
| 272 |
+
- Multiple rhetorical questions in a row
|
| 273 |
+
- Overuse of emojis or exclamation points
|
| 274 |
+
- Making simple things sound complicated
|
| 275 |
+
|
| 276 |
+
**ACCURACY BOUNDARIES:**
|
| 277 |
+
- Creative explanation and presentation of facts
|
| 278 |
+
- Simple, helpful examples from common knowledge
|
| 279 |
+
- Reorganizing information for better understanding
|
| 280 |
+
- Never invent facts not in the documents
|
| 281 |
+
- Don't contradict source material
|
| 282 |
+
- If info is missing, say so clearly and briefly
|
| 283 |
+
|
| 284 |
+
**Source Attribution:**
|
| 285 |
+
- End with: [Source: filename, Page: X]
|
| 286 |
+
- Keep it simple and clear
|
| 287 |
+
|
| 288 |
+
---
|
| 289 |
+
|
| 290 |
+
{context}
|
| 291 |
+
|
| 292 |
+
---
|
| 293 |
+
|
| 294 |
+
Now, explain the answer clearly and engagingly. Remember: if your grandmother couldn't understand it, simplify more:"""
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def create_rag_chain(
|
| 299 |
+
retriever,
|
| 300 |
+
get_session_history_func,
|
| 301 |
+
enable_query_expansion=True,
|
| 302 |
+
expansion_strategy="balanced",
|
| 303 |
+
model_name: str = "moonshotai/kimi-k2-instruct",
|
| 304 |
+
temperature: float = 0.2,
|
| 305 |
+
api_keys: Optional[List[str]] = None
|
| 306 |
+
):
|
| 307 |
+
if api_keys is None:
|
| 308 |
+
api_keys = load_api_keys_from_hf_secrets()
|
| 309 |
+
|
| 310 |
+
if not api_keys:
|
| 311 |
+
raise ValueError(
|
| 312 |
+
"No valid API keys found! Please set GROQ_API_KEY or GROQ_API_KEY_1, "
|
| 313 |
+
"GROQ_API_KEY_2, GROQ_API_KEY_3, GROQ_API_KEY_4 in your .env file"
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
api_key_manager = GroqAPIKeyManager(api_keys)
|
| 317 |
+
|
| 318 |
+
print(f" RAG: Initializing LLM - Model: {model_name}, Temp: {temperature}")
|
| 319 |
+
|
| 320 |
+
if temperature <= 0.4:
|
| 321 |
+
creativity_mode = "FACTUAL & STRUCTURED"
|
| 322 |
+
elif temperature <= 0.8:
|
| 323 |
+
creativity_mode = "BALANCED & ENGAGING"
|
| 324 |
+
else:
|
| 325 |
+
creativity_mode = "CREATIVE & STORYTELLING"
|
| 326 |
+
print(f"Creativity Mode: {creativity_mode}")
|
| 327 |
+
|
| 328 |
+
llm = create_llm_with_fallback(api_key_manager, model_name, temperature)
|
| 329 |
+
print(f"LLM initialized with API Key #{api_key_manager.current_index + 1}")
|
| 330 |
+
|
| 331 |
+
if enable_query_expansion:
|
| 332 |
+
print(f"RAG: Query Expansion ENABLED (Strategy: {expansion_strategy})")
|
| 333 |
+
enhanced_retriever = create_multi_query_retriever(
|
| 334 |
+
base_retriever=retriever,
|
| 335 |
+
llm=llm,
|
| 336 |
+
strategy=expansion_strategy
|
| 337 |
+
)
|
| 338 |
+
else:
|
| 339 |
+
enhanced_retriever = retriever
|
| 340 |
+
|
| 341 |
+
rewrite_template = """You are an expert at optimizing search queries for document retrieval.
|
| 342 |
+
|
| 343 |
+
Given the conversation history and a follow-up question, create a comprehensive standalone question that:
|
| 344 |
+
1. Incorporates all relevant context from the chat history
|
| 345 |
+
2. Expands abbreviations and resolves all pronouns (it, they, this, that, etc.)
|
| 346 |
+
3. Includes key technical terms and concepts that would help find relevant documents
|
| 347 |
+
4. Maintains the original intent, specificity, and detail level
|
| 348 |
+
5. If the question asks for comparison or multiple items, ensure all items are in the query
|
| 349 |
|
| 350 |
Chat History:
|
| 351 |
{chat_history}
|
| 352 |
|
| 353 |
Follow-up Question: {question}
|
| 354 |
+
|
| 355 |
+
Optimized Standalone Question:"""
|
| 356 |
rewrite_prompt = ChatPromptTemplate.from_messages([
|
| 357 |
("system", rewrite_template),
|
| 358 |
MessagesPlaceholder(variable_name="chat_history"),
|
| 359 |
+
("human", "{question}")
|
| 360 |
])
|
| 361 |
query_rewriter = rewrite_prompt | llm | StrOutputParser()
|
| 362 |
+
|
| 363 |
+
def format_docs(docs):
|
| 364 |
+
if not docs:
|
| 365 |
+
return "No relevant documents found in the knowledge base."
|
| 366 |
+
|
| 367 |
+
formatted_parts = []
|
| 368 |
+
for i, doc in enumerate(docs, 1):
|
| 369 |
+
source = doc.metadata.get('source', 'Unknown Document')
|
| 370 |
+
page = doc.metadata.get('page', 'N/A')
|
| 371 |
+
rerank_score = doc.metadata.get('rerank_score')
|
| 372 |
+
content = doc.page_content.strip()
|
| 373 |
+
|
| 374 |
+
doc_header = f"{'='*60}\nDOCUMENT {i}\n{'='*60}"
|
| 375 |
+
metadata_line = f"Source: {source} | Page: {page}"
|
| 376 |
+
if rerank_score:
|
| 377 |
+
metadata_line += f" | Relevance: {rerank_score:.3f}"
|
| 378 |
+
|
| 379 |
+
formatted_parts.append(
|
| 380 |
+
f"{doc_header}\n"
|
| 381 |
+
f"{metadata_line}\n"
|
| 382 |
+
f"{'-'*60}\n"
|
| 383 |
+
f"{content}\n"
|
| 384 |
+
)
|
| 385 |
+
return f"RETRIEVED CONTEXT ({len(docs)} documents):\n\n" + "\n".join(formatted_parts)
|
| 386 |
+
|
| 387 |
+
rag_template = get_system_prompt(temperature)
|
| 388 |
+
|
| 389 |
rag_prompt = ChatPromptTemplate.from_messages([
|
| 390 |
("system", rag_template),
|
| 391 |
MessagesPlaceholder(variable_name="chat_history"),
|
| 392 |
("human", "{question}"),
|
| 393 |
])
|
| 394 |
+
|
| 395 |
+
rewriter_input = RunnableParallel({
|
| 396 |
+
"question": itemgetter("question"),
|
| 397 |
+
"chat_history": itemgetter("chat_history"),
|
| 398 |
+
})
|
| 399 |
+
|
| 400 |
+
retrieval_chain = rewriter_input | query_rewriter | enhanced_retriever | format_docs
|
| 401 |
+
|
| 402 |
+
conversational_rag_chain = RunnableParallel({
|
| 403 |
+
"context": retrieval_chain,
|
| 404 |
+
"question": itemgetter("question"),
|
| 405 |
+
"chat_history": itemgetter("chat_history"),
|
| 406 |
+
}) | rag_prompt | llm | StrOutputParser()
|
| 407 |
+
|
| 408 |
chain_with_memory = RunnableWithMessageHistory(
|
| 409 |
conversational_rag_chain,
|
| 410 |
get_session_history_func,
|
| 411 |
input_messages_key="question",
|
| 412 |
history_messages_key="chat_history",
|
| 413 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
|
| 415 |
+
print("RAG: Chain created successfully.")
|
| 416 |
+
print("\n" + api_key_manager.get_statistics())
|
| 417 |
+
|
| 418 |
+
return chain_with_memory, api_key_manager
|
templates/index.html
CHANGED
|
@@ -1,615 +1,885 @@
|
|
| 1 |
<!DOCTYPE html>
|
| 2 |
-
<html lang="en">
|
| 3 |
-
<head>
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
position: absolute;
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
width: 56px;
|
| 74 |
-
height: 56px;
|
| 75 |
-
border-radius: 50%;
|
| 76 |
-
border: 3px solid;
|
| 77 |
-
border-color: var(--primary) transparent;
|
| 78 |
-
}
|
| 79 |
-
|
| 80 |
-
@keyframes rotation {
|
| 81 |
-
0% { transform: rotate(0deg); }
|
| 82 |
-
100% { transform: rotate(360deg); }
|
| 83 |
-
}
|
| 84 |
-
|
| 85 |
-
/* Typing Indicator Animation */
|
| 86 |
-
.typing-indicator span {
|
| 87 |
-
height: 10px;
|
| 88 |
-
width: 10px;
|
| 89 |
-
background-color: #9E9E9E;
|
| 90 |
-
border-radius: 50%;
|
| 91 |
-
display: inline-block;
|
| 92 |
-
animation: bounce 1.4s infinite ease-in-out both;
|
| 93 |
-
}
|
| 94 |
-
.typing-indicator span:nth-child(1) { animation-delay: -0.32s; }
|
| 95 |
-
.typing-indicator span:nth-child(2) { animation-delay: -0.16s; }
|
| 96 |
-
@keyframes bounce {
|
| 97 |
-
0%, 80%, 100% { transform: scale(0); }
|
| 98 |
-
40% { transform: scale(1.0); }
|
| 99 |
-
}
|
| 100 |
-
|
| 101 |
-
/* Enhanced Markdown Styling for better readability and aesthetics */
|
| 102 |
-
.markdown-content p {
|
| 103 |
-
margin-bottom: 1rem;
|
| 104 |
-
line-height: 1.75;
|
| 105 |
-
}
|
| 106 |
-
.markdown-content h1, .markdown-content h2, .markdown-content h3, .markdown-content h4 {
|
| 107 |
-
font-family: 'Google Sans', sans-serif;
|
| 108 |
-
font-weight: 700;
|
| 109 |
-
margin-top: 1.75rem;
|
| 110 |
-
margin-bottom: 1rem;
|
| 111 |
-
line-height: 1.3;
|
| 112 |
-
}
|
| 113 |
-
.markdown-content h1 { font-size: 1.75em; border-bottom: 1px solid var(--card-border); padding-bottom: 0.5rem; }
|
| 114 |
-
.markdown-content h2 { font-size: 1.5em; }
|
| 115 |
-
.markdown-content h3 { font-size: 1.25em; }
|
| 116 |
-
.markdown-content h4 { font-size: 1.1em; }
|
| 117 |
-
.markdown-content ul, .markdown-content ol {
|
| 118 |
-
padding-left: 1.75rem;
|
| 119 |
-
margin-bottom: 1rem;
|
| 120 |
-
}
|
| 121 |
-
.markdown-content li {
|
| 122 |
-
margin-bottom: 0.5rem;
|
| 123 |
-
}
|
| 124 |
-
.dark .markdown-content ul > li::marker { color: var(--primary); }
|
| 125 |
-
.markdown-content ul > li::marker { color: var(--primary); }
|
| 126 |
-
.markdown-content a {
|
| 127 |
-
color: var(--primary);
|
| 128 |
-
text-decoration: none;
|
| 129 |
-
font-weight: 500;
|
| 130 |
-
border-bottom: 1px solid transparent;
|
| 131 |
-
transition: all 0.2s ease-in-out;
|
| 132 |
-
}
|
| 133 |
-
.markdown-content a:hover {
|
| 134 |
-
border-bottom-color: var(--primary-hover);
|
| 135 |
-
}
|
| 136 |
-
.markdown-content blockquote {
|
| 137 |
-
margin: 1.5rem 0;
|
| 138 |
-
padding-left: 1.5rem;
|
| 139 |
-
border-left: 4px solid var(--card-border);
|
| 140 |
-
color: #6c757d;
|
| 141 |
-
font-style: italic;
|
| 142 |
-
}
|
| 143 |
-
.dark .markdown-content blockquote {
|
| 144 |
-
color: #adb5bd;
|
| 145 |
-
}
|
| 146 |
-
.markdown-content hr {
|
| 147 |
-
border: none;
|
| 148 |
-
border-top: 1px solid var(--card-border);
|
| 149 |
-
margin: 2rem 0;
|
| 150 |
-
}
|
| 151 |
-
.markdown-content table {
|
| 152 |
-
width: 100%;
|
| 153 |
-
border-collapse: collapse;
|
| 154 |
-
margin: 1.5rem 0;
|
| 155 |
-
font-size: 0.9em;
|
| 156 |
-
box-shadow: 0 1px 3px rgba(0,0,0,0.05);
|
| 157 |
-
border-radius: 8px;
|
| 158 |
-
overflow: hidden;
|
| 159 |
-
}
|
| 160 |
-
.markdown-content th, .markdown-content td {
|
| 161 |
border: 1px solid var(--card-border);
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
}
|
| 169 |
-
.markdown-content code {
|
| 170 |
-
background-color: rgba(0,0,0,0.05);
|
| 171 |
-
padding: 0.2rem 0.4rem;
|
| 172 |
-
border-radius: 0.25rem;
|
| 173 |
-
font-family: 'Roboto Mono', monospace;
|
| 174 |
-
font-size: 0.9em;
|
| 175 |
-
}
|
| 176 |
-
.dark .markdown-content code {
|
| 177 |
-
background-color: rgba(255,255,255,0.1);
|
| 178 |
-
}
|
| 179 |
-
.markdown-content pre {
|
| 180 |
-
position: relative;
|
| 181 |
-
background-color: #f8f9fa;
|
| 182 |
-
border: 1px solid var(--card-border);
|
| 183 |
-
border-radius: 0.5rem;
|
| 184 |
-
margin-bottom: 1rem;
|
| 185 |
-
}
|
| 186 |
-
.dark .markdown-content pre {
|
| 187 |
-
background-color: #2e2f32;
|
| 188 |
-
}
|
| 189 |
-
.markdown-content pre code {
|
| 190 |
-
background: none;
|
| 191 |
-
padding: 1rem;
|
| 192 |
-
display: block;
|
| 193 |
-
overflow-x: auto;
|
| 194 |
-
}
|
| 195 |
-
.markdown-content pre .copy-code-btn {
|
| 196 |
-
position: absolute;
|
| 197 |
-
top: 0.5rem;
|
| 198 |
-
right: 0.5rem;
|
| 199 |
-
background-color: #e8eaed;
|
| 200 |
-
border: 1px solid #dadce0;
|
| 201 |
-
color: #5f6368;
|
| 202 |
-
padding: 0.3rem 0.6rem;
|
| 203 |
-
border-radius: 0.25rem;
|
| 204 |
-
cursor: pointer;
|
| 205 |
-
opacity: 0;
|
| 206 |
-
transition: opacity 0.2s;
|
| 207 |
font-size: 0.8em;
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
/*
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
btn.disabled = false;
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>CogniChat - Chat with your Documents</title>
|
| 7 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
| 8 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 9 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 10 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Google+Sans:wght@400;500;700&family=Roboto:wght@400;500&display=swap" rel="stylesheet">
|
| 11 |
+
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
| 12 |
+
<style>
|
| 13 |
+
:root {
|
| 14 |
+
--background: #f0f4f9;
|
| 15 |
+
--foreground: #1f1f1f;
|
| 16 |
+
--primary: #1a73e8;
|
| 17 |
+
--primary-hover: #1867cf;
|
| 18 |
+
--card: #ffffff;
|
| 19 |
+
--card-border: #dadce0;
|
| 20 |
+
--input-bg: #e8f0fe;
|
| 21 |
+
--user-bubble: #d9e7ff;
|
| 22 |
+
--bot-bubble: #f1f3f4;
|
| 23 |
+
--select-bg: #ffffff;
|
| 24 |
+
--select-border: #dadce0;
|
| 25 |
+
--select-text: #1f1f1f;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
.dark {
|
| 29 |
+
--background: #111827; /* Darker background */
|
| 30 |
+
--foreground: #e5e7eb;
|
| 31 |
+
--primary: #3b82f6; /* Adjusted primary blue */
|
| 32 |
+
--primary-hover: #60a5fa; /* Lighter hover blue */
|
| 33 |
+
--card: #1f2937; /* Dark card background */
|
| 34 |
+
--card-border: #4b5563; /* Greyer border */
|
| 35 |
+
--input-bg: #374151; /* Darker input background */
|
| 36 |
+
--user-bubble: #374151; /* Darker user bubble */
|
| 37 |
+
--bot-bubble: #374151; /* Darker bot bubble */
|
| 38 |
+
--select-bg: #374151;
|
| 39 |
+
--select-border: #6b7280;
|
| 40 |
+
--select-text: #f3f4f6;
|
| 41 |
+
--code-bg: #2d2d2d; /* Specific background for code blocks */
|
| 42 |
+
--code-text: #d4d4d4; /* Light grey text for code */
|
| 43 |
+
--copy-btn-bg: #4a4a4a;
|
| 44 |
+
--copy-btn-hover-bg: #5a5a5a;
|
| 45 |
+
--copy-btn-text: #e0e0e0;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
body {
|
| 49 |
+
font-family: 'Inter', 'Google Sans', 'Roboto', sans-serif;
|
| 50 |
+
background-color: var(--background);
|
| 51 |
+
color: var(--foreground);
|
| 52 |
+
overflow: hidden; /* Prevent body scroll */
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
#chat-window {
|
| 56 |
+
scroll-behavior: smooth; /* Ensure smooth programatic scroll */
|
| 57 |
+
}
|
| 58 |
+
#chat-window::-webkit-scrollbar { width: 8px; }
|
| 59 |
+
#chat-window::-webkit-scrollbar-track { background: transparent; }
|
| 60 |
+
#chat-window::-webkit-scrollbar-thumb { background-color: #4b5563; border-radius: 20px; }
|
| 61 |
+
.dark #chat-window::-webkit-scrollbar-thumb { background-color: #5f6368; }
|
| 62 |
+
|
| 63 |
+
.drop-zone--over {
|
| 64 |
+
border-color: var(--primary);
|
| 65 |
+
box-shadow: 0 0 20px rgba(59, 130, 246, 0.4);
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
.loader {
|
| 69 |
+
width: 48px;
|
| 70 |
+
height: 48px;
|
| 71 |
+
border: 3px solid var(--card-border);
|
| 72 |
+
border-radius: 50%;
|
| 73 |
+
display: inline-block;
|
| 74 |
+
position: relative;
|
| 75 |
+
box-sizing: border-box;
|
| 76 |
+
animation: rotation 1s linear infinite;
|
| 77 |
+
}
|
| 78 |
+
.loader::after {
|
| 79 |
+
content: '';
|
| 80 |
+
box-sizing: border-box;
|
| 81 |
+
position: absolute;
|
| 82 |
+
left: 50%;
|
| 83 |
+
top: 50%;
|
| 84 |
+
transform: translate(-50%, -50%);
|
| 85 |
+
width: 56px;
|
| 86 |
+
height: 56px;
|
| 87 |
+
border-radius: 50%;
|
| 88 |
+
border: 3px solid;
|
| 89 |
+
border-color: var(--primary) transparent;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
@keyframes rotation {
|
| 93 |
+
0% { transform: rotate(0deg); }
|
| 94 |
+
100% { transform: rotate(360deg); }
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
/* --- Updated Typing Indicator --- */
|
| 98 |
+
.typing-indicator {
|
| 99 |
+
display: inline-flex; /* Changed to inline-flex */
|
| 100 |
+
align-items: center;
|
| 101 |
+
padding: 8px 0; /* Add some vertical padding */
|
| 102 |
+
}
|
| 103 |
+
.typing-indicator span {
|
| 104 |
+
height: 8px; /* Slightly smaller dots */
|
| 105 |
+
width: 8px;
|
| 106 |
+
margin: 0 2px;
|
| 107 |
+
background-color: #9E9E9E;
|
| 108 |
+
border-radius: 50%;
|
| 109 |
+
opacity: 0; /* Start invisible */
|
| 110 |
+
animation: typing-pulse 1.4s infinite ease-in-out;
|
| 111 |
+
}
|
| 112 |
+
.typing-indicator span:nth-child(1) { animation-delay: 0s; }
|
| 113 |
+
.typing-indicator span:nth-child(2) { animation-delay: 0.2s; }
|
| 114 |
+
.typing-indicator span:nth-child(3) { animation-delay: 0.4s; }
|
| 115 |
+
|
| 116 |
+
@keyframes typing-pulse {
|
| 117 |
+
0%, 100% { opacity: 0; transform: scale(0.7); }
|
| 118 |
+
50% { opacity: 1; transform: scale(1); }
|
| 119 |
+
}
|
| 120 |
+
/* --- End Typing Indicator --- */
|
| 121 |
+
|
| 122 |
+
/* --- Updated Markdown Styling --- */
|
| 123 |
+
.markdown-content { /* Base styles for the content area */
|
| 124 |
+
line-height: 1.75;
|
| 125 |
+
}
|
| 126 |
+
.markdown-content p { margin-bottom: 1rem; }
|
| 127 |
+
.markdown-content h1, .markdown-content h2, .markdown-content h3,
|
| 128 |
+
.markdown-content h4, .markdown-content h5, .markdown-content h6 {
|
| 129 |
+
font-weight: 600;
|
| 130 |
+
margin-top: 1.5rem;
|
| 131 |
+
margin-bottom: 0.75rem;
|
| 132 |
+
line-height: 1.3;
|
| 133 |
+
}
|
| 134 |
+
.markdown-content h1 { font-size: 1.5em; border-bottom: 1px solid var(--card-border); padding-bottom: 0.3rem;}
|
| 135 |
+
.markdown-content h2 { font-size: 1.25em; }
|
| 136 |
+
.markdown-content h3 { font-size: 1.1em; }
|
| 137 |
+
.markdown-content ul, .markdown-content ol { padding-left: 1.75rem; margin-bottom: 1rem; }
|
| 138 |
+
.markdown-content li { margin-bottom: 0.5rem; }
|
| 139 |
+
.markdown-content a { color: var(--primary); text-decoration: none; font-weight: 500; }
|
| 140 |
+
.markdown-content a:hover { text-decoration: underline; }
|
| 141 |
+
.markdown-content strong, .markdown-content b { font-weight: 600; } /* Ensure bold works */
|
| 142 |
+
.markdown-content blockquote {
|
| 143 |
+
border-left: 4px solid var(--card-border);
|
| 144 |
+
padding-left: 1rem;
|
| 145 |
+
margin-left: 0;
|
| 146 |
+
margin-bottom: 1rem;
|
| 147 |
+
color: #a0aec0; /* Lighter text for quotes */
|
| 148 |
+
}
|
| 149 |
+
/* --- Code Block Styling --- */
|
| 150 |
+
.markdown-content pre {
|
| 151 |
+
position: relative;
|
| 152 |
+
background-color: var(--code-bg);
|
| 153 |
+
border: 1px solid var(--card-border);
|
| 154 |
+
border-radius: 0.5rem;
|
| 155 |
+
margin-bottom: 1rem;
|
| 156 |
+
font-size: 0.9em;
|
| 157 |
+
color: var(--code-text);
|
| 158 |
+
overflow: hidden; /* Hide horizontal overflow until hovered/focused */
|
| 159 |
+
}
|
| 160 |
+
.markdown-content pre code {
|
| 161 |
+
display: block;
|
| 162 |
+
padding: 1rem;
|
| 163 |
+
overflow-x: auto; /* Enable horizontal scroll on the code itself */
|
| 164 |
+
background: none !important; /* Override potential highlight.js background */
|
| 165 |
+
font-family: Consolas, Monaco, 'Andale Mono', 'Ubuntu Mono', monospace;
|
| 166 |
+
white-space: pre; /* Ensure whitespace is preserved */
|
| 167 |
+
}
|
| 168 |
+
/* --- Copy Button Styling --- */
|
| 169 |
+
.markdown-content pre .copy-code-btn {
|
| 170 |
position: absolute;
|
| 171 |
+
top: 0.5rem;
|
| 172 |
+
right: 0.5rem;
|
| 173 |
+
background-color: var(--copy-btn-bg);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
border: 1px solid var(--card-border);
|
| 175 |
+
color: var(--copy-btn-text);
|
| 176 |
+
padding: 0.3rem 0.6rem;
|
| 177 |
+
border-radius: 0.25rem;
|
| 178 |
+
cursor: pointer;
|
| 179 |
+
opacity: 0; /* Initially hidden */
|
| 180 |
+
transition: opacity 0.2s, background-color 0.2s;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
font-size: 0.8em;
|
| 182 |
+
display: flex; /* For icon alignment */
|
| 183 |
+
align-items: center;
|
| 184 |
+
gap: 0.25rem;
|
| 185 |
+
}
|
| 186 |
+
.markdown-content pre .copy-code-btn:hover {
|
| 187 |
+
background-color: var(--copy-btn-hover-bg);
|
| 188 |
+
}
|
| 189 |
+
.markdown-content pre:hover .copy-code-btn {
|
| 190 |
+
opacity: 1; /* Show on hover */
|
| 191 |
+
}
|
| 192 |
+
/* --- Inline Code Styling --- */
|
| 193 |
+
.markdown-content code:not(pre code) {
|
| 194 |
+
background-color: rgba(110, 118, 129, 0.4);
|
| 195 |
+
padding: 0.2em 0.4em;
|
| 196 |
+
margin: 0 0.1em; /* Add slight horizontal margin */
|
| 197 |
+
font-size: 85%;
|
| 198 |
+
border-radius: 6px;
|
| 199 |
+
font-family: Consolas, Monaco, 'Andale Mono', 'Ubuntu Mono', monospace;
|
| 200 |
+
}
|
| 201 |
+
/* --- End Markdown Styling --- */
|
| 202 |
+
|
| 203 |
+
.tts-button-loader {
|
| 204 |
+
width: 16px;
|
| 205 |
+
height: 16px;
|
| 206 |
+
border: 2px solid currentColor;
|
| 207 |
+
border-radius: 50%;
|
| 208 |
+
display: inline-block;
|
| 209 |
+
box-sizing: border-box;
|
| 210 |
+
animation: rotation 0.8s linear infinite;
|
| 211 |
+
border-bottom-color: transparent;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
/* --- Style for TTS controls container --- */
|
| 215 |
+
.tts-controls {
|
| 216 |
+
display: flex;
|
| 217 |
+
align-items: center;
|
| 218 |
+
gap: 0.5rem; /* Space between play and speed buttons */
|
| 219 |
+
margin-top: 0.5rem;
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
/* --- Style for Speed Cycle Button --- */
|
| 223 |
+
.speed-cycle-btn {
|
| 224 |
+
padding: 0.25rem 0.6rem; /* Smaller padding */
|
| 225 |
+
font-size: 0.75rem; /* Smaller text */
|
| 226 |
+
background-color: #4b5563; /* Grey background */
|
| 227 |
+
color: #e5e7eb; /* Light text */
|
| 228 |
+
border-radius: 9999px; /* Pill shape */
|
| 229 |
+
border: none;
|
| 230 |
+
cursor: pointer;
|
| 231 |
+
transition: background-color 0.2s;
|
| 232 |
+
white-space: nowrap; /* Prevent text wrapping */
|
| 233 |
+
margin-top: 0.5rem;
|
| 234 |
+
}
|
| 235 |
+
.speed-cycle-btn:hover {
|
| 236 |
+
background-color: #1f0bb8e6; /* Lighter grey on hover */
|
| 237 |
+
}
|
| 238 |
+
.speed-cycle-btn:disabled {
|
| 239 |
+
opacity: 0.5;
|
| 240 |
+
cursor: not-allowed;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
/* --- Select dropdown styles (kept for consistency if needed elsewhere) --- */
|
| 244 |
+
.select-wrapper {
|
| 245 |
+
position: relative;
|
| 246 |
+
}
|
| 247 |
+
.select-wrapper select {
|
| 248 |
+
background-color: var(--select-bg);
|
| 249 |
+
border: 1px solid var(--select-border);
|
| 250 |
+
color: var(--select-text);
|
| 251 |
+
padding: 0.75rem 2.5rem 0.75rem 1rem;
|
| 252 |
+
border-radius: 0.75rem;
|
| 253 |
+
font-size: 0.875rem;
|
| 254 |
+
width: 100%;
|
| 255 |
+
appearance: none;
|
| 256 |
+
-webkit-appearance: none;
|
| 257 |
+
transition: all 0.2s ease-in-out;
|
| 258 |
+
cursor: pointer;
|
| 259 |
+
background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%239ca3af' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e");
|
| 260 |
+
background-position: right 0.75rem center;
|
| 261 |
+
background-repeat: no-repeat;
|
| 262 |
+
background-size: 1.25em 1.25em;
|
| 263 |
+
}
|
| 264 |
+
</style>
|
| 265 |
+
</head>
|
| 266 |
+
<body class="w-screen h-screen dark">
|
| 267 |
+
<main id="main-content" class="h-full flex flex-col transition-opacity duration-500">
|
| 268 |
+
<div id="chat-container" class="hidden flex-1 flex flex-col w-full mx-auto overflow-hidden">
|
| 269 |
+
|
| 270 |
+
<header class="p-4 border-b border-[var(--card-border)] flex-shrink-0 flex justify-between items-center w-full">
|
| 271 |
+
<div class="w-1/4"></div> <div class="w-1/2 text-center">
|
| 272 |
+
<h1 class="text-xl font-medium tracking-wide">CogniChat ✨</h1>
|
| 273 |
+
<p id="chat-filename" class="text-xs text-gray-400 mt-1 truncate"></p>
|
| 274 |
+
</div>
|
| 275 |
+
<div id="chat-session-info" class="w-1/4 text-right text-xs space-y-1 pr-4">
|
| 276 |
+
</div>
|
| 277 |
+
</header>
|
| 278 |
+
|
| 279 |
+
<div id="chat-window" class="flex-1 overflow-y-auto p-4 md:p-6 lg:p-10">
|
| 280 |
+
<div id="chat-content" class="max-w-4xl mx-auto space-y-8"></div>
|
| 281 |
+
</div>
|
| 282 |
+
<div class="p-4 flex-shrink-0 bg-opacity-50 backdrop-blur-md border-t border-[var(--card-border)]">
|
| 283 |
+
<form id="chat-form" class="max-w-4xl mx-auto bg-[var(--card)] rounded-full p-2 flex items-center shadow-lg border border-[var(--card-border)] focus-within:ring-2 focus-within:ring-[var(--primary)] transition-all">
|
| 284 |
+
<input type="text" id="chat-input" placeholder="Ask a question about your documents..." class="flex-grow bg-transparent focus:outline-none px-4 text-sm" autocomplete="off">
|
| 285 |
+
<button type="submit" id="chat-submit-btn" class="bg-[var(--primary)] hover:bg-[var(--primary-hover)] text-white p-2.5 rounded-full transition-all duration-200 disabled:opacity-50 disabled:cursor-not-allowed" title="Send">
|
| 286 |
+
<svg class="w-5 h-5" viewBox="0 0 20 20" fill="currentColor"><path fill-rule="evenodd" d="M10 18a8 8 0 100-16 8 8 0 000 16zm3.707-8.707l-3-3a1 1 0 00-1.414 1.414L10.586 9H7a1 1 0 100 2h3.586l-1.293 1.293a1 1 0 101.414 1.414l3-3a1 1 0 000-1.414z" clip-rule="evenodd"></path></svg>
|
| 287 |
+
</button>
|
| 288 |
+
</form>
|
| 289 |
+
</div>
|
| 290 |
+
</div>
|
| 291 |
+
|
| 292 |
+
<div id="upload-container" class="flex-1 flex flex-col items-center justify-center p-8 transition-opacity duration-300">
|
| 293 |
+
<div class="text-center max-w-xl w-full">
|
| 294 |
+
<h1 class="text-5xl font-bold mb-3 tracking-tight">CogniChat ✨</h1>
|
| 295 |
+
<p class="text-lg text-gray-400 mb-8">Upload your documents to start a conversation.</p>
|
| 296 |
+
<div class="mb-8 p-5 bg-[var(--card)] rounded-2xl border border-[var(--card-border)] shadow-lg">
|
| 297 |
+
<div class="flex flex-col sm:flex-row items-center gap-6">
|
| 298 |
+
<div class="w-full sm:w-1/2">
|
| 299 |
+
<div class="flex items-center gap-2 mb-2">
|
| 300 |
+
<svg class="w-5 h-5 text-gray-400" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"><path d="M7 3a1 1 0 000 2h6a1 1 0 100-2H7zM4 7a1 1 0 011-1h10a1 1 0 110 2H5a1 1 0 01-1-1zM2 11a2 2 0 012-2h12a2 2 0 012 2v4a2 2 0 01-2 2H4a2 2 0 01-2-2v-4z" /></svg>
|
| 301 |
+
<label for="model-select" class="block text-sm font-medium text-gray-300">Model</label>
|
| 302 |
+
</div>
|
| 303 |
+
<div class="select-wrapper">
|
| 304 |
+
<select id="model-select" name="model_name">
|
| 305 |
+
<option value="moonshotai/kimi-k2-instruct" selected>Kimi Instruct</option>
|
| 306 |
+
<option value="openai/gpt-oss-20b">GPT OSS 20b</option>
|
| 307 |
+
<option value="llama-3.3-70b-versatile">Llama 3.3 70b</option>
|
| 308 |
+
<option value="llama-3.1-8b-instant">Llama 3.1 8b Instant</option>
|
| 309 |
+
</select>
|
| 310 |
+
</div>
|
| 311 |
+
</div>
|
| 312 |
+
<div class="w-full sm:w-1/2">
|
| 313 |
+
<div class="flex items-center gap-2 mb-2">
|
| 314 |
+
<svg class="w-5 h-5 text-gray-400" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"><path fill-rule="evenodd" d="M5.5 16a3.5 3.5 0 100-7 3.5 3.5 0 000 7zM12 5.5a3.5 3.5 0 11-7 0 3.5 3.5 0 017 0zM14.5 16a3.5 3.5 0 100-7 3.5 3.5 0 000 7z" clip-rule="evenodd" /></svg>
|
| 315 |
+
<label for="temperature-select" class="block text-sm font-medium text-gray-300">Mode</label>
|
| 316 |
+
</div>
|
| 317 |
+
<div class="select-wrapper">
|
| 318 |
+
<select id="temperature-select" name="temperature">
|
| 319 |
+
<option value="0.2" selected>0.2 - Precise</option>
|
| 320 |
+
<option value="0.4">0.4 - Confident</option>
|
| 321 |
+
<option value="0.6">0.6 - Balanced</option>
|
| 322 |
+
<option value="0.8">0.8 - Flexible</option>
|
| 323 |
+
<option value="1.0">1.0 - Creative</option>
|
| 324 |
+
</select>
|
| 325 |
+
</div>
|
| 326 |
+
</div>
|
| 327 |
+
</div>
|
| 328 |
+
<p class="text-xs text-gray-500 mt-4 text-center">Higher creativity modes may reduce factual accuracy.</p>
|
| 329 |
+
</div>
|
| 330 |
+
<div id="drop-zone" class="w-full text-center border-2 border-dashed border-[var(--card-border)] rounded-2xl p-10 transition-all duration-300 cursor-pointer hover:bg-[var(--card)] hover:border-[var(--primary)]">
|
| 331 |
+
<div class="flex flex-col items-center justify-center pointer-events-none">
|
| 332 |
+
<svg class="mx-auto h-12 w-12 text-gray-500" fill="none" viewBox="0 0 24 24" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M12 16.5V9.75m0 0l3-3m-3 3l-3 3M6.75 19.5a4.5 4.5 0 01-1.41-8.775 5.25 5.25 0 0110.233-2.33 3 3 0 013.758 3.848A3.752 3.752 0 0118 19.5H6.75z"></path></svg>
|
| 333 |
+
<p class="mt-4 text-sm font-medium text-gray-400">Drag & drop files or <span class="text-[var(--primary)] font-semibold">click to upload</span></p>
|
| 334 |
+
<p class="text-xs text-gray-400 mt-1">Supports PDF, DOCX, TXT</p>
|
| 335 |
+
<p id="file-name" class="mt-2 text-xs text-gray-500"></p>
|
| 336 |
+
</div>
|
| 337 |
+
<input id="file-upload" type="file" class="hidden" accept=".pdf,.txt,.docx" multiple>
|
| 338 |
+
</div>
|
| 339 |
+
</div>
|
| 340 |
+
</div>
|
| 341 |
+
|
| 342 |
+
<div id="loading-overlay" class="hidden fixed inset-0 bg-[var(--background)] bg-opacity-80 backdrop-blur-sm flex flex-col items-center justify-center z-50">
|
| 343 |
+
<div class="loader"></div>
|
| 344 |
+
<p id="loading-text" class="mt-6 text-sm font-medium"></p>
|
| 345 |
+
<p id="loading-subtext" class="mt-2 text-xs text-gray-400"></p>
|
| 346 |
+
</div>
|
| 347 |
+
</main>
|
| 348 |
+
|
| 349 |
+
<script>
|
| 350 |
+
document.addEventListener('DOMContentLoaded', () => {
|
| 351 |
+
const uploadContainer = document.getElementById('upload-container');
|
| 352 |
+
const chatContainer = document.getElementById('chat-container');
|
| 353 |
+
const dropZone = document.getElementById('drop-zone');
|
| 354 |
+
const fileUploadInput = document.getElementById('file-upload');
|
| 355 |
+
const fileNameSpan = document.getElementById('file-name');
|
| 356 |
+
const loadingOverlay = document.getElementById('loading-overlay');
|
| 357 |
+
const loadingText = document.getElementById('loading-text');
|
| 358 |
+
const loadingSubtext = document.getElementById('loading-subtext');
|
| 359 |
+
const chatForm = document.getElementById('chat-form');
|
| 360 |
+
const chatInput = document.getElementById('chat-input');
|
| 361 |
+
const chatSubmitBtn = document.getElementById('chat-submit-btn');
|
| 362 |
+
const chatWindow = document.getElementById('chat-window');
|
| 363 |
+
const chatContent = document.getElementById('chat-content');
|
| 364 |
+
const modelSelect = document.getElementById('model-select');
|
| 365 |
+
const temperatureSelect = document.getElementById('temperature-select');
|
| 366 |
+
const chatFilename = document.getElementById('chat-filename');
|
| 367 |
+
const chatSessionInfo = document.getElementById('chat-session-info');
|
| 368 |
+
|
| 369 |
+
let sessionId = sessionStorage.getItem('cognichat_session_id');
|
| 370 |
+
let currentModelInfo = JSON.parse(sessionStorage.getItem('cognichat_model_info'));
|
| 371 |
+
|
| 372 |
+
marked.setOptions({
|
| 373 |
+
breaks: true,
|
| 374 |
+
gfm: true,
|
| 375 |
+
});
|
| 376 |
+
|
| 377 |
+
if (sessionId && currentModelInfo) {
|
| 378 |
+
console.log("Restoring session:", sessionId);
|
| 379 |
+
uploadContainer.classList.add('hidden');
|
| 380 |
+
chatContainer.classList.remove('hidden');
|
| 381 |
+
chatFilename.innerHTML = `Chatting with: <strong class="font-semibold">${sessionStorage.getItem('cognichat_filename') || 'documents'}</strong>`;
|
| 382 |
+
chatFilename.title = sessionStorage.getItem('cognichat_filename') || 'documents';
|
| 383 |
+
chatSessionInfo.innerHTML = `
|
| 384 |
+
<p>Model: ${currentModelInfo.simpleModelName}</p>
|
| 385 |
+
<p>Mode: ${currentModelInfo.mode}</p>
|
| 386 |
+
<button class="mt-1 text-xs text-blue-400 hover:text-blue-300 focus:outline-none" onclick="sessionStorage.clear(); location.reload();">New Chat</button>`;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
// --- File Upload Logic ---
|
| 391 |
+
dropZone.addEventListener('click', () => fileUploadInput.click());
|
| 392 |
+
|
| 393 |
+
['dragenter', 'dragover', 'dragleave', 'drop'].forEach(eventName => {
|
| 394 |
+
dropZone.addEventListener(eventName, e => {e.preventDefault(); e.stopPropagation();}, false);
|
| 395 |
+
document.body.addEventListener(eventName, e => {e.preventDefault(); e.stopPropagation();}, false);
|
| 396 |
+
});
|
| 397 |
+
['dragenter', 'dragover'].forEach(eventName => dropZone.addEventListener(eventName, () => dropZone.classList.add('drop-zone--over')));
|
| 398 |
+
['dragleave', 'drop'].forEach(eventName => dropZone.addEventListener(eventName, () => dropZone.classList.remove('drop-zone--over')));
|
| 399 |
+
|
| 400 |
+
dropZone.addEventListener('drop', (e) => {
|
| 401 |
+
if (e.dataTransfer.files.length > 0) handleFiles(e.dataTransfer.files);
|
| 402 |
+
});
|
| 403 |
+
fileUploadInput.addEventListener('change', (e) => {
|
| 404 |
+
if (e.target.files.length > 0) handleFiles(e.target.files);
|
| 405 |
+
});
|
| 406 |
+
|
| 407 |
+
async function handleFiles(files) {
|
| 408 |
+
const formData = new FormData();
|
| 409 |
+
let fileNames = Array.from(files).map(f => f.name);
|
| 410 |
+
for (const file of files) { formData.append('file', file); }
|
| 411 |
+
|
| 412 |
+
formData.append('model_name', modelSelect.value);
|
| 413 |
+
formData.append('temperature', temperatureSelect.value);
|
| 414 |
+
|
| 415 |
+
fileNameSpan.textContent = `Selected: ${fileNames.join(', ')}`;
|
| 416 |
+
await uploadAndProcessFiles(formData);
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
async function uploadAndProcessFiles(formData) {
|
| 420 |
+
loadingOverlay.classList.remove('hidden');
|
| 421 |
+
loadingText.textContent = `Processing document(s)...`;
|
| 422 |
+
loadingSubtext.textContent = "Creating a knowledge base... this might take a minute 🧠";
|
| 423 |
+
chatContent.innerHTML = ''; // Clear previous chat content on new upload
|
| 424 |
+
|
| 425 |
+
try {
|
| 426 |
+
const response = await fetch('/upload', { method: 'POST', body: formData });
|
| 427 |
+
const result = await response.json();
|
| 428 |
+
if (!response.ok) throw new Error(result.message || 'Unknown error occurred during upload.');
|
| 429 |
+
|
| 430 |
+
sessionId = result.session_id;
|
| 431 |
+
sessionStorage.setItem('cognichat_session_id', sessionId);
|
| 432 |
+
|
| 433 |
+
const modelOption = modelSelect.querySelector(`option[value="${result.model_name}"]`);
|
| 434 |
+
const simpleModelName = modelOption ? modelOption.textContent : result.model_name; // Adjust if needed
|
| 435 |
+
|
| 436 |
+
currentModelInfo = {
|
| 437 |
+
model: result.model_name,
|
| 438 |
+
mode: result.mode,
|
| 439 |
+
simpleModelName: simpleModelName // Use the derived simpler name
|
| 440 |
+
};
|
| 441 |
+
sessionStorage.setItem('cognichat_model_info', JSON.stringify(currentModelInfo)); // Store model info
|
| 442 |
+
sessionStorage.setItem('cognichat_filename', result.filename); // Store filename
|
| 443 |
+
|
| 444 |
+
chatFilename.innerHTML = `Chatting with: <strong class="font-semibold">${result.filename}</strong>`;
|
| 445 |
+
chatFilename.title = result.filename;
|
| 446 |
+
|
| 447 |
+
chatSessionInfo.innerHTML = `
|
| 448 |
+
<p>Model: ${currentModelInfo.simpleModelName}</p>
|
| 449 |
+
<p>Mode: ${currentModelInfo.mode}</p>
|
| 450 |
+
<button class="mt-1 text-xs text-blue-400 hover:text-blue-300 focus:outline-none" onclick="sessionStorage.clear(); location.reload();">New Chat</button>`;
|
| 451 |
+
|
| 452 |
+
uploadContainer.classList.add('hidden');
|
| 453 |
+
chatContainer.classList.remove('hidden');
|
| 454 |
+
appendMessage("Hello! 👋 I've analyzed your documents. What would you like to know?", "bot", currentModelInfo);
|
| 455 |
+
|
| 456 |
+
} catch (error) {
|
| 457 |
+
console.error('Upload error:', error);
|
| 458 |
+
alert(`Error processing files: ${error.message}`);
|
| 459 |
+
sessionStorage.clear(); // Clear session if upload fails
|
| 460 |
+
} finally {
|
| 461 |
+
loadingOverlay.classList.add('hidden');
|
| 462 |
+
fileNameSpan.textContent = '';
|
| 463 |
+
fileUploadInput.value = '';
|
| 464 |
+
}
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
// --- Chat Logic (Using Server-Sent Events - UPDATED FOR STREAMING & INDICATOR) ---
|
| 468 |
+
chatForm.addEventListener('submit', async (e) => {
|
| 469 |
+
e.preventDefault();
|
| 470 |
+
const question = chatInput.value.trim();
|
| 471 |
+
if (!question || !sessionId) {
|
| 472 |
+
console.warn("Submit ignored: No question or session ID.");
|
| 473 |
+
return;
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
appendMessage(question, 'user');
|
| 477 |
+
chatInput.value = '';
|
| 478 |
+
chatInput.disabled = true;
|
| 479 |
+
chatSubmitBtn.disabled = true;
|
| 480 |
+
|
| 481 |
+
let botMessageContainer;
|
| 482 |
+
let contentDiv;
|
| 483 |
+
let fullResponse = '';
|
| 484 |
+
let eventSource = null;
|
| 485 |
+
let inactivityTimeout = null;
|
| 486 |
+
let streamClosedCleanly = false; // Flag to check if stream ended normally vs error
|
| 487 |
+
let typingIndicatorElement = null; // Store indicator element
|
| 488 |
+
|
| 489 |
+
// Function to finalize chat (called on error, timeout, or successful completion)
|
| 490 |
+
function finalizeChat(isError = false) {
|
| 491 |
+
console.log(`Finalizing chat. Was error: ${isError}, Stream ended cleanly: ${streamClosedCleanly}`);
|
| 492 |
+
if (eventSource) {
|
| 493 |
+
eventSource.close();
|
| 494 |
+
eventSource = null;
|
| 495 |
+
console.log("SSE connection explicitly closed in finalizeChat.");
|
| 496 |
+
}
|
| 497 |
+
if (inactivityTimeout) {
|
| 498 |
+
clearTimeout(inactivityTimeout);
|
| 499 |
+
inactivityTimeout = null;
|
| 500 |
+
}
|
| 501 |
+
// Remove indicator if it's still there
|
| 502 |
+
if (typingIndicatorElement && typingIndicatorElement.parentNode) {
|
| 503 |
+
typingIndicatorElement.parentNode.removeChild(typingIndicatorElement);
|
| 504 |
+
typingIndicatorElement = null;
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
if (botMessageContainer && contentDiv) {
|
| 509 |
+
const hasErrorMsg = contentDiv.innerHTML.includes('⚠️');
|
| 510 |
+
// Ensure final render, apply copy buttons and TTS ONLY if response wasn't an error
|
| 511 |
+
if (!hasErrorMsg && fullResponse) {
|
| 512 |
+
// Re-parse the complete response to ensure correct final Markdown
|
| 513 |
+
contentDiv.innerHTML = marked.parse(fullResponse);
|
| 514 |
+
// Apply final touches like copy buttons and TTS
|
| 515 |
+
contentDiv.querySelectorAll('pre').forEach(addCopyButton);
|
| 516 |
+
addTextToSpeechControls(botMessageContainer, fullResponse);
|
| 517 |
+
// Optional: Final highlighting if using highlight.js
|
| 518 |
+
// contentDiv.querySelectorAll('pre code').forEach(block => hljs.highlightElement(block));
|
| 519 |
+
}
|
| 520 |
+
scrollToBottom(true); // Ensure scrolled to the end
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
// Always re-enable input fields
|
| 524 |
+
chatInput.disabled = false;
|
| 525 |
+
chatSubmitBtn.disabled = false;
|
| 526 |
+
chatInput.focus();
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
try {
|
| 530 |
+
// Create the bot message container *before* starting the stream
|
| 531 |
+
botMessageContainer = appendMessage('', 'bot', currentModelInfo); // Append empty bot message
|
| 532 |
+
contentDiv = botMessageContainer.querySelector('.markdown-content');
|
| 533 |
+
|
| 534 |
+
// Show typing indicator *inside* the contentDiv
|
| 535 |
+
typingIndicatorElement = showTypingIndicator();
|
| 536 |
+
if (contentDiv) {
|
| 537 |
+
contentDiv.appendChild(typingIndicatorElement);
|
| 538 |
+
scrollToBottom(true); // Scroll to show indicator
|
| 539 |
+
} else {
|
| 540 |
+
console.error("Could not find contentDiv to append typing indicator.");
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
// Establish SSE connection via GET request
|
| 545 |
+
const chatUrl = `/chat?question=${encodeURIComponent(question)}&session_id=${encodeURIComponent(sessionId)}`;
|
| 546 |
+
console.log("Connecting to SSE:", chatUrl);
|
| 547 |
+
eventSource = new EventSource(chatUrl);
|
| 548 |
+
|
| 549 |
+
eventSource.onopen = () => {
|
| 550 |
+
console.log("SSE Connection opened.");
|
| 551 |
+
// Remove indicator when connection opens and stream is about to start
|
| 552 |
+
if (typingIndicatorElement && typingIndicatorElement.parentNode) {
|
| 553 |
+
typingIndicatorElement.parentNode.removeChild(typingIndicatorElement);
|
| 554 |
+
typingIndicatorElement = null;
|
| 555 |
+
}
|
| 556 |
+
streamClosedCleanly = false; // Reset flag on new connection
|
| 557 |
+
};
|
| 558 |
+
|
| 559 |
+
eventSource.onmessage = (event) => {
|
| 560 |
+
// Remove indicator on first message just in case onopen didn't fire reliably
|
| 561 |
+
if (typingIndicatorElement && typingIndicatorElement.parentNode) {
|
| 562 |
+
typingIndicatorElement.parentNode.removeChild(typingIndicatorElement);
|
| 563 |
+
typingIndicatorElement = null;
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
// Reset inactivity timeout on each message
|
| 567 |
+
if (inactivityTimeout) clearTimeout(inactivityTimeout);
|
| 568 |
+
inactivityTimeout = setTimeout(() => {
|
| 569 |
+
console.log("Inactivity timeout triggered after message.");
|
| 570 |
+
streamClosedCleanly = true; // Assume normal end
|
| 571 |
+
finalizeChat(false);
|
| 572 |
+
}, 5000); // 5 seconds of inactivity
|
| 573 |
+
|
| 574 |
+
let data;
|
| 575 |
+
try {
|
| 576 |
+
data = JSON.parse(event.data);
|
| 577 |
+
} catch (parseError){
|
| 578 |
+
console.error("Failed to parse SSE data:", event.data, parseError);
|
| 579 |
+
contentDiv.innerHTML += `<p class="text-red-400 text-sm">Error receiving data chunk.</p>`;
|
| 580 |
+
return;
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
if (data.error) {
|
| 584 |
+
console.error('SSE Error from server:', data.error);
|
| 585 |
+
contentDiv.innerHTML = `<p class="text-red-500 font-semibold">⚠️ Server Error: ${data.error}</p>`;
|
| 586 |
+
streamClosedCleanly = false;
|
| 587 |
+
finalizeChat(true); // Pass true for error
|
| 588 |
+
return;
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
if (data.token !== undefined && data.token !== null) {
|
| 592 |
+
fullResponse += data.token;
|
| 593 |
+
// Update content by parsing the accumulated response
|
| 594 |
+
contentDiv.innerHTML = marked.parse(fullResponse);
|
| 595 |
+
scrollToBottom(); // Scroll smoothly as content arrives
|
| 596 |
+
}
|
| 597 |
+
};
|
| 598 |
+
|
| 599 |
+
eventSource.onerror = (error) => {
|
| 600 |
+
console.error('SSE connection error event:', error);
|
| 601 |
+
// Remove indicator on error
|
| 602 |
+
if (typingIndicatorElement && typingIndicatorElement.parentNode) {
|
| 603 |
+
typingIndicatorElement.parentNode.removeChild(typingIndicatorElement);
|
| 604 |
+
typingIndicatorElement = null;
|
| 605 |
+
}
|
| 606 |
+
// Don't show generic error if we received data and the stream likely just closed normally
|
| 607 |
+
if (!fullResponse && !streamClosedCleanly) { // Only show error if nothing received AND not already cleanly closed
|
| 608 |
+
const errorMsg = "⚠️ Connection error. Please try again.";
|
| 609 |
+
if (contentDiv) {
|
| 610 |
+
contentDiv.innerHTML = `<p class="text-red-500 font-semibold">${errorMsg}</p>`;
|
| 611 |
+
} else {
|
| 612 |
+
// Fallback if container wasn't created somehow
|
| 613 |
+
appendMessage(errorMsg, 'bot', currentModelInfo); // Pass model info here too
|
| 614 |
+
}
|
| 615 |
+
streamClosedCleanly = false;
|
| 616 |
+
} else if (!streamClosedCleanly) {
|
| 617 |
+
// If we received data, assume it's a normal closure misinterpreted as error
|
| 618 |
+
console.log("SSE connection closed (likely normal end detected by onerror).");
|
| 619 |
+
streamClosedCleanly = true; // Mark as clean closure NOW
|
| 620 |
+
} else {
|
| 621 |
+
console.log("SSE onerror event after stream already marked cleanly closed.")
|
| 622 |
+
}
|
| 623 |
+
finalizeChat(!streamClosedCleanly); // Finalize, indicate error if not clean
|
| 624 |
+
};
|
| 625 |
+
|
| 626 |
+
} catch (error) {
|
| 627 |
+
// For setup errors before SSE starts
|
| 628 |
+
console.error('Chat setup error:', error);
|
| 629 |
+
// Remove indicator on setup error
|
| 630 |
+
if (typingIndicatorElement && typingIndicatorElement.parentNode) {
|
| 631 |
+
typingIndicatorElement.parentNode.removeChild(typingIndicatorElement);
|
| 632 |
+
typingIndicatorElement = null;
|
| 633 |
+
}
|
| 634 |
+
if (botMessageContainer && contentDiv) {
|
| 635 |
+
contentDiv.innerHTML = `<p class="text-red-500 font-semibold">⚠️ Error starting chat: ${error.message}</p>`;
|
| 636 |
+
} else {
|
| 637 |
+
appendMessage(`Error starting chat: ${error.message}`, 'bot', currentModelInfo); // Pass model info
|
| 638 |
+
}
|
| 639 |
+
finalizeChat(true);
|
| 640 |
+
}
|
| 641 |
+
});
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
// --- UI Helper Functions ---
|
| 645 |
+
|
| 646 |
+
function appendMessage(text, sender, modelInfo = null) {
|
| 647 |
+
const messageWrapper = document.createElement('div');
|
| 648 |
+
const iconSVG = sender === 'user'
|
| 649 |
+
? `<div class="bg-blue-200 dark:bg-gray-700 p-2.5 rounded-full flex-shrink-0 mt-1 self-start"><svg class="w-5 h-5 text-blue-700 dark:text-blue-300" viewBox="0 0 24 24"><path fill="currentColor" d="M12 12c2.21 0 4-1.79 4-4s-1.79-4-4-4-4 1.79-4 4 1.79 4 4 4zm0 2c-2.67 0-8 1.34-8 4v2h16v-2c0-2.66-5.33-4-8-4z"></path></svg></div>`
|
| 650 |
+
: `<div class="bg-gray-200 dark:bg-gray-700 rounded-full flex-shrink-0 mt-1 self-start text-xl flex items-center justify-center w-10 h-10">✨</div>`;
|
| 651 |
+
|
| 652 |
+
let senderHTML;
|
| 653 |
+
if (sender === 'user') {
|
| 654 |
+
senderHTML = '<p class="font-medium text-sm mb-1">You</p>';
|
| 655 |
+
} else {
|
| 656 |
+
let modelInfoHTML = '';
|
| 657 |
+
const displayInfo = modelInfo || currentModelInfo;
|
| 658 |
+
if (displayInfo && displayInfo.simpleModelName) {
|
| 659 |
+
modelInfoHTML = `
|
| 660 |
+
<span class="ml-2 text-xs font-normal text-gray-400">
|
| 661 |
+
(Model: ${displayInfo.simpleModelName} | Mode: ${displayInfo.mode})
|
| 662 |
+
</span>
|
| 663 |
+
`;
|
| 664 |
+
}
|
| 665 |
+
senderHTML = `<div class="font-medium text-sm mb-1 flex items-center">CogniChat ${modelInfoHTML}</div>`;
|
| 666 |
+
}
|
| 667 |
+
|
| 668 |
+
messageWrapper.className = `flex items-start gap-3`;
|
| 669 |
+
// Ensure markdown-content div exists even if text is empty for the indicator
|
| 670 |
+
messageWrapper.innerHTML = `
|
| 671 |
+
${iconSVG}
|
| 672 |
+
<div class="flex-1 pt-1 min-w-0"> ${senderHTML}
|
| 673 |
+
<div class="text-base markdown-content prose dark:prose-invert max-w-none">${text ? marked.parse(text) : ''}</div>
|
| 674 |
+
<div class="tts-controls mt-2"></div>
|
| 675 |
+
</div>
|
| 676 |
+
`;
|
| 677 |
+
chatContent.appendChild(messageWrapper);
|
| 678 |
+
// Force scroll only when adding user message or initial bot message with content
|
| 679 |
+
if (sender === 'user' || text) {
|
| 680 |
+
scrollToBottom(true);
|
| 681 |
+
}
|
| 682 |
+
// Return the container that holds the sender name and content div
|
| 683 |
+
return messageWrapper.querySelector('.flex-1');
|
| 684 |
+
}
|
| 685 |
+
|
| 686 |
+
// --- UPDATED showTypingIndicator ---
|
| 687 |
+
function showTypingIndicator() {
|
| 688 |
+
const indicator = document.createElement('div');
|
| 689 |
+
indicator.className = 'typing-indicator'; // Use the main class
|
| 690 |
+
indicator.innerHTML = '<span></span><span></span><span></span>';
|
| 691 |
+
// Don't append here, just return the element
|
| 692 |
+
return indicator;
|
| 693 |
+
}
|
| 694 |
+
// --- End UPDATED showTypingIndicator ---
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
function scrollToBottom(force = false) {
|
| 698 |
+
const isNearBottom = chatWindow.scrollHeight - chatWindow.clientHeight <= chatWindow.scrollTop + 150; // Threshold
|
| 699 |
+
|
| 700 |
+
if (force || isNearBottom) {
|
| 701 |
+
requestAnimationFrame(() => { // Use rAF for smoother render loop
|
| 702 |
+
chatWindow.scrollTo({
|
| 703 |
+
top: chatWindow.scrollHeight,
|
| 704 |
+
behavior: 'smooth'
|
| 705 |
+
});
|
| 706 |
+
});
|
| 707 |
+
}
|
| 708 |
+
}
|
| 709 |
+
|
| 710 |
+
function addCopyButton(pre) {
|
| 711 |
+
if (pre.querySelector('.copy-code-btn')) return;
|
| 712 |
+
|
| 713 |
+
const button = document.createElement('button');
|
| 714 |
+
// Updated classes for better styling
|
| 715 |
+
button.className = 'copy-code-btn absolute top-2 right-2 p-1 rounded bg-[var(--copy-btn-bg)] text-[var(--copy-btn-text)] hover:bg-[var(--copy-btn-hover-bg)] transition-opacity duration-200 flex items-center gap-1 text-xs';
|
| 716 |
+
button.innerHTML = `<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"></path></svg> Copy`;
|
| 717 |
+
pre.style.position = 'relative'; // Ensure parent is relative for absolute positioning
|
| 718 |
+
pre.appendChild(button);
|
| 719 |
+
|
| 720 |
+
button.addEventListener('click', () => {
|
| 721 |
+
const code = pre.querySelector('code')?.innerText || '';
|
| 722 |
+
navigator.clipboard.writeText(code)
|
| 723 |
+
.then(() => {
|
| 724 |
+
button.textContent = 'Copied!';
|
| 725 |
+
setTimeout(() => button.innerHTML = `<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"></path></svg> Copy`, 1500);
|
| 726 |
+
})
|
| 727 |
+
.catch(err => {
|
| 728 |
+
console.error('Failed to copy code: ', err);
|
| 729 |
+
button.textContent = 'Error';
|
| 730 |
+
setTimeout(() => button.innerHTML = `<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"></path></svg> Copy`, 1500);
|
| 731 |
+
});
|
| 732 |
+
});
|
| 733 |
+
}
|
| 734 |
+
|
| 735 |
+
// --- TTS Functions (UPDATED FOR SPEED CYCLE) ---
|
| 736 |
+
let currentAudio = null;
|
| 737 |
+
let currentPlayingButton = null; // Stores the currently active *play/pause* button
|
| 738 |
+
const playIconSVG = `<svg class="w-4 h-4" fill="currentColor" viewBox="0 0 20 20"><path d="M6.3 2.841A1.5 1.5 0 004 4.11V15.89a1.5 1.5 0 002.3 1.269l9.344-5.89a1.5 1.5 0 000-2.538L6.3 2.84z"></path></svg>`;
|
| 739 |
+
const pauseIconSVG = `<svg class="w-4 h-4" fill="currentColor" viewBox="0 0 20 20"><path d="M5.75 4.75a.75.75 0 00-.75.75v9.5c0 .414.336.75.75.75h1.5a.75.75 0 00.75-.75v-9.5a.75.75 0 00-.75-.75h-1.5zm6.5 0a.75.75 0 00-.75.75v9.5c0 .414.336.75.75.75h1.5a.75.75 0 00.75-.75v-9.5a.75.75 0 00-.75-.75h-1.5z"></path></svg>`;
|
| 740 |
+
const availableSpeeds = [1.0, 1.5, 0.75]; // Normal, Fast, Slow
|
| 741 |
+
|
| 742 |
+
// --- UPDATED: addTextToSpeechControls ---
|
| 743 |
+
function addTextToSpeechControls(messageBubble, text) {
|
| 744 |
+
if (!text || !text.trim()) return;
|
| 745 |
+
const ttsControls = messageBubble.querySelector('.tts-controls');
|
| 746 |
+
if (!ttsControls || ttsControls.querySelector('.speak-btn')) return; // Avoid adding duplicates
|
| 747 |
+
|
| 748 |
+
// Play/Pause Button
|
| 749 |
+
const speakButton = document.createElement('button');
|
| 750 |
+
speakButton.className = 'speak-btn mt-2 px-3 py-1.5 bg-blue-700 text-white rounded-full text-xs font-medium hover:bg-blue-800 transition-colors flex items-center gap-1.5 disabled:opacity-50 disabled:cursor-not-allowed';
|
| 751 |
+
speakButton.title = 'Listen to this message';
|
| 752 |
+
speakButton.innerHTML = `${playIconSVG} <span>Listen</span>`;
|
| 753 |
+
speakButton.setAttribute('data-current-speed', '1.0'); // Store current speed
|
| 754 |
+
ttsControls.appendChild(speakButton);
|
| 755 |
+
speakButton.addEventListener('click', () => handleTTS(text, speakButton));
|
| 756 |
+
|
| 757 |
+
// Speed Cycle Button
|
| 758 |
+
const speedButton = document.createElement('button');
|
| 759 |
+
speedButton.className = 'speed-cycle-btn'; // Use new class for styling
|
| 760 |
+
speedButton.title = 'Cycle playback speed';
|
| 761 |
+
speedButton.textContent = 'Speed: 1x';
|
| 762 |
+
speedButton.setAttribute('data-speeds', JSON.stringify(availableSpeeds)); // Store speeds
|
| 763 |
+
ttsControls.appendChild(speedButton);
|
| 764 |
+
speedButton.addEventListener('click', () => cycleSpeed(speedButton, speakButton));
|
| 765 |
+
}
|
| 766 |
+
|
| 767 |
+
// --- NEW: cycleSpeed ---
|
| 768 |
+
function cycleSpeed(speedBtn, speakBtn) {
|
| 769 |
+
const speeds = JSON.parse(speedBtn.getAttribute('data-speeds'));
|
| 770 |
+
let currentSpeed = parseFloat(speakBtn.getAttribute('data-current-speed'));
|
| 771 |
+
let currentIndex = speeds.indexOf(currentSpeed);
|
| 772 |
+
|
| 773 |
+
// Find next speed index, looping back to 0
|
| 774 |
+
let nextIndex = (currentIndex + 1) % speeds.length;
|
| 775 |
+
let nextSpeed = speeds[nextIndex];
|
| 776 |
+
|
| 777 |
+
// Update speak button's data attribute and speed button's text
|
| 778 |
+
speakBtn.setAttribute('data-current-speed', nextSpeed.toString());
|
| 779 |
+
speedBtn.textContent = `Speed: ${nextSpeed}x`;
|
| 780 |
+
|
| 781 |
+
// If audio is currently playing and this is the active button, update playback rate
|
| 782 |
+
if (currentAudio && !currentAudio.paused && speakBtn === currentPlayingButton) {
|
| 783 |
+
currentAudio.playbackRate = nextSpeed;
|
| 784 |
+
}
|
| 785 |
+
}
|
| 786 |
+
|
| 787 |
+
// --- UPDATED: handleTTS ---
|
| 788 |
+
async function handleTTS(text, button) {
|
| 789 |
+
if (!text || !text.trim()) return;
|
| 790 |
+
|
| 791 |
+
// *** Get speed from the button's data attribute ***
|
| 792 |
+
const selectedSpeed = parseFloat(button.getAttribute('data-current-speed')) || 1.0;
|
| 793 |
+
|
| 794 |
+
if (button === currentPlayingButton) { // If clicking the currently active play/pause button
|
| 795 |
+
if (currentAudio && !currentAudio.paused) { // If playing, pause it
|
| 796 |
+
currentAudio.pause();
|
| 797 |
+
button.innerHTML = `${playIconSVG} <span>Listen</span>`;
|
| 798 |
+
} else if (currentAudio && currentAudio.paused) { // If paused, resume it
|
| 799 |
+
currentAudio.playbackRate = selectedSpeed; // Ensure speed is set on resume
|
| 800 |
+
currentAudio.play().catch(e => {console.error("Audio resume error:", e); resetAllSpeakButtons();});
|
| 801 |
+
button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
|
| 802 |
+
}
|
| 803 |
+
return;
|
| 804 |
+
}
|
| 805 |
+
|
| 806 |
+
// If clicking a new play button (or the first time)
|
| 807 |
+
resetAllSpeakButtons(); // Stop any other audio
|
| 808 |
+
currentPlayingButton = button; // Mark this button as active
|
| 809 |
+
button.innerHTML = `<div class="tts-button-loader mr-1"></div> <span>Loading...</span>`;
|
| 810 |
+
button.disabled = true;
|
| 811 |
+
// Disable the corresponding speed button while loading
|
| 812 |
+
const speedBtn = button.parentElement.querySelector('.speed-cycle-btn');
|
| 813 |
+
if(speedBtn) speedBtn.disabled = true;
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
try {
|
| 817 |
+
const response = await fetch('/tts', {
|
| 818 |
+
method: 'POST',
|
| 819 |
+
headers: { 'Content-Type': 'application/json' },
|
| 820 |
+
body: JSON.stringify({ text: text })
|
| 821 |
+
});
|
| 822 |
+
if (!response.ok) throw new Error(`TTS generation failed (${response.status})`);
|
| 823 |
+
const blob = await response.blob();
|
| 824 |
+
if (!blob || blob.size === 0) throw new Error("Received empty audio blob.");
|
| 825 |
+
|
| 826 |
+
const audioUrl = URL.createObjectURL(blob);
|
| 827 |
+
currentAudio = new Audio(audioUrl);
|
| 828 |
+
|
| 829 |
+
// *** Set the playback speed HERE ***
|
| 830 |
+
currentAudio.playbackRate = selectedSpeed;
|
| 831 |
+
|
| 832 |
+
await currentAudio.play();
|
| 833 |
+
button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
|
| 834 |
+
button.disabled = false;
|
| 835 |
+
// Re-enable the speed button
|
| 836 |
+
if(speedBtn) speedBtn.disabled = false;
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
currentAudio.onended = () => {
|
| 840 |
+
// Only reset if this button was the one playing
|
| 841 |
+
if (button === currentPlayingButton) resetAllSpeakButtons();
|
| 842 |
+
};
|
| 843 |
+
currentAudio.onerror = (e) => {
|
| 844 |
+
console.error('Audio object error:', e);
|
| 845 |
+
alert('Error playing audio.');
|
| 846 |
+
resetAllSpeakButtons();
|
| 847 |
+
};
|
| 848 |
+
|
| 849 |
+
} catch (error) {
|
| 850 |
+
console.error('TTS Handling Error:', error);
|
| 851 |
+
alert(`Failed to play audio: ${error.message}`);
|
| 852 |
+
resetAllSpeakButtons(); // Reset on error
|
| 853 |
+
}
|
| 854 |
+
}
|
| 855 |
+
|
| 856 |
+
// --- UPDATED: resetAllSpeakButtons ---
|
| 857 |
+
function resetAllSpeakButtons() {
|
| 858 |
+
document.querySelectorAll('.speak-btn').forEach(btn => {
|
| 859 |
+
btn.innerHTML = `${playIconSVG} <span>Listen</span>`;
|
| 860 |
+
btn.disabled = false;
|
| 861 |
+
btn.setAttribute('data-current-speed', '1.0'); // Reset speed attribute
|
| 862 |
+
});
|
| 863 |
+
document.querySelectorAll('.speed-cycle-btn').forEach(btn => {
|
| 864 |
+
btn.textContent = 'Speed: 1x'; // Reset speed button text
|
| 865 |
btn.disabled = false;
|
| 866 |
+
});
|
| 867 |
+
|
| 868 |
+
if (currentAudio) {
|
| 869 |
+
currentAudio.pause();
|
| 870 |
+
currentAudio.onended = null; // Clean up listeners
|
| 871 |
+
currentAudio.onerror = null;
|
| 872 |
+
currentAudio = null;
|
| 873 |
+
}
|
| 874 |
+
currentPlayingButton = null;
|
| 875 |
+
}
|
| 876 |
+
|
| 877 |
+
// Remove the resetSpecificButton function if it exists, it's integrated now.
|
| 878 |
+
// --- End of TTS Functions ---
|
| 879 |
+
|
| 880 |
+
// ... (keep the optional highlight.js part if you have it)
|
| 881 |
+
|
| 882 |
+
});
|
| 883 |
+
</script>
|
| 884 |
+
</body>
|
| 885 |
+
</html>
|