Spaces:
Sleeping
Sleeping
Commit ·
e81a689
1
Parent(s): 911c2cd
FAISS caching (Task 5), session persistence, streaming UX, suggestions as SSE event
Browse files- api/rag_engine.py +43 -14
- api/server.py +55 -2
- web/src/App.tsx +155 -11
api/rag_engine.py
CHANGED
|
@@ -318,29 +318,41 @@ def _parse_pptx_to_text(path: str) -> List[Tuple[str, str]]:
|
|
| 318 |
# ----------------------------
|
| 319 |
class VectorStore:
|
| 320 |
"""Simple in-memory vector store using FAISS (or fallback to list-based cosine similarity)."""
|
| 321 |
-
|
| 322 |
-
def __init__(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
self.faiss = _safe_import_faiss()
|
| 324 |
self.index = None
|
| 325 |
self.chunks: List[Dict] = []
|
| 326 |
self.use_faiss = False
|
| 327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
def build_index(self, chunks: List[Dict]):
|
| 329 |
"""Build FAISS index from chunks with embeddings."""
|
| 330 |
self.chunks = chunks or []
|
| 331 |
if not self.chunks:
|
| 332 |
return
|
| 333 |
-
|
| 334 |
# Filter chunks that have embeddings
|
| 335 |
chunks_with_emb = [c for c in self.chunks if c.get("embedding") is not None]
|
| 336 |
if not chunks_with_emb:
|
| 337 |
print("[rag_engine] No chunks with embeddings, using token-based retrieval")
|
| 338 |
return
|
| 339 |
-
|
| 340 |
if self.faiss is None:
|
| 341 |
print("[rag_engine] FAISS not available, using list-based cosine similarity")
|
| 342 |
return
|
| 343 |
-
|
| 344 |
try:
|
| 345 |
dim = len(chunks_with_emb[0]["embedding"])
|
| 346 |
# Use L2 (Euclidean) index for FAISS
|
|
@@ -354,7 +366,20 @@ class VectorStore:
|
|
| 354 |
except Exception as e:
|
| 355 |
print(f"[rag_engine] FAISS index build failed: {repr(e)}, using list-based")
|
| 356 |
self.use_faiss = False
|
| 357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
def search(self, query_embedding: List[float], k: int) -> List[Tuple[float, Dict]]:
|
| 359 |
"""
|
| 360 |
Search top-k chunks by vector similarity.
|
|
@@ -362,11 +387,11 @@ class VectorStore:
|
|
| 362 |
"""
|
| 363 |
if not query_embedding or not self.chunks:
|
| 364 |
return []
|
| 365 |
-
|
| 366 |
chunks_with_emb = [c for c in self.chunks if c.get("embedding") is not None]
|
| 367 |
if not chunks_with_emb:
|
| 368 |
return []
|
| 369 |
-
|
| 370 |
if self.use_faiss and self.index is not None:
|
| 371 |
try:
|
| 372 |
import numpy as np
|
|
@@ -381,7 +406,7 @@ class VectorStore:
|
|
| 381 |
return results
|
| 382 |
except Exception as e:
|
| 383 |
print(f"[rag_engine] FAISS search error: {repr(e)}, fallback to list-based")
|
| 384 |
-
|
| 385 |
# Fallback: list-based cosine similarity
|
| 386 |
results: List[Tuple[float, Dict]] = []
|
| 387 |
for chunk in chunks_with_emb:
|
|
@@ -472,6 +497,7 @@ def retrieve_relevant_chunks(
|
|
| 472 |
allowed_doc_types: Optional[List[str]] = None,
|
| 473 |
use_vector_search: bool = True, # NEW: enable/disable vector search
|
| 474 |
vector_similarity_threshold: float = 0.3, # L2-based similarity: 1/(1+dist), rarely reaches 0.7
|
|
|
|
| 475 |
) -> Tuple[str, List[Dict]]:
|
| 476 |
"""
|
| 477 |
Enhanced retrieval with vector similarity + token overlap rerank.
|
|
@@ -526,14 +552,17 @@ def retrieve_relevant_chunks(
|
|
| 526 |
# Vector search path (if enabled and embeddings available)
|
| 527 |
# ----------------------------
|
| 528 |
chunks_with_emb = [c for c in filtered if c.get("embedding") is not None]
|
| 529 |
-
|
| 530 |
if use_vector_search and chunks_with_emb:
|
| 531 |
try:
|
| 532 |
query_emb = get_chunk_embedding(query)
|
| 533 |
if query_emb:
|
| 534 |
-
#
|
| 535 |
-
|
| 536 |
-
|
|
|
|
|
|
|
|
|
|
| 537 |
vector_results = store.search(query_emb, k=k * 3) # Get 3x candidates for reranker
|
| 538 |
|
| 539 |
# Filter by similarity threshold
|
|
|
|
| 318 |
# ----------------------------
|
| 319 |
class VectorStore:
|
| 320 |
"""Simple in-memory vector store using FAISS (or fallback to list-based cosine similarity)."""
|
| 321 |
+
|
| 322 |
+
def __init__(self, cached_index=None):
|
| 323 |
+
"""
|
| 324 |
+
Initialize VectorStore with optional pre-built index.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
cached_index: Optional dict with 'index', 'use_faiss', 'chunks' (for reuse)
|
| 328 |
+
"""
|
| 329 |
self.faiss = _safe_import_faiss()
|
| 330 |
self.index = None
|
| 331 |
self.chunks: List[Dict] = []
|
| 332 |
self.use_faiss = False
|
| 333 |
+
|
| 334 |
+
# If cached index provided, restore it
|
| 335 |
+
if cached_index:
|
| 336 |
+
self.index = cached_index.get("index")
|
| 337 |
+
self.use_faiss = cached_index.get("use_faiss", False)
|
| 338 |
+
self.chunks = cached_index.get("chunks", [])
|
| 339 |
+
|
| 340 |
def build_index(self, chunks: List[Dict]):
|
| 341 |
"""Build FAISS index from chunks with embeddings."""
|
| 342 |
self.chunks = chunks or []
|
| 343 |
if not self.chunks:
|
| 344 |
return
|
| 345 |
+
|
| 346 |
# Filter chunks that have embeddings
|
| 347 |
chunks_with_emb = [c for c in self.chunks if c.get("embedding") is not None]
|
| 348 |
if not chunks_with_emb:
|
| 349 |
print("[rag_engine] No chunks with embeddings, using token-based retrieval")
|
| 350 |
return
|
| 351 |
+
|
| 352 |
if self.faiss is None:
|
| 353 |
print("[rag_engine] FAISS not available, using list-based cosine similarity")
|
| 354 |
return
|
| 355 |
+
|
| 356 |
try:
|
| 357 |
dim = len(chunks_with_emb[0]["embedding"])
|
| 358 |
# Use L2 (Euclidean) index for FAISS
|
|
|
|
| 366 |
except Exception as e:
|
| 367 |
print(f"[rag_engine] FAISS index build failed: {repr(e)}, using list-based")
|
| 368 |
self.use_faiss = False
|
| 369 |
+
|
| 370 |
+
def get_cached(self) -> Optional[Dict]:
|
| 371 |
+
"""
|
| 372 |
+
Export index for caching in session.
|
| 373 |
+
Returns: dict with 'index', 'use_faiss', 'chunks' or None if not built.
|
| 374 |
+
"""
|
| 375 |
+
if self.index is None:
|
| 376 |
+
return None
|
| 377 |
+
return {
|
| 378 |
+
"index": self.index,
|
| 379 |
+
"use_faiss": self.use_faiss,
|
| 380 |
+
"chunks": self.chunks,
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
def search(self, query_embedding: List[float], k: int) -> List[Tuple[float, Dict]]:
|
| 384 |
"""
|
| 385 |
Search top-k chunks by vector similarity.
|
|
|
|
| 387 |
"""
|
| 388 |
if not query_embedding or not self.chunks:
|
| 389 |
return []
|
| 390 |
+
|
| 391 |
chunks_with_emb = [c for c in self.chunks if c.get("embedding") is not None]
|
| 392 |
if not chunks_with_emb:
|
| 393 |
return []
|
| 394 |
+
|
| 395 |
if self.use_faiss and self.index is not None:
|
| 396 |
try:
|
| 397 |
import numpy as np
|
|
|
|
| 406 |
return results
|
| 407 |
except Exception as e:
|
| 408 |
print(f"[rag_engine] FAISS search error: {repr(e)}, fallback to list-based")
|
| 409 |
+
|
| 410 |
# Fallback: list-based cosine similarity
|
| 411 |
results: List[Tuple[float, Dict]] = []
|
| 412 |
for chunk in chunks_with_emb:
|
|
|
|
| 497 |
allowed_doc_types: Optional[List[str]] = None,
|
| 498 |
use_vector_search: bool = True, # NEW: enable/disable vector search
|
| 499 |
vector_similarity_threshold: float = 0.3, # L2-based similarity: 1/(1+dist), rarely reaches 0.7
|
| 500 |
+
cached_index: Optional[Dict] = None, # NEW: pre-built FAISS index for caching
|
| 501 |
) -> Tuple[str, List[Dict]]:
|
| 502 |
"""
|
| 503 |
Enhanced retrieval with vector similarity + token overlap rerank.
|
|
|
|
| 552 |
# Vector search path (if enabled and embeddings available)
|
| 553 |
# ----------------------------
|
| 554 |
chunks_with_emb = [c for c in filtered if c.get("embedding") is not None]
|
| 555 |
+
|
| 556 |
if use_vector_search and chunks_with_emb:
|
| 557 |
try:
|
| 558 |
query_emb = get_chunk_embedding(query)
|
| 559 |
if query_emb:
|
| 560 |
+
# Use cached index if provided; otherwise build a new one
|
| 561 |
+
if cached_index:
|
| 562 |
+
store = VectorStore(cached_index=cached_index)
|
| 563 |
+
else:
|
| 564 |
+
store = VectorStore()
|
| 565 |
+
store.build_index(chunks_with_emb)
|
| 566 |
vector_results = store.search(query_emb, k=k * 3) # Get 3x candidates for reranker
|
| 567 |
|
| 568 |
# Filter by similarity threshold
|
api/server.py
CHANGED
|
@@ -15,7 +15,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
| 15 |
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
| 16 |
from pydantic import BaseModel
|
| 17 |
|
| 18 |
-
from api.config import DEFAULT_COURSE_TOPICS, DEFAULT_MODEL
|
| 19 |
from api.syllabus_utils import extract_course_topics_from_file
|
| 20 |
from api.rag_engine import build_rag_chunks_from_file, retrieve_relevant_chunks
|
| 21 |
from api.clare_core import (
|
|
@@ -243,6 +243,21 @@ def _run_preload_in_background():
|
|
| 243 |
|
| 244 |
_run_preload_in_background()
|
| 245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
def _get_session(session_id: str) -> Dict[str, Any]:
|
| 247 |
if session_id not in SESSIONS:
|
| 248 |
SESSIONS[session_id] = {
|
|
@@ -259,8 +274,13 @@ def _get_session(session_id: str) -> Dict[str, Any]:
|
|
| 259 |
"profile_bio": "",
|
| 260 |
"init_answers": {},
|
| 261 |
"init_dismiss_until": 0,
|
|
|
|
| 262 |
}
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
if "uploaded_files" not in SESSIONS[session_id]:
|
| 265 |
SESSIONS[session_id]["uploaded_files"] = []
|
| 266 |
|
|
@@ -268,6 +288,7 @@ def _get_session(session_id: str) -> Dict[str, Any]:
|
|
| 268 |
SESSIONS[session_id].setdefault("profile_bio", "")
|
| 269 |
SESSIONS[session_id].setdefault("init_answers", {})
|
| 270 |
SESSIONS[session_id].setdefault("init_dismiss_until", 0)
|
|
|
|
| 271 |
|
| 272 |
return SESSIONS[session_id]
|
| 273 |
|
|
@@ -794,12 +815,14 @@ async def chat(req: ChatReq):
|
|
| 794 |
log.debug("rag skipped - message too short")
|
| 795 |
rag_context_text, rag_used_chunks = "", []
|
| 796 |
else:
|
|
|
|
| 797 |
rag_context_text, rag_used_chunks = retrieve_relevant_chunks(
|
| 798 |
msg,
|
| 799 |
MODULE10_CHUNKS_CACHE + sess["rag_chunks"],
|
| 800 |
allowed_source_files=allowed_files,
|
| 801 |
allowed_doc_types=allowed_doc_types,
|
| 802 |
max_context_chars=2000,
|
|
|
|
| 803 |
)
|
| 804 |
log.debug("faiss rag | chunks_returned=%d | context_chars=%d", len(rag_used_chunks), len(rag_context_text))
|
| 805 |
if rag_used_chunks:
|
|
@@ -950,6 +973,29 @@ async def chat(req: ChatReq):
|
|
| 950 |
run_id=None,
|
| 951 |
)
|
| 952 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 953 |
log.info("chat streamed | session=%s | chars=%d | total_ms=%.0f",
|
| 954 |
session_id, len(full_text), total_ms)
|
| 955 |
|
|
@@ -979,7 +1025,9 @@ async def quiz_start(req: QuizStartReq):
|
|
| 979 |
resolved_lang = detect_language(quiz_instruction, req.language_preference)
|
| 980 |
|
| 981 |
rag_context_text, rag_used_chunks = retrieve_relevant_chunks(
|
| 982 |
-
"Module 10 quiz",
|
|
|
|
|
|
|
| 983 |
)
|
| 984 |
|
| 985 |
# ✅ NEW: same hint for quiz start as well
|
|
@@ -1284,6 +1332,11 @@ async def upload(
|
|
| 1284 |
session_id, len(combined), MAX_UPLOAD_CHUNKS)
|
| 1285 |
combined = combined[:MAX_UPLOAD_CHUNKS]
|
| 1286 |
sess["rag_chunks"] = combined
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1287 |
except Exception as e:
|
| 1288 |
print(f"[upload] rag build error: {repr(e)}")
|
| 1289 |
new_chunks = []
|
|
|
|
| 15 |
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
| 16 |
from pydantic import BaseModel
|
| 17 |
|
| 18 |
+
from api.config import DEFAULT_COURSE_TOPICS, DEFAULT_MODEL, async_client
|
| 19 |
from api.syllabus_utils import extract_course_topics_from_file
|
| 20 |
from api.rag_engine import build_rag_chunks_from_file, retrieve_relevant_chunks
|
| 21 |
from api.clare_core import (
|
|
|
|
| 243 |
|
| 244 |
_run_preload_in_background()
|
| 245 |
|
| 246 |
+
def _build_faiss_index(chunks: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
| 247 |
+
"""Build and cache FAISS index from chunks. Returns cached index dict or None."""
|
| 248 |
+
if not chunks:
|
| 249 |
+
return None
|
| 250 |
+
from api.rag_engine import VectorStore
|
| 251 |
+
try:
|
| 252 |
+
vs = VectorStore()
|
| 253 |
+
vs.build_index(chunks)
|
| 254 |
+
cached = vs.get_cached()
|
| 255 |
+
return cached
|
| 256 |
+
except Exception as e:
|
| 257 |
+
log.error("failed to build FAISS index: %r", e)
|
| 258 |
+
return None
|
| 259 |
+
|
| 260 |
+
|
| 261 |
def _get_session(session_id: str) -> Dict[str, Any]:
|
| 262 |
if session_id not in SESSIONS:
|
| 263 |
SESSIONS[session_id] = {
|
|
|
|
| 274 |
"profile_bio": "",
|
| 275 |
"init_answers": {},
|
| 276 |
"init_dismiss_until": 0,
|
| 277 |
+
"faiss_index": None, # Cached FAISS index (built at init and on upload)
|
| 278 |
}
|
| 279 |
|
| 280 |
+
# Build initial FAISS index with MODULE10_CHUNKS_CACHE
|
| 281 |
+
initial_chunks = MODULE10_CHUNKS_CACHE
|
| 282 |
+
SESSIONS[session_id]["faiss_index"] = _build_faiss_index(initial_chunks)
|
| 283 |
+
|
| 284 |
if "uploaded_files" not in SESSIONS[session_id]:
|
| 285 |
SESSIONS[session_id]["uploaded_files"] = []
|
| 286 |
|
|
|
|
| 288 |
SESSIONS[session_id].setdefault("profile_bio", "")
|
| 289 |
SESSIONS[session_id].setdefault("init_answers", {})
|
| 290 |
SESSIONS[session_id].setdefault("init_dismiss_until", 0)
|
| 291 |
+
SESSIONS[session_id].setdefault("faiss_index", None)
|
| 292 |
|
| 293 |
return SESSIONS[session_id]
|
| 294 |
|
|
|
|
| 815 |
log.debug("rag skipped - message too short")
|
| 816 |
rag_context_text, rag_used_chunks = "", []
|
| 817 |
else:
|
| 818 |
+
# Use cached FAISS index if available (no rebuild on each query)
|
| 819 |
rag_context_text, rag_used_chunks = retrieve_relevant_chunks(
|
| 820 |
msg,
|
| 821 |
MODULE10_CHUNKS_CACHE + sess["rag_chunks"],
|
| 822 |
allowed_source_files=allowed_files,
|
| 823 |
allowed_doc_types=allowed_doc_types,
|
| 824 |
max_context_chars=2000,
|
| 825 |
+
cached_index=sess.get("faiss_index"),
|
| 826 |
)
|
| 827 |
log.debug("faiss rag | chunks_returned=%d | context_chars=%d", len(rag_used_chunks), len(rag_context_text))
|
| 828 |
if rag_used_chunks:
|
|
|
|
| 973 |
run_id=None,
|
| 974 |
)
|
| 975 |
|
| 976 |
+
# Generate follow-up suggestions (not blocking, sent after final message)
|
| 977 |
+
try:
|
| 978 |
+
log.debug("generating suggestions...")
|
| 979 |
+
suggestions = await asyncio.wait_for(
|
| 980 |
+
generate_suggested_questions(
|
| 981 |
+
user_message=msg,
|
| 982 |
+
assistant_reply=full_text,
|
| 983 |
+
language=resolved_lang,
|
| 984 |
+
model_name=model_name,
|
| 985 |
+
),
|
| 986 |
+
timeout=30.0, # Max 30 seconds for suggestions
|
| 987 |
+
)
|
| 988 |
+
log.debug("suggestions generated | count=%d | data=%r", len(suggestions) if suggestions else 0, suggestions)
|
| 989 |
+
if suggestions and len(suggestions) > 0:
|
| 990 |
+
yield f"data: {json.dumps({'suggested_questions': suggestions, 'type': 'suggestions', 'is_final': True})}\n\n"
|
| 991 |
+
log.info("suggestions sent | count=%d", len(suggestions))
|
| 992 |
+
else:
|
| 993 |
+
log.debug("no suggestions returned")
|
| 994 |
+
except asyncio.TimeoutError:
|
| 995 |
+
log.warning("suggestions generation timed out (>30s)")
|
| 996 |
+
except Exception as e:
|
| 997 |
+
log.warning("suggestions generation failed: %r", e)
|
| 998 |
+
|
| 999 |
log.info("chat streamed | session=%s | chars=%d | total_ms=%.0f",
|
| 1000 |
session_id, len(full_text), total_ms)
|
| 1001 |
|
|
|
|
| 1025 |
resolved_lang = detect_language(quiz_instruction, req.language_preference)
|
| 1026 |
|
| 1027 |
rag_context_text, rag_used_chunks = retrieve_relevant_chunks(
|
| 1028 |
+
"Module 10 quiz",
|
| 1029 |
+
MODULE10_CHUNKS_CACHE + sess["rag_chunks"],
|
| 1030 |
+
cached_index=sess.get("faiss_index"),
|
| 1031 |
)
|
| 1032 |
|
| 1033 |
# ✅ NEW: same hint for quiz start as well
|
|
|
|
| 1332 |
session_id, len(combined), MAX_UPLOAD_CHUNKS)
|
| 1333 |
combined = combined[:MAX_UPLOAD_CHUNKS]
|
| 1334 |
sess["rag_chunks"] = combined
|
| 1335 |
+
|
| 1336 |
+
# REBUILD FAISS index with merged chunks (MODULE10 + new uploads)
|
| 1337 |
+
all_chunks = MODULE10_CHUNKS_CACHE + sess["rag_chunks"]
|
| 1338 |
+
sess["faiss_index"] = _build_faiss_index(all_chunks)
|
| 1339 |
+
log.debug("[upload] rebuilt FAISS index with %d total chunks", len(all_chunks))
|
| 1340 |
except Exception as e:
|
| 1341 |
print(f"[upload] rag build error: {repr(e)}")
|
| 1342 |
new_chunks = []
|
web/src/App.tsx
CHANGED
|
@@ -166,6 +166,42 @@ function hydrateSavedChats(raw: any): SavedChat[] {
|
|
| 166 |
.filter(Boolean) as SavedChat[];
|
| 167 |
}
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
function App() {
|
| 171 |
const [isDarkMode, setIsDarkMode] = useState(() => {
|
|
@@ -173,7 +209,22 @@ function App() {
|
|
| 173 |
return saved === "dark" || (!saved && window.matchMedia("(prefers-color-scheme: dark)").matches);
|
| 174 |
});
|
| 175 |
|
| 176 |
-
const [user, setUser] = useState<User | null>(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
|
| 179 |
// -------------------------
|
|
@@ -343,6 +394,27 @@ function App() {
|
|
| 343 |
|
| 344 |
const [savedChats, setSavedChats] = useState<SavedChat[]>([]);
|
| 345 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
// ✅ load saved chats after login
|
| 347 |
useEffect(() => {
|
| 348 |
if (!user?.login_id) return;
|
|
@@ -369,6 +441,47 @@ function App() {
|
|
| 369 |
}
|
| 370 |
}, [savedChats, user?.login_id]);
|
| 371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
const [groupMembers] = useState<GroupMember[]>([
|
| 373 |
{ id: "clare", name: "Clare AI", email: "clare@ai.assistant", isAI: true },
|
| 374 |
{ id: "1", name: "Sarah Johnson", email: "sarah.j@university.edu" },
|
|
@@ -811,12 +924,47 @@ function App() {
|
|
| 811 |
try {
|
| 812 |
const docType = getCurrentDocTypeForChat();
|
| 813 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 814 |
const r = await apiChat({
|
| 815 |
session_id: user.session_id,
|
| 816 |
message: effectiveContent,
|
| 817 |
learning_mode: learningMode,
|
| 818 |
language_preference: mapLanguagePref(language),
|
| 819 |
doc_type: docType,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 820 |
});
|
| 821 |
|
| 822 |
const normalizeRefs = (raw: any): string[] => {
|
|
@@ -838,21 +986,17 @@ function App() {
|
|
| 838 |
|
| 839 |
const refs = normalizeRefs((r as any).refs ?? (r as any).references);
|
| 840 |
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
content: r.reply || "",
|
| 845 |
-
timestamp: new Date(),
|
| 846 |
references: refs.length ? refs : undefined,
|
| 847 |
-
sender: spaceType === "group" ? groupMembers.find((m) => m.isAI) : undefined,
|
| 848 |
suggestedQuestions: (r as any).suggested_questions?.length ? (r as any).suggested_questions : undefined,
|
| 849 |
};
|
| 850 |
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
else if (chatMode === "review") setReviewMessages((prev) => [...prev, assistantMessage]);
|
| 855 |
-
else setQuizMessages((prev) => [...prev, assistantMessage]);
|
| 856 |
} catch (e: any) {
|
| 857 |
setIsTyping(false);
|
| 858 |
toast.error(e?.message || "Something went wrong. Please try again.");
|
|
|
|
| 166 |
.filter(Boolean) as SavedChat[];
|
| 167 |
}
|
| 168 |
|
| 169 |
+
// ✅ localStorage helpers for ongoing session state (refresh persistence)
|
| 170 |
+
function sessionStorageKey(session_id: string) {
|
| 171 |
+
return `session_state::${session_id}`;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
function hydrateSessionState(raw: any) {
|
| 175 |
+
if (!raw) return null;
|
| 176 |
+
try {
|
| 177 |
+
return {
|
| 178 |
+
askMessages: Array.isArray(raw.askMessages)
|
| 179 |
+
? raw.askMessages.map((m: any) => ({
|
| 180 |
+
...m,
|
| 181 |
+
timestamp: m?.timestamp ? new Date(m.timestamp) : new Date(),
|
| 182 |
+
}))
|
| 183 |
+
: [],
|
| 184 |
+
reviewMessages: Array.isArray(raw.reviewMessages)
|
| 185 |
+
? raw.reviewMessages.map((m: any) => ({
|
| 186 |
+
...m,
|
| 187 |
+
timestamp: m?.timestamp ? new Date(m.timestamp) : new Date(),
|
| 188 |
+
}))
|
| 189 |
+
: [],
|
| 190 |
+
quizMessages: Array.isArray(raw.quizMessages)
|
| 191 |
+
? raw.quizMessages.map((m: any) => ({
|
| 192 |
+
...m,
|
| 193 |
+
timestamp: m?.timestamp ? new Date(m.timestamp) : new Date(),
|
| 194 |
+
}))
|
| 195 |
+
: [],
|
| 196 |
+
uploadedFiles: Array.isArray(raw.uploadedFiles) ? raw.uploadedFiles : [],
|
| 197 |
+
learningMode: raw.learningMode || "concept",
|
| 198 |
+
language: raw.language || "Auto",
|
| 199 |
+
};
|
| 200 |
+
} catch {
|
| 201 |
+
return null;
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
|
| 206 |
function App() {
|
| 207 |
const [isDarkMode, setIsDarkMode] = useState(() => {
|
|
|
|
| 209 |
return saved === "dark" || (!saved && window.matchMedia("(prefers-color-scheme: dark)").matches);
|
| 210 |
});
|
| 211 |
|
| 212 |
+
const [user, setUser] = useState<User | null>(() => {
|
| 213 |
+
// Restore user from localStorage on page load
|
| 214 |
+
try {
|
| 215 |
+
const saved = localStorage.getItem("user_session");
|
| 216 |
+
if (saved) {
|
| 217 |
+
const parsed = JSON.parse(saved);
|
| 218 |
+
return {
|
| 219 |
+
login_id: parsed.login_id,
|
| 220 |
+
session_id: parsed.session_id,
|
| 221 |
+
} as User;
|
| 222 |
+
}
|
| 223 |
+
} catch {
|
| 224 |
+
// ignore
|
| 225 |
+
}
|
| 226 |
+
return null;
|
| 227 |
+
});
|
| 228 |
|
| 229 |
|
| 230 |
// -------------------------
|
|
|
|
| 394 |
|
| 395 |
const [savedChats, setSavedChats] = useState<SavedChat[]>([]);
|
| 396 |
|
| 397 |
+
// ✅ persist user session to localStorage
|
| 398 |
+
useEffect(() => {
|
| 399 |
+
if (user?.login_id && user?.session_id) {
|
| 400 |
+
try {
|
| 401 |
+
localStorage.setItem("user_session", JSON.stringify({
|
| 402 |
+
login_id: user.login_id,
|
| 403 |
+
session_id: user.session_id,
|
| 404 |
+
}));
|
| 405 |
+
} catch {
|
| 406 |
+
// ignore
|
| 407 |
+
}
|
| 408 |
+
} else {
|
| 409 |
+
// Clear user session when logged out
|
| 410 |
+
try {
|
| 411 |
+
localStorage.removeItem("user_session");
|
| 412 |
+
} catch {
|
| 413 |
+
// ignore
|
| 414 |
+
}
|
| 415 |
+
}
|
| 416 |
+
}, [user?.login_id, user?.session_id]);
|
| 417 |
+
|
| 418 |
// ✅ load saved chats after login
|
| 419 |
useEffect(() => {
|
| 420 |
if (!user?.login_id) return;
|
|
|
|
| 441 |
}
|
| 442 |
}, [savedChats, user?.login_id]);
|
| 443 |
|
| 444 |
+
// ✅ restore session state from localStorage on login
|
| 445 |
+
useEffect(() => {
|
| 446 |
+
if (!user?.session_id) return;
|
| 447 |
+
try {
|
| 448 |
+
const raw = localStorage.getItem(sessionStorageKey(user.session_id));
|
| 449 |
+
if (!raw) return;
|
| 450 |
+
const state = hydrateSessionState(JSON.parse(raw));
|
| 451 |
+
if (!state) return;
|
| 452 |
+
|
| 453 |
+
// Restore session state
|
| 454 |
+
if (state.askMessages.length > 0) setAskMessages(state.askMessages);
|
| 455 |
+
if (state.reviewMessages.length > 0) setReviewMessages(state.reviewMessages);
|
| 456 |
+
if (state.quizMessages.length > 0) setQuizMessages(state.quizMessages);
|
| 457 |
+
if (state.uploadedFiles.length > 0) setUploadedFiles(state.uploadedFiles);
|
| 458 |
+
if (state.learningMode) setLearningMode(state.learningMode);
|
| 459 |
+
if (state.language) setLanguage(state.language);
|
| 460 |
+
} catch {
|
| 461 |
+
// ignore restore errors
|
| 462 |
+
}
|
| 463 |
+
}, [user?.session_id]);
|
| 464 |
+
|
| 465 |
+
// ✅ persist session state to localStorage whenever messages/state change
|
| 466 |
+
useEffect(() => {
|
| 467 |
+
if (!user?.session_id) return;
|
| 468 |
+
try {
|
| 469 |
+
localStorage.setItem(
|
| 470 |
+
sessionStorageKey(user.session_id),
|
| 471 |
+
JSON.stringify({
|
| 472 |
+
askMessages,
|
| 473 |
+
reviewMessages,
|
| 474 |
+
quizMessages,
|
| 475 |
+
uploadedFiles,
|
| 476 |
+
learningMode,
|
| 477 |
+
language,
|
| 478 |
+
})
|
| 479 |
+
);
|
| 480 |
+
} catch {
|
| 481 |
+
// ignore
|
| 482 |
+
}
|
| 483 |
+
}, [askMessages, reviewMessages, quizMessages, uploadedFiles, learningMode, language, user?.session_id]);
|
| 484 |
+
|
| 485 |
const [groupMembers] = useState<GroupMember[]>([
|
| 486 |
{ id: "clare", name: "Clare AI", email: "clare@ai.assistant", isAI: true },
|
| 487 |
{ id: "1", name: "Sarah Johnson", email: "sarah.j@university.edu" },
|
|
|
|
| 924 |
try {
|
| 925 |
const docType = getCurrentDocTypeForChat();
|
| 926 |
|
| 927 |
+
// Create message with empty content (will be filled as tokens arrive)
|
| 928 |
+
const messageId = (Date.now() + 1).toString();
|
| 929 |
+
const assistantMessage: Message = {
|
| 930 |
+
id: messageId,
|
| 931 |
+
role: "assistant",
|
| 932 |
+
content: "",
|
| 933 |
+
timestamp: new Date(),
|
| 934 |
+
references: undefined,
|
| 935 |
+
sender: spaceType === "group" ? groupMembers.find((m) => m.isAI) : undefined,
|
| 936 |
+
};
|
| 937 |
+
|
| 938 |
+
// Add empty message immediately so user sees typing indicator
|
| 939 |
+
if (chatMode === "ask") setAskMessages((prev) => [...prev, assistantMessage]);
|
| 940 |
+
else if (chatMode === "review") setReviewMessages((prev) => [...prev, assistantMessage]);
|
| 941 |
+
else setQuizMessages((prev) => [...prev, assistantMessage]);
|
| 942 |
+
|
| 943 |
+
// Hide typing indicator immediately (message will fill with tokens)
|
| 944 |
+
setIsTyping(false);
|
| 945 |
+
|
| 946 |
+
// Stream response with token callback
|
| 947 |
const r = await apiChat({
|
| 948 |
session_id: user.session_id,
|
| 949 |
message: effectiveContent,
|
| 950 |
learning_mode: learningMode,
|
| 951 |
language_preference: mapLanguagePref(language),
|
| 952 |
doc_type: docType,
|
| 953 |
+
}, (token: string) => {
|
| 954 |
+
// Update message content as tokens arrive
|
| 955 |
+
if (chatMode === "ask") {
|
| 956 |
+
setAskMessages((prev) =>
|
| 957 |
+
prev.map((m) => m.id === messageId ? { ...m, content: m.content + token } : m)
|
| 958 |
+
);
|
| 959 |
+
} else if (chatMode === "review") {
|
| 960 |
+
setReviewMessages((prev) =>
|
| 961 |
+
prev.map((m) => m.id === messageId ? { ...m, content: m.content + token } : m)
|
| 962 |
+
);
|
| 963 |
+
} else {
|
| 964 |
+
setQuizMessages((prev) =>
|
| 965 |
+
prev.map((m) => m.id === messageId ? { ...m, content: m.content + token } : m)
|
| 966 |
+
);
|
| 967 |
+
}
|
| 968 |
});
|
| 969 |
|
| 970 |
const normalizeRefs = (raw: any): string[] => {
|
|
|
|
| 986 |
|
| 987 |
const refs = normalizeRefs((r as any).refs ?? (r as any).references);
|
| 988 |
|
| 989 |
+
// Update message with final content, refs, and suggestions
|
| 990 |
+
const finalMessage: Message = {
|
| 991 |
+
...assistantMessage,
|
| 992 |
content: r.reply || "",
|
|
|
|
| 993 |
references: refs.length ? refs : undefined,
|
|
|
|
| 994 |
suggestedQuestions: (r as any).suggested_questions?.length ? (r as any).suggested_questions : undefined,
|
| 995 |
};
|
| 996 |
|
| 997 |
+
if (chatMode === "ask") setAskMessages((prev) => prev.map((m) => m.id === messageId ? finalMessage : m));
|
| 998 |
+
else if (chatMode === "review") setReviewMessages((prev) => prev.map((m) => m.id === messageId ? finalMessage : m));
|
| 999 |
+
else setQuizMessages((prev) => prev.map((m) => m.id === messageId ? finalMessage : m));
|
|
|
|
|
|
|
| 1000 |
} catch (e: any) {
|
| 1001 |
setIsTyping(false);
|
| 1002 |
toast.error(e?.message || "Something went wrong. Please try again.");
|