Spaces:
Running on Zero
Running on Zero
| """ | |
| Patch ComfyUI's Anima attention to use PyTorch's MATH SDPA backend on H200/ZeroGPU. | |
| Env-only TORCH_CUDNN_SDPA_ENABLED=0 is not always enough; cuDNN can still be chosen | |
| and fail with [cudnn_frontend] No valid execution plans built. | |
| Idempotent: skips if marker ANIMA_FORCE_MATH_SDPA is already present. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import re | |
| import sys | |
| MARKER = "ANIMA_FORCE_MATH_SDPA" | |
| HEADER = """ | |
| # {m} — injected by anima-gradio-zerogpu-space (force MATH SDPA; H200 / cuDNN frontend) | |
| try: | |
| from torch.nn.attention import SDPBackend, sdpa_kernel | |
| except ImportError: | |
| sdpa_kernel = None # type: ignore[assignment] | |
| SDPBackend = None # type: ignore[assignment] | |
| """.format( | |
| m=MARKER, | |
| ) | |
| def apply_comfy_anima_sdp_math_patch(comfy_root: str) -> bool: | |
| """ | |
| Patch comfy/ldm/anima/model.py. Returns True if applied or already patched. | |
| Returns False if file missing or patch could not be applied. | |
| """ | |
| rel = os.path.join("comfy", "ldm", "anima", "model.py") | |
| path = os.path.join(comfy_root, rel) | |
| if not os.path.isfile(path): | |
| print(f"[patch] skip: {path} not found", flush=True) | |
| return False | |
| try: | |
| with open(path, encoding="utf-8") as f: | |
| text = f.read() | |
| except OSError as e: | |
| print(f"[patch] read failed {path}: {e}", file=sys.stderr, flush=True) | |
| return False | |
| if MARKER in text: | |
| print(f"[patch] already applied: {rel}", flush=True) | |
| return True | |
| # Insert imports after `import torch.nn.functional as F` | |
| anchor = "import torch.nn.functional as F" | |
| if anchor not in text: | |
| print(f"[patch] anchor not found in {rel}", file=sys.stderr, flush=True) | |
| return False | |
| if "from torch.nn.attention import" in text: | |
| print(f"[patch] unexpected: attention imports already in {rel}", flush=True) | |
| text = text.replace(anchor, anchor + HEADER, 1) | |
| # Replace SDPA call in Attention.forward (single occurrence in upstream) | |
| pattern = re.compile( | |
| r"^(\s*)attn_output = F\.scaled_dot_product_attention\(" | |
| r"query_states, key_states, value_states, attn_mask=mask\)\s*$", | |
| re.MULTILINE, | |
| ) | |
| def _repl(m: re.Match[str]) -> str: | |
| ind = m.group(1) | |
| return ( | |
| f"{ind}# {MARKER} (wrap)\n" | |
| f"{ind}if sdpa_kernel is not None and SDPBackend is not None:\n" | |
| f"{ind} with sdpa_kernel(SDPBackend.MATH):\n" | |
| f"{ind} attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)\n" | |
| f"{ind}else:\n" | |
| f"{ind} attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)" | |
| ) | |
| new_text, n = pattern.subn(_repl, text, count=1) | |
| if n != 1: | |
| print( | |
| f"[patch] SDPA line not found or multiple matches in {rel} (n={n})", | |
| file=sys.stderr, | |
| flush=True, | |
| ) | |
| return False | |
| try: | |
| with open(path, "w", encoding="utf-8", newline="\n") as f: | |
| f.write(new_text) | |
| except OSError as e: | |
| print(f"[patch] write failed {path}: {e}", file=sys.stderr, flush=True) | |
| return False | |
| print(f"[patch] applied MATH SDPA wrap: {rel}", flush=True) | |
| return True | |