| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | from typing import Optional, List |
| | import torch |
| | import torch.nn as nn |
| | from diffusers.models.normalization import RMSNorm |
| |
|
| | def setup_lynx_attention_layers(blocks, lynx_full, dim): |
| | if lynx_full: |
| | lynx_cross_dim = 5120 |
| | lynx_layers = len(blocks) |
| | else: |
| | lynx_cross_dim = 2048 |
| | lynx_layers = 20 |
| | for i, block in enumerate(blocks): |
| | if i < lynx_layers: |
| | block.cross_attn.to_k_ip = nn.Linear(lynx_cross_dim, dim , bias=lynx_full) |
| | block.cross_attn.to_v_ip = nn.Linear(lynx_cross_dim, dim , bias=lynx_full) |
| | else: |
| | block.cross_attn.to_k_ip = None |
| | block.cross_attn.to_v_ip = None |
| | if lynx_full: |
| | block.cross_attn.registers = nn.Parameter(torch.randn(1, 16, lynx_cross_dim) / dim**0.5) |
| | block.cross_attn.norm_rms_k = None |
| | block.self_attn.to_k_ref = nn.Linear(dim, dim, bias=True) |
| | block.self_attn.to_v_ref = nn.Linear(dim, dim, bias=True) |
| | else: |
| | block.cross_attn.registers = None |
| | block.cross_attn.norm_rms_k = RMSNorm(dim, eps=1e-5, elementwise_affine=False) |
| |
|
| |
|
| |
|
| |
|