lesson-agent-dev / libs /researchmind /tests /test_retrieve.py
MSG
Feat/research tab agent skills (#5)
e7fd66f
Raw
History Blame Contribute Delete
2.8 kB
from __future__ import annotations
import numpy as np
from researchmind.config import ResearchMindConfig
from researchmind.retrieve import retrieve
from researchmind.store import MemRAGStore
def _fake_embed(monkeypatch):
def fake_embed_texts(texts, *, model_name):
out = []
for t in texts:
if "photosynthesis" in t.lower():
out.append(np.array([1.0, 0.0], dtype=np.float32))
else:
out.append(np.array([0.0, 1.0], dtype=np.float32))
return np.stack(out)
monkeypatch.setattr("researchmind.retrieve.embed_texts", fake_embed_texts)
def test_retrieve_ranks_by_similarity(tmp_path, monkeypatch):
_fake_embed(monkeypatch)
cfg = ResearchMindConfig(
data_dir=tmp_path,
embed_model="test",
auto_search=False,
top_k=1,
max_context_chunks=8,
chunk_size=512,
chunk_overlap=128,
)
store = MemRAGStore(cfg)
store.set_embed_dim(2)
store.add_document(
source_type="test",
uri="a",
title="A",
text="photosynthesis in plants",
chunks=[("c1", 0, "photosynthesis in plants", np.array([1.0, 0.0], dtype=np.float32), {})],
)
store.add_document(
source_type="test",
uri="b",
title="B",
text="fractions math",
chunks=[("c2", 0, "fractions math", np.array([0.0, 1.0], dtype=np.float32), {})],
)
hits = retrieve("photosynthesis", store, config=cfg, top_k=1, expand_neighbors=False)
assert len(hits) == 1
assert "photosynthesis" in hits[0].text
def test_retrieve_filters_by_session(tmp_path, monkeypatch):
_fake_embed(monkeypatch)
cfg = ResearchMindConfig(
data_dir=tmp_path,
embed_model="test",
auto_search=False,
top_k=2,
max_context_chunks=8,
chunk_size=512,
chunk_overlap=128,
)
store = MemRAGStore(cfg)
store.set_embed_dim(2)
sid_a = store.create_session(topic="a").id
sid_b = store.create_session(topic="b").id
store.add_document(
source_type="test",
uri="a",
title="Plants",
text="photosynthesis in plants",
chunks=[("c1", 0, "photosynthesis in plants", np.array([1.0, 0.0], dtype=np.float32), {})],
session_id=sid_a,
)
store.add_document(
source_type="test",
uri="b",
title="Math",
text="fractions math",
chunks=[("c2", 0, "fractions math", np.array([0.0, 1.0], dtype=np.float32), {})],
session_id=sid_b,
)
scoped = retrieve(
"photosynthesis",
store,
config=cfg,
top_k=2,
expand_neighbors=False,
session_id=sid_a,
)
assert len(scoped) == 1
assert "photosynthesis" in scoped[0].text