thanrl commited on
Commit
490a4f3
·
verified ·
1 Parent(s): ec1dc1f

Update modeling_gigaam.py

Browse files
Files changed (1) hide show
  1. modeling_gigaam.py +423 -8
modeling_gigaam.py CHANGED
@@ -5,6 +5,8 @@ import os
5
  import sys
6
  import warnings
7
  from abc import ABC, abstractmethod
 
 
8
  from pathlib import Path
9
  from subprocess import CalledProcessError, run
10
  from typing import Any, Dict, List, Optional, Tuple, Union
@@ -35,6 +37,144 @@ _PIPELINE = None
35
  ### preprocess ###
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def load_audio(audio_path: str, sample_rate: int = SAMPLE_RATE) -> Tensor:
39
  """
40
  Load an audio file and resample it to the specified sample rate.
@@ -89,6 +229,13 @@ class FeatureExtractor(nn.Module):
89
  self.win_length = kwargs.get("win_length", sample_rate // 40)
90
  self.n_fft = kwargs.get("n_fft", sample_rate // 40)
91
  self.center = kwargs.get("center", True)
 
 
 
 
 
 
 
92
  self.featurizer = nn.Sequential(
93
  torchaudio.transforms.MelSpectrogram(
94
  sample_rate=sample_rate,
@@ -97,10 +244,27 @@ class FeatureExtractor(nn.Module):
97
  hop_length=self.hop_length,
98
  n_fft=self.n_fft,
99
  center=self.center,
 
100
  ),
101
  SpecScaler(),
102
  )
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def out_len(self, input_lengths: Tensor) -> Tensor:
105
  """
106
  Calculates the output length after the feature extraction process.
@@ -1107,6 +1271,54 @@ class CTCGreedyDecoding:
1107
  return pred_texts
1108
 
1109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1110
  class RNNTGreedyDecoding:
1111
  def __init__(
1112
  self,
@@ -1121,29 +1333,88 @@ class RNNTGreedyDecoding:
1121
  self.blank_id = len(self.tokenizer)
1122
  self.max_symbols = max_symbols_per_step
1123
 
1124
- def _greedy_decode(self, head: RNNTHead, x: Tensor, seqlen: Tensor) -> str:
1125
- """
1126
- Internal helper function for performing greedy decoding on a single sequence.
1127
- """
 
 
 
 
 
1128
  hyp: List[int] = []
1129
  dec_state: Optional[Tensor] = None
1130
  last_label: Optional[Tensor] = None
 
 
 
 
 
 
 
 
 
1131
  for t in range(seqlen):
1132
  f = x[t, :, :].unsqueeze(1)
1133
  not_blank = True
1134
  new_symbols = 0
1135
  while not_blank and new_symbols < self.max_symbols:
1136
  g, hidden = head.decoder.predict(last_label, dec_state)
1137
- k = head.joint.joint(f, g)[0, 0, 0, :].argmax(0).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1138
  if k == self.blank_id:
 
1139
  not_blank = False
1140
  else:
 
 
 
1141
  hyp.append(int(k))
1142
  dec_state = hidden
1143
- last_label = torch.tensor([[hyp[-1]]]).to(x.device)
1144
  new_symbols += 1
1145
 
1146
- return self.tokenizer.decode(hyp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1147
 
1148
  @torch.inference_mode()
1149
  def decode(self, head: RNNTHead, encoded: Tensor, enc_len: Tensor) -> List[str]:
@@ -1159,6 +1430,23 @@ class RNNTGreedyDecoding:
1159
  return pred_texts
1160
 
1161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1162
  ### models ###
1163
 
1164
 
@@ -1180,9 +1468,20 @@ class GigaAM(nn.Module):
1180
  Perform forward pass through the preprocessor and encoder.
1181
  """
1182
  features, feature_lengths = self.preprocessor(features, feature_lengths)
 
 
 
 
 
1183
  if self._device.type == "cpu":
1184
  return self.encoder(features, feature_lengths)
1185
- with torch.autocast(device_type=self._device.type, dtype=torch.float16):
 
 
 
 
 
 
1186
  return self.encoder(features, feature_lengths)
1187
 
1188
  @property
@@ -1197,8 +1496,30 @@ class GigaAM(nn.Module):
1197
  """
1198
  Prepare an audio file for processing by loading it onto
1199
  the correct device and converting its format.
 
 
 
 
1200
  """
1201
  wav = load_audio(wav_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1202
  wav = wav.to(self._device).to(self._dtype).unsqueeze(0)
1203
  length = torch.full([1], wav.shape[-1], device=self._device)
1204
  return wav, length
@@ -1252,6 +1573,100 @@ class GigaAMASR(GigaAM):
1252
  encoded, encoded_len = self.forward(wav, length)
1253
  return self.decoding.decode(self.head, encoded, encoded_len)[0]
1254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1255
  def forward_for_export(self, features: Tensor, feature_lengths: Tensor) -> Tensor:
1256
  """
1257
  Encoder-decoder forward to save model entirely in onnx format.
 
5
  import sys
6
  import warnings
7
  from abc import ABC, abstractmethod
8
+ from dataclasses import dataclass
9
+ from contextlib import contextmanager, nullcontext
10
  from pathlib import Path
11
  from subprocess import CalledProcessError, run
12
  from typing import Any, Dict, List, Optional, Tuple, Union
 
37
  ### preprocess ###
38
 
39
 
40
+ # --- Debug/robustness toggles (env vars, no config changes required) ---
41
+ # Set GIGAAM_DEBUG=1 to enable warnings and per-utterance stats printing
42
+ # Set GIGAAM_FORCE_FP32=1 to disable autocast and run encoder in fp32
43
+ # Set GIGAAM_PAD_START_MS / GIGAAM_PAD_END_MS to pad waveform with silence (milliseconds)
44
+ # Set GIGAAM_MELS_PAD_MODE to override torchaudio MelSpectrogram pad_mode (e.g. "constant" or "reflect")
45
+ # Set GIGAAM_MELS_CENTER to override center (0/1) for MelSpectrogram
46
+ def _env_flag(name: str, default: bool = False) -> bool:
47
+ v = os.environ.get(name, None)
48
+ if v is None:
49
+ return default
50
+ return v.strip().lower() in {"1", "true", "yes", "y", "on"}
51
+
52
+ def _env_int(name: str, default: int = 0) -> int:
53
+ v = os.environ.get(name, None)
54
+ if v is None or v == "":
55
+ return default
56
+ try:
57
+ return int(float(v))
58
+ except Exception:
59
+ return default
60
+
61
+ def _env_str(name: str, default: str = "") -> str:
62
+ v = os.environ.get(name, None)
63
+ if v is None or v == "":
64
+ return default
65
+ return str(v)
66
+
67
+ def _env_opt_bool(name: str):
68
+ v = os.environ.get(name, None)
69
+ if v is None or v == "":
70
+ return None
71
+ return v.strip().lower() in {"1", "true", "yes", "y", "on"}
72
+
73
+ def _print_once(msg: str):
74
+ # avoid spamming in batched scenarios
75
+ key = "_GIGAAM_PRINTED"
76
+ printed = globals().setdefault(key, set())
77
+ if msg not in printed:
78
+ print(msg)
79
+ printed.add(msg)
80
+
81
+ def audio_stats(wav: Tensor, sr: int = SAMPLE_RATE) -> Dict[str, Any]:
82
+ # wav: 1D float tensor in [-1, 1] (best-effort)
83
+ if wav.numel() == 0:
84
+ return {"samples": 0, "seconds": 0.0}
85
+ x = wav.detach()
86
+ x = x.float().view(-1)
87
+ mean = x.mean().item()
88
+ x0 = x - mean
89
+ rms = torch.sqrt(torch.mean(x0 * x0)).item()
90
+ peak = torch.max(torch.abs(x)).item()
91
+ # "clipping" heuristic for int16-style inputs: near full-scale
92
+ clip_frac = (torch.abs(x) >= 0.999).float().mean().item()
93
+ # leading/trailing silence (rough): threshold at -45 dBFS ~= 0.0056
94
+ thr = 10 ** (-45 / 20)
95
+ above = (torch.abs(x) > thr).nonzero(as_tuple=False).view(-1)
96
+ if above.numel() == 0:
97
+ lead_s = x.numel() / sr
98
+ trail_s = x.numel() / sr
99
+ else:
100
+ lead_s = (above[0].item() / sr)
101
+ trail_s = ((x.numel() - 1 - above[-1].item()) / sr)
102
+ return {
103
+ "samples": int(x.numel()),
104
+ "seconds": float(x.numel() / sr),
105
+ "dtype": str(wav.dtype),
106
+ "mean": float(mean),
107
+ "rms": float(rms),
108
+ "peak": float(peak),
109
+ "clip_frac": float(clip_frac),
110
+ "lead_silence_s": float(lead_s),
111
+ "trail_silence_s": float(trail_s),
112
+ "nan": bool(torch.isnan(x).any().item()),
113
+ "inf": bool(torch.isinf(x).any().item()),
114
+ }
115
+
116
+ def pad_wav(wav: Tensor, sr: int, pad_start_ms: int = 0, pad_end_ms: int = 0) -> Tensor:
117
+ if pad_start_ms <= 0 and pad_end_ms <= 0:
118
+ return wav
119
+ pad_start = int(sr * pad_start_ms / 1000.0)
120
+ pad_end = int(sr * pad_end_ms / 1000.0)
121
+ if pad_start < 0 or pad_end < 0:
122
+ return wav
123
+ dtype = wav.dtype
124
+ device = wav.device
125
+ pre = torch.zeros(pad_start, dtype=dtype, device=device)
126
+ post = torch.zeros(pad_end, dtype=dtype, device=device)
127
+ return torch.cat([pre, wav.view(-1), post], dim=0)
128
+
129
+ def print_env_versions():
130
+ try:
131
+ import transformers as _tf
132
+ tfv = getattr(_tf, "__version__", "unknown")
133
+ except Exception:
134
+ tfv = "unknown"
135
+ _print_once(f"[GigaAM debug] torch={torch.__version__} torchaudio={torchaudio.__version__} transformers={tfv}")
136
+
137
+
138
+
139
+ @contextmanager
140
+ def temp_environ(**updates: str):
141
+ """Temporarily set os.environ keys for the duration of a context."""
142
+ old = {}
143
+ try:
144
+ for k, v in updates.items():
145
+ old[k] = os.environ.get(k, None)
146
+ if v is None:
147
+ os.environ.pop(k, None)
148
+ else:
149
+ os.environ[k] = str(v)
150
+ yield
151
+ finally:
152
+ for k, prev in old.items():
153
+ if prev is None:
154
+ os.environ.pop(k, None)
155
+ else:
156
+ os.environ[k] = prev
157
+
158
+ @contextmanager
159
+ def temporary_module_dtype(module: nn.Module, dtype: torch.dtype):
160
+ """Temporarily cast a module to dtype; restores original dtype afterwards."""
161
+ try:
162
+ p = next(module.parameters())
163
+ orig = p.dtype
164
+ except StopIteration:
165
+ orig = dtype
166
+ if orig == dtype:
167
+ yield
168
+ return
169
+ module.to(dtype)
170
+ try:
171
+ yield
172
+ finally:
173
+ module.to(orig)
174
+ @dataclass
175
+ class DecodeDebug:
176
+ text: str
177
+ stats: Dict[str, Any]
178
  def load_audio(audio_path: str, sample_rate: int = SAMPLE_RATE) -> Tensor:
179
  """
180
  Load an audio file and resample it to the specified sample rate.
 
229
  self.win_length = kwargs.get("win_length", sample_rate // 40)
230
  self.n_fft = kwargs.get("n_fft", sample_rate // 40)
231
  self.center = kwargs.get("center", True)
232
+ env_center = _env_opt_bool("GIGAAM_MELS_CENTER")
233
+ if env_center is not None:
234
+ self.center = bool(env_center)
235
+ self.pad_mode = kwargs.get("pad_mode", "reflect")
236
+ env_pad_mode = _env_str("GIGAAM_MELS_PAD_MODE", "")
237
+ if env_pad_mode:
238
+ self.pad_mode = env_pad_mode
239
  self.featurizer = nn.Sequential(
240
  torchaudio.transforms.MelSpectrogram(
241
  sample_rate=sample_rate,
 
244
  hop_length=self.hop_length,
245
  n_fft=self.n_fft,
246
  center=self.center,
247
+ pad_mode=self.pad_mode,
248
  ),
249
  SpecScaler(),
250
  )
251
 
252
+ def set_mels_padding(self, *, center: Optional[bool] = None, pad_mode: Optional[str] = None) -> None:
253
+ """Hot-swap MelSpectrogram padding behavior for debugging."""
254
+ if center is not None:
255
+ self.center = bool(center)
256
+ # try to update the underlying transform if possible
257
+ m = self.featurizer[0]
258
+ if hasattr(m, "center"):
259
+ m.center = self.center # type: ignore[attr-defined]
260
+ if pad_mode is not None:
261
+ self.pad_mode = str(pad_mode)
262
+ m = self.featurizer[0]
263
+ if hasattr(m, "pad_mode"):
264
+ m.pad_mode = self.pad_mode # type: ignore[attr-defined]
265
+ elif hasattr(m, "spectrogram") and hasattr(m.spectrogram, "pad_mode"):
266
+ m.spectrogram.pad_mode = self.pad_mode # type: ignore[attr-defined]
267
+
268
  def out_len(self, input_lengths: Tensor) -> Tensor:
269
  """
270
  Calculates the output length after the feature extraction process.
 
1271
  return pred_texts
1272
 
1273
 
1274
+ @torch.inference_mode()
1275
+ def decode_with_debug(
1276
+ self, head: CTCHead, encoded: Tensor, lengths: Tensor, topk: int = 5
1277
+ ) -> Tuple[List[str], List[DecodeDebug]]:
1278
+ """Like decode(), but also returns per-utterance blank/argmax diagnostics."""
1279
+ log_probs = head(encoder_output=encoded)
1280
+ labels = log_probs.argmax(dim=-1, keepdim=False)
1281
+ b, t, c = log_probs.shape
1282
+
1283
+ pred_texts = self.decode(head, encoded, lengths)
1284
+
1285
+ debugs: List[DecodeDebug] = []
1286
+ for i in range(b):
1287
+ L = int(lengths[i].item())
1288
+ L = max(0, min(L, t))
1289
+ if L == 0:
1290
+ debugs.append(DecodeDebug(text=pred_texts[i], stats={"enc_len": 0}))
1291
+ continue
1292
+ lab = labels[i, :L]
1293
+ blank = (lab == self.blank_id)
1294
+ blank_ratio = float(blank.float().mean().item())
1295
+ # first frame where argmax != blank
1296
+ nonblank_idx = (~blank).nonzero(as_tuple=False).view(-1)
1297
+ first_nonblank = int(nonblank_idx[0].item()) if nonblank_idx.numel() else None
1298
+ # top-k distribution at a few frames (start/mid/end) for quick inspection
1299
+ probe_frames = sorted(set([0, L // 2, max(0, L - 1)]))
1300
+ probes: Dict[str, Any] = {}
1301
+ for pf in probe_frames:
1302
+ vals, idxs = torch.topk(log_probs[i, pf, :], k=min(topk, c), dim=-1)
1303
+ probes[str(pf)] = {
1304
+ "topk_ids": idxs.detach().cpu().tolist(),
1305
+ "topk_logp": [float(v) for v in vals.detach().cpu().tolist()],
1306
+ "blank_logp": float(log_probs[i, pf, self.blank_id].item()),
1307
+ }
1308
+ debugs.append(
1309
+ DecodeDebug(
1310
+ text=pred_texts[i],
1311
+ stats={
1312
+ "enc_len": L,
1313
+ "blank_ratio_argmax": blank_ratio,
1314
+ "first_nonblank_frame": first_nonblank,
1315
+ "probe_frames": probes,
1316
+ },
1317
+ )
1318
+ )
1319
+ return pred_texts, debugs
1320
+
1321
+
1322
  class RNNTGreedyDecoding:
1323
  def __init__(
1324
  self,
 
1333
  self.blank_id = len(self.tokenizer)
1334
  self.max_symbols = max_symbols_per_step
1335
 
1336
+ def _greedy_decode_impl(
1337
+ self,
1338
+ head: RNNTHead,
1339
+ x: Tensor,
1340
+ seqlen: Tensor,
1341
+ collect_stats: bool = False,
1342
+ topk: int = 5,
1343
+ ) -> DecodeDebug:
1344
+ """Greedy RNNT decode for a single sequence, with optional blank diagnostics."""
1345
  hyp: List[int] = []
1346
  dec_state: Optional[Tensor] = None
1347
  last_label: Optional[Tensor] = None
1348
+
1349
+ # Diagnostics (kept lightweight unless collect_stats=True)
1350
+ total_joint_steps = 0
1351
+ blank_steps = 0
1352
+ emitted_steps = 0
1353
+ first_emit_frame: Optional[int] = None
1354
+ blank_margins: List[float] = []
1355
+ probe_frames: Dict[str, Any] = {}
1356
+
1357
  for t in range(seqlen):
1358
  f = x[t, :, :].unsqueeze(1)
1359
  not_blank = True
1360
  new_symbols = 0
1361
  while not_blank and new_symbols < self.max_symbols:
1362
  g, hidden = head.decoder.predict(last_label, dec_state)
1363
+ logp = head.joint.joint(f, g)[0, 0, 0, :] # log-probs over vocab+blank
1364
+ total_joint_steps += 1
1365
+
1366
+ k = int(logp.argmax(0).item())
1367
+ if collect_stats:
1368
+ # how strongly blank beats the best non-blank
1369
+ blank_lp = float(logp[self.blank_id].item())
1370
+ best_nonblank_lp = float(logp[: self.blank_id].max().item())
1371
+ blank_margins.append(blank_lp - best_nonblank_lp)
1372
+ if t in (0, int(seqlen) // 2, max(0, int(seqlen) - 1)) and str(t) not in probe_frames:
1373
+ vals, idxs = torch.topk(logp, k=min(topk, logp.numel()))
1374
+ probe_frames[str(int(t))] = {
1375
+ "topk_ids": idxs.detach().cpu().tolist(),
1376
+ "topk_logp": [float(v) for v in vals.detach().cpu().tolist()],
1377
+ "blank_logp": blank_lp,
1378
+ }
1379
+
1380
  if k == self.blank_id:
1381
+ blank_steps += 1
1382
  not_blank = False
1383
  else:
1384
+ emitted_steps += 1
1385
+ if first_emit_frame is None:
1386
+ first_emit_frame = int(t)
1387
  hyp.append(int(k))
1388
  dec_state = hidden
1389
+ last_label = torch.tensor([[hyp[-1]]], device=x.device)
1390
  new_symbols += 1
1391
 
1392
+ text = self.tokenizer.decode(hyp)
1393
+
1394
+ stats: Dict[str, Any] = {}
1395
+ if collect_stats:
1396
+ # Summaries only (avoid huge blobs)
1397
+ if blank_margins:
1398
+ bm = torch.tensor(blank_margins)
1399
+ stats["blank_margin_mean"] = float(bm.mean().item())
1400
+ stats["blank_margin_p50"] = float(bm.median().item())
1401
+ stats["blank_margin_p90"] = float(torch.quantile(bm, 0.9).item())
1402
+ stats.update(
1403
+ {
1404
+ "enc_len": int(seqlen),
1405
+ "total_joint_steps": int(total_joint_steps),
1406
+ "blank_steps": int(blank_steps),
1407
+ "emitted_steps": int(emitted_steps),
1408
+ "blank_step_frac": float(blank_steps / max(1, total_joint_steps)),
1409
+ "first_emit_frame": first_emit_frame,
1410
+ "probe_frames": probe_frames,
1411
+ }
1412
+ )
1413
+ return DecodeDebug(text=text, stats=stats)
1414
+
1415
+ def _greedy_decode(self, head: RNNTHead, x: Tensor, seqlen: Tensor) -> str:
1416
+ """Backward-compatible greedy decode (no stats)."""
1417
+ return self._greedy_decode_impl(head, x, seqlen, collect_stats=False).text
1418
 
1419
  @torch.inference_mode()
1420
  def decode(self, head: RNNTHead, encoded: Tensor, enc_len: Tensor) -> List[str]:
 
1430
  return pred_texts
1431
 
1432
 
1433
+ @torch.inference_mode()
1434
+ def decode_with_debug(
1435
+ self, head: RNNTHead, encoded: Tensor, enc_len: Tensor, topk: int = 5
1436
+ ) -> Tuple[List[str], List[DecodeDebug]]:
1437
+ """Like decode(), but also returns per-utterance blank diagnostics."""
1438
+ b = encoded.shape[0]
1439
+ encoded_t = encoded.transpose(1, 2)
1440
+ texts: List[str] = []
1441
+ debugs: List[DecodeDebug] = []
1442
+ for i in range(b):
1443
+ inseq = encoded_t[i, :, :].unsqueeze(1)
1444
+ dbg = self._greedy_decode_impl(head, inseq, enc_len[i], collect_stats=True, topk=topk)
1445
+ texts.append(dbg.text)
1446
+ debugs.append(dbg)
1447
+ return texts, debugs
1448
+
1449
+
1450
  ### models ###
1451
 
1452
 
 
1468
  Perform forward pass through the preprocessor and encoder.
1469
  """
1470
  features, feature_lengths = self.preprocessor(features, feature_lengths)
1471
+
1472
+ if _env_flag("GIGAAM_DEBUG", False):
1473
+ print_env_versions()
1474
+
1475
+ # CPU: no autocast
1476
  if self._device.type == "cpu":
1477
  return self.encoder(features, feature_lengths)
1478
+
1479
+ # GPU: optionally disable autocast to debug fp16-boundary failures
1480
+ force_fp32 = _env_flag("GIGAAM_FORCE_FP32", False)
1481
+ if force_fp32:
1482
+ features = features.float()
1483
+
1484
+ with torch.autocast(device_type=self._device.type, dtype=torch.float16, enabled=not force_fp32):
1485
  return self.encoder(features, feature_lengths)
1486
 
1487
  @property
 
1496
  """
1497
  Prepare an audio file for processing by loading it onto
1498
  the correct device and converting its format.
1499
+
1500
+ Debug/robustness (env vars):
1501
+ - GIGAAM_DEBUG=1 prints waveform stats
1502
+ - GIGAAM_PAD_START_MS / GIGAAM_PAD_END_MS pad silence (milliseconds)
1503
  """
1504
  wav = load_audio(wav_file)
1505
+
1506
+ # Optional padding to reduce edge effects from STFT centering/padding
1507
+ pad_start_ms = _env_int("GIGAAM_PAD_START_MS", 0)
1508
+ pad_end_ms = _env_int("GIGAAM_PAD_END_MS", 0)
1509
+ if pad_start_ms or pad_end_ms:
1510
+ wav = pad_wav(wav, SAMPLE_RATE, pad_start_ms=pad_start_ms, pad_end_ms=pad_end_ms)
1511
+
1512
+ if _env_flag("GIGAAM_DEBUG", False):
1513
+ st = audio_stats(wav, SAMPLE_RATE)
1514
+ # Very rough "this might be off-distribution" checks
1515
+ if abs(st.get("mean", 0.0)) > 1e-3:
1516
+ print(f"[GigaAM debug] WARNING: DC-ish mean={st['mean']:.4g} for {wav_file}")
1517
+ if st.get("clip_frac", 0.0) > 0.001:
1518
+ print(f"[GigaAM debug] WARNING: possible clipping frac={st['clip_frac']:.4g} for {wav_file}")
1519
+ if st.get("nan") or st.get("inf"):
1520
+ print(f"[GigaAM debug] ERROR: NaN/Inf in waveform for {wav_file}")
1521
+ print(f"[GigaAM debug] wav stats for {wav_file}: {json.dumps(st, ensure_ascii=False)}")
1522
+
1523
  wav = wav.to(self._device).to(self._dtype).unsqueeze(0)
1524
  length = torch.full([1], wav.shape[-1], device=self._device)
1525
  return wav, length
 
1573
  encoded, encoded_len = self.forward(wav, length)
1574
  return self.decoding.decode(self.head, encoded, encoded_len)[0]
1575
 
1576
+ @torch.inference_mode()
1577
+ def transcribe_debug(
1578
+ self,
1579
+ wav_file: str,
1580
+ *,
1581
+ topk: int = 5,
1582
+ try_fixes: bool = True,
1583
+ pad_ms: int = 500,
1584
+ ) -> Dict[str, Any]:
1585
+ """Run transcription plus diagnostics. If empty, optionally try common fixes.
1586
+
1587
+ Returns a JSON-serializable dict with:
1588
+ - attempts: list of {strategy, text, decode_stats}
1589
+ """
1590
+ report: Dict[str, Any] = {
1591
+ "wav_file": wav_file,
1592
+ "torch": torch.__version__,
1593
+ "torchaudio": torchaudio.__version__,
1594
+ "attempts": [],
1595
+ }
1596
+
1597
+ pre = self.preprocessor
1598
+ orig_center = getattr(pre, "center", None)
1599
+ orig_pad_mode = getattr(pre, "pad_mode", None)
1600
+
1601
+ def _run(strategy: str, *, force_fp32: bool = False, pad_start_ms: int = 0, pad_end_ms: int = 0, mels_pad_mode: Optional[str] = None):
1602
+ # Apply per-attempt toggles via env (forward()/prepare_wav() read these)
1603
+ env = {
1604
+ "GIGAAM_DEBUG": "1",
1605
+ "GIGAAM_FORCE_FP32": "1" if force_fp32 else None,
1606
+ "GIGAAM_PAD_START_MS": str(pad_start_ms) if pad_start_ms else None,
1607
+ "GIGAAM_PAD_END_MS": str(pad_end_ms) if pad_end_ms else None,
1608
+ }
1609
+ with temp_environ(**env):
1610
+ # Hot-swap mel padding mode if requested
1611
+ if mels_pad_mode is not None and hasattr(pre, "set_mels_padding"):
1612
+ pre.set_mels_padding(pad_mode=mels_pad_mode)
1613
+
1614
+ dtype_ctx = temporary_module_dtype(self, torch.float32) if force_fp32 else nullcontext()
1615
+ with dtype_ctx:
1616
+ wav, length = self.prepare_wav(wav_file)
1617
+ if length.item() > LONGFORM_THRESHOLD:
1618
+ raise ValueError("Too long wav file, use 'transcribe_longform' method.")
1619
+ encoded, encoded_len = self.forward(wav, length)
1620
+
1621
+ if hasattr(self.decoding, "decode_with_debug"):
1622
+ texts, debugs = self.decoding.decode_with_debug(self.head, encoded, encoded_len, topk=topk) # type: ignore[attr-defined]
1623
+ text = texts[0]
1624
+ dec_stats = debugs[0].stats
1625
+ else:
1626
+ text = self.decoding.decode(self.head, encoded, encoded_len)[0]
1627
+ dec_stats = {}
1628
+
1629
+ # Restore mel settings after attempt
1630
+ if hasattr(pre, "set_mels_padding"):
1631
+ pre.set_mels_padding(center=orig_center if isinstance(orig_center, bool) else None, pad_mode=orig_pad_mode if isinstance(orig_pad_mode, str) else None)
1632
+
1633
+ report["attempts"].append(
1634
+ {"strategy": strategy, "text": text, "decode_stats": dec_stats}
1635
+ )
1636
+ return text
1637
+
1638
+ # Attempt 0: baseline
1639
+ txt = _run("baseline")
1640
+ if txt != "" or not try_fixes:
1641
+ report["final_text"] = txt
1642
+ return report
1643
+
1644
+ # Fix 1: rerun with fp32 (disable autocast)
1645
+ txt = _run("force_fp32", force_fp32=True)
1646
+ if txt != "":
1647
+ report["final_text"] = txt
1648
+ return report
1649
+
1650
+ # Fix 2: pad both ends (helps with STFT centering + reflect padding edge artifacts)
1651
+ txt = _run("pad_silence_both_ends", pad_start_ms=pad_ms, pad_end_ms=pad_ms)
1652
+ if txt != "":
1653
+ report["final_text"] = txt
1654
+ return report
1655
+
1656
+ # Fix 3: stop reflect padding in the spectrogram (pad_mode=constant) + pad both ends
1657
+ txt = _run("mels_pad_mode_constant_plus_pad", pad_start_ms=pad_ms, pad_end_ms=pad_ms, mels_pad_mode="constant")
1658
+ report["final_text"] = txt
1659
+ return report
1660
+
1661
+ @torch.inference_mode()
1662
+ def transcribe_resilient(self, wav_file: str, **kwargs) -> str:
1663
+ """Convenience wrapper: return non-empty transcription if any fix works."""
1664
+ rep = self.transcribe_debug(wav_file, **kwargs)
1665
+ for att in rep.get("attempts", []):
1666
+ if att.get("text", "") != "":
1667
+ return att["text"]
1668
+ return rep.get("final_text", "")
1669
+
1670
  def forward_for_export(self, features: Tensor, feature_lengths: Tensor) -> Tensor:
1671
  """
1672
  Encoder-decoder forward to save model entirely in onnx format.