Spaces:
Running on Zero
Running on Zero
File size: 2,129 Bytes
df1664c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | """Unit tests for Comfy Anima MATH SDPA file patch (no Comfy / GPU)."""
import os
import tempfile
from src.patch_comfy_anima import MARKER, apply_comfy_anima_sdp_math_patch
MIN_MODEL_PY = '''\
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attention(torch.nn.Module):
def forward(self, query_states, key_states, value_states, mask):
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=mask
)
return attn_output
'''
def _layout_minimal():
"""Upstream uses one-line call; our patch pattern expects that shape."""
return """\
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attention(torch.nn.Module):
def forward(self, query_states, key_states, value_states, mask):
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
return attn_output
"""
def test_apply_patch_injects_markers_and_sdpa_wrap():
with tempfile.TemporaryDirectory() as tmp:
sub = os.path.join(tmp, "comfy", "ldm", "anima")
os.makedirs(sub, exist_ok=True)
path = os.path.join(sub, "model.py")
with open(path, "w", encoding="utf-8") as f:
f.write(_layout_minimal())
assert apply_comfy_anima_sdp_math_patch(tmp) is True
with open(path, encoding="utf-8") as f:
out = f.read()
assert MARKER in out
assert "sdpa_kernel" in out
assert "SDPBackend.MATH" in out
assert "scaled_dot_product_attention" in out
# Idempotent
assert apply_comfy_anima_sdp_math_patch(tmp) is True
def test_apply_patch_fails_on_multiline_sdpa_call():
with tempfile.TemporaryDirectory() as tmp:
sub = os.path.join(tmp, "comfy", "ldm", "anima")
os.makedirs(sub, exist_ok=True)
path = os.path.join(sub, "model.py")
with open(path, "w", encoding="utf-8") as f:
f.write(MIN_MODEL_PY)
assert apply_comfy_anima_sdp_math_patch(tmp) is False
|