pocket-tts-spanish-streaming / onnx_streaming_utils.py
ipsilondev's picture
Upload folder using huggingface_hub
1c50ba4 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
# =============================================================================
# Transformer Modules (Backbone & Mimi)
# =============================================================================
class ONNXStreamingMultiheadAttention(nn.Module):
"""ONNX-friendly Streaming Attention using Packed KV strategy.
Compatible with KevinAHM's export structure.
State Tuple:
0: KV Cache [2, B, H, MaxT, D]
1: Empty State [0] (Placeholder)
2: Step [1] (Int64)
"""
def __init__(self, original_attn):
super().__init__()
self.embed_dim = original_attn.embed_dim
self.num_heads = original_attn.num_heads
self.rope = original_attn.rope
self.in_proj = original_attn.in_proj
self.out_proj = original_attn.out_proj
# Copy weights
self.in_proj.weight = original_attn.in_proj.weight
self.in_proj.bias = original_attn.in_proj.bias
self.out_proj.weight = original_attn.out_proj.weight
self.out_proj.bias = original_attn.out_proj.bias
def forward(self, x, state_kv, state_empty, state_step):
# x: [B, T, D]
# state_kv: [2, B, MaxT, H, D_head] or [2, B, H, MaxT, D_head]
B, T, _ = x.shape
# 1. Project QKV
projected = self.in_proj(x)
d = self.embed_dim // self.num_heads
packed = projected.view(B, T, 3, self.num_heads, d)
q, k, v = torch.unbind(packed, dim=2) # [B, T, H, D_head]
# 2. Get current step
current_step = state_step[0]
# 3. Apply RoPE
q, k = self.rope(q, k, offset=current_step)
# 4. Update KV Cache
past_k = state_kv[0]
past_v = state_kv[1]
# Detect layout robustly
# If past_k shape is 4D: [B, MaxT, H, D] or [B, H, MaxT, D]
# Usually checking dimension size against num_heads works, but if B=num_heads or MaxT=num_heads it fails.
# We assume export script sets MAX_SEQ_LEN=500. Num heads is 8 or 16.
# So usually unambiguous if B=1.
is_h_major = (past_k.shape[1] == self.num_heads)
if past_k.shape[1] == 1 and self.num_heads != 1: # Batch=1, handle corner case
is_h_major = False
if past_k.shape[2] == self.num_heads:
is_h_major = False # Clearly T-major [B, MaxT, H, D]
# Force H-major if dimension 1 matches heads and dimension 2 is large (MaxT)
if past_k.shape[1] == self.num_heads and past_k.shape[2] >= T:
is_h_major = True
step_val = current_step.view(1, 1, 1, 1).to(torch.int64)
if not is_h_major:
# T-major: [B, MaxT, H, D]
# Sliding Window Logic:
# Shift buffer left by T, append new keys/values at end
# mask: we only attend to keys where time <= current_query_time
# buffer stores keys for times: [step - MaxT + 1 + T ... step + T] ?
# No. Buffer stores [step - MaxT + 1 ... step].
# New input k is at [step + 1 ... step + T].
# Updated buffer: [step + T - MaxT + 1 ... step + T].
# Implementation:
# 1. Concat old and new:
# cat([past_k, k], dim=1) -> [B, MaxT + T, H, D]
# 2. Slice last MaxT:
# present_k = cat(...)[..., -MaxT:, :, :]
# NOTE: For torch < 2.1, we might need manual slicing if negative indexing is tricky in export?
# But usually works.
H_dim = past_k.shape[2]
MaxT = past_k.shape[1]
print(f"[ONNX Export] Using Sliding Window (T-major) MaxT={MaxT}")
# Concat
k_cat = torch.cat([past_k, k], dim=1)
v_cat = torch.cat([past_v, v], dim=1)
# Slice
present_k = k_cat[:, -MaxT:, :, :]
present_v = v_cat[:, -MaxT:, :, :]
# For Attention (SDPA expects [B, H, T, D])
# q: [B, T, H, D] -> [B, H, T, D]
q_h = q.transpose(1, 2)
k_h = present_k.transpose(1, 2) # [B, H, MaxT, D]
v_h = present_v.transpose(1, 2)
# Mask setup
# Buffer times: [step + T - MaxT + 1 ... step + T]
# Query times: [step + 1 ... step + T]
# We need to construct mask (B, H, T, MaxT)
# Query time q_t attends to Key time k_t if k_t <= q_t.
# Create explicit time indices
# present_step_for_window = current_step + T
# window_start = present_step_for_window - MaxT + 1 (logical time of first element in buffer)
window_start = current_step + T - MaxT + 1
# Buffer times (1, 1, 1, MaxT)
# buf_times = window_start + arange(MaxT)
# Use floating point or int64? Int64.
buf_rng = torch.arange(MaxT, device=state_step.device, dtype=torch.int64).view(1, 1, 1, MaxT)
buf_times_mask = window_start.view(1, 1, 1, 1) + buf_rng
# Query times (1, 1, T, 1)
# q_times = (current_step + 1) + arange(T)
q_rng = torch.arange(T, device=state_step.device, dtype=torch.int64).view(1, 1, T, 1)
q_times_mask = (current_step + 1).view(1, 1, 1, 1) + q_rng
else:
# H-major: [B, H, MaxT, D]
# Same logic, dim=2 is Time
MaxT = past_k.shape[2]
print(f"[ONNX Export] Using Sliding Window (H-major) MaxT={MaxT}")
k_h_in = k.transpose(1, 2) # [B, H, T, D]
v_h_in = v.transpose(1, 2)
k_cat = torch.cat([past_k, k_h_in], dim=2)
v_cat = torch.cat([past_v, v_h_in], dim=2)
present_k = k_cat[:, :, -MaxT:, :]
present_v = v_cat[:, :, -MaxT:, :]
# For Attention
q_h = q.transpose(1, 2)
k_h = present_k
v_h = present_v
# Mask setup
window_start = current_step + T - MaxT + 1
buf_rng = torch.arange(MaxT, device=state_step.device, dtype=torch.int64).view(1, 1, 1, MaxT)
buf_times_mask = window_start.view(1, 1, 1, 1) + buf_rng
q_rng = torch.arange(T, device=state_step.device, dtype=torch.int64).view(1, 1, T, 1)
q_times_mask = (current_step + 1).view(1, 1, 1, 1) + q_rng
# Repack
present_kv = torch.stack([present_k, present_v], dim=0)
# Attention Mask
# Causal: k_t <= q_t
# Padding: k_t >= 1 (Assuming 1-based indexing for valid steps? RoPE is 0-based.
# But step starts at 0. So first token is at time 0.
# k_t >= 0.
causal_mask = (buf_times_mask <= q_times_mask) # [1, 1, T, MaxT]
valid_mask = (buf_times_mask >= 0)
mask_bool = causal_mask & valid_mask
attn_bias = torch.zeros_like(mask_bool, dtype=q.dtype)
attn_bias.masked_fill_(~mask_bool, float('-inf'))
# Check shapes before SDPA to avoid obscure errors
# q_h: [B, H, T, D]
# k_h: [B, H, MaxT, D]
# If T=1, SDPA output [B, H, 1, D]
out = F.scaled_dot_product_attention(q_h, k_h, v_h, attn_mask=attn_bias)
# Reshape Output
# out: [B, H, T, D] -> [B, T, H, D] -> [B, T, E]
out = out.transpose(1, 2).reshape(B, T, self.embed_dim)
out = self.out_proj(out)
new_step_val = current_step + T
present_step = new_step_val.view(1)
return out, present_kv, state_empty, present_step
class ONNXStreamingTransformer(nn.Module):
"""Wraps Backbone Transformer."""
def __init__(self, original_transformer):
super().__init__()
self.layers = nn.ModuleList()
for layer in original_transformer.layers:
self.layers.append(ONNXTransformerLayer(layer))
def forward(self, x, states):
# states is a flat list of tensors: 3 per layer
# [kv_0, empty_0, step_0, kv_1, empty_1, step_1, ...]
new_states = []
for i, layer in enumerate(self.layers):
base_idx = i * 3
skv = states[base_idx]
sempty = states[base_idx+1]
sstep = states[base_idx+2]
x, nkv, nempty, nstep = layer(x, skv, sempty, sstep)
new_states.extend([nkv, nempty, nstep])
return x, new_states
class ONNXTransformerLayer(nn.Module):
def __init__(self, layer):
super().__init__()
self.norm1 = layer.norm1
self.norm2 = layer.norm2
self.linear1 = layer.linear1
self.linear2 = layer.linear2
self.self_attn = ONNXStreamingMultiheadAttention(layer.self_attn)
self.layer_scale_1 = getattr(layer, 'layer_scale_1', None)
self.layer_scale_2 = getattr(layer, 'layer_scale_2', None)
self.pre_norm = getattr(layer, 'pre_norm', True) # Default to True (PocketTTS uses pre-norm?)
def forward(self, x, skv, sempty, sstep):
# Pre-Norm
residual = x
x = self.norm1(x)
x, nkv, nempty, nstep = self.self_attn(x, skv, sempty, sstep)
if self.layer_scale_1 is not None:
x = self.layer_scale_1(x)
x = residual + x
residual = x
x = self.norm2(x)
x = self.linear2(F.gelu(self.linear1(x)))
if self.layer_scale_2 is not None:
x = self.layer_scale_2(x)
x = residual + x
return x, nkv, nempty, nstep
class ONNXStreamingMimiTransformer(nn.Module):
"""Wraps Mimi Decoder Transformer.
"""
def __init__(self, original_transformer):
super().__init__()
self.transformer = ONNXStreamingTransformer(original_transformer.transformer)
self.input_proj = original_transformer.input_proj
self.output_proj = original_transformer.output_projs[0]
def forward(self, x, states):
if self.input_proj is not None:
x = self.input_proj(x)
x = x.transpose(1, 2) # [B, T, D] for transformer
x, new_states = self.transformer(x, states)
x = x.transpose(1, 2) # Back to [B, D, T]
x = self.output_proj(x)
return x, new_states
# =============================================================================
# Mimi Decoder Convolution Modules
# =============================================================================
class ONNXStreamingConv1d(nn.Module):
"""Stateless wrapper for StreamingConv1d that returns new state."""
def __init__(self, conv_module):
super().__init__()
self.conv = conv_module.conv
self.pad_mode = conv_module.pad_mode
self._stride_val = conv_module._stride
self._kernel_size_val = conv_module._kernel_size
self._effective_kernel_size_val = conv_module._effective_kernel_size
self.in_channels = conv_module.conv.in_channels
# Determine strict state size
# kernel_size - stride
self.state_size = self._kernel_size_val - self._stride_val
def forward(self, x, state_prev, state_first):
# state_prev: [B, C, K-S]
# state_first: [B] (Bool/Int)
B, C, T = x.shape
S = self._stride_val
TP = self.state_size
if TP > 0:
if self.pad_mode == "replicate":
init_val = x[..., :1].expand(-1, -1, TP)
is_first = state_first.view(B, 1, 1).expand(-1, C, TP)
effective_prev = torch.where(is_first > 0.5, init_val, state_prev)
x_padded = torch.cat([effective_prev, x], dim=-1)
else:
x_padded = torch.cat([state_prev, x], dim=-1)
else:
x_padded = x
y = self.conv(x_padded)
if TP > 0:
new_prev = x_padded[..., -TP:]
if self.pad_mode == "replicate":
new_first = torch.zeros_like(state_first) # False
else:
new_first = state_first
else:
new_prev = state_prev
new_first = state_first
return y, new_prev, new_first
class ONNXStreamingConvTranspose1d(nn.Module):
def __init__(self, convtr_module):
super().__init__()
self.convtr = convtr_module.convtr
self._stride_val = convtr_module._stride
self._kernel_size_val = convtr_module._kernel_size
self.out_channels = convtr_module.convtr.out_channels
self.state_size = self._kernel_size_val - self._stride_val
def forward(self, x, state_partial):
# x: [B, C_in, T_in]
# state_partial: [B, C_out, BufferLen]
# Valid output of ConvTranspose1d with S stride:
# T_out = T_in * S (for valid output part usually covers input, but here we just produce full convolution)
# Actually ConvTranspose1d produces L_out = (L_in - 1)*stride + kernel_size ...
y = self.convtr(x)
# y: [B, C_out, T_conv]
PT = state_partial.shape[-1]
if PT > 0:
# Overlap-add at start
# Ensure saved partial state matches expected length for addition
# If y is shorter than partial state? (Should rarely happen if stride <= kernel)
# Pad y if needed or clamp state
T_y = y.shape[-1]
if T_y < PT:
# If y is very short, we add to partial state and return updated partial state?
# But meaningful output requires full stride coverage.
pass
# Add state to start of y
# We assume y covers at least PT length if designed correctly?
# Actually, y covers (K) length for 1 input.
# PT = K - S.
# So y length is K.
# y[:PT] is the overlap part. y[PT:] is the non-overlap part (Stride S).
# Wait.
# Output of stream step is S samples.
# y length K.
# Overlap is K - S.
# So Valid Output = y excluding last (K-S).
# No.
# Valid Output = y[0:S] + state_partial (if partial state aligns with y[0:S]?)
# Standard overlap-add for streaming ConvTr:
# y is full conv result for current input frame(s).
# Output = y[..., :S] + state_partial (if it only covers S?)
# Buffer = y[..., S:] + (nothing, this is new buffer).
# BUT: PT = K - S.
# State is length PT.
# y has length L_out.
# If input is 1 frame. L_out = K.
# S is Stride.
# We want to output S samples.
# Buffer stores remaining K - S samples.
# AND we must add OLD buffer to current samples.
# Correct logic:
# 1. Add buffer to start of y.
# Buffer aligns with y[..., :PT].
# Wait, buffer has length PT.
# Does it align with start?
# Yes, the previous step's "tail" overlaps with current step's "head".
y_overlap = y[..., :PT] + state_partial
y_new = y[..., PT:]
# Reconstruct y with overlap added
y_combined = torch.cat([y_overlap, y_new], dim=-1)
# Now, extracted Valid Output and New Buffer.
# We want to emit S samples for each input frame.
# Total input frames T_in. Output samples T_in * S.
T_in = x.shape[-1]
S = self._stride_val
T_valid = T_in * S
# Valid output is the first T_valid samples of y_combined.
# New buffer is the rest.
y_out = y_combined[..., :T_valid]
new_partial = y_combined[..., T_valid:]
# Bias handling for buffer
# ConvTr bias is added to all y, so it's already in new_partial.
# But we might need to subtract bias from buffer to avoid double addition next time?
# PocketTTS streaming does:
# output = output + buffer
# buffer = ...
# If bias is in 'output', and we save 'buffer' which has bias.
# Next time: output + buffer -> has 2x bias?
# Yes. So we must subtract bias from new_partial.
if self.convtr.bias is not None:
bias_view = self.convtr.bias.view(1, -1, 1)
new_partial = new_partial - bias_view
return y_out, new_partial
else:
# No buffer (Stride >= Kernel ?).
return y, state_partial
class ONNXSeanetBlock(nn.Module):
def __init__(self, block):
super().__init__()
self.layers = nn.ModuleList()
for layer in block.block:
from pocket_tts.modules.conv import StreamingConv1d
if isinstance(layer, StreamingConv1d):
self.layers.append(ONNXStreamingConv1d(layer))
else:
self.layers.append(layer)
self.shortcut = None
if hasattr(block, 'shortcut') and block.shortcut:
from pocket_tts.modules.conv import StreamingConv1d
if isinstance(block.shortcut, StreamingConv1d):
self.shortcut = ONNXStreamingConv1d(block.shortcut)
else:
self.shortcut = block.shortcut
def forward(self, x, state_iter):
new_states = []
out = x
for layer in self.layers:
if isinstance(layer, ONNXStreamingConv1d):
eff_K = layer._effective_kernel_size_val
S = layer._stride_val
if (eff_K - S) > 0:
s_prev = next(state_iter)
s_first = next(state_iter)
out, ns_prev, ns_first = layer(out, s_prev, s_first)
new_states.extend([ns_prev, ns_first])
else:
# Only consumes pure buffer state (Prev) if any
# Or consumes 1 state which is empty buffer
s_prev = next(state_iter)
# Create dummy first locally
s_first = torch.zeros(1, dtype=torch.bool, device=out.device)
out, ns_prev, ns_first = layer(out, s_prev, s_first)
new_states.append(ns_prev)
elif isinstance(layer, ONNXStreamingConvTranspose1d):
s_prev = next(state_iter)
s_first = next(state_iter)
out, ns_prev = layer(out, s_prev)
# Pass through s_first as place holder
new_states.append(ns_prev)
new_states.append(s_first)
else:
out = layer(out)
if self.shortcut:
if isinstance(self.shortcut, ONNXStreamingConv1d):
s_prev = next(state_iter)
s_first = next(state_iter)
short, ns_prev, ns_first = self.shortcut(x, s_prev, s_first)
new_states.extend([ns_prev, ns_first])
x = short
else:
x = self.shortcut(x)
return x + out, new_states