File size: 3,388 Bytes
df1664c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""

Patch ComfyUI's Anima attention to use PyTorch's MATH SDPA backend on H200/ZeroGPU.



Env-only TORCH_CUDNN_SDPA_ENABLED=0 is not always enough; cuDNN can still be chosen

and fail with [cudnn_frontend] No valid execution plans built.



Idempotent: skips if marker ANIMA_FORCE_MATH_SDPA is already present.

"""
from __future__ import annotations

import os
import re
import sys

MARKER = "ANIMA_FORCE_MATH_SDPA"

HEADER = """

# {m} — injected by anima-gradio-zerogpu-space (force MATH SDPA; H200 / cuDNN frontend)

try:

    from torch.nn.attention import SDPBackend, sdpa_kernel

except ImportError:

    sdpa_kernel = None  # type: ignore[assignment]

    SDPBackend = None  # type: ignore[assignment]

""".format(
    m=MARKER,
)


def apply_comfy_anima_sdp_math_patch(comfy_root: str) -> bool:
    """

    Patch comfy/ldm/anima/model.py. Returns True if applied or already patched.

    Returns False if file missing or patch could not be applied.

    """
    rel = os.path.join("comfy", "ldm", "anima", "model.py")
    path = os.path.join(comfy_root, rel)
    if not os.path.isfile(path):
        print(f"[patch] skip: {path} not found", flush=True)
        return False
    try:
        with open(path, encoding="utf-8") as f:
            text = f.read()
    except OSError as e:
        print(f"[patch] read failed {path}: {e}", file=sys.stderr, flush=True)
        return False
    if MARKER in text:
        print(f"[patch] already applied: {rel}", flush=True)
        return True

    # Insert imports after `import torch.nn.functional as F`
    anchor = "import torch.nn.functional as F"
    if anchor not in text:
        print(f"[patch] anchor not found in {rel}", file=sys.stderr, flush=True)
        return False
    if "from torch.nn.attention import" in text:
        print(f"[patch] unexpected: attention imports already in {rel}", flush=True)

    text = text.replace(anchor, anchor + HEADER, 1)

    # Replace SDPA call in Attention.forward (single occurrence in upstream)
    pattern = re.compile(
        r"^(\s*)attn_output = F\.scaled_dot_product_attention\("
        r"query_states, key_states, value_states, attn_mask=mask\)\s*$",
        re.MULTILINE,
    )

    def _repl(m: re.Match[str]) -> str:
        ind = m.group(1)
        return (
            f"{ind}# {MARKER} (wrap)\n"
            f"{ind}if sdpa_kernel is not None and SDPBackend is not None:\n"
            f"{ind}    with sdpa_kernel(SDPBackend.MATH):\n"
            f"{ind}        attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)\n"
            f"{ind}else:\n"
            f"{ind}    attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)"
        )

    new_text, n = pattern.subn(_repl, text, count=1)
    if n != 1:
        print(
            f"[patch] SDPA line not found or multiple matches in {rel} (n={n})",
            file=sys.stderr,
            flush=True,
        )
        return False
    try:
        with open(path, "w", encoding="utf-8", newline="\n") as f:
            f.write(new_text)
    except OSError as e:
        print(f"[patch] write failed {path}: {e}", file=sys.stderr, flush=True)
        return False
    print(f"[patch] applied MATH SDPA wrap: {rel}", flush=True)
    return True