LeonardoMdSA commited on
Commit
0c6c6fd
·
1 Parent(s): 5366fc0

added unit tests and integration tests

Browse files
Dockerfile CHANGED
@@ -7,7 +7,7 @@ RUN apt-get update && \
7
 
8
  WORKDIR /app
9
 
10
- # Copy only what HF Space needs (exclude localhost folder)
11
  COPY multi_doc_chat ./multi_doc_chat
12
  COPY templates ./templates
13
  COPY static ./static
@@ -22,9 +22,8 @@ RUN pip install --no-cache-dir -r requirements.txt
22
  # Create folders for models/index
23
  RUN mkdir -p /app/models /app/faiss_index
24
 
25
- # HF Spaces port
26
- ENV PORT=7860
27
- EXPOSE 7860
28
 
29
  # Entrypoint
30
  CMD ["python", "app.py"]
 
7
 
8
  WORKDIR /app
9
 
10
+ # Copy only what HF Space needs
11
  COPY multi_doc_chat ./multi_doc_chat
12
  COPY templates ./templates
13
  COPY static ./static
 
22
  # Create folders for models/index
23
  RUN mkdir -p /app/models /app/faiss_index
24
 
25
+ # HuggingFace Spaces provides PORT automatically
26
+ ENV PORT=$PORT
 
27
 
28
  # Entrypoint
29
  CMD ["python", "app.py"]
app.py CHANGED
@@ -50,7 +50,7 @@ else:
50
  logger.warning("Embedding model not loaded")
51
 
52
  # --- FastAPI app ---
53
- app = FastAPI(title="MultiDocChat", version="0.1.0")
54
  app.add_middleware(
55
  CORSMiddleware,
56
  allow_origins=["*"],
@@ -154,5 +154,5 @@ def list_sessions():
154
  # --- Run server ---
155
  if __name__ == "__main__":
156
  import uvicorn
157
- port = int(os.getenv("PORT", "7860"))
158
  uvicorn.run("app:app", host="0.0.0.0", port=port)
 
50
  logger.warning("Embedding model not loaded")
51
 
52
  # --- FastAPI app ---
53
+ app = FastAPI(title="RAG Solution", version="0.1.0")
54
  app.add_middleware(
55
  CORSMiddleware,
56
  allow_origins=["*"],
 
154
  # --- Run server ---
155
  if __name__ == "__main__":
156
  import uvicorn
157
+ port = int(os.getenv("PORT"))
158
  uvicorn.run("app:app", host="0.0.0.0", port=port)
localhost/main.py CHANGED
@@ -50,7 +50,7 @@ else:
50
  logger.warning("Embedding model not loaded")
51
 
52
  # --- FastAPI app ---
53
- app = FastAPI(title="MultiDocChat", version="0.1.0")
54
  app.add_middleware(
55
  CORSMiddleware,
56
  allow_origins=["*"],
 
50
  logger.warning("Embedding model not loaded")
51
 
52
  # --- FastAPI app ---
53
+ app = FastAPI(title="RAG Solution", version="0.1.0")
54
  app.add_middleware(
55
  CORSMiddleware,
56
  allow_origins=["*"],
tests/__init__.py ADDED
File without changes
tests/conftest.py DELETED
@@ -1,149 +0,0 @@
1
- import os
2
- import io
3
- import types
4
- import json
5
- import shutil
6
- import pathlib
7
- import sys
8
- import pytest
9
-
10
- os.environ.setdefault("PYTHONPATH", str(pathlib.Path(__file__).resolve().parents[1] / "multi_doc_chat"))
11
- os.environ.setdefault("GROQ_API_KEY", "dummy")
12
- os.environ.setdefault("GOOGLE_API_KEY", "dummy")
13
- os.environ.setdefault("LLM_PROVIDER", "google")
14
-
15
- from fastapi.testclient import TestClient
16
-
17
- # Ensure repository root is importable for `import main`
18
- ROOT = pathlib.Path(__file__).resolve().parents[1]
19
- if str(ROOT) not in sys.path:
20
- sys.path.insert(0, str(ROOT))
21
-
22
- import localhost.main as main
23
-
24
-
25
- @pytest.fixture
26
- def client():
27
- return TestClient(main.app)
28
-
29
-
30
- @pytest.fixture
31
- def clear_sessions():
32
- main.SESSIONS.clear()
33
- yield
34
- main.SESSIONS.clear()
35
-
36
-
37
- @pytest.fixture
38
- def tmp_dirs(tmp_path: pathlib.Path):
39
- data_dir = tmp_path / "data"
40
- faiss_dir = tmp_path / "faiss_index"
41
- data_dir.mkdir(parents=True, exist_ok=True)
42
- faiss_dir.mkdir(parents=True, exist_ok=True)
43
- cwd = pathlib.Path.cwd()
44
- try:
45
- # Point working directories used by app code to tmp ones by chdir
46
- os.chdir(tmp_path)
47
- yield {"data": data_dir, "faiss": faiss_dir}
48
- finally:
49
- os.chdir(cwd)
50
-
51
-
52
- class _StubEmbeddings:
53
- def embed_query(self, text: str):
54
- return [0.0, 0.1, 0.2]
55
-
56
- def embed_documents(self, texts):
57
- return [[0.0, 0.1, 0.2] for _ in texts]
58
-
59
- def __call__(self, text: str):
60
- return [0.0, 0.1, 0.2]
61
-
62
-
63
- class _StubLLM:
64
- def invoke(self, input):
65
- return "stubbed answer"
66
-
67
-
68
- @pytest.fixture
69
- def stub_model_loader(monkeypatch):
70
- # Patch both module paths to cover imports via `utils.model_loader` and `multi_doc_chat.utils.model_loader`
71
- import utils.model_loader as ml_mod
72
- from multi_doc_chat.utils import model_loader as ml_mod2
73
-
74
- class FakeApiKeyMgr:
75
- def __init__(self):
76
- self.api_keys = {"GROQ_API_KEY": "x", "GOOGLE_API_KEY": "y"}
77
-
78
- def get(self, key: str) -> str:
79
- return self.api_keys[key]
80
-
81
- class FakeModelLoader:
82
- def __init__(self):
83
- self.api_key_mgr = FakeApiKeyMgr()
84
- self.config = {
85
- "embedding_model": {"model_name": "fake-embed"},
86
- "llm": {
87
- "google": {
88
- "provider": "google",
89
- "model_name": "fake-llm",
90
- "temperature": 0.0,
91
- "max_output_tokens": 128,
92
- }
93
- },
94
- }
95
-
96
- def load_embeddings(self):
97
- return _StubEmbeddings()
98
-
99
- def load_llm(self):
100
- return _StubLLM()
101
-
102
- monkeypatch.setattr(ml_mod, "ApiKeyManager", FakeApiKeyMgr)
103
- monkeypatch.setattr(ml_mod, "ModelLoader", FakeModelLoader)
104
- monkeypatch.setattr(ml_mod2, "ApiKeyManager", FakeApiKeyMgr)
105
- monkeypatch.setattr(ml_mod2, "ModelLoader", FakeModelLoader)
106
-
107
- # Also patch the already-imported symbols used in modules under test
108
- import multi_doc_chat.src.document_ingestion.data_ingestion as di
109
- import multi_doc_chat.src.document_chat.retrieval as r
110
- monkeypatch.setattr(di, "ModelLoader", FakeModelLoader)
111
- monkeypatch.setattr(r, "ModelLoader", FakeModelLoader)
112
- yield FakeModelLoader
113
-
114
-
115
- @pytest.fixture
116
- def stub_ingestor(monkeypatch):
117
- import multi_doc_chat.src.document_ingestion.data_ingestion as di
118
-
119
- class FakeIngestor:
120
- def __init__(self, use_session_dirs=True, **kwargs):
121
- self.use_session = use_session_dirs
122
- self.session_id = "sess_test"
123
-
124
- def built_retriver(self, uploaded_files, **kwargs):
125
- return None
126
-
127
- monkeypatch.setattr(di, "ChatIngestor", FakeIngestor)
128
- monkeypatch.setattr(main, "ChatIngestor", FakeIngestor)
129
- yield FakeIngestor
130
-
131
-
132
- @pytest.fixture
133
- def stub_rag(monkeypatch):
134
- import multi_doc_chat.src.document_chat.retrieval as r
135
-
136
- class FakeRAG:
137
- def __init__(self, session_id=None, retriever=None):
138
- self.session_id = session_id
139
- self.retriever = retriever
140
-
141
- def load_retriever_from_faiss(self, index_path, **kwargs):
142
- return None
143
-
144
- def invoke(self, user_input, chat_history=None):
145
- return "stubbed answer"
146
-
147
- monkeypatch.setattr(r, "ConversationalRAG", FakeRAG)
148
- monkeypatch.setattr(main, "ConversationalRAG", FakeRAG)
149
- yield FakeRAG
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/integration/test_chat_route.py DELETED
@@ -1,49 +0,0 @@
1
- import pytest
2
-
3
-
4
- def test_chat_invalid_session_returns_400(client, clear_sessions, stub_rag):
5
- body = {"session_id": "nope", "message": "hi"}
6
- resp = client.post("/chat", json=body)
7
- assert resp.status_code == 400
8
- assert "Invalid or expired" in resp.json()["detail"]
9
-
10
-
11
- def test_chat_empty_message_returns_400(client, clear_sessions, stub_rag):
12
- sid = "sess_test"
13
- import localhost.main as main
14
- main.SESSIONS[sid] = []
15
- body = {"session_id": sid, "message": " "}
16
- resp = client.post("/chat", json=body)
17
- assert resp.status_code == 400
18
- assert "Message cannot be empty" in resp.json()["detail"]
19
-
20
-
21
- def test_chat_success_returns_answer_and_appends_history(client, clear_sessions, stub_rag):
22
- sid = "sess_test"
23
- import localhost.main as main
24
- main.SESSIONS[sid] = []
25
- body = {"session_id": sid, "message": "Hello"}
26
- resp = client.post("/chat", json=body)
27
- assert resp.status_code == 200
28
- assert resp.json()["answer"] == "stubbed answer"
29
- assert len(main.SESSIONS[sid]) == 2
30
-
31
-
32
- def test_chat_failure_returns_500(client, clear_sessions, monkeypatch):
33
- sid = "sess_test"
34
- import localhost.main as main
35
- main.SESSIONS[sid] = []
36
-
37
- import localhost.main as main
38
-
39
- class BoomRAG:
40
- def __init__(self, session_id=None):
41
- pass
42
- def load_retriever_from_faiss(self, *a, **k):
43
- from multi_doc_chat.exception.custom_exception import DocumentPortalException
44
- raise DocumentPortalException("fail load", None)
45
-
46
- monkeypatch.setattr(main, "ConversationalRAG", BoomRAG)
47
- resp = client.post("/chat", json={"session_id": sid, "message": "hi"})
48
- assert resp.status_code == 500
49
- assert "fail load" in resp.json()["detail"].lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/integration/test_rag_service_flow.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from multi_doc_chat.rag_service import create_rag_service
3
+ from unittest.mock import MagicMock
4
+ import numpy as np
5
+
6
+ class FakeEmbedder:
7
+ def encode(self, texts, show_progress_bar=False):
8
+ return np.zeros((len(texts), 768), dtype="float32")
9
+
10
+ @pytest.mark.asyncio
11
+ async def test_rag_service_basic_flow():
12
+ rag = create_rag_service(faiss_dir="tests/faiss_test_index")
13
+
14
+ # patch embedder
15
+ rag.loader.embedder = FakeEmbedder()
16
+
17
+ # patch FAISS index to fake but correct shapes
18
+ class FakeIndex:
19
+ def search(self, q_vec, top_k):
20
+ # Return dummy distances and indices
21
+ return np.zeros((1, top_k)), np.zeros((1, top_k), dtype=int)
22
+ rag.index = FakeIndex()
23
+
24
+ # add docs to memory
25
+ rag.documents.extend(["Hello world", "Another chunk"])
26
+
27
+ answer = rag.query("Hello?")
28
+ assert isinstance(answer, str)
tests/integration/test_upload_route.py DELETED
@@ -1,36 +0,0 @@
1
- import io
2
- import pytest
3
-
4
-
5
- def test_upload_success_returns_session_and_indexed(client, clear_sessions, stub_ingestor, tmp_dirs):
6
- files = {"files": ("note.txt", io.BytesIO(b"hello world"), "text/plain")}
7
- resp = client.post("/upload", files=files)
8
- assert resp.status_code == 200
9
- data = resp.json()
10
- assert data["indexed"] is True
11
- assert data["session_id"]
12
-
13
-
14
- def test_upload_no_files_validation_error(client, clear_sessions, stub_ingestor):
15
- # Without files FastAPI validation will yield 422; send empty list to hit our 400
16
- resp = client.post("/upload", files=[])
17
- assert resp.status_code == 422
18
-
19
-
20
- def test_upload_ingestor_failure_returns_500(client, clear_sessions, monkeypatch, tmp_dirs):
21
- import multi_doc_chat.src.document_ingestion.data_ingestion as di
22
- import localhost.main as main
23
-
24
- class Boom:
25
- def __init__(self, *a, **k):
26
- self.session_id = "sess_test"
27
- def built_retriver(self, *a, **k):
28
- from multi_doc_chat.exception.custom_exception import DocumentPortalException
29
- raise DocumentPortalException("boom", None)
30
-
31
- monkeypatch.setattr(di, "ChatIngestor", Boom)
32
- monkeypatch.setattr(main, "ChatIngestor", Boom)
33
- files = {"files": ("note.txt", io.BytesIO(b"hello world"), "text/plain")}
34
- resp = client.post("/upload", files=files)
35
- assert resp.status_code == 500
36
- assert "boom" in resp.json()["detail"].lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/run_evaluations.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ if __name__ == "__main__":
4
+ # Run all tests
5
+ pytest.main(["-v", "tests"])
tests/unit/test_data_ingestion.py CHANGED
@@ -1,44 +1,44 @@
1
- import pathlib
2
- import pytest
3
- from langchain.schema import Document
4
-
5
- from multi_doc_chat.src.document_ingestion.data_ingestion import (
6
- generate_session_id,
7
- ChatIngestor,
8
- FaissManager,
9
- )
10
-
11
-
12
- def test_generate_session_id_format_and_uniqueness():
13
- a = generate_session_id()
14
- b = generate_session_id()
15
- assert a != b
16
- assert a.startswith("session_") and b.startswith("session_")
17
- # Rough pattern check: session_YYYYMMDD_HHMMSS_XXXXXXXX -> 4 parts
18
- assert len(a.split("_")) == 4
19
-
20
-
21
- def test_chat_ingestor_resolve_dir_uses_session_dirs(tmp_dirs, stub_model_loader):
22
- ing = ChatIngestor(temp_base="data", faiss_base="faiss_index", use_session_dirs=True)
23
- assert ing.session_id
24
- assert str(ing.temp_dir).endswith(ing.session_id)
25
- assert str(ing.faiss_dir).endswith(ing.session_id)
26
-
27
-
28
- def test_split_chunks_respect_size_and_overlap(tmp_dirs, stub_model_loader):
29
- ing = ChatIngestor(temp_base="data", faiss_base="faiss_index", use_session_dirs=True)
30
- docs = [Document(page_content="A" * 1200, metadata={"source": "x.txt"})]
31
- chunks = ing._split(docs, chunk_size=500, chunk_overlap=100)
32
- assert len(chunks) >= 2
33
- # spot check boundaries
34
- assert len(chunks[0].page_content) <= 500
35
-
36
-
37
- def test_faiss_manager_add_documents_idempotent(tmp_dirs, stub_model_loader):
38
- fm = FaissManager(index_dir=pathlib.Path("faiss_index/test"))
39
- fm.load_or_create(texts=["hello", "world"], metadatas=[{"source": "a"}, {"source": "b"}])
40
- docs = [Document(page_content="hello", metadata={"source": "a"})]
41
- first = fm.add_documents(docs)
42
- second = fm.add_documents(docs)
43
- assert first >= 0
44
- assert second == 0
 
1
+ import pytest
2
+ from multi_doc_chat.rag_service import create_rag_service
3
+ from multi_doc_chat.src.document_ingestion import data_ingestion as di
4
+ from io import BytesIO
5
+ import asyncio
6
+ from unittest.mock import MagicMock
7
+ import numpy as np
8
+
9
+ class DummyUploadFile:
10
+ def __init__(self, name, content):
11
+ self.filename = name
12
+ self.file = BytesIO(content.encode("utf-8"))
13
+ self._content = content.encode("utf-8")
14
+ def read(self):
15
+ return self._content
16
+
17
+ # Patch async file reading
18
+ import multi_doc_chat.utils.document_ops as doc_ops
19
+ async def async_read_text_fileobj(fileobj):
20
+ return await asyncio.to_thread(doc_ops.read_text_fileobj, fileobj)
21
+ di.read_text_fileobj = async_read_text_fileobj
22
+
23
+ class FakeEmbedder:
24
+ def encode(self, texts, show_progress_bar=False):
25
+ return np.zeros((len(texts), 768), dtype="float32")
26
+
27
+ @pytest.mark.asyncio
28
+ async def test_ingest_txt_file():
29
+ rag_service = create_rag_service(faiss_dir="tests/faiss_test_index")
30
+
31
+ # patch embedder
32
+ rag_service.loader.embedder = FakeEmbedder()
33
+
34
+ # patch FAISS completely
35
+ rag_service.index = MagicMock()
36
+ rag_service.index.add = MagicMock()
37
+
38
+ txt_file = DummyUploadFile("example.txt", "Hello world! This is a test.")
39
+
40
+ session_id = await di.ingest_upload_files([txt_file], rag_service)
41
+
42
+ assert session_id == "default"
43
+ assert len(rag_service.documents) > 0
44
+ assert any("Hello world" in chunk for chunk in rag_service.documents)
tests/unit/test_retrieval.py DELETED
@@ -1,13 +0,0 @@
1
- import pathlib
2
- import pytest
3
-
4
- from multi_doc_chat.src.document_chat.retrieval import ConversationalRAG
5
- from exception.custom_exception import DocumentPortalException
6
-
7
-
8
- def test_conversationalrag_error_handling(tmp_dirs, stub_model_loader):
9
- rag = ConversationalRAG(session_id="s1")
10
- with pytest.raises(DocumentPortalException):
11
- rag.invoke("hello")
12
- with pytest.raises(DocumentPortalException):
13
- rag.load_retriever_from_faiss(index_path="faiss_index/does_not_exist")