File size: 5,490 Bytes
dc4e6da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Functional tests — POST /generate/async endpoint
==================================================
Tests for the async (RQ-queue) document generation endpoint.

Key behaviours exercised:
  • Schema validation (422) for missing / bad fields
  • 404 when request_id not in Supabase
  • 503 behaviour when Redis queue is unavailable (if applicable)
  • Response contract when request is queued successfully
  • Same prompt_params validation as /generate/pdf (shared schema)
"""
import copy
import pytest
import requests
from tests.conftest import (
    BASE_URL, TIMEOUT, SEED_IMAGE_URL,
    NONEXISTENT_REQUEST_ID, MINIMAL_GENERATE_PAYLOAD,
)

ENDPOINT = f"{BASE_URL}/generate/async"


def make_payload(**overrides):
    payload = copy.deepcopy(MINIMAL_GENERATE_PAYLOAD)
    payload.update(overrides)
    return payload


def make_prompt_override(**kw):
    pp = copy.deepcopy(MINIMAL_GENERATE_PAYLOAD["prompt_params"])
    pp.update(kw)
    return pp


# ---------------------------------------------------------------------------
# 1. Schema / Input Validation
# ---------------------------------------------------------------------------

class TestGenerateAsyncInputValidation:
    """FastAPI must reject malformed requests before any business logic."""

    def test_missing_request_id_returns_422(self, http):
        payload = {
            "seed_images": [SEED_IMAGE_URL],
            "prompt_params": MINIMAL_GENERATE_PAYLOAD["prompt_params"],
        }
        r = http.post(ENDPOINT, json=payload, timeout=TIMEOUT)
        assert r.status_code == 422

    def test_empty_seed_images_returns_422(self, http):
        r = http.post(ENDPOINT, json=make_payload(seed_images=[]), timeout=TIMEOUT)
        assert r.status_code == 422

    def test_too_many_seed_images_returns_422(self, http):
        r = http.post(ENDPOINT,
                      json=make_payload(seed_images=[SEED_IMAGE_URL] * 11),
                      timeout=TIMEOUT)
        assert r.status_code == 422

    def test_invalid_seed_image_url_returns_422(self, http):
        r = http.post(ENDPOINT,
                      json=make_payload(seed_images=["not-a-url"]),
                      timeout=TIMEOUT)
        assert r.status_code == 422

    def test_num_solutions_below_min_returns_422(self, http):
        pp = make_prompt_override(num_solutions=0)
        r = http.post(ENDPOINT, json=make_payload(prompt_params=pp), timeout=TIMEOUT)
        assert r.status_code == 422

    def test_num_solutions_above_max_returns_422(self, http):
        pp = make_prompt_override(num_solutions=6)
        r = http.post(ENDPOINT, json=make_payload(prompt_params=pp), timeout=TIMEOUT)
        assert r.status_code == 422

    def test_empty_body_returns_422(self, http):
        r = http.post(ENDPOINT, json={}, timeout=TIMEOUT)
        assert r.status_code == 422


# ---------------------------------------------------------------------------
# 2. Business-logic (valid schema, unknown request_id → 404 or 503)
# ---------------------------------------------------------------------------

class TestGenerateAsyncBusinessLogic:
    """
    With a valid schema but nonexistent request_id the API should:
      • Return 404 if Redis is available (request_id lookup fails first), OR
      • Return 503 if Redis is unavailable (queue not initialised)
    Both are acceptable non-422 responses.
    """

    def test_nonexistent_request_id_is_not_422(self, http):
        r = http.post(ENDPOINT, json=MINIMAL_GENERATE_PAYLOAD, timeout=TIMEOUT)
        assert r.status_code != 422, (
            f"Valid schema must not produce 422, got {r.status_code}"
        )

    def test_nonexistent_request_id_returns_404_or_503(self, http):
        r = http.post(ENDPOINT, json=MINIMAL_GENERATE_PAYLOAD, timeout=TIMEOUT)
        assert r.status_code in (404, 503), (
            f"Expected 404 (no request) or 503 (no Redis), got {r.status_code}: {r.text}"
        )

    def test_error_response_is_json(self, http):
        r = http.post(ENDPOINT, json=MINIMAL_GENERATE_PAYLOAD, timeout=TIMEOUT)
        assert "application/json" in r.headers.get("Content-Type", "")

    def test_error_response_has_detail(self, http):
        r = http.post(ENDPOINT, json=MINIMAL_GENERATE_PAYLOAD, timeout=TIMEOUT)
        body = r.json()
        assert "detail" in body, f"Error body must have 'detail'. Got: {body}"

    def test_swagger_string_tokens_not_422(self, http):
        payload = make_payload(
            google_drive_token="string",
            google_drive_refresh_token="string",
        )
        r = http.post(ENDPOINT, json=payload, timeout=TIMEOUT)
        assert r.status_code != 422

    def test_none_google_tokens_accepted(self, http):
        payload = make_payload(google_drive_token=None, google_drive_refresh_token=None)
        r = http.post(ENDPOINT, json=payload, timeout=TIMEOUT)
        assert r.status_code != 422

    def test_num_solutions_boundary_values_schema_valid(self, http):
        for n in [1, 5]:
            pp = make_prompt_override(num_solutions=n)
            r = http.post(ENDPOINT, json=make_payload(prompt_params=pp), timeout=TIMEOUT)
            assert r.status_code != 422, (
                f"num_solutions={n} should be schema-valid"
            )

    def test_missing_prompt_params_uses_defaults(self, http):
        payload = {"request_id": NONEXISTENT_REQUEST_ID}
        r = http.post(ENDPOINT, json=payload, timeout=TIMEOUT)
        assert r.status_code != 422