Spaces:
Sleeping
Sleeping
Max Saavedra commited on
Commit ·
10aaf26
1
Parent(s): 87f9e22
RAG techniques and benchmark
Browse files- README.md +7 -0
- backend/api/artifacts.py +77 -54
- backend/api/chat.py +12 -5
- backend/api/notebooks.py +21 -11
- backend/api/sources.py +17 -7
- backend/models/schemas.py +2 -1
- backend/modules/artifacts.py +46 -23
- backend/modules/rag.py +52 -1
- backend/services/auth.py +51 -3
- docs/rag_techniques.md +50 -0
- frontend/app.py +80 -12
- scripts/rag_benchmark.py +93 -0
- tests/test_api_artifacts.py +9 -1
- tests/test_api_chat.py +8 -2
- tests/test_api_notebooks.py +19 -6
- tests/test_rag.py +24 -0
README.md
CHANGED
|
@@ -19,6 +19,7 @@ MemoriaLM is a full-stack RAG application inspired by NotebookLM. It allows auth
|
|
| 19 |
- Upload documents (`.pdf`, `.pptx`, `.txt`) and URLs
|
| 20 |
- Chat with uploaded content using RAG
|
| 21 |
- Generate study artifacts (reports, quizzes, podcasts)
|
|
|
|
| 22 |
|
| 23 |
### Architecture
|
| 24 |
- **Frontend**: Gradio UI (frontend/app.py)
|
|
@@ -78,3 +79,9 @@ Run all tests:
|
|
| 78 |
```sh
|
| 79 |
pytest tests/
|
| 80 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
- Upload documents (`.pdf`, `.pptx`, `.txt`) and URLs
|
| 20 |
- Chat with uploaded content using RAG
|
| 21 |
- Generate study artifacts (reports, quizzes, podcasts)
|
| 22 |
+
- Retrieval modes for comparison: `topk` and `rerank`
|
| 23 |
|
| 24 |
### Architecture
|
| 25 |
- **Frontend**: Gradio UI (frontend/app.py)
|
|
|
|
| 79 |
```sh
|
| 80 |
pytest tests/
|
| 81 |
```
|
| 82 |
+
|
| 83 |
+
## RAG Technique Comparison
|
| 84 |
+
See `docs/rag_techniques.md` for:
|
| 85 |
+
- implemented retrieval modes (`topk` vs `rerank`)
|
| 86 |
+
- benchmark command (`scripts/rag_benchmark.py`)
|
| 87 |
+
- results table template for the project deliverable
|
backend/api/artifacts.py
CHANGED
|
@@ -1,39 +1,50 @@
|
|
| 1 |
-
from fastapi import APIRouter, HTTPException, Query
|
| 2 |
-
from fastapi.responses import FileResponse
|
| 3 |
-
|
| 4 |
-
from backend.models.schemas import ArtifactGenerateOut, ArtifactGenerateRequest, ArtifactListOut
|
| 5 |
-
from backend.modules.artifacts import (
|
| 6 |
generate_podcast,
|
| 7 |
generate_quiz,
|
| 8 |
generate_report,
|
| 9 |
-
list_artifacts,
|
| 10 |
-
resolve_artifact_path,
|
| 11 |
-
)
|
| 12 |
-
from backend.services.
|
|
|
|
| 13 |
|
| 14 |
router = APIRouter()
|
| 15 |
store = NotebookStore(base_dir="data")
|
| 16 |
|
| 17 |
|
| 18 |
-
@router.get("/{notebook_id}/artifacts", response_model=ArtifactListOut)
|
| 19 |
-
def list_notebook_artifacts(
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
except ValueError as exc:
|
| 25 |
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 26 |
|
| 27 |
|
| 28 |
-
@router.post("/{notebook_id}/artifacts/report", response_model=ArtifactGenerateOut)
|
| 29 |
-
def create_report_artifact(
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
except FileNotFoundError:
|
| 38 |
raise HTTPException(status_code=404, detail="Notebook not found")
|
| 39 |
except ValueError as exc:
|
|
@@ -42,15 +53,20 @@ def create_report_artifact(notebook_id: str, payload: ArtifactGenerateRequest) -
|
|
| 42 |
raise HTTPException(status_code=500, detail=f"Report generation failed: {exc}") from exc
|
| 43 |
|
| 44 |
|
| 45 |
-
@router.post("/{notebook_id}/artifacts/quiz", response_model=ArtifactGenerateOut)
|
| 46 |
-
def create_quiz_artifact(
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
)
|
| 55 |
except FileNotFoundError:
|
| 56 |
raise HTTPException(status_code=404, detail="Notebook not found")
|
|
@@ -60,15 +76,20 @@ def create_quiz_artifact(notebook_id: str, payload: ArtifactGenerateRequest) ->
|
|
| 60 |
raise HTTPException(status_code=500, detail=f"Quiz generation failed: {exc}") from exc
|
| 61 |
|
| 62 |
|
| 63 |
-
@router.post("/{notebook_id}/artifacts/podcast", response_model=ArtifactGenerateOut)
|
| 64 |
-
def create_podcast_artifact(
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
except FileNotFoundError:
|
| 73 |
raise HTTPException(status_code=404, detail="Notebook not found")
|
| 74 |
except ValueError as exc:
|
|
@@ -79,18 +100,20 @@ def create_podcast_artifact(notebook_id: str, payload: ArtifactGenerateRequest)
|
|
| 79 |
|
| 80 |
@router.get("/{notebook_id}/artifacts/download")
|
| 81 |
def download_artifact(
|
| 82 |
-
notebook_id: str,
|
| 83 |
-
user_id: str = Query(
|
| 84 |
-
artifact_type: str = Query(...),
|
| 85 |
-
filename: str = Query(...),
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
)
|
| 95 |
media_type = "audio/mpeg" if path.suffix.lower() == ".mp3" else "text/markdown"
|
| 96 |
return FileResponse(path=path, filename=path.name, media_type=media_type)
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException, Query
|
| 2 |
+
from fastapi.responses import FileResponse
|
| 3 |
+
|
| 4 |
+
from backend.models.schemas import ArtifactGenerateOut, ArtifactGenerateRequest, ArtifactListOut
|
| 5 |
+
from backend.modules.artifacts import (
|
| 6 |
generate_podcast,
|
| 7 |
generate_quiz,
|
| 8 |
generate_report,
|
| 9 |
+
list_artifacts,
|
| 10 |
+
resolve_artifact_path,
|
| 11 |
+
)
|
| 12 |
+
from backend.services.auth import User, enforce_user_match, get_current_user
|
| 13 |
+
from backend.services.storage import NotebookStore
|
| 14 |
|
| 15 |
router = APIRouter()
|
| 16 |
store = NotebookStore(base_dir="data")
|
| 17 |
|
| 18 |
|
| 19 |
+
@router.get("/{notebook_id}/artifacts", response_model=ArtifactListOut)
|
| 20 |
+
def list_notebook_artifacts(
|
| 21 |
+
notebook_id: str,
|
| 22 |
+
user_id: str | None = Query(default=None),
|
| 23 |
+
current_user: User = Depends(get_current_user),
|
| 24 |
+
) -> ArtifactListOut:
|
| 25 |
+
try:
|
| 26 |
+
enforce_user_match(current_user, user_id)
|
| 27 |
+
return list_artifacts(store, user_id=current_user.user_id, notebook_id=notebook_id)
|
| 28 |
+
except FileNotFoundError:
|
| 29 |
+
raise HTTPException(status_code=404, detail="Notebook not found")
|
| 30 |
except ValueError as exc:
|
| 31 |
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 32 |
|
| 33 |
|
| 34 |
+
@router.post("/{notebook_id}/artifacts/report", response_model=ArtifactGenerateOut)
|
| 35 |
+
def create_report_artifact(
|
| 36 |
+
notebook_id: str,
|
| 37 |
+
payload: ArtifactGenerateRequest,
|
| 38 |
+
current_user: User = Depends(get_current_user),
|
| 39 |
+
) -> ArtifactGenerateOut:
|
| 40 |
+
try:
|
| 41 |
+
enforce_user_match(current_user, payload.user_id)
|
| 42 |
+
return generate_report(
|
| 43 |
+
store,
|
| 44 |
+
user_id=current_user.user_id,
|
| 45 |
+
notebook_id=notebook_id,
|
| 46 |
+
prompt=payload.prompt,
|
| 47 |
+
)
|
| 48 |
except FileNotFoundError:
|
| 49 |
raise HTTPException(status_code=404, detail="Notebook not found")
|
| 50 |
except ValueError as exc:
|
|
|
|
| 53 |
raise HTTPException(status_code=500, detail=f"Report generation failed: {exc}") from exc
|
| 54 |
|
| 55 |
|
| 56 |
+
@router.post("/{notebook_id}/artifacts/quiz", response_model=ArtifactGenerateOut)
|
| 57 |
+
def create_quiz_artifact(
|
| 58 |
+
notebook_id: str,
|
| 59 |
+
payload: ArtifactGenerateRequest,
|
| 60 |
+
current_user: User = Depends(get_current_user),
|
| 61 |
+
) -> ArtifactGenerateOut:
|
| 62 |
+
try:
|
| 63 |
+
enforce_user_match(current_user, payload.user_id)
|
| 64 |
+
return generate_quiz(
|
| 65 |
+
store,
|
| 66 |
+
user_id=current_user.user_id,
|
| 67 |
+
notebook_id=notebook_id,
|
| 68 |
+
prompt=payload.prompt,
|
| 69 |
+
num_questions=payload.num_questions,
|
| 70 |
)
|
| 71 |
except FileNotFoundError:
|
| 72 |
raise HTTPException(status_code=404, detail="Notebook not found")
|
|
|
|
| 76 |
raise HTTPException(status_code=500, detail=f"Quiz generation failed: {exc}") from exc
|
| 77 |
|
| 78 |
|
| 79 |
+
@router.post("/{notebook_id}/artifacts/podcast", response_model=ArtifactGenerateOut)
|
| 80 |
+
def create_podcast_artifact(
|
| 81 |
+
notebook_id: str,
|
| 82 |
+
payload: ArtifactGenerateRequest,
|
| 83 |
+
current_user: User = Depends(get_current_user),
|
| 84 |
+
) -> ArtifactGenerateOut:
|
| 85 |
+
try:
|
| 86 |
+
enforce_user_match(current_user, payload.user_id)
|
| 87 |
+
return generate_podcast(
|
| 88 |
+
store,
|
| 89 |
+
user_id=current_user.user_id,
|
| 90 |
+
notebook_id=notebook_id,
|
| 91 |
+
prompt=payload.prompt,
|
| 92 |
+
)
|
| 93 |
except FileNotFoundError:
|
| 94 |
raise HTTPException(status_code=404, detail="Notebook not found")
|
| 95 |
except ValueError as exc:
|
|
|
|
| 100 |
|
| 101 |
@router.get("/{notebook_id}/artifacts/download")
|
| 102 |
def download_artifact(
|
| 103 |
+
notebook_id: str,
|
| 104 |
+
user_id: str | None = Query(default=None),
|
| 105 |
+
artifact_type: str = Query(...),
|
| 106 |
+
filename: str = Query(...),
|
| 107 |
+
current_user: User = Depends(get_current_user),
|
| 108 |
+
):
|
| 109 |
+
try:
|
| 110 |
+
enforce_user_match(current_user, user_id)
|
| 111 |
+
path = resolve_artifact_path(
|
| 112 |
+
store,
|
| 113 |
+
user_id=current_user.user_id,
|
| 114 |
+
notebook_id=notebook_id,
|
| 115 |
+
artifact_type=artifact_type,
|
| 116 |
+
filename=filename,
|
| 117 |
)
|
| 118 |
media_type = "audio/mpeg" if path.suffix.lower() == ".mp3" else "text/markdown"
|
| 119 |
return FileResponse(path=path, filename=path.name, media_type=media_type)
|
backend/api/chat.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
-
from fastapi import APIRouter,
|
| 2 |
|
| 3 |
from backend.models.schemas import ChatHistoryOut, ChatMessageOut, ChatRequest, ChatResponseOut
|
| 4 |
from backend.modules.rag import answer_notebook_question, get_chat_history
|
|
|
|
| 5 |
from backend.services.storage import NotebookStore
|
| 6 |
|
| 7 |
router = APIRouter()
|
|
@@ -9,9 +10,9 @@ store = NotebookStore(base_dir="data")
|
|
| 9 |
|
| 10 |
|
| 11 |
@router.get("/{notebook_id}/chat", response_model=ChatHistoryOut)
|
| 12 |
-
def chat_history(notebook_id: str,
|
| 13 |
try:
|
| 14 |
-
messages = get_chat_history(store, user_id=user_id, notebook_id=notebook_id)
|
| 15 |
return ChatHistoryOut(messages=[ChatMessageOut(**m) for m in messages])
|
| 16 |
except FileNotFoundError:
|
| 17 |
raise HTTPException(status_code=404, detail="Notebook not found")
|
|
@@ -20,14 +21,20 @@ def chat_history(notebook_id: str, user_id: str = Query(...)) -> ChatHistoryOut:
|
|
| 20 |
|
| 21 |
|
| 22 |
@router.post("/{notebook_id}/chat", response_model=ChatResponseOut)
|
| 23 |
-
def chat(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
try:
|
|
|
|
| 25 |
result = answer_notebook_question(
|
| 26 |
store,
|
| 27 |
-
user_id=
|
| 28 |
notebook_id=notebook_id,
|
| 29 |
message=payload.message,
|
| 30 |
top_k=payload.top_k,
|
|
|
|
| 31 |
)
|
| 32 |
return ChatResponseOut(**result)
|
| 33 |
except FileNotFoundError:
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 2 |
|
| 3 |
from backend.models.schemas import ChatHistoryOut, ChatMessageOut, ChatRequest, ChatResponseOut
|
| 4 |
from backend.modules.rag import answer_notebook_question, get_chat_history
|
| 5 |
+
from backend.services.auth import User, enforce_user_match, get_current_user
|
| 6 |
from backend.services.storage import NotebookStore
|
| 7 |
|
| 8 |
router = APIRouter()
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
@router.get("/{notebook_id}/chat", response_model=ChatHistoryOut)
|
| 13 |
+
def chat_history(notebook_id: str, current_user: User = Depends(get_current_user)) -> ChatHistoryOut:
|
| 14 |
try:
|
| 15 |
+
messages = get_chat_history(store, user_id=current_user.user_id, notebook_id=notebook_id)
|
| 16 |
return ChatHistoryOut(messages=[ChatMessageOut(**m) for m in messages])
|
| 17 |
except FileNotFoundError:
|
| 18 |
raise HTTPException(status_code=404, detail="Notebook not found")
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
@router.post("/{notebook_id}/chat", response_model=ChatResponseOut)
|
| 24 |
+
def chat(
|
| 25 |
+
notebook_id: str,
|
| 26 |
+
payload: ChatRequest,
|
| 27 |
+
current_user: User = Depends(get_current_user),
|
| 28 |
+
) -> ChatResponseOut:
|
| 29 |
try:
|
| 30 |
+
enforce_user_match(current_user, payload.user_id)
|
| 31 |
result = answer_notebook_question(
|
| 32 |
store,
|
| 33 |
+
user_id=current_user.user_id,
|
| 34 |
notebook_id=notebook_id,
|
| 35 |
message=payload.message,
|
| 36 |
top_k=payload.top_k,
|
| 37 |
+
retrieval_mode=payload.retrieval_mode,
|
| 38 |
)
|
| 39 |
return ChatResponseOut(**result)
|
| 40 |
except FileNotFoundError:
|
backend/api/notebooks.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
from typing import List
|
| 2 |
|
| 3 |
-
from fastapi import APIRouter,
|
| 4 |
|
| 5 |
from backend.models.schemas import NotebookCreate, NotebookOut, NotebookRename
|
|
|
|
| 6 |
from backend.services.storage import NotebookStore
|
| 7 |
|
| 8 |
router = APIRouter()
|
|
@@ -10,25 +11,29 @@ store = NotebookStore(base_dir="data")
|
|
| 10 |
|
| 11 |
|
| 12 |
@router.get("/", response_model=List[NotebookOut])
|
| 13 |
-
def list_notebooks(
|
| 14 |
try:
|
| 15 |
-
return store.list(user_id)
|
| 16 |
except ValueError as exc:
|
| 17 |
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 18 |
|
| 19 |
|
| 20 |
@router.post("/", response_model=NotebookOut)
|
| 21 |
-
def create_notebook(
|
|
|
|
|
|
|
|
|
|
| 22 |
try:
|
| 23 |
-
|
|
|
|
| 24 |
except ValueError as exc:
|
| 25 |
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 26 |
|
| 27 |
|
| 28 |
@router.get("/{notebook_id}", response_model=NotebookOut)
|
| 29 |
-
def get_notebook(notebook_id: str,
|
| 30 |
try:
|
| 31 |
-
notebook = store.get(user_id, notebook_id)
|
| 32 |
except ValueError as exc:
|
| 33 |
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 34 |
|
|
@@ -38,9 +43,14 @@ def get_notebook(notebook_id: str, user_id: str = Query(...)) -> NotebookOut:
|
|
| 38 |
|
| 39 |
|
| 40 |
@router.patch("/{notebook_id}", response_model=NotebookOut)
|
| 41 |
-
def rename_notebook(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
try:
|
| 43 |
-
|
|
|
|
| 44 |
except ValueError as exc:
|
| 45 |
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 46 |
|
|
@@ -50,9 +60,9 @@ def rename_notebook(notebook_id: str, payload: NotebookRename) -> NotebookOut:
|
|
| 50 |
|
| 51 |
|
| 52 |
@router.delete("/{notebook_id}")
|
| 53 |
-
def delete_notebook(notebook_id: str,
|
| 54 |
try:
|
| 55 |
-
deleted = store.delete(user_id, notebook_id)
|
| 56 |
except ValueError as exc:
|
| 57 |
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 58 |
|
|
|
|
| 1 |
from typing import List
|
| 2 |
|
| 3 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 4 |
|
| 5 |
from backend.models.schemas import NotebookCreate, NotebookOut, NotebookRename
|
| 6 |
+
from backend.services.auth import User, enforce_user_match, get_current_user
|
| 7 |
from backend.services.storage import NotebookStore
|
| 8 |
|
| 9 |
router = APIRouter()
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
@router.get("/", response_model=List[NotebookOut])
|
| 14 |
+
def list_notebooks(current_user: User = Depends(get_current_user)) -> List[NotebookOut]:
|
| 15 |
try:
|
| 16 |
+
return store.list(current_user.user_id)
|
| 17 |
except ValueError as exc:
|
| 18 |
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 19 |
|
| 20 |
|
| 21 |
@router.post("/", response_model=NotebookOut)
|
| 22 |
+
def create_notebook(
|
| 23 |
+
payload: NotebookCreate,
|
| 24 |
+
current_user: User = Depends(get_current_user),
|
| 25 |
+
) -> NotebookOut:
|
| 26 |
try:
|
| 27 |
+
enforce_user_match(current_user, payload.user_id)
|
| 28 |
+
return store.create(NotebookCreate(user_id=current_user.user_id, name=payload.name))
|
| 29 |
except ValueError as exc:
|
| 30 |
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 31 |
|
| 32 |
|
| 33 |
@router.get("/{notebook_id}", response_model=NotebookOut)
|
| 34 |
+
def get_notebook(notebook_id: str, current_user: User = Depends(get_current_user)) -> NotebookOut:
|
| 35 |
try:
|
| 36 |
+
notebook = store.get(current_user.user_id, notebook_id)
|
| 37 |
except ValueError as exc:
|
| 38 |
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 39 |
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
@router.patch("/{notebook_id}", response_model=NotebookOut)
|
| 46 |
+
def rename_notebook(
|
| 47 |
+
notebook_id: str,
|
| 48 |
+
payload: NotebookRename,
|
| 49 |
+
current_user: User = Depends(get_current_user),
|
| 50 |
+
) -> NotebookOut:
|
| 51 |
try:
|
| 52 |
+
enforce_user_match(current_user, payload.user_id)
|
| 53 |
+
notebook = store.rename(current_user.user_id, notebook_id, payload.name)
|
| 54 |
except ValueError as exc:
|
| 55 |
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 56 |
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
@router.delete("/{notebook_id}")
|
| 63 |
+
def delete_notebook(notebook_id: str, current_user: User = Depends(get_current_user)) -> dict:
|
| 64 |
try:
|
| 65 |
+
deleted = store.delete(current_user.user_id, notebook_id)
|
| 66 |
except ValueError as exc:
|
| 67 |
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 68 |
|
backend/api/sources.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
| 1 |
-
from
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from backend.models.schemas import SourceListOut, SourceOut, UrlIngestRequest
|
| 4 |
from backend.modules.ingestion import ingest_uploaded_bytes, ingest_url, list_ingested_sources
|
|
|
|
| 5 |
from backend.services.storage import NotebookStore
|
| 6 |
|
| 7 |
router = APIRouter()
|
|
@@ -9,9 +12,9 @@ store = NotebookStore(base_dir="data")
|
|
| 9 |
|
| 10 |
|
| 11 |
@router.get("/{notebook_id}/sources", response_model=SourceListOut)
|
| 12 |
-
def list_sources(notebook_id: str,
|
| 13 |
try:
|
| 14 |
-
items = list_ingested_sources(store, user_id=user_id, notebook_id=notebook_id)
|
| 15 |
return SourceListOut(sources=items)
|
| 16 |
except FileNotFoundError:
|
| 17 |
raise HTTPException(status_code=404, detail="Notebook not found")
|
|
@@ -20,11 +23,16 @@ def list_sources(notebook_id: str, user_id: str = Query(...)) -> SourceListOut:
|
|
| 20 |
|
| 21 |
|
| 22 |
@router.post("/{notebook_id}/sources/url", response_model=SourceOut)
|
| 23 |
-
def ingest_source_url(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
try:
|
|
|
|
| 25 |
return ingest_url(
|
| 26 |
store,
|
| 27 |
-
user_id=
|
| 28 |
notebook_id=notebook_id,
|
| 29 |
url=str(payload.url),
|
| 30 |
source_name=payload.source_name,
|
|
@@ -40,14 +48,16 @@ def ingest_source_url(notebook_id: str, payload: UrlIngestRequest) -> SourceOut:
|
|
| 40 |
@router.post("/{notebook_id}/sources/upload", response_model=SourceOut)
|
| 41 |
async def upload_source_file(
|
| 42 |
notebook_id: str,
|
| 43 |
-
user_id: str = Form(
|
| 44 |
file: UploadFile = File(...),
|
|
|
|
| 45 |
) -> SourceOut:
|
| 46 |
try:
|
|
|
|
| 47 |
content = await file.read()
|
| 48 |
return ingest_uploaded_bytes(
|
| 49 |
store,
|
| 50 |
-
user_id=user_id,
|
| 51 |
notebook_id=notebook_id,
|
| 52 |
filename=file.filename or "upload.txt",
|
| 53 |
content=content,
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
| 4 |
|
| 5 |
from backend.models.schemas import SourceListOut, SourceOut, UrlIngestRequest
|
| 6 |
from backend.modules.ingestion import ingest_uploaded_bytes, ingest_url, list_ingested_sources
|
| 7 |
+
from backend.services.auth import User, enforce_user_match, get_current_user
|
| 8 |
from backend.services.storage import NotebookStore
|
| 9 |
|
| 10 |
router = APIRouter()
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
@router.get("/{notebook_id}/sources", response_model=SourceListOut)
|
| 15 |
+
def list_sources(notebook_id: str, current_user: User = Depends(get_current_user)) -> SourceListOut:
|
| 16 |
try:
|
| 17 |
+
items = list_ingested_sources(store, user_id=current_user.user_id, notebook_id=notebook_id)
|
| 18 |
return SourceListOut(sources=items)
|
| 19 |
except FileNotFoundError:
|
| 20 |
raise HTTPException(status_code=404, detail="Notebook not found")
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
@router.post("/{notebook_id}/sources/url", response_model=SourceOut)
|
| 26 |
+
def ingest_source_url(
|
| 27 |
+
notebook_id: str,
|
| 28 |
+
payload: UrlIngestRequest,
|
| 29 |
+
current_user: User = Depends(get_current_user),
|
| 30 |
+
) -> SourceOut:
|
| 31 |
try:
|
| 32 |
+
enforce_user_match(current_user, payload.user_id)
|
| 33 |
return ingest_url(
|
| 34 |
store,
|
| 35 |
+
user_id=current_user.user_id,
|
| 36 |
notebook_id=notebook_id,
|
| 37 |
url=str(payload.url),
|
| 38 |
source_name=payload.source_name,
|
|
|
|
| 48 |
@router.post("/{notebook_id}/sources/upload", response_model=SourceOut)
|
| 49 |
async def upload_source_file(
|
| 50 |
notebook_id: str,
|
| 51 |
+
user_id: Optional[str] = Form(default=None),
|
| 52 |
file: UploadFile = File(...),
|
| 53 |
+
current_user: User = Depends(get_current_user),
|
| 54 |
) -> SourceOut:
|
| 55 |
try:
|
| 56 |
+
enforce_user_match(current_user, user_id)
|
| 57 |
content = await file.read()
|
| 58 |
return ingest_uploaded_bytes(
|
| 59 |
store,
|
| 60 |
+
user_id=current_user.user_id,
|
| 61 |
notebook_id=notebook_id,
|
| 62 |
filename=file.filename or "upload.txt",
|
| 63 |
content=content,
|
backend/models/schemas.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import List, Optional
|
| 2 |
|
| 3 |
from pydantic import BaseModel, HttpUrl
|
| 4 |
|
|
@@ -46,6 +46,7 @@ class ChatRequest(BaseModel):
|
|
| 46 |
user_id: str
|
| 47 |
message: str
|
| 48 |
top_k: int = 5
|
|
|
|
| 49 |
|
| 50 |
|
| 51 |
class CitationOut(BaseModel):
|
|
|
|
| 1 |
+
from typing import List, Literal, Optional
|
| 2 |
|
| 3 |
from pydantic import BaseModel, HttpUrl
|
| 4 |
|
|
|
|
| 46 |
user_id: str
|
| 47 |
message: str
|
| 48 |
top_k: int = 5
|
| 49 |
+
retrieval_mode: Literal["topk", "rerank"] = "topk"
|
| 50 |
|
| 51 |
|
| 52 |
class CitationOut(BaseModel):
|
backend/modules/artifacts.py
CHANGED
|
@@ -1,17 +1,12 @@
|
|
| 1 |
-
import audioop
|
| 2 |
-
import json
|
| 3 |
-
import
|
| 4 |
-
import
|
| 5 |
-
import
|
| 6 |
-
import
|
| 7 |
-
import
|
| 8 |
-
import
|
| 9 |
-
import
|
| 10 |
-
from transformers import AutoTokenizer, VitsModel
|
| 11 |
-
from datetime import datetime, timezone
|
| 12 |
-
from io import BytesIO
|
| 13 |
-
from pathlib import Path
|
| 14 |
-
from typing import Dict, List, Optional, Tuple
|
| 15 |
|
| 16 |
from backend.models.schemas import (
|
| 17 |
ArtifactFileOut,
|
|
@@ -27,8 +22,35 @@ import logging
|
|
| 27 |
logging.basicConfig(level=logging.INFO)
|
| 28 |
logger = logging.getLogger(__name__)
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
def _now() -> str:
|
| 34 |
return datetime.now(timezone.utc).replace(microsecond=0).isoformat()
|
|
@@ -209,13 +231,14 @@ def clean_transcript_for_tts(transcript: str) -> str:
|
|
| 209 |
return " ".join(lines)
|
| 210 |
|
| 211 |
|
| 212 |
-
def _synthesize_podcast_mp3(transcript_text: str) -> bytes:
|
| 213 |
-
tts_text = clean_transcript_for_tts(transcript_text)[:1800]
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
|
|
|
| 219 |
|
| 220 |
wav_buffer = io.BytesIO()
|
| 221 |
sf.write(wav_buffer, waveform, 16000, format="WAV")
|
|
|
|
| 1 |
+
import audioop
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
import wave
|
| 5 |
+
import io
|
| 6 |
+
from datetime import datetime, timezone
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
from backend.models.schemas import (
|
| 12 |
ArtifactFileOut,
|
|
|
|
| 22 |
logging.basicConfig(level=logging.INFO)
|
| 23 |
logger = logging.getLogger(__name__)
|
| 24 |
|
| 25 |
+
try:
|
| 26 |
+
import soundfile as sf
|
| 27 |
+
except ImportError:
|
| 28 |
+
sf = None
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
import torch
|
| 32 |
+
except ImportError:
|
| 33 |
+
torch = None
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from transformers import AutoTokenizer, VitsModel
|
| 37 |
+
except ImportError:
|
| 38 |
+
AutoTokenizer = None
|
| 39 |
+
VitsModel = None
|
| 40 |
+
|
| 41 |
+
_vits_model = None
|
| 42 |
+
_tokenizer = None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _get_tts_model():
|
| 46 |
+
global _vits_model, _tokenizer
|
| 47 |
+
if _vits_model is not None and _tokenizer is not None:
|
| 48 |
+
return _vits_model, _tokenizer
|
| 49 |
+
if torch is None or sf is None or AutoTokenizer is None or VitsModel is None:
|
| 50 |
+
raise RuntimeError("TTS dependencies are not installed")
|
| 51 |
+
_vits_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
| 52 |
+
_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
|
| 53 |
+
return _vits_model, _tokenizer
|
| 54 |
|
| 55 |
def _now() -> str:
|
| 56 |
return datetime.now(timezone.utc).replace(microsecond=0).isoformat()
|
|
|
|
| 231 |
return " ".join(lines)
|
| 232 |
|
| 233 |
|
| 234 |
+
def _synthesize_podcast_mp3(transcript_text: str) -> bytes:
|
| 235 |
+
tts_text = clean_transcript_for_tts(transcript_text)[:1800]
|
| 236 |
+
model, tokenizer = _get_tts_model()
|
| 237 |
+
|
| 238 |
+
inputs = tokenizer(tts_text, return_tensors="pt")
|
| 239 |
+
|
| 240 |
+
with torch.no_grad():
|
| 241 |
+
waveform = model(**inputs).waveform.squeeze().cpu().numpy()
|
| 242 |
|
| 243 |
wav_buffer = io.BytesIO()
|
| 244 |
sf.write(wav_buffer, waveform, 16000, format="WAV")
|
backend/modules/rag.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from datetime import datetime, timezone
|
|
|
|
| 2 |
from typing import Any, Dict, List
|
| 3 |
|
| 4 |
from backend.services.embeddings import embedding_service
|
|
@@ -51,12 +52,60 @@ def retrieve_notebook_chunks(
|
|
| 51 |
notebook_id: str,
|
| 52 |
query: str,
|
| 53 |
top_k: int = 5,
|
|
|
|
| 54 |
) -> List[Dict[str, Any]]:
|
| 55 |
query_vecs = embedding_service.embed_texts([query])
|
| 56 |
if not query_vecs:
|
| 57 |
return []
|
| 58 |
chroma = ChromaNotebookStore(store.chroma_dir(user_id, notebook_id))
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
def _citation_objects(retrieved_chunks: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
|
@@ -87,6 +136,7 @@ def answer_notebook_question(
|
|
| 87 |
notebook_id: str,
|
| 88 |
message: str,
|
| 89 |
top_k: int = 5,
|
|
|
|
| 90 |
) -> Dict[str, Any]:
|
| 91 |
# Ensures the notebook exists and belongs to the user before retrieving.
|
| 92 |
store.require_notebook_dir(user_id, notebook_id)
|
|
@@ -97,6 +147,7 @@ def answer_notebook_question(
|
|
| 97 |
notebook_id=notebook_id,
|
| 98 |
query=message,
|
| 99 |
top_k=top_k,
|
|
|
|
| 100 |
)
|
| 101 |
prompt = build_rag_prompt(message, retrieved_chunks)
|
| 102 |
answer = llm_service.generate(prompt)
|
|
|
|
| 1 |
from datetime import datetime, timezone
|
| 2 |
+
import re
|
| 3 |
from typing import Any, Dict, List
|
| 4 |
|
| 5 |
from backend.services.embeddings import embedding_service
|
|
|
|
| 52 |
notebook_id: str,
|
| 53 |
query: str,
|
| 54 |
top_k: int = 5,
|
| 55 |
+
retrieval_mode: str = "topk",
|
| 56 |
) -> List[Dict[str, Any]]:
|
| 57 |
query_vecs = embedding_service.embed_texts([query])
|
| 58 |
if not query_vecs:
|
| 59 |
return []
|
| 60 |
chroma = ChromaNotebookStore(store.chroma_dir(user_id, notebook_id))
|
| 61 |
+
mode = retrieval_mode.strip().lower()
|
| 62 |
+
if mode == "topk":
|
| 63 |
+
return chroma.query(query_vecs[0], k=top_k)
|
| 64 |
+
if mode == "rerank":
|
| 65 |
+
# Retrieve a wider pool first, then rerank using lexical overlap + vector signal.
|
| 66 |
+
pool_size = max(top_k * 4, top_k)
|
| 67 |
+
pool = chroma.query(query_vecs[0], k=pool_size)
|
| 68 |
+
return rerank_chunks(query=query, candidate_chunks=pool, top_k=top_k)
|
| 69 |
+
raise ValueError("Unsupported retrieval_mode")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _tokenize(text: str) -> set[str]:
|
| 73 |
+
return {tok for tok in re.findall(r"[a-zA-Z0-9]+", text.lower()) if tok}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _vector_relevance_from_distance(distance: Any) -> float:
|
| 77 |
+
if distance is None:
|
| 78 |
+
return 0.0
|
| 79 |
+
try:
|
| 80 |
+
dist = float(distance)
|
| 81 |
+
except (TypeError, ValueError):
|
| 82 |
+
return 0.0
|
| 83 |
+
# Chroma cosine distance: lower is better; map to 0..1 relevance-like value.
|
| 84 |
+
return 1.0 / (1.0 + max(0.0, dist))
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def rerank_chunks(
|
| 88 |
+
*,
|
| 89 |
+
query: str,
|
| 90 |
+
candidate_chunks: List[Dict[str, Any]],
|
| 91 |
+
top_k: int,
|
| 92 |
+
alpha: float = 0.65,
|
| 93 |
+
) -> List[Dict[str, Any]]:
|
| 94 |
+
if not candidate_chunks:
|
| 95 |
+
return []
|
| 96 |
+
query_tokens = _tokenize(query)
|
| 97 |
+
ranked: List[tuple[float, Dict[str, Any]]] = []
|
| 98 |
+
for chunk in candidate_chunks:
|
| 99 |
+
text = str(chunk.get("text", ""))
|
| 100 |
+
chunk_tokens = _tokenize(text)
|
| 101 |
+
lexical = (len(query_tokens & chunk_tokens) / len(query_tokens)) if query_tokens else 0.0
|
| 102 |
+
vector_rel = _vector_relevance_from_distance(chunk.get("distance"))
|
| 103 |
+
score = (alpha * vector_rel) + ((1.0 - alpha) * lexical)
|
| 104 |
+
enriched = dict(chunk)
|
| 105 |
+
enriched["rerank_score"] = score
|
| 106 |
+
ranked.append((score, enriched))
|
| 107 |
+
ranked.sort(key=lambda x: x[0], reverse=True)
|
| 108 |
+
return [item[1] for item in ranked[: max(1, top_k)]]
|
| 109 |
|
| 110 |
|
| 111 |
def _citation_objects(retrieved_chunks: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
|
|
|
| 136 |
notebook_id: str,
|
| 137 |
message: str,
|
| 138 |
top_k: int = 5,
|
| 139 |
+
retrieval_mode: str = "topk",
|
| 140 |
) -> Dict[str, Any]:
|
| 141 |
# Ensures the notebook exists and belongs to the user before retrieving.
|
| 142 |
store.require_notebook_dir(user_id, notebook_id)
|
|
|
|
| 147 |
notebook_id=notebook_id,
|
| 148 |
query=message,
|
| 149 |
top_k=top_k,
|
| 150 |
+
retrieval_mode=retrieval_mode,
|
| 151 |
)
|
| 152 |
prompt = build_rag_prompt(message, retrieved_chunks)
|
| 153 |
answer = llm_service.generate(prompt)
|
backend/services/auth.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
| 1 |
from dataclasses import dataclass
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
@dataclass
|
|
@@ -7,6 +11,50 @@ class User:
|
|
| 7 |
email: str
|
| 8 |
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
+
import re
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
from fastapi import Header, HTTPException, Request
|
| 6 |
|
| 7 |
|
| 8 |
@dataclass
|
|
|
|
| 11 |
email: str
|
| 12 |
|
| 13 |
|
| 14 |
+
_USER_ID_PATTERN = re.compile(r"[A-Za-z0-9._@-]+$")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _validate_user_id(value: str) -> str:
|
| 18 |
+
if not _USER_ID_PATTERN.fullmatch(value):
|
| 19 |
+
raise HTTPException(status_code=400, detail="Invalid authenticated user id")
|
| 20 |
+
return value
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _extract_user_id_from_request(request: Request) -> Optional[str]:
|
| 24 |
+
# These headers are common in reverse-proxy / OAuth fronted deployments.
|
| 25 |
+
header_candidates = [
|
| 26 |
+
request.headers.get("x-user-id"),
|
| 27 |
+
request.headers.get("x-hf-username"),
|
| 28 |
+
request.headers.get("x-forwarded-user"),
|
| 29 |
+
request.headers.get("remote-user"),
|
| 30 |
+
]
|
| 31 |
+
for candidate in header_candidates:
|
| 32 |
+
if candidate and candidate.strip():
|
| 33 |
+
return candidate.strip()
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_current_user(
|
| 38 |
+
request: Request,
|
| 39 |
+
x_user_id: Optional[str] = Header(default=None, alias="X-User-Id"),
|
| 40 |
+
x_user_email: Optional[str] = Header(default=None, alias="X-User-Email"),
|
| 41 |
+
) -> User:
|
| 42 |
+
user_id = (x_user_id or "").strip() or _extract_user_id_from_request(request)
|
| 43 |
+
if not user_id:
|
| 44 |
+
raise HTTPException(
|
| 45 |
+
status_code=401,
|
| 46 |
+
detail="Authentication required. Provide authenticated user context.",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
user_id = _validate_user_id(user_id)
|
| 50 |
+
email = (x_user_email or f"{user_id}@local").strip()
|
| 51 |
+
return User(user_id=user_id, email=email)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def enforce_user_match(current_user: User, supplied_user_id: Optional[str]) -> None:
|
| 55 |
+
supplied = (supplied_user_id or "").strip()
|
| 56 |
+
if supplied and supplied != current_user.user_id:
|
| 57 |
+
raise HTTPException(
|
| 58 |
+
status_code=403,
|
| 59 |
+
detail="Authenticated user does not match supplied user_id",
|
| 60 |
+
)
|
docs/rag_techniques.md
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RAG Techniques Comparison
|
| 2 |
+
|
| 3 |
+
This project currently supports two retrieval modes in chat:
|
| 4 |
+
|
| 5 |
+
1. `topk`: baseline dense retrieval from Chroma using cosine distance.
|
| 6 |
+
2. `rerank`: retrieves a larger candidate pool, then re-scores candidates using:
|
| 7 |
+
- vector relevance (from Chroma distance)
|
| 8 |
+
- lexical overlap with query terms
|
| 9 |
+
|
| 10 |
+
## Where It Is Implemented
|
| 11 |
+
|
| 12 |
+
- Request model: `backend/models/schemas.py` (`ChatRequest.retrieval_mode`)
|
| 13 |
+
- Retrieval logic: `backend/modules/rag.py`
|
| 14 |
+
- Chat API: `backend/api/chat.py`
|
| 15 |
+
|
| 16 |
+
## How To Run Benchmark
|
| 17 |
+
|
| 18 |
+
Prerequisites:
|
| 19 |
+
- Backend running
|
| 20 |
+
- Notebook has at least one ingested source
|
| 21 |
+
|
| 22 |
+
Example:
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
python scripts/rag_benchmark.py \
|
| 26 |
+
--base-url http://127.0.0.1:8000 \
|
| 27 |
+
--user-id <user_id> \
|
| 28 |
+
--notebook-id <notebook_id> \
|
| 29 |
+
--query "Explain the key ideas in my notes" \
|
| 30 |
+
--top-k 5 \
|
| 31 |
+
--runs 5
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
The script prints JSON with average/min/max latency and citation/chunk stats for both modes.
|
| 35 |
+
|
| 36 |
+
## Report Template
|
| 37 |
+
|
| 38 |
+
Use this table in your class deliverable:
|
| 39 |
+
|
| 40 |
+
| Query | Mode | Avg Latency (ms) | Avg Citations | Notes on Retrieved Context |
|
| 41 |
+
|---|---:|---:|---:|---|
|
| 42 |
+
| Q1 | topk | | | |
|
| 43 |
+
| Q1 | rerank | | | |
|
| 44 |
+
| Q2 | topk | | | |
|
| 45 |
+
| Q2 | rerank | | | |
|
| 46 |
+
|
| 47 |
+
Recommended write-up points:
|
| 48 |
+
- How different the retrieved chunks were between `topk` and `rerank`
|
| 49 |
+
- Which mode produced less redundant context
|
| 50 |
+
- Latency tradeoff (rerank usually slightly slower)
|
frontend/app.py
CHANGED
|
@@ -80,7 +80,17 @@ def _maybe_start_local_backend() -> None:
|
|
| 80 |
time.sleep(0.2)
|
| 81 |
|
| 82 |
|
| 83 |
-
def _api_request(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
url = f"{BACKEND_URL}{path}"
|
| 85 |
try:
|
| 86 |
resp = requests.request(
|
|
@@ -90,6 +100,7 @@ def _api_request(method: str, path: str, *, params=None, json_body=None, files=N
|
|
| 90 |
json=json_body,
|
| 91 |
files=files,
|
| 92 |
data=data,
|
|
|
|
| 93 |
timeout=timeout,
|
| 94 |
)
|
| 95 |
except requests.RequestException as exc:
|
|
@@ -109,6 +120,11 @@ def _api_request(method: str, path: str, *, params=None, json_body=None, files=N
|
|
| 109 |
return None
|
| 110 |
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
def _format_notebook_choices(items: list[dict[str, Any]]):
|
| 113 |
return [
|
| 114 |
(f"{item.get('name', 'Untitled')} [{str(item.get('notebook_id', ''))[:8]}]", item.get("notebook_id"))
|
|
@@ -153,15 +169,25 @@ def load_notebooks(user_id: str):
|
|
| 153 |
if not user_id:
|
| 154 |
return gr.Dropdown(choices=[], value=None), [], gr.JSON(value={"sources": []}), [], "Enter a user ID."
|
| 155 |
try:
|
| 156 |
-
items = _api_request("GET", "/api/notebooks/", params={"user_id": user_id})
|
| 157 |
choices = _format_notebook_choices(items)
|
| 158 |
selected = choices[0][1] if choices else None
|
| 159 |
chat_history = []
|
| 160 |
sources_payload = {"sources": []}
|
| 161 |
if selected:
|
| 162 |
-
chat = _api_request(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
chat_history = _messages_to_chatbot(chat.get("messages", []))
|
| 164 |
-
sources_payload = _api_request(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
return gr.Dropdown(choices=choices, value=selected), items, gr.JSON(value=sources_payload), chat_history, ""
|
| 166 |
except Exception as exc:
|
| 167 |
return gr.Dropdown(choices=[], value=None), [], gr.JSON(value={"sources": []}), [], str(exc)
|
|
@@ -173,7 +199,12 @@ def create_notebook(user_id: str, notebook_name: str):
|
|
| 173 |
if not user_id:
|
| 174 |
return gr.Dropdown(choices=[], value=None), [], "", "Enter a user ID first."
|
| 175 |
try:
|
| 176 |
-
_api_request(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
dropdown, notebooks, _sources, _chat, _status = load_notebooks(user_id)
|
| 178 |
return dropdown, notebooks, "", f"Created notebook '{notebook_name}'."
|
| 179 |
except Exception as exc:
|
|
@@ -190,8 +221,9 @@ def rename_notebook(user_id: str, notebook_id: str, notebook_name: str):
|
|
| 190 |
"PATCH",
|
| 191 |
f"/api/notebooks/{notebook_id}",
|
| 192 |
json_body={"user_id": user_id, "name": notebook_name},
|
|
|
|
| 193 |
)
|
| 194 |
-
items = _api_request("GET", "/api/notebooks/", params={"user_id": user_id})
|
| 195 |
choices = _format_notebook_choices(items)
|
| 196 |
return gr.Dropdown(choices=choices, value=notebook_id), items, f"Renamed notebook to '{notebook_name}'."
|
| 197 |
except Exception as exc:
|
|
@@ -203,7 +235,12 @@ def delete_notebook(user_id: str, notebook_id: str):
|
|
| 203 |
if not user_id or not notebook_id:
|
| 204 |
return gr.Dropdown(choices=[], value=None), [], gr.JSON(value={"sources": []}), [], "Select a notebook to delete."
|
| 205 |
try:
|
| 206 |
-
_api_request(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
dropdown, notebooks, sources_json, chat_history, _ = load_notebooks(user_id)
|
| 208 |
return dropdown, notebooks, sources_json, chat_history, "Deleted notebook."
|
| 209 |
except Exception as exc:
|
|
@@ -215,8 +252,18 @@ def on_notebook_change(user_id: str, notebook_id: str):
|
|
| 215 |
if not user_id or not notebook_id:
|
| 216 |
return gr.JSON(value={"sources": []}), [], ""
|
| 217 |
try:
|
| 218 |
-
sources_payload = _api_request(
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
return gr.JSON(value=sources_payload), _messages_to_chatbot(chat.get("messages", [])), ""
|
| 221 |
except Exception as exc:
|
| 222 |
return gr.JSON(value={"sources": []}), [], str(exc)
|
|
@@ -233,9 +280,15 @@ def upload_source(user_id: str, notebook_id: str, file_path: str):
|
|
| 233 |
f"/api/notebooks/{notebook_id}/sources/upload",
|
| 234 |
data={"user_id": user_id},
|
| 235 |
files={"file": (os.path.basename(file_path), f)},
|
|
|
|
| 236 |
timeout=180,
|
| 237 |
)
|
| 238 |
-
sources_payload = _api_request(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
return gr.JSON(value=sources_payload), f"Ingested file: {payload.get('source_name', os.path.basename(file_path))}"
|
| 240 |
except Exception as exc:
|
| 241 |
return gr.JSON(value={"sources": []}), str(exc)
|
|
@@ -251,9 +304,15 @@ def ingest_url_source(user_id: str, notebook_id: str, url: str):
|
|
| 251 |
"POST",
|
| 252 |
f"/api/notebooks/{notebook_id}/sources/url",
|
| 253 |
json_body={"user_id": user_id, "url": url},
|
|
|
|
| 254 |
timeout=180,
|
| 255 |
)
|
| 256 |
-
sources_payload = _api_request(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
return gr.JSON(value=sources_payload), f"Ingested URL: {payload.get('source_name', url)}"
|
| 258 |
except Exception as exc:
|
| 259 |
return gr.JSON(value={"sources": []}), str(exc)
|
|
@@ -272,6 +331,7 @@ def send_message(message: str, history, user_id: str, notebook_id: str):
|
|
| 272 |
"POST",
|
| 273 |
f"/api/notebooks/{notebook_id}/chat",
|
| 274 |
json_body={"user_id": user_id, "message": message, "top_k": 5},
|
|
|
|
| 275 |
timeout=180,
|
| 276 |
)
|
| 277 |
assistant_text = str(resp.get("answer", "")) + _format_citations(resp.get("citations"))
|
|
@@ -313,7 +373,12 @@ def refresh_artifacts(user_id: str, notebook_id: str):
|
|
| 313 |
payload = _empty_artifacts_payload()
|
| 314 |
return gr.JSON(value=payload), None, None, None, None, "Select a notebook first."
|
| 315 |
try:
|
| 316 |
-
payload = _api_request(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
report, quiz, transcript, audio = _artifact_outputs_from_payload(payload)
|
| 318 |
return gr.JSON(value=payload), report, quiz, transcript, audio, ""
|
| 319 |
except Exception as exc:
|
|
@@ -336,6 +401,7 @@ def generate_report_artifact(user_id: str, notebook_id: str, artifact_prompt: st
|
|
| 336 |
"POST",
|
| 337 |
f"/api/notebooks/{notebook_id}/artifacts/report",
|
| 338 |
json_body={"user_id": user_id, "prompt": (artifact_prompt or "").strip() or None},
|
|
|
|
| 339 |
timeout=180,
|
| 340 |
)
|
| 341 |
payload_json, report, quiz, transcript, audio, _ = refresh_artifacts(user_id, notebook_id)
|
|
@@ -359,6 +425,7 @@ def generate_quiz_artifact(user_id: str, notebook_id: str, artifact_prompt: str,
|
|
| 359 |
"prompt": (artifact_prompt or "").strip() or None,
|
| 360 |
"num_questions": int(num_questions),
|
| 361 |
},
|
|
|
|
| 362 |
timeout=180,
|
| 363 |
)
|
| 364 |
payload_json, report, quiz, transcript, audio, _ = refresh_artifacts(user_id, notebook_id)
|
|
@@ -378,6 +445,7 @@ def generate_podcast_artifact(user_id: str, notebook_id: str, artifact_prompt: s
|
|
| 378 |
"POST",
|
| 379 |
f"/api/notebooks/{notebook_id}/artifacts/podcast",
|
| 380 |
json_body={"user_id": user_id, "prompt": (artifact_prompt or "").strip() or None},
|
|
|
|
| 381 |
timeout=240,
|
| 382 |
)
|
| 383 |
payload_json, report, quiz, transcript, audio, _ = refresh_artifacts(user_id, notebook_id)
|
|
|
|
| 80 |
time.sleep(0.2)
|
| 81 |
|
| 82 |
|
| 83 |
+
def _api_request(
|
| 84 |
+
method: str,
|
| 85 |
+
path: str,
|
| 86 |
+
*,
|
| 87 |
+
params=None,
|
| 88 |
+
json_body=None,
|
| 89 |
+
files=None,
|
| 90 |
+
data=None,
|
| 91 |
+
headers=None,
|
| 92 |
+
timeout: int = 60,
|
| 93 |
+
):
|
| 94 |
url = f"{BACKEND_URL}{path}"
|
| 95 |
try:
|
| 96 |
resp = requests.request(
|
|
|
|
| 100 |
json=json_body,
|
| 101 |
files=files,
|
| 102 |
data=data,
|
| 103 |
+
headers=headers,
|
| 104 |
timeout=timeout,
|
| 105 |
)
|
| 106 |
except requests.RequestException as exc:
|
|
|
|
| 120 |
return None
|
| 121 |
|
| 122 |
|
| 123 |
+
def _auth_headers(user_id: str | None) -> dict[str, str]:
|
| 124 |
+
uid = (user_id or "").strip()
|
| 125 |
+
return {"X-User-Id": uid} if uid else {}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
def _format_notebook_choices(items: list[dict[str, Any]]):
|
| 129 |
return [
|
| 130 |
(f"{item.get('name', 'Untitled')} [{str(item.get('notebook_id', ''))[:8]}]", item.get("notebook_id"))
|
|
|
|
| 169 |
if not user_id:
|
| 170 |
return gr.Dropdown(choices=[], value=None), [], gr.JSON(value={"sources": []}), [], "Enter a user ID."
|
| 171 |
try:
|
| 172 |
+
items = _api_request("GET", "/api/notebooks/", params={"user_id": user_id}, headers=_auth_headers(user_id))
|
| 173 |
choices = _format_notebook_choices(items)
|
| 174 |
selected = choices[0][1] if choices else None
|
| 175 |
chat_history = []
|
| 176 |
sources_payload = {"sources": []}
|
| 177 |
if selected:
|
| 178 |
+
chat = _api_request(
|
| 179 |
+
"GET",
|
| 180 |
+
f"/api/notebooks/{selected}/chat",
|
| 181 |
+
params={"user_id": user_id},
|
| 182 |
+
headers=_auth_headers(user_id),
|
| 183 |
+
)
|
| 184 |
chat_history = _messages_to_chatbot(chat.get("messages", []))
|
| 185 |
+
sources_payload = _api_request(
|
| 186 |
+
"GET",
|
| 187 |
+
f"/api/notebooks/{selected}/sources",
|
| 188 |
+
params={"user_id": user_id},
|
| 189 |
+
headers=_auth_headers(user_id),
|
| 190 |
+
)
|
| 191 |
return gr.Dropdown(choices=choices, value=selected), items, gr.JSON(value=sources_payload), chat_history, ""
|
| 192 |
except Exception as exc:
|
| 193 |
return gr.Dropdown(choices=[], value=None), [], gr.JSON(value={"sources": []}), [], str(exc)
|
|
|
|
| 199 |
if not user_id:
|
| 200 |
return gr.Dropdown(choices=[], value=None), [], "", "Enter a user ID first."
|
| 201 |
try:
|
| 202 |
+
_api_request(
|
| 203 |
+
"POST",
|
| 204 |
+
"/api/notebooks/",
|
| 205 |
+
json_body={"user_id": user_id, "name": notebook_name},
|
| 206 |
+
headers=_auth_headers(user_id),
|
| 207 |
+
)
|
| 208 |
dropdown, notebooks, _sources, _chat, _status = load_notebooks(user_id)
|
| 209 |
return dropdown, notebooks, "", f"Created notebook '{notebook_name}'."
|
| 210 |
except Exception as exc:
|
|
|
|
| 221 |
"PATCH",
|
| 222 |
f"/api/notebooks/{notebook_id}",
|
| 223 |
json_body={"user_id": user_id, "name": notebook_name},
|
| 224 |
+
headers=_auth_headers(user_id),
|
| 225 |
)
|
| 226 |
+
items = _api_request("GET", "/api/notebooks/", params={"user_id": user_id}, headers=_auth_headers(user_id))
|
| 227 |
choices = _format_notebook_choices(items)
|
| 228 |
return gr.Dropdown(choices=choices, value=notebook_id), items, f"Renamed notebook to '{notebook_name}'."
|
| 229 |
except Exception as exc:
|
|
|
|
| 235 |
if not user_id or not notebook_id:
|
| 236 |
return gr.Dropdown(choices=[], value=None), [], gr.JSON(value={"sources": []}), [], "Select a notebook to delete."
|
| 237 |
try:
|
| 238 |
+
_api_request(
|
| 239 |
+
"DELETE",
|
| 240 |
+
f"/api/notebooks/{notebook_id}",
|
| 241 |
+
params={"user_id": user_id},
|
| 242 |
+
headers=_auth_headers(user_id),
|
| 243 |
+
)
|
| 244 |
dropdown, notebooks, sources_json, chat_history, _ = load_notebooks(user_id)
|
| 245 |
return dropdown, notebooks, sources_json, chat_history, "Deleted notebook."
|
| 246 |
except Exception as exc:
|
|
|
|
| 252 |
if not user_id or not notebook_id:
|
| 253 |
return gr.JSON(value={"sources": []}), [], ""
|
| 254 |
try:
|
| 255 |
+
sources_payload = _api_request(
|
| 256 |
+
"GET",
|
| 257 |
+
f"/api/notebooks/{notebook_id}/sources",
|
| 258 |
+
params={"user_id": user_id},
|
| 259 |
+
headers=_auth_headers(user_id),
|
| 260 |
+
)
|
| 261 |
+
chat = _api_request(
|
| 262 |
+
"GET",
|
| 263 |
+
f"/api/notebooks/{notebook_id}/chat",
|
| 264 |
+
params={"user_id": user_id},
|
| 265 |
+
headers=_auth_headers(user_id),
|
| 266 |
+
)
|
| 267 |
return gr.JSON(value=sources_payload), _messages_to_chatbot(chat.get("messages", [])), ""
|
| 268 |
except Exception as exc:
|
| 269 |
return gr.JSON(value={"sources": []}), [], str(exc)
|
|
|
|
| 280 |
f"/api/notebooks/{notebook_id}/sources/upload",
|
| 281 |
data={"user_id": user_id},
|
| 282 |
files={"file": (os.path.basename(file_path), f)},
|
| 283 |
+
headers=_auth_headers(user_id),
|
| 284 |
timeout=180,
|
| 285 |
)
|
| 286 |
+
sources_payload = _api_request(
|
| 287 |
+
"GET",
|
| 288 |
+
f"/api/notebooks/{notebook_id}/sources",
|
| 289 |
+
params={"user_id": user_id},
|
| 290 |
+
headers=_auth_headers(user_id),
|
| 291 |
+
)
|
| 292 |
return gr.JSON(value=sources_payload), f"Ingested file: {payload.get('source_name', os.path.basename(file_path))}"
|
| 293 |
except Exception as exc:
|
| 294 |
return gr.JSON(value={"sources": []}), str(exc)
|
|
|
|
| 304 |
"POST",
|
| 305 |
f"/api/notebooks/{notebook_id}/sources/url",
|
| 306 |
json_body={"user_id": user_id, "url": url},
|
| 307 |
+
headers=_auth_headers(user_id),
|
| 308 |
timeout=180,
|
| 309 |
)
|
| 310 |
+
sources_payload = _api_request(
|
| 311 |
+
"GET",
|
| 312 |
+
f"/api/notebooks/{notebook_id}/sources",
|
| 313 |
+
params={"user_id": user_id},
|
| 314 |
+
headers=_auth_headers(user_id),
|
| 315 |
+
)
|
| 316 |
return gr.JSON(value=sources_payload), f"Ingested URL: {payload.get('source_name', url)}"
|
| 317 |
except Exception as exc:
|
| 318 |
return gr.JSON(value={"sources": []}), str(exc)
|
|
|
|
| 331 |
"POST",
|
| 332 |
f"/api/notebooks/{notebook_id}/chat",
|
| 333 |
json_body={"user_id": user_id, "message": message, "top_k": 5},
|
| 334 |
+
headers=_auth_headers(user_id),
|
| 335 |
timeout=180,
|
| 336 |
)
|
| 337 |
assistant_text = str(resp.get("answer", "")) + _format_citations(resp.get("citations"))
|
|
|
|
| 373 |
payload = _empty_artifacts_payload()
|
| 374 |
return gr.JSON(value=payload), None, None, None, None, "Select a notebook first."
|
| 375 |
try:
|
| 376 |
+
payload = _api_request(
|
| 377 |
+
"GET",
|
| 378 |
+
f"/api/notebooks/{notebook_id}/artifacts",
|
| 379 |
+
params={"user_id": user_id},
|
| 380 |
+
headers=_auth_headers(user_id),
|
| 381 |
+
)
|
| 382 |
report, quiz, transcript, audio = _artifact_outputs_from_payload(payload)
|
| 383 |
return gr.JSON(value=payload), report, quiz, transcript, audio, ""
|
| 384 |
except Exception as exc:
|
|
|
|
| 401 |
"POST",
|
| 402 |
f"/api/notebooks/{notebook_id}/artifacts/report",
|
| 403 |
json_body={"user_id": user_id, "prompt": (artifact_prompt or "").strip() or None},
|
| 404 |
+
headers=_auth_headers(user_id),
|
| 405 |
timeout=180,
|
| 406 |
)
|
| 407 |
payload_json, report, quiz, transcript, audio, _ = refresh_artifacts(user_id, notebook_id)
|
|
|
|
| 425 |
"prompt": (artifact_prompt or "").strip() or None,
|
| 426 |
"num_questions": int(num_questions),
|
| 427 |
},
|
| 428 |
+
headers=_auth_headers(user_id),
|
| 429 |
timeout=180,
|
| 430 |
)
|
| 431 |
payload_json, report, quiz, transcript, audio, _ = refresh_artifacts(user_id, notebook_id)
|
|
|
|
| 445 |
"POST",
|
| 446 |
f"/api/notebooks/{notebook_id}/artifacts/podcast",
|
| 447 |
json_body={"user_id": user_id, "prompt": (artifact_prompt or "").strip() or None},
|
| 448 |
+
headers=_auth_headers(user_id),
|
| 449 |
timeout=240,
|
| 450 |
)
|
| 451 |
payload_json, report, quiz, transcript, audio, _ = refresh_artifacts(user_id, notebook_id)
|
scripts/rag_benchmark.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
from statistics import mean
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _request(method: str, url: str, *, headers=None, params=None, json_body=None, timeout=180) -> dict[str, Any]:
|
| 11 |
+
resp = requests.request(
|
| 12 |
+
method=method,
|
| 13 |
+
url=url,
|
| 14 |
+
headers=headers,
|
| 15 |
+
params=params,
|
| 16 |
+
json=json_body,
|
| 17 |
+
timeout=timeout,
|
| 18 |
+
)
|
| 19 |
+
resp.raise_for_status()
|
| 20 |
+
if resp.content:
|
| 21 |
+
return resp.json()
|
| 22 |
+
return {}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def run_benchmark(
|
| 26 |
+
*,
|
| 27 |
+
base_url: str,
|
| 28 |
+
user_id: str,
|
| 29 |
+
notebook_id: str,
|
| 30 |
+
query: str,
|
| 31 |
+
top_k: int,
|
| 32 |
+
runs: int,
|
| 33 |
+
) -> dict[str, Any]:
|
| 34 |
+
headers = {"X-User-Id": user_id}
|
| 35 |
+
modes = ["topk", "rerank"]
|
| 36 |
+
results: dict[str, Any] = {}
|
| 37 |
+
|
| 38 |
+
for mode in modes:
|
| 39 |
+
latencies_ms: list[float] = []
|
| 40 |
+
citations_count: list[int] = []
|
| 41 |
+
used_chunks: list[int] = []
|
| 42 |
+
|
| 43 |
+
for _ in range(runs):
|
| 44 |
+
start = time.perf_counter()
|
| 45 |
+
payload = _request(
|
| 46 |
+
"POST",
|
| 47 |
+
f"{base_url}/api/notebooks/{notebook_id}/chat",
|
| 48 |
+
headers=headers,
|
| 49 |
+
json_body={
|
| 50 |
+
"user_id": user_id,
|
| 51 |
+
"message": query,
|
| 52 |
+
"top_k": top_k,
|
| 53 |
+
"retrieval_mode": mode,
|
| 54 |
+
},
|
| 55 |
+
)
|
| 56 |
+
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
| 57 |
+
latencies_ms.append(elapsed_ms)
|
| 58 |
+
citations_count.append(len(payload.get("citations", [])))
|
| 59 |
+
used_chunks.append(int(payload.get("used_chunks", 0)))
|
| 60 |
+
|
| 61 |
+
results[mode] = {
|
| 62 |
+
"avg_latency_ms": round(mean(latencies_ms), 2),
|
| 63 |
+
"min_latency_ms": round(min(latencies_ms), 2),
|
| 64 |
+
"max_latency_ms": round(max(latencies_ms), 2),
|
| 65 |
+
"avg_citations": round(mean(citations_count), 2),
|
| 66 |
+
"avg_used_chunks": round(mean(used_chunks), 2),
|
| 67 |
+
}
|
| 68 |
+
return results
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def main() -> None:
|
| 72 |
+
parser = argparse.ArgumentParser(description="Benchmark topk vs rerank retrieval modes.")
|
| 73 |
+
parser.add_argument("--base-url", default="http://127.0.0.1:8000")
|
| 74 |
+
parser.add_argument("--user-id", required=True)
|
| 75 |
+
parser.add_argument("--notebook-id", required=True)
|
| 76 |
+
parser.add_argument("--query", required=True)
|
| 77 |
+
parser.add_argument("--top-k", type=int, default=5)
|
| 78 |
+
parser.add_argument("--runs", type=int, default=5)
|
| 79 |
+
args = parser.parse_args()
|
| 80 |
+
|
| 81 |
+
results = run_benchmark(
|
| 82 |
+
base_url=args.base_url.rstrip("/"),
|
| 83 |
+
user_id=args.user_id,
|
| 84 |
+
notebook_id=args.notebook_id,
|
| 85 |
+
query=args.query,
|
| 86 |
+
top_k=args.top_k,
|
| 87 |
+
runs=max(1, args.runs),
|
| 88 |
+
)
|
| 89 |
+
print(json.dumps(results, indent=2))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
main()
|
tests/test_api_artifacts.py
CHANGED
|
@@ -12,6 +12,7 @@ from backend.services.storage import NotebookStore
|
|
| 12 |
|
| 13 |
|
| 14 |
client = TestClient(app)
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
def _seed_source(store: NotebookStore, user_id: str, notebook_id: str, source_id: str = "src_demo") -> None:
|
|
@@ -41,17 +42,23 @@ def test_artifact_endpoints_generate_list_and_download(monkeypatch, tmp_path: Pa
|
|
| 41 |
report = client.post(
|
| 42 |
f"/api/notebooks/{nb.notebook_id}/artifacts/report",
|
| 43 |
json={"user_id": "u1", "prompt": "Focus on definitions"},
|
|
|
|
| 44 |
)
|
| 45 |
assert report.status_code == 200
|
| 46 |
|
| 47 |
podcast = client.post(
|
| 48 |
f"/api/notebooks/{nb.notebook_id}/artifacts/podcast",
|
| 49 |
json={"user_id": "u1"},
|
|
|
|
| 50 |
)
|
| 51 |
assert podcast.status_code == 200
|
| 52 |
podcast_audio_name = Path(podcast.json()["audio_path"]).name
|
| 53 |
|
| 54 |
-
listed = client.get(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
assert listed.status_code == 200
|
| 56 |
payload = listed.json()
|
| 57 |
assert len(payload["reports"]) == 1
|
|
@@ -60,6 +67,7 @@ def test_artifact_endpoints_generate_list_and_download(monkeypatch, tmp_path: Pa
|
|
| 60 |
dl = client.get(
|
| 61 |
f"/api/notebooks/{nb.notebook_id}/artifacts/download",
|
| 62 |
params={"user_id": "u1", "artifact_type": "podcast", "filename": podcast_audio_name},
|
|
|
|
| 63 |
)
|
| 64 |
assert dl.status_code == 200
|
| 65 |
assert dl.content.startswith(b"ID3")
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
client = TestClient(app)
|
| 15 |
+
AUTH_U1 = {"X-User-Id": "u1"}
|
| 16 |
|
| 17 |
|
| 18 |
def _seed_source(store: NotebookStore, user_id: str, notebook_id: str, source_id: str = "src_demo") -> None:
|
|
|
|
| 42 |
report = client.post(
|
| 43 |
f"/api/notebooks/{nb.notebook_id}/artifacts/report",
|
| 44 |
json={"user_id": "u1", "prompt": "Focus on definitions"},
|
| 45 |
+
headers=AUTH_U1,
|
| 46 |
)
|
| 47 |
assert report.status_code == 200
|
| 48 |
|
| 49 |
podcast = client.post(
|
| 50 |
f"/api/notebooks/{nb.notebook_id}/artifacts/podcast",
|
| 51 |
json={"user_id": "u1"},
|
| 52 |
+
headers=AUTH_U1,
|
| 53 |
)
|
| 54 |
assert podcast.status_code == 200
|
| 55 |
podcast_audio_name = Path(podcast.json()["audio_path"]).name
|
| 56 |
|
| 57 |
+
listed = client.get(
|
| 58 |
+
f"/api/notebooks/{nb.notebook_id}/artifacts",
|
| 59 |
+
params={"user_id": "u1"},
|
| 60 |
+
headers=AUTH_U1,
|
| 61 |
+
)
|
| 62 |
assert listed.status_code == 200
|
| 63 |
payload = listed.json()
|
| 64 |
assert len(payload["reports"]) == 1
|
|
|
|
| 67 |
dl = client.get(
|
| 68 |
f"/api/notebooks/{nb.notebook_id}/artifacts/download",
|
| 69 |
params={"user_id": "u1", "artifact_type": "podcast", "filename": podcast_audio_name},
|
| 70 |
+
headers=AUTH_U1,
|
| 71 |
)
|
| 72 |
assert dl.status_code == 200
|
| 73 |
assert dl.content.startswith(b"ID3")
|
tests/test_api_chat.py
CHANGED
|
@@ -8,6 +8,7 @@ from backend.services.storage import NotebookStore
|
|
| 8 |
|
| 9 |
|
| 10 |
client = TestClient(app)
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
def test_chat_endpoint_success_with_mock(monkeypatch, tmp_path):
|
|
@@ -35,7 +36,8 @@ def test_chat_endpoint_success_with_mock(monkeypatch, tmp_path):
|
|
| 35 |
|
| 36 |
resp = client.post(
|
| 37 |
f"/api/notebooks/{created.notebook_id}/chat",
|
| 38 |
-
json={"user_id": "u1", "message": "Hi", "top_k": 3},
|
|
|
|
| 39 |
)
|
| 40 |
assert resp.status_code == 200
|
| 41 |
data = resp.json()
|
|
@@ -61,7 +63,11 @@ def test_chat_history_endpoint_reads_jsonl(tmp_path):
|
|
| 61 |
},
|
| 62 |
)
|
| 63 |
|
| 64 |
-
resp = client.get(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
assert resp.status_code == 200
|
| 66 |
payload = resp.json()
|
| 67 |
assert len(payload["messages"]) == 2
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
client = TestClient(app)
|
| 11 |
+
AUTH_U1 = {"X-User-Id": "u1"}
|
| 12 |
|
| 13 |
|
| 14 |
def test_chat_endpoint_success_with_mock(monkeypatch, tmp_path):
|
|
|
|
| 36 |
|
| 37 |
resp = client.post(
|
| 38 |
f"/api/notebooks/{created.notebook_id}/chat",
|
| 39 |
+
json={"user_id": "u1", "message": "Hi", "top_k": 3, "retrieval_mode": "rerank"},
|
| 40 |
+
headers=AUTH_U1,
|
| 41 |
)
|
| 42 |
assert resp.status_code == 200
|
| 43 |
data = resp.json()
|
|
|
|
| 63 |
},
|
| 64 |
)
|
| 65 |
|
| 66 |
+
resp = client.get(
|
| 67 |
+
f"/api/notebooks/{created.notebook_id}/chat",
|
| 68 |
+
params={"user_id": "u1"},
|
| 69 |
+
headers=AUTH_U1,
|
| 70 |
+
)
|
| 71 |
assert resp.status_code == 200
|
| 72 |
payload = resp.json()
|
| 73 |
assert len(payload["messages"]) == 2
|
tests/test_api_notebooks.py
CHANGED
|
@@ -6,6 +6,8 @@ from backend.services.storage import NotebookStore
|
|
| 6 |
|
| 7 |
|
| 8 |
client = TestClient(app)
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def setup_function():
|
|
@@ -16,7 +18,7 @@ def test_create_list_get_rename_delete_notebook(tmp_path):
|
|
| 16 |
notebooks_api.store = NotebookStore(base_dir=str(tmp_path))
|
| 17 |
|
| 18 |
payload = {"user_id": "u1", "name": "Notebook 1"}
|
| 19 |
-
resp = client.post("/api/notebooks/", json=payload)
|
| 20 |
assert resp.status_code == 200
|
| 21 |
data = resp.json()
|
| 22 |
assert data["user_id"] == "u1"
|
|
@@ -25,33 +27,44 @@ def test_create_list_get_rename_delete_notebook(tmp_path):
|
|
| 25 |
|
| 26 |
notebook_id = data["notebook_id"]
|
| 27 |
|
| 28 |
-
resp_list = client.get("/api/notebooks/", params={"user_id": "u1"})
|
| 29 |
assert resp_list.status_code == 200
|
| 30 |
listed = resp_list.json()
|
| 31 |
assert len(listed) == 1
|
| 32 |
assert listed[0]["notebook_id"] == notebook_id
|
| 33 |
|
| 34 |
-
resp_get = client.get(f"/api/notebooks/{notebook_id}", params={"user_id": "u1"})
|
| 35 |
assert resp_get.status_code == 200
|
| 36 |
data_get = resp_get.json()
|
| 37 |
assert data_get["notebook_id"] == notebook_id
|
| 38 |
|
| 39 |
-
resp_wrong_user = client.get(f"/api/notebooks/{notebook_id}", params={"user_id": "u2"})
|
| 40 |
assert resp_wrong_user.status_code == 404
|
| 41 |
|
| 42 |
resp_rename = client.patch(
|
| 43 |
f"/api/notebooks/{notebook_id}",
|
| 44 |
json={"user_id": "u1", "name": "Renamed"},
|
|
|
|
| 45 |
)
|
| 46 |
assert resp_rename.status_code == 200
|
| 47 |
assert resp_rename.json()["name"] == "Renamed"
|
| 48 |
|
| 49 |
-
resp_delete = client.delete(f"/api/notebooks/{notebook_id}", params={"user_id": "u1"})
|
| 50 |
assert resp_delete.status_code == 200
|
| 51 |
assert resp_delete.json() == {"deleted": True}
|
| 52 |
|
| 53 |
|
| 54 |
def test_get_missing_notebook(tmp_path):
|
| 55 |
notebooks_api.store = NotebookStore(base_dir=str(tmp_path))
|
| 56 |
-
resp = client.get("/api/notebooks/missing", params={"user_id": "u1"})
|
| 57 |
assert resp.status_code == 404
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
client = TestClient(app)
|
| 9 |
+
AUTH_U1 = {"X-User-Id": "u1"}
|
| 10 |
+
AUTH_U2 = {"X-User-Id": "u2"}
|
| 11 |
|
| 12 |
|
| 13 |
def setup_function():
|
|
|
|
| 18 |
notebooks_api.store = NotebookStore(base_dir=str(tmp_path))
|
| 19 |
|
| 20 |
payload = {"user_id": "u1", "name": "Notebook 1"}
|
| 21 |
+
resp = client.post("/api/notebooks/", json=payload, headers=AUTH_U1)
|
| 22 |
assert resp.status_code == 200
|
| 23 |
data = resp.json()
|
| 24 |
assert data["user_id"] == "u1"
|
|
|
|
| 27 |
|
| 28 |
notebook_id = data["notebook_id"]
|
| 29 |
|
| 30 |
+
resp_list = client.get("/api/notebooks/", params={"user_id": "u1"}, headers=AUTH_U1)
|
| 31 |
assert resp_list.status_code == 200
|
| 32 |
listed = resp_list.json()
|
| 33 |
assert len(listed) == 1
|
| 34 |
assert listed[0]["notebook_id"] == notebook_id
|
| 35 |
|
| 36 |
+
resp_get = client.get(f"/api/notebooks/{notebook_id}", params={"user_id": "u1"}, headers=AUTH_U1)
|
| 37 |
assert resp_get.status_code == 200
|
| 38 |
data_get = resp_get.json()
|
| 39 |
assert data_get["notebook_id"] == notebook_id
|
| 40 |
|
| 41 |
+
resp_wrong_user = client.get(f"/api/notebooks/{notebook_id}", params={"user_id": "u2"}, headers=AUTH_U2)
|
| 42 |
assert resp_wrong_user.status_code == 404
|
| 43 |
|
| 44 |
resp_rename = client.patch(
|
| 45 |
f"/api/notebooks/{notebook_id}",
|
| 46 |
json={"user_id": "u1", "name": "Renamed"},
|
| 47 |
+
headers=AUTH_U1,
|
| 48 |
)
|
| 49 |
assert resp_rename.status_code == 200
|
| 50 |
assert resp_rename.json()["name"] == "Renamed"
|
| 51 |
|
| 52 |
+
resp_delete = client.delete(f"/api/notebooks/{notebook_id}", params={"user_id": "u1"}, headers=AUTH_U1)
|
| 53 |
assert resp_delete.status_code == 200
|
| 54 |
assert resp_delete.json() == {"deleted": True}
|
| 55 |
|
| 56 |
|
| 57 |
def test_get_missing_notebook(tmp_path):
|
| 58 |
notebooks_api.store = NotebookStore(base_dir=str(tmp_path))
|
| 59 |
+
resp = client.get("/api/notebooks/missing", params={"user_id": "u1"}, headers=AUTH_U1)
|
| 60 |
assert resp.status_code == 404
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def test_rejects_mismatched_user_id_payload(tmp_path):
|
| 64 |
+
notebooks_api.store = NotebookStore(base_dir=str(tmp_path))
|
| 65 |
+
resp = client.post(
|
| 66 |
+
"/api/notebooks/",
|
| 67 |
+
json={"user_id": "u2", "name": "Notebook 1"},
|
| 68 |
+
headers=AUTH_U1,
|
| 69 |
+
)
|
| 70 |
+
assert resp.status_code == 403
|
tests/test_rag.py
CHANGED
|
@@ -59,3 +59,27 @@ def test_answer_notebook_question_persists_messages(monkeypatch, tmp_path):
|
|
| 59 |
assert len(messages) == 2
|
| 60 |
assert messages[0]["role"] == "user"
|
| 61 |
assert messages[1]["role"] == "assistant"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
assert len(messages) == 2
|
| 60 |
assert messages[0]["role"] == "user"
|
| 61 |
assert messages[1]["role"] == "assistant"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def test_rerank_chunks_prefers_lexically_relevant_chunk():
|
| 65 |
+
candidates = [
|
| 66 |
+
{
|
| 67 |
+
"chunk_id": "a",
|
| 68 |
+
"text": "general overview and intro",
|
| 69 |
+
"metadata": {},
|
| 70 |
+
"distance": 0.15,
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"chunk_id": "b",
|
| 74 |
+
"text": "transformer attention heads and query key value details",
|
| 75 |
+
"metadata": {},
|
| 76 |
+
"distance": 0.20,
|
| 77 |
+
},
|
| 78 |
+
]
|
| 79 |
+
reranked = rag.rerank_chunks(
|
| 80 |
+
query="explain transformer attention",
|
| 81 |
+
candidate_chunks=candidates,
|
| 82 |
+
top_k=1,
|
| 83 |
+
)
|
| 84 |
+
assert len(reranked) == 1
|
| 85 |
+
assert reranked[0]["chunk_id"] == "b"
|