File size: 2,873 Bytes
8fa7af1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Tests for the source-result -> chunk -> top-k retrieval pipeline."""

import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

import research.retrieval as retrieval
from research.retrieval import chunk_markdown, rank_chunks_for_query
from research.types import ResearchChunk


class FakeTextEmbedding:
    def __init__(self):
        self.seen_texts = []

    def embed(self, texts):
        self.seen_texts.extend(texts)
        for text in texts:
            lower = text.lower()
            if "backpropagation" in lower or "chain rule" in lower or "target" in lower:
                yield [1.0, 0.0]
            else:
                yield [0.0, 1.0]


def _chunk(title: str, text: str) -> ResearchChunk:
    return ResearchChunk(
        source="test",
        tool="fetch_docs",
        title=title,
        url="https://example.test",
        text=text,
    )


def test_chunking_then_top5_ranking():
    docs = chunk_markdown(
        """
# Intro
general overview
# Chain Rule
backpropagation chain rule gradients neural network
# History
unrelated history
""",
        "Fallback",
    )
    chunks = [_chunk(title, text) for title, text in docs]
    chunks.extend(_chunk(f"Filler {i}", f"unrelated filler {i}") for i in range(8))

    ranked = rank_chunks_for_query(
        "backpropagation",
        "chain rule gradients",
        chunks,
        embedding_model=FakeTextEmbedding(),
    )

    assert len(ranked) == 5
    assert [chunk.rank for chunk in ranked] == [1, 2, 3, 4, 5]
    assert ranked[0].title == "Chain Rule"


def test_embedding_ranking_is_not_bm25():
    chunks = [
        _chunk("Lexical Match", "query repeated query repeated lexical only"),
        _chunk("Embedding Match", "target concept with less lexical overlap"),
        _chunk("Other", "unrelated content"),
    ]

    ranked = rank_chunks_for_query(
        "query",
        "intent target",
        chunks,
        top_k=2,
        embedding_model=FakeTextEmbedding(),
    )

    assert len(ranked) == 2
    assert ranked[0].title == "Embedding Match"


def test_preload_embedding_model_warms_runtime():
    previous_model = retrieval._EMBEDDING_MODEL
    fake_model = FakeTextEmbedding()
    retrieval._EMBEDDING_MODEL = fake_model
    try:
        retrieval.preload_embedding_model()
        assert fake_model.seen_texts == ["startup warmup"]
    finally:
        retrieval._EMBEDDING_MODEL = previous_model


if __name__ == "__main__":
    tests = [
        test_chunking_then_top5_ranking,
        test_embedding_ranking_is_not_bm25,
        test_preload_embedding_model_warms_runtime,
    ]
    passed = 0
    for test in tests:
        try:
            test()
            passed += 1
        except Exception as exc:
            print(f"FAIL: {test.__name__}: {exc}")
    print(f"PASS: test_retrieval ({passed}/{len(tests)})")