Spaces:
Sleeping
Sleeping
Merge pull request #15 from tAnboyy/feature/citations-retrieval-isolation
Browse files- app.py +44 -17
- backend/chat_service.py +26 -4
- backend/chunking.py +81 -0
- backend/embedding_service.py +15 -3
- backend/ingestion_service.py +5 -16
- backend/ingestion_txt.py +8 -18
- backend/rag_service.py +7 -5
- backend/retrieval_service.py +56 -5
app.py
CHANGED
|
@@ -256,6 +256,7 @@ def _initial_load(profile: gr.OAuthProfile | None = None):
|
|
| 256 |
gr.update(visible=bool(user_id)),
|
| 257 |
gr.update(visible=not bool(user_id)),
|
| 258 |
source_status,
|
|
|
|
| 259 |
)
|
| 260 |
|
| 261 |
|
|
@@ -650,29 +651,51 @@ def _chat_history_to_pairs(messages: list[dict]) -> list[tuple[str, str]]:
|
|
| 650 |
return pairs
|
| 651 |
|
| 652 |
|
| 653 |
-
def _load_chat_history(notebook_id) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]:
|
| 654 |
-
"""Load chat for notebook. Returns (history_pairs, history_pairs) for State and Chatbot.
|
|
|
|
|
|
|
| 655 |
if not notebook_id:
|
| 656 |
return [], []
|
| 657 |
-
messages = load_chat(notebook_id)
|
| 658 |
pairs = _chat_history_to_pairs(messages)
|
| 659 |
return pairs, pairs
|
| 660 |
|
| 661 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 662 |
def _on_chat_submit(query, notebook_id, chat_history, profile: gr.OAuthProfile | None):
|
| 663 |
-
"""Handle chat submit: call RAG, return updated history."""
|
| 664 |
if not notebook_id:
|
| 665 |
-
return "", chat_history, "Select a notebook first."
|
| 666 |
if not query or not query.strip():
|
| 667 |
-
return "", chat_history, "Enter a message."
|
| 668 |
user_id = _user_id(profile)
|
| 669 |
if not user_id:
|
| 670 |
-
return "", chat_history, "Please sign in first."
|
| 671 |
try:
|
| 672 |
-
answer, updated = rag_chat(notebook_id, query.strip(), chat_history)
|
| 673 |
-
|
|
|
|
|
|
|
| 674 |
except Exception as e:
|
| 675 |
-
return "", chat_history, f"Error: {e}"
|
| 676 |
|
| 677 |
def _get_quiz_pdfs(source_type, notebook_id):
|
| 678 |
if source_type != "PDF":
|
|
@@ -755,6 +778,7 @@ with gr.Blocks(
|
|
| 755 |
selected_notebook_id = gr.State(None)
|
| 756 |
chat_history_state = gr.State([])
|
| 757 |
quiz_state = gr.State([])
|
|
|
|
| 758 |
|
| 759 |
with gr.Group(elem_classes=["section-card", "manager-card"]):
|
| 760 |
gr.Markdown("**Notebook Manager**", elem_classes=["section-title"])
|
|
@@ -891,6 +915,9 @@ with gr.Blocks(
|
|
| 891 |
)
|
| 892 |
chat_submit_btn = gr.Button("Send", variant="primary")
|
| 893 |
chat_status = gr.Markdown("", elem_classes=["status"])
|
|
|
|
|
|
|
|
|
|
| 894 |
|
| 895 |
with gr.Group(elem_classes=["section-card", "artifacts-card"]):
|
| 896 |
gr.Markdown("**Artifacts**", elem_classes=["section-title"])
|
|
@@ -945,21 +972,21 @@ with gr.Blocks(
|
|
| 945 |
demo.load(
|
| 946 |
_initial_load,
|
| 947 |
inputs=None,
|
| 948 |
-
outputs=[nb_state, selected_notebook_id, notebook_status, auth_text, auth_info_row, app_content, login_container, source_status],
|
| 949 |
api_name=False,
|
| 950 |
)
|
| 951 |
demo.load(_list_uploaded_pdfs, inputs=[selected_notebook_id], outputs=[uploaded_pdf_dd], api_name=False)
|
| 952 |
demo.load(_load_sources, inputs=[selected_notebook_id], outputs=[sources_display], api_name=False)
|
| 953 |
demo.load(_selected_notebook_text, inputs=[selected_notebook_id, nb_state], outputs=[selected_notebook_md], api_name=False)
|
| 954 |
|
| 955 |
-
def _on_notebook_select_for_chat(notebook_id):
|
| 956 |
-
hist, _ = _load_chat_history(notebook_id)
|
| 957 |
-
return hist, hist
|
| 958 |
|
| 959 |
selected_notebook_id.change(
|
| 960 |
_on_notebook_select_for_chat,
|
| 961 |
-
inputs=[selected_notebook_id],
|
| 962 |
-
outputs=[chat_history_state, chatbot],
|
| 963 |
api_name=False,
|
| 964 |
)
|
| 965 |
selected_notebook_id.change(_list_uploaded_pdfs, inputs=[selected_notebook_id], outputs=[uploaded_pdf_dd], api_name=False)
|
|
@@ -1077,7 +1104,7 @@ with gr.Blocks(
|
|
| 1077 |
chat_submit_btn.click(
|
| 1078 |
_on_chat_submit,
|
| 1079 |
inputs=[chat_input, selected_notebook_id, chat_history_state],
|
| 1080 |
-
outputs=[chat_input, chat_history_state, chat_status],
|
| 1081 |
api_name=False,
|
| 1082 |
).then(
|
| 1083 |
lambda h: (h, h),
|
|
|
|
| 256 |
gr.update(visible=bool(user_id)),
|
| 257 |
gr.update(visible=not bool(user_id)),
|
| 258 |
source_status,
|
| 259 |
+
user_id,
|
| 260 |
)
|
| 261 |
|
| 262 |
|
|
|
|
| 651 |
return pairs
|
| 652 |
|
| 653 |
|
| 654 |
+
def _load_chat_history(notebook_id, user_id: str | None) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]:
|
| 655 |
+
"""Load chat for notebook. Returns (history_pairs, history_pairs) for State and Chatbot.
|
| 656 |
+
Only loads if notebook belongs to user (ownership validation).
|
| 657 |
+
"""
|
| 658 |
if not notebook_id:
|
| 659 |
return [], []
|
| 660 |
+
messages = load_chat(notebook_id, user_id)
|
| 661 |
pairs = _chat_history_to_pairs(messages)
|
| 662 |
return pairs, pairs
|
| 663 |
|
| 664 |
|
| 665 |
+
def _format_citations(chunks: list[dict]) -> str:
|
| 666 |
+
"""Format retrieved chunks for citation display."""
|
| 667 |
+
if not chunks:
|
| 668 |
+
return ""
|
| 669 |
+
lines = ["**Sources cited:**", ""]
|
| 670 |
+
for i, c in enumerate(chunks, 1):
|
| 671 |
+
meta = c.get("metadata") or {}
|
| 672 |
+
source_label = meta.get("url") or meta.get("file_name") or meta.get("file_path") or "Source"
|
| 673 |
+
content = (c.get("content") or "")[:300]
|
| 674 |
+
if len(c.get("content") or "") > 300:
|
| 675 |
+
content += "..."
|
| 676 |
+
content = content.replace("\n", " ")
|
| 677 |
+
lines.append(f"**[{i}]** *{source_label}*")
|
| 678 |
+
lines.append(f"> {content}")
|
| 679 |
+
lines.append("")
|
| 680 |
+
return "\n".join(lines)
|
| 681 |
+
|
| 682 |
+
|
| 683 |
def _on_chat_submit(query, notebook_id, chat_history, profile: gr.OAuthProfile | None):
|
| 684 |
+
"""Handle chat submit: call RAG, return updated history and citations."""
|
| 685 |
if not notebook_id:
|
| 686 |
+
return "", chat_history, "Select a notebook first.", "", gr.update(visible=False)
|
| 687 |
if not query or not query.strip():
|
| 688 |
+
return "", chat_history, "Enter a message.", "", gr.update(visible=False)
|
| 689 |
user_id = _user_id(profile)
|
| 690 |
if not user_id:
|
| 691 |
+
return "", chat_history, "Please sign in first.", "", gr.update(visible=False)
|
| 692 |
try:
|
| 693 |
+
answer, updated, chunks = rag_chat(notebook_id, query.strip(), chat_history, user_id=user_id)
|
| 694 |
+
citations_md = _format_citations(chunks)
|
| 695 |
+
accordion_update = gr.update(visible=True) if chunks else gr.update(visible=False)
|
| 696 |
+
return "", updated, "", citations_md, accordion_update
|
| 697 |
except Exception as e:
|
| 698 |
+
return "", chat_history, f"Error: {e}", "", gr.update(visible=False)
|
| 699 |
|
| 700 |
def _get_quiz_pdfs(source_type, notebook_id):
|
| 701 |
if source_type != "PDF":
|
|
|
|
| 778 |
selected_notebook_id = gr.State(None)
|
| 779 |
chat_history_state = gr.State([])
|
| 780 |
quiz_state = gr.State([])
|
| 781 |
+
user_id_state = gr.State(None)
|
| 782 |
|
| 783 |
with gr.Group(elem_classes=["section-card", "manager-card"]):
|
| 784 |
gr.Markdown("**Notebook Manager**", elem_classes=["section-title"])
|
|
|
|
| 915 |
)
|
| 916 |
chat_submit_btn = gr.Button("Send", variant="primary")
|
| 917 |
chat_status = gr.Markdown("", elem_classes=["status"])
|
| 918 |
+
citations_display = gr.Accordion("📎 Sources cited (from last response)", open=True, visible=False)
|
| 919 |
+
with citations_display:
|
| 920 |
+
citations_md = gr.Markdown("", elem_classes=["status"])
|
| 921 |
|
| 922 |
with gr.Group(elem_classes=["section-card", "artifacts-card"]):
|
| 923 |
gr.Markdown("**Artifacts**", elem_classes=["section-title"])
|
|
|
|
| 972 |
demo.load(
|
| 973 |
_initial_load,
|
| 974 |
inputs=None,
|
| 975 |
+
outputs=[nb_state, selected_notebook_id, notebook_status, auth_text, auth_info_row, app_content, login_container, source_status, user_id_state],
|
| 976 |
api_name=False,
|
| 977 |
)
|
| 978 |
demo.load(_list_uploaded_pdfs, inputs=[selected_notebook_id], outputs=[uploaded_pdf_dd], api_name=False)
|
| 979 |
demo.load(_load_sources, inputs=[selected_notebook_id], outputs=[sources_display], api_name=False)
|
| 980 |
demo.load(_selected_notebook_text, inputs=[selected_notebook_id, nb_state], outputs=[selected_notebook_md], api_name=False)
|
| 981 |
|
| 982 |
+
def _on_notebook_select_for_chat(notebook_id, user_id):
|
| 983 |
+
hist, _ = _load_chat_history(notebook_id, user_id)
|
| 984 |
+
return hist, hist, "", gr.update(visible=False)
|
| 985 |
|
| 986 |
selected_notebook_id.change(
|
| 987 |
_on_notebook_select_for_chat,
|
| 988 |
+
inputs=[selected_notebook_id, user_id_state],
|
| 989 |
+
outputs=[chat_history_state, chatbot, citations_md, citations_display],
|
| 990 |
api_name=False,
|
| 991 |
)
|
| 992 |
selected_notebook_id.change(_list_uploaded_pdfs, inputs=[selected_notebook_id], outputs=[uploaded_pdf_dd], api_name=False)
|
|
|
|
| 1104 |
chat_submit_btn.click(
|
| 1105 |
_on_chat_submit,
|
| 1106 |
inputs=[chat_input, selected_notebook_id, chat_history_state],
|
| 1107 |
+
outputs=[chat_input, chat_history_state, chat_status, citations_md, citations_display],
|
| 1108 |
api_name=False,
|
| 1109 |
).then(
|
| 1110 |
lambda h: (h, h),
|
backend/chat_service.py
CHANGED
|
@@ -3,8 +3,28 @@
|
|
| 3 |
from backend.db import supabase
|
| 4 |
|
| 5 |
|
| 6 |
-
def
|
| 7 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
supabase.table("messages").insert({
|
| 9 |
"notebook_id": notebook_id,
|
| 10 |
"role": role,
|
|
@@ -12,8 +32,10 @@ def save_message(notebook_id: str, role: str, content: str) -> None:
|
|
| 12 |
}).execute()
|
| 13 |
|
| 14 |
|
| 15 |
-
def load_chat(notebook_id: str) -> list[dict]:
|
| 16 |
-
"""Load chat history. Returns [
|
|
|
|
|
|
|
| 17 |
result = (
|
| 18 |
supabase.table("messages")
|
| 19 |
.select("role, content, created_at")
|
|
|
|
| 3 |
from backend.db import supabase
|
| 4 |
|
| 5 |
|
| 6 |
+
def _notebook_belongs_to_user(notebook_id: str, user_id: str | None) -> bool:
|
| 7 |
+
"""Verify the notebook is owned by the user. Returns False if user_id is None."""
|
| 8 |
+
if not user_id or not notebook_id:
|
| 9 |
+
return False
|
| 10 |
+
try:
|
| 11 |
+
result = (
|
| 12 |
+
supabase.table("notebooks")
|
| 13 |
+
.select("id")
|
| 14 |
+
.eq("id", notebook_id)
|
| 15 |
+
.eq("user_id", user_id)
|
| 16 |
+
.limit(1)
|
| 17 |
+
.execute()
|
| 18 |
+
)
|
| 19 |
+
return len(result.data or []) > 0
|
| 20 |
+
except Exception:
|
| 21 |
+
return False
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def save_message(notebook_id: str, user_id: str | None, role: str, content: str) -> None:
|
| 25 |
+
"""Append a message only if the notebook belongs to the user."""
|
| 26 |
+
if not _notebook_belongs_to_user(notebook_id, user_id):
|
| 27 |
+
return
|
| 28 |
supabase.table("messages").insert({
|
| 29 |
"notebook_id": notebook_id,
|
| 30 |
"role": role,
|
|
|
|
| 32 |
}).execute()
|
| 33 |
|
| 34 |
|
| 35 |
+
def load_chat(notebook_id: str, user_id: str | None) -> list[dict]:
|
| 36 |
+
"""Load chat history only if the notebook belongs to the user. Returns [] if not owned."""
|
| 37 |
+
if not _notebook_belongs_to_user(notebook_id, user_id):
|
| 38 |
+
return []
|
| 39 |
result = (
|
| 40 |
supabase.table("messages")
|
| 41 |
.select("role, content, created_at")
|
backend/chunking.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared chunking utilities for RAG ingestion."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
DEFAULT_CHUNK_SIZE = 512
|
| 6 |
+
DEFAULT_CHUNK_OVERLAP = 80
|
| 7 |
+
MIN_CHUNK_SIZE = 100
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _split_into_sentences(text: str) -> list[str]:
|
| 11 |
+
"""Split text on sentence boundaries (rough heuristic)."""
|
| 12 |
+
text = re.sub(r"\n+", "\n", text.strip())
|
| 13 |
+
if not text:
|
| 14 |
+
return []
|
| 15 |
+
parts = re.split(r"(?<=[.!?])\s+", text)
|
| 16 |
+
return [p.strip() for p in parts if p.strip()]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def chunk_text_semantic(
|
| 20 |
+
text: str,
|
| 21 |
+
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
| 22 |
+
overlap: int = DEFAULT_CHUNK_OVERLAP,
|
| 23 |
+
) -> list[str]:
|
| 24 |
+
"""
|
| 25 |
+
Semantic chunking: split on paragraphs first, then sentences.
|
| 26 |
+
Preserves context better than blind character splits.
|
| 27 |
+
"""
|
| 28 |
+
text = " ".join(text.split())
|
| 29 |
+
if not text:
|
| 30 |
+
return []
|
| 31 |
+
|
| 32 |
+
paragraphs = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
|
| 33 |
+
if len(paragraphs) <= 1:
|
| 34 |
+
sentences = _split_into_sentences(text)
|
| 35 |
+
if not sentences:
|
| 36 |
+
sentences = [text]
|
| 37 |
+
if len(sentences) == 1 and len(sentences[0]) > chunk_size * 2:
|
| 38 |
+
return chunk_text_fallback(text, chunk_size, overlap)
|
| 39 |
+
paragraphs = sentences
|
| 40 |
+
|
| 41 |
+
chunks = []
|
| 42 |
+
current_chunk = []
|
| 43 |
+
current_len = 0
|
| 44 |
+
|
| 45 |
+
for para in paragraphs:
|
| 46 |
+
para_len = len(para) + 1
|
| 47 |
+
if current_len + para_len > chunk_size and current_chunk:
|
| 48 |
+
chunk_text = " ".join(current_chunk)
|
| 49 |
+
if len(chunk_text) >= MIN_CHUNK_SIZE:
|
| 50 |
+
chunks.append(chunk_text)
|
| 51 |
+
overlap_len = 0
|
| 52 |
+
overlap_items = []
|
| 53 |
+
for item in reversed(current_chunk):
|
| 54 |
+
if overlap_len + len(item) + 1 <= overlap:
|
| 55 |
+
overlap_items.insert(0, item)
|
| 56 |
+
overlap_len += len(item) + 1
|
| 57 |
+
else:
|
| 58 |
+
break
|
| 59 |
+
current_chunk = overlap_items
|
| 60 |
+
current_len = overlap_len
|
| 61 |
+
current_chunk.append(para)
|
| 62 |
+
current_len += para_len
|
| 63 |
+
|
| 64 |
+
if current_chunk:
|
| 65 |
+
chunks.append(" ".join(current_chunk))
|
| 66 |
+
return chunks
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def chunk_text_fallback(text: str, chunk_size: int, overlap: int) -> list[str]:
|
| 70 |
+
"""Character-based chunking when semantic splitting fails."""
|
| 71 |
+
clean = " ".join(text.split())
|
| 72 |
+
if not clean:
|
| 73 |
+
return []
|
| 74 |
+
chunks = []
|
| 75 |
+
start = 0
|
| 76 |
+
step = max(1, chunk_size - overlap)
|
| 77 |
+
while start < len(clean):
|
| 78 |
+
end = min(len(clean), start + chunk_size)
|
| 79 |
+
chunks.append(clean[start:end])
|
| 80 |
+
start += step
|
| 81 |
+
return chunks
|
backend/embedding_service.py
CHANGED
|
@@ -1,10 +1,16 @@
|
|
| 1 |
-
"""Shared embedding service - 384-dim vectors for RAG (ingestion + retrieval).
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from sentence_transformers import SentenceTransformer
|
| 4 |
|
| 5 |
-
|
|
|
|
| 6 |
_model = None
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
def _get_model() -> SentenceTransformer:
|
| 10 |
"""Lazy-load the embedding model."""
|
|
@@ -14,17 +20,23 @@ def _get_model() -> SentenceTransformer:
|
|
| 14 |
return _model
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def encode(texts: list[str], task: str = "search_document") -> list[list[float]]:
|
| 18 |
"""
|
| 19 |
Embed texts. Returns list of 384-dim vectors.
|
| 20 |
|
| 21 |
Args:
|
| 22 |
texts: List of strings to embed.
|
| 23 |
-
task:
|
| 24 |
"""
|
| 25 |
if not texts:
|
| 26 |
return []
|
| 27 |
|
| 28 |
model = _get_model()
|
|
|
|
|
|
|
| 29 |
embeddings = model.encode(texts, show_progress_bar=False)
|
| 30 |
return [e.tolist() for e in embeddings]
|
|
|
|
| 1 |
+
"""Shared embedding service - 384-dim vectors for RAG (ingestion + retrieval)."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
|
| 5 |
from sentence_transformers import SentenceTransformer
|
| 6 |
|
| 7 |
+
# all-MiniLM-L6-v2 (default) or BAAI/bge-small-en-v1.5 for better quality (both 384 dims)
|
| 8 |
+
_MODEL_NAME = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
|
| 9 |
_model = None
|
| 10 |
|
| 11 |
+
# BGE models: add prefix only to queries, not to documents
|
| 12 |
+
_BGE_QUERY_PREFIX = "Represent this sentence for searching relevant passages: "
|
| 13 |
+
|
| 14 |
|
| 15 |
def _get_model() -> SentenceTransformer:
|
| 16 |
"""Lazy-load the embedding model."""
|
|
|
|
| 20 |
return _model
|
| 21 |
|
| 22 |
|
| 23 |
+
def _is_bge_model() -> bool:
|
| 24 |
+
return "bge" in _MODEL_NAME.lower()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
def encode(texts: list[str], task: str = "search_document") -> list[list[float]]:
|
| 28 |
"""
|
| 29 |
Embed texts. Returns list of 384-dim vectors.
|
| 30 |
|
| 31 |
Args:
|
| 32 |
texts: List of strings to embed.
|
| 33 |
+
task: "search_query" for queries, "search_document" for documents. BGE uses prefixes.
|
| 34 |
"""
|
| 35 |
if not texts:
|
| 36 |
return []
|
| 37 |
|
| 38 |
model = _get_model()
|
| 39 |
+
if _is_bge_model() and task == "search_query":
|
| 40 |
+
texts = [_BGE_QUERY_PREFIX + t for t in texts]
|
| 41 |
embeddings = model.encode(texts, show_progress_bar=False)
|
| 42 |
return [e.tolist() for e in embeddings]
|
backend/ingestion_service.py
CHANGED
|
@@ -4,14 +4,15 @@ from pathlib import Path
|
|
| 4 |
|
| 5 |
from pypdf import PdfReader
|
| 6 |
|
|
|
|
| 7 |
from backend.db import supabase
|
| 8 |
from backend.embedding_service import encode as embed_texts
|
| 9 |
|
| 10 |
import requests
|
| 11 |
from bs4 import BeautifulSoup
|
| 12 |
|
| 13 |
-
DEFAULT_CHUNK_SIZE =
|
| 14 |
-
DEFAULT_CHUNK_OVERLAP =
|
| 15 |
|
| 16 |
|
| 17 |
def _extract_pdf_text(pdf_path: Path) -> str:
|
|
@@ -23,20 +24,8 @@ def _extract_pdf_text(pdf_path: Path) -> str:
|
|
| 23 |
|
| 24 |
|
| 25 |
def _chunk_text(text: str, chunk_size: int = DEFAULT_CHUNK_SIZE, overlap: int = DEFAULT_CHUNK_OVERLAP) -> list[str]:
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
return []
|
| 29 |
-
|
| 30 |
-
chunks: list[str] = []
|
| 31 |
-
start = 0
|
| 32 |
-
step = max(1, chunk_size - overlap)
|
| 33 |
-
|
| 34 |
-
while start < len(clean):
|
| 35 |
-
end = min(len(clean), start + chunk_size)
|
| 36 |
-
chunks.append(clean[start:end])
|
| 37 |
-
start += step
|
| 38 |
-
|
| 39 |
-
return chunks
|
| 40 |
|
| 41 |
|
| 42 |
def ingest_pdf_chunks(notebook_id: str, source_id: str, pdf_path: Path) -> int:
|
|
|
|
| 4 |
|
| 5 |
from pypdf import PdfReader
|
| 6 |
|
| 7 |
+
from backend.chunking import chunk_text_semantic, chunk_text_fallback
|
| 8 |
from backend.db import supabase
|
| 9 |
from backend.embedding_service import encode as embed_texts
|
| 10 |
|
| 11 |
import requests
|
| 12 |
from bs4 import BeautifulSoup
|
| 13 |
|
| 14 |
+
DEFAULT_CHUNK_SIZE = 512
|
| 15 |
+
DEFAULT_CHUNK_OVERLAP = 80
|
| 16 |
|
| 17 |
|
| 18 |
def _extract_pdf_text(pdf_path: Path) -> str:
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
def _chunk_text(text: str, chunk_size: int = DEFAULT_CHUNK_SIZE, overlap: int = DEFAULT_CHUNK_OVERLAP) -> list[str]:
|
| 27 |
+
"""Semantic chunking via shared utility."""
|
| 28 |
+
return chunk_text_semantic(text, chunk_size, overlap)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def ingest_pdf_chunks(notebook_id: str, source_id: str, pdf_path: Path) -> int:
|
backend/ingestion_txt.py
CHANGED
|
@@ -8,6 +8,7 @@ import re
|
|
| 8 |
from datetime import datetime
|
| 9 |
from uuid import uuid4
|
| 10 |
|
|
|
|
| 11 |
from backend.db import supabase
|
| 12 |
from backend.storage import save_file, get_sources_path
|
| 13 |
|
|
@@ -82,35 +83,24 @@ def _create_source_record(
|
|
| 82 |
"storage_path": storage_path,
|
| 83 |
}).execute()
|
| 84 |
|
| 85 |
-
# Chunking
|
| 86 |
def chunk_text(text: str, source_id: str, notebook_id: str, filename: str = "") -> list[dict]:
|
| 87 |
-
|
| 88 |
-
chunk_size = 400
|
| 89 |
-
overlap = 40
|
| 90 |
chunks = []
|
| 91 |
-
i
|
| 92 |
-
|
| 93 |
-
# Calculate total chunks upfront
|
| 94 |
-
total_chunks = max(1, (len(words) + chunk_size - overlap - 1) // (chunk_size - overlap))
|
| 95 |
-
|
| 96 |
-
while i < len(words):
|
| 97 |
-
chunk_words = words[i:i + chunk_size]
|
| 98 |
-
content = " ".join(chunk_words)
|
| 99 |
chunks.append({
|
| 100 |
"id": str(uuid4()),
|
| 101 |
"source_id": source_id,
|
| 102 |
"notebook_id": notebook_id,
|
| 103 |
"content": content,
|
| 104 |
-
"chunk_index":
|
| 105 |
"metadata": {
|
| 106 |
-
"word_count": len(
|
| 107 |
"file_name": filename,
|
| 108 |
-
"chunk_index":
|
| 109 |
-
"total_chunks":
|
| 110 |
}
|
| 111 |
})
|
| 112 |
-
i += chunk_size - overlap
|
| 113 |
-
|
| 114 |
return chunks
|
| 115 |
|
| 116 |
|
|
|
|
| 8 |
from datetime import datetime
|
| 9 |
from uuid import uuid4
|
| 10 |
|
| 11 |
+
from backend.chunking import chunk_text_semantic
|
| 12 |
from backend.db import supabase
|
| 13 |
from backend.storage import save_file, get_sources_path
|
| 14 |
|
|
|
|
| 83 |
"storage_path": storage_path,
|
| 84 |
}).execute()
|
| 85 |
|
| 86 |
+
# Chunking - use semantic chunking for better retrieval (aligned with PDF/URL)
|
| 87 |
def chunk_text(text: str, source_id: str, notebook_id: str, filename: str = "") -> list[dict]:
|
| 88 |
+
content_chunks = chunk_text_semantic(text, chunk_size=512, overlap=80)
|
|
|
|
|
|
|
| 89 |
chunks = []
|
| 90 |
+
for i, content in enumerate(content_chunks):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
chunks.append({
|
| 92 |
"id": str(uuid4()),
|
| 93 |
"source_id": source_id,
|
| 94 |
"notebook_id": notebook_id,
|
| 95 |
"content": content,
|
| 96 |
+
"chunk_index": i,
|
| 97 |
"metadata": {
|
| 98 |
+
"word_count": len(content.split()),
|
| 99 |
"file_name": filename,
|
| 100 |
+
"chunk_index": i,
|
| 101 |
+
"total_chunks": len(content_chunks),
|
| 102 |
}
|
| 103 |
})
|
|
|
|
|
|
|
| 104 |
return chunks
|
| 105 |
|
| 106 |
|
backend/rag_service.py
CHANGED
|
@@ -25,14 +25,16 @@ def _validate_citations(text: str, num_chunks: int) -> str:
|
|
| 25 |
return re.sub(r"\[(\d+)\]", replace_citation, text)
|
| 26 |
|
| 27 |
|
| 28 |
-
def rag_chat(notebook_id: str, query: str, chat_history: list) -> tuple[str, list]:
|
| 29 |
"""
|
| 30 |
RAG chat: retrieve chunks, build prompt, call LLM, persist, return answer and updated history.
|
| 31 |
|
| 32 |
chat_history: list of [user_msg, assistant_msg] pairs (Gradio Chatbot format).
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
"""
|
| 35 |
-
save_message(notebook_id, "user", query)
|
| 36 |
|
| 37 |
chunks = retrieve_chunks(notebook_id, query, top_k=TOP_K)
|
| 38 |
|
|
@@ -72,7 +74,7 @@ def rag_chat(notebook_id: str, query: str, chat_history: list) -> tuple[str, lis
|
|
| 72 |
except Exception as e:
|
| 73 |
answer = f"Error calling model: {e}"
|
| 74 |
|
| 75 |
-
save_message(notebook_id, "assistant", answer)
|
| 76 |
|
| 77 |
updated_history = chat_history + [[query, answer]]
|
| 78 |
-
return answer, updated_history
|
|
|
|
| 25 |
return re.sub(r"\[(\d+)\]", replace_citation, text)
|
| 26 |
|
| 27 |
|
| 28 |
+
def rag_chat(notebook_id: str, query: str, chat_history: list, user_id: str | None = None) -> tuple[str, list, list[dict]]:
|
| 29 |
"""
|
| 30 |
RAG chat: retrieve chunks, build prompt, call LLM, persist, return answer and updated history.
|
| 31 |
|
| 32 |
chat_history: list of [user_msg, assistant_msg] pairs (Gradio Chatbot format).
|
| 33 |
+
user_id: for ownership validation; messages are only saved if notebook belongs to user.
|
| 34 |
+
Returns: (assistant_reply, updated_history, chunks).
|
| 35 |
+
chunks: list of dicts with id, content, metadata, similarity for citation display.
|
| 36 |
"""
|
| 37 |
+
save_message(notebook_id, user_id, "user", query)
|
| 38 |
|
| 39 |
chunks = retrieve_chunks(notebook_id, query, top_k=TOP_K)
|
| 40 |
|
|
|
|
| 74 |
except Exception as e:
|
| 75 |
answer = f"Error calling model: {e}"
|
| 76 |
|
| 77 |
+
save_message(notebook_id, user_id, "assistant", answer)
|
| 78 |
|
| 79 |
updated_history = chat_history + [[query, answer]]
|
| 80 |
+
return answer, updated_history, chunks
|
backend/retrieval_service.py
CHANGED
|
@@ -1,32 +1,72 @@
|
|
| 1 |
-
"""Retrieval service - vector similarity search for RAG."""
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from backend.db import supabase
|
| 4 |
from backend.embedding_service import encode
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
Retrieve top-k chunks for a query, filtered by notebook_id.
|
|
|
|
| 10 |
|
| 11 |
Returns list of dicts with keys: id, content, metadata, similarity.
|
| 12 |
"""
|
| 13 |
if not query or not query.strip():
|
| 14 |
return []
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
try:
|
| 19 |
result = supabase.rpc(
|
| 20 |
"match_chunks",
|
| 21 |
{
|
| 22 |
"query_embedding": query_embedding,
|
| 23 |
-
"match_count":
|
| 24 |
"p_notebook_id": notebook_id,
|
| 25 |
},
|
| 26 |
).execute()
|
| 27 |
|
| 28 |
rows = result.data or []
|
| 29 |
-
|
| 30 |
{
|
| 31 |
"id": str(r["id"]),
|
| 32 |
"content": r["content"],
|
|
@@ -35,5 +75,16 @@ def retrieve_chunks(notebook_id: str, query: str, top_k: int = 5) -> list[dict]:
|
|
| 35 |
}
|
| 36 |
for r in rows
|
| 37 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
except Exception:
|
| 39 |
return []
|
|
|
|
| 1 |
+
"""Retrieval service - vector similarity search for RAG with optional reranking."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
|
| 5 |
from backend.db import supabase
|
| 6 |
from backend.embedding_service import encode
|
| 7 |
|
| 8 |
+
# Retrieve more candidates for reranking; final count after rerank/filter
|
| 9 |
+
RETRIEVE_TOP_K = int(os.getenv("RETRIEVE_TOP_K", "12"))
|
| 10 |
+
FINAL_TOP_K = int(os.getenv("FINAL_TOP_K", "5"))
|
| 11 |
+
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.2"))
|
| 12 |
+
USE_RERANKER = os.getenv("USE_RERANKER", "true").lower() in ("true", "1", "yes")
|
| 13 |
+
|
| 14 |
+
_reranker = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _get_reranker():
|
| 18 |
+
"""Lazy-load cross-encoder reranker."""
|
| 19 |
+
global _reranker
|
| 20 |
+
if _reranker is None:
|
| 21 |
+
try:
|
| 22 |
+
from sentence_transformers import CrossEncoder
|
| 23 |
+
_reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
|
| 24 |
+
except Exception:
|
| 25 |
+
_reranker = False # Disabled on failure
|
| 26 |
+
return _reranker if _reranker else None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _rerank_chunks(query: str, chunks: list[dict], top_k: int) -> list[dict]:
|
| 30 |
+
"""Rerank chunks using cross-encoder; return top_k."""
|
| 31 |
+
model = _get_reranker()
|
| 32 |
+
if not model or not chunks:
|
| 33 |
+
return chunks[:top_k]
|
| 34 |
|
| 35 |
+
pairs = [(query, c["content"]) for c in chunks]
|
| 36 |
+
scores = model.predict(pairs)
|
| 37 |
+
scored = list(zip(chunks, scores))
|
| 38 |
+
scored.sort(key=lambda x: x[1], reverse=True)
|
| 39 |
+
reranked = [c for c, _ in scored[:top_k]]
|
| 40 |
+
return reranked
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def retrieve_chunks(notebook_id: str, query: str, top_k: int = None) -> list[dict]:
|
| 44 |
"""
|
| 45 |
Retrieve top-k chunks for a query, filtered by notebook_id.
|
| 46 |
+
Uses two-stage retrieval: vector search -> optional rerank -> similarity filter.
|
| 47 |
|
| 48 |
Returns list of dicts with keys: id, content, metadata, similarity.
|
| 49 |
"""
|
| 50 |
if not query or not query.strip():
|
| 51 |
return []
|
| 52 |
|
| 53 |
+
top_k = top_k or FINAL_TOP_K
|
| 54 |
+
query_clean = query.strip()
|
| 55 |
+
|
| 56 |
+
query_embedding = encode([query_clean], task="search_query")[0]
|
| 57 |
|
| 58 |
try:
|
| 59 |
result = supabase.rpc(
|
| 60 |
"match_chunks",
|
| 61 |
{
|
| 62 |
"query_embedding": query_embedding,
|
| 63 |
+
"match_count": RETRIEVE_TOP_K,
|
| 64 |
"p_notebook_id": notebook_id,
|
| 65 |
},
|
| 66 |
).execute()
|
| 67 |
|
| 68 |
rows = result.data or []
|
| 69 |
+
chunks = [
|
| 70 |
{
|
| 71 |
"id": str(r["id"]),
|
| 72 |
"content": r["content"],
|
|
|
|
| 75 |
}
|
| 76 |
for r in rows
|
| 77 |
]
|
| 78 |
+
|
| 79 |
+
# Filter by similarity threshold
|
| 80 |
+
chunks = [c for c in chunks if c["similarity"] >= SIMILARITY_THRESHOLD]
|
| 81 |
+
|
| 82 |
+
# Rerank for better precision
|
| 83 |
+
if USE_RERANKER and len(chunks) > top_k:
|
| 84 |
+
chunks = _rerank_chunks(query_clean, chunks, top_k)
|
| 85 |
+
else:
|
| 86 |
+
chunks = chunks[:top_k]
|
| 87 |
+
|
| 88 |
+
return chunks
|
| 89 |
except Exception:
|
| 90 |
return []
|