NotebookLMClone / tests /test_chat_citations.py
github-actions[bot]
Sync from GitHub e2e802be5157aa05d1251459f529eb7eb4242ef2
dba1a8e
"""
Integration tests for citation persistence in chat threads.
"""
from __future__ import annotations
import pathlib
import sys
from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
ROOT = pathlib.Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from app import app
from data.db import Base, get_db
@pytest.fixture()
def db_engine(tmp_path):
db_file = tmp_path / "test_chat_citations.db"
engine = create_engine(
f"sqlite:///{db_file}",
connect_args={"check_same_thread": False},
)
import data.models # noqa: F401
Base.metadata.create_all(bind=engine)
yield engine
Base.metadata.drop_all(bind=engine)
engine.dispose()
@pytest.fixture()
def db_session(db_engine):
Session = sessionmaker(autocommit=False, autoflush=False, bind=db_engine)
session = Session()
yield session
session.close()
@pytest.fixture()
def client(db_session, monkeypatch):
monkeypatch.setenv("AUTH_MODE", "dev")
monkeypatch.setenv("APP_SESSION_SECRET", "chat-citations-test-secret")
def _override_get_db():
yield db_session
app.dependency_overrides[get_db] = _override_get_db
with TestClient(app, raise_server_exceptions=True) as c:
yield c
app.dependency_overrides.clear()
def test_thread_messages_include_persisted_citations(client):
create_notebook = client.post("/notebooks", json={"title": "Citation Notebook"})
assert create_notebook.status_code == 200
notebook_id = int(create_notebook.json()["id"])
create_source = client.post(
f"/notebooks/{notebook_id}/sources",
json={
"type": "text",
"title": "Lecture Notes",
"status": "ready",
},
)
assert create_source.status_code == 200
source_id = int(create_source.json()["id"])
create_thread = client.post(
f"/notebooks/{notebook_id}/threads",
json={"title": "Q&A"},
)
assert create_thread.status_code == 200
thread_id = int(create_thread.json()["id"])
retrieval_rows = [
{
"chunk_id": "chunk-1",
"score": 0.12,
"document": "Neural networks learn from examples.",
"metadata": {
"source_id": str(source_id),
"source_title": "Lecture Notes",
"chunk_index": 0,
},
}
]
with patch("app.query_notebook_chunks", return_value=retrieval_rows), patch(
"app.generate_chat_completion", return_value="They learn from examples in the data."
):
chat_resp = client.post(
f"/threads/{thread_id}/chat",
params={"notebook_id": notebook_id},
json={"question": "How do neural networks learn?", "top_k": 5},
)
assert chat_resp.status_code == 200
chat_payload = chat_resp.json()
assert len(chat_payload["citations"]) == 1
assert int(chat_payload["citations"][0]["source_id"]) == source_id
messages_resp = client.get(
f"/threads/{thread_id}/messages",
params={"notebook_id": notebook_id},
)
assert messages_resp.status_code == 200
messages = messages_resp.json()
assistant_message = next((m for m in messages if m["role"] == "assistant"), None)
assert assistant_message is not None
assert len(assistant_message["citations"]) == 1
assert int(assistant_message["citations"][0]["source_id"]) == source_id
assert assistant_message["citations"][0]["source_title"] == "Lecture Notes"