Vladislav Gadzhikhanov commited on
Commit
74e40ab
·
1 Parent(s): ec1dc1f

Added BinaryIO input for transcribe_longform

Browse files
Files changed (2) hide show
  1. README.md +0 -65
  2. modeling_gigaam.py +38 -7
README.md CHANGED
@@ -1,65 +0,0 @@
1
- ---
2
- license: mit
3
- language:
4
- - ru
5
- - en
6
- pipeline_tag: automatic-speech-recognition
7
- ---
8
-
9
- # GigaAM-v3
10
-
11
- GigaAM-v3 is a Conformer-based foundation model with 220–240M parameters, pretrained on diverse Russian speech data using the HuBERT-CTC objective.
12
- It is the third generation of the GigaAM family and provides state-of-the-art performance on Russian ASR across a wide range of domains.
13
-
14
- GigaAM-v3 includes the following model variants:
15
- - `ssl` — self-supervised HuBERT–CTC encoder pre-trained on 700,000 hours of Russian speech
16
- - `ctc` — ASR model fine-tuned with a CTC decoder
17
- - `rnnt` — ASR model fine-tuned with an RNN-T decoder
18
- - `e2e_ctc` — end-to-end CTC model with punctuation and text normalization
19
- - `e2e_rnnt` — end-to-end RNN-T model with punctuation and text normalization
20
-
21
- `GigaAM-v3` training incorporates new internal datasets: callcenter conversations, speech with background music, natural speech, and speech with atypical characteristics.
22
- the models perform on average **30%** better on these new domains, while maintaining the same quality as previous GigaAM generations on public benchmarks.
23
-
24
- The table below reports the Word Error Rate (%) for `GigaAM-v3` and other existing models over diverse domains.
25
-
26
- | Set Name | V3_CTC | V3_RNNT | T-One + LM | Whisper |
27
- |:------------------|-------:|--------:|-----------:|--------:|
28
- | Open Datasets | 3.0 | 2.6 | 5.7 | 12.0 |
29
- | Golos Farfield | 4.5 | 3.9 | 12.2 | 16.7 |
30
- | Natural Speech | 7.8 | 6.9 | 14.5 | 13.6 |
31
- | Disordered Speech | 20.6 | 19.2 | 51.0 | 59.3 |
32
- | Callcenter | 10.3 | 9.5 | 13.5 | 23.9 |
33
- | **Average** | **9.2**| **8.4** | 19.4 | 25.1 |
34
-
35
- The end-to-end ASR models (`e2e_ctc` and `e2e_rnnt`) produce punctuated, normalized text directly.
36
- In end-to-end ASR comparisons of `e2e_ctc` and `e2e_rnnt` against Whisper-large-v3, using Gemini 2.5 Pro as an LLM-as-a-judge, GigaAM-v3 models win by an average margin of **70:30**.
37
-
38
- For detailed results, see [metrics](https://github.com/salute-developers/GigaAM/blob/main/evaluation.md).
39
-
40
- ## Usage
41
- ```python
42
- from transformers import AutoModel
43
-
44
- revision = "e2e_rnnt" # can be any v3 model: ssl, ctc, rnnt, e2e_ctc, e2e_rnnt
45
- model = AutoModel.from_pretrained(
46
- "ai-sage/GigaAM-v3",
47
- revision=revision,
48
- trust_remote_code=True,
49
- )
50
-
51
- transcription = model.transcribe("example.wav")
52
- print(transcription)
53
- ```
54
-
55
- Recommended versions:
56
- - `torch==2.8.0`, `torchaudio==2.8.0`
57
- - `transformers==4.57.1`
58
- - `pyannote-audio==4.0.0`, `torchcodec==0.7.0`
59
- - (any) `hydra-core`, `omegaconf`, `sentencepiece`
60
-
61
- Full usage guide can be found in the [example](https://github.com/salute-developers/GigaAM/blob/main/colab_example.ipynb).
62
-
63
- **License:** MIT
64
-
65
- **Paper:** [GigaAM: Efficient Self-Supervised Learner for Speech Recognition (InterSpeech 2025)](https://arxiv.org/abs/2506.01192)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_gigaam.py CHANGED
@@ -21,6 +21,7 @@ from torch import Tensor, nn
21
  from torch.jit import TracerWarning
22
  from transformers import PretrainedConfig, PreTrainedModel
23
  from transformers.utils import cached_file
 
24
 
25
  DIR_NAME = os.path.dirname(os.path.abspath(__file__))
26
  sys.path.append(DIR_NAME) # enable using modules through modeling_gigaam.<module_name>
@@ -66,6 +67,35 @@ def load_audio(audio_path: str, sample_rate: int = SAMPLE_RATE) -> Tensor:
66
  return torch.frombuffer(audio, dtype=torch.int16).float() / 32768.0
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  class SpecScaler(nn.Module):
70
  """
71
  Module that applies logarithmic scaling to spectrogram values.
@@ -296,7 +326,7 @@ def get_pipeline(device: torch.device):
296
 
297
 
298
  def segment_audio_file(
299
- wav_file: str,
300
  sr: int,
301
  max_duration: float = 22.0,
302
  min_duration: float = 15.0,
@@ -309,9 +339,10 @@ def segment_audio_file(
309
  The segmentation is performed using a PyAnnote voice activity detection pipeline.
310
  """
311
 
312
- audio = load_audio(wav_file)
313
  pipeline = get_pipeline(device)
314
- sad_segments = pipeline(wav_file)
 
315
 
316
  segments: List[torch.Tensor] = []
317
  curr_duration = 0.0
@@ -1296,7 +1327,7 @@ class GigaAMASR(GigaAM):
1296
 
1297
  @torch.inference_mode()
1298
  def transcribe_longform(
1299
- self, wav_file: str, **kwargs
1300
  ) -> List[Dict[str, Union[str, Tuple[float, float]]]]:
1301
  """
1302
  Transcribes a long audio file by splitting it into segments and
@@ -1304,7 +1335,7 @@ class GigaAMASR(GigaAM):
1304
  """
1305
  transcribed_segments = []
1306
  segments, boundaries = segment_audio_file(
1307
- wav_file, SAMPLE_RATE, device=self._device, **kwargs
1308
  )
1309
  for segment, segment_boundaries in zip(segments, boundaries):
1310
  wav = segment.to(self._device).unsqueeze(0).to(self._dtype)
@@ -1411,8 +1442,8 @@ class GigaAMModel(PreTrainedModel):
1411
  def transcribe(self, wav_file: str) -> str:
1412
  return self.model.transcribe(wav_file)
1413
 
1414
- def transcribe_longform(self, wav_file: str) -> List[Dict[str, Union[str, Tuple[float, float]]]]:
1415
- return self.model.transcribe_longform(wav_file)
1416
 
1417
  def get_probs(self, wav_file: str) -> Dict[str, float]:
1418
  return self.model.get_probs(wav_file)
 
21
  from torch.jit import TracerWarning
22
  from transformers import PretrainedConfig, PreTrainedModel
23
  from transformers.utils import cached_file
24
+ from typing import BinaryIO
25
 
26
  DIR_NAME = os.path.dirname(os.path.abspath(__file__))
27
  sys.path.append(DIR_NAME) # enable using modules through modeling_gigaam.<module_name>
 
67
  return torch.frombuffer(audio, dtype=torch.int16).float() / 32768.0
68
 
69
 
70
+ def load_audio_binary(file: BinaryIO, sample_rate: int = SAMPLE_RATE) -> Tensor:
71
+ """
72
+ Load audio from binary stream using ffmpeg pipe.
73
+ Note: Requires ffmpeg compiled with proper stdin support.
74
+ """
75
+ cmd = [
76
+ "ffmpeg",
77
+ "-i", "pipe:0", # Читаем из stdin
78
+ "-f", "s16le",
79
+ "-ac", "1",
80
+ "-acodec", "pcm_s16le",
81
+ "-ar", str(sample_rate),
82
+ "pipe:1" # Пишем в stdout
83
+ ]
84
+
85
+ if hasattr(file, 'seek'):
86
+ file.seek(0)
87
+
88
+ try:
89
+ result = run(cmd, input=file.read(), capture_output=True, check=True)
90
+ audio_bytes = result.stdout
91
+ except CalledProcessError as exc:
92
+ raise RuntimeError(f"FFmpeg failed: {exc.stderr.decode()}") from exc
93
+
94
+ with warnings.catch_warnings():
95
+ warnings.simplefilter("ignore", category=UserWarning)
96
+ return torch.frombuffer(audio_bytes, dtype=torch.int16).float() / 32768.0
97
+
98
+
99
  class SpecScaler(nn.Module):
100
  """
101
  Module that applies logarithmic scaling to spectrogram values.
 
326
 
327
 
328
  def segment_audio_file(
329
+ file: BinaryIO,
330
  sr: int,
331
  max_duration: float = 22.0,
332
  min_duration: float = 15.0,
 
339
  The segmentation is performed using a PyAnnote voice activity detection pipeline.
340
  """
341
 
342
+ audio = load_audio_binary(file)
343
  pipeline = get_pipeline(device)
344
+ if hasattr(file, 'seek'): file.seek(0)
345
+ sad_segments = pipeline(file)
346
 
347
  segments: List[torch.Tensor] = []
348
  curr_duration = 0.0
 
1327
 
1328
  @torch.inference_mode()
1329
  def transcribe_longform(
1330
+ self, file: BinaryIO, **kwargs
1331
  ) -> List[Dict[str, Union[str, Tuple[float, float]]]]:
1332
  """
1333
  Transcribes a long audio file by splitting it into segments and
 
1335
  """
1336
  transcribed_segments = []
1337
  segments, boundaries = segment_audio_file(
1338
+ file, SAMPLE_RATE, device=self._device, **kwargs
1339
  )
1340
  for segment, segment_boundaries in zip(segments, boundaries):
1341
  wav = segment.to(self._device).unsqueeze(0).to(self._dtype)
 
1442
  def transcribe(self, wav_file: str) -> str:
1443
  return self.model.transcribe(wav_file)
1444
 
1445
+ def transcribe_longform(self, file: BinaryIO) -> List[Dict[str, Union[str, Tuple[float, float]]]]:
1446
+ return self.model.transcribe_longform(file)
1447
 
1448
  def get_probs(self, wav_file: str) -> Dict[str, float]:
1449
  return self.model.get_probs(wav_file)