notmax123 commited on
Commit
cd6425e
·
verified ·
1 Parent(s): 9650d7d

Upload folder using huggingface_hub

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
backbone.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92d02eac3c6f9b2c4b347e87d18c825b5a5e44158c341ce62714f20324cc74b5
3
+ size 132644653
backbone_keys.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b465da29ba2cff4cfcb7c8ae7b70420bb88fd9d9d7a306e3a81cfee303297550
3
+ size 132592456
length_pred.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b31959fb99a04a7b907d63ed8edf5e202e484064e41ad3ba8723c7bf9fc04a8c
3
+ size 2055214
length_pred_style.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ab708ab77ea16e2d7bde0f906fb13b46feeba12802859bec1208cec8eed3ee0
3
+ size 1418679
obfuscate_onnx.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Secure ONNX Export Pipeline for Light-BlueTTS
4
+ ==============================================
5
+
6
+ Combines two obfuscation techniques:
7
+ 1. Self-Contained Output Scrambling - permanently scrambles layer weights
8
+ (Conv1d, Linear, Embedding) and injects inverse Gather nodes in the
9
+ computation graph. The model still produces identical outputs, but
10
+ stored weights are "poisoned". Uses Dynamic Zero anti-optimizer trick
11
+ to prevent ONNX Runtime constant folding.
12
+ 2. ONNX Name Obfuscation - randomizes all internal node/tensor/weight names
13
+ so the graph is unreadable in tools like Netron.
14
+
15
+ Usage:
16
+ python obfuscate_onnx.py --config hebrew/tts.json \\
17
+ --ttl_ckpt ckpt_step_580000.pt \\
18
+ --ae_ckpt ae_latest_newer.pt \\
19
+ --dp_ckpt duration_predictor_final.pt \\
20
+ --onnx_dir onnx_obfuscated
21
+ """
22
+
23
+ import os
24
+ import argparse
25
+ import glob as glob_mod
26
+ import random
27
+ import string
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn as nn
32
+ import onnx
33
+ from onnx import numpy_helper as onnx_numpy_helper
34
+
35
+ # Model imports
36
+ from models.text2latent.text_encoder import TextEncoder
37
+ from models.text2latent.vf_estimator import VectorFieldEstimator
38
+ from models.autoencoder.latent_decoder import LatentDecoder1D
39
+ from models.text2latent.dp_network import DPNetwork
40
+ from models.text2latent.reference_encoder import ReferenceEncoder
41
+ from models.utils import load_ttl_config
42
+
43
+ # Reuse ONNX-safe MHA replacement and wrappers from export pipeline
44
+ from export_onnx import (
45
+ _replace_mha_with_safe,
46
+ VectorFieldEstimatorWrapper,
47
+ VectorFieldEstimatorKeysWrapper,
48
+ export_one,
49
+ )
50
+
51
+
52
+ # =====================================================================
53
+ # Part 1: Self-Contained Output Scrambling (Weight Poisoning)
54
+ # =====================================================================
55
+
56
+ class SelfScrambledConv1d(nn.Module):
57
+ """
58
+ Wraps an existing Conv1d (or subclass like CausalConv1d).
59
+ Permanently scrambles its output-channel weights, and un-scrambles
60
+ the output at runtime via an inverse permutation index.
61
+
62
+ Uses the "Dynamic Zero" anti-optimizer trick: the inverse permutation
63
+ indices are added to a runtime-derived zero value (x[0] * 0), creating
64
+ a data dependency that prevents ONNX Runtime's constant folding from
65
+ pre-computing the Gather operation. The result is mathematically
66
+ identical (always +0) but the optimizer cannot prove this statically.
67
+ """
68
+ def __init__(self, original_conv):
69
+ super().__init__()
70
+ self.conv = original_conv
71
+ out_channels = self.conv.out_channels
72
+
73
+ perm = torch.randperm(out_channels)
74
+ inv_perm = torch.empty_like(perm)
75
+ inv_perm[perm] = torch.arange(out_channels)
76
+
77
+ self.register_buffer('inv_shuffle_indices', inv_perm)
78
+
79
+ with torch.no_grad():
80
+ self.conv.weight.data = self.conv.weight.data[perm, :, :]
81
+ if self.conv.bias is not None:
82
+ self.conv.bias.data = self.conv.bias.data[perm]
83
+
84
+ def forward(self, x):
85
+ x = self.conv(x)
86
+ # Anti-optimizer: derive a runtime zero from the data tensor.
87
+ # Mathematically always 0, but ORT cannot constant-fold it.
88
+ dynamic_zero = (x.reshape(-1)[0] * 0.0).long()
89
+ safe_indices = self.inv_shuffle_indices + dynamic_zero
90
+ x = x[:, safe_indices, :]
91
+ return x
92
+
93
+
94
+ class SelfScrambledLinear(nn.Module):
95
+ """
96
+ Same principle as SelfScrambledConv1d but for nn.Linear layers.
97
+ Scrambles output features and immediately un-scrambles them.
98
+ Uses the same Dynamic Zero anti-optimizer trick.
99
+ """
100
+ def __init__(self, original_linear):
101
+ super().__init__()
102
+ self.linear = original_linear
103
+ out_features = self.linear.out_features
104
+
105
+ perm = torch.randperm(out_features)
106
+ inv_perm = torch.empty_like(perm)
107
+ inv_perm[perm] = torch.arange(out_features)
108
+
109
+ self.register_buffer('inv_shuffle_indices', inv_perm)
110
+
111
+ with torch.no_grad():
112
+ self.linear.weight.data = self.linear.weight.data[perm, :]
113
+ if self.linear.bias is not None:
114
+ self.linear.bias.data = self.linear.bias.data[perm]
115
+
116
+ def forward(self, x):
117
+ x = self.linear(x)
118
+ dynamic_zero = (x.reshape(-1)[0] * 0.0).long()
119
+ safe_indices = self.inv_shuffle_indices + dynamic_zero
120
+ x = x[..., safe_indices]
121
+ return x
122
+
123
+
124
+ class SelfScrambledEmbedding(nn.Module):
125
+ """
126
+ Wraps an existing nn.Embedding. Permanently scrambles the embedding
127
+ table rows (num_embeddings dimension), and remaps input indices at
128
+ runtime via an inverse permutation.
129
+
130
+ Uses the Dynamic Zero anti-optimizer trick: the index remap is added to
131
+ a runtime-derived zero value (x.float()[0] * 0.0), creating a data
132
+ dependency that prevents ONNX Runtime's constant folding from
133
+ pre-computing the Gather remap. Mathematically identical (always +0)
134
+ but the optimizer cannot prove this statically.
135
+ """
136
+ def __init__(self, original_embedding):
137
+ super().__init__()
138
+ self.embedding = original_embedding
139
+ num_embeddings = self.embedding.num_embeddings
140
+
141
+ perm = torch.randperm(num_embeddings)
142
+ inv_perm = torch.empty_like(perm)
143
+ inv_perm[perm] = torch.arange(num_embeddings)
144
+
145
+ self.register_buffer('inv_shuffle_indices', inv_perm)
146
+
147
+ with torch.no_grad():
148
+ self.embedding.weight.data = self.embedding.weight.data[perm, :]
149
+
150
+ def forward(self, x):
151
+ # Anti-optimizer: derive a runtime zero tied to the dynamic input x.
152
+ # Cast int indices to float for the zero derivation, then back to long.
153
+ dynamic_zero = (x.float().reshape(-1)[0] * 0.0).long()
154
+ safe_indices = self.inv_shuffle_indices + dynamic_zero
155
+ return self.embedding(safe_indices[x])
156
+
157
+
158
+ def scramble_all_layers(module, scramble_linear=False, scramble_embedding=False, _depth=0):
159
+ """
160
+ Recursively replace layers with their SelfScrambled equivalents.
161
+
162
+ Handles:
163
+ - nn.Conv1d (groups==1, out_channels>1) -> SelfScrambledConv1d
164
+ - nn.Linear (out_features>1, if scramble_linear=True) -> SelfScrambledLinear
165
+ - nn.Embedding (num_embeddings>1, if scramble_embedding=True) -> SelfScrambledEmbedding
166
+
167
+ Skips grouped/depthwise convolutions (groups>1) to avoid breaking them.
168
+ Returns count of scrambled layers.
169
+ """
170
+ count = 0
171
+ for name, child in list(module.named_children()):
172
+ # Conv1d (including subclasses like CausalConv1d)
173
+ if isinstance(child, nn.Conv1d) and child.groups == 1 and child.out_channels > 1:
174
+ setattr(module, name, SelfScrambledConv1d(child))
175
+ count += 1
176
+ elif scramble_linear and isinstance(child, nn.Linear) and child.out_features > 1:
177
+ setattr(module, name, SelfScrambledLinear(child))
178
+ count += 1
179
+ elif scramble_embedding and isinstance(child, nn.Embedding) and child.num_embeddings > 1:
180
+ setattr(module, name, SelfScrambledEmbedding(child))
181
+ count += 1
182
+ else:
183
+ # Recurse into containers (Sequential, ModuleList, custom blocks, etc.)
184
+ count += scramble_all_layers(child, scramble_linear=scramble_linear,
185
+ scramble_embedding=scramble_embedding, _depth=_depth + 1)
186
+ return count
187
+
188
+
189
+ # =====================================================================
190
+ # Part 2: ONNX Name Obfuscation
191
+ # =====================================================================
192
+
193
+ def _random_name(length=12):
194
+ """Generate a random alphanumeric string."""
195
+ return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
196
+
197
+
198
+ def obfuscate_onnx_names(input_path, output_path=None, keep_io_names=True):
199
+ """
200
+ Load an ONNX model and replace all internal names recursively (including
201
+ inside If/Loop subgraphs) with random strings.
202
+ """
203
+ if output_path is None:
204
+ output_path = input_path
205
+
206
+ model = onnx.load(input_path)
207
+
208
+ name_map = {}
209
+ preserved = set()
210
+
211
+ # 1. Preserve main I/O names so inference doesn't break
212
+ if keep_io_names:
213
+ for inp in model.graph.input: preserved.add(inp.name)
214
+ for out in model.graph.output: preserved.add(out.name)
215
+
216
+ def remap(old_name):
217
+ if old_name == "": return ""
218
+ if old_name in preserved: return old_name
219
+ if old_name not in name_map:
220
+ name_map[old_name] = _random_name()
221
+ return name_map[old_name]
222
+
223
+ def process_graph(g):
224
+ """Recursively process a graph and all its subgraphs."""
225
+ # Value info
226
+ for vi in g.value_info:
227
+ vi.name = remap(vi.name)
228
+
229
+ # Initializers
230
+ for init in g.initializer:
231
+ init.name = remap(init.name)
232
+
233
+ # Graph inputs and outputs (handles both main graph and subgraphs safely)
234
+ for inp in g.input:
235
+ inp.name = remap(inp.name)
236
+ for out in g.output:
237
+ out.name = remap(out.name)
238
+
239
+ # Nodes
240
+ for node in g.node:
241
+ if node.name:
242
+ node.name = remap(node.name)
243
+ for i, n in enumerate(node.input):
244
+ node.input[i] = remap(n)
245
+ for i, n in enumerate(node.output):
246
+ node.output[i] = remap(n)
247
+
248
+ # --- CRITICAL FIX: Recurse into Subgraphs (If, Loop, etc.) ---
249
+ for attr in node.attribute:
250
+ if attr.type == onnx.AttributeProto.GRAPH:
251
+ process_graph(attr.g)
252
+ elif attr.type == onnx.AttributeProto.GRAPHS:
253
+ for sub_g in attr.graphs:
254
+ process_graph(sub_g)
255
+
256
+ # Run the recursive obfuscator starting at the top-level graph
257
+ process_graph(model.graph)
258
+
259
+ onnx.save(model, output_path)
260
+ return len(name_map)
261
+
262
+
263
+ # =====================================================================
264
+ # Part 2b: Shuffle Key Extraction (Optimizer-Proof)
265
+ # =====================================================================
266
+
267
+ def extract_shuffle_keys(onnx_path):
268
+ """
269
+ Extract inv_shuffle_indices arrays from ONNX initializers and convert
270
+ them to dynamic graph inputs.
271
+
272
+ Because the indices are no longer embedded as Constants, ONNX Runtime's
273
+ graph optimizer cannot constant-fold the Gather nodes away. The model
274
+ file stays permanently scrambled even if an attacker runs:
275
+
276
+ sess_options.graph_optimization_level = ORT_ENABLE_ALL
277
+ sess_options.optimized_model_filepath = "cracked.onnx"
278
+
279
+ The extracted arrays must be fed at inference time via keys.npz.
280
+
281
+ Returns:
282
+ dict of {input_name: numpy_array} for the extracted keys.
283
+ """
284
+ model = onnx.load(onnx_path)
285
+ graph = model.graph
286
+
287
+ extracted = {}
288
+ to_remove = []
289
+
290
+ for init in graph.initializer:
291
+ if "inv_shuffle_indices" in init.name:
292
+ arr = onnx_numpy_helper.to_array(init)
293
+ extracted[init.name] = arr
294
+ to_remove.append(init)
295
+
296
+ if not to_remove:
297
+ return extracted
298
+
299
+ # Remove from initializers (no longer a constant in the file)
300
+ for init in to_remove:
301
+ graph.initializer.remove(init)
302
+
303
+ # Ensure each extracted key is registered as a dynamic graph input
304
+ existing_inputs = {inp.name for inp in graph.input}
305
+ for name, arr in extracted.items():
306
+ if name not in existing_inputs:
307
+ inp = onnx.helper.make_tensor_value_info(
308
+ name, onnx.TensorProto.INT64, list(arr.shape)
309
+ )
310
+ graph.input.append(inp)
311
+
312
+ onnx.save(model, onnx_path)
313
+ return extracted
314
+
315
+
316
+ # =====================================================================
317
+ # Part 3: Full Obfuscation Pipeline
318
+ # =====================================================================
319
+
320
+ def get_latest_ckpt(dir_path):
321
+ """Find the latest checkpoint (by step number or mtime) in a directory."""
322
+ ckpt_step = glob_mod.glob(os.path.join(dir_path, "ckpt_step_*.pt"))
323
+ if ckpt_step:
324
+ def step_num(p):
325
+ try:
326
+ return int(os.path.basename(p).split("ckpt_step_")[-1].split(".pt")[0])
327
+ except Exception:
328
+ return -1
329
+ ckpt_step.sort(key=step_num)
330
+ return ckpt_step[-1]
331
+ ckpts = glob_mod.glob(os.path.join(dir_path, "*.pt"))
332
+ return max(ckpts, key=os.path.getmtime) if ckpts else None
333
+
334
+
335
+ def main():
336
+ parser = argparse.ArgumentParser(
337
+ description="Obfuscate Light-BlueTTS models: weight scrambling + ONNX name randomization"
338
+ )
339
+ parser.add_argument("--config", type=str, default="hebrew/tts.json",
340
+ help="Path to tts.json config")
341
+ parser.add_argument("--onnx_dir", type=str, default="onnx_obfuscated",
342
+ help="Output directory for obfuscated ONNX models")
343
+ parser.add_argument("--ckpt_dir", type=str, default=None,
344
+ help="Text2Latent checkpoint dir (auto-finds latest ckpt_step_*.pt)")
345
+ parser.add_argument("--ttl_ckpt", type=str, default=None,
346
+ help="Explicit TTL checkpoint file (overrides --ckpt_dir)")
347
+ parser.add_argument("--ae_ckpt", type=str, default="ae_latest_newer.pt",
348
+ help="AutoEncoder checkpoint (.pt)")
349
+ parser.add_argument("--dp_ckpt", type=str, default="duration_predictor_final.pt",
350
+ help="Duration Predictor checkpoint (.pt)")
351
+ parser.add_argument("--scramble-linear", action="store_true",
352
+ help="Also scramble nn.Linear layers (extra obfuscation)")
353
+ parser.add_argument("--scramble-embedding", action="store_true",
354
+ help="Also scramble nn.Embedding layers (token table obfuscation)")
355
+ parser.add_argument("--scramble-all", action="store_true",
356
+ help="Scramble all supported layer types (Conv1d + Linear + Embedding)")
357
+ parser.add_argument("--no-name-obfuscation", action="store_true",
358
+ help="Skip ONNX name randomization (only do weight scrambling)")
359
+ parser.add_argument("--extract-keys", action="store_true",
360
+ help="Also extract shuffle keys to keys.npz (defense-in-depth, not needed with Dynamic Zero)")
361
+ args = parser.parse_args()
362
+
363
+ device = "cpu"
364
+ do_name_obfuscation = not args.no_name_obfuscation
365
+ do_key_extraction = args.extract_keys
366
+
367
+ # Resolve scrambling flags (--scramble-all enables everything)
368
+ scramble_linear = args.scramble_linear or args.scramble_all
369
+ scramble_embedding = args.scramble_embedding or args.scramble_all
370
+
371
+ # ---- Load Config ----
372
+ if not os.path.exists(args.config):
373
+ print(f"[ERROR] Config not found: {args.config}")
374
+ return
375
+ cfg = load_ttl_config(args.config)
376
+ print(f"[INFO] Loaded config: {args.config}")
377
+
378
+ os.makedirs(args.onnx_dir, exist_ok=True)
379
+
380
+ # ---- Find Checkpoints ----
381
+ # TTL checkpoint
382
+ ttl_ckpt = args.ttl_ckpt
383
+ if ttl_ckpt is None and args.ckpt_dir:
384
+ ttl_ckpt = get_latest_ckpt(args.ckpt_dir)
385
+ if ttl_ckpt is None:
386
+ # Try to find .pt files matching ckpt_step_*.pt in current dir
387
+ candidates = glob_mod.glob("ckpt_step_*.pt")
388
+ if candidates:
389
+ candidates.sort()
390
+ ttl_ckpt = candidates[-1]
391
+
392
+ ae_ckpt = args.ae_ckpt
393
+ dp_ckpt = args.dp_ckpt
394
+
395
+ print(f"[INFO] TTL checkpoint: {ttl_ckpt or '(none - random weights)'}")
396
+ print(f"[INFO] AE checkpoint: {ae_ckpt}")
397
+ print(f"[INFO] DP checkpoint: {dp_ckpt}")
398
+
399
+ # ---- Load Checkpoint State Dicts ----
400
+ t2l_state = torch.load(ttl_ckpt, map_location=device) if ttl_ckpt and os.path.exists(ttl_ckpt) else {}
401
+ ae_state = torch.load(ae_ckpt, map_location=device) if os.path.exists(ae_ckpt) else {}
402
+ dp_state_raw = torch.load(dp_ckpt, map_location=device) if os.path.exists(dp_ckpt) else None
403
+
404
+ # ---- Dimensions from config ----
405
+ vocab_size = cfg["vocab_size"]
406
+ compressed_channels = cfg["compressed_channels"]
407
+ latent_dim = cfg["latent_dim"]
408
+ chunk_compress_factor = cfg["chunk_compress_factor"]
409
+ te_d_model = cfg["te_d_model"]
410
+ se_d_model = cfg["se_d_model"]
411
+ se_n_style = cfg["se_n_style"]
412
+
413
+ total_scrambled = 0
414
+ exported_files = []
415
+
416
+ # ==============================================================
417
+ # 1. Reference Encoder
418
+ # ==============================================================
419
+ print("\n[1/5] Reference Encoder")
420
+ ref_enc = ReferenceEncoder(
421
+ in_channels=compressed_channels,
422
+ d_model=se_d_model,
423
+ hidden_dim=cfg["se_hidden_dim"],
424
+ num_blocks=cfg["se_num_blocks"],
425
+ num_tokens=se_n_style,
426
+ num_heads=cfg["se_n_heads"],
427
+ ).to(device).eval()
428
+ if "reference_encoder" in t2l_state:
429
+ ref_enc.load_state_dict(t2l_state["reference_encoder"], strict=True)
430
+ _replace_mha_with_safe(ref_enc)
431
+
432
+ n = scramble_all_layers(ref_enc, scramble_linear=scramble_linear, scramble_embedding=scramble_embedding)
433
+ total_scrambled += n
434
+ print(f" Scrambled {n} layers")
435
+
436
+ B, C_lat, T_audio_ref, T_text, T_lat = 1, compressed_channels, 256, 32, 100
437
+ z_ref = torch.randn(B, C_lat, T_audio_ref, device=device)
438
+ ref_mask = torch.ones(B, 1, T_audio_ref, device=device)
439
+
440
+ ref_path = os.path.join(args.onnx_dir, "reference_encoder.onnx")
441
+ export_one(ref_enc, ref_path, (z_ref, ref_mask),
442
+ input_names=["z_ref", "mask"],
443
+ output_names=["ref_values", "ref_keys"],
444
+ dynamic_axes={"z_ref": {2: "T_ref_in"}, "mask": {2: "T_ref_in"}})
445
+ exported_files.append(ref_path)
446
+
447
+ # ==============================================================
448
+ # 2. Text Encoder
449
+ # ==============================================================
450
+ print("\n[2/5] Text Encoder")
451
+ text_enc = TextEncoder(
452
+ vocab_size=vocab_size,
453
+ d_model=te_d_model,
454
+ n_conv_layers=cfg["te_convnext_layers"],
455
+ n_attn_layers=cfg["te_attn_n_layers"],
456
+ expansion_factor=cfg["te_expansion_factor"],
457
+ p_dropout=cfg["te_attn_p_dropout"],
458
+ ).to(device).eval()
459
+ if "text_encoder" in t2l_state:
460
+ text_enc.load_state_dict(t2l_state["text_encoder"], strict=True)
461
+
462
+ n = scramble_all_layers(text_enc, scramble_linear=scramble_linear, scramble_embedding=scramble_embedding)
463
+ total_scrambled += n
464
+ print(f" Scrambled {n} layers")
465
+
466
+ text_ids = torch.zeros(B, T_text, dtype=torch.long, device=device)
467
+ text_mask = torch.ones(B, 1, T_text, device=device)
468
+ style_ttl = torch.randn(B, se_n_style, se_d_model, device=device)
469
+
470
+ te_path = os.path.join(args.onnx_dir, "text_encoder.onnx")
471
+ export_one(text_enc, te_path, (text_ids, style_ttl, text_mask),
472
+ input_names=["text_ids", "style_ttl", "text_mask"],
473
+ output_names=["text_emb"],
474
+ dynamic_axes={
475
+ "text_ids": {1: "T_text"}, "style_ttl": {1: "T_ref"},
476
+ "text_mask": {2: "T_text"}, "text_emb": {2: "T_text"},
477
+ })
478
+ exported_files.append(te_path)
479
+
480
+ # ==============================================================
481
+ # 3. Vector Field Estimator (two variants)
482
+ # ==============================================================
483
+ print("\n[3/5] Vector Field Estimator")
484
+ vf = VectorFieldEstimator(
485
+ in_channels=compressed_channels,
486
+ out_channels=compressed_channels,
487
+ hidden_channels=cfg["vf_hidden"],
488
+ text_dim=cfg["vf_text_dim"],
489
+ style_dim=cfg["vf_style_dim"],
490
+ num_style_tokens=se_n_style,
491
+ num_superblocks=cfg["vf_n_blocks"],
492
+ time_embed_dim=cfg["vf_time_dim"],
493
+ rope_gamma=cfg["vf_rotary_scale"],
494
+ ).to(device).eval()
495
+ if "vf_estimator" in t2l_state:
496
+ vf.load_state_dict(t2l_state["vf_estimator"], strict=False)
497
+
498
+ # Sync baked-in style key with text encoder
499
+ with torch.no_grad():
500
+ vf.style_key.copy_(text_enc.ref_keys)
501
+
502
+ n = scramble_all_layers(vf, scramble_linear=scramble_linear, scramble_embedding=scramble_embedding)
503
+ total_scrambled += n
504
+ print(f" Scrambled {n} layers")
505
+
506
+ noisy_latent = torch.randn(B, C_lat, T_lat, device=device)
507
+ latent_mask = torch.ones(B, 1, T_lat, device=device)
508
+ text_emb = torch.randn(B, se_d_model, T_text, device=device)
509
+ current_step = torch.tensor([0.0], device=device)
510
+ total_step = torch.tensor([1.0], device=device)
511
+
512
+ # Variant A: no style_keys input
513
+ vf_wrapped = VectorFieldEstimatorWrapper(vf)
514
+ vf_path = os.path.join(args.onnx_dir, "vector_estimator.onnx")
515
+ vf_inputs = (noisy_latent, text_emb, style_ttl, latent_mask, text_mask, current_step, total_step)
516
+ vf_names = ["noisy_latent", "text_emb", "style_ttl", "latent_mask", "text_mask", "current_step", "total_step"]
517
+ export_one(vf_wrapped, vf_path, vf_inputs,
518
+ input_names=vf_names, output_names=["denoised_latent"],
519
+ dynamic_axes={
520
+ "noisy_latent": {2: "T_lat"}, "text_emb": {2: "T_text"},
521
+ "style_ttl": {1: "T_ref"}, "latent_mask": {2: "T_lat"},
522
+ "text_mask": {2: "T_text"}, "denoised_latent": {2: "T_lat"},
523
+ })
524
+ exported_files.append(vf_path)
525
+
526
+ # Variant B: with style_keys input (for CFG)
527
+ style_keys_dummy = text_enc.ref_keys.expand(B, -1, -1).to(device)
528
+ vf_keys_wrapped = VectorFieldEstimatorKeysWrapper(vf)
529
+ vfk_path = os.path.join(args.onnx_dir, "vector_estimator_keys.onnx")
530
+ vfk_inputs = (noisy_latent, text_emb, style_ttl, style_keys_dummy, latent_mask, text_mask, current_step, total_step)
531
+ vfk_names = ["noisy_latent", "text_emb", "style_ttl", "style_keys", "latent_mask", "text_mask", "current_step", "total_step"]
532
+ export_one(vf_keys_wrapped, vfk_path, vfk_inputs,
533
+ input_names=vfk_names, output_names=["denoised_latent"],
534
+ dynamic_axes={
535
+ "noisy_latent": {2: "T_lat"}, "text_emb": {2: "T_text"},
536
+ "style_ttl": {1: "T_ref"}, "style_keys": {1: "T_ref"},
537
+ "latent_mask": {2: "T_lat"}, "text_mask": {2: "T_text"},
538
+ "denoised_latent": {2: "T_lat"},
539
+ })
540
+ exported_files.append(vfk_path)
541
+
542
+ # ==============================================================
543
+ # 4. Vocoder (Latent Decoder)
544
+ # ==============================================================
545
+ print("\n[4/5] Vocoder")
546
+ ae_dec_cfg = cfg["ae_dec_cfg"]
547
+ vocoder = LatentDecoder1D(cfg=ae_dec_cfg).to(device).eval()
548
+ if "decoder" in ae_state:
549
+ vocoder.load_state_dict(ae_state["decoder"], strict=True)
550
+
551
+ n = scramble_all_layers(vocoder, scramble_linear=scramble_linear, scramble_embedding=scramble_embedding)
552
+ total_scrambled += n
553
+ print(f" Scrambled {n} layers")
554
+
555
+ C_dec = latent_dim
556
+ latent_dec = torch.randn(B, C_dec, T_lat * chunk_compress_factor, device=device)
557
+ voc_path = os.path.join(args.onnx_dir, "vocoder.onnx")
558
+ export_one(vocoder, voc_path, (latent_dec,),
559
+ input_names=["latent"], output_names=["waveform"],
560
+ dynamic_axes={"latent": {2: "T_dec"}, "waveform": {2: "T_wav"}})
561
+ exported_files.append(voc_path)
562
+
563
+ # ==============================================================
564
+ # 5. Duration Predictor (two variants)
565
+ # ==============================================================
566
+ print("\n[5/5] Duration Predictor")
567
+ dp_style_tokens = cfg["dp_style_tokens"]
568
+ dp_style_dim = cfg["dp_style_dim"]
569
+ dp = DPNetwork(
570
+ vocab_size=cfg["dp_vocab_size"],
571
+ style_tokens=dp_style_tokens,
572
+ style_dim=dp_style_dim,
573
+ ).to(device).eval()
574
+
575
+ if dp_state_raw is not None:
576
+ ds = dp_state_raw
577
+ if isinstance(ds, dict) and "state_dict" in ds:
578
+ ds = ds["state_dict"]
579
+ dp.load_state_dict(ds, strict=False)
580
+ elif "dp_network" in t2l_state:
581
+ dp.load_state_dict(t2l_state["dp_network"], strict=True)
582
+ elif "dp_model" in t2l_state:
583
+ dp.load_state_dict(t2l_state["dp_model"], strict=True)
584
+
585
+ _replace_mha_with_safe(dp)
586
+
587
+ n = scramble_all_layers(dp, scramble_linear=scramble_linear, scramble_embedding=scramble_embedding)
588
+ total_scrambled += n
589
+ print(f" Scrambled {n} layers")
590
+
591
+ # Standard path (z_ref)
592
+ dp_path = os.path.join(args.onnx_dir, "duration_predictor.onnx")
593
+ dp_inputs = (text_ids, z_ref, text_mask, ref_mask)
594
+ dp_names = ["text_ids", "z_ref", "text_mask", "ref_mask"]
595
+ export_one(dp, dp_path, dp_inputs,
596
+ input_names=dp_names, output_names=["duration"],
597
+ dynamic_axes={
598
+ "text_ids": {1: "T_text"}, "text_mask": {2: "T_text"},
599
+ "z_ref": {2: "T_ref_audio"}, "ref_mask": {2: "T_ref_audio"},
600
+ })
601
+ exported_files.append(dp_path)
602
+
603
+ # Style path (pre-computed style tokens)
604
+ class DPStyleWrapper(nn.Module):
605
+ """Wrap DPNetwork for the style_tokens input path (no z_ref)."""
606
+ def __init__(self, dp_model):
607
+ super().__init__()
608
+ self.dp = dp_model
609
+ def forward(self, text_ids, style_dp, text_mask):
610
+ return self.dp(text_ids, text_mask=text_mask, style_tokens=style_dp)
611
+
612
+ dp_style_wrapper = DPStyleWrapper(dp).eval()
613
+ style_dp_dummy = torch.randn(B, dp_style_tokens, dp_style_dim, device=device)
614
+ dp_style_path = os.path.join(args.onnx_dir, "duration_predictor_style.onnx")
615
+ dp_style_inputs = (text_ids, style_dp_dummy, text_mask)
616
+ dp_style_names = ["text_ids", "style_dp", "text_mask"]
617
+ export_one(dp_style_wrapper, dp_style_path, dp_style_inputs,
618
+ input_names=dp_style_names, output_names=["duration"],
619
+ dynamic_axes={"text_ids": {1: "T_text"}, "text_mask": {2: "T_text"}})
620
+ exported_files.append(dp_style_path)
621
+
622
+ # ==============================================================
623
+ # Unconditional Tokens (for CFG)
624
+ # ==============================================================
625
+ print("\nExporting uncond.npz...")
626
+ uncond_data = {}
627
+ for key in ("u_text", "u_ref", "u_keys"):
628
+ if key in t2l_state:
629
+ uncond_data[key] = t2l_state[key].cpu().numpy()
630
+ with torch.no_grad():
631
+ uncond_data["cond_keys"] = text_enc.ref_keys.cpu().numpy()
632
+ if uncond_data:
633
+ np.savez(os.path.join(args.onnx_dir, "uncond.npz"), **uncond_data)
634
+ print(f"[OK] Saved uncond.npz")
635
+
636
+ # ==============================================================
637
+ # Extract Shuffle Keys (Optimizer-Proof)
638
+ # ==============================================================
639
+ total_keys_extracted = 0
640
+ if do_key_extraction:
641
+ print("\nExtracting shuffle keys from models...")
642
+ all_extracted_keys = {}
643
+ for fpath in exported_files:
644
+ model_name = os.path.splitext(os.path.basename(fpath))[0]
645
+ keys = extract_shuffle_keys(fpath)
646
+ for input_name, arr in keys.items():
647
+ all_extracted_keys[f"{model_name}/{input_name}"] = arr
648
+ if keys:
649
+ print(f" {model_name}: {len(keys)} key arrays extracted")
650
+ total_keys_extracted += len(keys)
651
+
652
+ if all_extracted_keys:
653
+ keys_path = os.path.join(args.onnx_dir, "keys.npz")
654
+ np.savez(keys_path, **all_extracted_keys)
655
+ print(f" Saved {total_keys_extracted} total keys to keys.npz")
656
+ else:
657
+ print("\nSkipping shuffle key extraction (--no-key-extraction).")
658
+
659
+ # ==============================================================
660
+ # Apply ONNX Name Obfuscation
661
+ # ==============================================================
662
+ total_names_randomized = 0
663
+ if do_name_obfuscation:
664
+ print("\nApplying ONNX name obfuscation...")
665
+ for fpath in exported_files:
666
+ n_names = obfuscate_onnx_names(fpath, fpath, keep_io_names=True)
667
+ total_names_randomized += n_names
668
+ print(f" {os.path.basename(fpath)}: {n_names} names randomized")
669
+
670
+ # ==============================================================
671
+ # Summary
672
+ # ==============================================================
673
+ print("\n" + "=" * 60)
674
+ print("OBFUSCATION COMPLETE")
675
+ print("=" * 60)
676
+ print(f" Layers scrambled : {total_scrambled}")
677
+ print(f" (Conv1d always, Linear={'ON' if scramble_linear else 'OFF'}, Embedding={'ON' if scramble_embedding else 'OFF'})")
678
+ print(f" ONNX files exported : {len(exported_files)}")
679
+ if do_key_extraction:
680
+ print(f" Shuffle keys extracted : {total_keys_extracted}")
681
+ if do_name_obfuscation:
682
+ print(f" Internal names randomized: {total_names_randomized}")
683
+ print(f" Output directory : {args.onnx_dir}/")
684
+ print()
685
+ print("Weight poisoning: Layer weights are permanently permuted.")
686
+ print("Gather injection: Every scrambled layer has an inverse-Gather node.")
687
+ if do_key_extraction:
688
+ print("Key extraction : Shuffle indices moved to keys.npz (optimizer-proof).")
689
+ print(" Models REQUIRE keys.npz at inference time.")
690
+ else:
691
+ print("WARNING: Shuffle keys are embedded as constants (vulnerable to ORT optimizer).")
692
+ if do_name_obfuscation:
693
+ print("Name obfuscation: All internal node/tensor names are random strings.")
694
+ print("I/O tensor names are preserved for inference compatibility.")
695
+ print("=" * 60)
696
+
697
+
698
+ if __name__ == "__main__":
699
+ main()
reference_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:775ed3896b688411d37934edc9b827dc20e676c11fae78baa77ad29bb1f1dbdb
3
+ size 24416182
stats.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bcb4f3cf96356860dea9e161fe4f5704b19d76467826611e3548891ea58986c2
3
+ size 1920
text_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c70b6d4983b0e96e157ac4f4672834c7034991dfdda40d1d55b3ed26edd55ed3
3
+ size 27745455
uncond.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3681ac7a4959a1217fd9af26a93fe653df5c3d7a74261d843cb2ebcd0fbef79c
3
+ size 155626
vocoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1df1cee4f4205ed5d10accec49d295a3521cdd7c04fcf0db3bfd926c78bd96d
3
+ size 101638298