Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import numpy as np | |
| from researchmind.config import ResearchMindConfig | |
| from researchmind.retrieve import retrieve | |
| from researchmind.store import MemRAGStore | |
| def _fake_embed(monkeypatch): | |
| def fake_embed_texts(texts, *, model_name): | |
| out = [] | |
| for t in texts: | |
| if "photosynthesis" in t.lower(): | |
| out.append(np.array([1.0, 0.0], dtype=np.float32)) | |
| else: | |
| out.append(np.array([0.0, 1.0], dtype=np.float32)) | |
| return np.stack(out) | |
| monkeypatch.setattr("researchmind.retrieve.embed_texts", fake_embed_texts) | |
| def test_retrieve_ranks_by_similarity(tmp_path, monkeypatch): | |
| _fake_embed(monkeypatch) | |
| cfg = ResearchMindConfig( | |
| data_dir=tmp_path, | |
| embed_model="test", | |
| auto_search=False, | |
| top_k=1, | |
| max_context_chunks=8, | |
| chunk_size=512, | |
| chunk_overlap=128, | |
| ) | |
| store = MemRAGStore(cfg) | |
| store.set_embed_dim(2) | |
| store.add_document( | |
| source_type="test", | |
| uri="a", | |
| title="A", | |
| text="photosynthesis in plants", | |
| chunks=[("c1", 0, "photosynthesis in plants", np.array([1.0, 0.0], dtype=np.float32), {})], | |
| ) | |
| store.add_document( | |
| source_type="test", | |
| uri="b", | |
| title="B", | |
| text="fractions math", | |
| chunks=[("c2", 0, "fractions math", np.array([0.0, 1.0], dtype=np.float32), {})], | |
| ) | |
| hits = retrieve("photosynthesis", store, config=cfg, top_k=1, expand_neighbors=False) | |
| assert len(hits) == 1 | |
| assert "photosynthesis" in hits[0].text | |
| def test_retrieve_filters_by_session(tmp_path, monkeypatch): | |
| _fake_embed(monkeypatch) | |
| cfg = ResearchMindConfig( | |
| data_dir=tmp_path, | |
| embed_model="test", | |
| auto_search=False, | |
| top_k=2, | |
| max_context_chunks=8, | |
| chunk_size=512, | |
| chunk_overlap=128, | |
| ) | |
| store = MemRAGStore(cfg) | |
| store.set_embed_dim(2) | |
| sid_a = store.create_session(topic="a").id | |
| sid_b = store.create_session(topic="b").id | |
| store.add_document( | |
| source_type="test", | |
| uri="a", | |
| title="Plants", | |
| text="photosynthesis in plants", | |
| chunks=[("c1", 0, "photosynthesis in plants", np.array([1.0, 0.0], dtype=np.float32), {})], | |
| session_id=sid_a, | |
| ) | |
| store.add_document( | |
| source_type="test", | |
| uri="b", | |
| title="Math", | |
| text="fractions math", | |
| chunks=[("c2", 0, "fractions math", np.array([0.0, 1.0], dtype=np.float32), {})], | |
| session_id=sid_b, | |
| ) | |
| scoped = retrieve( | |
| "photosynthesis", | |
| store, | |
| config=cfg, | |
| top_k=2, | |
| expand_neighbors=False, | |
| session_id=sid_a, | |
| ) | |
| assert len(scoped) == 1 | |
| assert "photosynthesis" in scoped[0].text | |