Spaces:
Paused
Paused
File size: 1,455 Bytes
9d7cf7f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | from typing import Callable, Optional
from .tripo2_transformer import Tripo2DiTModel
def default_set_attn_proc_func(
name: str,
hidden_size: int,
cross_attention_dim: Optional[int],
ori_attn_proc: object,
) -> object:
return ori_attn_proc
def set_transformer_attn_processor(
transformer: Tripo2DiTModel,
set_self_attn_proc_func: Callable = default_set_attn_proc_func,
set_cross_attn_proc_func: Callable = default_set_attn_proc_func,
) -> None:
attn_procs = {}
for name, attn_processor in transformer.attn_processors.items():
hidden_size = transformer.config.width
if name.endswith("attn1.processor"):
# self attention
attn_procs[name] = set_self_attn_proc_func(
name, hidden_size, None, attn_processor
)
elif name.endswith("attn2.processor"):
# cross attention
cross_attention_dim = transformer.config.cross_attention_dim
attn_procs[name] = set_cross_attn_proc_func(
name, hidden_size, cross_attention_dim, attn_processor
)
elif name.endswith("attn2_2.processor"):
# cross attention 2
cross_attention_dim = transformer.config.cross_attention_2_dim
attn_procs[name] = set_cross_attn_proc_func(
name, hidden_size, cross_attention_dim, attn_processor
)
transformer.set_attn_processor(attn_procs)
|