import torch import torch.nn as nn from .text_encoder import ConvNeXtWrapper class ReferenceEncoder(nn.Module): def __init__( self, in_channels: int = 144, d_model: int = 256, hidden_dim: int = 1024, num_blocks: int = 6, num_tokens: int = 50, num_heads: int = 2, kernel_size: int = 5, dilation_lst: list = None, prototype_dim: int = 256, n_units: int = 256, style_value_dim: int = 256, ): super().__init__() self.d_model = d_model self.num_tokens = num_tokens if hidden_dim % d_model != 0: raise ValueError(f"hidden_dim ({hidden_dim}) must be divisible by d_model ({d_model})") mlp_ratio = hidden_dim // d_model self.input_proj = nn.Conv1d(in_channels, d_model, kernel_size=1) self.convnext = ConvNeXtWrapper( d_model, n_layers=num_blocks, expansion_factor=mlp_ratio, kernel_size=kernel_size, dilation_lst=dilation_lst, ) self.ref_keys = nn.Parameter(torch.randn(num_tokens, prototype_dim) * 0.02) self.q_proj = nn.Linear(prototype_dim, n_units) if prototype_dim != n_units else nn.Identity() self.out_proj = nn.Linear(n_units, style_value_dim) if n_units != style_value_dim else nn.Identity() self.attn1 = nn.MultiheadAttention( embed_dim=n_units, num_heads=num_heads, kdim=d_model, vdim=d_model, batch_first=True ) self.attn2 = nn.MultiheadAttention( embed_dim=n_units, num_heads=num_heads, kdim=d_model, vdim=d_model, batch_first=True ) def forward(self, z_ref: torch.Tensor, mask: torch.Tensor = None): B = z_ref.shape[0] x = self.input_proj(z_ref) x = self.convnext(x, mask=mask) kv = x.transpose(1, 2) key_padding_mask = None if mask is not None: key_padding_mask = (mask.squeeze(1) == 0) q0 = self.ref_keys.unsqueeze(0).expand(B, -1, -1) q0 = self.q_proj(q0) q1, _ = self.attn1(query=q0, key=kv, value=kv, key_padding_mask=key_padding_mask, need_weights=False) q2 = q0 + q1 out, _ = self.attn2(query=q2, key=kv, value=kv, key_padding_mask=key_padding_mask, need_weights=False) return self.out_proj(out) @staticmethod def remap_legacy_state_dict(state_dict: dict) -> dict: """Remap pre-refactor checkpoints (per-layer pre-norm + FFN) onto current layout.""" remapped = {} legacy_prefix_map = { "attn_layers.0.attn.": "attn1.", "attn_layers.1.attn.": "attn2.", } drop_substrings = (".norm_q.", ".norm_kv.", ".ffn.", "pos_emb.") for k, v in state_dict.items(): if any(s in k for s in drop_substrings): continue new_key = k for old, new in legacy_prefix_map.items(): if new_key.startswith(old): new_key = new + new_key[len(old):] break remapped[new_key] = v return remapped