"""Unit tests for Comfy Anima MATH SDPA file patch (no Comfy / GPU).""" import os import tempfile from src.patch_comfy_anima import MARKER, apply_comfy_anima_sdp_math_patch MIN_MODEL_PY = '''\ import torch import torch.nn as nn import torch.nn.functional as F class Attention(torch.nn.Module): def forward(self, query_states, key_states, value_states, mask): attn_output = F.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=mask ) return attn_output ''' def _layout_minimal(): """Upstream uses one-line call; our patch pattern expects that shape.""" return """\ import torch import torch.nn as nn import torch.nn.functional as F class Attention(torch.nn.Module): def forward(self, query_states, key_states, value_states, mask): attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask) return attn_output """ def test_apply_patch_injects_markers_and_sdpa_wrap(): with tempfile.TemporaryDirectory() as tmp: sub = os.path.join(tmp, "comfy", "ldm", "anima") os.makedirs(sub, exist_ok=True) path = os.path.join(sub, "model.py") with open(path, "w", encoding="utf-8") as f: f.write(_layout_minimal()) assert apply_comfy_anima_sdp_math_patch(tmp) is True with open(path, encoding="utf-8") as f: out = f.read() assert MARKER in out assert "sdpa_kernel" in out assert "SDPBackend.MATH" in out assert "scaled_dot_product_attention" in out # Idempotent assert apply_comfy_anima_sdp_math_patch(tmp) is True def test_apply_patch_fails_on_multiline_sdpa_call(): with tempfile.TemporaryDirectory() as tmp: sub = os.path.join(tmp, "comfy", "ldm", "anima") os.makedirs(sub, exist_ok=True) path = os.path.join(sub, "model.py") with open(path, "w", encoding="utf-8") as f: f.write(MIN_MODEL_PY) assert apply_comfy_anima_sdp_math_patch(tmp) is False