Spaces:
Running on Zero
Running on Zero
| """Unit tests for Comfy Anima MATH SDPA file patch (no Comfy / GPU).""" | |
| import os | |
| import tempfile | |
| from src.patch_comfy_anima import MARKER, apply_comfy_anima_sdp_math_patch | |
| MIN_MODEL_PY = '''\ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class Attention(torch.nn.Module): | |
| def forward(self, query_states, key_states, value_states, mask): | |
| attn_output = F.scaled_dot_product_attention( | |
| query_states, key_states, value_states, attn_mask=mask | |
| ) | |
| return attn_output | |
| ''' | |
| def _layout_minimal(): | |
| """Upstream uses one-line call; our patch pattern expects that shape.""" | |
| return """\ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class Attention(torch.nn.Module): | |
| def forward(self, query_states, key_states, value_states, mask): | |
| attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask) | |
| return attn_output | |
| """ | |
| def test_apply_patch_injects_markers_and_sdpa_wrap(): | |
| with tempfile.TemporaryDirectory() as tmp: | |
| sub = os.path.join(tmp, "comfy", "ldm", "anima") | |
| os.makedirs(sub, exist_ok=True) | |
| path = os.path.join(sub, "model.py") | |
| with open(path, "w", encoding="utf-8") as f: | |
| f.write(_layout_minimal()) | |
| assert apply_comfy_anima_sdp_math_patch(tmp) is True | |
| with open(path, encoding="utf-8") as f: | |
| out = f.read() | |
| assert MARKER in out | |
| assert "sdpa_kernel" in out | |
| assert "SDPBackend.MATH" in out | |
| assert "scaled_dot_product_attention" in out | |
| # Idempotent | |
| assert apply_comfy_anima_sdp_math_patch(tmp) is True | |
| def test_apply_patch_fails_on_multiline_sdpa_call(): | |
| with tempfile.TemporaryDirectory() as tmp: | |
| sub = os.path.join(tmp, "comfy", "ldm", "anima") | |
| os.makedirs(sub, exist_ok=True) | |
| path = os.path.join(sub, "model.py") | |
| with open(path, "w", encoding="utf-8") as f: | |
| f.write(MIN_MODEL_PY) | |
| assert apply_comfy_anima_sdp_math_patch(tmp) is False | |