File size: 3,624 Bytes
31a2688
 
9612292
31a2688
 
 
 
 
9612292
31a2688
 
 
 
 
 
 
 
 
 
9612292
 
 
31a2688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9612292
 
31a2688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9612292
31a2688
 
 
 
 
 
 
9612292
 
31a2688
9612292
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Tests for the cross-encoder reranker."""

from unittest.mock import MagicMock

import numpy as np
import pytest

from src.models import DocumentChunk, QueryResult
from src.retrieval.reranker import Reranker, _sigmoid


def _make_result(text: str, score: float) -> QueryResult:
    """Create a QueryResult with the given text and score."""
    chunk = DocumentChunk(chunk_id="c1", document_id="d1", text=text)
    return QueryResult(chunk=chunk, score=score, source="test")


@pytest.fixture
def reranker() -> Reranker:
    """Return a Reranker with a mocked model."""
    model = MagicMock()
    return Reranker(model=model)


class TestRerank:
    """Tests for Reranker.rerank."""

    def test_rerank_reorders_by_cross_encoder_score(self, reranker: Reranker) -> None:
        """Reranked order should follow cross-encoder scores, not original order."""
        results = [
            _make_result("low relevance", score=0.9),
            _make_result("high relevance", score=0.1),
            _make_result("mid relevance", score=0.5),
        ]
        # Cross-encoder assigns: low->0.1, high->0.9, mid->0.5
        reranker._model.predict = MagicMock(return_value=np.array([0.1, 0.9, 0.5]))

        reranked = reranker.rerank("test query", results, top_k=3)

        assert len(reranked) == 3
        assert reranked[0].chunk.text == "high relevance"
        assert reranked[1].chunk.text == "mid relevance"
        assert reranked[2].chunk.text == "low relevance"
        assert all(r.source == "reranked" for r in reranked)
        # Scores must be sigmoid-normalized to [0, 1]
        assert all(0.0 <= r.score <= 1.0 for r in reranked)

    def test_rerank_respects_top_k(self, reranker: Reranker) -> None:
        """Only top_k results should be returned."""
        results = [_make_result(f"doc{i}", score=0.5) for i in range(5)]
        reranker._model.predict = MagicMock(return_value=np.array([0.1, 0.5, 0.9, 0.3, 0.7]))

        reranked = reranker.rerank("query", results, top_k=2)

        assert len(reranked) == 2
        assert reranked[0].chunk.text == "doc2"
        assert reranked[1].chunk.text == "doc4"

    def test_rerank_empty_list(self, reranker: Reranker) -> None:
        """Empty input should return empty list without calling the model."""
        reranked = reranker.rerank("query", [], top_k=5)

        assert reranked == []
        reranker._model.predict.assert_not_called()

    def test_rerank_single_result(self, reranker: Reranker) -> None:
        """A single result should be returned with sigmoid-normalized score."""
        results = [_make_result("only doc", score=0.3)]
        reranker._model.predict = MagicMock(return_value=np.array([0.85]))

        reranked = reranker.rerank("query", results, top_k=1)

        assert len(reranked) == 1
        assert reranked[0].chunk.text == "only doc"
        assert reranked[0].score == pytest.approx(_sigmoid(0.85))
        assert 0.0 <= reranked[0].score <= 1.0
        assert reranked[0].source == "reranked"

    def test_rerank_negative_scores_normalized(self, reranker: Reranker) -> None:
        """Negative cross-encoder scores must be normalized to [0, 1]."""
        results = [
            _make_result("bad", score=0.5),
            _make_result("worse", score=0.5),
        ]
        reranker._model.predict = MagicMock(return_value=np.array([-2.0, -5.0]))

        reranked = reranker.rerank("query", results, top_k=2)

        assert all(0.0 <= r.score <= 1.0 for r in reranked)
        assert reranked[0].score == pytest.approx(_sigmoid(-2.0))
        assert reranked[1].score == pytest.approx(_sigmoid(-5.0))