Lucabr01 commited on
Commit
41a06ac
·
verified ·
1 Parent(s): cbac10e

Upload zpcodec/model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. zpcodec/model.py +386 -0
zpcodec/model.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ZPCodec: full codec model combining encoder, RVQ, optional repair, and decoder.
3
+
4
+ Data flow:
5
+ waveform [B, 1, T]
6
+ -> ZPEncoder -> latent z [B, D, T']
7
+ -> ResidualVQ -> quantized z_q [B, D, T'], indices, commit_loss
8
+ -> (GE simulator) -> frame_mask [B, T'] (training only, if use_repair=True)
9
+ -> LatentRepairTransformer -> z_q_post [B, D, T'] (missing frames concealed)
10
+ -> ZPDecoder -> waveform [B, 1, T_out]
11
+
12
+ T' = T / hop_length (hop_length = prod(ratios) = 240 for ratios=[8,5,3,2] -> 15ms/frame)
13
+
14
+ The repair module is optional (use_repair=False for stage 1 codec-only training).
15
+ The GE simulator is optional too: if no GilbertElliottConfig is provided, no
16
+ packet loss is simulated and frame_mask is never generated automatically.
17
+ """
18
+
19
+ import typing as tp
20
+ from contextlib import contextmanager
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.nn as nn
25
+ from vector_quantize_pytorch import ResidualVQ
26
+
27
+ from .components import ZPEncoder, ZPDecoder
28
+ from .repair import LatentRepairTransformer
29
+ from .GilbertElliot import GilbertElliottConfig, GilbertElliottSimulator
30
+
31
+
32
+ @contextmanager
33
+ def temporarily_set(obj, attr: str, value):
34
+ """Context manager that sets obj.attr = value for the duration of the block,
35
+ then restores the original value. Used to toggle quantize_dropout per-batch."""
36
+ original = getattr(obj, attr)
37
+ setattr(obj, attr, value)
38
+ try:
39
+ yield
40
+ finally:
41
+ setattr(obj, attr, original)
42
+
43
+
44
+ class ZPCodec(nn.Module):
45
+ """
46
+ Full codec: encoder -> RVQ -> (repair) -> decoder.
47
+
48
+ Stage 1 (codec pre-training): use_repair=False, no GilbertElliottConfig.
49
+ Stage 2 (repair training): use_repair=True, GilbertElliottConfig provided.
50
+ Stage 3 (joint fine-tuning): use_repair=True, GE curriculum via set_gilbert_elliott_config().
51
+ """
52
+ def __init__(
53
+ self,
54
+ channels: int = 1,
55
+ dimension: int = 128,
56
+ n_filters: int = 32,
57
+ ratios: tp.List[int] = [8, 5, 3, 2],
58
+ norm: str = 'weight_norm',
59
+ causal: bool = True,
60
+ num_quantizers: int = 9,
61
+ codebook_size: int = 1024,
62
+ sample_rate: int = 16000,
63
+ # --- Repair module ---
64
+ use_repair: bool = False,
65
+ repair_hidden_dim: int = 256,
66
+ repair_num_layers: int = 4,
67
+ repair_num_heads: int = 4,
68
+ repair_ffn_mult: int = 2,
69
+ repair_past: int = 8,
70
+ repair_future: int = 2,
71
+ repair_two_pass: bool = True,
72
+ # --- Packet loss simulation ---
73
+ gilbert_elliott_config: tp.Optional[GilbertElliottConfig] = None,
74
+ ):
75
+ super().__init__()
76
+ self.encoder = ZPEncoder(
77
+ channels=channels,
78
+ dimension=dimension,
79
+ n_filters=n_filters,
80
+ ratios=ratios,
81
+ norm=norm,
82
+ causal=causal,
83
+ )
84
+ self.rvq = ResidualVQ(
85
+ dim=dimension,
86
+ num_quantizers=num_quantizers,
87
+ codebook_size=codebook_size,
88
+ kmeans_init=True,
89
+ kmeans_iters=10,
90
+ use_cosine_sim=True, # prop to improved RVQGAN's paper
91
+ threshold_ema_dead_code=2,
92
+ quantize_dropout=True,
93
+ quantize_dropout_cutoff_index=5, # first 5 quantizers are always active -
94
+ # theoretically with 5 quant active we can switch to 3kbps. But this was not my focus for that project...
95
+ quantize_dropout_multiple_of=1,
96
+ )
97
+ self.decoder = ZPDecoder(
98
+ channels=channels,
99
+ dimension=dimension,
100
+ n_filters=n_filters,
101
+ ratios=ratios,
102
+ norm=norm,
103
+ causal=causal,
104
+ )
105
+
106
+ self.sample_rate = sample_rate
107
+ self.hop_length = int(np.prod(ratios)) # 240 for ratios=[8,5,3,2]
108
+
109
+ self.use_repair = use_repair
110
+ self.repair_two_pass = repair_two_pass
111
+ if use_repair:
112
+ self.repair = LatentRepairTransformer(
113
+ latent_dim=dimension,
114
+ hidden_dim=repair_hidden_dim,
115
+ num_layers=repair_num_layers,
116
+ num_heads=repair_num_heads,
117
+ ffn_mult=repair_ffn_mult,
118
+ past=repair_past,
119
+ future=repair_future,
120
+ )
121
+ else:
122
+ self.repair = None
123
+
124
+ self.ge_simulator: tp.Optional[GilbertElliottSimulator] = None
125
+ if gilbert_elliott_config is not None:
126
+ self.set_gilbert_elliott_config(gilbert_elliott_config)
127
+
128
+
129
+ # Runtime configuration of the packet-loss simulator
130
+ def set_gilbert_elliott_config(self, config: GilbertElliottConfig) -> None:
131
+ """Replace the GE simulator at runtime. Called between training stages
132
+ to apply a harder loss curriculum without reloading the model."""
133
+ self.ge_simulator = GilbertElliottSimulator(
134
+ config=config,
135
+ sample_rate=self.sample_rate,
136
+ hop_length=self.hop_length,
137
+ )
138
+
139
+ def sample_frame_mask(
140
+ self,
141
+ batch_size: int,
142
+ num_frames: int,
143
+ device: tp.Optional[torch.device] = None,
144
+ seed: tp.Optional[int] = None,
145
+ ) -> torch.Tensor:
146
+ """Expose the GE simulator directly. Useful when the same mask needs to
147
+ be reused across multiple points (e.g. logging, loss weighting)."""
148
+ assert self.ge_simulator is not None, (
149
+ "No GilbertElliottConfig configured. Call set_gilbert_elliott_config() first."
150
+ )
151
+ return self.ge_simulator.sample_frame_mask(
152
+ batch_size, num_frames, device=device, seed=seed
153
+ )
154
+
155
+ # Encoding
156
+ def _encode_raw(self, x: torch.Tensor):
157
+ """Encode waveform to quantized latent. Returns (z, z_q, indices, commit_loss).
158
+ quantize_dropout is randomly toggled per-call during training to teach
159
+ the decoder to handle a variable number of active quantizers (bitrate scalability)."""
160
+ z = self.encoder(x) # [B, D, T']
161
+ z_seq = z.permute(0, 2, 1) # [B, T', D] — RVQ expects (B, T, D)
162
+ use_dropout = self.training and (torch.rand(1).item() < 0.5) # dropout applied only 50% of the time, this improve the
163
+ # quality at full kbps. Citing the improved RVQGAN paper.
164
+ with temporarily_set(self.rvq, 'quantize_dropout', use_dropout):
165
+ z_q, indices, commit_loss = self.rvq(z_seq)
166
+ z_q = z_q.permute(0, 2, 1) # [B, D, T']
167
+ return z, z_q, indices, commit_loss
168
+
169
+
170
+ # Repair
171
+ def _apply_repair(
172
+ self,
173
+ z_q: torch.Tensor,
174
+ frame_mask: torch.Tensor,
175
+ ) -> torch.Tensor:
176
+ """Run the repair transformer and selectively substitute only missing frames.
177
+
178
+ z_q: [B, D, T']
179
+ frame_mask: [B, T'] 1 = received, 0 = missing
180
+
181
+ The transformer outputs a full [B, D, T'] tensor, but received frames are
182
+ kept as-is from z_q — only positions where frame_mask == 0 are replaced.
183
+ This means z_q_post == z_q on received frames by construction, which is
184
+ important for latent_repair_loss (the mask isolates the useful gradient).
185
+
186
+ Two-pass mode (repair_two_pass=True): mimics streaming deployment where
187
+ previous repair estimates are already in the buffer when estimating frame t.
188
+ See LatentRepairTransformer.forward_two_pass for the full explanation.
189
+ """
190
+ assert self.repair is not None, "use_repair=False, repair not initialised"
191
+
192
+ z_seq = z_q.permute(0, 2, 1) # [B, T', D]
193
+
194
+ if self.repair_two_pass:
195
+ z_repaired = self.repair.forward_two_pass(z_seq, frame_mask)
196
+ else:
197
+ # Single-pass fallback
198
+ z_seq_filled = self.repair.fill_missing(z_seq, frame_mask)
199
+ z_repaired = self.repair(z_seq_filled, frame_mask)
200
+
201
+ # Selective substitution: keep received frames from z_q, replace missing ones
202
+ m = frame_mask.unsqueeze(-1).to(z_seq.dtype) # [B, T', 1]
203
+ z_out = z_seq * m + z_repaired * (1.0 - m)
204
+ return z_out.permute(0, 2, 1) # [B, D, T']
205
+
206
+ def _get_frame_mask(
207
+ self,
208
+ z_q: torch.Tensor,
209
+ frame_mask: tp.Optional[torch.Tensor],
210
+ ) -> torch.Tensor:
211
+ """Return the provided frame_mask, or sample one from the GE simulator."""
212
+ if frame_mask is not None:
213
+ return frame_mask
214
+ assert self.ge_simulator is not None, (
215
+ "use_repair=True but no GilbertElliottConfig configured. "
216
+ "Call set_gilbert_elliott_config() before training."
217
+ )
218
+ B, _, T_prime = z_q.shape
219
+ return self.ge_simulator.sample_frame_mask(B, T_prime, device=z_q.device)
220
+
221
+ # Public encode / decode API
222
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
223
+ """Encode waveform to (z_q, indices). x: [B, 1, T]"""
224
+ _, z_q, indices, _ = self._encode_raw(x)
225
+ return z_q, indices
226
+
227
+ def decode(
228
+ self,
229
+ z_q: torch.Tensor,
230
+ frame_mask: tp.Optional[torch.Tensor] = None,
231
+ ) -> torch.Tensor:
232
+ """Decode quantized latent to waveform.
233
+ z_q: [B, D, T']
234
+ frame_mask: [B, T'] optional; if provided and use_repair=True, runs repair first.
235
+ """
236
+ if self.use_repair and frame_mask is not None:
237
+ z_q = self._apply_repair(z_q, frame_mask)
238
+ return self.decoder(z_q)
239
+
240
+ # Training forward
241
+ def forward(
242
+ self,
243
+ x: torch.Tensor,
244
+ frame_mask: tp.Optional[torch.Tensor] = None,
245
+ return_intermediates: bool = False,
246
+ ):
247
+ """
248
+ x: [B, 1, T]
249
+ frame_mask: [B, T'] optional. If use_repair=True and None,
250
+ sampled automatically from the GE simulator.
251
+ return_intermediates: if True, also returns z_q pre/post repair and the
252
+ effective frame_mask — required by latent_repair_loss
253
+ and ZPCodecTrainer.forward_codec during training.
254
+
255
+ Returns:
256
+ return_intermediates=False: (x_hat, commit_loss)
257
+ return_intermediates=True: (x_hat, commit_loss, z_q_pre, z_q_post, frame_mask)
258
+ When use_repair=False: z_q_pre == z_q_post and frame_mask == None.
259
+ """
260
+ _, z_q_pre, _, commit_loss = self._encode_raw(x)
261
+ commit_loss = commit_loss.mean()
262
+
263
+ if self.use_repair:
264
+ frame_mask = self._get_frame_mask(z_q_pre, frame_mask)
265
+ z_q_post = self._apply_repair(z_q_pre, frame_mask)
266
+ else:
267
+ z_q_post = z_q_pre
268
+ frame_mask = None
269
+
270
+ x_hat = self.decoder(z_q_post)
271
+
272
+ if return_intermediates:
273
+ return x_hat, commit_loss, z_q_pre, z_q_post, frame_mask
274
+ return x_hat, commit_loss
275
+
276
+ # ------------------------------------------------------------------
277
+ # from_pretrained — load from Hugging Face Hub or local path
278
+ # ------------------------------------------------------------------
279
+ @classmethod
280
+ def from_pretrained(
281
+ cls,
282
+ model_id: str,
283
+ device: str = "cpu",
284
+ filename: str = "zpcodec_weights.pt",
285
+ **hf_kwargs,
286
+ ) -> "ZPCodec":
287
+ """
288
+ Load ZPCodec from a Hugging Face Hub repo or a local file.
289
+
290
+ Args:
291
+ model_id: HF repo id (e.g. "yourname/zpcodec") OR a local path
292
+ to a .pt file OR a local directory containing filename.
293
+ device: "cpu" | "cuda" | "cuda:0" etc.
294
+ filename: name of the weights file inside the HF repo.
295
+ **hf_kwargs: forwarded to huggingface_hub.hf_hub_download
296
+ (e.g. revision="main", token="hf_...").
297
+
298
+ Returns:
299
+ ZPCodec in eval mode.
300
+
301
+ Examples:
302
+ # From Hugging Face Hub
303
+ model = ZPCodec.from_pretrained("yourname/zpcodec")
304
+
305
+ # From a local .pt file
306
+ model = ZPCodec.from_pretrained("./zpcodec_weights.pt")
307
+
308
+ # With explicit device
309
+ model = ZPCodec.from_pretrained("yourname/zpcodec", device="cuda")
310
+ """
311
+ import os
312
+ import torch
313
+
314
+ # Resolve checkpoint path: local file, local dir, or HF Hub
315
+ if os.path.isfile(model_id):
316
+ ckpt_path = model_id
317
+ elif os.path.isdir(model_id):
318
+ ckpt_path = os.path.join(model_id, filename)
319
+ if not os.path.isfile(ckpt_path):
320
+ raise FileNotFoundError(
321
+ f"No '{filename}' found in directory '{model_id}'"
322
+ )
323
+ else:
324
+ # Treat as a Hugging Face Hub repo id
325
+ try:
326
+ from huggingface_hub import hf_hub_download
327
+ except ImportError:
328
+ raise ImportError(
329
+ "huggingface_hub is required to download from the Hub.\n"
330
+ "Install with: pip install huggingface_hub"
331
+ )
332
+ ckpt_path = hf_hub_download(
333
+ repo_id=model_id, filename=filename, **hf_kwargs
334
+ )
335
+
336
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
337
+
338
+ # Support both clean checkpoints (with 'config' key) and raw
339
+ # full-trainer checkpoints (with 'args' key) for backward compat
340
+ if "config" in ckpt:
341
+ cfg = ckpt["config"]
342
+ state_dict = ckpt["model_state_dict"]
343
+ elif "args" in ckpt and "trainer" in ckpt:
344
+ # Full trainer checkpoint — extract codec weights and config
345
+ args = ckpt["args"]
346
+ state_dict = {
347
+ k[len("codec."):]: v
348
+ for k, v in ckpt["trainer"].items()
349
+ if k.startswith("codec.")
350
+ }
351
+ cfg = {
352
+ "channels": 1, "dimension": args["dimension"],
353
+ "n_filters": args["n_filters"], "ratios": [8, 5, 3, 2],
354
+ "norm": "weight_norm", "causal": True,
355
+ "num_quantizers": args["num_quantizers"],
356
+ "codebook_size": args["codebook_size"], "sample_rate": 16000,
357
+ "use_repair": True,
358
+ "repair_hidden_dim": args["repair_hidden_dim"],
359
+ "repair_num_layers": args["repair_num_layers"],
360
+ "repair_num_heads": args["repair_num_heads"],
361
+ "repair_ffn_mult": args["repair_ffn_mult"],
362
+ "repair_past": args["repair_past"],
363
+ "repair_future": args["repair_future"],
364
+ "repair_two_pass": True,
365
+ }
366
+ else:
367
+ raise ValueError(
368
+ "Unrecognised checkpoint format. "
369
+ "Expected keys: 'config'+'model_state_dict' or 'args'+'trainer'."
370
+ )
371
+
372
+ model = cls(**cfg)
373
+ missing, unexpected = model.load_state_dict(state_dict, strict=True)
374
+ if missing:
375
+ raise RuntimeError(f"Missing keys: {missing[:5]}")
376
+ if unexpected:
377
+ raise RuntimeError(f"Unexpected keys: {unexpected[:5]}")
378
+
379
+ n_params = sum(p.numel() for p in model.parameters()) / 1e6
380
+ info = ckpt.get("training_info", {})
381
+ stoi = info.get("best_val_stoi", ckpt.get("best_val_metric", "?"))
382
+ print(f"✓ ZPCodec loaded — {n_params:.1f}M params | best val STOI: {stoi}")
383
+
384
+ model = model.to(device)
385
+ model.eval()
386
+ return model