File size: 3,774 Bytes
5befce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dd3a70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from __future__ import annotations

import pytest

from src import config
from src.errors import UserFacingError
from src.validation import validate_and_clamp


def test_defaults_match_rdbt() -> None:
    p = validate_and_clamp(
        prompt="a test prompt",
        negative_prompt="",
        width=1024,
        height=1024,
        steps=16,
        cfg=1.0,
        batch_size=1,
        sampler_name="euler_ancestral",
        scheduler="simple",
        denoise=1.0,
    )
    assert p.width == 1024
    assert p.height == 1024
    assert p.steps == 16
    assert p.cfg == 1.0
    assert p.sampler_name == "euler_ancestral"
    assert p.scheduler == "simple"
    assert p.denoise == 1.0
    assert p.negative_prompt == ""


def test_clamp_width_height() -> None:
    p = validate_and_clamp(
        prompt="x",
        negative_prompt="",
        width=100,
        height=3000,
        steps=16,
        cfg=1.0,
        batch_size=1,
        sampler_name="euler_ancestral",
        scheduler="simple",
        denoise=1.0,
    )
    assert p.width == 512
    assert p.height == 2048
    assert p.warnings


def test_euler_a_alias() -> None:
    p = validate_and_clamp(
        prompt="x",
        negative_prompt="",
        width=1024,
        height=1024,
        steps=16,
        cfg=1.0,
        batch_size=1,
        sampler_name="euler_a",
        scheduler="simple",
        denoise=1.0,
    )
    assert p.sampler_name == "euler_ancestral"
    assert any("euler" in w.lower() for w in p.warnings)


def test_empty_prompt_rejected() -> None:
    with pytest.raises(UserFacingError):
        validate_and_clamp(
            prompt="  ",
            negative_prompt="",
            width=1024,
            height=1024,
            steps=16,
            cfg=1.0,
            batch_size=1,
            sampler_name="euler_ancestral",
            scheduler="simple",
            denoise=1.0,
        )


def test_cfg_respects_step() -> None:
    p = validate_and_clamp(
        prompt="x",
        negative_prompt="",
        width=1024,
        height=1024,
        steps=16,
        cfg=1.23,  # will snap
        batch_size=1,
        sampler_name="euler_ancestral",
        scheduler="simple",
        denoise=1.0,
    )
    # snapped to 0.1 step from min 1.0: round((1.23-1)/0.1)*0.1+1 = 1.2
    assert abs(p.cfg - 1.2) < 0.01


def test_fixed_seed_is_used() -> None:
    p = validate_and_clamp(
        prompt="x",
        negative_prompt="",
        width=1024,
        height=1024,
        steps=16,
        cfg=1.0,
        batch_size=1,
        sampler_name="euler_ancestral",
        scheduler="simple",
        denoise=1.0,
        seed=12345,
        randomize_seed=False,
    )
    assert p.seed == 12345


def test_randomized_seed_is_in_range() -> None:
    p = validate_and_clamp(
        prompt="x",
        negative_prompt="",
        width=1024,
        height=1024,
        steps=16,
        cfg=1.0,
        batch_size=1,
        sampler_name="euler_ancestral",
        scheduler="simple",
        denoise=1.0,
        seed=12345,
        randomize_seed=True,
    )
    assert config.MIN_SEED <= p.seed <= config.MAX_SEED


def test_seed_clamps_when_locked() -> None:
    p = validate_and_clamp(
        prompt="x",
        negative_prompt="",
        width=1024,
        height=1024,
        steps=16,
        cfg=1.0,
        batch_size=1,
        sampler_name="euler_ancestral",
        scheduler="simple",
        denoise=1.0,
        seed=config.MAX_SEED + 1,
        randomize_seed=False,
    )
    assert p.seed == config.MAX_SEED
    assert any("seed" in w for w in p.warnings)