File size: 15,116 Bytes
d5a6f3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61d5f0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
"""
Tests for db.py using a temporary in-memory/temp database.
"""
import os
import tempfile
import pytest
import pytest_asyncio

# Override DB_PATH before importing db
@pytest.fixture(autouse=True)
def tmp_db(monkeypatch, tmp_path):
    db_path = str(tmp_path / "test.db")
    monkeypatch.setenv("DB_PATH", db_path)
    # Also patch the module-level constant
    import app.config as cfg
    monkeypatch.setattr(cfg, "DB_PATH", db_path)
    import app.db as db_mod
    monkeypatch.setattr(db_mod, "DB_PATH", db_path)
    return db_path


@pytest.mark.asyncio
async def test_init_db_creates_tables(tmp_db):
    import app.db as db
    await db.init_db()
    import aiosqlite
    async with aiosqlite.connect(tmp_db) as conn:
        cur = await conn.execute(
            "SELECT name FROM sqlite_master WHERE type='table'"
        )
        tables = {r[0] for r in await cur.fetchall()}
    assert "interactions" in tables
    assert "paper_qdrant_map" in tables
    assert "paper_metadata" in tables


@pytest.mark.asyncio
async def test_log_and_retrieve_interactions(tmp_db):
    import app.db as db
    await db.init_db()
    await db.log_interaction("user1", "1706.03762", "save", source="search")
    await db.log_interaction("user1", "2302.11382", "not_interested", source="search")

    rows = await db.get_user_interactions("user1")
    assert len(rows) == 2
    paper_ids = {r["paper_id"] for r in rows}
    assert "1706.03762" in paper_ids
    assert "2302.11382" in paper_ids


@pytest.mark.asyncio
async def test_filter_interactions_by_event_type(tmp_db):
    import app.db as db
    await db.init_db()
    await db.log_interaction("u2", "aaa", "save")
    await db.log_interaction("u2", "bbb", "not_interested")
    await db.log_interaction("u2", "ccc", "click")

    saves = await db.get_user_interactions("u2", event_types=["save"])
    assert len(saves) == 1
    assert saves[0]["paper_id"] == "aaa"


@pytest.mark.asyncio
async def test_qdrant_id_roundtrip(tmp_db):
    import app.db as db
    await db.init_db()
    await db.save_qdrant_id("1706.03762", 42)
    assert await db.get_qdrant_id("1706.03762") == 42
    assert await db.get_qdrant_id("unknown") is None


@pytest.mark.asyncio
async def test_qdrant_id_batch(tmp_db):
    import app.db as db
    await db.init_db()
    await db.save_qdrant_id("a1", 10)
    await db.save_qdrant_id("a2", 20)
    result = await db.get_qdrant_ids_batch(["a1", "a2", "a3"])
    assert result == {"a1": 10, "a2": 20}


@pytest.mark.asyncio
async def test_metadata_cache_roundtrip(tmp_db):
    import app.db as db
    await db.init_db()
    paper = {
        "arxiv_id": "1706.03762",
        "title": "Attention Is All You Need",
        "abstract": "The dominant sequence transduction models...",
        "authors": '["Vaswani", "Shazeer"]',
        "category": "cs.CL",
        "published": "2017-06-12",
    }
    await db.cache_metadata(paper)
    cached = await db.get_cached_metadata("1706.03762")
    assert cached is not None
    assert cached["title"] == "Attention Is All You Need"
    assert cached["category"] == "cs.CL"


@pytest.mark.asyncio
async def test_metadata_cache_batch(tmp_db):
    import app.db as db
    await db.init_db()
    for i in range(3):
        await db.cache_metadata({
            "arxiv_id": f"paper{i}",
            "title": f"Title {i}",
            "abstract": "...",
            "authors": "[]",
            "category": "cs.LG",
            "published": "2023-01-01",
        })
    result = await db.get_cached_metadata_batch(["paper0", "paper2", "paper99"])
    assert "paper0" in result
    assert "paper2" in result
    assert "paper99" not in result


# ── Phase 4.3: cache_turso_metadata_batch ────────────────────────────────────

@pytest.mark.asyncio
async def test_cache_turso_metadata_batch_writes_all(tmp_db):
    """Turso dicts should be written to paper_metadata verbatim."""
    import app.db as db
    await db.init_db()
    papers = [
        {
            "arxiv_id": "1706.03762",
            "title": "Attention Is All You Need",
            "abstract": "Transformers.",
            "authors": '["Vaswani"]',
            "category": "cs.CL",
            "published": "2017-06-12",
            "year": 2017,
            "citation_count": 50000,
        },
        {
            "arxiv_id": "2001.00001",
            "title": "Another Paper",
            "abstract": "...",
            "authors": '["Smith"]',
            "category": "cs.CV",
            "published": "2020-01-01",
            "year": 2020,
        },
    ]
    await db.cache_turso_metadata_batch(papers)

    cached = await db.get_cached_metadata("1706.03762")
    assert cached is not None
    assert cached["title"] == "Attention Is All You Need"
    assert cached["category"] == "cs.CL"

    cached2 = await db.get_cached_metadata("2001.00001")
    assert cached2 is not None
    assert cached2["category"] == "cs.CV"


@pytest.mark.asyncio
async def test_cache_turso_metadata_batch_empty(tmp_db):
    """Empty input must not crash."""
    import app.db as db
    await db.init_db()
    await db.cache_turso_metadata_batch([])
    # No exception = success


@pytest.mark.asyncio
async def test_cache_turso_metadata_batch_skips_missing_arxiv_id(tmp_db):
    """Rows without arxiv_id should be skipped, others persisted."""
    import app.db as db
    await db.init_db()
    papers = [
        {"title": "No ID", "category": "cs.LG"},  # missing arxiv_id
        {"arxiv_id": "good.123", "title": "Good", "category": "cs.AI",
         "abstract": "", "authors": "[]", "published": "2024-01-01"},
    ]
    await db.cache_turso_metadata_batch(papers)
    cached = await db.get_cached_metadata("good.123")
    assert cached is not None
    assert cached["title"] == "Good"


@pytest.mark.asyncio
async def test_cache_turso_metadata_batch_upserts(tmp_db):
    """Second write for same arxiv_id should overwrite the first."""
    import app.db as db
    await db.init_db()
    paper_v1 = {"arxiv_id": "p1", "title": "V1", "category": "cs.LG",
                "abstract": "", "authors": "[]", "published": "2024-01-01"}
    paper_v2 = {"arxiv_id": "p1", "title": "V2", "category": "cs.CV",
                "abstract": "", "authors": "[]", "published": "2024-01-01"}
    await db.cache_turso_metadata_batch([paper_v1])
    await db.cache_turso_metadata_batch([paper_v2])
    cached = await db.get_cached_metadata("p1")
    assert cached["title"] == "V2"
    assert cached["category"] == "cs.CV"


# ── Phase 4.3: get_suppressed_categories ──────────────────────────────────────

@pytest.mark.asyncio
async def test_suppressed_empty_for_new_user(tmp_db):
    import app.db as db
    await db.init_db()
    result = await db.get_suppressed_categories("never-dismissed")
    assert result == set()


@pytest.mark.asyncio
async def test_suppressed_below_threshold_not_returned(tmp_db):
    """Two dismissals in one category (< threshold=3) should NOT suppress."""
    import app.db as db
    await db.init_db()
    # Seed metadata
    for i, aid in enumerate(["p1", "p2"]):
        await db.cache_metadata({
            "arxiv_id": aid, "title": f"t{i}", "abstract": "",
            "authors": "[]", "category": "cs.CV", "published": "2024-01-01",
        })
    # Two dismissals β€” below threshold=3
    await db.log_interaction("u1", "p1", "not_interested")
    await db.log_interaction("u1", "p2", "not_interested")

    result = await db.get_suppressed_categories("u1")
    assert "cs.CV" not in result


@pytest.mark.asyncio
async def test_suppressed_at_threshold_returned(tmp_db):
    """Three dismissals in same category should suppress that category."""
    import app.db as db
    await db.init_db()
    for i, aid in enumerate(["p1", "p2", "p3"]):
        await db.cache_metadata({
            "arxiv_id": aid, "title": f"t{i}", "abstract": "",
            "authors": "[]", "category": "physics.optics", "published": "2024-01-01",
        })
    for aid in ["p1", "p2", "p3"]:
        await db.log_interaction("u1", aid, "not_interested")

    result = await db.get_suppressed_categories("u1")
    assert "physics.optics" in result


@pytest.mark.asyncio
async def test_suppressed_only_counts_not_interested(tmp_db):
    """Saves should NOT count toward suppression."""
    import app.db as db
    await db.init_db()
    for aid in ["p1", "p2", "p3"]:
        await db.cache_metadata({
            "arxiv_id": aid, "title": "t", "abstract": "",
            "authors": "[]", "category": "cs.CL", "published": "2024-01-01",
        })
    # 3 saves (not dismissals) in same category
    for aid in ["p1", "p2", "p3"]:
        await db.log_interaction("u1", aid, "save")

    result = await db.get_suppressed_categories("u1")
    assert "cs.CL" not in result


@pytest.mark.asyncio
async def test_suppressed_partitions_categories(tmp_db):
    """Different categories should be independent."""
    import app.db as db
    await db.init_db()
    # 3 dismissals in cs.AI, 1 in cs.LG
    for aid in ["a1", "a2", "a3"]:
        await db.cache_metadata({
            "arxiv_id": aid, "title": "t", "abstract": "",
            "authors": "[]", "category": "cs.AI", "published": "2024-01-01",
        })
        await db.log_interaction("u1", aid, "not_interested")
    await db.cache_metadata({
        "arxiv_id": "lone", "title": "t", "abstract": "",
        "authors": "[]", "category": "cs.LG", "published": "2024-01-01",
    })
    await db.log_interaction("u1", "lone", "not_interested")

    result = await db.get_suppressed_categories("u1")
    assert "cs.AI" in result
    assert "cs.LG" not in result


@pytest.mark.asyncio
async def test_suppressed_ignores_other_users(tmp_db):
    """One user's dismissals must not affect another user's suppressions."""
    import app.db as db
    await db.init_db()
    for aid in ["p1", "p2", "p3"]:
        await db.cache_metadata({
            "arxiv_id": aid, "title": "t", "abstract": "",
            "authors": "[]", "category": "cs.CV", "published": "2024-01-01",
        })
        await db.log_interaction("userA", aid, "not_interested")

    result_a = await db.get_suppressed_categories("userA")
    result_b = await db.get_suppressed_categories("userB")
    assert "cs.CV" in result_a
    assert result_b == set()


@pytest.mark.asyncio
async def test_suppressed_empty_category_excluded(tmp_db):
    """Papers with empty category string should not produce a '' suppression."""
    import app.db as db
    await db.init_db()
    for aid in ["e1", "e2", "e3"]:
        await db.cache_metadata({
            "arxiv_id": aid, "title": "t", "abstract": "",
            "authors": "[]", "category": "", "published": "2024-01-01",
        })
        await db.log_interaction("u1", aid, "not_interested")

    result = await db.get_suppressed_categories("u1")
    assert "" not in result


@pytest.mark.asyncio
async def test_suppressed_custom_threshold(tmp_db):
    """Threshold=2 should trigger at 2 dismissals."""
    import app.db as db
    await db.init_db()
    for aid in ["x1", "x2"]:
        await db.cache_metadata({
            "arxiv_id": aid, "title": "t", "abstract": "",
            "authors": "[]", "category": "math.NT", "published": "2024-01-01",
        })
        await db.log_interaction("u1", aid, "not_interested")

    result = await db.get_suppressed_categories("u1", threshold=2)
    assert "math.NT" in result

    result_high = await db.get_suppressed_categories("u1", threshold=5)
    assert "math.NT" not in result_high


# ── Phase 4.5: Instrumentation columns ───────────────────────────────────────

@pytest.mark.asyncio
async def test_instrumentation_columns_exist(tmp_db):
    """The interactions table should have ranker_version, candidate_source, cluster_id columns."""
    import app.db as db
    import aiosqlite
    await db.init_db()
    async with aiosqlite.connect(tmp_db) as conn:
        cur = await conn.execute("PRAGMA table_info(interactions)")
        columns = {row[1] for row in await cur.fetchall()}
    assert "ranker_version" in columns
    assert "candidate_source" in columns
    assert "cluster_id" in columns


@pytest.mark.asyncio
async def test_log_interaction_stores_instrumentation_fields(tmp_db):
    """log_interaction should persist ranker_version, candidate_source, cluster_id."""
    import app.db as db
    import aiosqlite
    await db.init_db()
    await db.log_interaction(
        user_id="u1",
        paper_id="p1",
        event_type="save",
        source="recommendation",
        ranker_version="v4.1_test",
        candidate_source="cluster_0",
        cluster_id=0,
    )
    async with aiosqlite.connect(tmp_db) as conn:
        conn.row_factory = aiosqlite.Row
        cur = await conn.execute(
            "SELECT ranker_version, candidate_source, cluster_id FROM interactions WHERE paper_id = 'p1'"
        )
        row = dict(await cur.fetchone())
    assert row["ranker_version"] == "v4.1_test"
    assert row["candidate_source"] == "cluster_0"
    assert row["cluster_id"] == 0


@pytest.mark.asyncio
async def test_log_interaction_instrumentation_defaults_to_null(tmp_db):
    """Omitting instrumentation fields should store NULLs (backward compat)."""
    import app.db as db
    import aiosqlite
    await db.init_db()
    await db.log_interaction("u1", "p2", "save", source="search")
    async with aiosqlite.connect(tmp_db) as conn:
        conn.row_factory = aiosqlite.Row
        cur = await conn.execute(
            "SELECT ranker_version, candidate_source, cluster_id FROM interactions WHERE paper_id = 'p2'"
        )
        row = dict(await cur.fetchone())
    assert row["ranker_version"] is None
    assert row["candidate_source"] is None
    assert row["cluster_id"] is None


@pytest.mark.asyncio
async def test_migration_idempotent(tmp_db):
    """Calling init_db() twice must not crash (ALTER TABLE migration is safe)."""
    import app.db as db
    await db.init_db()
    await db.init_db()  # second call β€” migration should be idempotent
    # No exception = success


@pytest.mark.asyncio
async def test_instrumentation_exploration_tag(tmp_db):
    """Exploration papers should be stored with candidate_source='exploration'."""
    import app.db as db
    import aiosqlite
    await db.init_db()
    await db.log_interaction(
        user_id="u1",
        paper_id="explore_paper",
        event_type="save",
        source="recommendation",
        ranker_version="v4.1_quota_hungarian_suppression",
        candidate_source="exploration",
        cluster_id=None,
    )
    async with aiosqlite.connect(tmp_db) as conn:
        conn.row_factory = aiosqlite.Row
        cur = await conn.execute(
            "SELECT candidate_source, cluster_id FROM interactions WHERE paper_id = 'explore_paper'"
        )
        row = dict(await cur.fetchone())
    assert row["candidate_source"] == "exploration"
    assert row["cluster_id"] is None