Spaces:
Sleeping
Sleeping
| # tests/test_unit.py | |
| import pytest | |
| # ββ RRF logic βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_rrf_prefers_doc_appearing_in_both_lists(): | |
| from retriever import _reciprocal_rank_fusion | |
| scores = _reciprocal_rank_fusion([[0, 1, 2], [2, 0, 1]]) | |
| # doc 2 is rank-0 in sparse and rank-2 in dense β should beat doc 1 | |
| assert scores[2] > scores[1] | |
| def test_rrf_returns_all_docs(): | |
| from retriever import _reciprocal_rank_fusion | |
| scores = _reciprocal_rank_fusion([[0, 1], [1, 2]]) | |
| assert set(scores.keys()) == {0, 1, 2} | |
| def test_rrf_scores_are_positive(): | |
| from retriever import _reciprocal_rank_fusion | |
| scores = _reciprocal_rank_fusion([[0, 1, 2]]) | |
| assert all(v > 0 for v in scores.values()) | |
| # ββ Config sanity βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_config_values_are_sane(): | |
| from config import CHUNK_SIZE, CHUNK_OVERLAP, TOP_K, MAX_RETRIES | |
| assert CHUNK_SIZE > CHUNK_OVERLAP, "overlap must be smaller than chunk size" | |
| assert TOP_K > 0, "TOP_K must be positive" | |
| assert MAX_RETRIES >= 1, "need at least 1 retry" | |
| def test_groq_api_key_present(monkeypatch): | |
| # patch so we don't need a real key in CI | |
| monkeypatch.setenv("GROQ_API_KEY", "gsk_fakekeyfortesting1234567890") | |
| import importlib, config | |
| importlib.reload(config) # re-reads env | |
| assert len(config.GROQ_API_KEY) > 10 | |
| # ββ Agent routing logic βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_route_returns_done_on_pass(): | |
| from agent import route_after_validation | |
| state = {"validation_result": "PASS", "retry_count": 0} | |
| assert route_after_validation(state) == "done" | |
| def test_route_returns_retry_on_fail_within_limit(): | |
| from agent import route_after_validation | |
| state = {"validation_result": "FAIL", "retry_count": 0} | |
| assert route_after_validation(state) == "retry" | |
| def test_route_returns_done_when_retries_exhausted(): | |
| from agent import route_after_validation | |
| state = {"validation_result": "FAIL", "retry_count": 3} | |
| assert route_after_validation(state) == "done" | |
| def test_increment_retry_node(): | |
| from agent import increment_retry_node | |
| result = increment_retry_node({"retry_count": 1}) | |
| assert result["retry_count"] == 2 | |
| def test_parse_validation_score_accepts_score_out_of_100(): | |
| from agent import _parse_validation_score | |
| assert _parse_validation_score("85/100", 0) == 85 | |
| def test_agent_returns_best_attempt_when_validation_fails(monkeypatch): | |
| import agent | |
| class FakeGraph: | |
| def invoke(self, init_state): | |
| return { | |
| **init_state, | |
| "answer": "weak final answer", | |
| "retry_count": 3, | |
| "validation_result": "FAIL", | |
| "validation_score": 40, | |
| "fail_reason": "Not supported by context", | |
| "best_answer": "best available answer", | |
| "best_validation_score": 70, | |
| "best_fail_reason": "Partially supported by context", | |
| } | |
| monkeypatch.setattr(agent, "_rag_graph", FakeGraph()) | |
| answer, retries, verdict = agent.run_rag_agent("q", [{"chunk": "c", "source": "s"}]) | |
| assert "I could not fully validate a confident answer" in answer | |
| assert "validation score: 70/100" in answer | |
| assert "Partially supported by context" in answer | |
| assert "best available answer" in answer | |
| assert retries == 3 | |
| assert verdict == "FAIL" | |
| # ββ Retriever output shape (mocked indexes) βββββββββββββββββββββββββββββββββββ | |
| def mock_indexes(monkeypatch): | |
| """Patches all globals in retriever so no files need to exist.""" | |
| import numpy as np | |
| import retriever | |
| # Fake chunks and sources | |
| fake_chunks = ["Paris is in France.", "Tower is 330m tall.", "Built in 1889."] | |
| fake_sources = ["doc1.txt", "doc1.txt", "doc1.txt"] | |
| class FakeCollection: | |
| def count(self): | |
| return len(fake_chunks) | |
| def query(self, query_embeddings, n_results, include): | |
| # Returns the same shape ChromaDB returns | |
| return { | |
| "documents": [fake_chunks[:n_results]], | |
| "metadatas": [[{"source": s} for s in fake_sources[:n_results]]], | |
| "distances": [[0.1, 0.2, 0.3][:n_results]], | |
| } | |
| # Fake BM25 that returns uniform scores | |
| class FakeBM25: | |
| def get_scores(self, tokens): | |
| return np.array([0.9, 0.5, 0.3]) | |
| # Fake embedder | |
| class FakeModel: | |
| def encode(self, texts, convert_to_numpy=True): | |
| return np.random.rand(len(texts), 384).astype("float32") | |
| # Fake cross-encoder | |
| class FakeReranker: | |
| def predict(self, pairs): | |
| return np.array([0.9, 0.7, 0.5][: len(pairs)]) | |
| monkeypatch.setattr(retriever, "_collection", FakeCollection()) | |
| monkeypatch.setattr(retriever, "_bm25_index", FakeBM25()) | |
| monkeypatch.setattr(retriever, "_chunks", fake_chunks) | |
| monkeypatch.setattr(retriever, "_sources", fake_sources) | |
| monkeypatch.setattr(retriever, "_model", FakeModel()) | |
| monkeypatch.setattr(retriever, "_reranker", FakeReranker()) | |
| return fake_chunks | |
| def test_hybrid_retrieve_returns_top_k(mock_indexes): | |
| from retriever import hybrid_retrieve | |
| results = hybrid_retrieve("Where is Paris?", top_k=2) | |
| assert len(results) == 2 | |
| def test_hybrid_retrieve_result_has_required_keys(mock_indexes): | |
| from retriever import hybrid_retrieve | |
| result = hybrid_retrieve("Where is Paris?", top_k=1)[0] | |
| assert "chunk" in result | |
| assert "source" in result | |
| assert "rrf_score" in result | |
| assert "ce_score" in result | |
| def test_hybrid_retrieve_scores_are_floats(mock_indexes): | |
| from retriever import hybrid_retrieve | |
| result = hybrid_retrieve("test", top_k=1)[0] | |
| assert isinstance(result["rrf_score"], float) | |
| assert isinstance(result["ce_score"], float) | |