""" Patch ComfyUI's Anima attention to use PyTorch's MATH SDPA backend on H200/ZeroGPU. Env-only TORCH_CUDNN_SDPA_ENABLED=0 is not always enough; cuDNN can still be chosen and fail with [cudnn_frontend] No valid execution plans built. Idempotent: skips if marker ANIMA_FORCE_MATH_SDPA is already present. """ from __future__ import annotations import os import re import sys MARKER = "ANIMA_FORCE_MATH_SDPA" HEADER = """ # {m} — injected by anima-gradio-zerogpu-space (force MATH SDPA; H200 / cuDNN frontend) try: from torch.nn.attention import SDPBackend, sdpa_kernel except ImportError: sdpa_kernel = None # type: ignore[assignment] SDPBackend = None # type: ignore[assignment] """.format( m=MARKER, ) def apply_comfy_anima_sdp_math_patch(comfy_root: str) -> bool: """ Patch comfy/ldm/anima/model.py. Returns True if applied or already patched. Returns False if file missing or patch could not be applied. """ rel = os.path.join("comfy", "ldm", "anima", "model.py") path = os.path.join(comfy_root, rel) if not os.path.isfile(path): print(f"[patch] skip: {path} not found", flush=True) return False try: with open(path, encoding="utf-8") as f: text = f.read() except OSError as e: print(f"[patch] read failed {path}: {e}", file=sys.stderr, flush=True) return False if MARKER in text: print(f"[patch] already applied: {rel}", flush=True) return True # Insert imports after `import torch.nn.functional as F` anchor = "import torch.nn.functional as F" if anchor not in text: print(f"[patch] anchor not found in {rel}", file=sys.stderr, flush=True) return False if "from torch.nn.attention import" in text: print(f"[patch] unexpected: attention imports already in {rel}", flush=True) text = text.replace(anchor, anchor + HEADER, 1) # Replace SDPA call in Attention.forward (single occurrence in upstream) pattern = re.compile( r"^(\s*)attn_output = F\.scaled_dot_product_attention\(" r"query_states, key_states, value_states, attn_mask=mask\)\s*$", re.MULTILINE, ) def _repl(m: re.Match[str]) -> str: ind = m.group(1) return ( f"{ind}# {MARKER} (wrap)\n" f"{ind}if sdpa_kernel is not None and SDPBackend is not None:\n" f"{ind} with sdpa_kernel(SDPBackend.MATH):\n" f"{ind} attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)\n" f"{ind}else:\n" f"{ind} attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)" ) new_text, n = pattern.subn(_repl, text, count=1) if n != 1: print( f"[patch] SDPA line not found or multiple matches in {rel} (n={n})", file=sys.stderr, flush=True, ) return False try: with open(path, "w", encoding="utf-8", newline="\n") as f: f.write(new_text) except OSError as e: print(f"[patch] write failed {path}: {e}", file=sys.stderr, flush=True) return False print(f"[patch] applied MATH SDPA wrap: {rel}", flush=True) return True