""" Tests for POST /api/v1/documents/upload/batch — issue #435. """ import io import uuid from unittest.mock import MagicMock, patch import pytest from app.models import Document # ── helpers ────────────────────────────────────────────────────────────────── def _fake_txt_file(name: str = "test.txt", content: bytes = b"hello world") -> tuple: """Return a multipart files tuple accepted by httpx TestClient.""" return ("files", (name, io.BytesIO(content), "text/plain")) def _patch_validate(monkeypatch, tmp_path, content: bytes = b"hello world") -> None: """Make validate_upload write content to a real temp file and return its path.""" import tempfile, shutil async def fake_validate(file): tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", dir=tmp_path) tmp.write(content) tmp.close() return tmp.name monkeypatch.setattr("app.routes.documents.validate_upload", fake_validate) def _patch_celery(monkeypatch) -> MagicMock: """Stub out Celery so tests never touch Redis.""" mock_task = MagicMock() mock_task.id = f"celery-{uuid.uuid4().hex}" monkeypatch.setattr( "app.routes.documents.process_document", MagicMock(delay=MagicMock(return_value=mock_task)), ) return mock_task # ── tests ───────────────────────────────────────────────────────────────────── def test_batch_upload_single_file(client, auth_headers, monkeypatch, tmp_path): _patch_validate(monkeypatch, tmp_path) _patch_celery(monkeypatch) response = client.post( "/api/v1/documents/upload/batch", headers=auth_headers, files=[_fake_txt_file("report.txt")], data={"chunk_size": "1000", "chunk_overlap": "200"}, ) assert response.status_code == 202 payload = response.json() assert payload["total"] == 1 assert len(payload["documents"]) == 1 assert payload["documents"][0]["original_name"] == "report.txt" assert payload["documents"][0]["status"] == "pending" assert len(payload["task_ids"]) == 1 assert payload["failed"] == [] def test_batch_upload_multiple_files(client, auth_headers, monkeypatch, tmp_path): _patch_validate(monkeypatch, tmp_path) _patch_celery(monkeypatch) response = client.post( "/api/v1/documents/upload/batch", headers=auth_headers, files=[ _fake_txt_file("a.txt"), _fake_txt_file("b.txt"), _fake_txt_file("c.txt"), ], data={"chunk_size": "1000", "chunk_overlap": "200"}, ) assert response.status_code == 202 payload = response.json() assert payload["total"] == 3 assert len(payload["documents"]) == 3 assert len(payload["task_ids"]) == 3 assert payload["failed"] == [] def test_batch_upload_rejects_bad_extension(client, auth_headers, monkeypatch, tmp_path): """A .exe file should land in failed[], not crash the whole batch.""" _patch_validate(monkeypatch, tmp_path) _patch_celery(monkeypatch) response = client.post( "/api/v1/documents/upload/batch", headers=auth_headers, files=[ _fake_txt_file("good.txt"), ("files", ("bad.exe", io.BytesIO(b"binary"), "application/octet-stream")), ], data={"chunk_size": "1000", "chunk_overlap": "200"}, ) assert response.status_code == 202 payload = response.json() assert payload["total"] == 1 assert payload["documents"][0]["original_name"] == "good.txt" assert "bad.exe" in payload["failed"] def test_batch_upload_all_files_fail_returns_400(client, auth_headers, monkeypatch, tmp_path): """When every file fails, the endpoint should return 400.""" _patch_validate(monkeypatch, tmp_path) _patch_celery(monkeypatch) response = client.post( "/api/v1/documents/upload/batch", headers=auth_headers, files=[ ("files", ("bad1.exe", io.BytesIO(b"x"), "application/octet-stream")), ("files", ("bad2.exe", io.BytesIO(b"y"), "application/octet-stream")), ], data={"chunk_size": "1000", "chunk_overlap": "200"}, ) assert response.status_code == 400 def test_batch_upload_requires_auth(client): response = client.post( "/api/v1/documents/upload/batch", files=[_fake_txt_file("test.txt")], data={"chunk_size": "1000", "chunk_overlap": "200"}, ) assert response.status_code in (401, 403) def test_batch_upload_invalid_chunk_size(client, auth_headers, monkeypatch, tmp_path): _patch_validate(monkeypatch, tmp_path) _patch_celery(monkeypatch) response = client.post( "/api/v1/documents/upload/batch", headers=auth_headers, files=[_fake_txt_file("test.txt")], data={"chunk_size": "50", "chunk_overlap": "10"}, ) assert response.status_code == 400 def test_batch_upload_invalid_chunk_overlap(client, auth_headers, monkeypatch, tmp_path): _patch_validate(monkeypatch, tmp_path) _patch_celery(monkeypatch) response = client.post( "/api/v1/documents/upload/batch", headers=auth_headers, files=[_fake_txt_file("test.txt")], data={"chunk_size": "500", "chunk_overlap": "600"}, ) assert response.status_code == 400 def test_batch_upload_celery_fallback_uses_background_task(client, auth_headers, monkeypatch, tmp_path): """When Celery is unavailable, tasks should fall back gracefully.""" _patch_validate(monkeypatch, tmp_path) # Make Celery raise so the fallback branch is taken monkeypatch.setattr( "app.routes.documents.process_document", MagicMock(delay=MagicMock(side_effect=Exception("Redis down"))), ) response = client.post( "/api/v1/documents/upload/batch", headers=auth_headers, files=[_fake_txt_file("fallback.txt")], data={"chunk_size": "1000", "chunk_overlap": "200"}, ) assert response.status_code == 202 payload = response.json() assert payload["total"] == 1 assert payload["task_ids"][0].startswith("local_") def test_batch_upload_document_persisted_in_db(client, auth_headers, monkeypatch, tmp_path, db_session): _patch_validate(monkeypatch, tmp_path) _patch_celery(monkeypatch) response = client.post( "/api/v1/documents/upload/batch", headers=auth_headers, files=[_fake_txt_file("persisted.txt")], data={"chunk_size": "1000", "chunk_overlap": "200"}, ) assert response.status_code == 202 doc_id = response.json()["documents"][0]["id"] doc = db_session.get(Document, doc_id) assert doc is not None assert doc.original_name == "persisted.txt" assert doc.status == "pending" assert doc.chunk_size == 1000 assert doc.chunk_overlap == 200 def test_batch_upload_chunk_settings_stored(client, auth_headers, monkeypatch, tmp_path, db_session): _patch_validate(monkeypatch, tmp_path) _patch_celery(monkeypatch) response = client.post( "/api/v1/documents/upload/batch", headers=auth_headers, files=[_fake_txt_file("chunked.txt")], data={"chunk_size": "800", "chunk_overlap": "100"}, ) assert response.status_code == 202 doc_id = response.json()["documents"][0]["id"] doc = db_session.get(Document, doc_id) assert doc.chunk_size == 800 assert doc.chunk_overlap == 100