Instructions to use mhnakif/comfy2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use mhnakif/comfy2 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("mhnakif/comfy2", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| """Tests for external cache provider API.""" | |
| import importlib.util | |
| import pytest | |
| from typing import Optional | |
| def _torch_available() -> bool: | |
| """Check if PyTorch is available.""" | |
| return importlib.util.find_spec("torch") is not None | |
| from comfy_execution.cache_provider import ( | |
| CacheProvider, | |
| CacheContext, | |
| CacheValue, | |
| register_cache_provider, | |
| unregister_cache_provider, | |
| _get_cache_providers, | |
| _has_cache_providers, | |
| _clear_cache_providers, | |
| _serialize_cache_key, | |
| _contains_self_unequal, | |
| _estimate_value_size, | |
| _canonicalize, | |
| ) | |
| class TestCanonicalize: | |
| """Test _canonicalize function for deterministic ordering.""" | |
| def test_frozenset_ordering_is_deterministic(self): | |
| """Frozensets should produce consistent canonical form regardless of iteration order.""" | |
| # Create two frozensets with same content | |
| fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)]) | |
| fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)]) | |
| result1 = _canonicalize(fs1) | |
| result2 = _canonicalize(fs2) | |
| assert result1 == result2 | |
| def test_nested_frozenset_ordering(self): | |
| """Nested frozensets should also be deterministically ordered.""" | |
| inner1 = frozenset([1, 2, 3]) | |
| inner2 = frozenset([3, 2, 1]) | |
| fs1 = frozenset([("key", inner1)]) | |
| fs2 = frozenset([("key", inner2)]) | |
| result1 = _canonicalize(fs1) | |
| result2 = _canonicalize(fs2) | |
| assert result1 == result2 | |
| def test_dict_ordering(self): | |
| """Dicts should be sorted by key.""" | |
| d1 = {"z": 1, "a": 2, "m": 3} | |
| d2 = {"a": 2, "m": 3, "z": 1} | |
| result1 = _canonicalize(d1) | |
| result2 = _canonicalize(d2) | |
| assert result1 == result2 | |
| def test_tuple_preserved(self): | |
| """Tuples should be marked and preserved.""" | |
| t = (1, 2, 3) | |
| result = _canonicalize(t) | |
| assert result[0] == "__tuple__" | |
| def test_list_preserved(self): | |
| """Lists should be recursively canonicalized.""" | |
| lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])] | |
| result = _canonicalize(lst) | |
| # First element should be canonicalized dict | |
| assert "__dict__" in result[0] | |
| # Second element should be canonicalized frozenset | |
| assert result[1][0] == "__frozenset__" | |
| def test_primitives_include_type(self): | |
| """Primitive types should include type name for disambiguation.""" | |
| assert _canonicalize(42) == ("int", 42) | |
| assert _canonicalize(3.14) == ("float", 3.14) | |
| assert _canonicalize("hello") == ("str", "hello") | |
| assert _canonicalize(True) == ("bool", True) | |
| assert _canonicalize(None) == ("NoneType", None) | |
| def test_int_and_str_distinguished(self): | |
| """int 7 and str '7' must produce different canonical forms.""" | |
| assert _canonicalize(7) != _canonicalize("7") | |
| def test_bytes_converted(self): | |
| """Bytes should be converted to hex string.""" | |
| b = b"\x00\xff" | |
| result = _canonicalize(b) | |
| assert result[0] == "__bytes__" | |
| assert result[1] == "00ff" | |
| def test_set_ordering(self): | |
| """Sets should be sorted like frozensets.""" | |
| s1 = {3, 1, 2} | |
| s2 = {1, 2, 3} | |
| result1 = _canonicalize(s1) | |
| result2 = _canonicalize(s2) | |
| assert result1 == result2 | |
| assert result1[0] == "__set__" | |
| def test_unknown_type_raises(self): | |
| """Unknown types should raise ValueError (fail-closed).""" | |
| class CustomObj: | |
| pass | |
| with pytest.raises(ValueError): | |
| _canonicalize(CustomObj()) | |
| def test_object_with_value_attr_raises(self): | |
| """Objects with .value attribute (Unhashable-like) should raise ValueError.""" | |
| class FakeUnhashable: | |
| def __init__(self): | |
| self.value = float('nan') | |
| with pytest.raises(ValueError): | |
| _canonicalize(FakeUnhashable()) | |
| class TestSerializeCacheKey: | |
| """Test _serialize_cache_key for deterministic hashing.""" | |
| def test_same_content_same_hash(self): | |
| """Same content should produce same hash.""" | |
| key1 = frozenset([("node_1", frozenset([("input", "value")]))]) | |
| key2 = frozenset([("node_1", frozenset([("input", "value")]))]) | |
| hash1 = _serialize_cache_key(key1) | |
| hash2 = _serialize_cache_key(key2) | |
| assert hash1 == hash2 | |
| def test_different_content_different_hash(self): | |
| """Different content should produce different hash.""" | |
| key1 = frozenset([("node_1", "value_a")]) | |
| key2 = frozenset([("node_1", "value_b")]) | |
| hash1 = _serialize_cache_key(key1) | |
| hash2 = _serialize_cache_key(key2) | |
| assert hash1 != hash2 | |
| def test_returns_hex_string(self): | |
| """Should return hex string (SHA256 hex digest).""" | |
| key = frozenset([("test", 123)]) | |
| result = _serialize_cache_key(key) | |
| assert isinstance(result, str) | |
| assert len(result) == 64 # SHA256 hex digest is 64 chars | |
| def test_complex_nested_structure(self): | |
| """Complex nested structures should hash deterministically.""" | |
| # Note: frozensets can only contain hashable types, so we use | |
| # nested frozensets of tuples to represent dict-like structures | |
| key = frozenset([ | |
| ("node_1", frozenset([ | |
| ("input_a", ("tuple", "value")), | |
| ("input_b", frozenset([("nested", "dict")])), | |
| ])), | |
| ("node_2", frozenset([ | |
| ("param", 42), | |
| ])), | |
| ]) | |
| # Hash twice to verify determinism | |
| hash1 = _serialize_cache_key(key) | |
| hash2 = _serialize_cache_key(key) | |
| assert hash1 == hash2 | |
| def test_dict_in_cache_key(self): | |
| """Dicts passed directly to _serialize_cache_key should work.""" | |
| key = {"node_1": {"input": "value"}, "node_2": 42} | |
| hash1 = _serialize_cache_key(key) | |
| hash2 = _serialize_cache_key(key) | |
| assert hash1 == hash2 | |
| assert isinstance(hash1, str) | |
| assert len(hash1) == 64 | |
| def test_unknown_type_returns_none(self): | |
| """Non-cacheable types should return None (fail-closed).""" | |
| class CustomObj: | |
| pass | |
| assert _serialize_cache_key(CustomObj()) is None | |
| class TestContainsSelfUnequal: | |
| """Test _contains_self_unequal utility function.""" | |
| def test_nan_float_detected(self): | |
| """NaN floats should be detected (not equal to itself).""" | |
| assert _contains_self_unequal(float('nan')) is True | |
| def test_regular_float_not_detected(self): | |
| """Regular floats are equal to themselves.""" | |
| assert _contains_self_unequal(3.14) is False | |
| assert _contains_self_unequal(0.0) is False | |
| assert _contains_self_unequal(-1.5) is False | |
| def test_infinity_not_detected(self): | |
| """Infinity is equal to itself.""" | |
| assert _contains_self_unequal(float('inf')) is False | |
| assert _contains_self_unequal(float('-inf')) is False | |
| def test_nan_in_list(self): | |
| """NaN in list should be detected.""" | |
| assert _contains_self_unequal([1, 2, float('nan'), 4]) is True | |
| assert _contains_self_unequal([1, 2, 3, 4]) is False | |
| def test_nan_in_tuple(self): | |
| """NaN in tuple should be detected.""" | |
| assert _contains_self_unequal((1, float('nan'))) is True | |
| assert _contains_self_unequal((1, 2, 3)) is False | |
| def test_nan_in_frozenset(self): | |
| """NaN in frozenset should be detected.""" | |
| assert _contains_self_unequal(frozenset([1, float('nan')])) is True | |
| assert _contains_self_unequal(frozenset([1, 2, 3])) is False | |
| def test_nan_in_dict_value(self): | |
| """NaN in dict value should be detected.""" | |
| assert _contains_self_unequal({"key": float('nan')}) is True | |
| assert _contains_self_unequal({"key": 42}) is False | |
| def test_nan_in_nested_structure(self): | |
| """NaN in deeply nested structure should be detected.""" | |
| nested = {"level1": [{"level2": (1, 2, float('nan'))}]} | |
| assert _contains_self_unequal(nested) is True | |
| def test_non_numeric_types(self): | |
| """Non-numeric types should not be self-unequal.""" | |
| assert _contains_self_unequal("string") is False | |
| assert _contains_self_unequal(None) is False | |
| assert _contains_self_unequal(True) is False | |
| def test_object_with_nan_value_attribute(self): | |
| """Objects wrapping NaN in .value should be detected.""" | |
| class NanWrapper: | |
| def __init__(self): | |
| self.value = float('nan') | |
| assert _contains_self_unequal(NanWrapper()) is True | |
| def test_custom_self_unequal_object(self): | |
| """Custom objects where not (x == x) should be detected.""" | |
| class NeverEqual: | |
| def __eq__(self, other): | |
| return False | |
| assert _contains_self_unequal(NeverEqual()) is True | |
| class TestEstimateValueSize: | |
| """Test _estimate_value_size utility function.""" | |
| def test_empty_outputs(self): | |
| """Empty outputs should have zero size.""" | |
| value = CacheValue(outputs=[]) | |
| assert _estimate_value_size(value) == 0 | |
| def test_tensor_size_estimation(self): | |
| """Tensor size should be estimated correctly.""" | |
| import torch | |
| # 1000 float32 elements = 4000 bytes | |
| tensor = torch.zeros(1000, dtype=torch.float32) | |
| value = CacheValue(outputs=[[tensor]]) | |
| size = _estimate_value_size(value) | |
| assert size == 4000 | |
| def test_nested_tensor_in_dict(self): | |
| """Tensors nested in dicts should be counted.""" | |
| import torch | |
| tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes | |
| value = CacheValue(outputs=[[{"samples": tensor}]]) | |
| size = _estimate_value_size(value) | |
| assert size == 400 | |
| class TestProviderRegistry: | |
| """Test cache provider registration and retrieval.""" | |
| def setup_method(self): | |
| """Clear providers before each test.""" | |
| _clear_cache_providers() | |
| def teardown_method(self): | |
| """Clear providers after each test.""" | |
| _clear_cache_providers() | |
| def test_register_provider(self): | |
| """Provider should be registered successfully.""" | |
| provider = MockCacheProvider() | |
| register_cache_provider(provider) | |
| assert _has_cache_providers() is True | |
| providers = _get_cache_providers() | |
| assert len(providers) == 1 | |
| assert providers[0] is provider | |
| def test_unregister_provider(self): | |
| """Provider should be unregistered successfully.""" | |
| provider = MockCacheProvider() | |
| register_cache_provider(provider) | |
| unregister_cache_provider(provider) | |
| assert _has_cache_providers() is False | |
| def test_multiple_providers(self): | |
| """Multiple providers can be registered.""" | |
| provider1 = MockCacheProvider() | |
| provider2 = MockCacheProvider() | |
| register_cache_provider(provider1) | |
| register_cache_provider(provider2) | |
| providers = _get_cache_providers() | |
| assert len(providers) == 2 | |
| def test_duplicate_registration_ignored(self): | |
| """Registering same provider twice should be ignored.""" | |
| provider = MockCacheProvider() | |
| register_cache_provider(provider) | |
| register_cache_provider(provider) # Should be ignored | |
| providers = _get_cache_providers() | |
| assert len(providers) == 1 | |
| def test_clear_providers(self): | |
| """_clear_cache_providers should remove all providers.""" | |
| provider1 = MockCacheProvider() | |
| provider2 = MockCacheProvider() | |
| register_cache_provider(provider1) | |
| register_cache_provider(provider2) | |
| _clear_cache_providers() | |
| assert _has_cache_providers() is False | |
| assert len(_get_cache_providers()) == 0 | |
| class TestCacheContext: | |
| """Test CacheContext dataclass.""" | |
| def test_context_creation(self): | |
| """CacheContext should be created with all fields.""" | |
| context = CacheContext( | |
| node_id="node-456", | |
| class_type="KSampler", | |
| cache_key_hash="a" * 64, | |
| ) | |
| assert context.node_id == "node-456" | |
| assert context.class_type == "KSampler" | |
| assert context.cache_key_hash == "a" * 64 | |
| class TestCacheValue: | |
| """Test CacheValue dataclass.""" | |
| def test_value_creation(self): | |
| """CacheValue should be created with outputs.""" | |
| outputs = [[{"samples": "tensor_data"}]] | |
| value = CacheValue(outputs=outputs) | |
| assert value.outputs == outputs | |
| class MockCacheProvider(CacheProvider): | |
| """Mock cache provider for testing.""" | |
| def __init__(self): | |
| self.lookups = [] | |
| self.stores = [] | |
| async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: | |
| self.lookups.append(context) | |
| return None | |
| async def on_store(self, context: CacheContext, value: CacheValue) -> None: | |
| self.stores.append((context, value)) | |