GradientDescent2718 commited on
Commit
4603cf0
·
verified ·
1 Parent(s): fe10b77

Upload example/ls_eend_step_model.py

Browse files
Files changed (1) hide show
  1. example/ls_eend_step_model.py +384 -0
example/ls_eend_step_model.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ if TYPE_CHECKING:
12
+ from ls_eend_runtime import LSEENDInferenceEngine
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class StepStateLayout:
17
+ input_dim: int
18
+ full_output_dim: int
19
+ real_output_dim: int
20
+ encoder_layers: int
21
+ decoder_layers: int
22
+ encoder_dim: int
23
+ num_heads: int
24
+ key_dim: int
25
+ head_dim: int
26
+ encoder_conv_cache_len: int
27
+ top_buffer_len: int
28
+ conv_delay: int
29
+ max_nspks: int
30
+
31
+
32
+ def build_state_layout(engine: Any) -> StepStateLayout:
33
+ model = engine.model
34
+ params = engine.config["model"]["params"]
35
+ n_units = int(params["n_units"])
36
+ n_heads = int(params["n_heads"])
37
+ max_nspks = int(engine.decode_max_nspks)
38
+ encoder_conv_cache_len = int(params["conv_kernel_size"]) - 1
39
+ top_buffer_len = 2 * int(params["conv_delay"]) + 1
40
+ return StepStateLayout(
41
+ input_dim=(2 * engine.config["data"]["context_recp"] + 1) * engine.config["data"]["feat"]["n_mels"],
42
+ full_output_dim=max_nspks,
43
+ real_output_dim=max(0, max_nspks - 2),
44
+ encoder_layers=int(params["enc_n_layers"]),
45
+ decoder_layers=int(params["dec_n_layers"]),
46
+ encoder_dim=n_units,
47
+ num_heads=n_heads,
48
+ key_dim=n_units // n_heads,
49
+ head_dim=n_units // n_heads,
50
+ encoder_conv_cache_len=encoder_conv_cache_len,
51
+ top_buffer_len=top_buffer_len,
52
+ conv_delay=int(params["conv_delay"]),
53
+ max_nspks=max_nspks,
54
+ )
55
+
56
+
57
+ def initial_state_tensors(layout: StepStateLayout, dtype: np.dtype = np.float32) -> dict[str, np.ndarray]:
58
+ return {
59
+ "enc_ret_kv": np.zeros(
60
+ (layout.encoder_layers, 1, layout.num_heads, layout.key_dim, layout.head_dim),
61
+ dtype=dtype,
62
+ ),
63
+ "enc_ret_scale": np.zeros((layout.encoder_layers, 1, layout.num_heads), dtype=dtype),
64
+ "enc_conv_cache": np.zeros(
65
+ (layout.encoder_layers, 1, layout.encoder_conv_cache_len, layout.encoder_dim),
66
+ dtype=dtype,
67
+ ),
68
+ "dec_ret_kv": np.zeros(
69
+ (layout.decoder_layers, layout.max_nspks, layout.num_heads, layout.key_dim, layout.head_dim),
70
+ dtype=dtype,
71
+ ),
72
+ "dec_ret_scale": np.zeros(
73
+ (layout.decoder_layers, layout.max_nspks, layout.num_heads),
74
+ dtype=dtype,
75
+ ),
76
+ "top_buffer": np.zeros((1, layout.top_buffer_len, layout.encoder_dim), dtype=dtype),
77
+ }
78
+
79
+
80
+ def _as_rank3_scalar(value: torch.Tensor, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
81
+ return value.to(device=device, dtype=dtype).reshape(1, 1, 1)
82
+
83
+
84
+ def _safe_l2_normalize(x: torch.Tensor, dim: int) -> torch.Tensor:
85
+ # 1e-12 underflows to zero in fp16 CoreML execution and can produce NaNs
86
+ # during warmup frames when an embedding or attractor vector is exactly zero.
87
+ return x / torch.norm(x, dim=dim, keepdim=True).clamp_min(1e-4)
88
+
89
+
90
+ class OnlineStepModule(torch.nn.Module):
91
+ """Single online LS-EEND step with explicit state tensors for export/runtime backends."""
92
+
93
+ def __init__(self, model: torch.nn.Module, layout: StepStateLayout) -> None:
94
+ super().__init__()
95
+ self.model = model
96
+ self.layout = layout
97
+ self.encoder_decay = torch.exp(
98
+ self.model.enc.encoder.layers[0].sequential[1].module.ret_pos.decay
99
+ ).float()
100
+ self.decoder_decay = torch.exp(
101
+ self.model.dec.attractor_decoder.layers[0].ret_pos1.decay
102
+ ).float()
103
+
104
+ def forward(
105
+ self,
106
+ frame: torch.Tensor,
107
+ enc_ret_kv: torch.Tensor,
108
+ enc_ret_scale: torch.Tensor,
109
+ enc_conv_cache: torch.Tensor,
110
+ dec_ret_kv: torch.Tensor,
111
+ dec_ret_scale: torch.Tensor,
112
+ top_buffer: torch.Tensor,
113
+ ingest: torch.Tensor,
114
+ decode: torch.Tensor,
115
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
116
+ dtype = frame.dtype
117
+ device = frame.device
118
+ ingest_scalar = _as_rank3_scalar(ingest, dtype, device)
119
+ decode_scalar = _as_rank3_scalar(decode, dtype, device)
120
+ ingest_vec = ingest.to(device=device, dtype=dtype).reshape(1, 1)
121
+ decode_vec = decode.to(device=device, dtype=dtype).reshape(1, 1)
122
+
123
+ x = self.model.enc.encoder.input_projection(frame)
124
+ x = self.model.enc.encoder.layer_norm(x)
125
+
126
+ new_enc_ret_kv = []
127
+ new_enc_ret_scale = []
128
+ new_enc_conv_cache = []
129
+
130
+ for layer_index, layer in enumerate(self.model.enc.encoder.layers):
131
+ old_kv = enc_ret_kv[layer_index]
132
+ old_scale = enc_ret_scale[layer_index]
133
+ old_conv = enc_conv_cache[layer_index]
134
+ x, candidate_kv, candidate_scale, candidate_conv = self._encoder_layer_step(
135
+ layer=layer,
136
+ x=x,
137
+ old_kv=old_kv,
138
+ old_scale=old_scale,
139
+ old_conv_cache=old_conv,
140
+ )
141
+ blended_kv = old_kv + (candidate_kv - old_kv) * ingest_scalar.unsqueeze(-1)
142
+ blended_scale = old_scale + (candidate_scale - old_scale) * ingest_vec
143
+ blended_conv = old_conv + (candidate_conv - old_conv) * ingest_scalar
144
+ new_enc_ret_kv.append(blended_kv)
145
+ new_enc_ret_scale.append(blended_scale)
146
+ new_enc_conv_cache.append(blended_conv)
147
+
148
+ appended_encoder_frame = x * ingest_scalar
149
+ top_buffer = torch.cat([top_buffer[:, 1:, :], appended_encoder_frame], dim=1)
150
+
151
+ emb = F.conv1d(
152
+ top_buffer.transpose(1, 2),
153
+ self.model.cnn.weight,
154
+ self.model.cnn.bias,
155
+ ).transpose(1, 2)
156
+ emb = _safe_l2_normalize(emb, dim=-1)
157
+
158
+ logits, candidate_dec_ret_kv, candidate_dec_ret_scale = self._decoder_step(
159
+ emb=emb,
160
+ dec_ret_kv=dec_ret_kv,
161
+ dec_ret_scale=dec_ret_scale,
162
+ )
163
+
164
+ new_dec_ret_kv = dec_ret_kv + (candidate_dec_ret_kv - dec_ret_kv) * decode_scalar.unsqueeze(-1)
165
+ new_dec_ret_scale = dec_ret_scale + (candidate_dec_ret_scale - dec_ret_scale) * decode_vec.unsqueeze(-1)
166
+
167
+ logits = logits * decode_scalar
168
+
169
+ return (
170
+ logits,
171
+ torch.stack(new_enc_ret_kv, dim=0),
172
+ torch.stack(new_enc_ret_scale, dim=0),
173
+ torch.stack(new_enc_conv_cache, dim=0),
174
+ new_dec_ret_kv,
175
+ new_dec_ret_scale,
176
+ top_buffer,
177
+ )
178
+
179
+ def _encoder_layer_step(
180
+ self,
181
+ layer: torch.nn.Module,
182
+ x: torch.Tensor,
183
+ old_kv: torch.Tensor,
184
+ old_scale: torch.Tensor,
185
+ old_conv_cache: torch.Tensor,
186
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
187
+ ff1 = layer.sequential[0]
188
+ attn = layer.sequential[1].module
189
+ conv = layer.sequential[2].module
190
+ ff2 = layer.sequential[3]
191
+ final_norm = layer.sequential[4]
192
+
193
+ x = ff1.module(x) * ff1.module_factor + x * ff1.input_factor
194
+ attn_input = attn.layer_norm(x)
195
+ attn_output, candidate_kv, candidate_scale = self._retention_recurrent(
196
+ retention_module=attn.self_attn,
197
+ x=attn_input,
198
+ old_kv=old_kv,
199
+ old_scale=old_scale,
200
+ decay=self.encoder_decay,
201
+ )
202
+ x = x + attn.dropout(attn_output)
203
+ conv_output, candidate_conv = self._conformer_conv_step(conv, x, old_conv_cache)
204
+ x = x + conv_output
205
+ x = ff2.module(x) * ff2.module_factor + x * ff2.input_factor
206
+ return final_norm(x), candidate_kv, candidate_scale, candidate_conv
207
+
208
+ def _conformer_conv_step(
209
+ self,
210
+ conv_module: torch.nn.Module,
211
+ x: torch.Tensor,
212
+ old_cache: torch.Tensor,
213
+ ) -> tuple[torch.Tensor, torch.Tensor]:
214
+ modules = conv_module.sequential
215
+
216
+ current = modules[0](x)
217
+ current = modules[1](current)
218
+ current = modules[2](current)
219
+ current = modules[3](current)
220
+
221
+ cache = old_cache.transpose(1, 2)
222
+ depthwise_window = torch.cat([cache, current], dim=2)
223
+ depthwise_conv = modules[4].conv
224
+ depthwise = F.conv1d(
225
+ depthwise_window,
226
+ depthwise_conv.weight,
227
+ depthwise_conv.bias,
228
+ stride=depthwise_conv.stride,
229
+ padding=0,
230
+ dilation=depthwise_conv.dilation,
231
+ groups=depthwise_conv.groups,
232
+ )
233
+ candidate_cache = depthwise_window[:, :, -self.layout.encoder_conv_cache_len :].transpose(1, 2)
234
+
235
+ depthwise = modules[5](depthwise)
236
+ depthwise = modules[6](depthwise)
237
+ depthwise = modules[7](depthwise)
238
+ depthwise = modules[8](depthwise)
239
+ return depthwise.transpose(1, 2), candidate_cache
240
+
241
+ def _decoder_step(
242
+ self,
243
+ emb: torch.Tensor,
244
+ dec_ret_kv: torch.Tensor,
245
+ dec_ret_scale: torch.Tensor,
246
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
247
+ pos_enc = self.model.dec.pos_enc(emb, self.layout.max_nspks)
248
+ repeated_emb = emb.unsqueeze(dim=2).repeat(1, 1, self.layout.max_nspks, 1)
249
+ attractors = self.model.dec.convert(torch.cat([repeated_emb, pos_enc], dim=-1))
250
+
251
+ new_dec_ret_kv = []
252
+ new_dec_ret_scale = []
253
+ for layer_index, layer in enumerate(self.model.dec.attractor_decoder.layers):
254
+ attractors, candidate_kv, candidate_scale = self._fusion_layer_step(
255
+ layer=layer,
256
+ src=attractors,
257
+ old_kv=dec_ret_kv[layer_index],
258
+ old_scale=dec_ret_scale[layer_index],
259
+ )
260
+ new_dec_ret_kv.append(candidate_kv)
261
+ new_dec_ret_scale.append(candidate_scale)
262
+
263
+ if self.model.dec.attractor_decoder.norm is not None:
264
+ attractors = self.model.dec.attractor_decoder.norm(attractors)
265
+ attractors = _safe_l2_normalize(attractors, dim=-1)
266
+ logits = torch.matmul(emb.unsqueeze(dim=-2), attractors.transpose(-1, -2)).squeeze(dim=-2)
267
+ return logits, torch.stack(new_dec_ret_kv, dim=0), torch.stack(new_dec_ret_scale, dim=0)
268
+
269
+ def _fusion_layer_step(
270
+ self,
271
+ layer: torch.nn.Module,
272
+ src: torch.Tensor,
273
+ old_kv: torch.Tensor,
274
+ old_scale: torch.Tensor,
275
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
276
+ batch_size, time_steps, speaker_count, feat_dim = src.shape
277
+ x = src.transpose(1, 2).reshape(batch_size * speaker_count, time_steps, feat_dim)
278
+
279
+ if layer.norm_first:
280
+ time_input = layer.norm11(x)
281
+ time_output, candidate_kv, candidate_scale = self._retention_recurrent(
282
+ retention_module=layer.self_attn1,
283
+ x=time_input,
284
+ old_kv=old_kv,
285
+ old_scale=old_scale,
286
+ decay=self.decoder_decay,
287
+ )
288
+ x = x + layer.dropout11(time_output)
289
+ else:
290
+ time_output, candidate_kv, candidate_scale = self._retention_recurrent(
291
+ retention_module=layer.self_attn1,
292
+ x=x,
293
+ old_kv=old_kv,
294
+ old_scale=old_scale,
295
+ decay=self.decoder_decay,
296
+ )
297
+ x = layer.norm11(x + layer.dropout11(time_output))
298
+
299
+ x = x.reshape(batch_size, speaker_count, time_steps, feat_dim).transpose(1, 2)
300
+ x = x.reshape(batch_size * time_steps, speaker_count, feat_dim)
301
+
302
+ if layer.norm_first:
303
+ x = x + self._speaker_attention(layer.self_attn2, layer.norm21(x))
304
+ x = x + layer._ff_block(layer.norm22(x))
305
+ else:
306
+ x = layer.norm21(x + self._speaker_attention(layer.self_attn2, x))
307
+ x = layer.norm22(x + layer._ff_block(x))
308
+
309
+ return x.reshape(batch_size, time_steps, speaker_count, feat_dim), candidate_kv, candidate_scale
310
+
311
+ def _retention_recurrent(
312
+ self,
313
+ retention_module: torch.nn.Module,
314
+ x: torch.Tensor,
315
+ old_kv: torch.Tensor,
316
+ old_scale: torch.Tensor,
317
+ decay: torch.Tensor,
318
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
319
+ batch_size, target_length, _ = x.shape
320
+ q = retention_module.q_proj(x)
321
+ k = retention_module.k_proj(x)
322
+ v = retention_module.v_proj(x)
323
+ g = retention_module.g_proj(x)
324
+
325
+ k = k * retention_module.scaling
326
+ q = q.view(batch_size, target_length, retention_module.num_heads, retention_module.key_dim).transpose(1, 2)
327
+ k = k.view(batch_size, target_length, retention_module.num_heads, retention_module.key_dim).transpose(1, 2)
328
+ v = v.view(batch_size, retention_module.num_heads, retention_module.head_dim, 1)
329
+
330
+ qr = q
331
+ kr = k
332
+ kv = kr * v
333
+
334
+ decay = decay.to(device=x.device, dtype=x.dtype).reshape(1, retention_module.num_heads)
335
+ candidate_scale = old_scale * decay + 1.0
336
+ blend = (old_scale.sqrt() * decay / candidate_scale.sqrt()).unsqueeze(-1).unsqueeze(-1)
337
+ candidate_kv = old_kv * blend + kv / candidate_scale.sqrt().unsqueeze(-1).unsqueeze(-1)
338
+
339
+ output = torch.sum(qr * candidate_kv, dim=3)
340
+ output = retention_module.group_norm(output).reshape(
341
+ batch_size, target_length, retention_module.head_dim * retention_module.num_heads
342
+ )
343
+ output = retention_module.gate_fn(g) * output
344
+ output = retention_module.out_proj(output)
345
+ return output, candidate_kv, candidate_scale
346
+
347
+ def _speaker_attention(self, attention: torch.nn.MultiheadAttention, x: torch.Tensor) -> torch.Tensor:
348
+ batch_size, seq_len, embed_dim = x.shape
349
+ head_dim = embed_dim // attention.num_heads
350
+ q_weight, k_weight, v_weight = attention.in_proj_weight.chunk(3, dim=0)
351
+ q_bias, k_bias, v_bias = attention.in_proj_bias.chunk(3, dim=0)
352
+
353
+ q = F.linear(x, q_weight, q_bias)
354
+ k = F.linear(x, k_weight, k_bias)
355
+ v = F.linear(x, v_weight, v_bias)
356
+
357
+ q = q.view(batch_size, seq_len, attention.num_heads, head_dim).transpose(1, 2)
358
+ k = k.view(batch_size, seq_len, attention.num_heads, head_dim).transpose(1, 2)
359
+ v = v.view(batch_size, seq_len, attention.num_heads, head_dim).transpose(1, 2)
360
+
361
+ attn = torch.matmul(q, k.transpose(-2, -1)) / (head_dim**0.5)
362
+ attn = torch.softmax(attn, dim=-1)
363
+ out = torch.matmul(attn, v)
364
+ out = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
365
+ return F.linear(out, attention.out_proj.weight, attention.out_proj.bias)
366
+
367
+
368
+ def load_step_module(
369
+ checkpoint_path: Path,
370
+ config_path: Path,
371
+ device: str = "cpu",
372
+ ) -> tuple[OnlineStepModule, StepStateLayout, "LSEENDInferenceEngine"]:
373
+ from ls_eend_runtime import LSEENDInferenceEngine
374
+
375
+ engine = LSEENDInferenceEngine(
376
+ checkpoint_path=checkpoint_path,
377
+ config_path=config_path,
378
+ device=device,
379
+ )
380
+ engine.model = engine.model.float().to(torch.device(device))
381
+ engine.model.eval()
382
+ layout = build_state_layout(engine)
383
+ module = OnlineStepModule(engine.model, layout).to(torch.device(device)).eval()
384
+ return module, layout, engine