File size: 3,129 Bytes
863d06f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | import torch
import torch.nn as nn
from .text_encoder import ConvNeXtWrapper
class ReferenceEncoder(nn.Module):
def __init__(
self,
in_channels: int = 144,
d_model: int = 256,
hidden_dim: int = 1024,
num_blocks: int = 6,
num_tokens: int = 50,
num_heads: int = 2,
kernel_size: int = 5,
dilation_lst: list = None,
prototype_dim: int = 256,
n_units: int = 256,
style_value_dim: int = 256,
):
super().__init__()
self.d_model = d_model
self.num_tokens = num_tokens
if hidden_dim % d_model != 0:
raise ValueError(f"hidden_dim ({hidden_dim}) must be divisible by d_model ({d_model})")
mlp_ratio = hidden_dim // d_model
self.input_proj = nn.Conv1d(in_channels, d_model, kernel_size=1)
self.convnext = ConvNeXtWrapper(
d_model,
n_layers=num_blocks,
expansion_factor=mlp_ratio,
kernel_size=kernel_size,
dilation_lst=dilation_lst,
)
self.ref_keys = nn.Parameter(torch.randn(num_tokens, prototype_dim) * 0.02)
self.q_proj = nn.Linear(prototype_dim, n_units) if prototype_dim != n_units else nn.Identity()
self.out_proj = nn.Linear(n_units, style_value_dim) if n_units != style_value_dim else nn.Identity()
self.attn1 = nn.MultiheadAttention(
embed_dim=n_units, num_heads=num_heads, kdim=d_model, vdim=d_model, batch_first=True
)
self.attn2 = nn.MultiheadAttention(
embed_dim=n_units, num_heads=num_heads, kdim=d_model, vdim=d_model, batch_first=True
)
def forward(self, z_ref: torch.Tensor, mask: torch.Tensor = None):
B = z_ref.shape[0]
x = self.input_proj(z_ref)
x = self.convnext(x, mask=mask)
kv = x.transpose(1, 2)
key_padding_mask = None
if mask is not None:
key_padding_mask = (mask.squeeze(1) == 0)
q0 = self.ref_keys.unsqueeze(0).expand(B, -1, -1)
q0 = self.q_proj(q0)
q1, _ = self.attn1(query=q0, key=kv, value=kv, key_padding_mask=key_padding_mask, need_weights=False)
q2 = q0 + q1
out, _ = self.attn2(query=q2, key=kv, value=kv, key_padding_mask=key_padding_mask, need_weights=False)
return self.out_proj(out)
@staticmethod
def remap_legacy_state_dict(state_dict: dict) -> dict:
"""Remap pre-refactor checkpoints (per-layer pre-norm + FFN) onto current layout."""
remapped = {}
legacy_prefix_map = {
"attn_layers.0.attn.": "attn1.",
"attn_layers.1.attn.": "attn2.",
}
drop_substrings = (".norm_q.", ".norm_kv.", ".ffn.", "pos_emb.")
for k, v in state_dict.items():
if any(s in k for s in drop_substrings):
continue
new_key = k
for old, new in legacy_prefix_map.items():
if new_key.startswith(old):
new_key = new + new_key[len(old):]
break
remapped[new_key] = v
return remapped
|