rahulrb99 commited on
Commit
5ccf778
·
unverified ·
2 Parent(s): 07f5acaa524c65

Merge pull request #15 from tAnboyy/feature/citations-retrieval-isolation

Browse files
app.py CHANGED
@@ -256,6 +256,7 @@ def _initial_load(profile: gr.OAuthProfile | None = None):
256
  gr.update(visible=bool(user_id)),
257
  gr.update(visible=not bool(user_id)),
258
  source_status,
 
259
  )
260
 
261
 
@@ -650,29 +651,51 @@ def _chat_history_to_pairs(messages: list[dict]) -> list[tuple[str, str]]:
650
  return pairs
651
 
652
 
653
- def _load_chat_history(notebook_id) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]:
654
- """Load chat for notebook. Returns (history_pairs, history_pairs) for State and Chatbot."""
 
 
655
  if not notebook_id:
656
  return [], []
657
- messages = load_chat(notebook_id)
658
  pairs = _chat_history_to_pairs(messages)
659
  return pairs, pairs
660
 
661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662
  def _on_chat_submit(query, notebook_id, chat_history, profile: gr.OAuthProfile | None):
663
- """Handle chat submit: call RAG, return updated history."""
664
  if not notebook_id:
665
- return "", chat_history, "Select a notebook first."
666
  if not query or not query.strip():
667
- return "", chat_history, "Enter a message."
668
  user_id = _user_id(profile)
669
  if not user_id:
670
- return "", chat_history, "Please sign in first."
671
  try:
672
- answer, updated = rag_chat(notebook_id, query.strip(), chat_history)
673
- return "", updated, ""
 
 
674
  except Exception as e:
675
- return "", chat_history, f"Error: {e}"
676
 
677
  def _get_quiz_pdfs(source_type, notebook_id):
678
  if source_type != "PDF":
@@ -755,6 +778,7 @@ with gr.Blocks(
755
  selected_notebook_id = gr.State(None)
756
  chat_history_state = gr.State([])
757
  quiz_state = gr.State([])
 
758
 
759
  with gr.Group(elem_classes=["section-card", "manager-card"]):
760
  gr.Markdown("**Notebook Manager**", elem_classes=["section-title"])
@@ -891,6 +915,9 @@ with gr.Blocks(
891
  )
892
  chat_submit_btn = gr.Button("Send", variant="primary")
893
  chat_status = gr.Markdown("", elem_classes=["status"])
 
 
 
894
 
895
  with gr.Group(elem_classes=["section-card", "artifacts-card"]):
896
  gr.Markdown("**Artifacts**", elem_classes=["section-title"])
@@ -945,21 +972,21 @@ with gr.Blocks(
945
  demo.load(
946
  _initial_load,
947
  inputs=None,
948
- outputs=[nb_state, selected_notebook_id, notebook_status, auth_text, auth_info_row, app_content, login_container, source_status],
949
  api_name=False,
950
  )
951
  demo.load(_list_uploaded_pdfs, inputs=[selected_notebook_id], outputs=[uploaded_pdf_dd], api_name=False)
952
  demo.load(_load_sources, inputs=[selected_notebook_id], outputs=[sources_display], api_name=False)
953
  demo.load(_selected_notebook_text, inputs=[selected_notebook_id, nb_state], outputs=[selected_notebook_md], api_name=False)
954
 
955
- def _on_notebook_select_for_chat(notebook_id):
956
- hist, _ = _load_chat_history(notebook_id)
957
- return hist, hist
958
 
959
  selected_notebook_id.change(
960
  _on_notebook_select_for_chat,
961
- inputs=[selected_notebook_id],
962
- outputs=[chat_history_state, chatbot],
963
  api_name=False,
964
  )
965
  selected_notebook_id.change(_list_uploaded_pdfs, inputs=[selected_notebook_id], outputs=[uploaded_pdf_dd], api_name=False)
@@ -1077,7 +1104,7 @@ with gr.Blocks(
1077
  chat_submit_btn.click(
1078
  _on_chat_submit,
1079
  inputs=[chat_input, selected_notebook_id, chat_history_state],
1080
- outputs=[chat_input, chat_history_state, chat_status],
1081
  api_name=False,
1082
  ).then(
1083
  lambda h: (h, h),
 
256
  gr.update(visible=bool(user_id)),
257
  gr.update(visible=not bool(user_id)),
258
  source_status,
259
+ user_id,
260
  )
261
 
262
 
 
651
  return pairs
652
 
653
 
654
+ def _load_chat_history(notebook_id, user_id: str | None) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]:
655
+ """Load chat for notebook. Returns (history_pairs, history_pairs) for State and Chatbot.
656
+ Only loads if notebook belongs to user (ownership validation).
657
+ """
658
  if not notebook_id:
659
  return [], []
660
+ messages = load_chat(notebook_id, user_id)
661
  pairs = _chat_history_to_pairs(messages)
662
  return pairs, pairs
663
 
664
 
665
+ def _format_citations(chunks: list[dict]) -> str:
666
+ """Format retrieved chunks for citation display."""
667
+ if not chunks:
668
+ return ""
669
+ lines = ["**Sources cited:**", ""]
670
+ for i, c in enumerate(chunks, 1):
671
+ meta = c.get("metadata") or {}
672
+ source_label = meta.get("url") or meta.get("file_name") or meta.get("file_path") or "Source"
673
+ content = (c.get("content") or "")[:300]
674
+ if len(c.get("content") or "") > 300:
675
+ content += "..."
676
+ content = content.replace("\n", " ")
677
+ lines.append(f"**[{i}]** *{source_label}*")
678
+ lines.append(f"> {content}")
679
+ lines.append("")
680
+ return "\n".join(lines)
681
+
682
+
683
  def _on_chat_submit(query, notebook_id, chat_history, profile: gr.OAuthProfile | None):
684
+ """Handle chat submit: call RAG, return updated history and citations."""
685
  if not notebook_id:
686
+ return "", chat_history, "Select a notebook first.", "", gr.update(visible=False)
687
  if not query or not query.strip():
688
+ return "", chat_history, "Enter a message.", "", gr.update(visible=False)
689
  user_id = _user_id(profile)
690
  if not user_id:
691
+ return "", chat_history, "Please sign in first.", "", gr.update(visible=False)
692
  try:
693
+ answer, updated, chunks = rag_chat(notebook_id, query.strip(), chat_history, user_id=user_id)
694
+ citations_md = _format_citations(chunks)
695
+ accordion_update = gr.update(visible=True) if chunks else gr.update(visible=False)
696
+ return "", updated, "", citations_md, accordion_update
697
  except Exception as e:
698
+ return "", chat_history, f"Error: {e}", "", gr.update(visible=False)
699
 
700
  def _get_quiz_pdfs(source_type, notebook_id):
701
  if source_type != "PDF":
 
778
  selected_notebook_id = gr.State(None)
779
  chat_history_state = gr.State([])
780
  quiz_state = gr.State([])
781
+ user_id_state = gr.State(None)
782
 
783
  with gr.Group(elem_classes=["section-card", "manager-card"]):
784
  gr.Markdown("**Notebook Manager**", elem_classes=["section-title"])
 
915
  )
916
  chat_submit_btn = gr.Button("Send", variant="primary")
917
  chat_status = gr.Markdown("", elem_classes=["status"])
918
+ citations_display = gr.Accordion("📎 Sources cited (from last response)", open=True, visible=False)
919
+ with citations_display:
920
+ citations_md = gr.Markdown("", elem_classes=["status"])
921
 
922
  with gr.Group(elem_classes=["section-card", "artifacts-card"]):
923
  gr.Markdown("**Artifacts**", elem_classes=["section-title"])
 
972
  demo.load(
973
  _initial_load,
974
  inputs=None,
975
+ outputs=[nb_state, selected_notebook_id, notebook_status, auth_text, auth_info_row, app_content, login_container, source_status, user_id_state],
976
  api_name=False,
977
  )
978
  demo.load(_list_uploaded_pdfs, inputs=[selected_notebook_id], outputs=[uploaded_pdf_dd], api_name=False)
979
  demo.load(_load_sources, inputs=[selected_notebook_id], outputs=[sources_display], api_name=False)
980
  demo.load(_selected_notebook_text, inputs=[selected_notebook_id, nb_state], outputs=[selected_notebook_md], api_name=False)
981
 
982
+ def _on_notebook_select_for_chat(notebook_id, user_id):
983
+ hist, _ = _load_chat_history(notebook_id, user_id)
984
+ return hist, hist, "", gr.update(visible=False)
985
 
986
  selected_notebook_id.change(
987
  _on_notebook_select_for_chat,
988
+ inputs=[selected_notebook_id, user_id_state],
989
+ outputs=[chat_history_state, chatbot, citations_md, citations_display],
990
  api_name=False,
991
  )
992
  selected_notebook_id.change(_list_uploaded_pdfs, inputs=[selected_notebook_id], outputs=[uploaded_pdf_dd], api_name=False)
 
1104
  chat_submit_btn.click(
1105
  _on_chat_submit,
1106
  inputs=[chat_input, selected_notebook_id, chat_history_state],
1107
+ outputs=[chat_input, chat_history_state, chat_status, citations_md, citations_display],
1108
  api_name=False,
1109
  ).then(
1110
  lambda h: (h, h),
backend/chat_service.py CHANGED
@@ -3,8 +3,28 @@
3
  from backend.db import supabase
4
 
5
 
6
- def save_message(notebook_id: str, role: str, content: str) -> None:
7
- """Append a message to the messages table."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  supabase.table("messages").insert({
9
  "notebook_id": notebook_id,
10
  "role": role,
@@ -12,8 +32,10 @@ def save_message(notebook_id: str, role: str, content: str) -> None:
12
  }).execute()
13
 
14
 
15
- def load_chat(notebook_id: str) -> list[dict]:
16
- """Load chat history. Returns [{role, content, created_at}, ...]."""
 
 
17
  result = (
18
  supabase.table("messages")
19
  .select("role, content, created_at")
 
3
  from backend.db import supabase
4
 
5
 
6
+ def _notebook_belongs_to_user(notebook_id: str, user_id: str | None) -> bool:
7
+ """Verify the notebook is owned by the user. Returns False if user_id is None."""
8
+ if not user_id or not notebook_id:
9
+ return False
10
+ try:
11
+ result = (
12
+ supabase.table("notebooks")
13
+ .select("id")
14
+ .eq("id", notebook_id)
15
+ .eq("user_id", user_id)
16
+ .limit(1)
17
+ .execute()
18
+ )
19
+ return len(result.data or []) > 0
20
+ except Exception:
21
+ return False
22
+
23
+
24
+ def save_message(notebook_id: str, user_id: str | None, role: str, content: str) -> None:
25
+ """Append a message only if the notebook belongs to the user."""
26
+ if not _notebook_belongs_to_user(notebook_id, user_id):
27
+ return
28
  supabase.table("messages").insert({
29
  "notebook_id": notebook_id,
30
  "role": role,
 
32
  }).execute()
33
 
34
 
35
+ def load_chat(notebook_id: str, user_id: str | None) -> list[dict]:
36
+ """Load chat history only if the notebook belongs to the user. Returns [] if not owned."""
37
+ if not _notebook_belongs_to_user(notebook_id, user_id):
38
+ return []
39
  result = (
40
  supabase.table("messages")
41
  .select("role, content, created_at")
backend/chunking.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared chunking utilities for RAG ingestion."""
2
+
3
+ import re
4
+
5
+ DEFAULT_CHUNK_SIZE = 512
6
+ DEFAULT_CHUNK_OVERLAP = 80
7
+ MIN_CHUNK_SIZE = 100
8
+
9
+
10
+ def _split_into_sentences(text: str) -> list[str]:
11
+ """Split text on sentence boundaries (rough heuristic)."""
12
+ text = re.sub(r"\n+", "\n", text.strip())
13
+ if not text:
14
+ return []
15
+ parts = re.split(r"(?<=[.!?])\s+", text)
16
+ return [p.strip() for p in parts if p.strip()]
17
+
18
+
19
+ def chunk_text_semantic(
20
+ text: str,
21
+ chunk_size: int = DEFAULT_CHUNK_SIZE,
22
+ overlap: int = DEFAULT_CHUNK_OVERLAP,
23
+ ) -> list[str]:
24
+ """
25
+ Semantic chunking: split on paragraphs first, then sentences.
26
+ Preserves context better than blind character splits.
27
+ """
28
+ text = " ".join(text.split())
29
+ if not text:
30
+ return []
31
+
32
+ paragraphs = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
33
+ if len(paragraphs) <= 1:
34
+ sentences = _split_into_sentences(text)
35
+ if not sentences:
36
+ sentences = [text]
37
+ if len(sentences) == 1 and len(sentences[0]) > chunk_size * 2:
38
+ return chunk_text_fallback(text, chunk_size, overlap)
39
+ paragraphs = sentences
40
+
41
+ chunks = []
42
+ current_chunk = []
43
+ current_len = 0
44
+
45
+ for para in paragraphs:
46
+ para_len = len(para) + 1
47
+ if current_len + para_len > chunk_size and current_chunk:
48
+ chunk_text = " ".join(current_chunk)
49
+ if len(chunk_text) >= MIN_CHUNK_SIZE:
50
+ chunks.append(chunk_text)
51
+ overlap_len = 0
52
+ overlap_items = []
53
+ for item in reversed(current_chunk):
54
+ if overlap_len + len(item) + 1 <= overlap:
55
+ overlap_items.insert(0, item)
56
+ overlap_len += len(item) + 1
57
+ else:
58
+ break
59
+ current_chunk = overlap_items
60
+ current_len = overlap_len
61
+ current_chunk.append(para)
62
+ current_len += para_len
63
+
64
+ if current_chunk:
65
+ chunks.append(" ".join(current_chunk))
66
+ return chunks
67
+
68
+
69
+ def chunk_text_fallback(text: str, chunk_size: int, overlap: int) -> list[str]:
70
+ """Character-based chunking when semantic splitting fails."""
71
+ clean = " ".join(text.split())
72
+ if not clean:
73
+ return []
74
+ chunks = []
75
+ start = 0
76
+ step = max(1, chunk_size - overlap)
77
+ while start < len(clean):
78
+ end = min(len(clean), start + chunk_size)
79
+ chunks.append(clean[start:end])
80
+ start += step
81
+ return chunks
backend/embedding_service.py CHANGED
@@ -1,10 +1,16 @@
1
- """Shared embedding service - 384-dim vectors for RAG (ingestion + retrieval). Uses MiniLM for low memory."""
 
 
2
 
3
  from sentence_transformers import SentenceTransformer
4
 
5
- _MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
 
6
  _model = None
7
 
 
 
 
8
 
9
  def _get_model() -> SentenceTransformer:
10
  """Lazy-load the embedding model."""
@@ -14,17 +20,23 @@ def _get_model() -> SentenceTransformer:
14
  return _model
15
 
16
 
 
 
 
 
17
  def encode(texts: list[str], task: str = "search_document") -> list[list[float]]:
18
  """
19
  Embed texts. Returns list of 384-dim vectors.
20
 
21
  Args:
22
  texts: List of strings to embed.
23
- task: Unused (MiniLM doesn't need prefix); kept for API compatibility.
24
  """
25
  if not texts:
26
  return []
27
 
28
  model = _get_model()
 
 
29
  embeddings = model.encode(texts, show_progress_bar=False)
30
  return [e.tolist() for e in embeddings]
 
1
+ """Shared embedding service - 384-dim vectors for RAG (ingestion + retrieval)."""
2
+
3
+ import os
4
 
5
  from sentence_transformers import SentenceTransformer
6
 
7
+ # all-MiniLM-L6-v2 (default) or BAAI/bge-small-en-v1.5 for better quality (both 384 dims)
8
+ _MODEL_NAME = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
9
  _model = None
10
 
11
+ # BGE models: add prefix only to queries, not to documents
12
+ _BGE_QUERY_PREFIX = "Represent this sentence for searching relevant passages: "
13
+
14
 
15
  def _get_model() -> SentenceTransformer:
16
  """Lazy-load the embedding model."""
 
20
  return _model
21
 
22
 
23
+ def _is_bge_model() -> bool:
24
+ return "bge" in _MODEL_NAME.lower()
25
+
26
+
27
  def encode(texts: list[str], task: str = "search_document") -> list[list[float]]:
28
  """
29
  Embed texts. Returns list of 384-dim vectors.
30
 
31
  Args:
32
  texts: List of strings to embed.
33
+ task: "search_query" for queries, "search_document" for documents. BGE uses prefixes.
34
  """
35
  if not texts:
36
  return []
37
 
38
  model = _get_model()
39
+ if _is_bge_model() and task == "search_query":
40
+ texts = [_BGE_QUERY_PREFIX + t for t in texts]
41
  embeddings = model.encode(texts, show_progress_bar=False)
42
  return [e.tolist() for e in embeddings]
backend/ingestion_service.py CHANGED
@@ -4,14 +4,15 @@ from pathlib import Path
4
 
5
  from pypdf import PdfReader
6
 
 
7
  from backend.db import supabase
8
  from backend.embedding_service import encode as embed_texts
9
 
10
  import requests
11
  from bs4 import BeautifulSoup
12
 
13
- DEFAULT_CHUNK_SIZE = 1200
14
- DEFAULT_CHUNK_OVERLAP = 200
15
 
16
 
17
  def _extract_pdf_text(pdf_path: Path) -> str:
@@ -23,20 +24,8 @@ def _extract_pdf_text(pdf_path: Path) -> str:
23
 
24
 
25
  def _chunk_text(text: str, chunk_size: int = DEFAULT_CHUNK_SIZE, overlap: int = DEFAULT_CHUNK_OVERLAP) -> list[str]:
26
- clean = " ".join(text.split())
27
- if not clean:
28
- return []
29
-
30
- chunks: list[str] = []
31
- start = 0
32
- step = max(1, chunk_size - overlap)
33
-
34
- while start < len(clean):
35
- end = min(len(clean), start + chunk_size)
36
- chunks.append(clean[start:end])
37
- start += step
38
-
39
- return chunks
40
 
41
 
42
  def ingest_pdf_chunks(notebook_id: str, source_id: str, pdf_path: Path) -> int:
 
4
 
5
  from pypdf import PdfReader
6
 
7
+ from backend.chunking import chunk_text_semantic, chunk_text_fallback
8
  from backend.db import supabase
9
  from backend.embedding_service import encode as embed_texts
10
 
11
  import requests
12
  from bs4 import BeautifulSoup
13
 
14
+ DEFAULT_CHUNK_SIZE = 512
15
+ DEFAULT_CHUNK_OVERLAP = 80
16
 
17
 
18
  def _extract_pdf_text(pdf_path: Path) -> str:
 
24
 
25
 
26
  def _chunk_text(text: str, chunk_size: int = DEFAULT_CHUNK_SIZE, overlap: int = DEFAULT_CHUNK_OVERLAP) -> list[str]:
27
+ """Semantic chunking via shared utility."""
28
+ return chunk_text_semantic(text, chunk_size, overlap)
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  def ingest_pdf_chunks(notebook_id: str, source_id: str, pdf_path: Path) -> int:
backend/ingestion_txt.py CHANGED
@@ -8,6 +8,7 @@ import re
8
  from datetime import datetime
9
  from uuid import uuid4
10
 
 
11
  from backend.db import supabase
12
  from backend.storage import save_file, get_sources_path
13
 
@@ -82,35 +83,24 @@ def _create_source_record(
82
  "storage_path": storage_path,
83
  }).execute()
84
 
85
- # Chunking
86
  def chunk_text(text: str, source_id: str, notebook_id: str, filename: str = "") -> list[dict]:
87
- words = text.split()
88
- chunk_size = 400
89
- overlap = 40
90
  chunks = []
91
- i = 0
92
-
93
- # Calculate total chunks upfront
94
- total_chunks = max(1, (len(words) + chunk_size - overlap - 1) // (chunk_size - overlap))
95
-
96
- while i < len(words):
97
- chunk_words = words[i:i + chunk_size]
98
- content = " ".join(chunk_words)
99
  chunks.append({
100
  "id": str(uuid4()),
101
  "source_id": source_id,
102
  "notebook_id": notebook_id,
103
  "content": content,
104
- "chunk_index": len(chunks),
105
  "metadata": {
106
- "word_count": len(chunk_words),
107
  "file_name": filename,
108
- "chunk_index": len(chunks),
109
- "total_chunks": total_chunks,
110
  }
111
  })
112
- i += chunk_size - overlap
113
-
114
  return chunks
115
 
116
 
 
8
  from datetime import datetime
9
  from uuid import uuid4
10
 
11
+ from backend.chunking import chunk_text_semantic
12
  from backend.db import supabase
13
  from backend.storage import save_file, get_sources_path
14
 
 
83
  "storage_path": storage_path,
84
  }).execute()
85
 
86
+ # Chunking - use semantic chunking for better retrieval (aligned with PDF/URL)
87
  def chunk_text(text: str, source_id: str, notebook_id: str, filename: str = "") -> list[dict]:
88
+ content_chunks = chunk_text_semantic(text, chunk_size=512, overlap=80)
 
 
89
  chunks = []
90
+ for i, content in enumerate(content_chunks):
 
 
 
 
 
 
 
91
  chunks.append({
92
  "id": str(uuid4()),
93
  "source_id": source_id,
94
  "notebook_id": notebook_id,
95
  "content": content,
96
+ "chunk_index": i,
97
  "metadata": {
98
+ "word_count": len(content.split()),
99
  "file_name": filename,
100
+ "chunk_index": i,
101
+ "total_chunks": len(content_chunks),
102
  }
103
  })
 
 
104
  return chunks
105
 
106
 
backend/rag_service.py CHANGED
@@ -25,14 +25,16 @@ def _validate_citations(text: str, num_chunks: int) -> str:
25
  return re.sub(r"\[(\d+)\]", replace_citation, text)
26
 
27
 
28
- def rag_chat(notebook_id: str, query: str, chat_history: list) -> tuple[str, list]:
29
  """
30
  RAG chat: retrieve chunks, build prompt, call LLM, persist, return answer and updated history.
31
 
32
  chat_history: list of [user_msg, assistant_msg] pairs (Gradio Chatbot format).
33
- Returns: (assistant_reply, updated_history).
 
 
34
  """
35
- save_message(notebook_id, "user", query)
36
 
37
  chunks = retrieve_chunks(notebook_id, query, top_k=TOP_K)
38
 
@@ -72,7 +74,7 @@ def rag_chat(notebook_id: str, query: str, chat_history: list) -> tuple[str, lis
72
  except Exception as e:
73
  answer = f"Error calling model: {e}"
74
 
75
- save_message(notebook_id, "assistant", answer)
76
 
77
  updated_history = chat_history + [[query, answer]]
78
- return answer, updated_history
 
25
  return re.sub(r"\[(\d+)\]", replace_citation, text)
26
 
27
 
28
+ def rag_chat(notebook_id: str, query: str, chat_history: list, user_id: str | None = None) -> tuple[str, list, list[dict]]:
29
  """
30
  RAG chat: retrieve chunks, build prompt, call LLM, persist, return answer and updated history.
31
 
32
  chat_history: list of [user_msg, assistant_msg] pairs (Gradio Chatbot format).
33
+ user_id: for ownership validation; messages are only saved if notebook belongs to user.
34
+ Returns: (assistant_reply, updated_history, chunks).
35
+ chunks: list of dicts with id, content, metadata, similarity for citation display.
36
  """
37
+ save_message(notebook_id, user_id, "user", query)
38
 
39
  chunks = retrieve_chunks(notebook_id, query, top_k=TOP_K)
40
 
 
74
  except Exception as e:
75
  answer = f"Error calling model: {e}"
76
 
77
+ save_message(notebook_id, user_id, "assistant", answer)
78
 
79
  updated_history = chat_history + [[query, answer]]
80
+ return answer, updated_history, chunks
backend/retrieval_service.py CHANGED
@@ -1,32 +1,72 @@
1
- """Retrieval service - vector similarity search for RAG."""
 
 
2
 
3
  from backend.db import supabase
4
  from backend.embedding_service import encode
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- def retrieve_chunks(notebook_id: str, query: str, top_k: int = 5) -> list[dict]:
 
 
 
 
 
 
 
 
8
  """
9
  Retrieve top-k chunks for a query, filtered by notebook_id.
 
10
 
11
  Returns list of dicts with keys: id, content, metadata, similarity.
12
  """
13
  if not query or not query.strip():
14
  return []
15
 
16
- query_embedding = encode([query.strip()], task="search_query")[0]
 
 
 
17
 
18
  try:
19
  result = supabase.rpc(
20
  "match_chunks",
21
  {
22
  "query_embedding": query_embedding,
23
- "match_count": top_k,
24
  "p_notebook_id": notebook_id,
25
  },
26
  ).execute()
27
 
28
  rows = result.data or []
29
- return [
30
  {
31
  "id": str(r["id"]),
32
  "content": r["content"],
@@ -35,5 +75,16 @@ def retrieve_chunks(notebook_id: str, query: str, top_k: int = 5) -> list[dict]:
35
  }
36
  for r in rows
37
  ]
 
 
 
 
 
 
 
 
 
 
 
38
  except Exception:
39
  return []
 
1
+ """Retrieval service - vector similarity search for RAG with optional reranking."""
2
+
3
+ import os
4
 
5
  from backend.db import supabase
6
  from backend.embedding_service import encode
7
 
8
+ # Retrieve more candidates for reranking; final count after rerank/filter
9
+ RETRIEVE_TOP_K = int(os.getenv("RETRIEVE_TOP_K", "12"))
10
+ FINAL_TOP_K = int(os.getenv("FINAL_TOP_K", "5"))
11
+ SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.2"))
12
+ USE_RERANKER = os.getenv("USE_RERANKER", "true").lower() in ("true", "1", "yes")
13
+
14
+ _reranker = None
15
+
16
+
17
+ def _get_reranker():
18
+ """Lazy-load cross-encoder reranker."""
19
+ global _reranker
20
+ if _reranker is None:
21
+ try:
22
+ from sentence_transformers import CrossEncoder
23
+ _reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
24
+ except Exception:
25
+ _reranker = False # Disabled on failure
26
+ return _reranker if _reranker else None
27
+
28
+
29
+ def _rerank_chunks(query: str, chunks: list[dict], top_k: int) -> list[dict]:
30
+ """Rerank chunks using cross-encoder; return top_k."""
31
+ model = _get_reranker()
32
+ if not model or not chunks:
33
+ return chunks[:top_k]
34
 
35
+ pairs = [(query, c["content"]) for c in chunks]
36
+ scores = model.predict(pairs)
37
+ scored = list(zip(chunks, scores))
38
+ scored.sort(key=lambda x: x[1], reverse=True)
39
+ reranked = [c for c, _ in scored[:top_k]]
40
+ return reranked
41
+
42
+
43
+ def retrieve_chunks(notebook_id: str, query: str, top_k: int = None) -> list[dict]:
44
  """
45
  Retrieve top-k chunks for a query, filtered by notebook_id.
46
+ Uses two-stage retrieval: vector search -> optional rerank -> similarity filter.
47
 
48
  Returns list of dicts with keys: id, content, metadata, similarity.
49
  """
50
  if not query or not query.strip():
51
  return []
52
 
53
+ top_k = top_k or FINAL_TOP_K
54
+ query_clean = query.strip()
55
+
56
+ query_embedding = encode([query_clean], task="search_query")[0]
57
 
58
  try:
59
  result = supabase.rpc(
60
  "match_chunks",
61
  {
62
  "query_embedding": query_embedding,
63
+ "match_count": RETRIEVE_TOP_K,
64
  "p_notebook_id": notebook_id,
65
  },
66
  ).execute()
67
 
68
  rows = result.data or []
69
+ chunks = [
70
  {
71
  "id": str(r["id"]),
72
  "content": r["content"],
 
75
  }
76
  for r in rows
77
  ]
78
+
79
+ # Filter by similarity threshold
80
+ chunks = [c for c in chunks if c["similarity"] >= SIMILARITY_THRESHOLD]
81
+
82
+ # Rerank for better precision
83
+ if USE_RERANKER and len(chunks) > top_k:
84
+ chunks = _rerank_chunks(query_clean, chunks, top_k)
85
+ else:
86
+ chunks = chunks[:top_k]
87
+
88
+ return chunks
89
  except Exception:
90
  return []