| """ |
| Gilbert-Elliott packet-loss simulator in the latent frame domain. |
| I could have done it directly on the tokens, but it would have been more difficult for the transformer to learn |
| how to repair the latent vectors. In a real implementation scenario (with RTP packets) this doesn't change anything. |
| |
| The classic model (Hasslinger & Hohlfeld 2008) has two states G/B with: |
| p = P(G -> B) probability of entering the Bad state |
| r = P(B -> G) probability of leaving the Bad state |
| k = P(no loss | G) typically ~1 (Good channel, nearly clean) |
| h = P(no loss | B) typically ~0.5 (Bad channel, lossy) |
| |
| Stationary probabilities: |
| pi_G = r / (p + r) |
| pi_B = p / (p + r) |
| |
| Expected error rate: |
| p_E = (1-k) * pi_G + (1-h) * pi_B |
| |
| Difference from the baseline implementation (which operates on the waveform): |
| here we directly generate a frame_mask [B, T'] over latent frames. |
| Packetization over latent frames is derived from packet_ms, sample_rate, and hop_length. |
| """ |
|
|
| import math |
| from dataclasses import dataclass, field |
| import typing as tp |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| @dataclass |
| class GilbertElliottConfig: |
| """ |
| Simulator parameters. Defaults: |
| channel degrades rarely, recovers quickly, short bursts. |
| """ |
| p: float = 0.005 |
| r: float = 0.5 |
| k: float = 0.999 |
| h: float = 0.5 |
| packet_ms: float = 15.0 |
|
|
| def stationary_loss_rate(self) -> float: |
| pi_G = self.r / (self.p + self.r) |
| pi_B = self.p / (self.p + self.r) |
| return (1.0 - self.k) * pi_G + (1.0 - self.h) * pi_B |
|
|
| def __post_init__(self): |
| assert 0.0 < self.p < 1.0, "p must be in (0, 1)" |
| assert 0.0 < self.r < 1.0, "r must be in (0, 1)" |
| assert 0.0 <= self.k <= 1.0, "k must be in [0, 1]" |
| assert 0.0 <= self.h <= 1.0, "h must be in [0, 1]" |
| assert self.packet_ms > 0, "packet_ms must be positive" |
|
|
|
|
| class GilbertElliottSimulator: |
| """ |
| Generates a frame_mask [B, T'] (1 = received, 0 = missing) where T' is the |
| number of latent frames produced by the codec. If a packet spans multiple |
| consecutive frames (packet_ms > frame_ms), all frames in that packet are |
| dropped together. |
| |
| Args: |
| config: Gilbert-Elliott model parameters. |
| sample_rate: audio sample rate (Hz). |
| hop_length: total encoder downsampling factor (1 latent frame |
| = hop_length samples). For Zero-Ping with |
| ratios=[8,5,3,2] this is 240 -> 15ms at 16kHz. |
| """ |
|
|
| def __init__( |
| self, |
| config: GilbertElliottConfig, |
| sample_rate: int = 16000, |
| hop_length: int = 240, |
| ): |
| self.config = config |
| self.sample_rate = sample_rate |
| self.hop_length = hop_length |
|
|
| frame_ms = 1000.0 * hop_length / sample_rate |
| self.frames_per_packet = max(1, round(config.packet_ms / frame_ms)) |
|
|
| def _sample_one(self, T: int, rng: np.random.Generator) -> np.ndarray: |
| """Generate a single frame_mask of length T.""" |
| cfg = self.config |
| n_packets = math.ceil(T / self.frames_per_packet) |
|
|
| |
| pi_G = cfg.r / (cfg.p + cfg.r) |
| state_good = bool(rng.random() < pi_G) |
|
|
| mask = np.ones(T, dtype=np.float32) |
|
|
| for i in range(n_packets): |
| no_loss_prob = cfg.k if state_good else cfg.h |
| if rng.random() >= no_loss_prob: |
| |
| start = i * self.frames_per_packet |
| end = min(start + self.frames_per_packet, T) |
| mask[start:end] = 0.0 |
|
|
| |
| if state_good: |
| state_good = (rng.random() >= cfg.p) |
| else: |
| state_good = (rng.random() < cfg.r) |
|
|
| return mask |
|
|
| def sample_frame_mask( |
| self, |
| batch_size: int, |
| num_frames: int, |
| device: tp.Optional[torch.device] = None, |
| seed: tp.Optional[int] = None, |
| ) -> torch.Tensor: |
| """ |
| Returns a float tensor [B, T'] with values in {0, 1}. |
| where 1 = frame received, 0 = frame lost. |
| Sequences in the batch are sampled independently. |
| """ |
| rng = np.random.default_rng(seed) |
| masks = np.stack([ |
| self._sample_one(num_frames, rng) for _ in range(batch_size) |
| ], axis=0) |
| out = torch.from_numpy(masks) |
| if device is not None: |
| out = out.to(device) |
| return out |
|
|