MAX / test_rag_system.py
1
Update RAG system: universal document chatbot, Gemini support, fix dependencies
3257ee5
"""
ВСсты для RAG-систСмы.
Запуск: python3 -m pytest test_rag_system.py -v
"""
import os
import sys
import types
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch, PropertyMock
# ─────────────────────────────────────────────
# Π—Π°Π³Π»ΡƒΡˆΠΊΠΈ для тяТёлых зависимостСй (Π½Π΅ Π³Ρ€ΡƒΠ·ΠΈΠΌ ΠΌΠΎΠ΄Π΅Π»ΠΈ ΠΏΡ€ΠΈ тСстах)
# ─────────────────────────────────────────────
def _make_stub(name):
mod = types.ModuleType(name)
sys.modules[name] = mod
return mod
# Stub langchain_community
for mod_name in [
"langchain_community",
"langchain_community.document_loaders",
"langchain_community.embeddings",
"langchain_community.llms",
"langchain_community.retrievers",
]:
_make_stub(mod_name)
sys.modules["langchain_community.retrievers"].BM25Retriever = MagicMock()
# Stub langchain_chroma
lc_chroma = _make_stub("langchain_chroma")
lc_chroma.Chroma = MagicMock()
# Stub langchain_openai
lc_openai = _make_stub("langchain_openai")
lc_openai.ChatOpenAI = MagicMock()
# Stub langchain_text_splitters
lts = _make_stub("langchain_text_splitters")
lts.RecursiveCharacterTextSplitter = MagicMock()
# Stub langchain_core
for mod_name in [
"langchain_core",
"langchain_core.prompts",
"langchain_core.runnables",
"langchain_core.output_parsers",
"langchain_core.documents",
"langchain_core.callbacks",
"langchain_core.callbacks.manager",
]:
_make_stub(mod_name)
sys.modules["langchain_core.prompts"].PromptTemplate = MagicMock()
sys.modules["langchain_core.runnables"].RunnablePassthrough = MagicMock()
sys.modules["langchain_core.output_parsers"].StrOutputParser = MagicMock()
# Π—Π°Π³Π»ΡƒΡˆΠΊΠ° Document
class FakeDocument:
def __init__(self, page_content="", metadata=None):
self.page_content = page_content
self.metadata = metadata or {}
sys.modules["langchain_core.documents"].Document = FakeDocument
# Stub langchain_huggingface
lc_hf = _make_stub("langchain_huggingface")
lc_hf.HuggingFaceEmbeddings = MagicMock()
# Stub sentence_transformers / torch
_make_stub("sentence_transformers")
_make_stub("torch")
# Stub PyPDFLoader
sys.modules["langchain_community.document_loaders"].PyPDFLoader = MagicMock()
# Stub pdfplumber
_make_stub("pdfplumber")
# Stub python-docx
_make_stub("docx")
# ── Π’Π΅ΠΏΠ΅Ρ€ΡŒ бСзопасно ΠΈΠΌΠΏΠΎΡ€Ρ‚ΠΈΡ€ΠΎΠ²Π°Ρ‚ΡŒ ──────────────────────────────────────────
sys.path.insert(0, str(Path(__file__).parent))
import importlib, rag_system as _rag_module # noqa: E402
# ΠŸΠ°Ρ‚Ρ‡ΠΈΠΌ тяТёлыС Π²Π΅Ρ‰ΠΈ Π½Π° ΡƒΡ€ΠΎΠ²Π½Π΅ модуля
_rag_module.HuggingFaceEmbeddings = MagicMock()
_rag_module.Chroma = MagicMock()
_rag_module.RecursiveCharacterTextSplitter = MagicMock()
_rag_module.PyPDFLoader = MagicMock()
from rag_system import RAGAnswer, RAGSystem # noqa: E402
# ─────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────
def _make_rag(provider="openai") -> RAGSystem:
"""Π‘ΠΎΠ·Π΄Π°Ρ‘Ρ‚ RAGSystem с ΠΏΠΎΠ»Π½ΠΎΡΡ‚ΡŒΡŽ Π·Π°Π³Π»ΡƒΡˆΡ‘Π½Π½Ρ‹ΠΌΠΈ зависимостями."""
with patch.object(_rag_module, "HuggingFaceEmbeddings") as mock_emb, \
patch.object(_rag_module, "Chroma") as mock_chroma, \
patch.object(_rag_module, "RecursiveCharacterTextSplitter") as mock_splitter, \
patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test"}):
# vectorstore mock
vs = MagicMock()
vs._collection.count.return_value = 0
mock_chroma.return_value = vs
rag = RAGSystem(llm_provider=provider)
rag.vectorstore = vs
return rag
# ─────────────────────────────────────────────
# ВСст 1: RAGAnswer.__str__ β€” Ρ„ΠΎΡ€ΠΌΠ°Ρ‚ΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ Π²Ρ‹Π²ΠΎΠ΄Π°
# ─────────────────────────────────────────────
class TestRAGAnswer(unittest.TestCase):
def test_str_basic(self):
answer = RAGAnswer(
answer="Π‘Ρ‚Π°Π²ΠΊΠ° 15%",
sources=[{"file": "doc.pdf", "page": 3}],
)
text = str(answer)
self.assertIn("Π‘Ρ‚Π°Π²ΠΊΠ° 15%", text)
self.assertIn("doc.pdf", text)
self.assertIn("стр. 3", text)
def test_str_deduplicates_sources(self):
"""ΠžΠ΄ΠΈΠ½Π°ΠΊΠΎΠ²Ρ‹Π΅ источники Π΄ΠΎΠ»ΠΆΠ½Ρ‹ Π²Ρ‹Π²ΠΎΠ΄ΠΈΡ‚ΡŒΡΡ Ρ‚ΠΎΠ»ΡŒΠΊΠΎ ΠΎΠ΄ΠΈΠ½ Ρ€Π°Π·."""
answer = RAGAnswer(
answer="ВСкст",
sources=[
{"file": "a.pdf", "page": 1},
{"file": "a.pdf", "page": 1}, # Π΄ΡƒΠ±Π»ΡŒ
{"file": "b.pdf", "page": 2},
],
)
text = str(answer)
self.assertEqual(text.count("a.pdf"), 1)
self.assertEqual(text.count("b.pdf"), 1)
def test_str_no_sources(self):
answer = RAGAnswer(answer="Π˜Π½Ρ„ΠΎΡ€ΠΌΠ°Ρ†ΠΈΡ Π½Π΅ Π½Π°ΠΉΠ΄Π΅Π½Π°", sources=[])
text = str(answer)
self.assertIn("Π˜Π½Ρ„ΠΎΡ€ΠΌΠ°Ρ†ΠΈΡ Π½Π΅ Π½Π°ΠΉΠ΄Π΅Π½Π°", text)
self.assertIn("Π˜ΡΡ‚ΠΎΡ‡Π½ΠΈΠΊΠΈ (Ρ€Π΅Π»Π΅Π²Π°Π½Ρ‚Π½ΠΎΡΡ‚ΡŒ):", text)
# ─────────────────────────────────────────────
# ВСст 2: _resolve_pdf_paths β€” Ρ€Π°Π·Π±ΠΎΡ€ ΠΏΡƒΡ‚Π΅ΠΉ
# ─────────────────────────────────────────────
class TestResolvePdfPaths(unittest.TestCase):
def test_single_pdf_file(self, tmp_path=None):
import tempfile, shutil
tmp_path = Path(tempfile.mkdtemp())
pdf = tmp_path / "test.pdf"
pdf.write_bytes(b"%PDF-1.4")
result = RAGSystem._resolve_pdf_paths(str(pdf))
self.assertEqual(result, [str(pdf)])
shutil.rmtree(tmp_path)
def test_directory_with_pdfs(self):
import tempfile, shutil
tmp_path = Path(tempfile.mkdtemp())
(tmp_path / "a.pdf").write_bytes(b"%PDF")
(tmp_path / "b.pdf").write_bytes(b"%PDF")
(tmp_path / "notes.txt").write_text("not a pdf")
result = RAGSystem._resolve_pdf_paths(str(tmp_path))
names = {Path(p).name for p in result}
self.assertIn("a.pdf", names)
self.assertIn("b.pdf", names)
self.assertNotIn("notes.txt", names)
shutil.rmtree(tmp_path)
def test_nonexistent_path_returns_empty(self):
result = RAGSystem._resolve_pdf_paths("/not/existing/path")
self.assertEqual(result, [])
def test_non_pdf_file_returns_empty(self):
import tempfile, shutil
tmp_path = Path(tempfile.mkdtemp())
txt = tmp_path / "readme.txt"
txt.write_text("hello")
result = RAGSystem._resolve_pdf_paths(str(txt))
self.assertEqual(result, [])
shutil.rmtree(tmp_path)
def test_nested_pdfs(self):
import tempfile, shutil
tmp_path = Path(tempfile.mkdtemp())
sub = tmp_path / "sub"
sub.mkdir()
(sub / "nested.pdf").write_bytes(b"%PDF")
result = RAGSystem._resolve_pdf_paths(str(tmp_path))
self.assertTrue(any("nested.pdf" in p for p in result))
shutil.rmtree(tmp_path)
# ─────────────────────────────────────────────
# ВСст 3: ask_question β€” ΠΏΡ€Π°Π²ΠΈΠ»ΡŒΠ½Π°Ρ Ρ€Π°Π±ΠΎΡ‚Π° Ρ†Π΅ΠΏΠΎΡ‡ΠΊΠΈ
# ─────────────────────────────────────────────
class TestAskQuestion(unittest.TestCase):
def setUp(self):
self.rag = _make_rag()
def test_returns_rag_answer(self):
doc = FakeDocument(
page_content="Максимальная ставка ΠΏΠΎ Π²ΠΊΠ»Π°Π΄Ρƒ составляСт 15%.",
metadata={"source_file": "bank.pdf", "page": 0},
)
self.rag.similarity_search = MagicMock(return_value=[(doc, 0.85)])
mock_chain_result = "Π‘Ρ‚Π°Π²ΠΊΠ° составляСт 15%."
mock_chain = MagicMock()
mock_chain.invoke.return_value = mock_chain_result
with patch.object(_rag_module, "StrOutputParser") as mock_parser:
# Π‘Ρ‚Ρ€ΠΎΠΈΠΌ Ρ†Π΅ΠΏΠΎΡ‡ΠΊΡƒ: prompt | llm | parser => mock_chain
self.rag.prompt.__or__ = MagicMock(return_value=mock_chain)
mock_chain.__or__ = MagicMock(return_value=mock_chain)
result = self.rag.ask_question("Какая ставка ΠΏΠΎ Π²ΠΊΠ»Π°Π΄Ρƒ?")
self.assertIsInstance(result, RAGAnswer)
def test_empty_vectorstore_returns_not_found(self):
self.rag.similarity_search = MagicMock(return_value=[])
result = self.rag.ask_question("Какой курс Π΄ΠΎΠ»Π»Π°Ρ€Π°?")
self.assertEqual(result.answer, "Π˜Π½Ρ„ΠΎΡ€ΠΌΠ°Ρ†ΠΈΡ Π½Π΅ Π½Π°ΠΉΠ΄Π΅Π½Π°")
self.assertEqual(result.sources, [])
def test_sources_page_number_incremented(self):
"""Π‘Ρ‚Ρ€Π°Π½ΠΈΡ†Ρ‹ Π½ΡƒΠΌΠ΅Ρ€ΡƒΡŽΡ‚ΡΡ с 0 β€” ΠΌΡ‹ Π΄ΠΎΠ»ΠΆΠ½Ρ‹ ΠΏΠΎΠΊΠ°Π·Ρ‹Π²Π°Ρ‚ΡŒ +1."""
doc = FakeDocument(
page_content="ВСкст Π΄ΠΎΠΊΡƒΠΌΠ΅Π½Ρ‚Π°.",
metadata={"source_file": "contract.pdf", "page": 4}, # страница 4 β†’ Π΄ΠΎΠ»ΠΆΠ½Π° ΡΡ‚Π°Ρ‚ΡŒ 5
)
self.rag.similarity_search = MagicMock(return_value=[(doc, 0.90)])
mock_chain = MagicMock()
mock_chain.invoke.return_value = "ΠžΡ‚Π²Π΅Ρ‚ ΠΈΠ· Π΄ΠΎΠΊΡƒΠΌΠ΅Π½Ρ‚Π°."
self.rag.prompt.__or__ = MagicMock(return_value=mock_chain)
mock_chain.__or__ = MagicMock(return_value=mock_chain)
result = self.rag.ask_question("Вопрос?")
self.assertEqual(result.sources[0]["page"], 5)
def test_source_fallback_to_source_metadata(self):
"""Если Π½Π΅Ρ‚ source_file, Π±Π΅Ρ€Ρ‘ΠΌ имя ΠΈΠ· поля 'source'."""
doc = FakeDocument(
page_content="ВСкст.",
metadata={"source": "/path/to/report.pdf", "page": 0},
)
self.rag.similarity_search = MagicMock(return_value=[(doc, 0.75)])
mock_chain = MagicMock()
mock_chain.invoke.return_value = "ΠžΡ‚Π²Π΅Ρ‚."
self.rag.prompt.__or__ = MagicMock(return_value=mock_chain)
mock_chain.__or__ = MagicMock(return_value=mock_chain)
result = self.rag.ask_question("Вопрос?")
self.assertEqual(result.sources[0]["file"], "report.pdf")
# ─────────────────────────────────────────────
# ВСст 4: add_documents β€” индСксация PDF
# ─────────────────────────────────────────────
class TestAddDocuments(unittest.TestCase):
def setUp(self):
self.rag = _make_rag()
def test_no_pdfs_returns_zero(self):
count = self.rag.add_documents("/no/such/dir")
self.assertEqual(count, 0)
def test_pdf_is_split_and_indexed(self):
import tempfile, shutil
tmp = Path(tempfile.mkdtemp())
pdf_path = tmp / "doc.pdf"
pdf_path.write_bytes(b"%PDF-1.4")
fake_page = FakeDocument(page_content="ВСкст Π΄ΠΎΠΊΡƒΠΌΠ΅Π½Ρ‚Π°.", metadata={"source_file": "doc.pdf", "page": 0})
fake_chunk = FakeDocument(page_content="ВСкст", metadata={"source_file": "doc.pdf", "page": 0})
with patch.object(RAGSystem, "_load_pdf_with_tables", return_value=[fake_page]):
self.rag.splitter.split_documents = MagicMock(return_value=[fake_chunk])
self.rag.vectorstore.add_documents = MagicMock()
self.rag.vectorstore._collection.count.return_value = 1
result = self.rag.add_documents(str(tmp))
self.assertEqual(result, 1)
self.rag.vectorstore.add_documents.assert_called_once()
shutil.rmtree(tmp)
def test_loader_exception_is_handled(self):
import tempfile, shutil
tmp = Path(tempfile.mkdtemp())
(tmp / "bad.pdf").write_bytes(b"not a pdf")
loader_instance = MagicMock()
loader_instance.load.side_effect = Exception("corrupt pdf")
_rag_module.PyPDFLoader = MagicMock(return_value=loader_instance)
# НС Π΄ΠΎΠ»ΠΆΠ½ΠΎ Π±Ρ€ΠΎΡΠΈΡ‚ΡŒ ΠΈΡΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠ΅
result = self.rag.add_documents(str(tmp))
self.assertEqual(result, 0)
shutil.rmtree(tmp)
def test_source_file_metadata_is_set(self):
"""КаТдой страницС Π΄ΠΎΠ»ΠΆΠ½ΠΎ ΠΏΡ€ΠΎΡΡ‚Π°Π²Π»ΡΡ‚ΡŒΡΡ ΠΏΠΎΠ»Π΅ source_file."""
import tempfile, shutil
tmp = Path(tempfile.mkdtemp())
pdf_path = tmp / "annual_report.pdf"
pdf_path.write_bytes(b"%PDF-1.4")
fake_page = FakeDocument(page_content="ВСкст.", metadata={"source_file": "annual_report.pdf", "page": 0})
with patch.object(RAGSystem, "_load_pdf_with_tables", return_value=[fake_page]):
self.rag.splitter.split_documents = MagicMock(return_value=[fake_page])
self.rag.vectorstore.add_documents = MagicMock()
self.rag.vectorstore._collection.count.return_value = 1
self.rag.add_documents(str(tmp))
self.assertEqual(fake_page.metadata["source_file"], "annual_report.pdf")
shutil.rmtree(tmp)
# ─────────────────────────────────────────────
# ВСст 5: get_stats β€” статистика Π±Π°Π·Ρ‹
# ─────────────────────────────────────────────
class TestGetStats(unittest.TestCase):
def test_stats_returns_correct_keys(self):
rag = _make_rag()
rag.vectorstore._collection.count.return_value = 42
stats = rag.get_stats()
self.assertIn("collection", stats)
self.assertIn("total_chunks", stats)
self.assertIn("persist_dir", stats)
self.assertEqual(stats["total_chunks"], 42)
self.assertEqual(stats["collection"], "rag_docs")
# ─────────────────────────────────────────────
# ВСст 6: инициализация LLM β€” ΠΏΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° ΠΏΡ€ΠΎΠ²Π°ΠΉΠ΄Π΅Ρ€ΠΎΠ²
# ─────────────────────────────────────────────
class TestInitLLM(unittest.TestCase):
def test_openai_requires_api_key(self):
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("OPENAI_API_KEY", None)
with self.assertRaises(ValueError) as ctx:
RAGSystem._init_llm("openai", "gpt-4o-mini", None, "llama3", "http://localhost:11434")
self.assertIn("API", str(ctx.exception))
def test_openai_accepts_env_key(self):
mock_openai_module = MagicMock()
mock_openai_module.ChatOpenAI = MagicMock(return_value=MagicMock())
with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key"}):
with patch.dict(sys.modules, {"langchain_openai": mock_openai_module}):
# НС Π΄ΠΎΠ»ΠΆΠ½ΠΎ ΠΊΠΈΠ½ΡƒΡ‚ΡŒ ValueError ΠΎΠ± отсутствии ΠΊΠ»ΡŽΡ‡Π°
try:
RAGSystem._init_llm("openai", "gpt-4o-mini", None, "llama3", "http://localhost:11434")
except ValueError as e:
self.fail(f"ValueError Π½Π΅ Π΄ΠΎΠ»ΠΆΠ½ΠΎ Π±Ρ‹Ρ‚ΡŒ ΠΏΡ€ΠΈ Π½Π°Π»ΠΈΡ‡ΠΈΠΈ ΠΊΠ»ΡŽΡ‡Π°: {e}")
def test_unknown_provider_raises(self):
with self.assertRaises(ValueError) as ctx:
RAGSystem._init_llm("anthropic", "claude", None, "llama3", "http://localhost:11434")
self.assertIn("anthropic", str(ctx.exception))
# ─────────────────────────────────────────────
# ВСст 7: конфигурация β€” константы
# ─────────────────────────────────────────────
class TestConfiguration(unittest.TestCase):
def test_chunk_size_reasonable(self):
self.assertGreater(_rag_module.CHUNK_SIZE, 0)
self.assertLessEqual(_rag_module.CHUNK_SIZE, 4096)
def test_chunk_overlap_less_than_chunk_size(self):
self.assertLess(_rag_module.CHUNK_OVERLAP, _rag_module.CHUNK_SIZE)
def test_top_k_positive(self):
self.assertGreater(_rag_module.TOP_K, 0)
def test_embedding_model_is_multilingual(self):
"""Для банковских Π΄ΠΎΠΊΡƒΠΌΠ΅Π½Ρ‚ΠΎΠ² Π½Π° русском Π½ΡƒΠΆΠ½Π° многоязычная модСль."""
self.assertIn("multilingual", _rag_module.EMBEDDING_MODEL)
def test_prompt_contains_required_variables(self):
self.assertIn("{context}", _rag_module.STRICT_PROMPT_TEMPLATE)
self.assertIn("{question}", _rag_module.STRICT_PROMPT_TEMPLATE)
if __name__ == "__main__":
unittest.main(verbosity=2)