File size: 17,430 Bytes
aacd162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1815b1f
aacd162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f09ce8f
 
 
 
 
 
 
 
 
 
 
 
 
 
aacd162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1815b1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aacd162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
"""
Integration tests for artifact API endpoints.

Uses FastAPI TestClient with an in-memory SQLite database so no external
services are required.
"""
from __future__ import annotations

import pathlib
import sys
from unittest.mock import MagicMock, patch

import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

ROOT = pathlib.Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))

from data.db import Base, get_db
from data import crud
import app as app_module
from app import app


# ── Database fixtures ─────────────────────────────────────────────────────────


@pytest.fixture()
def db_engine(tmp_path):
    """Fresh SQLite database for each test."""
    db_file = tmp_path / "test.db"
    engine = create_engine(
        f"sqlite:///{db_file}",
        connect_args={"check_same_thread": False},
    )
    # Import all models so Base knows about them
    import data.models  # noqa: F401
    Base.metadata.create_all(bind=engine)
    yield engine
    Base.metadata.drop_all(bind=engine)
    engine.dispose()


@pytest.fixture()
def db_session(db_engine):
    """Single session shared across one test."""
    Session = sessionmaker(autocommit=False, autoflush=False, bind=db_engine)
    session = Session()
    yield session
    session.close()


@pytest.fixture()
def client(db_session):
    """TestClient with database overridden to the test session."""

    def _override_get_db():
        yield db_session

    app.dependency_overrides[get_db] = _override_get_db
    with TestClient(app, raise_server_exceptions=True) as c:
        yield c
    app.dependency_overrides.clear()


@pytest.fixture()
def notebook(db_session):
    """A test user + notebook ready in the test database."""
    crud.get_or_create_user(db=db_session, user_id=1)
    nb = crud.create_notebook(db=db_session, owner_user_id=1, title="Test Notebook")
    return nb


# ── Quiz endpoint tests ───────────────────────────────────────────────────────


class TestQuizEndpoint:
    def test_generate_quiz_success(self, client, notebook):
        """POST quiz returns 200 with status=ready and content populated."""
        mock_result = {
            "questions": [
                {
                    "id": 1,
                    "question": "What is ML?",
                    "options": ["A) a", "B) b", "C) c", "D) d"],
                    "correct_answer": "A",
                    "explanation": "Because.",
                    "difficulty": "easy",
                    "topic": "ML",
                }
            ],
            "metadata": {"num_questions": 1},
        }

        with patch("app.QuizGenerator") as MockGen:
            mock_gen = MagicMock()
            mock_gen.generate_quiz.return_value = mock_result
            mock_gen.format_quiz_markdown.return_value = "# Quiz\n\n## Questions\n\n## Answer Key\n"
            mock_gen.save_quiz.return_value = "/tmp/quiz.md"
            MockGen.return_value = mock_gen

            resp = client.post(
                f"/notebooks/{notebook.id}/artifacts/quiz",
                json={"num_questions": 1, "difficulty": "easy"},
            )

        assert resp.status_code == 200
        data = resp.json()
        assert data["type"] == "quiz"
        assert data["status"] == "ready"
        assert data["content"] is not None
        assert "## Answer Key" in data["content"]
        assert data["file_path"] == "/tmp/quiz.md"

    def test_generate_quiz_no_content_fails(self, client, notebook):
        """POST quiz marks artifact as failed when generator reports no content."""
        with patch("app.QuizGenerator") as MockGen:
            mock_gen = MagicMock()
            mock_gen.generate_quiz.return_value = {
                "error": "No content found in notebook.",
                "questions": [],
                "metadata": {},
            }
            MockGen.return_value = mock_gen

            resp = client.post(
                f"/notebooks/{notebook.id}/artifacts/quiz",
                json={},
            )

        assert resp.status_code == 200
        assert resp.json()["status"] == "failed"
        assert resp.json()["error_message"] is not None

    def test_generate_quiz_unknown_notebook_404(self, client):
        """POST quiz for a non-existent notebook returns 404."""
        resp = client.post("/notebooks/9999/artifacts/quiz", json={})
        assert resp.status_code == 404

    def test_generate_quiz_stores_metadata(self, client, notebook):
        """Artifact metadata reflects the request parameters."""
        mock_result = {"questions": [], "metadata": {}}
        with patch("app.QuizGenerator") as MockGen:
            mock_gen = MagicMock()
            mock_gen.generate_quiz.return_value = {
                "error": "No content.",
                "questions": [],
                "metadata": {},
            }
            MockGen.return_value = mock_gen

            resp = client.post(
                f"/notebooks/{notebook.id}/artifacts/quiz",
                json={"num_questions": 7, "difficulty": "hard", "title": "Hard Quiz"},
            )

        data = resp.json()
        assert data["title"] == "Hard Quiz"
        assert data["metadata"]["num_questions"] == 7
        assert data["metadata"]["difficulty"] == "hard"


# ── Report endpoint tests ─────────────────────────────────────────────────────


class TestReportEndpoint:
    def test_generate_report_success(self, client, notebook):
        mock_result = {"content": "# Report\n\nGenerated report content.", "detail_level": "medium"}
        with patch("app.ReportGenerator") as MockGen:
            mock_gen = MagicMock()
            mock_gen.generate_report.return_value = mock_result
            mock_gen.save_report.return_value = "/tmp/report.md"
            MockGen.return_value = mock_gen

            resp = client.post(
                f"/notebooks/{notebook.id}/artifacts/report",
                json={"detail_level": "medium"},
            )

        assert resp.status_code == 200
        data = resp.json()
        assert data["type"] == "report"
        assert data["status"] == "ready"
        assert data["content"].startswith("# Report")
        assert data["file_path"] == "/tmp/report.md"

    def test_generate_report_no_content_fails(self, client, notebook):
        with patch("app.ReportGenerator") as MockGen:
            mock_gen = MagicMock()
            mock_gen.generate_report.return_value = {"error": "No content found in notebook."}
            MockGen.return_value = mock_gen

            resp = client.post(
                f"/notebooks/{notebook.id}/artifacts/report",
                json={"detail_level": "short"},
            )

        assert resp.status_code == 200
        assert resp.json()["status"] == "failed"
        assert "No content" in str(resp.json()["error_message"])

    def test_generate_report_invalid_detail_level_400(self, client, notebook):
        resp = client.post(
            f"/notebooks/{notebook.id}/artifacts/report",
            json={"detail_level": "super-long"},
        )
        assert resp.status_code == 400

    def test_generate_report_provider_config_error_returns_failed_artifact(self, client, notebook):
        """Provider init/runtime errors should not bubble as HTTP 500."""
        with patch("app.ReportGenerator", side_effect=ValueError("missing provider credentials")):
            resp = client.post(
                f"/notebooks/{notebook.id}/artifacts/report",
                json={"detail_level": "medium"},
            )

        assert resp.status_code == 200
        data = resp.json()
        assert data["type"] == "report"
        assert data["status"] == "failed"
        assert "missing provider credentials" in str(data.get("error_message"))

    def test_generate_report_unknown_notebook_404(self, client):
        resp = client.post("/notebooks/9999/artifacts/report", json={"detail_level": "medium"})
        assert resp.status_code == 404


# ── Podcast endpoint tests ────────────────────────────────────────────────────


class TestPodcastEndpoint:
    def test_generate_podcast_returns_pending(self, client, notebook):
        """POST podcast returns 200 with status=pending immediately."""
        with patch("app._run_podcast_background"):
            resp = client.post(
                f"/notebooks/{notebook.id}/artifacts/podcast",
                json={"duration": "5min"},
            )

        assert resp.status_code == 200
        data = resp.json()
        assert data["type"] == "podcast"
        assert data["status"] == "pending"
        assert data["metadata"]["duration"] == "5min"

    def test_generate_podcast_unknown_notebook_404(self, client):
        """POST podcast for a non-existent notebook returns 404."""
        resp = client.post("/notebooks/9999/artifacts/podcast", json={})
        assert resp.status_code == 404

    def test_generate_podcast_stores_topic(self, client, notebook):
        """topic_focus is persisted in artifact metadata."""
        with patch("app._run_podcast_background"):
            resp = client.post(
                f"/notebooks/{notebook.id}/artifacts/podcast",
                json={"duration": "10min", "topic_focus": "neural nets"},
            )

        data = resp.json()
        assert data["metadata"]["topic_focus"] == "neural nets"

    def test_background_podcast_failure_persists_transcript(self, notebook, db_session):
        """If audio fails but transcript exists, artifact should be failed with transcript content."""
        artifact = crud.create_artifact(
            db=db_session,
            notebook_id=notebook.id,
            artifact_type="podcast",
            metadata={"duration": "5min"},
        )

        result_payload = {
            "error": "Transcript generated but audio synthesis failed for all segments.",
            "transcript": [{"speaker": "Alex", "text": "Intro text"}],
            "audio_path": None,
            "metadata": {"tts_provider": "elevenlabs"},
        }

        with patch("app.PodcastGenerator") as MockGen, patch("app.SessionLocal", return_value=db_session):
            mock_gen = MagicMock()
            mock_gen.generate_podcast.return_value = result_payload
            mock_gen.format_transcript_markdown.return_value = "# Podcast Transcript\n\n**Alex:** Intro text"
            mock_gen.save_transcript.return_value = "/tmp/transcript.md"
            MockGen.return_value = mock_gen

            app_module._run_podcast_background(
                artifact_id=artifact.id,
                user_id=1,
                notebook_id=notebook.id,
                duration="5min",
                topic_focus=None,
            )

        updated = crud.get_artifact(db_session, artifact.id)
        assert updated is not None
        assert updated.status == "failed"
        assert updated.content is not None
        assert "Podcast Transcript" in updated.content
        assert updated.error_message is not None
        assert "audio synthesis failed" in updated.error_message.lower()


# ── List artifacts tests ──────────────────────────────────────────────────────


class TestListArtifacts:
    def test_list_empty(self, client, notebook):
        """GET /artifacts returns empty list when none exist."""
        resp = client.get(f"/notebooks/{notebook.id}/artifacts")
        assert resp.status_code == 200
        assert resp.json() == []

    def test_list_returns_all(self, client, notebook, db_session):
        """GET /artifacts returns all artifacts for the notebook."""
        crud.create_artifact(db=db_session, notebook_id=notebook.id, artifact_type="quiz")
        crud.create_artifact(db=db_session, notebook_id=notebook.id, artifact_type="podcast")

        resp = client.get(f"/notebooks/{notebook.id}/artifacts")
        assert resp.status_code == 200
        assert len(resp.json()) == 2

    def test_list_filter_by_type(self, client, notebook, db_session):
        """GET /artifacts?artifact_type=quiz returns only quizzes."""
        crud.create_artifact(db=db_session, notebook_id=notebook.id, artifact_type="quiz")
        crud.create_artifact(db=db_session, notebook_id=notebook.id, artifact_type="podcast")

        resp = client.get(f"/notebooks/{notebook.id}/artifacts?artifact_type=quiz")
        data = resp.json()
        assert len(data) == 1
        assert data[0]["type"] == "quiz"

    def test_list_unknown_notebook_404(self, client):
        """GET /artifacts for unknown notebook returns 404."""
        resp = client.get("/notebooks/9999/artifacts")
        assert resp.status_code == 404


# ── Get single artifact tests ─────────────────────────────────────────────────


class TestGetArtifact:
    def test_get_existing_artifact(self, client, notebook, db_session):
        """GET /artifacts/{id} returns the correct artifact."""
        artifact = crud.create_artifact(
            db=db_session,
            notebook_id=notebook.id,
            artifact_type="quiz",
            title="My Quiz",
        )

        resp = client.get(f"/notebooks/{notebook.id}/artifacts/{artifact.id}")
        assert resp.status_code == 200
        data = resp.json()
        assert data["id"] == artifact.id
        assert data["title"] == "My Quiz"
        assert data["type"] == "quiz"

    def test_get_nonexistent_artifact_404(self, client, notebook):
        """GET /artifacts/9999 returns 404."""
        resp = client.get(f"/notebooks/{notebook.id}/artifacts/9999")
        assert resp.status_code == 404

    def test_get_artifact_wrong_notebook_404(self, client, db_session):
        """GET /artifacts/{id} returns 404 if artifact belongs to a different notebook."""
        crud.get_or_create_user(db=db_session, user_id=1)
        nb1 = crud.create_notebook(db=db_session, owner_user_id=1, title="NB1")
        nb2 = crud.create_notebook(db=db_session, owner_user_id=1, title="NB2")
        artifact = crud.create_artifact(db=db_session, notebook_id=nb1.id, artifact_type="quiz")

        # Try to access artifact via the wrong notebook
        resp = client.get(f"/notebooks/{nb2.id}/artifacts/{artifact.id}")
        assert resp.status_code == 404


# ── Audio download tests ──────────────────────────────────────────────────────


class TestPodcastAudio:
    def test_audio_pending_returns_409(self, client, notebook, db_session):
        """GET /audio for a pending podcast returns 409."""
        artifact = crud.create_artifact(
            db=db_session,
            notebook_id=notebook.id,
            artifact_type="podcast",
        )

        resp = client.get(f"/notebooks/{notebook.id}/artifacts/{artifact.id}/audio")
        assert resp.status_code == 409

    def test_audio_quiz_artifact_returns_400(self, client, notebook, db_session):
        """GET /audio for a quiz artifact returns 400 (wrong type)."""
        artifact = crud.create_artifact(
            db=db_session,
            notebook_id=notebook.id,
            artifact_type="quiz",
        )
        crud.update_artifact(db_session, artifact.id, status="ready")

        resp = client.get(f"/notebooks/{notebook.id}/artifacts/{artifact.id}/audio")
        assert resp.status_code == 400

    def test_audio_ready_no_file_returns_404(self, client, notebook, db_session, tmp_path):
        """GET /audio returns 404 when audio file is missing from disk."""
        artifact = crud.create_artifact(
            db=db_session,
            notebook_id=notebook.id,
            artifact_type="podcast",
        )
        crud.update_artifact(
            db_session,
            artifact.id,
            status="ready",
            file_path=str(tmp_path / "nonexistent.mp3"),
        )

        resp = client.get(f"/notebooks/{notebook.id}/artifacts/{artifact.id}/audio")
        assert resp.status_code == 404

    def test_audio_ready_with_file_returns_200(self, client, notebook, db_session, tmp_path):
        """GET /audio returns 200 and file content when audio exists."""
        audio_file = tmp_path / "podcast.mp3"
        audio_file.write_bytes(b"fake mp3 content")

        artifact = crud.create_artifact(
            db=db_session,
            notebook_id=notebook.id,
            artifact_type="podcast",
        )
        crud.update_artifact(
            db_session,
            artifact.id,
            status="ready",
            file_path=str(audio_file),
        )

        resp = client.get(f"/notebooks/{notebook.id}/artifacts/{artifact.id}/audio")
        assert resp.status_code == 200
        assert resp.headers["content-type"].startswith("audio/mpeg")
        assert resp.content == b"fake mp3 content"