LTX2.3-Studio / tests /test_modes.py
techfreakworm's picture
ci: run unit tests + ruff lint on every push
3ea399a unverified
"""Unit tests for modes.py — MODE_REGISTRY and parameterize_fn correctness."""
import pytest
import modes
import workflow
def test_mode_dataclass_has_expected_fields():
"""Mode dataclass exposes the expected attribute set."""
fields = {"name", "label", "icon", "parameterize_fn", "stage_map"}
actual = set(modes.Mode.__dataclass_fields__.keys())
assert fields == actual
def test_mode_registry_is_a_dict():
"""MODE_REGISTRY exists and is a dict (entries added in Tasks 11–12)."""
assert isinstance(modes.MODE_REGISTRY, dict)
def test_t2v_parameterize_produces_valid_patches(canonical_inputs):
inputs = canonical_inputs["t2v"]
mode = modes.MODE_REGISTRY["t2v"]
patches = mode.parameterize_fn(inputs)
# All patches must be (node_id: int, widget_index: int, value: Any)
for node_id, widget_index, _value in patches:
assert isinstance(node_id, int)
assert isinstance(widget_index, int)
# Apply patches to a real template; result must validate.
wf = workflow.load_template("t2v")
for patch in patches:
workflow.set_input(wf, *patch)
workflow.validate(wf)
def test_i2v_parameterize_uses_image_path(canonical_inputs):
inputs = canonical_inputs["i2v"]
mode = modes.MODE_REGISTRY["i2v"]
patches = mode.parameterize_fn(inputs)
values = [p[2] for p in patches]
assert inputs["image"] in values
def test_t2v_and_i2v_in_registry():
"""T2V and I2V exist in MODE_REGISTRY (full completeness in Task 12)."""
assert "t2v" in modes.MODE_REGISTRY
assert "i2v" in modes.MODE_REGISTRY
@pytest.mark.parametrize("mode_name", ["a2v", "lipsync", "keyframe", "style"])
def test_remaining_modes_parameterize_validates(mode_name, canonical_inputs):
inputs = canonical_inputs[mode_name]
mode = modes.MODE_REGISTRY[mode_name]
patches = mode.parameterize_fn(inputs)
assert len(patches) > 0
wf = workflow.load_template(mode_name)
for patch in patches:
workflow.set_input(wf, *patch)
workflow.validate(wf)
def test_a2v_parameterize_passes_audio_path(canonical_inputs):
patches = modes.MODE_REGISTRY["a2v"].parameterize_fn(canonical_inputs["a2v"])
assert canonical_inputs["a2v"]["audio"] in [p[2] for p in patches]
def test_lipsync_parameterize_passes_image_and_audio(canonical_inputs):
patches = modes.MODE_REGISTRY["lipsync"].parameterize_fn(canonical_inputs["lipsync"])
values = [p[2] for p in patches]
assert canonical_inputs["lipsync"]["image"] in values
assert canonical_inputs["lipsync"]["audio"] in values
def test_keyframe_parameterize_passes_two_frames(canonical_inputs):
patches = modes.MODE_REGISTRY["keyframe"].parameterize_fn(canonical_inputs["keyframe"])
values = [p[2] for p in patches]
assert canonical_inputs["keyframe"]["first_frame"] in values
assert canonical_inputs["keyframe"]["last_frame"] in values
def test_style_parameterize_passes_input_video(canonical_inputs):
patches = modes.MODE_REGISTRY["style"].parameterize_fn(canonical_inputs["style"])
assert canonical_inputs["style"]["input_video"] in [p[2] for p in patches]
def test_mode_registry_has_all_six_keys():
"""All six modes are in the registry now."""
assert set(modes.MODE_REGISTRY.keys()) == {
"t2v",
"a2v",
"i2v",
"lipsync",
"keyframe",
"style",
}
def test_each_mode_has_required_attributes():
for name, mode in modes.MODE_REGISTRY.items():
assert mode.name == name
assert mode.label # non-empty
assert mode.icon # non-empty
assert callable(mode.parameterize_fn)
assert isinstance(mode.stage_map, list) and len(mode.stage_map) > 0