File size: 4,728 Bytes
395651c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Tests for POST /api/v1/sessions/{session_id}/solve_multipart."""

from __future__ import annotations

import os
from unittest.mock import MagicMock, patch

import pytest
from httpx import ASGITransport, AsyncClient

os.environ.setdefault("ALLOW_TEST_BYPASS", "true")

from app.main import app  # noqa: E402
from app.models.schemas import SolveResponse  # noqa: E402

_VALID_SESSION_ID = "00000000-0000-0000-0000-000000000088"

# PNG signature + padding (>= 12 bytes) for magic check in validate_chat_image_bytes
_VALID_PNG_BODY = b"\x89PNG\r\n\x1a\n" + b"\x00" * 32


@pytest.fixture
def auth_headers():
    return {"Authorization": "Test test-user-solve-mp"}


@pytest.mark.asyncio
async def test_solve_multipart_requires_auth():
    async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
        res = await client.post(
            f"/api/v1/sessions/{_VALID_SESSION_ID}/solve_multipart",
            files={"file": ("t.png", _VALID_PNG_BODY, "image/png")},
            data={"text": "hi"},
        )
    assert res.status_code == 401


@pytest.mark.asyncio
async def test_solve_multipart_forbidden(auth_headers):
    with patch("app.routers.solve.session_owned_by_user", return_value=False):
        async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
            res = await client.post(
                f"/api/v1/sessions/{_VALID_SESSION_ID}/solve_multipart",
                headers=auth_headers,
                files={"file": ("t.png", _VALID_PNG_BODY, "image/png")},
                data={"text": "hi"},
            )
    assert res.status_code == 403


@pytest.mark.asyncio
async def test_solve_multipart_empty_text(auth_headers):
    with patch("app.routers.solve.session_owned_by_user", return_value=True):
        async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
            res = await client.post(
                f"/api/v1/sessions/{_VALID_SESSION_ID}/solve_multipart",
                headers=auth_headers,
                files={"file": ("t.png", _VALID_PNG_BODY, "image/png")},
                data={"text": "   "},
            )
    assert res.status_code == 400


@pytest.mark.asyncio
async def test_solve_multipart_bad_magic(auth_headers):
    with patch("app.routers.solve.session_owned_by_user", return_value=True):
        async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
            res = await client.post(
                f"/api/v1/sessions/{_VALID_SESSION_ID}/solve_multipart",
                headers=auth_headers,
                files={"file": ("t.png", b"not-a-real-png!!", "image/png")},
                data={"text": "problem text"},
            )
    assert res.status_code == 400


@pytest.mark.asyncio
async def test_solve_multipart_upload_then_enqueue(auth_headers):
    up = {
        "public_url": "https://example.test/bucket/sessions/s1/image_v1_j.png",
        "storage_path": f"sessions/{_VALID_SESSION_ID}/image_v1_job.png",
        "version": 1,
        "session_asset_id": "00000000-0000-0000-0000-000000000099",
    }
    captured = {}

    def fake_enqueue(supabase, background_tasks, session_id, user_id, uid, request, message_metadata, job_id):
        captured["metadata"] = message_metadata
        captured["job_id"] = job_id
        captured["request"] = request
        return SolveResponse(job_id=job_id, status="processing")

    with (
        patch("app.routers.solve.session_owned_by_user", return_value=True),
        patch("app.routers.solve.upload_session_chat_image", return_value=up) as up_mock,
        patch("app.routers.solve._enqueue_solve_common", side_effect=fake_enqueue),
    ):
        async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
            res = await client.post(
                f"/api/v1/sessions/{_VALID_SESSION_ID}/solve_multipart",
                headers=auth_headers,
                files={"file": ("t.png", _VALID_PNG_BODY, "image/png")},
                data={"text": "  my problem  "},
            )
    assert res.status_code == 200, res.text
    data = res.json()
    assert data["status"] == "processing"
    jid = data["job_id"]
    assert jid
    up_mock.assert_called_once()
    call_args = up_mock.call_args[0]
    assert call_args[0] == _VALID_SESSION_ID
    assert call_args[1] == jid
    assert len(call_args[2]) == len(_VALID_PNG_BODY)
    att = captured["metadata"].get("attachment", {})
    assert att.get("size_bytes") == len(_VALID_PNG_BODY)
    assert att.get("public_url") == up["public_url"]
    assert captured["request"].text == "my problem"
    assert captured["request"].image_url == up["public_url"]