Spaces:
Running
Running
| """ | |
| Tests for db.py using a temporary in-memory/temp database. | |
| """ | |
| import os | |
| import tempfile | |
| import pytest | |
| import pytest_asyncio | |
| # Override DB_PATH before importing db | |
| def tmp_db(monkeypatch, tmp_path): | |
| db_path = str(tmp_path / "test.db") | |
| monkeypatch.setenv("DB_PATH", db_path) | |
| # Also patch the module-level constant | |
| import app.config as cfg | |
| monkeypatch.setattr(cfg, "DB_PATH", db_path) | |
| import app.db as db_mod | |
| monkeypatch.setattr(db_mod, "DB_PATH", db_path) | |
| return db_path | |
| async def test_init_db_creates_tables(tmp_db): | |
| import app.db as db | |
| await db.init_db() | |
| import aiosqlite | |
| async with aiosqlite.connect(tmp_db) as conn: | |
| cur = await conn.execute( | |
| "SELECT name FROM sqlite_master WHERE type='table'" | |
| ) | |
| tables = {r[0] for r in await cur.fetchall()} | |
| assert "interactions" in tables | |
| assert "paper_qdrant_map" in tables | |
| assert "paper_metadata" in tables | |
| async def test_log_and_retrieve_interactions(tmp_db): | |
| import app.db as db | |
| await db.init_db() | |
| await db.log_interaction("user1", "1706.03762", "save", source="search") | |
| await db.log_interaction("user1", "2302.11382", "not_interested", source="search") | |
| rows = await db.get_user_interactions("user1") | |
| assert len(rows) == 2 | |
| paper_ids = {r["paper_id"] for r in rows} | |
| assert "1706.03762" in paper_ids | |
| assert "2302.11382" in paper_ids | |
| async def test_filter_interactions_by_event_type(tmp_db): | |
| import app.db as db | |
| await db.init_db() | |
| await db.log_interaction("u2", "aaa", "save") | |
| await db.log_interaction("u2", "bbb", "not_interested") | |
| await db.log_interaction("u2", "ccc", "click") | |
| saves = await db.get_user_interactions("u2", event_types=["save"]) | |
| assert len(saves) == 1 | |
| assert saves[0]["paper_id"] == "aaa" | |
| async def test_qdrant_id_roundtrip(tmp_db): | |
| import app.db as db | |
| await db.init_db() | |
| await db.save_qdrant_id("1706.03762", 42) | |
| assert await db.get_qdrant_id("1706.03762") == 42 | |
| assert await db.get_qdrant_id("unknown") is None | |
| async def test_qdrant_id_batch(tmp_db): | |
| import app.db as db | |
| await db.init_db() | |
| await db.save_qdrant_id("a1", 10) | |
| await db.save_qdrant_id("a2", 20) | |
| result = await db.get_qdrant_ids_batch(["a1", "a2", "a3"]) | |
| assert result == {"a1": 10, "a2": 20} | |
| async def test_metadata_cache_roundtrip(tmp_db): | |
| import app.db as db | |
| await db.init_db() | |
| paper = { | |
| "arxiv_id": "1706.03762", | |
| "title": "Attention Is All You Need", | |
| "abstract": "The dominant sequence transduction models...", | |
| "authors": '["Vaswani", "Shazeer"]', | |
| "category": "cs.CL", | |
| "published": "2017-06-12", | |
| } | |
| await db.cache_metadata(paper) | |
| cached = await db.get_cached_metadata("1706.03762") | |
| assert cached is not None | |
| assert cached["title"] == "Attention Is All You Need" | |
| assert cached["category"] == "cs.CL" | |
| async def test_metadata_cache_batch(tmp_db): | |
| import app.db as db | |
| await db.init_db() | |
| for i in range(3): | |
| await db.cache_metadata({ | |
| "arxiv_id": f"paper{i}", | |
| "title": f"Title {i}", | |
| "abstract": "...", | |
| "authors": "[]", | |
| "category": "cs.LG", | |
| "published": "2023-01-01", | |
| }) | |
| result = await db.get_cached_metadata_batch(["paper0", "paper2", "paper99"]) | |
| assert "paper0" in result | |
| assert "paper2" in result | |
| assert "paper99" not in result | |
| # ββ Phase 4.3: cache_turso_metadata_batch ββββββββββββββββββββββββββββββββββββ | |
| async def test_cache_turso_metadata_batch_writes_all(tmp_db): | |
| """Turso dicts should be written to paper_metadata verbatim.""" | |
| import app.db as db | |
| await db.init_db() | |
| papers = [ | |
| { | |
| "arxiv_id": "1706.03762", | |
| "title": "Attention Is All You Need", | |
| "abstract": "Transformers.", | |
| "authors": '["Vaswani"]', | |
| "category": "cs.CL", | |
| "published": "2017-06-12", | |
| "year": 2017, | |
| "citation_count": 50000, | |
| }, | |
| { | |
| "arxiv_id": "2001.00001", | |
| "title": "Another Paper", | |
| "abstract": "...", | |
| "authors": '["Smith"]', | |
| "category": "cs.CV", | |
| "published": "2020-01-01", | |
| "year": 2020, | |
| }, | |
| ] | |
| await db.cache_turso_metadata_batch(papers) | |
| cached = await db.get_cached_metadata("1706.03762") | |
| assert cached is not None | |
| assert cached["title"] == "Attention Is All You Need" | |
| assert cached["category"] == "cs.CL" | |
| cached2 = await db.get_cached_metadata("2001.00001") | |
| assert cached2 is not None | |
| assert cached2["category"] == "cs.CV" | |
| async def test_cache_turso_metadata_batch_empty(tmp_db): | |
| """Empty input must not crash.""" | |
| import app.db as db | |
| await db.init_db() | |
| await db.cache_turso_metadata_batch([]) | |
| # No exception = success | |
| async def test_cache_turso_metadata_batch_skips_missing_arxiv_id(tmp_db): | |
| """Rows without arxiv_id should be skipped, others persisted.""" | |
| import app.db as db | |
| await db.init_db() | |
| papers = [ | |
| {"title": "No ID", "category": "cs.LG"}, # missing arxiv_id | |
| {"arxiv_id": "good.123", "title": "Good", "category": "cs.AI", | |
| "abstract": "", "authors": "[]", "published": "2024-01-01"}, | |
| ] | |
| await db.cache_turso_metadata_batch(papers) | |
| cached = await db.get_cached_metadata("good.123") | |
| assert cached is not None | |
| assert cached["title"] == "Good" | |
| async def test_cache_turso_metadata_batch_upserts(tmp_db): | |
| """Second write for same arxiv_id should overwrite the first.""" | |
| import app.db as db | |
| await db.init_db() | |
| paper_v1 = {"arxiv_id": "p1", "title": "V1", "category": "cs.LG", | |
| "abstract": "", "authors": "[]", "published": "2024-01-01"} | |
| paper_v2 = {"arxiv_id": "p1", "title": "V2", "category": "cs.CV", | |
| "abstract": "", "authors": "[]", "published": "2024-01-01"} | |
| await db.cache_turso_metadata_batch([paper_v1]) | |
| await db.cache_turso_metadata_batch([paper_v2]) | |
| cached = await db.get_cached_metadata("p1") | |
| assert cached["title"] == "V2" | |
| assert cached["category"] == "cs.CV" | |
| # ββ Phase 4.3: get_suppressed_categories ββββββββββββββββββββββββββββββββββββββ | |
| async def test_suppressed_empty_for_new_user(tmp_db): | |
| import app.db as db | |
| await db.init_db() | |
| result = await db.get_suppressed_categories("never-dismissed") | |
| assert result == set() | |
| async def test_suppressed_below_threshold_not_returned(tmp_db): | |
| """Two dismissals in one category (< threshold=3) should NOT suppress.""" | |
| import app.db as db | |
| await db.init_db() | |
| # Seed metadata | |
| for i, aid in enumerate(["p1", "p2"]): | |
| await db.cache_metadata({ | |
| "arxiv_id": aid, "title": f"t{i}", "abstract": "", | |
| "authors": "[]", "category": "cs.CV", "published": "2024-01-01", | |
| }) | |
| # Two dismissals β below threshold=3 | |
| await db.log_interaction("u1", "p1", "not_interested") | |
| await db.log_interaction("u1", "p2", "not_interested") | |
| result = await db.get_suppressed_categories("u1") | |
| assert "cs.CV" not in result | |
| async def test_suppressed_at_threshold_returned(tmp_db): | |
| """Three dismissals in same category should suppress that category.""" | |
| import app.db as db | |
| await db.init_db() | |
| for i, aid in enumerate(["p1", "p2", "p3"]): | |
| await db.cache_metadata({ | |
| "arxiv_id": aid, "title": f"t{i}", "abstract": "", | |
| "authors": "[]", "category": "physics.optics", "published": "2024-01-01", | |
| }) | |
| for aid in ["p1", "p2", "p3"]: | |
| await db.log_interaction("u1", aid, "not_interested") | |
| result = await db.get_suppressed_categories("u1") | |
| assert "physics.optics" in result | |
| async def test_suppressed_only_counts_not_interested(tmp_db): | |
| """Saves should NOT count toward suppression.""" | |
| import app.db as db | |
| await db.init_db() | |
| for aid in ["p1", "p2", "p3"]: | |
| await db.cache_metadata({ | |
| "arxiv_id": aid, "title": "t", "abstract": "", | |
| "authors": "[]", "category": "cs.CL", "published": "2024-01-01", | |
| }) | |
| # 3 saves (not dismissals) in same category | |
| for aid in ["p1", "p2", "p3"]: | |
| await db.log_interaction("u1", aid, "save") | |
| result = await db.get_suppressed_categories("u1") | |
| assert "cs.CL" not in result | |
| async def test_suppressed_partitions_categories(tmp_db): | |
| """Different categories should be independent.""" | |
| import app.db as db | |
| await db.init_db() | |
| # 3 dismissals in cs.AI, 1 in cs.LG | |
| for aid in ["a1", "a2", "a3"]: | |
| await db.cache_metadata({ | |
| "arxiv_id": aid, "title": "t", "abstract": "", | |
| "authors": "[]", "category": "cs.AI", "published": "2024-01-01", | |
| }) | |
| await db.log_interaction("u1", aid, "not_interested") | |
| await db.cache_metadata({ | |
| "arxiv_id": "lone", "title": "t", "abstract": "", | |
| "authors": "[]", "category": "cs.LG", "published": "2024-01-01", | |
| }) | |
| await db.log_interaction("u1", "lone", "not_interested") | |
| result = await db.get_suppressed_categories("u1") | |
| assert "cs.AI" in result | |
| assert "cs.LG" not in result | |
| async def test_suppressed_ignores_other_users(tmp_db): | |
| """One user's dismissals must not affect another user's suppressions.""" | |
| import app.db as db | |
| await db.init_db() | |
| for aid in ["p1", "p2", "p3"]: | |
| await db.cache_metadata({ | |
| "arxiv_id": aid, "title": "t", "abstract": "", | |
| "authors": "[]", "category": "cs.CV", "published": "2024-01-01", | |
| }) | |
| await db.log_interaction("userA", aid, "not_interested") | |
| result_a = await db.get_suppressed_categories("userA") | |
| result_b = await db.get_suppressed_categories("userB") | |
| assert "cs.CV" in result_a | |
| assert result_b == set() | |
| async def test_suppressed_empty_category_excluded(tmp_db): | |
| """Papers with empty category string should not produce a '' suppression.""" | |
| import app.db as db | |
| await db.init_db() | |
| for aid in ["e1", "e2", "e3"]: | |
| await db.cache_metadata({ | |
| "arxiv_id": aid, "title": "t", "abstract": "", | |
| "authors": "[]", "category": "", "published": "2024-01-01", | |
| }) | |
| await db.log_interaction("u1", aid, "not_interested") | |
| result = await db.get_suppressed_categories("u1") | |
| assert "" not in result | |
| async def test_suppressed_custom_threshold(tmp_db): | |
| """Threshold=2 should trigger at 2 dismissals.""" | |
| import app.db as db | |
| await db.init_db() | |
| for aid in ["x1", "x2"]: | |
| await db.cache_metadata({ | |
| "arxiv_id": aid, "title": "t", "abstract": "", | |
| "authors": "[]", "category": "math.NT", "published": "2024-01-01", | |
| }) | |
| await db.log_interaction("u1", aid, "not_interested") | |
| result = await db.get_suppressed_categories("u1", threshold=2) | |
| assert "math.NT" in result | |
| result_high = await db.get_suppressed_categories("u1", threshold=5) | |
| assert "math.NT" not in result_high | |
| # ββ Phase 4.5: Instrumentation columns βββββββββββββββββββββββββββββββββββββββ | |
| async def test_instrumentation_columns_exist(tmp_db): | |
| """The interactions table should have ranker_version, candidate_source, cluster_id columns.""" | |
| import app.db as db | |
| import aiosqlite | |
| await db.init_db() | |
| async with aiosqlite.connect(tmp_db) as conn: | |
| cur = await conn.execute("PRAGMA table_info(interactions)") | |
| columns = {row[1] for row in await cur.fetchall()} | |
| assert "ranker_version" in columns | |
| assert "candidate_source" in columns | |
| assert "cluster_id" in columns | |
| async def test_log_interaction_stores_instrumentation_fields(tmp_db): | |
| """log_interaction should persist ranker_version, candidate_source, cluster_id.""" | |
| import app.db as db | |
| import aiosqlite | |
| await db.init_db() | |
| await db.log_interaction( | |
| user_id="u1", | |
| paper_id="p1", | |
| event_type="save", | |
| source="recommendation", | |
| ranker_version="v4.1_test", | |
| candidate_source="cluster_0", | |
| cluster_id=0, | |
| ) | |
| async with aiosqlite.connect(tmp_db) as conn: | |
| conn.row_factory = aiosqlite.Row | |
| cur = await conn.execute( | |
| "SELECT ranker_version, candidate_source, cluster_id FROM interactions WHERE paper_id = 'p1'" | |
| ) | |
| row = dict(await cur.fetchone()) | |
| assert row["ranker_version"] == "v4.1_test" | |
| assert row["candidate_source"] == "cluster_0" | |
| assert row["cluster_id"] == 0 | |
| async def test_log_interaction_instrumentation_defaults_to_null(tmp_db): | |
| """Omitting instrumentation fields should store NULLs (backward compat).""" | |
| import app.db as db | |
| import aiosqlite | |
| await db.init_db() | |
| await db.log_interaction("u1", "p2", "save", source="search") | |
| async with aiosqlite.connect(tmp_db) as conn: | |
| conn.row_factory = aiosqlite.Row | |
| cur = await conn.execute( | |
| "SELECT ranker_version, candidate_source, cluster_id FROM interactions WHERE paper_id = 'p2'" | |
| ) | |
| row = dict(await cur.fetchone()) | |
| assert row["ranker_version"] is None | |
| assert row["candidate_source"] is None | |
| assert row["cluster_id"] is None | |
| async def test_migration_idempotent(tmp_db): | |
| """Calling init_db() twice must not crash (ALTER TABLE migration is safe).""" | |
| import app.db as db | |
| await db.init_db() | |
| await db.init_db() # second call β migration should be idempotent | |
| # No exception = success | |
| async def test_instrumentation_exploration_tag(tmp_db): | |
| """Exploration papers should be stored with candidate_source='exploration'.""" | |
| import app.db as db | |
| import aiosqlite | |
| await db.init_db() | |
| await db.log_interaction( | |
| user_id="u1", | |
| paper_id="explore_paper", | |
| event_type="save", | |
| source="recommendation", | |
| ranker_version="v4.1_quota_hungarian_suppression", | |
| candidate_source="exploration", | |
| cluster_id=None, | |
| ) | |
| async with aiosqlite.connect(tmp_db) as conn: | |
| conn.row_factory = aiosqlite.Row | |
| cur = await conn.execute( | |
| "SELECT candidate_source, cluster_id FROM interactions WHERE paper_id = 'explore_paper'" | |
| ) | |
| row = dict(await cur.fetchone()) | |
| assert row["candidate_source"] == "exploration" | |
| assert row["cluster_id"] is None | |