File size: 6,411 Bytes
b689b3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee5d4b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b689b3f
 
 
 
 
 
 
 
 
 
 
 
3c72c9d
 
 
 
 
 
 
 
 
 
 
b689b3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c72c9d
b689b3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c72c9d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# tests/test_unit.py
import pytest

# ── RRF logic ─────────────────────────────────────────────────────────────────

def test_rrf_prefers_doc_appearing_in_both_lists():
    from retriever import _reciprocal_rank_fusion
    scores = _reciprocal_rank_fusion([[0, 1, 2], [2, 0, 1]])
    # doc 2 is rank-0 in sparse and rank-2 in dense β†’ should beat doc 1
    assert scores[2] > scores[1]

def test_rrf_returns_all_docs():
    from retriever import _reciprocal_rank_fusion
    scores = _reciprocal_rank_fusion([[0, 1], [1, 2]])
    assert set(scores.keys()) == {0, 1, 2}

def test_rrf_scores_are_positive():
    from retriever import _reciprocal_rank_fusion
    scores = _reciprocal_rank_fusion([[0, 1, 2]])
    assert all(v > 0 for v in scores.values())

# ── Config sanity ─────────────────────────────────────────────────────────────

def test_config_values_are_sane():
    from config import CHUNK_SIZE, CHUNK_OVERLAP, TOP_K, MAX_RETRIES
    assert CHUNK_SIZE > CHUNK_OVERLAP,  "overlap must be smaller than chunk size"
    assert TOP_K > 0,                   "TOP_K must be positive"
    assert MAX_RETRIES >= 1,            "need at least 1 retry"

def test_groq_api_key_present(monkeypatch):
    # patch so we don't need a real key in CI
    monkeypatch.setenv("GROQ_API_KEY", "gsk_fakekeyfortesting1234567890")
    import importlib, config
    importlib.reload(config)             # re-reads env
    assert len(config.GROQ_API_KEY) > 10

# ── Agent routing logic ───────────────────────────────────────────────────────

def test_route_returns_done_on_pass():
    from agent import route_after_validation
    state = {"validation_result": "PASS", "retry_count": 0}
    assert route_after_validation(state) == "done"

def test_route_returns_retry_on_fail_within_limit():
    from agent import route_after_validation
    state = {"validation_result": "FAIL", "retry_count": 0}
    assert route_after_validation(state) == "retry"

def test_route_returns_done_when_retries_exhausted():
    from agent import route_after_validation
    state = {"validation_result": "FAIL", "retry_count": 3}
    assert route_after_validation(state) == "done"

def test_increment_retry_node():
    from agent import increment_retry_node
    result = increment_retry_node({"retry_count": 1})
    assert result["retry_count"] == 2

def test_parse_validation_score_accepts_score_out_of_100():
    from agent import _parse_validation_score
    assert _parse_validation_score("85/100", 0) == 85

def test_agent_returns_best_attempt_when_validation_fails(monkeypatch):
    import agent

    class FakeGraph:
        def invoke(self, init_state):
            return {
                **init_state,
                "answer": "weak final answer",
                "retry_count": 3,
                "validation_result": "FAIL",
                "validation_score": 40,
                "fail_reason": "Not supported by context",
                "best_answer": "best available answer",
                "best_validation_score": 70,
                "best_fail_reason": "Partially supported by context",
            }

    monkeypatch.setattr(agent, "_rag_graph", FakeGraph())
    answer, retries, verdict = agent.run_rag_agent("q", [{"chunk": "c", "source": "s"}])
    assert "I could not fully validate a confident answer" in answer
    assert "validation score: 70/100" in answer
    assert "Partially supported by context" in answer
    assert "best available answer" in answer
    assert retries == 3
    assert verdict == "FAIL"

# ── Retriever output shape (mocked indexes) ───────────────────────────────────

@pytest.fixture
def mock_indexes(monkeypatch):
    """Patches all globals in retriever so no files need to exist."""
    import numpy as np
    import retriever

    # Fake chunks and sources
    fake_chunks  = ["Paris is in France.", "Tower is 330m tall.", "Built in 1889."]
    fake_sources = ["doc1.txt", "doc1.txt", "doc1.txt"]

    class FakeCollection:
        def count(self):
            return len(fake_chunks)

        def query(self, query_embeddings, n_results, include):
            # Returns the same shape ChromaDB returns
            return {
                "documents": [fake_chunks[:n_results]],
                "metadatas": [[{"source": s} for s in fake_sources[:n_results]]],
                "distances": [[0.1, 0.2, 0.3][:n_results]],
            }
    # Fake BM25 that returns uniform scores
    class FakeBM25:
        def get_scores(self, tokens):
            return np.array([0.9, 0.5, 0.3])

    # Fake embedder
    class FakeModel:
        def encode(self, texts, convert_to_numpy=True):
            return np.random.rand(len(texts), 384).astype("float32")

    # Fake cross-encoder
    class FakeReranker:
        def predict(self, pairs):
            return np.array([0.9, 0.7, 0.5][: len(pairs)])

    monkeypatch.setattr(retriever, "_collection",  FakeCollection())
    monkeypatch.setattr(retriever, "_bm25_index",  FakeBM25())
    monkeypatch.setattr(retriever, "_chunks",      fake_chunks)
    monkeypatch.setattr(retriever, "_sources",     fake_sources)
    monkeypatch.setattr(retriever, "_model",       FakeModel())
    monkeypatch.setattr(retriever, "_reranker",    FakeReranker())
    return fake_chunks


def test_hybrid_retrieve_returns_top_k(mock_indexes):
    from retriever import hybrid_retrieve
    results = hybrid_retrieve("Where is Paris?", top_k=2)
    assert len(results) == 2

def test_hybrid_retrieve_result_has_required_keys(mock_indexes):
    from retriever import hybrid_retrieve
    result = hybrid_retrieve("Where is Paris?", top_k=1)[0]
    assert "chunk"     in result
    assert "source"    in result
    assert "rrf_score" in result
    assert "ce_score"  in result

def test_hybrid_retrieve_scores_are_floats(mock_indexes):
    from retriever import hybrid_retrieve
    result = hybrid_retrieve("test", top_k=1)[0]
    assert isinstance(result["rrf_score"], float)
    assert isinstance(result["ce_score"],  float)