Spaces:
Sleeping
Sleeping
File size: 6,411 Bytes
b689b3f ee5d4b7 b689b3f 3c72c9d b689b3f 3c72c9d b689b3f 3c72c9d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | # tests/test_unit.py
import pytest
# ββ RRF logic βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def test_rrf_prefers_doc_appearing_in_both_lists():
from retriever import _reciprocal_rank_fusion
scores = _reciprocal_rank_fusion([[0, 1, 2], [2, 0, 1]])
# doc 2 is rank-0 in sparse and rank-2 in dense β should beat doc 1
assert scores[2] > scores[1]
def test_rrf_returns_all_docs():
from retriever import _reciprocal_rank_fusion
scores = _reciprocal_rank_fusion([[0, 1], [1, 2]])
assert set(scores.keys()) == {0, 1, 2}
def test_rrf_scores_are_positive():
from retriever import _reciprocal_rank_fusion
scores = _reciprocal_rank_fusion([[0, 1, 2]])
assert all(v > 0 for v in scores.values())
# ββ Config sanity βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def test_config_values_are_sane():
from config import CHUNK_SIZE, CHUNK_OVERLAP, TOP_K, MAX_RETRIES
assert CHUNK_SIZE > CHUNK_OVERLAP, "overlap must be smaller than chunk size"
assert TOP_K > 0, "TOP_K must be positive"
assert MAX_RETRIES >= 1, "need at least 1 retry"
def test_groq_api_key_present(monkeypatch):
# patch so we don't need a real key in CI
monkeypatch.setenv("GROQ_API_KEY", "gsk_fakekeyfortesting1234567890")
import importlib, config
importlib.reload(config) # re-reads env
assert len(config.GROQ_API_KEY) > 10
# ββ Agent routing logic βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def test_route_returns_done_on_pass():
from agent import route_after_validation
state = {"validation_result": "PASS", "retry_count": 0}
assert route_after_validation(state) == "done"
def test_route_returns_retry_on_fail_within_limit():
from agent import route_after_validation
state = {"validation_result": "FAIL", "retry_count": 0}
assert route_after_validation(state) == "retry"
def test_route_returns_done_when_retries_exhausted():
from agent import route_after_validation
state = {"validation_result": "FAIL", "retry_count": 3}
assert route_after_validation(state) == "done"
def test_increment_retry_node():
from agent import increment_retry_node
result = increment_retry_node({"retry_count": 1})
assert result["retry_count"] == 2
def test_parse_validation_score_accepts_score_out_of_100():
from agent import _parse_validation_score
assert _parse_validation_score("85/100", 0) == 85
def test_agent_returns_best_attempt_when_validation_fails(monkeypatch):
import agent
class FakeGraph:
def invoke(self, init_state):
return {
**init_state,
"answer": "weak final answer",
"retry_count": 3,
"validation_result": "FAIL",
"validation_score": 40,
"fail_reason": "Not supported by context",
"best_answer": "best available answer",
"best_validation_score": 70,
"best_fail_reason": "Partially supported by context",
}
monkeypatch.setattr(agent, "_rag_graph", FakeGraph())
answer, retries, verdict = agent.run_rag_agent("q", [{"chunk": "c", "source": "s"}])
assert "I could not fully validate a confident answer" in answer
assert "validation score: 70/100" in answer
assert "Partially supported by context" in answer
assert "best available answer" in answer
assert retries == 3
assert verdict == "FAIL"
# ββ Retriever output shape (mocked indexes) βββββββββββββββββββββββββββββββββββ
@pytest.fixture
def mock_indexes(monkeypatch):
"""Patches all globals in retriever so no files need to exist."""
import numpy as np
import retriever
# Fake chunks and sources
fake_chunks = ["Paris is in France.", "Tower is 330m tall.", "Built in 1889."]
fake_sources = ["doc1.txt", "doc1.txt", "doc1.txt"]
class FakeCollection:
def count(self):
return len(fake_chunks)
def query(self, query_embeddings, n_results, include):
# Returns the same shape ChromaDB returns
return {
"documents": [fake_chunks[:n_results]],
"metadatas": [[{"source": s} for s in fake_sources[:n_results]]],
"distances": [[0.1, 0.2, 0.3][:n_results]],
}
# Fake BM25 that returns uniform scores
class FakeBM25:
def get_scores(self, tokens):
return np.array([0.9, 0.5, 0.3])
# Fake embedder
class FakeModel:
def encode(self, texts, convert_to_numpy=True):
return np.random.rand(len(texts), 384).astype("float32")
# Fake cross-encoder
class FakeReranker:
def predict(self, pairs):
return np.array([0.9, 0.7, 0.5][: len(pairs)])
monkeypatch.setattr(retriever, "_collection", FakeCollection())
monkeypatch.setattr(retriever, "_bm25_index", FakeBM25())
monkeypatch.setattr(retriever, "_chunks", fake_chunks)
monkeypatch.setattr(retriever, "_sources", fake_sources)
monkeypatch.setattr(retriever, "_model", FakeModel())
monkeypatch.setattr(retriever, "_reranker", FakeReranker())
return fake_chunks
def test_hybrid_retrieve_returns_top_k(mock_indexes):
from retriever import hybrid_retrieve
results = hybrid_retrieve("Where is Paris?", top_k=2)
assert len(results) == 2
def test_hybrid_retrieve_result_has_required_keys(mock_indexes):
from retriever import hybrid_retrieve
result = hybrid_retrieve("Where is Paris?", top_k=1)[0]
assert "chunk" in result
assert "source" in result
assert "rrf_score" in result
assert "ce_score" in result
def test_hybrid_retrieve_scores_are_floats(mock_indexes):
from retriever import hybrid_retrieve
result = hybrid_retrieve("test", top_k=1)[0]
assert isinstance(result["rrf_score"], float)
assert isinstance(result["ce_score"], float)
|