anima-gradio-zerogpu-space / src /patch_comfy_anima.py
JSCPPProgrammer's picture
Bootstrap Anima MATH SDPA patch; wire patch in comfy_backend; add tests
df1664c verified
"""
Patch ComfyUI's Anima attention to use PyTorch's MATH SDPA backend on H200/ZeroGPU.
Env-only TORCH_CUDNN_SDPA_ENABLED=0 is not always enough; cuDNN can still be chosen
and fail with [cudnn_frontend] No valid execution plans built.
Idempotent: skips if marker ANIMA_FORCE_MATH_SDPA is already present.
"""
from __future__ import annotations
import os
import re
import sys
MARKER = "ANIMA_FORCE_MATH_SDPA"
HEADER = """
# {m} — injected by anima-gradio-zerogpu-space (force MATH SDPA; H200 / cuDNN frontend)
try:
from torch.nn.attention import SDPBackend, sdpa_kernel
except ImportError:
sdpa_kernel = None # type: ignore[assignment]
SDPBackend = None # type: ignore[assignment]
""".format(
m=MARKER,
)
def apply_comfy_anima_sdp_math_patch(comfy_root: str) -> bool:
"""
Patch comfy/ldm/anima/model.py. Returns True if applied or already patched.
Returns False if file missing or patch could not be applied.
"""
rel = os.path.join("comfy", "ldm", "anima", "model.py")
path = os.path.join(comfy_root, rel)
if not os.path.isfile(path):
print(f"[patch] skip: {path} not found", flush=True)
return False
try:
with open(path, encoding="utf-8") as f:
text = f.read()
except OSError as e:
print(f"[patch] read failed {path}: {e}", file=sys.stderr, flush=True)
return False
if MARKER in text:
print(f"[patch] already applied: {rel}", flush=True)
return True
# Insert imports after `import torch.nn.functional as F`
anchor = "import torch.nn.functional as F"
if anchor not in text:
print(f"[patch] anchor not found in {rel}", file=sys.stderr, flush=True)
return False
if "from torch.nn.attention import" in text:
print(f"[patch] unexpected: attention imports already in {rel}", flush=True)
text = text.replace(anchor, anchor + HEADER, 1)
# Replace SDPA call in Attention.forward (single occurrence in upstream)
pattern = re.compile(
r"^(\s*)attn_output = F\.scaled_dot_product_attention\("
r"query_states, key_states, value_states, attn_mask=mask\)\s*$",
re.MULTILINE,
)
def _repl(m: re.Match[str]) -> str:
ind = m.group(1)
return (
f"{ind}# {MARKER} (wrap)\n"
f"{ind}if sdpa_kernel is not None and SDPBackend is not None:\n"
f"{ind} with sdpa_kernel(SDPBackend.MATH):\n"
f"{ind} attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)\n"
f"{ind}else:\n"
f"{ind} attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)"
)
new_text, n = pattern.subn(_repl, text, count=1)
if n != 1:
print(
f"[patch] SDPA line not found or multiple matches in {rel} (n={n})",
file=sys.stderr,
flush=True,
)
return False
try:
with open(path, "w", encoding="utf-8", newline="\n") as f:
f.write(new_text)
except OSError as e:
print(f"[patch] write failed {path}: {e}", file=sys.stderr, flush=True)
return False
print(f"[patch] applied MATH SDPA wrap: {rel}", flush=True)
return True