| """ |
| ZPCodec: full codec model combining encoder, RVQ, optional repair, and decoder. |
| |
| Data flow: |
| waveform [B, 1, T] |
| -> ZPEncoder -> latent z [B, D, T'] |
| -> ResidualVQ -> quantized z_q [B, D, T'], indices, commit_loss |
| -> (GE simulator) -> frame_mask [B, T'] (training only, if use_repair=True) |
| -> LatentRepairTransformer -> z_q_post [B, D, T'] (missing frames concealed) |
| -> ZPDecoder -> waveform [B, 1, T_out] |
| |
| T' = T / hop_length (hop_length = prod(ratios) = 240 for ratios=[8,5,3,2] -> 15ms/frame) |
| |
| The repair module is optional (use_repair=False for stage 1 codec-only training). |
| The GE simulator is optional too: if no GilbertElliottConfig is provided, no |
| packet loss is simulated and frame_mask is never generated automatically. |
| """ |
|
|
| import typing as tp |
| from contextlib import contextmanager |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from vector_quantize_pytorch import ResidualVQ |
|
|
| from .components import ZPEncoder, ZPDecoder |
| from .repair import LatentRepairTransformer |
| from .GilbertElliot import GilbertElliottConfig, GilbertElliottSimulator |
|
|
|
|
| @contextmanager |
| def temporarily_set(obj, attr: str, value): |
| """Context manager that sets obj.attr = value for the duration of the block, |
| then restores the original value. Used to toggle quantize_dropout per-batch.""" |
| original = getattr(obj, attr) |
| setattr(obj, attr, value) |
| try: |
| yield |
| finally: |
| setattr(obj, attr, original) |
|
|
|
|
| class ZPCodec(nn.Module): |
| """ |
| Full codec: encoder -> RVQ -> (repair) -> decoder. |
| |
| Stage 1 (codec pre-training): use_repair=False, no GilbertElliottConfig. |
| Stage 2 (repair training): use_repair=True, GilbertElliottConfig provided. |
| Stage 3 (joint fine-tuning): use_repair=True, GE curriculum via set_gilbert_elliott_config(). |
| """ |
| def __init__( |
| self, |
| channels: int = 1, |
| dimension: int = 128, |
| n_filters: int = 32, |
| ratios: tp.List[int] = [8, 5, 3, 2], |
| norm: str = 'weight_norm', |
| causal: bool = True, |
| num_quantizers: int = 9, |
| codebook_size: int = 1024, |
| sample_rate: int = 16000, |
| |
| use_repair: bool = False, |
| repair_hidden_dim: int = 256, |
| repair_num_layers: int = 4, |
| repair_num_heads: int = 4, |
| repair_ffn_mult: int = 2, |
| repair_past: int = 8, |
| repair_future: int = 2, |
| repair_two_pass: bool = True, |
| |
| gilbert_elliott_config: tp.Optional[GilbertElliottConfig] = None, |
| ): |
| super().__init__() |
| self.encoder = ZPEncoder( |
| channels=channels, |
| dimension=dimension, |
| n_filters=n_filters, |
| ratios=ratios, |
| norm=norm, |
| causal=causal, |
| ) |
| self.rvq = ResidualVQ( |
| dim=dimension, |
| num_quantizers=num_quantizers, |
| codebook_size=codebook_size, |
| kmeans_init=True, |
| kmeans_iters=10, |
| use_cosine_sim=True, |
| threshold_ema_dead_code=2, |
| quantize_dropout=True, |
| quantize_dropout_cutoff_index=5, |
| |
| quantize_dropout_multiple_of=1, |
| ) |
| self.decoder = ZPDecoder( |
| channels=channels, |
| dimension=dimension, |
| n_filters=n_filters, |
| ratios=ratios, |
| norm=norm, |
| causal=causal, |
| ) |
|
|
| self.sample_rate = sample_rate |
| self.hop_length = int(np.prod(ratios)) |
|
|
| self.use_repair = use_repair |
| self.repair_two_pass = repair_two_pass |
| if use_repair: |
| self.repair = LatentRepairTransformer( |
| latent_dim=dimension, |
| hidden_dim=repair_hidden_dim, |
| num_layers=repair_num_layers, |
| num_heads=repair_num_heads, |
| ffn_mult=repair_ffn_mult, |
| past=repair_past, |
| future=repair_future, |
| ) |
| else: |
| self.repair = None |
|
|
| self.ge_simulator: tp.Optional[GilbertElliottSimulator] = None |
| if gilbert_elliott_config is not None: |
| self.set_gilbert_elliott_config(gilbert_elliott_config) |
|
|
| |
| |
| def set_gilbert_elliott_config(self, config: GilbertElliottConfig) -> None: |
| """Replace the GE simulator at runtime. Called between training stages |
| to apply a harder loss curriculum without reloading the model.""" |
| self.ge_simulator = GilbertElliottSimulator( |
| config=config, |
| sample_rate=self.sample_rate, |
| hop_length=self.hop_length, |
| ) |
|
|
| def sample_frame_mask( |
| self, |
| batch_size: int, |
| num_frames: int, |
| device: tp.Optional[torch.device] = None, |
| seed: tp.Optional[int] = None, |
| ) -> torch.Tensor: |
| """Expose the GE simulator directly. Useful when the same mask needs to |
| be reused across multiple points (e.g. logging, loss weighting).""" |
| assert self.ge_simulator is not None, ( |
| "No GilbertElliottConfig configured. Call set_gilbert_elliott_config() first." |
| ) |
| return self.ge_simulator.sample_frame_mask( |
| batch_size, num_frames, device=device, seed=seed |
| ) |
|
|
| |
| def _encode_raw(self, x: torch.Tensor): |
| """Encode waveform to quantized latent. Returns (z, z_q, indices, commit_loss). |
| quantize_dropout is randomly toggled per-call during training to teach |
| the decoder to handle a variable number of active quantizers (bitrate scalability).""" |
| z = self.encoder(x) |
| z_seq = z.permute(0, 2, 1) |
| use_dropout = self.training and (torch.rand(1).item() < 0.5) |
| |
| with temporarily_set(self.rvq, 'quantize_dropout', use_dropout): |
| z_q, indices, commit_loss = self.rvq(z_seq) |
| z_q = z_q.permute(0, 2, 1) |
| return z, z_q, indices, commit_loss |
|
|
| |
| |
| def _apply_repair( |
| self, |
| z_q: torch.Tensor, |
| frame_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """Run the repair transformer and selectively substitute only missing frames. |
| |
| z_q: [B, D, T'] |
| frame_mask: [B, T'] 1 = received, 0 = missing |
| |
| The transformer outputs a full [B, D, T'] tensor, but received frames are |
| kept as-is from z_q — only positions where frame_mask == 0 are replaced. |
| This means z_q_post == z_q on received frames by construction, which is |
| important for latent_repair_loss (the mask isolates the useful gradient). |
| |
| Two-pass mode (repair_two_pass=True): mimics streaming deployment where |
| previous repair estimates are already in the buffer when estimating frame t. |
| See LatentRepairTransformer.forward_two_pass for the full explanation. |
| """ |
| assert self.repair is not None, "use_repair=False, repair not initialised" |
|
|
| z_seq = z_q.permute(0, 2, 1) |
|
|
| if self.repair_two_pass: |
| z_repaired = self.repair.forward_two_pass(z_seq, frame_mask) |
| else: |
| |
| z_seq_filled = self.repair.fill_missing(z_seq, frame_mask) |
| z_repaired = self.repair(z_seq_filled, frame_mask) |
|
|
| |
| m = frame_mask.unsqueeze(-1).to(z_seq.dtype) |
| z_out = z_seq * m + z_repaired * (1.0 - m) |
| return z_out.permute(0, 2, 1) |
|
|
| def _get_frame_mask( |
| self, |
| z_q: torch.Tensor, |
| frame_mask: tp.Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| """Return the provided frame_mask, or sample one from the GE simulator.""" |
| if frame_mask is not None: |
| return frame_mask |
| assert self.ge_simulator is not None, ( |
| "use_repair=True but no GilbertElliottConfig configured. " |
| "Call set_gilbert_elliott_config() before training." |
| ) |
| B, _, T_prime = z_q.shape |
| return self.ge_simulator.sample_frame_mask(B, T_prime, device=z_q.device) |
|
|
| |
| def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: |
| """Encode waveform to (z_q, indices). x: [B, 1, T]""" |
| _, z_q, indices, _ = self._encode_raw(x) |
| return z_q, indices |
|
|
| def decode( |
| self, |
| z_q: torch.Tensor, |
| frame_mask: tp.Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Decode quantized latent to waveform. |
| z_q: [B, D, T'] |
| frame_mask: [B, T'] optional; if provided and use_repair=True, runs repair first. |
| """ |
| if self.use_repair and frame_mask is not None: |
| z_q = self._apply_repair(z_q, frame_mask) |
| return self.decoder(z_q) |
|
|
| |
| def forward( |
| self, |
| x: torch.Tensor, |
| frame_mask: tp.Optional[torch.Tensor] = None, |
| return_intermediates: bool = False, |
| ): |
| """ |
| x: [B, 1, T] |
| frame_mask: [B, T'] optional. If use_repair=True and None, |
| sampled automatically from the GE simulator. |
| return_intermediates: if True, also returns z_q pre/post repair and the |
| effective frame_mask — required by latent_repair_loss |
| and ZPCodecTrainer.forward_codec during training. |
| |
| Returns: |
| return_intermediates=False: (x_hat, commit_loss) |
| return_intermediates=True: (x_hat, commit_loss, z_q_pre, z_q_post, frame_mask) |
| When use_repair=False: z_q_pre == z_q_post and frame_mask == None. |
| """ |
| _, z_q_pre, _, commit_loss = self._encode_raw(x) |
| commit_loss = commit_loss.mean() |
|
|
| if self.use_repair: |
| frame_mask = self._get_frame_mask(z_q_pre, frame_mask) |
| z_q_post = self._apply_repair(z_q_pre, frame_mask) |
| else: |
| z_q_post = z_q_pre |
| frame_mask = None |
|
|
| x_hat = self.decoder(z_q_post) |
|
|
| if return_intermediates: |
| return x_hat, commit_loss, z_q_pre, z_q_post, frame_mask |
| return x_hat, commit_loss |
|
|
| |
| |
| |
| @classmethod |
| def from_pretrained( |
| cls, |
| model_id: str, |
| device: str = "cpu", |
| filename: str = "zpcodec_weights.pt", |
| **hf_kwargs, |
| ) -> "ZPCodec": |
| """ |
| Load ZPCodec from a Hugging Face Hub repo or a local file. |
| |
| Args: |
| model_id: HF repo id (e.g. "yourname/zpcodec") OR a local path |
| to a .pt file OR a local directory containing filename. |
| device: "cpu" | "cuda" | "cuda:0" etc. |
| filename: name of the weights file inside the HF repo. |
| **hf_kwargs: forwarded to huggingface_hub.hf_hub_download |
| (e.g. revision="main", token="hf_..."). |
| |
| Returns: |
| ZPCodec in eval mode. |
| |
| Examples: |
| # From Hugging Face Hub |
| model = ZPCodec.from_pretrained("yourname/zpcodec") |
| |
| # From a local .pt file |
| model = ZPCodec.from_pretrained("./zpcodec_weights.pt") |
| |
| # With explicit device |
| model = ZPCodec.from_pretrained("yourname/zpcodec", device="cuda") |
| """ |
| import os |
| import torch |
|
|
| |
| if os.path.isfile(model_id): |
| ckpt_path = model_id |
| elif os.path.isdir(model_id): |
| ckpt_path = os.path.join(model_id, filename) |
| if not os.path.isfile(ckpt_path): |
| raise FileNotFoundError( |
| f"No '{filename}' found in directory '{model_id}'" |
| ) |
| else: |
| |
| try: |
| from huggingface_hub import hf_hub_download |
| except ImportError: |
| raise ImportError( |
| "huggingface_hub is required to download from the Hub.\n" |
| "Install with: pip install huggingface_hub" |
| ) |
| ckpt_path = hf_hub_download( |
| repo_id=model_id, filename=filename, **hf_kwargs |
| ) |
|
|
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
|
|
| |
| |
| if "config" in ckpt: |
| cfg = ckpt["config"] |
| state_dict = ckpt["model_state_dict"] |
| elif "args" in ckpt and "trainer" in ckpt: |
| |
| args = ckpt["args"] |
| state_dict = { |
| k[len("codec."):]: v |
| for k, v in ckpt["trainer"].items() |
| if k.startswith("codec.") |
| } |
| cfg = { |
| "channels": 1, "dimension": args["dimension"], |
| "n_filters": args["n_filters"], "ratios": [8, 5, 3, 2], |
| "norm": "weight_norm", "causal": True, |
| "num_quantizers": args["num_quantizers"], |
| "codebook_size": args["codebook_size"], "sample_rate": 16000, |
| "use_repair": True, |
| "repair_hidden_dim": args["repair_hidden_dim"], |
| "repair_num_layers": args["repair_num_layers"], |
| "repair_num_heads": args["repair_num_heads"], |
| "repair_ffn_mult": args["repair_ffn_mult"], |
| "repair_past": args["repair_past"], |
| "repair_future": args["repair_future"], |
| "repair_two_pass": True, |
| } |
| else: |
| raise ValueError( |
| "Unrecognised checkpoint format. " |
| "Expected keys: 'config'+'model_state_dict' or 'args'+'trainer'." |
| ) |
|
|
| model = cls(**cfg) |
| missing, unexpected = model.load_state_dict(state_dict, strict=True) |
| if missing: |
| raise RuntimeError(f"Missing keys: {missing[:5]}") |
| if unexpected: |
| raise RuntimeError(f"Unexpected keys: {unexpected[:5]}") |
|
|
| n_params = sum(p.numel() for p in model.parameters()) / 1e6 |
| info = ckpt.get("training_info", {}) |
| stoi = info.get("best_val_stoi", ckpt.get("best_val_metric", "?")) |
| print(f"✓ ZPCodec loaded — {n_params:.1f}M params | best val STOI: {stoi}") |
|
|
| model = model.to(device) |
| model.eval() |
| return model |
|
|