|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
| class ONNXStreamingMultiheadAttention(nn.Module):
|
| """ONNX-friendly Streaming Attention using Packed KV strategy.
|
|
|
| Compatible with KevinAHM's export structure.
|
|
|
| State Tuple:
|
| 0: KV Cache [2, B, H, MaxT, D]
|
| 1: Empty State [0] (Placeholder)
|
| 2: Step [1] (Int64)
|
| """
|
| def __init__(self, original_attn):
|
| super().__init__()
|
| self.embed_dim = original_attn.embed_dim
|
| self.num_heads = original_attn.num_heads
|
| self.rope = original_attn.rope
|
| self.in_proj = original_attn.in_proj
|
| self.out_proj = original_attn.out_proj
|
|
|
|
|
| self.in_proj.weight = original_attn.in_proj.weight
|
| self.in_proj.bias = original_attn.in_proj.bias
|
| self.out_proj.weight = original_attn.out_proj.weight
|
| self.out_proj.bias = original_attn.out_proj.bias
|
|
|
| def forward(self, x, state_kv, state_empty, state_step):
|
|
|
|
|
|
|
| B, T, _ = x.shape
|
|
|
|
|
| projected = self.in_proj(x)
|
| d = self.embed_dim // self.num_heads
|
| packed = projected.view(B, T, 3, self.num_heads, d)
|
| q, k, v = torch.unbind(packed, dim=2)
|
|
|
|
|
| current_step = state_step[0]
|
|
|
|
|
| q, k = self.rope(q, k, offset=current_step)
|
|
|
|
|
| past_k = state_kv[0]
|
| past_v = state_kv[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| is_h_major = (past_k.shape[1] == self.num_heads)
|
| if past_k.shape[1] == 1 and self.num_heads != 1:
|
| is_h_major = False
|
| if past_k.shape[2] == self.num_heads:
|
| is_h_major = False
|
|
|
|
|
| if past_k.shape[1] == self.num_heads and past_k.shape[2] >= T:
|
| is_h_major = True
|
|
|
| step_val = current_step.view(1, 1, 1, 1).to(torch.int64)
|
|
|
| if not is_h_major:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| H_dim = past_k.shape[2]
|
| MaxT = past_k.shape[1]
|
| print(f"[ONNX Export] Using Sliding Window (T-major) MaxT={MaxT}")
|
|
|
|
|
| k_cat = torch.cat([past_k, k], dim=1)
|
| v_cat = torch.cat([past_v, v], dim=1)
|
|
|
|
|
| present_k = k_cat[:, -MaxT:, :, :]
|
| present_v = v_cat[:, -MaxT:, :, :]
|
|
|
|
|
|
|
| q_h = q.transpose(1, 2)
|
| k_h = present_k.transpose(1, 2)
|
| v_h = present_v.transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| window_start = current_step + T - MaxT + 1
|
|
|
|
|
|
|
|
|
| buf_rng = torch.arange(MaxT, device=state_step.device, dtype=torch.int64).view(1, 1, 1, MaxT)
|
| buf_times_mask = window_start.view(1, 1, 1, 1) + buf_rng
|
|
|
|
|
|
|
| q_rng = torch.arange(T, device=state_step.device, dtype=torch.int64).view(1, 1, T, 1)
|
| q_times_mask = (current_step + 1).view(1, 1, 1, 1) + q_rng
|
|
|
| else:
|
|
|
|
|
|
|
| MaxT = past_k.shape[2]
|
| print(f"[ONNX Export] Using Sliding Window (H-major) MaxT={MaxT}")
|
|
|
| k_h_in = k.transpose(1, 2)
|
| v_h_in = v.transpose(1, 2)
|
|
|
| k_cat = torch.cat([past_k, k_h_in], dim=2)
|
| v_cat = torch.cat([past_v, v_h_in], dim=2)
|
|
|
| present_k = k_cat[:, :, -MaxT:, :]
|
| present_v = v_cat[:, :, -MaxT:, :]
|
|
|
|
|
| q_h = q.transpose(1, 2)
|
| k_h = present_k
|
| v_h = present_v
|
|
|
|
|
| window_start = current_step + T - MaxT + 1
|
|
|
| buf_rng = torch.arange(MaxT, device=state_step.device, dtype=torch.int64).view(1, 1, 1, MaxT)
|
| buf_times_mask = window_start.view(1, 1, 1, 1) + buf_rng
|
|
|
| q_rng = torch.arange(T, device=state_step.device, dtype=torch.int64).view(1, 1, T, 1)
|
| q_times_mask = (current_step + 1).view(1, 1, 1, 1) + q_rng
|
|
|
|
|
| present_kv = torch.stack([present_k, present_v], dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| causal_mask = (buf_times_mask <= q_times_mask)
|
| valid_mask = (buf_times_mask >= 0)
|
|
|
| mask_bool = causal_mask & valid_mask
|
|
|
| attn_bias = torch.zeros_like(mask_bool, dtype=q.dtype)
|
| attn_bias.masked_fill_(~mask_bool, float('-inf'))
|
|
|
|
|
|
|
|
|
|
|
|
|
| out = F.scaled_dot_product_attention(q_h, k_h, v_h, attn_mask=attn_bias)
|
|
|
|
|
|
|
| out = out.transpose(1, 2).reshape(B, T, self.embed_dim)
|
| out = self.out_proj(out)
|
|
|
| new_step_val = current_step + T
|
| present_step = new_step_val.view(1)
|
|
|
| return out, present_kv, state_empty, present_step
|
|
|
|
|
| class ONNXStreamingTransformer(nn.Module):
|
| """Wraps Backbone Transformer."""
|
| def __init__(self, original_transformer):
|
| super().__init__()
|
| self.layers = nn.ModuleList()
|
| for layer in original_transformer.layers:
|
| self.layers.append(ONNXTransformerLayer(layer))
|
|
|
| def forward(self, x, states):
|
|
|
|
|
|
|
| new_states = []
|
|
|
| for i, layer in enumerate(self.layers):
|
| base_idx = i * 3
|
| skv = states[base_idx]
|
| sempty = states[base_idx+1]
|
| sstep = states[base_idx+2]
|
|
|
| x, nkv, nempty, nstep = layer(x, skv, sempty, sstep)
|
|
|
| new_states.extend([nkv, nempty, nstep])
|
|
|
| return x, new_states
|
|
|
| class ONNXTransformerLayer(nn.Module):
|
| def __init__(self, layer):
|
| super().__init__()
|
| self.norm1 = layer.norm1
|
| self.norm2 = layer.norm2
|
| self.linear1 = layer.linear1
|
| self.linear2 = layer.linear2
|
| self.self_attn = ONNXStreamingMultiheadAttention(layer.self_attn)
|
| self.layer_scale_1 = getattr(layer, 'layer_scale_1', None)
|
| self.layer_scale_2 = getattr(layer, 'layer_scale_2', None)
|
| self.pre_norm = getattr(layer, 'pre_norm', True)
|
|
|
| def forward(self, x, skv, sempty, sstep):
|
|
|
| residual = x
|
| x = self.norm1(x)
|
| x, nkv, nempty, nstep = self.self_attn(x, skv, sempty, sstep)
|
|
|
| if self.layer_scale_1 is not None:
|
| x = self.layer_scale_1(x)
|
|
|
| x = residual + x
|
|
|
| residual = x
|
| x = self.norm2(x)
|
| x = self.linear2(F.gelu(self.linear1(x)))
|
|
|
| if self.layer_scale_2 is not None:
|
| x = self.layer_scale_2(x)
|
|
|
| x = residual + x
|
| return x, nkv, nempty, nstep
|
|
|
|
|
| class ONNXStreamingMimiTransformer(nn.Module):
|
| """Wraps Mimi Decoder Transformer.
|
| """
|
| def __init__(self, original_transformer):
|
| super().__init__()
|
| self.transformer = ONNXStreamingTransformer(original_transformer.transformer)
|
| self.input_proj = original_transformer.input_proj
|
| self.output_proj = original_transformer.output_projs[0]
|
|
|
| def forward(self, x, states):
|
| if self.input_proj is not None:
|
| x = self.input_proj(x)
|
| x = x.transpose(1, 2)
|
|
|
| x, new_states = self.transformer(x, states)
|
|
|
| x = x.transpose(1, 2)
|
| x = self.output_proj(x)
|
| return x, new_states
|
|
|
|
|
|
|
|
|
|
|
|
|
| class ONNXStreamingConv1d(nn.Module):
|
| """Stateless wrapper for StreamingConv1d that returns new state."""
|
| def __init__(self, conv_module):
|
| super().__init__()
|
| self.conv = conv_module.conv
|
| self.pad_mode = conv_module.pad_mode
|
| self._stride_val = conv_module._stride
|
| self._kernel_size_val = conv_module._kernel_size
|
| self._effective_kernel_size_val = conv_module._effective_kernel_size
|
| self.in_channels = conv_module.conv.in_channels
|
|
|
|
|
|
|
| self.state_size = self._kernel_size_val - self._stride_val
|
|
|
| def forward(self, x, state_prev, state_first):
|
|
|
|
|
| B, C, T = x.shape
|
| S = self._stride_val
|
| TP = self.state_size
|
|
|
| if TP > 0:
|
| if self.pad_mode == "replicate":
|
| init_val = x[..., :1].expand(-1, -1, TP)
|
| is_first = state_first.view(B, 1, 1).expand(-1, C, TP)
|
| effective_prev = torch.where(is_first > 0.5, init_val, state_prev)
|
| x_padded = torch.cat([effective_prev, x], dim=-1)
|
| else:
|
| x_padded = torch.cat([state_prev, x], dim=-1)
|
| else:
|
| x_padded = x
|
|
|
| y = self.conv(x_padded)
|
|
|
| if TP > 0:
|
| new_prev = x_padded[..., -TP:]
|
| if self.pad_mode == "replicate":
|
| new_first = torch.zeros_like(state_first)
|
| else:
|
| new_first = state_first
|
| else:
|
| new_prev = state_prev
|
| new_first = state_first
|
|
|
| return y, new_prev, new_first
|
|
|
| class ONNXStreamingConvTranspose1d(nn.Module):
|
| def __init__(self, convtr_module):
|
| super().__init__()
|
| self.convtr = convtr_module.convtr
|
| self._stride_val = convtr_module._stride
|
| self._kernel_size_val = convtr_module._kernel_size
|
| self.out_channels = convtr_module.convtr.out_channels
|
| self.state_size = self._kernel_size_val - self._stride_val
|
|
|
| def forward(self, x, state_partial):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| y = self.convtr(x)
|
|
|
|
|
| PT = state_partial.shape[-1]
|
|
|
| if PT > 0:
|
|
|
|
|
|
|
|
|
|
|
| T_y = y.shape[-1]
|
| if T_y < PT:
|
|
|
|
|
| pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| y_overlap = y[..., :PT] + state_partial
|
| y_new = y[..., PT:]
|
|
|
|
|
| y_combined = torch.cat([y_overlap, y_new], dim=-1)
|
|
|
|
|
|
|
|
|
| T_in = x.shape[-1]
|
| S = self._stride_val
|
| T_valid = T_in * S
|
|
|
|
|
|
|
|
|
| y_out = y_combined[..., :T_valid]
|
| new_partial = y_combined[..., T_valid:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if self.convtr.bias is not None:
|
| bias_view = self.convtr.bias.view(1, -1, 1)
|
| new_partial = new_partial - bias_view
|
|
|
| return y_out, new_partial
|
|
|
| else:
|
|
|
| return y, state_partial
|
|
|
| class ONNXSeanetBlock(nn.Module):
|
| def __init__(self, block):
|
| super().__init__()
|
| self.layers = nn.ModuleList()
|
| for layer in block.block:
|
| from pocket_tts.modules.conv import StreamingConv1d
|
| if isinstance(layer, StreamingConv1d):
|
| self.layers.append(ONNXStreamingConv1d(layer))
|
| else:
|
| self.layers.append(layer)
|
|
|
| self.shortcut = None
|
| if hasattr(block, 'shortcut') and block.shortcut:
|
| from pocket_tts.modules.conv import StreamingConv1d
|
| if isinstance(block.shortcut, StreamingConv1d):
|
| self.shortcut = ONNXStreamingConv1d(block.shortcut)
|
| else:
|
| self.shortcut = block.shortcut
|
|
|
| def forward(self, x, state_iter):
|
| new_states = []
|
| out = x
|
| for layer in self.layers:
|
| if isinstance(layer, ONNXStreamingConv1d):
|
| eff_K = layer._effective_kernel_size_val
|
| S = layer._stride_val
|
| if (eff_K - S) > 0:
|
| s_prev = next(state_iter)
|
| s_first = next(state_iter)
|
| out, ns_prev, ns_first = layer(out, s_prev, s_first)
|
| new_states.extend([ns_prev, ns_first])
|
| else:
|
|
|
|
|
| s_prev = next(state_iter)
|
|
|
| s_first = torch.zeros(1, dtype=torch.bool, device=out.device)
|
| out, ns_prev, ns_first = layer(out, s_prev, s_first)
|
| new_states.append(ns_prev)
|
| elif isinstance(layer, ONNXStreamingConvTranspose1d):
|
| s_prev = next(state_iter)
|
| s_first = next(state_iter)
|
| out, ns_prev = layer(out, s_prev)
|
|
|
| new_states.append(ns_prev)
|
| new_states.append(s_first)
|
| else:
|
| out = layer(out)
|
|
|
| if self.shortcut:
|
| if isinstance(self.shortcut, ONNXStreamingConv1d):
|
| s_prev = next(state_iter)
|
| s_first = next(state_iter)
|
| short, ns_prev, ns_first = self.shortcut(x, s_prev, s_first)
|
| new_states.extend([ns_prev, ns_first])
|
| x = short
|
| else:
|
| x = self.shortcut(x)
|
|
|
| return x + out, new_states
|
|
|
|
|