anima-gradio-zerogpu-space / tests /test_validation.py
JSCPPProgrammer's picture
feat: add seed controls
7dd3a70 verified
from __future__ import annotations
import pytest
from src import config
from src.errors import UserFacingError
from src.validation import validate_and_clamp
def test_defaults_match_rdbt() -> None:
p = validate_and_clamp(
prompt="a test prompt",
negative_prompt="",
width=1024,
height=1024,
steps=16,
cfg=1.0,
batch_size=1,
sampler_name="euler_ancestral",
scheduler="simple",
denoise=1.0,
)
assert p.width == 1024
assert p.height == 1024
assert p.steps == 16
assert p.cfg == 1.0
assert p.sampler_name == "euler_ancestral"
assert p.scheduler == "simple"
assert p.denoise == 1.0
assert p.negative_prompt == ""
def test_clamp_width_height() -> None:
p = validate_and_clamp(
prompt="x",
negative_prompt="",
width=100,
height=3000,
steps=16,
cfg=1.0,
batch_size=1,
sampler_name="euler_ancestral",
scheduler="simple",
denoise=1.0,
)
assert p.width == 512
assert p.height == 2048
assert p.warnings
def test_euler_a_alias() -> None:
p = validate_and_clamp(
prompt="x",
negative_prompt="",
width=1024,
height=1024,
steps=16,
cfg=1.0,
batch_size=1,
sampler_name="euler_a",
scheduler="simple",
denoise=1.0,
)
assert p.sampler_name == "euler_ancestral"
assert any("euler" in w.lower() for w in p.warnings)
def test_empty_prompt_rejected() -> None:
with pytest.raises(UserFacingError):
validate_and_clamp(
prompt=" ",
negative_prompt="",
width=1024,
height=1024,
steps=16,
cfg=1.0,
batch_size=1,
sampler_name="euler_ancestral",
scheduler="simple",
denoise=1.0,
)
def test_cfg_respects_step() -> None:
p = validate_and_clamp(
prompt="x",
negative_prompt="",
width=1024,
height=1024,
steps=16,
cfg=1.23, # will snap
batch_size=1,
sampler_name="euler_ancestral",
scheduler="simple",
denoise=1.0,
)
# snapped to 0.1 step from min 1.0: round((1.23-1)/0.1)*0.1+1 = 1.2
assert abs(p.cfg - 1.2) < 0.01
def test_fixed_seed_is_used() -> None:
p = validate_and_clamp(
prompt="x",
negative_prompt="",
width=1024,
height=1024,
steps=16,
cfg=1.0,
batch_size=1,
sampler_name="euler_ancestral",
scheduler="simple",
denoise=1.0,
seed=12345,
randomize_seed=False,
)
assert p.seed == 12345
def test_randomized_seed_is_in_range() -> None:
p = validate_and_clamp(
prompt="x",
negative_prompt="",
width=1024,
height=1024,
steps=16,
cfg=1.0,
batch_size=1,
sampler_name="euler_ancestral",
scheduler="simple",
denoise=1.0,
seed=12345,
randomize_seed=True,
)
assert config.MIN_SEED <= p.seed <= config.MAX_SEED
def test_seed_clamps_when_locked() -> None:
p = validate_and_clamp(
prompt="x",
negative_prompt="",
width=1024,
height=1024,
steps=16,
cfg=1.0,
batch_size=1,
sampler_name="euler_ancestral",
scheduler="simple",
denoise=1.0,
seed=config.MAX_SEED + 1,
randomize_seed=False,
)
assert p.seed == config.MAX_SEED
assert any("seed" in w for w in p.warnings)