Spaces:
Running on Zero
Running on Zero
File size: 2,911 Bytes
c80a8b9 3d485ad c80a8b9 55a3bb4 c80a8b9 3ea399a c80a8b9 55a3bb4 3ea399a 55a3bb4 3d485ad 3ea399a 3d485ad 3ea399a 3d485ad 3ea399a 3d485ad 3ea399a 3d485ad 3ea399a 3d485ad 3ea399a 3d485ad | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | """Unit tests for models.py — MODEL_REGISTRY and ensure_models_for_mode."""
import models
import workflow
def test_model_registry_resolves_known_files():
assert (
models.MODEL_REGISTRY["ltx-2.3-22b-distilled.safetensors"].repo_id == "Lightricks/LTX-2.3"
)
assert models.MODEL_REGISTRY["ltx-2.3-22b-distilled.safetensors"].subfolder == ""
def test_model_registry_includes_gemma_shards():
for i in range(1, 6):
key = f"model-{i:05d}-of-00005.safetensors"
assert key in models.MODEL_REGISTRY
assert "gemma-3-12b-it" in models.MODEL_REGISTRY[key].repo_id
def test_walk_workflow_for_models_finds_t2v_loaders():
wf = workflow.load_template("t2v")
needed = models.walk_workflow_for_models(wf)
# T2V needs at minimum a transformer (distilled, dev fp8, or GGUF Q4) and a gemma encoder
assert any(
name.endswith(".gguf") or "distilled.safetensors" in name or "transformer_only" in name
for name in needed
)
assert any("gemma" in name.lower() for name in needed)
def test_ensure_models_creates_symlinks_local(tmp_path, monkeypatch, fake_hf_cache):
"""In local mode, ensure_models creates symlinks from comfy/models -> HF cache."""
monkeypatch.setenv("HF_HUB_CACHE", str(fake_hf_cache))
monkeypatch.setattr(models, "_on_spaces", lambda: False)
# Force the HF Hub call to fail so the fallback path (cache_dir.rglob) is exercised.
def _raise(*_args, **_kwargs):
raise RuntimeError("offline test: forcing fallback to cache scan")
monkeypatch.setattr(models, "hf_hub_download", _raise)
comfy_models = tmp_path / "comfyui" / "models"
monkeypatch.setattr(models, "_comfy_models_dir", lambda: comfy_models)
needed = {
"ltx-2.3-22b-distilled.safetensors",
"model-00001-of-00005.safetensors",
}
list(models.ensure_models(needed))
# Each requested file should now have a symlink in comfyui/models/<type>/
assert (comfy_models / "checkpoints" / "ltx-2.3-22b-distilled.safetensors").is_symlink()
assert (
comfy_models / "text_encoders" / "gemma-3-12b-it" / "model-00001-of-00005.safetensors"
).is_symlink()
def test_ensure_models_skips_unregistered_files_with_warning(
tmp_path, monkeypatch, fake_hf_cache, caplog
):
"""Files not in MODEL_REGISTRY are skipped (with warning), not raised."""
import logging
monkeypatch.setenv("HF_HUB_CACHE", str(fake_hf_cache))
monkeypatch.setattr(models, "_on_spaces", lambda: False)
monkeypatch.setattr(models, "_comfy_models_dir", lambda: tmp_path / "comfyui" / "models")
with caplog.at_level(logging.WARNING):
list(models.ensure_models({"nonexistent_phantom_file.safetensors"}))
# Should not raise, should log a warning, should yield no events for the missing entry.
assert any("nonexistent_phantom_file" in record.message for record in caplog.records)
|