Hitan2004 commited on
Commit
b689b3f
Β·
1 Parent(s): d0245ab

initial commit

Browse files
.github/workflows/ci.yml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: RAG Unit Tests
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+ branches: [main]
8
+
9
+ jobs:
10
+ test:
11
+ runs-on: ubuntu-latest
12
+
13
+ steps:
14
+ - uses: actions/checkout@v4
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v5
18
+ with:
19
+ python-version: "3.11"
20
+
21
+ - name: Install dependencies
22
+ run: pip install -r requirements.txt
23
+
24
+ - name: Run unit tests only # ← integration tests are skipped here
25
+ env:
26
+ GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} # add this in GitHub β†’ Settings β†’ Secrets
27
+ run: pytest tests/test_unit.py -v
agent.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import TypedDict
2
  from langgraph.graph import StateGraph, END
3
  from langchain_groq import ChatGroq
 
1
+ #agent.py
2
  from typing import TypedDict
3
  from langgraph.graph import StateGraph, END
4
  from langchain_groq import ChatGroq
ingestion.py CHANGED
@@ -80,8 +80,8 @@ def load_documents():
80
  # ─────────────────────────────────────────────────────────────
81
  def semantic_chunk(docs, filenames):
82
  splitter = RecursiveCharacterTextSplitter(
83
- chunk_size=300, # smaller chunks β†’ better retrieval
84
- chunk_overlap=80,
85
  separators=["\n\n", "\n", ". ", " "],
86
  )
87
 
 
80
  # ─────────────────────────────────────────────────────────────
81
  def semantic_chunk(docs, filenames):
82
  splitter = RecursiveCharacterTextSplitter(
83
+ chunk_size=CHUNK_SIZE,
84
+ chunk_overlap=CHUNK_OVERLAP,
85
  separators=["\n\n", "\n", ". ", " "],
86
  )
87
 
retriever.py CHANGED
@@ -1,17 +1,20 @@
1
  import pickle
2
  import numpy as np
3
  import faiss
4
- from sentence_transformers import SentenceTransformer
5
  from config import (
6
  FAISS_INDEX_PATH, BM25_PATH, CHUNKS_PATH,
7
  SOURCES_PATH, EMBEDDER_PATH
8
  )
9
 
10
- _faiss_index = None
11
- _bm25_index = None
12
- _chunks = None
13
- _sources = None
14
- _model = None
 
 
 
15
 
16
 
17
  def indexes_loaded() -> bool:
@@ -19,44 +22,69 @@ def indexes_loaded() -> bool:
19
 
20
 
21
  def load_indexes():
22
- global _faiss_index, _bm25_index, _chunks, _sources, _model
23
  _faiss_index = faiss.read_index(FAISS_INDEX_PATH)
24
- with open(BM25_PATH, "rb") as f: _bm25_index = pickle.load(f)
25
- with open(CHUNKS_PATH, "rb") as f: _chunks = pickle.load(f)
26
- with open(SOURCES_PATH,"rb") as f: _sources = pickle.load(f)
27
- _model = SentenceTransformer(EMBEDDER_PATH)
 
28
  print(f"Indexes loaded: {_faiss_index.ntotal} vectors, {len(_chunks)} chunks")
29
 
30
 
31
  def reload_indexes():
32
- global _faiss_index, _bm25_index, _chunks, _sources, _model
33
- _faiss_index = _bm25_index = _chunks = _sources = _model = None
34
  load_indexes()
35
 
36
 
37
- def _reciprocal_rank_fusion(lists: list, k: int = 60) -> list:
 
38
  scores: dict = {}
39
  for ranked_list in lists:
40
  for rank, doc_id in enumerate(ranked_list):
41
  scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank + 1)
42
- return sorted(scores.keys(), key=lambda x: scores[x], reverse=True)
43
 
44
 
45
  def hybrid_retrieve(query: str, top_k: int = 5) -> list:
46
  if not indexes_loaded():
47
  raise RuntimeError("Indexes not loaded. Call load_indexes() first.")
48
 
 
49
  q_emb = _model.encode([query], convert_to_numpy=True).astype("float32")
50
  faiss.normalize_L2(q_emb)
51
- _, dense_ids = _faiss_index.search(q_emb, top_k * 3)
52
- dense_ranking = [int(i) for i in dense_ids[0] if i >= 0]
 
 
 
 
 
 
 
 
53
 
54
- bm25_scores = _bm25_index.get_scores(query.lower().split())
55
- sparse_ranking = np.argsort(bm25_scores)[::-1][:top_k * 3].tolist()
 
 
 
56
 
57
- merged = _reciprocal_rank_fusion([dense_ranking, sparse_ranking])[:top_k]
 
 
 
 
 
58
 
59
  return [
60
- {"chunk": _chunks[i], "source": _sources[i], "chunk_id": i}
61
- for i in merged
 
 
 
 
 
 
62
  ]
 
1
  import pickle
2
  import numpy as np
3
  import faiss
4
+ from sentence_transformers import SentenceTransformer, CrossEncoder
5
  from config import (
6
  FAISS_INDEX_PATH, BM25_PATH, CHUNKS_PATH,
7
  SOURCES_PATH, EMBEDDER_PATH
8
  )
9
 
10
+ _faiss_index = None
11
+ _bm25_index = None
12
+ _chunks = None
13
+ _sources = None
14
+ _model = None
15
+ _reranker = None
16
+
17
+ RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
18
 
19
 
20
  def indexes_loaded() -> bool:
 
22
 
23
 
24
  def load_indexes():
25
+ global _faiss_index, _bm25_index, _chunks, _sources, _model, _reranker
26
  _faiss_index = faiss.read_index(FAISS_INDEX_PATH)
27
+ with open(BM25_PATH, "rb") as f: _bm25_index = pickle.load(f)
28
+ with open(CHUNKS_PATH, "rb") as f: _chunks = pickle.load(f)
29
+ with open(SOURCES_PATH, "rb") as f: _sources = pickle.load(f)
30
+ _model = SentenceTransformer(EMBEDDER_PATH)
31
+ _reranker = CrossEncoder(RERANKER_MODEL) # ← reranker loads once
32
  print(f"Indexes loaded: {_faiss_index.ntotal} vectors, {len(_chunks)} chunks")
33
 
34
 
35
  def reload_indexes():
36
+ global _faiss_index, _bm25_index, _chunks, _sources, _model, _reranker
37
+ _faiss_index = _bm25_index = _chunks = _sources = _model = _reranker = None
38
  load_indexes()
39
 
40
 
41
+ def _reciprocal_rank_fusion(lists: list, k: int = 60) -> dict:
42
+ """Returns {doc_id: rrf_score} instead of just a sorted list."""
43
  scores: dict = {}
44
  for ranked_list in lists:
45
  for rank, doc_id in enumerate(ranked_list):
46
  scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank + 1)
47
+ return scores # ← return the dict now
48
 
49
 
50
  def hybrid_retrieve(query: str, top_k: int = 5) -> list:
51
  if not indexes_loaded():
52
  raise RuntimeError("Indexes not loaded. Call load_indexes() first.")
53
 
54
+ # ── Dense retrieval (FAISS) ───────────────────────────────
55
  q_emb = _model.encode([query], convert_to_numpy=True).astype("float32")
56
  faiss.normalize_L2(q_emb)
57
+ _, dense_ids = _faiss_index.search(q_emb, top_k * 3)
58
+ dense_ranking = [int(i) for i in dense_ids[0] if i >= 0]
59
+
60
+ # ── Sparse retrieval (BM25) ───────────────────────────────
61
+ bm25_scores = _bm25_index.get_scores(query.lower().split())
62
+ sparse_ranking = np.argsort(bm25_scores)[::-1][: top_k * 3].tolist()
63
+
64
+ # ── Fusion (RRF) β€” now returns score dict ─────────────────
65
+ rrf_scores = _reciprocal_rank_fusion([dense_ranking, sparse_ranking])
66
+ fused_ids = sorted(rrf_scores, key=rrf_scores.get, reverse=True)[: top_k * 2]
67
 
68
+ # ── Cross-encoder reranking ───────────────────────────────
69
+ # The cross-encoder scores each (query, chunk) pair together
70
+ # much more accurately than embedding similarity alone
71
+ candidates = [(query, _chunks[i]) for i in fused_ids]
72
+ ce_scores = _reranker.predict(candidates) # shape: (len(candidates),)
73
 
74
+ # Sort by cross-encoder score, keep top_k
75
+ ranked = sorted(
76
+ zip(fused_ids, ce_scores),
77
+ key=lambda x: x[1],
78
+ reverse=True,
79
+ )[:top_k]
80
 
81
  return [
82
+ {
83
+ "chunk": _chunks[i],
84
+ "source": _sources[i],
85
+ "chunk_id": i,
86
+ "rrf_score": round(float(rrf_scores[i]), 4),
87
+ "ce_score": round(float(score), 4), # ← reranker confidence
88
+ }
89
+ for i, score in ranked
90
  ]
test_sources.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from retriever import load_indexes, _sources
2
 
3
  load_indexes()
 
1
+ #test_source.py
2
  from retriever import load_indexes, _sources
3
 
4
  load_indexes()
tests/__init__.py ADDED
File without changes
tests/test_integration.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tests/test_integration.py
2
+ # Run with: pytest tests/test_integration.py -v -m integration
3
+ # These call real APIs β€” don't run in CI automatically.
4
+
5
+ import pytest
6
+
7
+ pytestmark = pytest.mark.integration # tag so CI can skip these
8
+
9
+
10
+ def test_groq_connection_live():
11
+ from langchain_groq import ChatGroq
12
+ from langchain_core.messages import HumanMessage
13
+ from config import GROQ_API_KEY, GROQ_MODEL
14
+ llm = ChatGroq(model=GROQ_MODEL, temperature=0, api_key=GROQ_API_KEY)
15
+ r = llm.invoke([HumanMessage(content="Reply with just the word OK")])
16
+ assert len(r.content) > 0
17
+
18
+
19
+ def test_full_pipeline_live():
20
+ """Ingests a tiny doc, retrieves, runs agent β€” end to end."""
21
+ import os
22
+ from pathlib import Path
23
+
24
+ # Write test doc
25
+ Path("./docs").mkdir(exist_ok=True)
26
+ test_file = Path("./docs/_pytest_temp.txt")
27
+ test_file.write_text(
28
+ "The Eiffel Tower is in Paris, France. "
29
+ "It was built in 1889. It is 330 metres tall."
30
+ )
31
+
32
+ try:
33
+ from ingestion import run_ingestion
34
+ from retriever import load_indexes, hybrid_retrieve
35
+ from agent import run_rag_agent
36
+
37
+ run_ingestion()
38
+ load_indexes()
39
+
40
+ results = hybrid_retrieve("How tall is the Eiffel Tower?", top_k=3)
41
+ assert len(results) > 0
42
+ assert "ce_score" in results[0] # reranker ran
43
+
44
+ answer, retries, verdict = run_rag_agent(
45
+ "How tall is the Eiffel Tower?", results
46
+ )
47
+ assert "330" in answer or "metres" in answer.lower()
48
+ assert verdict in {"PASS", "FAIL"}
49
+
50
+ finally:
51
+ test_file.unlink(missing_ok=True) # always clean up
tests/test_unit.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tests/test_unit.py
2
+ import pytest
3
+
4
+ # ── RRF logic ─────────────────────────────────────────────────────────────────
5
+
6
+ def test_rrf_prefers_doc_appearing_in_both_lists():
7
+ from retriever import _reciprocal_rank_fusion
8
+ scores = _reciprocal_rank_fusion([[0, 1, 2], [2, 0, 1]])
9
+ # doc 2 is rank-0 in sparse and rank-2 in dense β†’ should beat doc 1
10
+ assert scores[2] > scores[1]
11
+
12
+ def test_rrf_returns_all_docs():
13
+ from retriever import _reciprocal_rank_fusion
14
+ scores = _reciprocal_rank_fusion([[0, 1], [1, 2]])
15
+ assert set(scores.keys()) == {0, 1, 2}
16
+
17
+ def test_rrf_scores_are_positive():
18
+ from retriever import _reciprocal_rank_fusion
19
+ scores = _reciprocal_rank_fusion([[0, 1, 2]])
20
+ assert all(v > 0 for v in scores.values())
21
+
22
+ # ── Config sanity ─────────────────────────────────────────────────────────────
23
+
24
+ def test_config_values_are_sane():
25
+ from config import CHUNK_SIZE, CHUNK_OVERLAP, TOP_K, MAX_RETRIES
26
+ assert CHUNK_SIZE > CHUNK_OVERLAP, "overlap must be smaller than chunk size"
27
+ assert TOP_K > 0, "TOP_K must be positive"
28
+ assert MAX_RETRIES >= 1, "need at least 1 retry"
29
+
30
+ def test_groq_api_key_present(monkeypatch):
31
+ # patch so we don't need a real key in CI
32
+ monkeypatch.setenv("GROQ_API_KEY", "gsk_fakekeyfortesting1234567890")
33
+ import importlib, config
34
+ importlib.reload(config) # re-reads env
35
+ assert len(config.GROQ_API_KEY) > 10
36
+
37
+ # ── Agent routing logic ───────────────────────────────────────────────────────
38
+
39
+ def test_route_returns_done_on_pass():
40
+ from agent import route_after_validation
41
+ state = {"validation_result": "PASS", "retry_count": 0}
42
+ assert route_after_validation(state) == "done"
43
+
44
+ def test_route_returns_retry_on_fail_within_limit():
45
+ from agent import route_after_validation
46
+ state = {"validation_result": "FAIL", "retry_count": 0}
47
+ assert route_after_validation(state) == "retry"
48
+
49
+ def test_route_returns_done_when_retries_exhausted():
50
+ from agent import route_after_validation
51
+ state = {"validation_result": "FAIL", "retry_count": 3}
52
+ assert route_after_validation(state) == "done"
53
+
54
+ def test_increment_retry_node():
55
+ from agent import increment_retry_node
56
+ result = increment_retry_node({"retry_count": 1})
57
+ assert result["retry_count"] == 2
58
+
59
+ # ── Retriever output shape (mocked indexes) ───────────────────────────────────
60
+
61
+ @pytest.fixture
62
+ def mock_indexes(monkeypatch):
63
+ """Patches all globals in retriever so no files need to exist."""
64
+ import numpy as np
65
+ import retriever
66
+
67
+ # Fake chunks and sources
68
+ fake_chunks = ["Paris is in France.", "Tower is 330m tall.", "Built in 1889."]
69
+ fake_sources = ["doc1.txt", "doc1.txt", "doc1.txt"]
70
+
71
+ # Fake FAISS index that always returns ids [0, 1, 2]
72
+ class FakeFaiss:
73
+ ntotal = 3
74
+ def search(self, vec, k):
75
+ ids = np.array([[0, 1, 2]])
76
+ return None, ids
77
+
78
+ # Fake BM25 that returns uniform scores
79
+ class FakeBM25:
80
+ def get_scores(self, tokens):
81
+ return np.array([0.9, 0.5, 0.3])
82
+
83
+ # Fake embedder
84
+ class FakeModel:
85
+ def encode(self, texts, convert_to_numpy=True):
86
+ return np.random.rand(len(texts), 384).astype("float32")
87
+
88
+ # Fake cross-encoder
89
+ class FakeReranker:
90
+ def predict(self, pairs):
91
+ return np.array([0.9, 0.7, 0.5][: len(pairs)])
92
+
93
+ monkeypatch.setattr(retriever, "_faiss_index", FakeFaiss())
94
+ monkeypatch.setattr(retriever, "_bm25_index", FakeBM25())
95
+ monkeypatch.setattr(retriever, "_chunks", fake_chunks)
96
+ monkeypatch.setattr(retriever, "_sources", fake_sources)
97
+ monkeypatch.setattr(retriever, "_model", FakeModel())
98
+ monkeypatch.setattr(retriever, "_reranker", FakeReranker())
99
+ return fake_chunks
100
+
101
+
102
+ def test_hybrid_retrieve_returns_top_k(mock_indexes):
103
+ from retriever import hybrid_retrieve
104
+ results = hybrid_retrieve("Where is Paris?", top_k=2)
105
+ assert len(results) == 2
106
+
107
+ def test_hybrid_retrieve_result_has_required_keys(mock_indexes):
108
+ from retriever import hybrid_retrieve
109
+ result = hybrid_retrieve("Where is Paris?", top_k=1)[0]
110
+ assert "chunk" in result
111
+ assert "source" in result
112
+ assert "rrf_score" in result
113
+ assert "ce_score" in result
114
+
115
+ def test_hybrid_retrieve_scores_are_floats(mock_indexes):
116
+ from retriever import hybrid_retrieve
117
+ result = hybrid_retrieve("test", top_k=1)[0]
118
+ assert isinstance(result["rrf_score"], float)
119
+ assert isinstance(result["ce_score"], float)