anima-gradio-zerogpu-space / tests /test_patch_comfy_anima.py
JSCPPProgrammer's picture
Bootstrap Anima MATH SDPA patch; wire patch in comfy_backend; add tests
df1664c verified
"""Unit tests for Comfy Anima MATH SDPA file patch (no Comfy / GPU)."""
import os
import tempfile
from src.patch_comfy_anima import MARKER, apply_comfy_anima_sdp_math_patch
MIN_MODEL_PY = '''\
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attention(torch.nn.Module):
def forward(self, query_states, key_states, value_states, mask):
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=mask
)
return attn_output
'''
def _layout_minimal():
"""Upstream uses one-line call; our patch pattern expects that shape."""
return """\
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attention(torch.nn.Module):
def forward(self, query_states, key_states, value_states, mask):
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
return attn_output
"""
def test_apply_patch_injects_markers_and_sdpa_wrap():
with tempfile.TemporaryDirectory() as tmp:
sub = os.path.join(tmp, "comfy", "ldm", "anima")
os.makedirs(sub, exist_ok=True)
path = os.path.join(sub, "model.py")
with open(path, "w", encoding="utf-8") as f:
f.write(_layout_minimal())
assert apply_comfy_anima_sdp_math_patch(tmp) is True
with open(path, encoding="utf-8") as f:
out = f.read()
assert MARKER in out
assert "sdpa_kernel" in out
assert "SDPBackend.MATH" in out
assert "scaled_dot_product_attention" in out
# Idempotent
assert apply_comfy_anima_sdp_math_patch(tmp) is True
def test_apply_patch_fails_on_multiline_sdpa_call():
with tempfile.TemporaryDirectory() as tmp:
sub = os.path.join(tmp, "comfy", "ldm", "anima")
os.makedirs(sub, exist_ok=True)
path = os.path.join(sub, "model.py")
with open(path, "w", encoding="utf-8") as f:
f.write(MIN_MODEL_PY)
assert apply_comfy_anima_sdp_math_patch(tmp) is False