File size: 2,129 Bytes
df1664c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""Unit tests for Comfy Anima MATH SDPA file patch (no Comfy / GPU)."""

import os
import tempfile

from src.patch_comfy_anima import MARKER, apply_comfy_anima_sdp_math_patch

MIN_MODEL_PY = '''\

import torch

import torch.nn as nn

import torch.nn.functional as F





class Attention(torch.nn.Module):

    def forward(self, query_states, key_states, value_states, mask):

        attn_output = F.scaled_dot_product_attention(

            query_states, key_states, value_states, attn_mask=mask

        )

        return attn_output

'''


def _layout_minimal():
    """Upstream uses one-line call; our patch pattern expects that shape."""
    return """\

import torch

import torch.nn as nn

import torch.nn.functional as F





class Attention(torch.nn.Module):

    def forward(self, query_states, key_states, value_states, mask):

        attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)

        return attn_output

"""


def test_apply_patch_injects_markers_and_sdpa_wrap():
    with tempfile.TemporaryDirectory() as tmp:
        sub = os.path.join(tmp, "comfy", "ldm", "anima")
        os.makedirs(sub, exist_ok=True)
        path = os.path.join(sub, "model.py")
        with open(path, "w", encoding="utf-8") as f:
            f.write(_layout_minimal())
        assert apply_comfy_anima_sdp_math_patch(tmp) is True
        with open(path, encoding="utf-8") as f:
            out = f.read()
        assert MARKER in out
        assert "sdpa_kernel" in out
        assert "SDPBackend.MATH" in out
        assert "scaled_dot_product_attention" in out
        # Idempotent
        assert apply_comfy_anima_sdp_math_patch(tmp) is True


def test_apply_patch_fails_on_multiline_sdpa_call():
    with tempfile.TemporaryDirectory() as tmp:
        sub = os.path.join(tmp, "comfy", "ldm", "anima")
        os.makedirs(sub, exist_ok=True)
        path = os.path.join(sub, "model.py")
        with open(path, "w", encoding="utf-8") as f:
            f.write(MIN_MODEL_PY)
        assert apply_comfy_anima_sdp_math_patch(tmp) is False