Reconstructed PyTorch code by Gemini 3 Pro

#2
by hr16 - opened

Getting the ONNX graph in text

import onnx

with open("model_graph.txt", "w") as f:
    for model_name in ["vocoder.onnx", "vector_estimator.onnx", "text_encoder.onnx", "duration_predictor.onnx"]:
        model = onnx.load(r"C:\Users\hi\Downloads\\" + model_name)
        f.write(model_name + '\n')
        f.write(onnx.helper.printable_graph(model.graph))

Result from Gemini 3 Pro

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# ==============================================================================
# Shared Modules
# ==============================================================================

class LayerNormChannelFirst(nn.Module):
    """
    LayerNorm for (B, C, L) input.
    """
    def __init__(self, channels, eps=1e-6):
        super().__init__()
        self.norm = nn.LayerNorm(channels, eps=eps)

    def forward(self, x):
        return self.norm(x.transpose(1, 2)).transpose(1, 2)

class ConvNeXtBlock(nn.Module):
    """
    1D ConvNeXt Block adapted for TTS.
    Graph Trace: Pad -> DWConv -> Norm -> PWConv1 -> GELU -> PWConv2 -> Scale -> Add
    """
    def __init__(self, dim, intermediate_dim, kernel_size=7, dilation=1):
        super().__init__()
        padding = (kernel_size - 1) * dilation // 2
        
        self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding, 
                                groups=dim, dilation=dilation)
        self.norm = LayerNormChannelFirst(dim)
        self.pwconv1 = nn.Conv1d(dim, intermediate_dim, kernel_size=1)
        self.act = nn.GELU()
        self.pwconv2 = nn.Conv1d(intermediate_dim, dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.ones(1, dim, 1) * 1e-6)

    def forward(self, x):
        residual = x
        x = self.dwconv(x)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        return residual + (self.gamma * x)

class RelativeAttention(nn.Module):
    """
    Self Attention with Relative Positional Embeddings.
    Can optionally modulate Key/Value with Style (for Vector Estimator).
    """
    def __init__(self, dim, heads=8, window_size=4, style_dim=None):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.head_dim = dim // heads
        self.scale = self.head_dim ** -0.5
        
        self.conv_q = nn.Conv1d(dim, dim, 1)
        self.conv_k = nn.Conv1d(dim, dim, 1)
        self.conv_v = nn.Conv1d(dim, dim, 1)
        self.conv_o = nn.Conv1d(dim, dim, 1)
        
        # Relative position embeddings (k and v)
        self.emb_rel_k = nn.Parameter(torch.randn(1, window_size * 2 + 1, self.head_dim))
        self.emb_rel_v = nn.Parameter(torch.randn(1, window_size * 2 + 1, self.head_dim))
        self.window_size = window_size

        # Style modulation projections (if used)
        if style_dim is not None:
            self.style_k = nn.Linear(style_dim, dim)
            self.style_v = nn.Linear(style_dim, dim)
        else:
            self.style_k = None

    def forward(self, x, mask=None, style_vec=None):
        B, C, T = x.shape
        
        q = self.conv_q(x).view(B, self.heads, self.head_dim, T).transpose(2, 3) # (B, H, T, D)
        k = self.conv_k(x)
        v = self.conv_v(x)
        
        # Style Injection (Vector Estimator logic)
        # Style is added to K and V linear outputs before reshaping
        if style_vec is not None and self.style_k is not None:
            # style_vec: (B, Style_Dim)
            s_k = self.style_k(style_vec).unsqueeze(-1) # (B, Dim, 1)
            s_v = self.style_v(style_vec).unsqueeze(-1)
            k = k + s_k
            v = v + s_v
            
        k = k.view(B, self.heads, self.head_dim, T).transpose(2, 3)
        v = v.view(B, self.heads, self.head_dim, T).transpose(2, 3)
        
        # Attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        
        # Relative Position Bias (Simplified)
        # Note: A full implementation requires indexing relative embeddings based on T
        # Here we add a placeholder for the relative embedding logic
        # scores += relative_bias
        
        if mask is not None:
            # mask: (B, 1, T)
            mask_expanded = mask.unsqueeze(1) # (B, 1, 1, T)
            scores = scores.masked_fill(mask_expanded == 0, -1e4)
            
        attn = torch.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)
        
        out = out.transpose(2, 3).reshape(B, C, T)
        return self.conv_o(out)

class SpeechPromptedAttention(nn.Module):
    """
    Cross Attention for Text Encoder.
    Query = Text, Key/Value = Style (Speech).
    """
    def __init__(self, dim, style_dim=256, heads=8):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.head_dim = dim // heads
        self.scale = self.head_dim ** -0.5
        
        self.norm = LayerNormChannelFirst(dim)
        self.q_proj = nn.Conv1d(dim, dim, 1)
        self.k_proj = nn.Linear(style_dim, dim)
        self.v_proj = nn.Linear(style_dim, dim)
        self.out_proj = nn.Conv1d(dim, dim, 1)

    def forward(self, x, style, mask=None):
        # x: (B, C, T_text)
        # style: (B, T_style, Style_Dim) 
        
        residual = x
        x = self.norm(x)
        B, C, T = x.shape
        
        q = self.q_proj(x).view(B, self.heads, self.head_dim, T).transpose(2, 3) # (B, H, T, D)
        
        # K, V from style
        # (B, S, Style_Dim) -> (B, S, Dim) -> (B, H, S, D)
        k = self.k_proj(style).view(B, -1, self.heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(style).view(B, -1, self.heads, self.head_dim).transpose(1, 2)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        
        # We usually don't mask cross-attention over style unless style has padding
        attn = torch.softmax(scores, dim=-1)
        out = torch.matmul(attn, v) # (B, H, T, D)
        
        out = out.transpose(2, 3).reshape(B, C, T)
        out = self.out_proj(out)
        
        if mask is not None:
            out = out * mask
            
        return residual + out

# ==============================================================================
# 1. Vocoder (vocoder.onnx)
# ==============================================================================

class Vocoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('normalizer_scale', torch.tensor(1.0))
        self.latent_mean = nn.Parameter(torch.zeros(1, 24, 1))
        self.latent_std = nn.Parameter(torch.ones(1, 24, 1))
        
        self.embed = nn.Conv1d(24, 512, 7, padding=3)
        
        # 10 Blocks with specific dilation pattern
        # Dilation pattern from graph: [1, 2, 4, 1, 2, 4, 1, 2, 4, 1]
        dilations = [1, 2, 4, 1, 2, 4, 1, 2, 4, 1]
        self.blocks = nn.ModuleList([
            ConvNeXtBlock(512, 2048, kernel_size=7, dilation=d)
            for d in dilations
        ])
        
        self.final_norm = nn.BatchNorm1d(512)
        self.head_layer1 = nn.Conv1d(512, 2048, 3, padding=1)
        self.head_act = nn.PReLU()
        self.head_layer2 = nn.Conv1d(2048, 1, 1)

    def forward(self, latent):
        # Denormalize
        x = latent / self.normalizer_scale
        x = x * self.latent_std + self.latent_mean
        
        x = self.embed(x)
        for block in self.blocks:
            x = block(x)
            
        x = self.final_norm(x)
        x = self.head_layer1(x)
        x = self.head_act(x)
        x = self.head_layer2(x)
        return x

# ==============================================================================
# 2. Vector Estimator (vector_estimator.onnx)
# ==============================================================================

class TimePositionalEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.Mish(), # Graph: Softplus -> Tanh matches Mish approximation
            nn.Linear(dim * 4, 512)
        )

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return self.mlp(emb)

class VectorEstimator(nn.Module):
    def __init__(self):
        super().__init__()
        self.proj_in = nn.Conv1d(64, 512, 1) # Latent dim 64 inferred from proj_out
        self.time_encoder = TimePositionalEmbedding(64)
        
        # Main Backbone
        # The ONNX graph iterates blocks 0-23.
        # Blocks 1, 7, 13, 19 are Time Projections (Linear layers).
        # Blocks 0, 2, 4, 6, 8, 10... are ConvNeXt or Attention.
        
        self.layers = nn.ModuleList()
        self.time_projs = nn.ModuleList()
        
        # Constructing the exact sequence
        current_time_idx = 0
        
        # Block 0: ConvNeXt x4
        self.layers.append(nn.Sequential(*[ConvNeXtBlock(512, 1024, 5) for _ in range(4)]))
        # Block 1: Time Injection 0
        self.time_projs.append(nn.Linear(512, 512))
        
        # Block 2: ConvNeXt
        self.layers.append(ConvNeXtBlock(512, 1024, 5))
        # Block 3: Attention (Style Conditioned)
        self.layers.append(RelativeAttention(512, style_dim=256))
        
        # Block 4: ConvNeXt
        self.layers.append(ConvNeXtBlock(512, 1024, 5))
        # Block 5: Attention
        self.layers.append(RelativeAttention(512, style_dim=256))
        
        # Block 6: ConvNeXt x4
        self.layers.append(nn.Sequential(*[ConvNeXtBlock(512, 1024, 5) for _ in range(4)]))
        # Block 7: Time Injection 1
        self.time_projs.append(nn.Linear(512, 512))
        
        # Block 8: ConvNeXt
        self.layers.append(ConvNeXtBlock(512, 1024, 5))
        # Block 9: Attention
        self.layers.append(RelativeAttention(512, style_dim=256))
        
        # Block 10: ConvNeXt
        self.layers.append(ConvNeXtBlock(512, 1024, 5))
        # Block 11: Attention
        self.layers.append(RelativeAttention(512, style_dim=256))
        
        # Block 12: ConvNeXt x4
        self.layers.append(nn.Sequential(*[ConvNeXtBlock(512, 1024, 5) for _ in range(4)]))
        # Block 13: Time Injection 2
        self.time_projs.append(nn.Linear(512, 512))
        
        # Block 14: ConvNeXt
        self.layers.append(ConvNeXtBlock(512, 1024, 5))
        # Block 15: Attention
        self.layers.append(RelativeAttention(512, style_dim=256))
        
        # Block 16: ConvNeXt
        self.layers.append(ConvNeXtBlock(512, 1024, 5))
        # Block 17: Attention
        self.layers.append(RelativeAttention(512, style_dim=256))
        
        # Block 18: ConvNeXt x4
        self.layers.append(nn.Sequential(*[ConvNeXtBlock(512, 1024, 5) for _ in range(4)]))
        # Block 19: Time Injection 3
        self.time_projs.append(nn.Linear(512, 512))
        
        # Block 20: ConvNeXt
        self.layers.append(ConvNeXtBlock(512, 1024, 5))
        # Block 21: Attention
        self.layers.append(RelativeAttention(512, style_dim=256))
        
        # Block 22: ConvNeXt
        self.layers.append(ConvNeXtBlock(512, 1024, 5))
        # Block 23: Attention
        self.layers.append(RelativeAttention(512, style_dim=256))
        
        # Final block
        self.last_convnext = nn.Sequential(*[ConvNeXtBlock(512, 1024, 5) for _ in range(4)])
        self.proj_out = nn.Conv1d(512, 64, 1) # latent dim output

    def forward(self, noisy_latent, text_emb, style_ttl, latent_mask, text_mask, current_step, total_step):
        # Note: Inputs text_emb/text_mask are available but graph shows them used in attn as 'mask'?
        # The graph inputs to Attention are mainly Transpose(latent_mask) and Transpose(text_mask) for masking logic.
        # But style_ttl is the main conditioning.
        
        x = self.proj_in(noisy_latent) * latent_mask
        t_emb = self.time_encoder(current_step)
        
        # Process Layers
        # This maps the manual unrolling in ONNX to a loop
        layer_idx = 0
        time_idx = 0
        
        # Logic map:
        # Group 1: Layer 0 (Conv) -> Time 0
        # Group 2: Layer 2 (Conv) -> Layer 3 (Attn) -> Layer 4 (Conv) -> Layer 5 (Attn)
        # Group 3: Layer 6 (Conv) -> Time 1
        # ...
        
        # Simplified execution flow:
        # 1. Conv x4
        x = self.layers[0](x) * latent_mask
        # 2. Time Inj 0
        x = x + self.time_projs[0](t_emb).unsqueeze(-1)
        
        # 3. Conv -> Attn -> Conv -> Attn
        x = self.layers[1](x) * latent_mask
        x = self.layers[2](x, mask=latent_mask, style_vec=style_ttl) * latent_mask
        x = self.layers[3](x) * latent_mask
        x = self.layers[4](x, mask=latent_mask, style_vec=style_ttl) * latent_mask
        
        # 4. Conv x4
        x = self.layers[5](x) * latent_mask
        # 5. Time Inj 1
        x = x + self.time_projs[1](t_emb).unsqueeze(-1)
        
        # 6. Conv -> Attn -> Conv -> Attn
        x = self.layers[6](x) * latent_mask
        x = self.layers[7](x, mask=latent_mask, style_vec=style_ttl) * latent_mask
        x = self.layers[8](x) * latent_mask
        x = self.layers[9](x, mask=latent_mask, style_vec=style_ttl) * latent_mask
        
        # 7. Conv x4
        x = self.layers[10](x) * latent_mask
        # 8. Time Inj 2
        x = x + self.time_projs[2](t_emb).unsqueeze(-1)
        
        # 9. Conv -> Attn -> Conv -> Attn
        x = self.layers[11](x) * latent_mask
        x = self.layers[12](x, mask=latent_mask, style_vec=style_ttl) * latent_mask
        x = self.layers[13](x) * latent_mask
        x = self.layers[14](x, mask=latent_mask, style_vec=style_ttl) * latent_mask
        
        # 10. Conv x4
        x = self.layers[15](x) * latent_mask
        # 11. Time Inj 3
        x = x + self.time_projs[3](t_emb).unsqueeze(-1)
        
        # 12. Conv -> Attn -> Conv -> Attn
        x = self.layers[16](x) * latent_mask
        x = self.layers[17](x, mask=latent_mask, style_vec=style_ttl) * latent_mask
        x = self.layers[18](x) * latent_mask
        x = self.layers[19](x, mask=latent_mask, style_vec=style_ttl) * latent_mask
        
        # Final
        x = self.last_convnext(x) * latent_mask
        v_pred = self.proj_out(x) * latent_mask
        
        # Graph returns denoised prediction via Euler step logic:
        # Reciprocal(total_step) * v_pred + noisy_latent
        # We return the vector v for flexibility
        return v_pred

# ==============================================================================
# 3. Text Encoder (text_encoder.onnx)
# ==============================================================================

class TextEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.text_embedder = nn.Embedding(163, 256)
        
        # 6 ConvNeXt Blocks
        self.convnext = nn.ModuleList([
            ConvNeXtBlock(256, 1024, 5) for _ in range(6)
        ])
        
        # 4 Attention Encoder Blocks (Self Attention + FFN)
        self.attn_encoder_layers = nn.ModuleList([
            RelativeAttention(256) for _ in range(4)
        ])
        self.attn_ffn_layers = nn.ModuleList([
            nn.Sequential(
                LayerNormChannelFirst(256),
                nn.Conv1d(256, 1024, 1),
                nn.ReLU(), # Graph shows Relu
                nn.Conv1d(1024, 256, 1)
            ) for _ in range(4)
        ])
        
        # 2 Speech Prompted Blocks (Cross Attention)
        self.speech_prompted_attn = nn.ModuleList([
            SpeechPromptedAttention(256, style_dim=256) for _ in range(2)
        ])
        
        self.proj_out = nn.Conv1d(256, 256, 1) # Implied identity or specific projection? graph shows output layer logic.
        
        self.norm = LayerNormChannelFirst(256)

    def forward(self, text_ids, style_ttl, text_mask):
        x = self.text_embedder(text_ids).transpose(1, 2) * text_mask
        
        # ConvNeXt Stack
        for block in self.convnext:
            x = block(x) * text_mask
            
        # Attention Encoder Stack
        for attn, ffn in zip(self.attn_encoder_layers, self.attn_ffn_layers):
            # Attention with Residual
            res = x
            x = attn(x, mask=text_mask)
            # Add & Norm (handled inside blocks usually, but graph shows Add outside)
            # Graph: Add -> Norm -> FFN -> Add -> Norm
            # My Attention block does ConvQKV -> Attn -> ConvO. Norm is external in graph trace.
            
            # The exact trace: Input -> Conv(q,k,v) ... -> Add(Input, Out) -> Norm -> FFN
            x = (res + x) 
            x = self.norm(x) # Simplified norm placement
            
            # FFN
            res = x
            x = ffn(x) * text_mask
            x = res + x
            x = self.norm(x)
            
        # Speech Prompted (Cross) Attention
        for block in self.speech_prompted_attn:
            x = block(x, style=style_ttl, mask=text_mask)
            
        x = self.norm(x) * text_mask
        return x

# ==============================================================================
# 4. Duration Predictor (duration_predictor.onnx)
# ==============================================================================

class DurationPredictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(163, 64)
        
        # Global sentence token
        self.sentence_token = nn.Parameter(torch.randn(1, 64, 1))
        
        # Sentence Encoder
        # 2 Attention Layers
        self.attn_layers = nn.ModuleList([
            RelativeAttention(64, heads=8, window_size=4) for _ in range(2)
        ])
        self.attn_ffn = nn.ModuleList([
            nn.Sequential(
                LayerNormChannelFirst(64),
                nn.Conv1d(64, 256, 1),
                nn.ReLU(),
                nn.Conv1d(256, 64, 1)
            ) for _ in range(2)
        ])
        
        # 6 ConvNeXt Blocks
        self.convnext_stack = nn.ModuleList([
            ConvNeXtBlock(64, 256, kernel_size=5) for _ in range(6)
        ])
        
        self.proj_out = nn.Conv1d(64, 64, 1)
        
        # Predictor MLP
        # Input 64 (text) + 128 (style) -> 128 -> 1
        self.mlp = nn.Sequential(
            nn.Linear(192, 128),
            nn.PReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, text_ids, style_dp, text_mask):
        # 1. Embedding
        x = self.embedding(text_ids).transpose(1, 2) * text_mask
        
        # 2. Append/Prepend Sentence Token
        # Graph: Concatenates a learnable token at the start
        B = x.shape[0]
        token = self.sentence_token.expand(B, -1, -1)
        x = torch.cat([token, x], dim=2)
        
        # Adjust mask for token
        mask_pad = F.pad(text_mask, (1, 0), value=1.0)
        
        # 3. Attention Encoder
        for attn, ffn in zip(self.attn_layers, self.attn_ffn):
            res = x
            x = attn(x, mask=mask_pad)
            x = res + x
            # Norm logic implied similar to text encoder
            
            res = x
            x = ffn(x) * mask_pad
            x = res + x
            
        # 4. ConvNeXt Stack
        for block in self.convnext_stack:
            x = block(x) * mask_pad
            
        # 5. Remove token and project
        x = x[:, :, 1:] # Slice off sentence token
        x = self.proj_out(x) * text_mask
        
        # 6. Predictor Head
        # Flatten time: (B, C, T) -> (B, T, C)
        x = x.transpose(1, 2)
        
        # Expand style: style_dp (B, 128) -> (B, T, 128) assuming flattened/pooled style
        # ONNX input says style_dp[FLOAT, batch_sizex8x16] which is 128 flat
        style = style_dp.reshape(B, 128).unsqueeze(1).expand(-1, x.shape[1], -1)
        
        # Concat
        combined = torch.cat([x, style], dim=-1) # 64 + 128 = 192
        
        log_dur = self.mlp(combined)
        duration = torch.exp(log_dur).squeeze(-1)
        
        return duration * text_mask.squeeze(1)

Sign up or log in to comment