File size: 3,854 Bytes
395651c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""Tests for POST /api/v1/sessions/{session_id}/ocr_preview (auth + owner + merge)."""

from __future__ import annotations

import os
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from httpx import ASGITransport, AsyncClient

os.environ.setdefault("ALLOW_TEST_BYPASS", "true")

from app.main import app  # noqa: E402

_VALID_SESSION_ID = "00000000-0000-0000-0000-000000000099"


@pytest.fixture
def auth_headers():
    return {"Authorization": "Test test-user-ocr-preview"}


@pytest.mark.asyncio
async def test_ocr_preview_requires_auth():
    async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
        res = await client.post(
            f"/api/v1/sessions/{_VALID_SESSION_ID}/ocr_preview",
            files={"file": ("t.png", b"\x89PNG\r\n\x1a\n", "image/png")},
            data={"user_message": "hello"},
        )
    assert res.status_code == 401


@pytest.mark.asyncio
async def test_ocr_preview_forbidden_when_not_owner(auth_headers):
    with patch("app.routers.solve.session_owned_by_user", return_value=False):
        async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
            res = await client.post(
                f"/api/v1/sessions/{_VALID_SESSION_ID}/ocr_preview",
                headers=auth_headers,
                files={"file": ("t.png", b"\x89PNG\r\n\x1a\n", "image/png")},
                data={"user_message": "note"},
            )
    assert res.status_code == 403


@pytest.mark.asyncio
async def test_ocr_preview_success_merges_draft(auth_headers):
    mock_orch = MagicMock()
    mock_orch.ocr_agent.process_image = AsyncMock(return_value="OCR_LINE")

    with (
        patch("app.routers.solve.session_owned_by_user", return_value=True),
        patch("app.routers.solve.get_orchestrator", return_value=mock_orch),
    ):
        async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
            res = await client.post(
                f"/api/v1/sessions/{_VALID_SESSION_ID}/ocr_preview",
                headers=auth_headers,
                files={"file": ("t.png", b"\x89PNG\r\n\x1a\n", "image/png")},
                data={"user_message": "  my note  "},
            )
    assert res.status_code == 200, res.text
    data = res.json()
    assert data["ocr_text"] == "OCR_LINE"
    assert data["user_message"] == "my note"
    assert data["combined_draft"] == "my note\n\nOCR_LINE"
    mock_orch.ocr_agent.process_image.assert_called_once()


@pytest.mark.asyncio
async def test_ocr_preview_rejects_oversized_file(auth_headers):
    mock_orch = MagicMock()
    mock_orch.ocr_agent.process_image = AsyncMock(return_value="")

    big = b"x" * (11 * 1024 * 1024)
    with (
        patch("app.routers.solve.session_owned_by_user", return_value=True),
        patch("app.routers.solve.get_orchestrator", return_value=mock_orch),
    ):
        async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
            res = await client.post(
                f"/api/v1/sessions/{_VALID_SESSION_ID}/ocr_preview",
                headers=auth_headers,
                files={"file": ("huge.png", big, "image/png")},
            )
    assert res.status_code == 413
    mock_orch.ocr_agent.process_image.assert_not_called()


@pytest.mark.asyncio
async def test_ocr_preview_rejects_empty_file(auth_headers):
    with patch("app.routers.solve.session_owned_by_user", return_value=True):
        async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
            res = await client.post(
                f"/api/v1/sessions/{_VALID_SESSION_ID}/ocr_preview",
                headers=auth_headers,
                files={"file": ("empty.png", b"", "image/png")},
            )
    assert res.status_code == 400