Max Saavedra commited on
Commit
10aaf26
·
1 Parent(s): 87f9e22

RAG techniques and benchmark

Browse files
README.md CHANGED
@@ -19,6 +19,7 @@ MemoriaLM is a full-stack RAG application inspired by NotebookLM. It allows auth
19
  - Upload documents (`.pdf`, `.pptx`, `.txt`) and URLs
20
  - Chat with uploaded content using RAG
21
  - Generate study artifacts (reports, quizzes, podcasts)
 
22
 
23
  ### Architecture
24
  - **Frontend**: Gradio UI (frontend/app.py)
@@ -78,3 +79,9 @@ Run all tests:
78
  ```sh
79
  pytest tests/
80
  ```
 
 
 
 
 
 
 
19
  - Upload documents (`.pdf`, `.pptx`, `.txt`) and URLs
20
  - Chat with uploaded content using RAG
21
  - Generate study artifacts (reports, quizzes, podcasts)
22
+ - Retrieval modes for comparison: `topk` and `rerank`
23
 
24
  ### Architecture
25
  - **Frontend**: Gradio UI (frontend/app.py)
 
79
  ```sh
80
  pytest tests/
81
  ```
82
+
83
+ ## RAG Technique Comparison
84
+ See `docs/rag_techniques.md` for:
85
+ - implemented retrieval modes (`topk` vs `rerank`)
86
+ - benchmark command (`scripts/rag_benchmark.py`)
87
+ - results table template for the project deliverable
backend/api/artifacts.py CHANGED
@@ -1,39 +1,50 @@
1
- from fastapi import APIRouter, HTTPException, Query
2
- from fastapi.responses import FileResponse
3
-
4
- from backend.models.schemas import ArtifactGenerateOut, ArtifactGenerateRequest, ArtifactListOut
5
- from backend.modules.artifacts import (
6
  generate_podcast,
7
  generate_quiz,
8
  generate_report,
9
- list_artifacts,
10
- resolve_artifact_path,
11
- )
12
- from backend.services.storage import NotebookStore
 
13
 
14
  router = APIRouter()
15
  store = NotebookStore(base_dir="data")
16
 
17
 
18
- @router.get("/{notebook_id}/artifacts", response_model=ArtifactListOut)
19
- def list_notebook_artifacts(notebook_id: str, user_id: str = Query(...)) -> ArtifactListOut:
20
- try:
21
- return list_artifacts(store, user_id=user_id, notebook_id=notebook_id)
22
- except FileNotFoundError:
23
- raise HTTPException(status_code=404, detail="Notebook not found")
 
 
 
 
 
24
  except ValueError as exc:
25
  raise HTTPException(status_code=400, detail=str(exc)) from exc
26
 
27
 
28
- @router.post("/{notebook_id}/artifacts/report", response_model=ArtifactGenerateOut)
29
- def create_report_artifact(notebook_id: str, payload: ArtifactGenerateRequest) -> ArtifactGenerateOut:
30
- try:
31
- return generate_report(
32
- store,
33
- user_id=payload.user_id,
34
- notebook_id=notebook_id,
35
- prompt=payload.prompt,
36
- )
 
 
 
 
 
37
  except FileNotFoundError:
38
  raise HTTPException(status_code=404, detail="Notebook not found")
39
  except ValueError as exc:
@@ -42,15 +53,20 @@ def create_report_artifact(notebook_id: str, payload: ArtifactGenerateRequest) -
42
  raise HTTPException(status_code=500, detail=f"Report generation failed: {exc}") from exc
43
 
44
 
45
- @router.post("/{notebook_id}/artifacts/quiz", response_model=ArtifactGenerateOut)
46
- def create_quiz_artifact(notebook_id: str, payload: ArtifactGenerateRequest) -> ArtifactGenerateOut:
47
- try:
48
- return generate_quiz(
49
- store,
50
- user_id=payload.user_id,
51
- notebook_id=notebook_id,
52
- prompt=payload.prompt,
53
- num_questions=payload.num_questions,
 
 
 
 
 
54
  )
55
  except FileNotFoundError:
56
  raise HTTPException(status_code=404, detail="Notebook not found")
@@ -60,15 +76,20 @@ def create_quiz_artifact(notebook_id: str, payload: ArtifactGenerateRequest) ->
60
  raise HTTPException(status_code=500, detail=f"Quiz generation failed: {exc}") from exc
61
 
62
 
63
- @router.post("/{notebook_id}/artifacts/podcast", response_model=ArtifactGenerateOut)
64
- def create_podcast_artifact(notebook_id: str, payload: ArtifactGenerateRequest) -> ArtifactGenerateOut:
65
- try:
66
- return generate_podcast(
67
- store,
68
- user_id=payload.user_id,
69
- notebook_id=notebook_id,
70
- prompt=payload.prompt,
71
- )
 
 
 
 
 
72
  except FileNotFoundError:
73
  raise HTTPException(status_code=404, detail="Notebook not found")
74
  except ValueError as exc:
@@ -79,18 +100,20 @@ def create_podcast_artifact(notebook_id: str, payload: ArtifactGenerateRequest)
79
 
80
  @router.get("/{notebook_id}/artifacts/download")
81
  def download_artifact(
82
- notebook_id: str,
83
- user_id: str = Query(...),
84
- artifact_type: str = Query(...),
85
- filename: str = Query(...),
86
- ):
87
- try:
88
- path = resolve_artifact_path(
89
- store,
90
- user_id=user_id,
91
- notebook_id=notebook_id,
92
- artifact_type=artifact_type,
93
- filename=filename,
 
 
94
  )
95
  media_type = "audio/mpeg" if path.suffix.lower() == ".mp3" else "text/markdown"
96
  return FileResponse(path=path, filename=path.name, media_type=media_type)
 
1
+ from fastapi import APIRouter, Depends, HTTPException, Query
2
+ from fastapi.responses import FileResponse
3
+
4
+ from backend.models.schemas import ArtifactGenerateOut, ArtifactGenerateRequest, ArtifactListOut
5
+ from backend.modules.artifacts import (
6
  generate_podcast,
7
  generate_quiz,
8
  generate_report,
9
+ list_artifacts,
10
+ resolve_artifact_path,
11
+ )
12
+ from backend.services.auth import User, enforce_user_match, get_current_user
13
+ from backend.services.storage import NotebookStore
14
 
15
  router = APIRouter()
16
  store = NotebookStore(base_dir="data")
17
 
18
 
19
+ @router.get("/{notebook_id}/artifacts", response_model=ArtifactListOut)
20
+ def list_notebook_artifacts(
21
+ notebook_id: str,
22
+ user_id: str | None = Query(default=None),
23
+ current_user: User = Depends(get_current_user),
24
+ ) -> ArtifactListOut:
25
+ try:
26
+ enforce_user_match(current_user, user_id)
27
+ return list_artifacts(store, user_id=current_user.user_id, notebook_id=notebook_id)
28
+ except FileNotFoundError:
29
+ raise HTTPException(status_code=404, detail="Notebook not found")
30
  except ValueError as exc:
31
  raise HTTPException(status_code=400, detail=str(exc)) from exc
32
 
33
 
34
+ @router.post("/{notebook_id}/artifacts/report", response_model=ArtifactGenerateOut)
35
+ def create_report_artifact(
36
+ notebook_id: str,
37
+ payload: ArtifactGenerateRequest,
38
+ current_user: User = Depends(get_current_user),
39
+ ) -> ArtifactGenerateOut:
40
+ try:
41
+ enforce_user_match(current_user, payload.user_id)
42
+ return generate_report(
43
+ store,
44
+ user_id=current_user.user_id,
45
+ notebook_id=notebook_id,
46
+ prompt=payload.prompt,
47
+ )
48
  except FileNotFoundError:
49
  raise HTTPException(status_code=404, detail="Notebook not found")
50
  except ValueError as exc:
 
53
  raise HTTPException(status_code=500, detail=f"Report generation failed: {exc}") from exc
54
 
55
 
56
+ @router.post("/{notebook_id}/artifacts/quiz", response_model=ArtifactGenerateOut)
57
+ def create_quiz_artifact(
58
+ notebook_id: str,
59
+ payload: ArtifactGenerateRequest,
60
+ current_user: User = Depends(get_current_user),
61
+ ) -> ArtifactGenerateOut:
62
+ try:
63
+ enforce_user_match(current_user, payload.user_id)
64
+ return generate_quiz(
65
+ store,
66
+ user_id=current_user.user_id,
67
+ notebook_id=notebook_id,
68
+ prompt=payload.prompt,
69
+ num_questions=payload.num_questions,
70
  )
71
  except FileNotFoundError:
72
  raise HTTPException(status_code=404, detail="Notebook not found")
 
76
  raise HTTPException(status_code=500, detail=f"Quiz generation failed: {exc}") from exc
77
 
78
 
79
+ @router.post("/{notebook_id}/artifacts/podcast", response_model=ArtifactGenerateOut)
80
+ def create_podcast_artifact(
81
+ notebook_id: str,
82
+ payload: ArtifactGenerateRequest,
83
+ current_user: User = Depends(get_current_user),
84
+ ) -> ArtifactGenerateOut:
85
+ try:
86
+ enforce_user_match(current_user, payload.user_id)
87
+ return generate_podcast(
88
+ store,
89
+ user_id=current_user.user_id,
90
+ notebook_id=notebook_id,
91
+ prompt=payload.prompt,
92
+ )
93
  except FileNotFoundError:
94
  raise HTTPException(status_code=404, detail="Notebook not found")
95
  except ValueError as exc:
 
100
 
101
  @router.get("/{notebook_id}/artifacts/download")
102
  def download_artifact(
103
+ notebook_id: str,
104
+ user_id: str | None = Query(default=None),
105
+ artifact_type: str = Query(...),
106
+ filename: str = Query(...),
107
+ current_user: User = Depends(get_current_user),
108
+ ):
109
+ try:
110
+ enforce_user_match(current_user, user_id)
111
+ path = resolve_artifact_path(
112
+ store,
113
+ user_id=current_user.user_id,
114
+ notebook_id=notebook_id,
115
+ artifact_type=artifact_type,
116
+ filename=filename,
117
  )
118
  media_type = "audio/mpeg" if path.suffix.lower() == ".mp3" else "text/markdown"
119
  return FileResponse(path=path, filename=path.name, media_type=media_type)
backend/api/chat.py CHANGED
@@ -1,7 +1,8 @@
1
- from fastapi import APIRouter, HTTPException, Query
2
 
3
  from backend.models.schemas import ChatHistoryOut, ChatMessageOut, ChatRequest, ChatResponseOut
4
  from backend.modules.rag import answer_notebook_question, get_chat_history
 
5
  from backend.services.storage import NotebookStore
6
 
7
  router = APIRouter()
@@ -9,9 +10,9 @@ store = NotebookStore(base_dir="data")
9
 
10
 
11
  @router.get("/{notebook_id}/chat", response_model=ChatHistoryOut)
12
- def chat_history(notebook_id: str, user_id: str = Query(...)) -> ChatHistoryOut:
13
  try:
14
- messages = get_chat_history(store, user_id=user_id, notebook_id=notebook_id)
15
  return ChatHistoryOut(messages=[ChatMessageOut(**m) for m in messages])
16
  except FileNotFoundError:
17
  raise HTTPException(status_code=404, detail="Notebook not found")
@@ -20,14 +21,20 @@ def chat_history(notebook_id: str, user_id: str = Query(...)) -> ChatHistoryOut:
20
 
21
 
22
  @router.post("/{notebook_id}/chat", response_model=ChatResponseOut)
23
- def chat(notebook_id: str, payload: ChatRequest) -> ChatResponseOut:
 
 
 
 
24
  try:
 
25
  result = answer_notebook_question(
26
  store,
27
- user_id=payload.user_id,
28
  notebook_id=notebook_id,
29
  message=payload.message,
30
  top_k=payload.top_k,
 
31
  )
32
  return ChatResponseOut(**result)
33
  except FileNotFoundError:
 
1
+ from fastapi import APIRouter, Depends, HTTPException
2
 
3
  from backend.models.schemas import ChatHistoryOut, ChatMessageOut, ChatRequest, ChatResponseOut
4
  from backend.modules.rag import answer_notebook_question, get_chat_history
5
+ from backend.services.auth import User, enforce_user_match, get_current_user
6
  from backend.services.storage import NotebookStore
7
 
8
  router = APIRouter()
 
10
 
11
 
12
  @router.get("/{notebook_id}/chat", response_model=ChatHistoryOut)
13
+ def chat_history(notebook_id: str, current_user: User = Depends(get_current_user)) -> ChatHistoryOut:
14
  try:
15
+ messages = get_chat_history(store, user_id=current_user.user_id, notebook_id=notebook_id)
16
  return ChatHistoryOut(messages=[ChatMessageOut(**m) for m in messages])
17
  except FileNotFoundError:
18
  raise HTTPException(status_code=404, detail="Notebook not found")
 
21
 
22
 
23
  @router.post("/{notebook_id}/chat", response_model=ChatResponseOut)
24
+ def chat(
25
+ notebook_id: str,
26
+ payload: ChatRequest,
27
+ current_user: User = Depends(get_current_user),
28
+ ) -> ChatResponseOut:
29
  try:
30
+ enforce_user_match(current_user, payload.user_id)
31
  result = answer_notebook_question(
32
  store,
33
+ user_id=current_user.user_id,
34
  notebook_id=notebook_id,
35
  message=payload.message,
36
  top_k=payload.top_k,
37
+ retrieval_mode=payload.retrieval_mode,
38
  )
39
  return ChatResponseOut(**result)
40
  except FileNotFoundError:
backend/api/notebooks.py CHANGED
@@ -1,8 +1,9 @@
1
  from typing import List
2
 
3
- from fastapi import APIRouter, HTTPException, Query
4
 
5
  from backend.models.schemas import NotebookCreate, NotebookOut, NotebookRename
 
6
  from backend.services.storage import NotebookStore
7
 
8
  router = APIRouter()
@@ -10,25 +11,29 @@ store = NotebookStore(base_dir="data")
10
 
11
 
12
  @router.get("/", response_model=List[NotebookOut])
13
- def list_notebooks(user_id: str = Query(...)) -> List[NotebookOut]:
14
  try:
15
- return store.list(user_id)
16
  except ValueError as exc:
17
  raise HTTPException(status_code=400, detail=str(exc)) from exc
18
 
19
 
20
  @router.post("/", response_model=NotebookOut)
21
- def create_notebook(payload: NotebookCreate) -> NotebookOut:
 
 
 
22
  try:
23
- return store.create(payload)
 
24
  except ValueError as exc:
25
  raise HTTPException(status_code=400, detail=str(exc)) from exc
26
 
27
 
28
  @router.get("/{notebook_id}", response_model=NotebookOut)
29
- def get_notebook(notebook_id: str, user_id: str = Query(...)) -> NotebookOut:
30
  try:
31
- notebook = store.get(user_id, notebook_id)
32
  except ValueError as exc:
33
  raise HTTPException(status_code=400, detail=str(exc)) from exc
34
 
@@ -38,9 +43,14 @@ def get_notebook(notebook_id: str, user_id: str = Query(...)) -> NotebookOut:
38
 
39
 
40
  @router.patch("/{notebook_id}", response_model=NotebookOut)
41
- def rename_notebook(notebook_id: str, payload: NotebookRename) -> NotebookOut:
 
 
 
 
42
  try:
43
- notebook = store.rename(payload.user_id, notebook_id, payload.name)
 
44
  except ValueError as exc:
45
  raise HTTPException(status_code=400, detail=str(exc)) from exc
46
 
@@ -50,9 +60,9 @@ def rename_notebook(notebook_id: str, payload: NotebookRename) -> NotebookOut:
50
 
51
 
52
  @router.delete("/{notebook_id}")
53
- def delete_notebook(notebook_id: str, user_id: str = Query(...)) -> dict:
54
  try:
55
- deleted = store.delete(user_id, notebook_id)
56
  except ValueError as exc:
57
  raise HTTPException(status_code=400, detail=str(exc)) from exc
58
 
 
1
  from typing import List
2
 
3
+ from fastapi import APIRouter, Depends, HTTPException
4
 
5
  from backend.models.schemas import NotebookCreate, NotebookOut, NotebookRename
6
+ from backend.services.auth import User, enforce_user_match, get_current_user
7
  from backend.services.storage import NotebookStore
8
 
9
  router = APIRouter()
 
11
 
12
 
13
  @router.get("/", response_model=List[NotebookOut])
14
+ def list_notebooks(current_user: User = Depends(get_current_user)) -> List[NotebookOut]:
15
  try:
16
+ return store.list(current_user.user_id)
17
  except ValueError as exc:
18
  raise HTTPException(status_code=400, detail=str(exc)) from exc
19
 
20
 
21
  @router.post("/", response_model=NotebookOut)
22
+ def create_notebook(
23
+ payload: NotebookCreate,
24
+ current_user: User = Depends(get_current_user),
25
+ ) -> NotebookOut:
26
  try:
27
+ enforce_user_match(current_user, payload.user_id)
28
+ return store.create(NotebookCreate(user_id=current_user.user_id, name=payload.name))
29
  except ValueError as exc:
30
  raise HTTPException(status_code=400, detail=str(exc)) from exc
31
 
32
 
33
  @router.get("/{notebook_id}", response_model=NotebookOut)
34
+ def get_notebook(notebook_id: str, current_user: User = Depends(get_current_user)) -> NotebookOut:
35
  try:
36
+ notebook = store.get(current_user.user_id, notebook_id)
37
  except ValueError as exc:
38
  raise HTTPException(status_code=400, detail=str(exc)) from exc
39
 
 
43
 
44
 
45
  @router.patch("/{notebook_id}", response_model=NotebookOut)
46
+ def rename_notebook(
47
+ notebook_id: str,
48
+ payload: NotebookRename,
49
+ current_user: User = Depends(get_current_user),
50
+ ) -> NotebookOut:
51
  try:
52
+ enforce_user_match(current_user, payload.user_id)
53
+ notebook = store.rename(current_user.user_id, notebook_id, payload.name)
54
  except ValueError as exc:
55
  raise HTTPException(status_code=400, detail=str(exc)) from exc
56
 
 
60
 
61
 
62
  @router.delete("/{notebook_id}")
63
+ def delete_notebook(notebook_id: str, current_user: User = Depends(get_current_user)) -> dict:
64
  try:
65
+ deleted = store.delete(current_user.user_id, notebook_id)
66
  except ValueError as exc:
67
  raise HTTPException(status_code=400, detail=str(exc)) from exc
68
 
backend/api/sources.py CHANGED
@@ -1,7 +1,10 @@
1
- from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
 
 
2
 
3
  from backend.models.schemas import SourceListOut, SourceOut, UrlIngestRequest
4
  from backend.modules.ingestion import ingest_uploaded_bytes, ingest_url, list_ingested_sources
 
5
  from backend.services.storage import NotebookStore
6
 
7
  router = APIRouter()
@@ -9,9 +12,9 @@ store = NotebookStore(base_dir="data")
9
 
10
 
11
  @router.get("/{notebook_id}/sources", response_model=SourceListOut)
12
- def list_sources(notebook_id: str, user_id: str = Query(...)) -> SourceListOut:
13
  try:
14
- items = list_ingested_sources(store, user_id=user_id, notebook_id=notebook_id)
15
  return SourceListOut(sources=items)
16
  except FileNotFoundError:
17
  raise HTTPException(status_code=404, detail="Notebook not found")
@@ -20,11 +23,16 @@ def list_sources(notebook_id: str, user_id: str = Query(...)) -> SourceListOut:
20
 
21
 
22
  @router.post("/{notebook_id}/sources/url", response_model=SourceOut)
23
- def ingest_source_url(notebook_id: str, payload: UrlIngestRequest) -> SourceOut:
 
 
 
 
24
  try:
 
25
  return ingest_url(
26
  store,
27
- user_id=payload.user_id,
28
  notebook_id=notebook_id,
29
  url=str(payload.url),
30
  source_name=payload.source_name,
@@ -40,14 +48,16 @@ def ingest_source_url(notebook_id: str, payload: UrlIngestRequest) -> SourceOut:
40
  @router.post("/{notebook_id}/sources/upload", response_model=SourceOut)
41
  async def upload_source_file(
42
  notebook_id: str,
43
- user_id: str = Form(...),
44
  file: UploadFile = File(...),
 
45
  ) -> SourceOut:
46
  try:
 
47
  content = await file.read()
48
  return ingest_uploaded_bytes(
49
  store,
50
- user_id=user_id,
51
  notebook_id=notebook_id,
52
  filename=file.filename or "upload.txt",
53
  content=content,
 
1
+ from typing import Optional
2
+
3
+ from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
4
 
5
  from backend.models.schemas import SourceListOut, SourceOut, UrlIngestRequest
6
  from backend.modules.ingestion import ingest_uploaded_bytes, ingest_url, list_ingested_sources
7
+ from backend.services.auth import User, enforce_user_match, get_current_user
8
  from backend.services.storage import NotebookStore
9
 
10
  router = APIRouter()
 
12
 
13
 
14
  @router.get("/{notebook_id}/sources", response_model=SourceListOut)
15
+ def list_sources(notebook_id: str, current_user: User = Depends(get_current_user)) -> SourceListOut:
16
  try:
17
+ items = list_ingested_sources(store, user_id=current_user.user_id, notebook_id=notebook_id)
18
  return SourceListOut(sources=items)
19
  except FileNotFoundError:
20
  raise HTTPException(status_code=404, detail="Notebook not found")
 
23
 
24
 
25
  @router.post("/{notebook_id}/sources/url", response_model=SourceOut)
26
+ def ingest_source_url(
27
+ notebook_id: str,
28
+ payload: UrlIngestRequest,
29
+ current_user: User = Depends(get_current_user),
30
+ ) -> SourceOut:
31
  try:
32
+ enforce_user_match(current_user, payload.user_id)
33
  return ingest_url(
34
  store,
35
+ user_id=current_user.user_id,
36
  notebook_id=notebook_id,
37
  url=str(payload.url),
38
  source_name=payload.source_name,
 
48
  @router.post("/{notebook_id}/sources/upload", response_model=SourceOut)
49
  async def upload_source_file(
50
  notebook_id: str,
51
+ user_id: Optional[str] = Form(default=None),
52
  file: UploadFile = File(...),
53
+ current_user: User = Depends(get_current_user),
54
  ) -> SourceOut:
55
  try:
56
+ enforce_user_match(current_user, user_id)
57
  content = await file.read()
58
  return ingest_uploaded_bytes(
59
  store,
60
+ user_id=current_user.user_id,
61
  notebook_id=notebook_id,
62
  filename=file.filename or "upload.txt",
63
  content=content,
backend/models/schemas.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import List, Optional
2
 
3
  from pydantic import BaseModel, HttpUrl
4
 
@@ -46,6 +46,7 @@ class ChatRequest(BaseModel):
46
  user_id: str
47
  message: str
48
  top_k: int = 5
 
49
 
50
 
51
  class CitationOut(BaseModel):
 
1
+ from typing import List, Literal, Optional
2
 
3
  from pydantic import BaseModel, HttpUrl
4
 
 
46
  user_id: str
47
  message: str
48
  top_k: int = 5
49
+ retrieval_mode: Literal["topk", "rerank"] = "topk"
50
 
51
 
52
  class CitationOut(BaseModel):
backend/modules/artifacts.py CHANGED
@@ -1,17 +1,12 @@
1
- import audioop
2
- import json
3
- import math
4
- import os
5
- import re
6
- import wave
7
- import torch
8
- import io
9
- import soundfile as sf
10
- from transformers import AutoTokenizer, VitsModel
11
- from datetime import datetime, timezone
12
- from io import BytesIO
13
- from pathlib import Path
14
- from typing import Dict, List, Optional, Tuple
15
 
16
  from backend.models.schemas import (
17
  ArtifactFileOut,
@@ -27,8 +22,35 @@ import logging
27
  logging.basicConfig(level=logging.INFO)
28
  logger = logging.getLogger(__name__)
29
 
30
- _vits_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
31
- _tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def _now() -> str:
34
  return datetime.now(timezone.utc).replace(microsecond=0).isoformat()
@@ -209,13 +231,14 @@ def clean_transcript_for_tts(transcript: str) -> str:
209
  return " ".join(lines)
210
 
211
 
212
- def _synthesize_podcast_mp3(transcript_text: str) -> bytes:
213
- tts_text = clean_transcript_for_tts(transcript_text)[:1800]
214
-
215
- inputs = _tokenizer(tts_text, return_tensors="pt")
216
-
217
- with torch.no_grad():
218
- waveform = _vits_model(**inputs).waveform.squeeze().cpu().numpy()
 
219
 
220
  wav_buffer = io.BytesIO()
221
  sf.write(wav_buffer, waveform, 16000, format="WAV")
 
1
+ import audioop
2
+ import json
3
+ import re
4
+ import wave
5
+ import io
6
+ from datetime import datetime, timezone
7
+ from io import BytesIO
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional
 
 
 
 
 
10
 
11
  from backend.models.schemas import (
12
  ArtifactFileOut,
 
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
25
+ try:
26
+ import soundfile as sf
27
+ except ImportError:
28
+ sf = None
29
+
30
+ try:
31
+ import torch
32
+ except ImportError:
33
+ torch = None
34
+
35
+ try:
36
+ from transformers import AutoTokenizer, VitsModel
37
+ except ImportError:
38
+ AutoTokenizer = None
39
+ VitsModel = None
40
+
41
+ _vits_model = None
42
+ _tokenizer = None
43
+
44
+
45
+ def _get_tts_model():
46
+ global _vits_model, _tokenizer
47
+ if _vits_model is not None and _tokenizer is not None:
48
+ return _vits_model, _tokenizer
49
+ if torch is None or sf is None or AutoTokenizer is None or VitsModel is None:
50
+ raise RuntimeError("TTS dependencies are not installed")
51
+ _vits_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
52
+ _tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
53
+ return _vits_model, _tokenizer
54
 
55
  def _now() -> str:
56
  return datetime.now(timezone.utc).replace(microsecond=0).isoformat()
 
231
  return " ".join(lines)
232
 
233
 
234
+ def _synthesize_podcast_mp3(transcript_text: str) -> bytes:
235
+ tts_text = clean_transcript_for_tts(transcript_text)[:1800]
236
+ model, tokenizer = _get_tts_model()
237
+
238
+ inputs = tokenizer(tts_text, return_tensors="pt")
239
+
240
+ with torch.no_grad():
241
+ waveform = model(**inputs).waveform.squeeze().cpu().numpy()
242
 
243
  wav_buffer = io.BytesIO()
244
  sf.write(wav_buffer, waveform, 16000, format="WAV")
backend/modules/rag.py CHANGED
@@ -1,4 +1,5 @@
1
  from datetime import datetime, timezone
 
2
  from typing import Any, Dict, List
3
 
4
  from backend.services.embeddings import embedding_service
@@ -51,12 +52,60 @@ def retrieve_notebook_chunks(
51
  notebook_id: str,
52
  query: str,
53
  top_k: int = 5,
 
54
  ) -> List[Dict[str, Any]]:
55
  query_vecs = embedding_service.embed_texts([query])
56
  if not query_vecs:
57
  return []
58
  chroma = ChromaNotebookStore(store.chroma_dir(user_id, notebook_id))
59
- return chroma.query(query_vecs[0], k=top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
  def _citation_objects(retrieved_chunks: List[Dict[str, Any]]) -> List[Dict[str, str]]:
@@ -87,6 +136,7 @@ def answer_notebook_question(
87
  notebook_id: str,
88
  message: str,
89
  top_k: int = 5,
 
90
  ) -> Dict[str, Any]:
91
  # Ensures the notebook exists and belongs to the user before retrieving.
92
  store.require_notebook_dir(user_id, notebook_id)
@@ -97,6 +147,7 @@ def answer_notebook_question(
97
  notebook_id=notebook_id,
98
  query=message,
99
  top_k=top_k,
 
100
  )
101
  prompt = build_rag_prompt(message, retrieved_chunks)
102
  answer = llm_service.generate(prompt)
 
1
  from datetime import datetime, timezone
2
+ import re
3
  from typing import Any, Dict, List
4
 
5
  from backend.services.embeddings import embedding_service
 
52
  notebook_id: str,
53
  query: str,
54
  top_k: int = 5,
55
+ retrieval_mode: str = "topk",
56
  ) -> List[Dict[str, Any]]:
57
  query_vecs = embedding_service.embed_texts([query])
58
  if not query_vecs:
59
  return []
60
  chroma = ChromaNotebookStore(store.chroma_dir(user_id, notebook_id))
61
+ mode = retrieval_mode.strip().lower()
62
+ if mode == "topk":
63
+ return chroma.query(query_vecs[0], k=top_k)
64
+ if mode == "rerank":
65
+ # Retrieve a wider pool first, then rerank using lexical overlap + vector signal.
66
+ pool_size = max(top_k * 4, top_k)
67
+ pool = chroma.query(query_vecs[0], k=pool_size)
68
+ return rerank_chunks(query=query, candidate_chunks=pool, top_k=top_k)
69
+ raise ValueError("Unsupported retrieval_mode")
70
+
71
+
72
+ def _tokenize(text: str) -> set[str]:
73
+ return {tok for tok in re.findall(r"[a-zA-Z0-9]+", text.lower()) if tok}
74
+
75
+
76
+ def _vector_relevance_from_distance(distance: Any) -> float:
77
+ if distance is None:
78
+ return 0.0
79
+ try:
80
+ dist = float(distance)
81
+ except (TypeError, ValueError):
82
+ return 0.0
83
+ # Chroma cosine distance: lower is better; map to 0..1 relevance-like value.
84
+ return 1.0 / (1.0 + max(0.0, dist))
85
+
86
+
87
+ def rerank_chunks(
88
+ *,
89
+ query: str,
90
+ candidate_chunks: List[Dict[str, Any]],
91
+ top_k: int,
92
+ alpha: float = 0.65,
93
+ ) -> List[Dict[str, Any]]:
94
+ if not candidate_chunks:
95
+ return []
96
+ query_tokens = _tokenize(query)
97
+ ranked: List[tuple[float, Dict[str, Any]]] = []
98
+ for chunk in candidate_chunks:
99
+ text = str(chunk.get("text", ""))
100
+ chunk_tokens = _tokenize(text)
101
+ lexical = (len(query_tokens & chunk_tokens) / len(query_tokens)) if query_tokens else 0.0
102
+ vector_rel = _vector_relevance_from_distance(chunk.get("distance"))
103
+ score = (alpha * vector_rel) + ((1.0 - alpha) * lexical)
104
+ enriched = dict(chunk)
105
+ enriched["rerank_score"] = score
106
+ ranked.append((score, enriched))
107
+ ranked.sort(key=lambda x: x[0], reverse=True)
108
+ return [item[1] for item in ranked[: max(1, top_k)]]
109
 
110
 
111
  def _citation_objects(retrieved_chunks: List[Dict[str, Any]]) -> List[Dict[str, str]]:
 
136
  notebook_id: str,
137
  message: str,
138
  top_k: int = 5,
139
+ retrieval_mode: str = "topk",
140
  ) -> Dict[str, Any]:
141
  # Ensures the notebook exists and belongs to the user before retrieving.
142
  store.require_notebook_dir(user_id, notebook_id)
 
147
  notebook_id=notebook_id,
148
  query=message,
149
  top_k=top_k,
150
+ retrieval_mode=retrieval_mode,
151
  )
152
  prompt = build_rag_prompt(message, retrieved_chunks)
153
  answer = llm_service.generate(prompt)
backend/services/auth.py CHANGED
@@ -1,4 +1,8 @@
1
  from dataclasses import dataclass
 
 
 
 
2
 
3
 
4
  @dataclass
@@ -7,6 +11,50 @@ class User:
7
  email: str
8
 
9
 
10
- def get_current_user() -> User:
11
- # Placeholder for HF OAuth integration
12
- return User(user_id="demo-user", email="demo@example.com")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from dataclasses import dataclass
2
+ import re
3
+ from typing import Optional
4
+
5
+ from fastapi import Header, HTTPException, Request
6
 
7
 
8
  @dataclass
 
11
  email: str
12
 
13
 
14
+ _USER_ID_PATTERN = re.compile(r"[A-Za-z0-9._@-]+$")
15
+
16
+
17
+ def _validate_user_id(value: str) -> str:
18
+ if not _USER_ID_PATTERN.fullmatch(value):
19
+ raise HTTPException(status_code=400, detail="Invalid authenticated user id")
20
+ return value
21
+
22
+
23
+ def _extract_user_id_from_request(request: Request) -> Optional[str]:
24
+ # These headers are common in reverse-proxy / OAuth fronted deployments.
25
+ header_candidates = [
26
+ request.headers.get("x-user-id"),
27
+ request.headers.get("x-hf-username"),
28
+ request.headers.get("x-forwarded-user"),
29
+ request.headers.get("remote-user"),
30
+ ]
31
+ for candidate in header_candidates:
32
+ if candidate and candidate.strip():
33
+ return candidate.strip()
34
+ return None
35
+
36
+
37
+ def get_current_user(
38
+ request: Request,
39
+ x_user_id: Optional[str] = Header(default=None, alias="X-User-Id"),
40
+ x_user_email: Optional[str] = Header(default=None, alias="X-User-Email"),
41
+ ) -> User:
42
+ user_id = (x_user_id or "").strip() or _extract_user_id_from_request(request)
43
+ if not user_id:
44
+ raise HTTPException(
45
+ status_code=401,
46
+ detail="Authentication required. Provide authenticated user context.",
47
+ )
48
+
49
+ user_id = _validate_user_id(user_id)
50
+ email = (x_user_email or f"{user_id}@local").strip()
51
+ return User(user_id=user_id, email=email)
52
+
53
+
54
+ def enforce_user_match(current_user: User, supplied_user_id: Optional[str]) -> None:
55
+ supplied = (supplied_user_id or "").strip()
56
+ if supplied and supplied != current_user.user_id:
57
+ raise HTTPException(
58
+ status_code=403,
59
+ detail="Authenticated user does not match supplied user_id",
60
+ )
docs/rag_techniques.md ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RAG Techniques Comparison
2
+
3
+ This project currently supports two retrieval modes in chat:
4
+
5
+ 1. `topk`: baseline dense retrieval from Chroma using cosine distance.
6
+ 2. `rerank`: retrieves a larger candidate pool, then re-scores candidates using:
7
+ - vector relevance (from Chroma distance)
8
+ - lexical overlap with query terms
9
+
10
+ ## Where It Is Implemented
11
+
12
+ - Request model: `backend/models/schemas.py` (`ChatRequest.retrieval_mode`)
13
+ - Retrieval logic: `backend/modules/rag.py`
14
+ - Chat API: `backend/api/chat.py`
15
+
16
+ ## How To Run Benchmark
17
+
18
+ Prerequisites:
19
+ - Backend running
20
+ - Notebook has at least one ingested source
21
+
22
+ Example:
23
+
24
+ ```bash
25
+ python scripts/rag_benchmark.py \
26
+ --base-url http://127.0.0.1:8000 \
27
+ --user-id <user_id> \
28
+ --notebook-id <notebook_id> \
29
+ --query "Explain the key ideas in my notes" \
30
+ --top-k 5 \
31
+ --runs 5
32
+ ```
33
+
34
+ The script prints JSON with average/min/max latency and citation/chunk stats for both modes.
35
+
36
+ ## Report Template
37
+
38
+ Use this table in your class deliverable:
39
+
40
+ | Query | Mode | Avg Latency (ms) | Avg Citations | Notes on Retrieved Context |
41
+ |---|---:|---:|---:|---|
42
+ | Q1 | topk | | | |
43
+ | Q1 | rerank | | | |
44
+ | Q2 | topk | | | |
45
+ | Q2 | rerank | | | |
46
+
47
+ Recommended write-up points:
48
+ - How different the retrieved chunks were between `topk` and `rerank`
49
+ - Which mode produced less redundant context
50
+ - Latency tradeoff (rerank usually slightly slower)
frontend/app.py CHANGED
@@ -80,7 +80,17 @@ def _maybe_start_local_backend() -> None:
80
  time.sleep(0.2)
81
 
82
 
83
- def _api_request(method: str, path: str, *, params=None, json_body=None, files=None, data=None, timeout: int = 60):
 
 
 
 
 
 
 
 
 
 
84
  url = f"{BACKEND_URL}{path}"
85
  try:
86
  resp = requests.request(
@@ -90,6 +100,7 @@ def _api_request(method: str, path: str, *, params=None, json_body=None, files=N
90
  json=json_body,
91
  files=files,
92
  data=data,
 
93
  timeout=timeout,
94
  )
95
  except requests.RequestException as exc:
@@ -109,6 +120,11 @@ def _api_request(method: str, path: str, *, params=None, json_body=None, files=N
109
  return None
110
 
111
 
 
 
 
 
 
112
  def _format_notebook_choices(items: list[dict[str, Any]]):
113
  return [
114
  (f"{item.get('name', 'Untitled')} [{str(item.get('notebook_id', ''))[:8]}]", item.get("notebook_id"))
@@ -153,15 +169,25 @@ def load_notebooks(user_id: str):
153
  if not user_id:
154
  return gr.Dropdown(choices=[], value=None), [], gr.JSON(value={"sources": []}), [], "Enter a user ID."
155
  try:
156
- items = _api_request("GET", "/api/notebooks/", params={"user_id": user_id})
157
  choices = _format_notebook_choices(items)
158
  selected = choices[0][1] if choices else None
159
  chat_history = []
160
  sources_payload = {"sources": []}
161
  if selected:
162
- chat = _api_request("GET", f"/api/notebooks/{selected}/chat", params={"user_id": user_id})
 
 
 
 
 
163
  chat_history = _messages_to_chatbot(chat.get("messages", []))
164
- sources_payload = _api_request("GET", f"/api/notebooks/{selected}/sources", params={"user_id": user_id})
 
 
 
 
 
165
  return gr.Dropdown(choices=choices, value=selected), items, gr.JSON(value=sources_payload), chat_history, ""
166
  except Exception as exc:
167
  return gr.Dropdown(choices=[], value=None), [], gr.JSON(value={"sources": []}), [], str(exc)
@@ -173,7 +199,12 @@ def create_notebook(user_id: str, notebook_name: str):
173
  if not user_id:
174
  return gr.Dropdown(choices=[], value=None), [], "", "Enter a user ID first."
175
  try:
176
- _api_request("POST", "/api/notebooks/", json_body={"user_id": user_id, "name": notebook_name})
 
 
 
 
 
177
  dropdown, notebooks, _sources, _chat, _status = load_notebooks(user_id)
178
  return dropdown, notebooks, "", f"Created notebook '{notebook_name}'."
179
  except Exception as exc:
@@ -190,8 +221,9 @@ def rename_notebook(user_id: str, notebook_id: str, notebook_name: str):
190
  "PATCH",
191
  f"/api/notebooks/{notebook_id}",
192
  json_body={"user_id": user_id, "name": notebook_name},
 
193
  )
194
- items = _api_request("GET", "/api/notebooks/", params={"user_id": user_id})
195
  choices = _format_notebook_choices(items)
196
  return gr.Dropdown(choices=choices, value=notebook_id), items, f"Renamed notebook to '{notebook_name}'."
197
  except Exception as exc:
@@ -203,7 +235,12 @@ def delete_notebook(user_id: str, notebook_id: str):
203
  if not user_id or not notebook_id:
204
  return gr.Dropdown(choices=[], value=None), [], gr.JSON(value={"sources": []}), [], "Select a notebook to delete."
205
  try:
206
- _api_request("DELETE", f"/api/notebooks/{notebook_id}", params={"user_id": user_id})
 
 
 
 
 
207
  dropdown, notebooks, sources_json, chat_history, _ = load_notebooks(user_id)
208
  return dropdown, notebooks, sources_json, chat_history, "Deleted notebook."
209
  except Exception as exc:
@@ -215,8 +252,18 @@ def on_notebook_change(user_id: str, notebook_id: str):
215
  if not user_id or not notebook_id:
216
  return gr.JSON(value={"sources": []}), [], ""
217
  try:
218
- sources_payload = _api_request("GET", f"/api/notebooks/{notebook_id}/sources", params={"user_id": user_id})
219
- chat = _api_request("GET", f"/api/notebooks/{notebook_id}/chat", params={"user_id": user_id})
 
 
 
 
 
 
 
 
 
 
220
  return gr.JSON(value=sources_payload), _messages_to_chatbot(chat.get("messages", [])), ""
221
  except Exception as exc:
222
  return gr.JSON(value={"sources": []}), [], str(exc)
@@ -233,9 +280,15 @@ def upload_source(user_id: str, notebook_id: str, file_path: str):
233
  f"/api/notebooks/{notebook_id}/sources/upload",
234
  data={"user_id": user_id},
235
  files={"file": (os.path.basename(file_path), f)},
 
236
  timeout=180,
237
  )
238
- sources_payload = _api_request("GET", f"/api/notebooks/{notebook_id}/sources", params={"user_id": user_id})
 
 
 
 
 
239
  return gr.JSON(value=sources_payload), f"Ingested file: {payload.get('source_name', os.path.basename(file_path))}"
240
  except Exception as exc:
241
  return gr.JSON(value={"sources": []}), str(exc)
@@ -251,9 +304,15 @@ def ingest_url_source(user_id: str, notebook_id: str, url: str):
251
  "POST",
252
  f"/api/notebooks/{notebook_id}/sources/url",
253
  json_body={"user_id": user_id, "url": url},
 
254
  timeout=180,
255
  )
256
- sources_payload = _api_request("GET", f"/api/notebooks/{notebook_id}/sources", params={"user_id": user_id})
 
 
 
 
 
257
  return gr.JSON(value=sources_payload), f"Ingested URL: {payload.get('source_name', url)}"
258
  except Exception as exc:
259
  return gr.JSON(value={"sources": []}), str(exc)
@@ -272,6 +331,7 @@ def send_message(message: str, history, user_id: str, notebook_id: str):
272
  "POST",
273
  f"/api/notebooks/{notebook_id}/chat",
274
  json_body={"user_id": user_id, "message": message, "top_k": 5},
 
275
  timeout=180,
276
  )
277
  assistant_text = str(resp.get("answer", "")) + _format_citations(resp.get("citations"))
@@ -313,7 +373,12 @@ def refresh_artifacts(user_id: str, notebook_id: str):
313
  payload = _empty_artifacts_payload()
314
  return gr.JSON(value=payload), None, None, None, None, "Select a notebook first."
315
  try:
316
- payload = _api_request("GET", f"/api/notebooks/{notebook_id}/artifacts", params={"user_id": user_id})
 
 
 
 
 
317
  report, quiz, transcript, audio = _artifact_outputs_from_payload(payload)
318
  return gr.JSON(value=payload), report, quiz, transcript, audio, ""
319
  except Exception as exc:
@@ -336,6 +401,7 @@ def generate_report_artifact(user_id: str, notebook_id: str, artifact_prompt: st
336
  "POST",
337
  f"/api/notebooks/{notebook_id}/artifacts/report",
338
  json_body={"user_id": user_id, "prompt": (artifact_prompt or "").strip() or None},
 
339
  timeout=180,
340
  )
341
  payload_json, report, quiz, transcript, audio, _ = refresh_artifacts(user_id, notebook_id)
@@ -359,6 +425,7 @@ def generate_quiz_artifact(user_id: str, notebook_id: str, artifact_prompt: str,
359
  "prompt": (artifact_prompt or "").strip() or None,
360
  "num_questions": int(num_questions),
361
  },
 
362
  timeout=180,
363
  )
364
  payload_json, report, quiz, transcript, audio, _ = refresh_artifacts(user_id, notebook_id)
@@ -378,6 +445,7 @@ def generate_podcast_artifact(user_id: str, notebook_id: str, artifact_prompt: s
378
  "POST",
379
  f"/api/notebooks/{notebook_id}/artifacts/podcast",
380
  json_body={"user_id": user_id, "prompt": (artifact_prompt or "").strip() or None},
 
381
  timeout=240,
382
  )
383
  payload_json, report, quiz, transcript, audio, _ = refresh_artifacts(user_id, notebook_id)
 
80
  time.sleep(0.2)
81
 
82
 
83
+ def _api_request(
84
+ method: str,
85
+ path: str,
86
+ *,
87
+ params=None,
88
+ json_body=None,
89
+ files=None,
90
+ data=None,
91
+ headers=None,
92
+ timeout: int = 60,
93
+ ):
94
  url = f"{BACKEND_URL}{path}"
95
  try:
96
  resp = requests.request(
 
100
  json=json_body,
101
  files=files,
102
  data=data,
103
+ headers=headers,
104
  timeout=timeout,
105
  )
106
  except requests.RequestException as exc:
 
120
  return None
121
 
122
 
123
+ def _auth_headers(user_id: str | None) -> dict[str, str]:
124
+ uid = (user_id or "").strip()
125
+ return {"X-User-Id": uid} if uid else {}
126
+
127
+
128
  def _format_notebook_choices(items: list[dict[str, Any]]):
129
  return [
130
  (f"{item.get('name', 'Untitled')} [{str(item.get('notebook_id', ''))[:8]}]", item.get("notebook_id"))
 
169
  if not user_id:
170
  return gr.Dropdown(choices=[], value=None), [], gr.JSON(value={"sources": []}), [], "Enter a user ID."
171
  try:
172
+ items = _api_request("GET", "/api/notebooks/", params={"user_id": user_id}, headers=_auth_headers(user_id))
173
  choices = _format_notebook_choices(items)
174
  selected = choices[0][1] if choices else None
175
  chat_history = []
176
  sources_payload = {"sources": []}
177
  if selected:
178
+ chat = _api_request(
179
+ "GET",
180
+ f"/api/notebooks/{selected}/chat",
181
+ params={"user_id": user_id},
182
+ headers=_auth_headers(user_id),
183
+ )
184
  chat_history = _messages_to_chatbot(chat.get("messages", []))
185
+ sources_payload = _api_request(
186
+ "GET",
187
+ f"/api/notebooks/{selected}/sources",
188
+ params={"user_id": user_id},
189
+ headers=_auth_headers(user_id),
190
+ )
191
  return gr.Dropdown(choices=choices, value=selected), items, gr.JSON(value=sources_payload), chat_history, ""
192
  except Exception as exc:
193
  return gr.Dropdown(choices=[], value=None), [], gr.JSON(value={"sources": []}), [], str(exc)
 
199
  if not user_id:
200
  return gr.Dropdown(choices=[], value=None), [], "", "Enter a user ID first."
201
  try:
202
+ _api_request(
203
+ "POST",
204
+ "/api/notebooks/",
205
+ json_body={"user_id": user_id, "name": notebook_name},
206
+ headers=_auth_headers(user_id),
207
+ )
208
  dropdown, notebooks, _sources, _chat, _status = load_notebooks(user_id)
209
  return dropdown, notebooks, "", f"Created notebook '{notebook_name}'."
210
  except Exception as exc:
 
221
  "PATCH",
222
  f"/api/notebooks/{notebook_id}",
223
  json_body={"user_id": user_id, "name": notebook_name},
224
+ headers=_auth_headers(user_id),
225
  )
226
+ items = _api_request("GET", "/api/notebooks/", params={"user_id": user_id}, headers=_auth_headers(user_id))
227
  choices = _format_notebook_choices(items)
228
  return gr.Dropdown(choices=choices, value=notebook_id), items, f"Renamed notebook to '{notebook_name}'."
229
  except Exception as exc:
 
235
  if not user_id or not notebook_id:
236
  return gr.Dropdown(choices=[], value=None), [], gr.JSON(value={"sources": []}), [], "Select a notebook to delete."
237
  try:
238
+ _api_request(
239
+ "DELETE",
240
+ f"/api/notebooks/{notebook_id}",
241
+ params={"user_id": user_id},
242
+ headers=_auth_headers(user_id),
243
+ )
244
  dropdown, notebooks, sources_json, chat_history, _ = load_notebooks(user_id)
245
  return dropdown, notebooks, sources_json, chat_history, "Deleted notebook."
246
  except Exception as exc:
 
252
  if not user_id or not notebook_id:
253
  return gr.JSON(value={"sources": []}), [], ""
254
  try:
255
+ sources_payload = _api_request(
256
+ "GET",
257
+ f"/api/notebooks/{notebook_id}/sources",
258
+ params={"user_id": user_id},
259
+ headers=_auth_headers(user_id),
260
+ )
261
+ chat = _api_request(
262
+ "GET",
263
+ f"/api/notebooks/{notebook_id}/chat",
264
+ params={"user_id": user_id},
265
+ headers=_auth_headers(user_id),
266
+ )
267
  return gr.JSON(value=sources_payload), _messages_to_chatbot(chat.get("messages", [])), ""
268
  except Exception as exc:
269
  return gr.JSON(value={"sources": []}), [], str(exc)
 
280
  f"/api/notebooks/{notebook_id}/sources/upload",
281
  data={"user_id": user_id},
282
  files={"file": (os.path.basename(file_path), f)},
283
+ headers=_auth_headers(user_id),
284
  timeout=180,
285
  )
286
+ sources_payload = _api_request(
287
+ "GET",
288
+ f"/api/notebooks/{notebook_id}/sources",
289
+ params={"user_id": user_id},
290
+ headers=_auth_headers(user_id),
291
+ )
292
  return gr.JSON(value=sources_payload), f"Ingested file: {payload.get('source_name', os.path.basename(file_path))}"
293
  except Exception as exc:
294
  return gr.JSON(value={"sources": []}), str(exc)
 
304
  "POST",
305
  f"/api/notebooks/{notebook_id}/sources/url",
306
  json_body={"user_id": user_id, "url": url},
307
+ headers=_auth_headers(user_id),
308
  timeout=180,
309
  )
310
+ sources_payload = _api_request(
311
+ "GET",
312
+ f"/api/notebooks/{notebook_id}/sources",
313
+ params={"user_id": user_id},
314
+ headers=_auth_headers(user_id),
315
+ )
316
  return gr.JSON(value=sources_payload), f"Ingested URL: {payload.get('source_name', url)}"
317
  except Exception as exc:
318
  return gr.JSON(value={"sources": []}), str(exc)
 
331
  "POST",
332
  f"/api/notebooks/{notebook_id}/chat",
333
  json_body={"user_id": user_id, "message": message, "top_k": 5},
334
+ headers=_auth_headers(user_id),
335
  timeout=180,
336
  )
337
  assistant_text = str(resp.get("answer", "")) + _format_citations(resp.get("citations"))
 
373
  payload = _empty_artifacts_payload()
374
  return gr.JSON(value=payload), None, None, None, None, "Select a notebook first."
375
  try:
376
+ payload = _api_request(
377
+ "GET",
378
+ f"/api/notebooks/{notebook_id}/artifacts",
379
+ params={"user_id": user_id},
380
+ headers=_auth_headers(user_id),
381
+ )
382
  report, quiz, transcript, audio = _artifact_outputs_from_payload(payload)
383
  return gr.JSON(value=payload), report, quiz, transcript, audio, ""
384
  except Exception as exc:
 
401
  "POST",
402
  f"/api/notebooks/{notebook_id}/artifacts/report",
403
  json_body={"user_id": user_id, "prompt": (artifact_prompt or "").strip() or None},
404
+ headers=_auth_headers(user_id),
405
  timeout=180,
406
  )
407
  payload_json, report, quiz, transcript, audio, _ = refresh_artifacts(user_id, notebook_id)
 
425
  "prompt": (artifact_prompt or "").strip() or None,
426
  "num_questions": int(num_questions),
427
  },
428
+ headers=_auth_headers(user_id),
429
  timeout=180,
430
  )
431
  payload_json, report, quiz, transcript, audio, _ = refresh_artifacts(user_id, notebook_id)
 
445
  "POST",
446
  f"/api/notebooks/{notebook_id}/artifacts/podcast",
447
  json_body={"user_id": user_id, "prompt": (artifact_prompt or "").strip() or None},
448
+ headers=_auth_headers(user_id),
449
  timeout=240,
450
  )
451
  payload_json, report, quiz, transcript, audio, _ = refresh_artifacts(user_id, notebook_id)
scripts/rag_benchmark.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+ from statistics import mean
5
+ from typing import Any
6
+
7
+ import requests
8
+
9
+
10
+ def _request(method: str, url: str, *, headers=None, params=None, json_body=None, timeout=180) -> dict[str, Any]:
11
+ resp = requests.request(
12
+ method=method,
13
+ url=url,
14
+ headers=headers,
15
+ params=params,
16
+ json=json_body,
17
+ timeout=timeout,
18
+ )
19
+ resp.raise_for_status()
20
+ if resp.content:
21
+ return resp.json()
22
+ return {}
23
+
24
+
25
+ def run_benchmark(
26
+ *,
27
+ base_url: str,
28
+ user_id: str,
29
+ notebook_id: str,
30
+ query: str,
31
+ top_k: int,
32
+ runs: int,
33
+ ) -> dict[str, Any]:
34
+ headers = {"X-User-Id": user_id}
35
+ modes = ["topk", "rerank"]
36
+ results: dict[str, Any] = {}
37
+
38
+ for mode in modes:
39
+ latencies_ms: list[float] = []
40
+ citations_count: list[int] = []
41
+ used_chunks: list[int] = []
42
+
43
+ for _ in range(runs):
44
+ start = time.perf_counter()
45
+ payload = _request(
46
+ "POST",
47
+ f"{base_url}/api/notebooks/{notebook_id}/chat",
48
+ headers=headers,
49
+ json_body={
50
+ "user_id": user_id,
51
+ "message": query,
52
+ "top_k": top_k,
53
+ "retrieval_mode": mode,
54
+ },
55
+ )
56
+ elapsed_ms = (time.perf_counter() - start) * 1000.0
57
+ latencies_ms.append(elapsed_ms)
58
+ citations_count.append(len(payload.get("citations", [])))
59
+ used_chunks.append(int(payload.get("used_chunks", 0)))
60
+
61
+ results[mode] = {
62
+ "avg_latency_ms": round(mean(latencies_ms), 2),
63
+ "min_latency_ms": round(min(latencies_ms), 2),
64
+ "max_latency_ms": round(max(latencies_ms), 2),
65
+ "avg_citations": round(mean(citations_count), 2),
66
+ "avg_used_chunks": round(mean(used_chunks), 2),
67
+ }
68
+ return results
69
+
70
+
71
+ def main() -> None:
72
+ parser = argparse.ArgumentParser(description="Benchmark topk vs rerank retrieval modes.")
73
+ parser.add_argument("--base-url", default="http://127.0.0.1:8000")
74
+ parser.add_argument("--user-id", required=True)
75
+ parser.add_argument("--notebook-id", required=True)
76
+ parser.add_argument("--query", required=True)
77
+ parser.add_argument("--top-k", type=int, default=5)
78
+ parser.add_argument("--runs", type=int, default=5)
79
+ args = parser.parse_args()
80
+
81
+ results = run_benchmark(
82
+ base_url=args.base_url.rstrip("/"),
83
+ user_id=args.user_id,
84
+ notebook_id=args.notebook_id,
85
+ query=args.query,
86
+ top_k=args.top_k,
87
+ runs=max(1, args.runs),
88
+ )
89
+ print(json.dumps(results, indent=2))
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
tests/test_api_artifacts.py CHANGED
@@ -12,6 +12,7 @@ from backend.services.storage import NotebookStore
12
 
13
 
14
  client = TestClient(app)
 
15
 
16
 
17
  def _seed_source(store: NotebookStore, user_id: str, notebook_id: str, source_id: str = "src_demo") -> None:
@@ -41,17 +42,23 @@ def test_artifact_endpoints_generate_list_and_download(monkeypatch, tmp_path: Pa
41
  report = client.post(
42
  f"/api/notebooks/{nb.notebook_id}/artifacts/report",
43
  json={"user_id": "u1", "prompt": "Focus on definitions"},
 
44
  )
45
  assert report.status_code == 200
46
 
47
  podcast = client.post(
48
  f"/api/notebooks/{nb.notebook_id}/artifacts/podcast",
49
  json={"user_id": "u1"},
 
50
  )
51
  assert podcast.status_code == 200
52
  podcast_audio_name = Path(podcast.json()["audio_path"]).name
53
 
54
- listed = client.get(f"/api/notebooks/{nb.notebook_id}/artifacts", params={"user_id": "u1"})
 
 
 
 
55
  assert listed.status_code == 200
56
  payload = listed.json()
57
  assert len(payload["reports"]) == 1
@@ -60,6 +67,7 @@ def test_artifact_endpoints_generate_list_and_download(monkeypatch, tmp_path: Pa
60
  dl = client.get(
61
  f"/api/notebooks/{nb.notebook_id}/artifacts/download",
62
  params={"user_id": "u1", "artifact_type": "podcast", "filename": podcast_audio_name},
 
63
  )
64
  assert dl.status_code == 200
65
  assert dl.content.startswith(b"ID3")
 
12
 
13
 
14
  client = TestClient(app)
15
+ AUTH_U1 = {"X-User-Id": "u1"}
16
 
17
 
18
  def _seed_source(store: NotebookStore, user_id: str, notebook_id: str, source_id: str = "src_demo") -> None:
 
42
  report = client.post(
43
  f"/api/notebooks/{nb.notebook_id}/artifacts/report",
44
  json={"user_id": "u1", "prompt": "Focus on definitions"},
45
+ headers=AUTH_U1,
46
  )
47
  assert report.status_code == 200
48
 
49
  podcast = client.post(
50
  f"/api/notebooks/{nb.notebook_id}/artifacts/podcast",
51
  json={"user_id": "u1"},
52
+ headers=AUTH_U1,
53
  )
54
  assert podcast.status_code == 200
55
  podcast_audio_name = Path(podcast.json()["audio_path"]).name
56
 
57
+ listed = client.get(
58
+ f"/api/notebooks/{nb.notebook_id}/artifacts",
59
+ params={"user_id": "u1"},
60
+ headers=AUTH_U1,
61
+ )
62
  assert listed.status_code == 200
63
  payload = listed.json()
64
  assert len(payload["reports"]) == 1
 
67
  dl = client.get(
68
  f"/api/notebooks/{nb.notebook_id}/artifacts/download",
69
  params={"user_id": "u1", "artifact_type": "podcast", "filename": podcast_audio_name},
70
+ headers=AUTH_U1,
71
  )
72
  assert dl.status_code == 200
73
  assert dl.content.startswith(b"ID3")
tests/test_api_chat.py CHANGED
@@ -8,6 +8,7 @@ from backend.services.storage import NotebookStore
8
 
9
 
10
  client = TestClient(app)
 
11
 
12
 
13
  def test_chat_endpoint_success_with_mock(monkeypatch, tmp_path):
@@ -35,7 +36,8 @@ def test_chat_endpoint_success_with_mock(monkeypatch, tmp_path):
35
 
36
  resp = client.post(
37
  f"/api/notebooks/{created.notebook_id}/chat",
38
- json={"user_id": "u1", "message": "Hi", "top_k": 3},
 
39
  )
40
  assert resp.status_code == 200
41
  data = resp.json()
@@ -61,7 +63,11 @@ def test_chat_history_endpoint_reads_jsonl(tmp_path):
61
  },
62
  )
63
 
64
- resp = client.get(f"/api/notebooks/{created.notebook_id}/chat", params={"user_id": "u1"})
 
 
 
 
65
  assert resp.status_code == 200
66
  payload = resp.json()
67
  assert len(payload["messages"]) == 2
 
8
 
9
 
10
  client = TestClient(app)
11
+ AUTH_U1 = {"X-User-Id": "u1"}
12
 
13
 
14
  def test_chat_endpoint_success_with_mock(monkeypatch, tmp_path):
 
36
 
37
  resp = client.post(
38
  f"/api/notebooks/{created.notebook_id}/chat",
39
+ json={"user_id": "u1", "message": "Hi", "top_k": 3, "retrieval_mode": "rerank"},
40
+ headers=AUTH_U1,
41
  )
42
  assert resp.status_code == 200
43
  data = resp.json()
 
63
  },
64
  )
65
 
66
+ resp = client.get(
67
+ f"/api/notebooks/{created.notebook_id}/chat",
68
+ params={"user_id": "u1"},
69
+ headers=AUTH_U1,
70
+ )
71
  assert resp.status_code == 200
72
  payload = resp.json()
73
  assert len(payload["messages"]) == 2
tests/test_api_notebooks.py CHANGED
@@ -6,6 +6,8 @@ from backend.services.storage import NotebookStore
6
 
7
 
8
  client = TestClient(app)
 
 
9
 
10
 
11
  def setup_function():
@@ -16,7 +18,7 @@ def test_create_list_get_rename_delete_notebook(tmp_path):
16
  notebooks_api.store = NotebookStore(base_dir=str(tmp_path))
17
 
18
  payload = {"user_id": "u1", "name": "Notebook 1"}
19
- resp = client.post("/api/notebooks/", json=payload)
20
  assert resp.status_code == 200
21
  data = resp.json()
22
  assert data["user_id"] == "u1"
@@ -25,33 +27,44 @@ def test_create_list_get_rename_delete_notebook(tmp_path):
25
 
26
  notebook_id = data["notebook_id"]
27
 
28
- resp_list = client.get("/api/notebooks/", params={"user_id": "u1"})
29
  assert resp_list.status_code == 200
30
  listed = resp_list.json()
31
  assert len(listed) == 1
32
  assert listed[0]["notebook_id"] == notebook_id
33
 
34
- resp_get = client.get(f"/api/notebooks/{notebook_id}", params={"user_id": "u1"})
35
  assert resp_get.status_code == 200
36
  data_get = resp_get.json()
37
  assert data_get["notebook_id"] == notebook_id
38
 
39
- resp_wrong_user = client.get(f"/api/notebooks/{notebook_id}", params={"user_id": "u2"})
40
  assert resp_wrong_user.status_code == 404
41
 
42
  resp_rename = client.patch(
43
  f"/api/notebooks/{notebook_id}",
44
  json={"user_id": "u1", "name": "Renamed"},
 
45
  )
46
  assert resp_rename.status_code == 200
47
  assert resp_rename.json()["name"] == "Renamed"
48
 
49
- resp_delete = client.delete(f"/api/notebooks/{notebook_id}", params={"user_id": "u1"})
50
  assert resp_delete.status_code == 200
51
  assert resp_delete.json() == {"deleted": True}
52
 
53
 
54
  def test_get_missing_notebook(tmp_path):
55
  notebooks_api.store = NotebookStore(base_dir=str(tmp_path))
56
- resp = client.get("/api/notebooks/missing", params={"user_id": "u1"})
57
  assert resp.status_code == 404
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  client = TestClient(app)
9
+ AUTH_U1 = {"X-User-Id": "u1"}
10
+ AUTH_U2 = {"X-User-Id": "u2"}
11
 
12
 
13
  def setup_function():
 
18
  notebooks_api.store = NotebookStore(base_dir=str(tmp_path))
19
 
20
  payload = {"user_id": "u1", "name": "Notebook 1"}
21
+ resp = client.post("/api/notebooks/", json=payload, headers=AUTH_U1)
22
  assert resp.status_code == 200
23
  data = resp.json()
24
  assert data["user_id"] == "u1"
 
27
 
28
  notebook_id = data["notebook_id"]
29
 
30
+ resp_list = client.get("/api/notebooks/", params={"user_id": "u1"}, headers=AUTH_U1)
31
  assert resp_list.status_code == 200
32
  listed = resp_list.json()
33
  assert len(listed) == 1
34
  assert listed[0]["notebook_id"] == notebook_id
35
 
36
+ resp_get = client.get(f"/api/notebooks/{notebook_id}", params={"user_id": "u1"}, headers=AUTH_U1)
37
  assert resp_get.status_code == 200
38
  data_get = resp_get.json()
39
  assert data_get["notebook_id"] == notebook_id
40
 
41
+ resp_wrong_user = client.get(f"/api/notebooks/{notebook_id}", params={"user_id": "u2"}, headers=AUTH_U2)
42
  assert resp_wrong_user.status_code == 404
43
 
44
  resp_rename = client.patch(
45
  f"/api/notebooks/{notebook_id}",
46
  json={"user_id": "u1", "name": "Renamed"},
47
+ headers=AUTH_U1,
48
  )
49
  assert resp_rename.status_code == 200
50
  assert resp_rename.json()["name"] == "Renamed"
51
 
52
+ resp_delete = client.delete(f"/api/notebooks/{notebook_id}", params={"user_id": "u1"}, headers=AUTH_U1)
53
  assert resp_delete.status_code == 200
54
  assert resp_delete.json() == {"deleted": True}
55
 
56
 
57
  def test_get_missing_notebook(tmp_path):
58
  notebooks_api.store = NotebookStore(base_dir=str(tmp_path))
59
+ resp = client.get("/api/notebooks/missing", params={"user_id": "u1"}, headers=AUTH_U1)
60
  assert resp.status_code == 404
61
+
62
+
63
+ def test_rejects_mismatched_user_id_payload(tmp_path):
64
+ notebooks_api.store = NotebookStore(base_dir=str(tmp_path))
65
+ resp = client.post(
66
+ "/api/notebooks/",
67
+ json={"user_id": "u2", "name": "Notebook 1"},
68
+ headers=AUTH_U1,
69
+ )
70
+ assert resp.status_code == 403
tests/test_rag.py CHANGED
@@ -59,3 +59,27 @@ def test_answer_notebook_question_persists_messages(monkeypatch, tmp_path):
59
  assert len(messages) == 2
60
  assert messages[0]["role"] == "user"
61
  assert messages[1]["role"] == "assistant"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  assert len(messages) == 2
60
  assert messages[0]["role"] == "user"
61
  assert messages[1]["role"] == "assistant"
62
+
63
+
64
+ def test_rerank_chunks_prefers_lexically_relevant_chunk():
65
+ candidates = [
66
+ {
67
+ "chunk_id": "a",
68
+ "text": "general overview and intro",
69
+ "metadata": {},
70
+ "distance": 0.15,
71
+ },
72
+ {
73
+ "chunk_id": "b",
74
+ "text": "transformer attention heads and query key value details",
75
+ "metadata": {},
76
+ "distance": 0.20,
77
+ },
78
+ ]
79
+ reranked = rag.rerank_chunks(
80
+ query="explain transformer attention",
81
+ candidate_chunks=candidates,
82
+ top_k=1,
83
+ )
84
+ assert len(reranked) == 1
85
+ assert reranked[0]["chunk_id"] == "b"