Spaces:
Sleeping
Sleeping
initial commit
Browse files- .github/workflows/ci.yml +27 -0
- agent.py +1 -0
- ingestion.py +2 -2
- retriever.py +50 -22
- test_sources.py +1 -0
- tests/__init__.py +0 -0
- tests/test_integration.py +51 -0
- tests/test_unit.py +119 -0
.github/workflows/ci.yml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: RAG Unit Tests
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [main]
|
| 6 |
+
pull_request:
|
| 7 |
+
branches: [main]
|
| 8 |
+
|
| 9 |
+
jobs:
|
| 10 |
+
test:
|
| 11 |
+
runs-on: ubuntu-latest
|
| 12 |
+
|
| 13 |
+
steps:
|
| 14 |
+
- uses: actions/checkout@v4
|
| 15 |
+
|
| 16 |
+
- name: Set up Python
|
| 17 |
+
uses: actions/setup-python@v5
|
| 18 |
+
with:
|
| 19 |
+
python-version: "3.11"
|
| 20 |
+
|
| 21 |
+
- name: Install dependencies
|
| 22 |
+
run: pip install -r requirements.txt
|
| 23 |
+
|
| 24 |
+
- name: Run unit tests only # β integration tests are skipped here
|
| 25 |
+
env:
|
| 26 |
+
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} # add this in GitHub β Settings β Secrets
|
| 27 |
+
run: pytest tests/test_unit.py -v
|
agent.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from typing import TypedDict
|
| 2 |
from langgraph.graph import StateGraph, END
|
| 3 |
from langchain_groq import ChatGroq
|
|
|
|
| 1 |
+
#agent.py
|
| 2 |
from typing import TypedDict
|
| 3 |
from langgraph.graph import StateGraph, END
|
| 4 |
from langchain_groq import ChatGroq
|
ingestion.py
CHANGED
|
@@ -80,8 +80,8 @@ def load_documents():
|
|
| 80 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 81 |
def semantic_chunk(docs, filenames):
|
| 82 |
splitter = RecursiveCharacterTextSplitter(
|
| 83 |
-
chunk_size=
|
| 84 |
-
chunk_overlap=
|
| 85 |
separators=["\n\n", "\n", ". ", " "],
|
| 86 |
)
|
| 87 |
|
|
|
|
| 80 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 81 |
def semantic_chunk(docs, filenames):
|
| 82 |
splitter = RecursiveCharacterTextSplitter(
|
| 83 |
+
chunk_size=CHUNK_SIZE,
|
| 84 |
+
chunk_overlap=CHUNK_OVERLAP,
|
| 85 |
separators=["\n\n", "\n", ". ", " "],
|
| 86 |
)
|
| 87 |
|
retriever.py
CHANGED
|
@@ -1,17 +1,20 @@
|
|
| 1 |
import pickle
|
| 2 |
import numpy as np
|
| 3 |
import faiss
|
| 4 |
-
from sentence_transformers import SentenceTransformer
|
| 5 |
from config import (
|
| 6 |
FAISS_INDEX_PATH, BM25_PATH, CHUNKS_PATH,
|
| 7 |
SOURCES_PATH, EMBEDDER_PATH
|
| 8 |
)
|
| 9 |
|
| 10 |
-
_faiss_index
|
| 11 |
-
_bm25_index
|
| 12 |
-
_chunks
|
| 13 |
-
_sources
|
| 14 |
-
_model
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
def indexes_loaded() -> bool:
|
|
@@ -19,44 +22,69 @@ def indexes_loaded() -> bool:
|
|
| 19 |
|
| 20 |
|
| 21 |
def load_indexes():
|
| 22 |
-
global _faiss_index, _bm25_index, _chunks, _sources, _model
|
| 23 |
_faiss_index = faiss.read_index(FAISS_INDEX_PATH)
|
| 24 |
-
with open(BM25_PATH,
|
| 25 |
-
with open(CHUNKS_PATH,
|
| 26 |
-
with open(SOURCES_PATH,"rb") as f: _sources = pickle.load(f)
|
| 27 |
-
_model
|
|
|
|
| 28 |
print(f"Indexes loaded: {_faiss_index.ntotal} vectors, {len(_chunks)} chunks")
|
| 29 |
|
| 30 |
|
| 31 |
def reload_indexes():
|
| 32 |
-
global _faiss_index, _bm25_index, _chunks, _sources, _model
|
| 33 |
-
_faiss_index = _bm25_index = _chunks = _sources = _model = None
|
| 34 |
load_indexes()
|
| 35 |
|
| 36 |
|
| 37 |
-
def _reciprocal_rank_fusion(lists: list, k: int = 60) ->
|
|
|
|
| 38 |
scores: dict = {}
|
| 39 |
for ranked_list in lists:
|
| 40 |
for rank, doc_id in enumerate(ranked_list):
|
| 41 |
scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank + 1)
|
| 42 |
-
return
|
| 43 |
|
| 44 |
|
| 45 |
def hybrid_retrieve(query: str, top_k: int = 5) -> list:
|
| 46 |
if not indexes_loaded():
|
| 47 |
raise RuntimeError("Indexes not loaded. Call load_indexes() first.")
|
| 48 |
|
|
|
|
| 49 |
q_emb = _model.encode([query], convert_to_numpy=True).astype("float32")
|
| 50 |
faiss.normalize_L2(q_emb)
|
| 51 |
-
_, dense_ids
|
| 52 |
-
dense_ranking
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
return [
|
| 60 |
-
{
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
]
|
|
|
|
| 1 |
import pickle
|
| 2 |
import numpy as np
|
| 3 |
import faiss
|
| 4 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
| 5 |
from config import (
|
| 6 |
FAISS_INDEX_PATH, BM25_PATH, CHUNKS_PATH,
|
| 7 |
SOURCES_PATH, EMBEDDER_PATH
|
| 8 |
)
|
| 9 |
|
| 10 |
+
_faiss_index = None
|
| 11 |
+
_bm25_index = None
|
| 12 |
+
_chunks = None
|
| 13 |
+
_sources = None
|
| 14 |
+
_model = None
|
| 15 |
+
_reranker = None
|
| 16 |
+
|
| 17 |
+
RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
| 18 |
|
| 19 |
|
| 20 |
def indexes_loaded() -> bool:
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
def load_indexes():
|
| 25 |
+
global _faiss_index, _bm25_index, _chunks, _sources, _model, _reranker
|
| 26 |
_faiss_index = faiss.read_index(FAISS_INDEX_PATH)
|
| 27 |
+
with open(BM25_PATH, "rb") as f: _bm25_index = pickle.load(f)
|
| 28 |
+
with open(CHUNKS_PATH, "rb") as f: _chunks = pickle.load(f)
|
| 29 |
+
with open(SOURCES_PATH, "rb") as f: _sources = pickle.load(f)
|
| 30 |
+
_model = SentenceTransformer(EMBEDDER_PATH)
|
| 31 |
+
_reranker = CrossEncoder(RERANKER_MODEL) # β reranker loads once
|
| 32 |
print(f"Indexes loaded: {_faiss_index.ntotal} vectors, {len(_chunks)} chunks")
|
| 33 |
|
| 34 |
|
| 35 |
def reload_indexes():
|
| 36 |
+
global _faiss_index, _bm25_index, _chunks, _sources, _model, _reranker
|
| 37 |
+
_faiss_index = _bm25_index = _chunks = _sources = _model = _reranker = None
|
| 38 |
load_indexes()
|
| 39 |
|
| 40 |
|
| 41 |
+
def _reciprocal_rank_fusion(lists: list, k: int = 60) -> dict:
|
| 42 |
+
"""Returns {doc_id: rrf_score} instead of just a sorted list."""
|
| 43 |
scores: dict = {}
|
| 44 |
for ranked_list in lists:
|
| 45 |
for rank, doc_id in enumerate(ranked_list):
|
| 46 |
scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank + 1)
|
| 47 |
+
return scores # β return the dict now
|
| 48 |
|
| 49 |
|
| 50 |
def hybrid_retrieve(query: str, top_k: int = 5) -> list:
|
| 51 |
if not indexes_loaded():
|
| 52 |
raise RuntimeError("Indexes not loaded. Call load_indexes() first.")
|
| 53 |
|
| 54 |
+
# ββ Dense retrieval (FAISS) βββββββββββββββββββββββββββββββ
|
| 55 |
q_emb = _model.encode([query], convert_to_numpy=True).astype("float32")
|
| 56 |
faiss.normalize_L2(q_emb)
|
| 57 |
+
_, dense_ids = _faiss_index.search(q_emb, top_k * 3)
|
| 58 |
+
dense_ranking = [int(i) for i in dense_ids[0] if i >= 0]
|
| 59 |
+
|
| 60 |
+
# ββ Sparse retrieval (BM25) βββββββββββββββββββββββββββββββ
|
| 61 |
+
bm25_scores = _bm25_index.get_scores(query.lower().split())
|
| 62 |
+
sparse_ranking = np.argsort(bm25_scores)[::-1][: top_k * 3].tolist()
|
| 63 |
+
|
| 64 |
+
# ββ Fusion (RRF) β now returns score dict βββββββββββββββββ
|
| 65 |
+
rrf_scores = _reciprocal_rank_fusion([dense_ranking, sparse_ranking])
|
| 66 |
+
fused_ids = sorted(rrf_scores, key=rrf_scores.get, reverse=True)[: top_k * 2]
|
| 67 |
|
| 68 |
+
# ββ Cross-encoder reranking βββββββββββββββββββββββββββββββ
|
| 69 |
+
# The cross-encoder scores each (query, chunk) pair together
|
| 70 |
+
# much more accurately than embedding similarity alone
|
| 71 |
+
candidates = [(query, _chunks[i]) for i in fused_ids]
|
| 72 |
+
ce_scores = _reranker.predict(candidates) # shape: (len(candidates),)
|
| 73 |
|
| 74 |
+
# Sort by cross-encoder score, keep top_k
|
| 75 |
+
ranked = sorted(
|
| 76 |
+
zip(fused_ids, ce_scores),
|
| 77 |
+
key=lambda x: x[1],
|
| 78 |
+
reverse=True,
|
| 79 |
+
)[:top_k]
|
| 80 |
|
| 81 |
return [
|
| 82 |
+
{
|
| 83 |
+
"chunk": _chunks[i],
|
| 84 |
+
"source": _sources[i],
|
| 85 |
+
"chunk_id": i,
|
| 86 |
+
"rrf_score": round(float(rrf_scores[i]), 4),
|
| 87 |
+
"ce_score": round(float(score), 4), # β reranker confidence
|
| 88 |
+
}
|
| 89 |
+
for i, score in ranked
|
| 90 |
]
|
test_sources.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from retriever import load_indexes, _sources
|
| 2 |
|
| 3 |
load_indexes()
|
|
|
|
| 1 |
+
#test_source.py
|
| 2 |
from retriever import load_indexes, _sources
|
| 3 |
|
| 4 |
load_indexes()
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_integration.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# tests/test_integration.py
|
| 2 |
+
# Run with: pytest tests/test_integration.py -v -m integration
|
| 3 |
+
# These call real APIs β don't run in CI automatically.
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
pytestmark = pytest.mark.integration # tag so CI can skip these
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_groq_connection_live():
|
| 11 |
+
from langchain_groq import ChatGroq
|
| 12 |
+
from langchain_core.messages import HumanMessage
|
| 13 |
+
from config import GROQ_API_KEY, GROQ_MODEL
|
| 14 |
+
llm = ChatGroq(model=GROQ_MODEL, temperature=0, api_key=GROQ_API_KEY)
|
| 15 |
+
r = llm.invoke([HumanMessage(content="Reply with just the word OK")])
|
| 16 |
+
assert len(r.content) > 0
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_full_pipeline_live():
|
| 20 |
+
"""Ingests a tiny doc, retrieves, runs agent β end to end."""
|
| 21 |
+
import os
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
# Write test doc
|
| 25 |
+
Path("./docs").mkdir(exist_ok=True)
|
| 26 |
+
test_file = Path("./docs/_pytest_temp.txt")
|
| 27 |
+
test_file.write_text(
|
| 28 |
+
"The Eiffel Tower is in Paris, France. "
|
| 29 |
+
"It was built in 1889. It is 330 metres tall."
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
from ingestion import run_ingestion
|
| 34 |
+
from retriever import load_indexes, hybrid_retrieve
|
| 35 |
+
from agent import run_rag_agent
|
| 36 |
+
|
| 37 |
+
run_ingestion()
|
| 38 |
+
load_indexes()
|
| 39 |
+
|
| 40 |
+
results = hybrid_retrieve("How tall is the Eiffel Tower?", top_k=3)
|
| 41 |
+
assert len(results) > 0
|
| 42 |
+
assert "ce_score" in results[0] # reranker ran
|
| 43 |
+
|
| 44 |
+
answer, retries, verdict = run_rag_agent(
|
| 45 |
+
"How tall is the Eiffel Tower?", results
|
| 46 |
+
)
|
| 47 |
+
assert "330" in answer or "metres" in answer.lower()
|
| 48 |
+
assert verdict in {"PASS", "FAIL"}
|
| 49 |
+
|
| 50 |
+
finally:
|
| 51 |
+
test_file.unlink(missing_ok=True) # always clean up
|
tests/test_unit.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# tests/test_unit.py
|
| 2 |
+
import pytest
|
| 3 |
+
|
| 4 |
+
# ββ RRF logic βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5 |
+
|
| 6 |
+
def test_rrf_prefers_doc_appearing_in_both_lists():
|
| 7 |
+
from retriever import _reciprocal_rank_fusion
|
| 8 |
+
scores = _reciprocal_rank_fusion([[0, 1, 2], [2, 0, 1]])
|
| 9 |
+
# doc 2 is rank-0 in sparse and rank-2 in dense β should beat doc 1
|
| 10 |
+
assert scores[2] > scores[1]
|
| 11 |
+
|
| 12 |
+
def test_rrf_returns_all_docs():
|
| 13 |
+
from retriever import _reciprocal_rank_fusion
|
| 14 |
+
scores = _reciprocal_rank_fusion([[0, 1], [1, 2]])
|
| 15 |
+
assert set(scores.keys()) == {0, 1, 2}
|
| 16 |
+
|
| 17 |
+
def test_rrf_scores_are_positive():
|
| 18 |
+
from retriever import _reciprocal_rank_fusion
|
| 19 |
+
scores = _reciprocal_rank_fusion([[0, 1, 2]])
|
| 20 |
+
assert all(v > 0 for v in scores.values())
|
| 21 |
+
|
| 22 |
+
# ββ Config sanity βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 23 |
+
|
| 24 |
+
def test_config_values_are_sane():
|
| 25 |
+
from config import CHUNK_SIZE, CHUNK_OVERLAP, TOP_K, MAX_RETRIES
|
| 26 |
+
assert CHUNK_SIZE > CHUNK_OVERLAP, "overlap must be smaller than chunk size"
|
| 27 |
+
assert TOP_K > 0, "TOP_K must be positive"
|
| 28 |
+
assert MAX_RETRIES >= 1, "need at least 1 retry"
|
| 29 |
+
|
| 30 |
+
def test_groq_api_key_present(monkeypatch):
|
| 31 |
+
# patch so we don't need a real key in CI
|
| 32 |
+
monkeypatch.setenv("GROQ_API_KEY", "gsk_fakekeyfortesting1234567890")
|
| 33 |
+
import importlib, config
|
| 34 |
+
importlib.reload(config) # re-reads env
|
| 35 |
+
assert len(config.GROQ_API_KEY) > 10
|
| 36 |
+
|
| 37 |
+
# ββ Agent routing logic βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
+
|
| 39 |
+
def test_route_returns_done_on_pass():
|
| 40 |
+
from agent import route_after_validation
|
| 41 |
+
state = {"validation_result": "PASS", "retry_count": 0}
|
| 42 |
+
assert route_after_validation(state) == "done"
|
| 43 |
+
|
| 44 |
+
def test_route_returns_retry_on_fail_within_limit():
|
| 45 |
+
from agent import route_after_validation
|
| 46 |
+
state = {"validation_result": "FAIL", "retry_count": 0}
|
| 47 |
+
assert route_after_validation(state) == "retry"
|
| 48 |
+
|
| 49 |
+
def test_route_returns_done_when_retries_exhausted():
|
| 50 |
+
from agent import route_after_validation
|
| 51 |
+
state = {"validation_result": "FAIL", "retry_count": 3}
|
| 52 |
+
assert route_after_validation(state) == "done"
|
| 53 |
+
|
| 54 |
+
def test_increment_retry_node():
|
| 55 |
+
from agent import increment_retry_node
|
| 56 |
+
result = increment_retry_node({"retry_count": 1})
|
| 57 |
+
assert result["retry_count"] == 2
|
| 58 |
+
|
| 59 |
+
# ββ Retriever output shape (mocked indexes) βββββββββββββββββββββββββββββββββββ
|
| 60 |
+
|
| 61 |
+
@pytest.fixture
|
| 62 |
+
def mock_indexes(monkeypatch):
|
| 63 |
+
"""Patches all globals in retriever so no files need to exist."""
|
| 64 |
+
import numpy as np
|
| 65 |
+
import retriever
|
| 66 |
+
|
| 67 |
+
# Fake chunks and sources
|
| 68 |
+
fake_chunks = ["Paris is in France.", "Tower is 330m tall.", "Built in 1889."]
|
| 69 |
+
fake_sources = ["doc1.txt", "doc1.txt", "doc1.txt"]
|
| 70 |
+
|
| 71 |
+
# Fake FAISS index that always returns ids [0, 1, 2]
|
| 72 |
+
class FakeFaiss:
|
| 73 |
+
ntotal = 3
|
| 74 |
+
def search(self, vec, k):
|
| 75 |
+
ids = np.array([[0, 1, 2]])
|
| 76 |
+
return None, ids
|
| 77 |
+
|
| 78 |
+
# Fake BM25 that returns uniform scores
|
| 79 |
+
class FakeBM25:
|
| 80 |
+
def get_scores(self, tokens):
|
| 81 |
+
return np.array([0.9, 0.5, 0.3])
|
| 82 |
+
|
| 83 |
+
# Fake embedder
|
| 84 |
+
class FakeModel:
|
| 85 |
+
def encode(self, texts, convert_to_numpy=True):
|
| 86 |
+
return np.random.rand(len(texts), 384).astype("float32")
|
| 87 |
+
|
| 88 |
+
# Fake cross-encoder
|
| 89 |
+
class FakeReranker:
|
| 90 |
+
def predict(self, pairs):
|
| 91 |
+
return np.array([0.9, 0.7, 0.5][: len(pairs)])
|
| 92 |
+
|
| 93 |
+
monkeypatch.setattr(retriever, "_faiss_index", FakeFaiss())
|
| 94 |
+
monkeypatch.setattr(retriever, "_bm25_index", FakeBM25())
|
| 95 |
+
monkeypatch.setattr(retriever, "_chunks", fake_chunks)
|
| 96 |
+
monkeypatch.setattr(retriever, "_sources", fake_sources)
|
| 97 |
+
monkeypatch.setattr(retriever, "_model", FakeModel())
|
| 98 |
+
monkeypatch.setattr(retriever, "_reranker", FakeReranker())
|
| 99 |
+
return fake_chunks
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def test_hybrid_retrieve_returns_top_k(mock_indexes):
|
| 103 |
+
from retriever import hybrid_retrieve
|
| 104 |
+
results = hybrid_retrieve("Where is Paris?", top_k=2)
|
| 105 |
+
assert len(results) == 2
|
| 106 |
+
|
| 107 |
+
def test_hybrid_retrieve_result_has_required_keys(mock_indexes):
|
| 108 |
+
from retriever import hybrid_retrieve
|
| 109 |
+
result = hybrid_retrieve("Where is Paris?", top_k=1)[0]
|
| 110 |
+
assert "chunk" in result
|
| 111 |
+
assert "source" in result
|
| 112 |
+
assert "rrf_score" in result
|
| 113 |
+
assert "ce_score" in result
|
| 114 |
+
|
| 115 |
+
def test_hybrid_retrieve_scores_are_floats(mock_indexes):
|
| 116 |
+
from retriever import hybrid_retrieve
|
| 117 |
+
result = hybrid_retrieve("test", top_k=1)[0]
|
| 118 |
+
assert isinstance(result["rrf_score"], float)
|
| 119 |
+
assert isinstance(result["ce_score"], float)
|