File size: 1,455 Bytes
9d7cf7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import Callable, Optional

from .tripo2_transformer import Tripo2DiTModel


def default_set_attn_proc_func(
    name: str,
    hidden_size: int,
    cross_attention_dim: Optional[int],
    ori_attn_proc: object,
) -> object:
    return ori_attn_proc


def set_transformer_attn_processor(
    transformer: Tripo2DiTModel,
    set_self_attn_proc_func: Callable = default_set_attn_proc_func,
    set_cross_attn_proc_func: Callable = default_set_attn_proc_func,
) -> None:
    attn_procs = {}
    for name, attn_processor in transformer.attn_processors.items():
        hidden_size = transformer.config.width
        if name.endswith("attn1.processor"):
            # self attention
            attn_procs[name] = set_self_attn_proc_func(
                name, hidden_size, None, attn_processor
            )
        elif name.endswith("attn2.processor"):
            # cross attention
            cross_attention_dim = transformer.config.cross_attention_dim
            attn_procs[name] = set_cross_attn_proc_func(
                name, hidden_size, cross_attention_dim, attn_processor
            )
        elif name.endswith("attn2_2.processor"):
            # cross attention 2
            cross_attention_dim = transformer.config.cross_attention_2_dim
            attn_procs[name] = set_cross_attn_proc_func(
                name, hidden_size, cross_attention_dim, attn_processor
            )

    transformer.set_attn_processor(attn_procs)