Lucabr01 commited on
Commit
d864252
Β·
verified Β·
1 Parent(s): 68e74d1

Upload zpcodec/repair.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. zpcodec/repair.py +288 -0
zpcodec/repair.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LatentRepairTransformer: packet-loss concealment in the RVQ latent domain.
3
+
4
+ Receives z_q_masked [B, T, D] and frame_mask [B, T] (1 = received, 0 = missing),
5
+ returns z_q_repaired [B, T, D] where missing frames are reconstructed by
6
+ attending only to valid frames within a local window [t-past, t+future].
7
+
8
+ Selective substitution (replacing only missing frames in the output) is handled
9
+ upstream in ZPCodec._apply_repair, not here β€” this module always produces a
10
+ full-length output tensor.
11
+
12
+ Why operate in the latent domain rather than on tokens?
13
+ Tokens are discrete: a small error in the codec's estimate produces a
14
+ completely different codebook entry, with no gradient signal. Latent vectors
15
+ are continuous, so the transformer can produce soft interpolations between
16
+ neighbouring frames and the repair loss (L1 on z) provides a smooth gradient.
17
+ In a real RTP deployment this makes no difference to the transmitted bitstream.
18
+ """
19
+
20
+ import math
21
+ import typing as tp
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+
28
+ def _build_local_attention_mask(
29
+ T: int,
30
+ past: int,
31
+ future: int,
32
+ device: torch.device,
33
+ ) -> torch.Tensor:
34
+ """
35
+ Returns a boolean mask [T, T] where True means the position is allowed.
36
+ Query t can attend to key s if -past <= s - t <= future.
37
+ This enforces a fixed-size local receptive field and a bounded lookahead.
38
+ """
39
+ idx = torch.arange(T, device=device)
40
+ delta = idx.unsqueeze(0) - idx.unsqueeze(1) # delta[t, s] = s - t
41
+ return (delta >= -past) & (delta <= future)
42
+
43
+
44
+ def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
45
+ """
46
+ Applies Rotary Position Embedding (RoPE) to queries or keys.
47
+ x: [B, H, T, D_head]
48
+ cos: [T, D_head]
49
+ sin: [T, D_head]
50
+ RoPE encodes relative positions directly in the dot product, making it
51
+ compatible with the local attention mask without requiring learned positional embeddings.
52
+ """
53
+ x1, x2 = x.chunk(2, dim=-1)
54
+ rotated = torch.cat([-x2, x1], dim=-1)
55
+ return x * cos + rotated * sin
56
+
57
+
58
+ class RotaryEmbedding(nn.Module):
59
+ """Precomputes and caches RoPE cos/sin tables up to max_seq_len."""
60
+ def __init__(self, dim_head: int, max_seq_len: int = 2048, base: float = 10000.0):
61
+ super().__init__()
62
+ assert dim_head % 2 == 0, "RoPE requires even dim_head"
63
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim_head, 2).float() / dim_head))
64
+ t = torch.arange(max_seq_len).float()
65
+ freqs = torch.einsum('t,d->td', t, inv_freq) # [T, dim_head/2]
66
+ emb = torch.cat([freqs, freqs], dim=-1) # [T, dim_head]
67
+ self.register_buffer('cos_cached', emb.cos(), persistent=False)
68
+ self.register_buffer('sin_cached', emb.sin(), persistent=False)
69
+
70
+ def forward(self, T: int) -> tp.Tuple[torch.Tensor, torch.Tensor]:
71
+ return self.cos_cached[:T], self.sin_cached[:T]
72
+
73
+
74
+ class MaskedLocalAttention(nn.Module):
75
+ """
76
+ Multi-head self-attention with two simultaneous masks:
77
+ 1) Local window mask: query t can only attend within [t-past, t+future].
78
+ Keeps the receptive field bounded and latency predictable.
79
+ 2) Validity mask: keys from missing frames (frame_mask=0) are excluded,
80
+ so the transformer cannot "cheat" by attending to frames it doesn't have.
81
+
82
+ The combination forces the model to reconstruct missing frames exclusively
83
+ from neighbouring received frames β€” the same information available at inference.
84
+ """
85
+
86
+ def __init__(self, dim: int, num_heads: int = 4, past: int = 8, future: int = 2):
87
+ super().__init__()
88
+ assert dim % num_heads == 0
89
+ self.num_heads = num_heads
90
+ self.dim_head = dim // num_heads
91
+ self.scale = self.dim_head ** -0.5
92
+ self.past = past
93
+ self.future = future
94
+
95
+ self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
96
+ self.to_out = nn.Linear(dim, dim, bias=False)
97
+ self.rope = RotaryEmbedding(self.dim_head)
98
+
99
+ def forward(self, x: torch.Tensor, frame_mask: torch.Tensor) -> torch.Tensor:
100
+ """
101
+ x: [B, T, D]
102
+ frame_mask: [B, T] 1 = received, 0 = missing
103
+ """
104
+ B, T, D = x.shape
105
+ H, Dh = self.num_heads, self.dim_head
106
+
107
+ qkv = self.to_qkv(x).reshape(B, T, 3, H, Dh).permute(2, 0, 3, 1, 4)
108
+ q, k, v = qkv.unbind(0) # [B, H, T, Dh]
109
+
110
+ cos, sin = self.rope(T)
111
+ q = _apply_rope(q, cos, sin)
112
+ k = _apply_rope(k, cos, sin)
113
+
114
+ scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale # [B, H, T, T]
115
+
116
+ # Local window mask [T, T]: True where attention is allowed
117
+ local_allowed = _build_local_attention_mask(T, self.past, self.future, x.device)
118
+
119
+ # Validity mask: exclude keys from missing frames
120
+ # frame_mask [B, T] -> key_valid [B, 1, 1, T]
121
+ key_valid = frame_mask.bool().unsqueeze(1).unsqueeze(1)
122
+
123
+ allowed = local_allowed.unsqueeze(0).unsqueeze(0) & key_valid # [B, 1, T, T]
124
+
125
+ # Failsafe: if a query has no valid key in its window (e.g. a long burst
126
+ # of losses covers the entire local range), re-enable the diagonal so
127
+ # softmax(-inf, ...) doesn't produce NaN. The query then attends to itself
128
+ # (its learned missing_frame_embedding), which is a reasonable fallback.
129
+ any_valid = allowed.any(dim=-1, keepdim=True)
130
+ diag = torch.eye(T, dtype=torch.bool, device=x.device).unsqueeze(0).unsqueeze(0)
131
+ allowed = allowed | (~any_valid & diag)
132
+
133
+ scores = scores.masked_fill(~allowed, float('-inf'))
134
+ attn = F.softmax(scores, dim=-1)
135
+
136
+ out = torch.matmul(attn, v) # [B, H, T, Dh]
137
+ out = out.transpose(1, 2).reshape(B, T, D)
138
+ return self.to_out(out)
139
+
140
+
141
+ class TransformerBlock(nn.Module):
142
+ """Standard pre-norm transformer block (attention + FFN) with masked local attention."""
143
+ def __init__(self, dim: int, num_heads: int, ffn_mult: int, past: int, future: int):
144
+ super().__init__()
145
+ self.norm1 = nn.LayerNorm(dim)
146
+ self.attn = MaskedLocalAttention(dim, num_heads, past, future)
147
+ self.norm2 = nn.LayerNorm(dim)
148
+ self.ffn = nn.Sequential(
149
+ nn.Linear(dim, dim * ffn_mult),
150
+ nn.GELU(),
151
+ nn.Linear(dim * ffn_mult, dim),
152
+ )
153
+
154
+ def forward(self, x: torch.Tensor, frame_mask: torch.Tensor) -> torch.Tensor:
155
+ x = x + self.attn(self.norm1(x), frame_mask)
156
+ x = x + self.ffn(self.norm2(x))
157
+ return x
158
+
159
+
160
+ class LatentRepairTransformer(nn.Module):
161
+ """
162
+ Local inpainting of z_q after simulated packet loss on RVQ tokens.
163
+
164
+ Architecture:
165
+ missing_frame_embedding β€” learned placeholder substituted for lost frames
166
+ before the transformer sees the sequence
167
+ mask_embedding β€” additive token telling the model which frames
168
+ are received (1) vs missing (0)
169
+ in_proj β€” projects latent_dim -> hidden_dim
170
+ blocks β€” stack of MaskedLocalAttention + FFN layers
171
+ out_norm + out_proj β€” projects hidden_dim -> latent_dim
172
+
173
+ Args:
174
+ latent_dim: D of the RVQ latent (128 for ZPCodec).
175
+ hidden_dim: internal transformer width.
176
+ num_layers: number of transformer blocks.
177
+ num_heads: attention heads.
178
+ ffn_mult: FFN hidden size = hidden_dim * ffn_mult.
179
+ past: past frames in the local receptive field.
180
+ future: future frames (lookahead; each frame = 15ms latency cost).
181
+ """
182
+
183
+ def __init__(
184
+ self,
185
+ latent_dim: int = 128,
186
+ hidden_dim: int = 256,
187
+ num_layers: int = 6,
188
+ num_heads: int = 4,
189
+ ffn_mult: int = 4,
190
+ past: int = 8,
191
+ future: int = 2,
192
+ ):
193
+ super().__init__()
194
+ self.latent_dim = latent_dim
195
+ self.past = past
196
+ self.future = future
197
+
198
+ # Learned placeholder for missing frames.
199
+ # Substituted into z_q where frame_mask == 0 before the forward pass.
200
+ self.missing_frame_embedding = nn.Parameter(torch.zeros(latent_dim))
201
+ nn.init.normal_(self.missing_frame_embedding, std=0.02)
202
+
203
+ # Additive embedding that signals received (1) vs missing (0) to the model.
204
+ # Lets the transformer distinguish genuine latent vectors from placeholders.
205
+ self.mask_embedding = nn.Embedding(2, hidden_dim)
206
+
207
+ self.in_proj = nn.Linear(latent_dim, hidden_dim)
208
+ self.blocks = nn.ModuleList([
209
+ TransformerBlock(hidden_dim, num_heads, ffn_mult, past, future)
210
+ for _ in range(num_layers)
211
+ ])
212
+ self.out_norm = nn.LayerNorm(hidden_dim)
213
+ self.out_proj = nn.Linear(hidden_dim, latent_dim)
214
+
215
+ def fill_missing(self, z_q: torch.Tensor, frame_mask: torch.Tensor) -> torch.Tensor:
216
+ """
217
+ Replace frames where frame_mask == 0 with the learned missing_frame_embedding.
218
+ z_q: [B, T, D]
219
+ frame_mask: [B, T]
220
+ """
221
+ B, T, D = z_q.shape
222
+ emb = self.missing_frame_embedding.view(1, 1, D).expand(B, T, D)
223
+ m = frame_mask.unsqueeze(-1).to(z_q.dtype)
224
+ return z_q * m + emb * (1.0 - m)
225
+
226
+ def forward(
227
+ self,
228
+ z_q_masked: torch.Tensor,
229
+ frame_mask: torch.Tensor,
230
+ ) -> torch.Tensor:
231
+ """
232
+ z_q_masked: [B, T, D] β€” missing frames must already contain
233
+ missing_frame_embedding; call fill_missing() first if needed.
234
+ frame_mask: [B, T] β€” 1 = received, 0 = missing.
235
+ Returns: [B, T, D] β€” full reconstructed sequence.
236
+ Selective substitution (only replacing missing frames in z_q)
237
+ is done upstream in ZPCodec._apply_repair.
238
+ """
239
+ x = self.in_proj(z_q_masked)
240
+ x = x + self.mask_embedding(frame_mask.long()) # inject received/missing signal
241
+
242
+ for block in self.blocks:
243
+ x = block(x, frame_mask)
244
+
245
+ x = self.out_norm(x)
246
+ return self.out_proj(x)
247
+
248
+ def forward_two_pass(
249
+ self,
250
+ z_q: torch.Tensor,
251
+ frame_mask: torch.Tensor,
252
+ ) -> torch.Tensor:
253
+ """
254
+ Two-pass forward that mimics streaming deployment behaviour.
255
+
256
+ In real-time streaming, when estimating a lost frame at time t, the
257
+ estimates for previously lost frames (s < t) are already in the buffer β€”
258
+ not the original missing_frame_embedding. Training with a single pass
259
+ creates a train/inference mismatch because the model never sees its own
260
+ estimates as context. Two passes close that gap.
261
+
262
+ Pass 1: standard forward with missing_emb as placeholder for all lost frames.
263
+ Produces initial rough estimates z_pass1.
264
+
265
+ Pass 2: update the buffer β€” lost frames now contain z_pass1 instead of
266
+ missing_emb. Re-run the transformer on the updated buffer to
267
+ produce refined estimates z_pass2.
268
+
269
+ z_q: [B, T, D] original quantized latent (NOT yet masked).
270
+ frame_mask: [B, T]
271
+ Returns: [B, T, D] refined estimates for missing frames.
272
+ Values at received-frame positions are arbitrary β€”
273
+ selective substitution happens upstream in ZPCodec._apply_repair.
274
+ """
275
+ # Pass 1: fill missing with placeholder, run transformer
276
+ z_masked_1 = self.fill_missing(z_q, frame_mask)
277
+ z_pass1 = self.forward(z_masked_1, frame_mask)
278
+
279
+ # Update buffer: received frames keep z_q, lost frames get pass-1 estimates
280
+ m = frame_mask.unsqueeze(-1).to(z_q.dtype)
281
+ z_buffer_updated = z_q * m + z_pass1 * (1.0 - m)
282
+
283
+ # Pass 2: re-run on updated buffer.
284
+ # Do NOT call fill_missing again β€” lost frames already contain z_pass1,
285
+ # which is exactly the buffer state we want to simulate.
286
+ z_pass2 = self.forward(z_buffer_updated, frame_mask)
287
+
288
+ return z_pass2