Reconstructed PyTorch code by Gemini 3 Pro
#2
by
hr16
- opened
Getting the ONNX graph in text
import onnx
with open("model_graph.txt", "w") as f:
for model_name in ["vocoder.onnx", "vector_estimator.onnx", "text_encoder.onnx", "duration_predictor.onnx"]:
model = onnx.load(r"C:\Users\hi\Downloads\\" + model_name)
f.write(model_name + '\n')
f.write(onnx.helper.printable_graph(model.graph))
Result from Gemini 3 Pro
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# ==============================================================================
# Shared Modules
# ==============================================================================
class LayerNormChannelFirst(nn.Module):
"""
LayerNorm for (B, C, L) input.
"""
def __init__(self, channels, eps=1e-6):
super().__init__()
self.norm = nn.LayerNorm(channels, eps=eps)
def forward(self, x):
return self.norm(x.transpose(1, 2)).transpose(1, 2)
class ConvNeXtBlock(nn.Module):
"""
1D ConvNeXt Block adapted for TTS.
Graph Trace: Pad -> DWConv -> Norm -> PWConv1 -> GELU -> PWConv2 -> Scale -> Add
"""
def __init__(self, dim, intermediate_dim, kernel_size=7, dilation=1):
super().__init__()
padding = (kernel_size - 1) * dilation // 2
self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding,
groups=dim, dilation=dilation)
self.norm = LayerNormChannelFirst(dim)
self.pwconv1 = nn.Conv1d(dim, intermediate_dim, kernel_size=1)
self.act = nn.GELU()
self.pwconv2 = nn.Conv1d(intermediate_dim, dim, kernel_size=1)
self.gamma = nn.Parameter(torch.ones(1, dim, 1) * 1e-6)
def forward(self, x):
residual = x
x = self.dwconv(x)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
return residual + (self.gamma * x)
class RelativeAttention(nn.Module):
"""
Self Attention with Relative Positional Embeddings.
Can optionally modulate Key/Value with Style (for Vector Estimator).
"""
def __init__(self, dim, heads=8, window_size=4, style_dim=None):
super().__init__()
self.dim = dim
self.heads = heads
self.head_dim = dim // heads
self.scale = self.head_dim ** -0.5
self.conv_q = nn.Conv1d(dim, dim, 1)
self.conv_k = nn.Conv1d(dim, dim, 1)
self.conv_v = nn.Conv1d(dim, dim, 1)
self.conv_o = nn.Conv1d(dim, dim, 1)
# Relative position embeddings (k and v)
self.emb_rel_k = nn.Parameter(torch.randn(1, window_size * 2 + 1, self.head_dim))
self.emb_rel_v = nn.Parameter(torch.randn(1, window_size * 2 + 1, self.head_dim))
self.window_size = window_size
# Style modulation projections (if used)
if style_dim is not None:
self.style_k = nn.Linear(style_dim, dim)
self.style_v = nn.Linear(style_dim, dim)
else:
self.style_k = None
def forward(self, x, mask=None, style_vec=None):
B, C, T = x.shape
q = self.conv_q(x).view(B, self.heads, self.head_dim, T).transpose(2, 3) # (B, H, T, D)
k = self.conv_k(x)
v = self.conv_v(x)
# Style Injection (Vector Estimator logic)
# Style is added to K and V linear outputs before reshaping
if style_vec is not None and self.style_k is not None:
# style_vec: (B, Style_Dim)
s_k = self.style_k(style_vec).unsqueeze(-1) # (B, Dim, 1)
s_v = self.style_v(style_vec).unsqueeze(-1)
k = k + s_k
v = v + s_v
k = k.view(B, self.heads, self.head_dim, T).transpose(2, 3)
v = v.view(B, self.heads, self.head_dim, T).transpose(2, 3)
# Attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# Relative Position Bias (Simplified)
# Note: A full implementation requires indexing relative embeddings based on T
# Here we add a placeholder for the relative embedding logic
# scores += relative_bias
if mask is not None:
# mask: (B, 1, T)
mask_expanded = mask.unsqueeze(1) # (B, 1, 1, T)
scores = scores.masked_fill(mask_expanded == 0, -1e4)
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(2, 3).reshape(B, C, T)
return self.conv_o(out)
class SpeechPromptedAttention(nn.Module):
"""
Cross Attention for Text Encoder.
Query = Text, Key/Value = Style (Speech).
"""
def __init__(self, dim, style_dim=256, heads=8):
super().__init__()
self.dim = dim
self.heads = heads
self.head_dim = dim // heads
self.scale = self.head_dim ** -0.5
self.norm = LayerNormChannelFirst(dim)
self.q_proj = nn.Conv1d(dim, dim, 1)
self.k_proj = nn.Linear(style_dim, dim)
self.v_proj = nn.Linear(style_dim, dim)
self.out_proj = nn.Conv1d(dim, dim, 1)
def forward(self, x, style, mask=None):
# x: (B, C, T_text)
# style: (B, T_style, Style_Dim)
residual = x
x = self.norm(x)
B, C, T = x.shape
q = self.q_proj(x).view(B, self.heads, self.head_dim, T).transpose(2, 3) # (B, H, T, D)
# K, V from style
# (B, S, Style_Dim) -> (B, S, Dim) -> (B, H, S, D)
k = self.k_proj(style).view(B, -1, self.heads, self.head_dim).transpose(1, 2)
v = self.v_proj(style).view(B, -1, self.heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# We usually don't mask cross-attention over style unless style has padding
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v) # (B, H, T, D)
out = out.transpose(2, 3).reshape(B, C, T)
out = self.out_proj(out)
if mask is not None:
out = out * mask
return residual + out
# ==============================================================================
# 1. Vocoder (vocoder.onnx)
# ==============================================================================
class Vocoder(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('normalizer_scale', torch.tensor(1.0))
self.latent_mean = nn.Parameter(torch.zeros(1, 24, 1))
self.latent_std = nn.Parameter(torch.ones(1, 24, 1))
self.embed = nn.Conv1d(24, 512, 7, padding=3)
# 10 Blocks with specific dilation pattern
# Dilation pattern from graph: [1, 2, 4, 1, 2, 4, 1, 2, 4, 1]
dilations = [1, 2, 4, 1, 2, 4, 1, 2, 4, 1]
self.blocks = nn.ModuleList([
ConvNeXtBlock(512, 2048, kernel_size=7, dilation=d)
for d in dilations
])
self.final_norm = nn.BatchNorm1d(512)
self.head_layer1 = nn.Conv1d(512, 2048, 3, padding=1)
self.head_act = nn.PReLU()
self.head_layer2 = nn.Conv1d(2048, 1, 1)
def forward(self, latent):
# Denormalize
x = latent / self.normalizer_scale
x = x * self.latent_std + self.latent_mean
x = self.embed(x)
for block in self.blocks:
x = block(x)
x = self.final_norm(x)
x = self.head_layer1(x)
x = self.head_act(x)
x = self.head_layer2(x)
return x
# ==============================================================================
# 2. Vector Estimator (vector_estimator.onnx)
# ==============================================================================
class TimePositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.Mish(), # Graph: Softplus -> Tanh matches Mish approximation
nn.Linear(dim * 4, 512)
)
def forward(self, t):
device = t.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = t[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return self.mlp(emb)
class VectorEstimator(nn.Module):
def __init__(self):
super().__init__()
self.proj_in = nn.Conv1d(64, 512, 1) # Latent dim 64 inferred from proj_out
self.time_encoder = TimePositionalEmbedding(64)
# Main Backbone
# The ONNX graph iterates blocks 0-23.
# Blocks 1, 7, 13, 19 are Time Projections (Linear layers).
# Blocks 0, 2, 4, 6, 8, 10... are ConvNeXt or Attention.
self.layers = nn.ModuleList()
self.time_projs = nn.ModuleList()
# Constructing the exact sequence
current_time_idx = 0
# Block 0: ConvNeXt x4
self.layers.append(nn.Sequential(*[ConvNeXtBlock(512, 1024, 5) for _ in range(4)]))
# Block 1: Time Injection 0
self.time_projs.append(nn.Linear(512, 512))
# Block 2: ConvNeXt
self.layers.append(ConvNeXtBlock(512, 1024, 5))
# Block 3: Attention (Style Conditioned)
self.layers.append(RelativeAttention(512, style_dim=256))
# Block 4: ConvNeXt
self.layers.append(ConvNeXtBlock(512, 1024, 5))
# Block 5: Attention
self.layers.append(RelativeAttention(512, style_dim=256))
# Block 6: ConvNeXt x4
self.layers.append(nn.Sequential(*[ConvNeXtBlock(512, 1024, 5) for _ in range(4)]))
# Block 7: Time Injection 1
self.time_projs.append(nn.Linear(512, 512))
# Block 8: ConvNeXt
self.layers.append(ConvNeXtBlock(512, 1024, 5))
# Block 9: Attention
self.layers.append(RelativeAttention(512, style_dim=256))
# Block 10: ConvNeXt
self.layers.append(ConvNeXtBlock(512, 1024, 5))
# Block 11: Attention
self.layers.append(RelativeAttention(512, style_dim=256))
# Block 12: ConvNeXt x4
self.layers.append(nn.Sequential(*[ConvNeXtBlock(512, 1024, 5) for _ in range(4)]))
# Block 13: Time Injection 2
self.time_projs.append(nn.Linear(512, 512))
# Block 14: ConvNeXt
self.layers.append(ConvNeXtBlock(512, 1024, 5))
# Block 15: Attention
self.layers.append(RelativeAttention(512, style_dim=256))
# Block 16: ConvNeXt
self.layers.append(ConvNeXtBlock(512, 1024, 5))
# Block 17: Attention
self.layers.append(RelativeAttention(512, style_dim=256))
# Block 18: ConvNeXt x4
self.layers.append(nn.Sequential(*[ConvNeXtBlock(512, 1024, 5) for _ in range(4)]))
# Block 19: Time Injection 3
self.time_projs.append(nn.Linear(512, 512))
# Block 20: ConvNeXt
self.layers.append(ConvNeXtBlock(512, 1024, 5))
# Block 21: Attention
self.layers.append(RelativeAttention(512, style_dim=256))
# Block 22: ConvNeXt
self.layers.append(ConvNeXtBlock(512, 1024, 5))
# Block 23: Attention
self.layers.append(RelativeAttention(512, style_dim=256))
# Final block
self.last_convnext = nn.Sequential(*[ConvNeXtBlock(512, 1024, 5) for _ in range(4)])
self.proj_out = nn.Conv1d(512, 64, 1) # latent dim output
def forward(self, noisy_latent, text_emb, style_ttl, latent_mask, text_mask, current_step, total_step):
# Note: Inputs text_emb/text_mask are available but graph shows them used in attn as 'mask'?
# The graph inputs to Attention are mainly Transpose(latent_mask) and Transpose(text_mask) for masking logic.
# But style_ttl is the main conditioning.
x = self.proj_in(noisy_latent) * latent_mask
t_emb = self.time_encoder(current_step)
# Process Layers
# This maps the manual unrolling in ONNX to a loop
layer_idx = 0
time_idx = 0
# Logic map:
# Group 1: Layer 0 (Conv) -> Time 0
# Group 2: Layer 2 (Conv) -> Layer 3 (Attn) -> Layer 4 (Conv) -> Layer 5 (Attn)
# Group 3: Layer 6 (Conv) -> Time 1
# ...
# Simplified execution flow:
# 1. Conv x4
x = self.layers[0](x) * latent_mask
# 2. Time Inj 0
x = x + self.time_projs[0](t_emb).unsqueeze(-1)
# 3. Conv -> Attn -> Conv -> Attn
x = self.layers[1](x) * latent_mask
x = self.layers[2](x, mask=latent_mask, style_vec=style_ttl) * latent_mask
x = self.layers[3](x) * latent_mask
x = self.layers[4](x, mask=latent_mask, style_vec=style_ttl) * latent_mask
# 4. Conv x4
x = self.layers[5](x) * latent_mask
# 5. Time Inj 1
x = x + self.time_projs[1](t_emb).unsqueeze(-1)
# 6. Conv -> Attn -> Conv -> Attn
x = self.layers[6](x) * latent_mask
x = self.layers[7](x, mask=latent_mask, style_vec=style_ttl) * latent_mask
x = self.layers[8](x) * latent_mask
x = self.layers[9](x, mask=latent_mask, style_vec=style_ttl) * latent_mask
# 7. Conv x4
x = self.layers[10](x) * latent_mask
# 8. Time Inj 2
x = x + self.time_projs[2](t_emb).unsqueeze(-1)
# 9. Conv -> Attn -> Conv -> Attn
x = self.layers[11](x) * latent_mask
x = self.layers[12](x, mask=latent_mask, style_vec=style_ttl) * latent_mask
x = self.layers[13](x) * latent_mask
x = self.layers[14](x, mask=latent_mask, style_vec=style_ttl) * latent_mask
# 10. Conv x4
x = self.layers[15](x) * latent_mask
# 11. Time Inj 3
x = x + self.time_projs[3](t_emb).unsqueeze(-1)
# 12. Conv -> Attn -> Conv -> Attn
x = self.layers[16](x) * latent_mask
x = self.layers[17](x, mask=latent_mask, style_vec=style_ttl) * latent_mask
x = self.layers[18](x) * latent_mask
x = self.layers[19](x, mask=latent_mask, style_vec=style_ttl) * latent_mask
# Final
x = self.last_convnext(x) * latent_mask
v_pred = self.proj_out(x) * latent_mask
# Graph returns denoised prediction via Euler step logic:
# Reciprocal(total_step) * v_pred + noisy_latent
# We return the vector v for flexibility
return v_pred
# ==============================================================================
# 3. Text Encoder (text_encoder.onnx)
# ==============================================================================
class TextEncoder(nn.Module):
def __init__(self):
super().__init__()
self.text_embedder = nn.Embedding(163, 256)
# 6 ConvNeXt Blocks
self.convnext = nn.ModuleList([
ConvNeXtBlock(256, 1024, 5) for _ in range(6)
])
# 4 Attention Encoder Blocks (Self Attention + FFN)
self.attn_encoder_layers = nn.ModuleList([
RelativeAttention(256) for _ in range(4)
])
self.attn_ffn_layers = nn.ModuleList([
nn.Sequential(
LayerNormChannelFirst(256),
nn.Conv1d(256, 1024, 1),
nn.ReLU(), # Graph shows Relu
nn.Conv1d(1024, 256, 1)
) for _ in range(4)
])
# 2 Speech Prompted Blocks (Cross Attention)
self.speech_prompted_attn = nn.ModuleList([
SpeechPromptedAttention(256, style_dim=256) for _ in range(2)
])
self.proj_out = nn.Conv1d(256, 256, 1) # Implied identity or specific projection? graph shows output layer logic.
self.norm = LayerNormChannelFirst(256)
def forward(self, text_ids, style_ttl, text_mask):
x = self.text_embedder(text_ids).transpose(1, 2) * text_mask
# ConvNeXt Stack
for block in self.convnext:
x = block(x) * text_mask
# Attention Encoder Stack
for attn, ffn in zip(self.attn_encoder_layers, self.attn_ffn_layers):
# Attention with Residual
res = x
x = attn(x, mask=text_mask)
# Add & Norm (handled inside blocks usually, but graph shows Add outside)
# Graph: Add -> Norm -> FFN -> Add -> Norm
# My Attention block does ConvQKV -> Attn -> ConvO. Norm is external in graph trace.
# The exact trace: Input -> Conv(q,k,v) ... -> Add(Input, Out) -> Norm -> FFN
x = (res + x)
x = self.norm(x) # Simplified norm placement
# FFN
res = x
x = ffn(x) * text_mask
x = res + x
x = self.norm(x)
# Speech Prompted (Cross) Attention
for block in self.speech_prompted_attn:
x = block(x, style=style_ttl, mask=text_mask)
x = self.norm(x) * text_mask
return x
# ==============================================================================
# 4. Duration Predictor (duration_predictor.onnx)
# ==============================================================================
class DurationPredictor(nn.Module):
def __init__(self):
super().__init__()
self.embedding = nn.Embedding(163, 64)
# Global sentence token
self.sentence_token = nn.Parameter(torch.randn(1, 64, 1))
# Sentence Encoder
# 2 Attention Layers
self.attn_layers = nn.ModuleList([
RelativeAttention(64, heads=8, window_size=4) for _ in range(2)
])
self.attn_ffn = nn.ModuleList([
nn.Sequential(
LayerNormChannelFirst(64),
nn.Conv1d(64, 256, 1),
nn.ReLU(),
nn.Conv1d(256, 64, 1)
) for _ in range(2)
])
# 6 ConvNeXt Blocks
self.convnext_stack = nn.ModuleList([
ConvNeXtBlock(64, 256, kernel_size=5) for _ in range(6)
])
self.proj_out = nn.Conv1d(64, 64, 1)
# Predictor MLP
# Input 64 (text) + 128 (style) -> 128 -> 1
self.mlp = nn.Sequential(
nn.Linear(192, 128),
nn.PReLU(),
nn.Linear(128, 1)
)
def forward(self, text_ids, style_dp, text_mask):
# 1. Embedding
x = self.embedding(text_ids).transpose(1, 2) * text_mask
# 2. Append/Prepend Sentence Token
# Graph: Concatenates a learnable token at the start
B = x.shape[0]
token = self.sentence_token.expand(B, -1, -1)
x = torch.cat([token, x], dim=2)
# Adjust mask for token
mask_pad = F.pad(text_mask, (1, 0), value=1.0)
# 3. Attention Encoder
for attn, ffn in zip(self.attn_layers, self.attn_ffn):
res = x
x = attn(x, mask=mask_pad)
x = res + x
# Norm logic implied similar to text encoder
res = x
x = ffn(x) * mask_pad
x = res + x
# 4. ConvNeXt Stack
for block in self.convnext_stack:
x = block(x) * mask_pad
# 5. Remove token and project
x = x[:, :, 1:] # Slice off sentence token
x = self.proj_out(x) * text_mask
# 6. Predictor Head
# Flatten time: (B, C, T) -> (B, T, C)
x = x.transpose(1, 2)
# Expand style: style_dp (B, 128) -> (B, T, 128) assuming flattened/pooled style
# ONNX input says style_dp[FLOAT, batch_sizex8x16] which is 128 flat
style = style_dp.reshape(B, 128).unsqueeze(1).expand(-1, x.shape[1], -1)
# Concat
combined = torch.cat([x, style], dim=-1) # 64 + 128 = 192
log_dur = self.mlp(combined)
duration = torch.exp(log_dur).squeeze(-1)
return duration * text_mask.squeeze(1)