|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, List
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from diffusers.models.normalization import RMSNorm
|
|
|
|
|
|
def setup_lynx_attention_layers(blocks, lynx_full, dim):
|
|
|
if lynx_full:
|
|
|
lynx_cross_dim = 5120
|
|
|
lynx_layers = len(blocks)
|
|
|
else:
|
|
|
lynx_cross_dim = 2048
|
|
|
lynx_layers = 20
|
|
|
for i, block in enumerate(blocks):
|
|
|
if i < lynx_layers:
|
|
|
block.cross_attn.to_k_ip = nn.Linear(lynx_cross_dim, dim , bias=lynx_full)
|
|
|
block.cross_attn.to_v_ip = nn.Linear(lynx_cross_dim, dim , bias=lynx_full)
|
|
|
else:
|
|
|
block.cross_attn.to_k_ip = None
|
|
|
block.cross_attn.to_v_ip = None
|
|
|
if lynx_full:
|
|
|
block.cross_attn.registers = nn.Parameter(torch.randn(1, 16, lynx_cross_dim) / dim**0.5)
|
|
|
block.cross_attn.norm_rms_k = None
|
|
|
block.self_attn.to_k_ref = nn.Linear(dim, dim, bias=True)
|
|
|
block.self_attn.to_v_ref = nn.Linear(dim, dim, bias=True)
|
|
|
else:
|
|
|
block.cross_attn.registers = None
|
|
|
block.cross_attn.norm_rms_k = RMSNorm(dim, eps=1e-5, elementwise_affine=False)
|
|
|
|
|
|
|
|
|
|
|
|
|