Spaces:
Paused
Paused
| from typing import Callable, Optional | |
| from .tripo2_transformer import Tripo2DiTModel | |
| def default_set_attn_proc_func( | |
| name: str, | |
| hidden_size: int, | |
| cross_attention_dim: Optional[int], | |
| ori_attn_proc: object, | |
| ) -> object: | |
| return ori_attn_proc | |
| def set_transformer_attn_processor( | |
| transformer: Tripo2DiTModel, | |
| set_self_attn_proc_func: Callable = default_set_attn_proc_func, | |
| set_cross_attn_proc_func: Callable = default_set_attn_proc_func, | |
| ) -> None: | |
| attn_procs = {} | |
| for name, attn_processor in transformer.attn_processors.items(): | |
| hidden_size = transformer.config.width | |
| if name.endswith("attn1.processor"): | |
| # self attention | |
| attn_procs[name] = set_self_attn_proc_func( | |
| name, hidden_size, None, attn_processor | |
| ) | |
| elif name.endswith("attn2.processor"): | |
| # cross attention | |
| cross_attention_dim = transformer.config.cross_attention_dim | |
| attn_procs[name] = set_cross_attn_proc_func( | |
| name, hidden_size, cross_attention_dim, attn_processor | |
| ) | |
| elif name.endswith("attn2_2.processor"): | |
| # cross attention 2 | |
| cross_attention_dim = transformer.config.cross_attention_2_dim | |
| attn_procs[name] = set_cross_attn_proc_func( | |
| name, hidden_size, cross_attention_dim, attn_processor | |
| ) | |
| transformer.set_attn_processor(attn_procs) | |