Zero-Ping / zpcodec /GilbertElliot.py
Lucabr01's picture
Upload zpcodec/GilbertElliot.py with huggingface_hub
b9fd30d verified
"""
Gilbert-Elliott packet-loss simulator in the latent frame domain.
I could have done it directly on the tokens, but it would have been more difficult for the transformer to learn
how to repair the latent vectors. In a real implementation scenario (with RTP packets) this doesn't change anything.
The classic model (Hasslinger & Hohlfeld 2008) has two states G/B with:
p = P(G -> B) probability of entering the Bad state
r = P(B -> G) probability of leaving the Bad state
k = P(no loss | G) typically ~1 (Good channel, nearly clean)
h = P(no loss | B) typically ~0.5 (Bad channel, lossy)
Stationary probabilities:
pi_G = r / (p + r)
pi_B = p / (p + r)
Expected error rate:
p_E = (1-k) * pi_G + (1-h) * pi_B
Difference from the baseline implementation (which operates on the waveform):
here we directly generate a frame_mask [B, T'] over latent frames.
Packetization over latent frames is derived from packet_ms, sample_rate, and hop_length.
"""
import math
from dataclasses import dataclass, field
import typing as tp
import numpy as np
import torch
@dataclass
class GilbertElliottConfig:
"""
Simulator parameters. Defaults:
channel degrades rarely, recovers quickly, short bursts.
"""
p: float = 0.005 # G -> B transition probability
r: float = 0.5 # B -> G transition probability
k: float = 0.999 # P(no loss | Good state)
h: float = 0.5 # P(no loss | Bad state)
packet_ms: float = 15.0 # packet duration in milliseconds
def stationary_loss_rate(self) -> float:
pi_G = self.r / (self.p + self.r)
pi_B = self.p / (self.p + self.r)
return (1.0 - self.k) * pi_G + (1.0 - self.h) * pi_B
def __post_init__(self):
assert 0.0 < self.p < 1.0, "p must be in (0, 1)"
assert 0.0 < self.r < 1.0, "r must be in (0, 1)"
assert 0.0 <= self.k <= 1.0, "k must be in [0, 1]"
assert 0.0 <= self.h <= 1.0, "h must be in [0, 1]"
assert self.packet_ms > 0, "packet_ms must be positive"
class GilbertElliottSimulator:
"""
Generates a frame_mask [B, T'] (1 = received, 0 = missing) where T' is the
number of latent frames produced by the codec. If a packet spans multiple
consecutive frames (packet_ms > frame_ms), all frames in that packet are
dropped together.
Args:
config: Gilbert-Elliott model parameters.
sample_rate: audio sample rate (Hz).
hop_length: total encoder downsampling factor (1 latent frame
= hop_length samples). For Zero-Ping with
ratios=[8,5,3,2] this is 240 -> 15ms at 16kHz.
"""
def __init__(
self,
config: GilbertElliottConfig,
sample_rate: int = 16000,
hop_length: int = 240, #as said before, 15ms (i.e. thats whats carryed in a single packet)
):
self.config = config
self.sample_rate = sample_rate
self.hop_length = hop_length
frame_ms = 1000.0 * hop_length / sample_rate
self.frames_per_packet = max(1, round(config.packet_ms / frame_ms))
def _sample_one(self, T: int, rng: np.random.Generator) -> np.ndarray:
"""Generate a single frame_mask of length T."""
cfg = self.config
n_packets = math.ceil(T / self.frames_per_packet)
# Initialize state by sampling from the stationary distribution
pi_G = cfg.r / (cfg.p + cfg.r)
state_good = bool(rng.random() < pi_G)
mask = np.ones(T, dtype=np.float32)
for i in range(n_packets):
no_loss_prob = cfg.k if state_good else cfg.h
if rng.random() >= no_loss_prob:
# Drop all frames belonging to this packet
start = i * self.frames_per_packet
end = min(start + self.frames_per_packet, T)
mask[start:end] = 0.0
# Markov state transition
if state_good:
state_good = (rng.random() >= cfg.p)
else:
state_good = (rng.random() < cfg.r)
return mask
def sample_frame_mask(
self,
batch_size: int,
num_frames: int,
device: tp.Optional[torch.device] = None,
seed: tp.Optional[int] = None,
) -> torch.Tensor:
"""
Returns a float tensor [B, T'] with values in {0, 1}.
where 1 = frame received, 0 = frame lost.
Sequences in the batch are sampled independently.
"""
rng = np.random.default_rng(seed)
masks = np.stack([
self._sample_one(num_frames, rng) for _ in range(batch_size)
], axis=0)
out = torch.from_numpy(masks)
if device is not None:
out = out.to(device)
return out