File size: 6,495 Bytes
1670833 c9e4648 5b8c2e5 c9e4648 1670833 3537619 1670833 3537619 1670833 3537619 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | # tests/test_unit.py
import pytest
# ββ RRF logic βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def test_rrf_prefers_doc_appearing_in_both_lists():
from retriever import _reciprocal_rank_fusion
scores = _reciprocal_rank_fusion([[0, 1, 2], [2, 0, 1]])
# doc 2 is rank-0 in sparse and rank-2 in dense β should beat doc 1
assert scores[2] > scores[1]
def test_rrf_returns_all_docs():
from retriever import _reciprocal_rank_fusion
scores = _reciprocal_rank_fusion([[0, 1], [1, 2]])
assert set(scores.keys()) == {0, 1, 2}
def test_rrf_scores_are_positive():
from retriever import _reciprocal_rank_fusion
scores = _reciprocal_rank_fusion([[0, 1, 2]])
assert all(v > 0 for v in scores.values())
# ββ Config sanity βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def test_config_values_are_sane():
from config import CHUNK_SIZE, CHUNK_OVERLAP, TOP_K, MAX_RETRIES
assert CHUNK_SIZE > CHUNK_OVERLAP, "overlap must be smaller than chunk size"
assert TOP_K > 0, "TOP_K must be positive"
assert MAX_RETRIES >= 1, "need at least 1 retry"
def test_groq_api_key_present(monkeypatch):
# patch so we don't need a real key in CI
monkeypatch.setenv("GROQ_API_KEY", "gsk_fakekeyfortesting1234567890")
import importlib, config
importlib.reload(config) # re-reads env
assert len(config.GROQ_API_KEY) > 10
# ββ Agent routing logic βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def test_route_returns_done_on_pass():
from agent import route_after_validation
state = {"validation_result": "PASS", "retry_count": 0}
assert route_after_validation(state) == "done"
def test_route_returns_retry_on_fail_within_limit():
from agent import route_after_validation
state = {"validation_result": "FAIL", "retry_count": 0}
assert route_after_validation(state) == "retry"
def test_route_returns_done_when_retries_exhausted():
from agent import route_after_validation
state = {"validation_result": "FAIL", "retry_count": 3}
assert route_after_validation(state) == "done"
def test_increment_retry_node():
from agent import increment_retry_node
result = increment_retry_node({"retry_count": 1})
assert result["retry_count"] == 2
def test_parse_validation_score_accepts_score_out_of_100():
from agent import _parse_validation_score
assert _parse_validation_score("85/100", 0) == 85
def test_agent_returns_best_attempt_when_validation_fails(monkeypatch):
import agent
class FakeGraph:
def invoke(self, init_state):
return {
**init_state,
"answer": "weak final answer",
"retry_count": 3,
"validation_result": "FAIL",
"validation_score": 40,
"fail_reason": "Not supported by context",
"best_answer": "best available answer",
"best_validation_score": 70,
"best_fail_reason": "Partially supported by context",
}
monkeypatch.setattr(agent, "_rag_graph", FakeGraph())
result = agent.run_rag_agent("q", [{"chunk": "c", "source": "s"}])
answer = result["answer"]
retries = result["retries_used"]
verdict = result["validation"]
assert "I could not fully validate a confident answer" in answer
assert "validation score: 70/100" in answer
assert "Partially supported by context" in answer
assert "best available answer" in answer
assert retries == 3
assert verdict == "FAIL"
# ββ Retriever output shape (mocked indexes) βββββββββββββββββββββββββββββββββββ
@pytest.fixture
def mock_indexes(monkeypatch):
"""Patches all globals in retriever so no files need to exist."""
import numpy as np
import retriever
# Fake chunks and sources
fake_chunks = ["Paris is in France.", "Tower is 330m tall.", "Built in 1889."]
fake_sources = ["doc1.txt", "doc1.txt", "doc1.txt"]
class FakeCollection:
def count(self):
return len(fake_chunks)
def query(self, query_embeddings, n_results, include):
# Returns the same shape ChromaDB returns
return {
"documents": [fake_chunks[:n_results]],
"metadatas": [[{"source": s} for s in fake_sources[:n_results]]],
"distances": [[0.1, 0.2, 0.3][:n_results]],
}
# Fake BM25 that returns uniform scores
class FakeBM25:
def get_scores(self, tokens):
return np.array([0.9, 0.5, 0.3])
# Fake embedder
class FakeModel:
def encode(self, texts, convert_to_numpy=True):
return np.random.rand(len(texts), 384).astype("float32")
# Fake cross-encoder
class FakeReranker:
def predict(self, pairs):
return np.array([0.9, 0.7, 0.5][: len(pairs)])
monkeypatch.setattr(retriever, "_collection", FakeCollection())
monkeypatch.setattr(retriever, "_bm25_index", FakeBM25())
monkeypatch.setattr(retriever, "_chunks", fake_chunks)
monkeypatch.setattr(retriever, "_sources", fake_sources)
monkeypatch.setattr(retriever, "_model", FakeModel())
monkeypatch.setattr(retriever, "_reranker", FakeReranker())
return fake_chunks
def test_hybrid_retrieve_returns_top_k(mock_indexes):
from retriever import hybrid_retrieve
results = hybrid_retrieve("Where is Paris?", top_k=2)
assert len(results) == 2
def test_hybrid_retrieve_result_has_required_keys(mock_indexes):
from retriever import hybrid_retrieve
result = hybrid_retrieve("Where is Paris?", top_k=1)[0]
assert "chunk" in result
assert "source" in result
assert "rrf_score" in result
assert "ce_score" in result
def test_hybrid_retrieve_scores_are_floats(mock_indexes):
from retriever import hybrid_retrieve
result = hybrid_retrieve("test", top_k=1)[0]
assert isinstance(result["rrf_score"], float)
assert isinstance(result["ce_score"], float)
|