mazesmazes commited on
Commit
2457a76
·
verified ·
1 Parent(s): 63322b6

Training in progress - step 13000

Browse files
Files changed (3) hide show
  1. asr_config.py +0 -1
  2. asr_modeling.py +160 -6
  3. asr_pipeline.py +116 -3
asr_config.py CHANGED
@@ -46,7 +46,6 @@ class ASRConfig(transformers.PretrainedConfig):
46
  "min_new_tokens": 1,
47
  "do_sample": False,
48
  "repetition_penalty": 1.05,
49
- "length_penalty": 1.0,
50
  "no_repeat_ngram_size": 0,
51
  "use_cache": True,
52
  }
 
46
  "min_new_tokens": 1,
47
  "do_sample": False,
48
  "repetition_penalty": 1.05,
 
49
  "no_repeat_ngram_size": 0,
50
  "use_cache": True,
51
  }
asr_modeling.py CHANGED
@@ -1,5 +1,8 @@
1
  from pathlib import Path
2
- from typing import Optional, Union
 
 
 
3
 
4
  import torch
5
  import torch.nn as nn
@@ -11,6 +14,7 @@ from transformers import (
11
  AutoTokenizer,
12
  PreTrainedModel,
13
  Wav2Vec2FeatureExtractor,
 
14
  )
15
  from transformers.generation.utils import (
16
  GenerateBeamDecoderOnlyOutput,
@@ -25,6 +29,17 @@ except ImportError:
25
  from asr_config import ASRConfig # type: ignore[no-redef]
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
28
  class SwiGLU(nn.Module):
29
  def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
30
  super().__init__()
@@ -118,8 +133,12 @@ class ASRModel(PreTrainedModel):
118
  return WhisperFeatureExtractor.from_pretrained(
119
  audio_model_id,
120
  feature_size=num_mel_bins,
 
121
  )
122
- return Wav2Vec2FeatureExtractor.from_pretrained(audio_model_id)
 
 
 
123
 
124
  @classmethod
125
  def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
@@ -206,10 +225,6 @@ class ASRModel(PreTrainedModel):
206
  self.decoder = self._create_decoder(config)
207
  self.generation_config = self.decoder.generation_config
208
 
209
- # Set default generation parameters
210
- self.generation_config.num_beams = 1
211
- self.generation_config.length_penalty = 1.0
212
-
213
  self._init_tokenizer()
214
 
215
  from types import SimpleNamespace
@@ -691,6 +706,145 @@ class ASRModel(PreTrainedModel):
691
 
692
  return generated_ids[:, prompt_length:]
693
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
694
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
695
  import shutil
696
  from pathlib import Path as PathlibPath
 
1
  from pathlib import Path
2
+ from typing import Optional, Union, Generator, NamedTuple
3
+
4
+ import threading
5
+ from concurrent import futures
6
 
7
  import torch
8
  import torch.nn as nn
 
14
  AutoTokenizer,
15
  PreTrainedModel,
16
  Wav2Vec2FeatureExtractor,
17
+ TextIteratorStreamer,
18
  )
19
  from transformers.generation.utils import (
20
  GenerateBeamDecoderOnlyOutput,
 
29
  from asr_config import ASRConfig # type: ignore[no-redef]
30
 
31
 
32
+ class StreamChunk(NamedTuple):
33
+ """A chunk of streaming transcription text."""
34
+ text: str
35
+
36
+
37
+ class StreamStats(NamedTuple):
38
+ """Statistics about the streaming inference."""
39
+ input_tokens: int
40
+ output_tokens: int
41
+
42
+
43
  class SwiGLU(nn.Module):
44
  def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
45
  super().__init__()
 
133
  return WhisperFeatureExtractor.from_pretrained(
134
  audio_model_id,
135
  feature_size=num_mel_bins,
136
+ do_normalize=True,
137
  )
138
+ return Wav2Vec2FeatureExtractor.from_pretrained(
139
+ audio_model_id,
140
+ do_normalize=True,
141
+ )
142
 
143
  @classmethod
144
  def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
 
225
  self.decoder = self._create_decoder(config)
226
  self.generation_config = self.decoder.generation_config
227
 
 
 
 
 
228
  self._init_tokenizer()
229
 
230
  from types import SimpleNamespace
 
706
 
707
  return generated_ids[:, prompt_length:]
708
 
709
+ @torch.no_grad()
710
+ def generate_stream(
711
+ self,
712
+ input_values: Optional[torch.Tensor] = None,
713
+ input_features: Optional[torch.Tensor] = None,
714
+ system_prompt: Optional[str] = None,
715
+ user_prompt: Optional[str] = None,
716
+ task: Optional[str] = None,
717
+ max_new_tokens: Optional[int] = None,
718
+ temperature: Optional[float] = None,
719
+ **generate_kwargs,
720
+ ) -> Generator[Union[StreamChunk, StreamStats], None, None]:
721
+ """
722
+ Generate transcription in streaming mode, yielding text chunks as they're generated.
723
+
724
+ Args:
725
+ input_values: Audio input tensor for non-Whisper models
726
+ input_features: Audio input tensor for Whisper models
727
+ system_prompt: System prompt override
728
+ user_prompt: User prompt override
729
+ task: Task type (transcribe, describe, emotion, continue)
730
+ max_new_tokens: Maximum tokens to generate
731
+ temperature: Sampling temperature
732
+ **generate_kwargs: Additional generation parameters
733
+
734
+ Yields:
735
+ StreamChunk: Text chunks as they're generated
736
+ StreamStats: Final statistics (input_tokens, output_tokens)
737
+ """
738
+ audio_inputs = input_values if input_values is not None else input_features
739
+ if audio_inputs is None:
740
+ raise ValueError("input_values or input_features must be provided for generation")
741
+
742
+ # Encode audio once and prepare prompt
743
+ audio_embeds = self._encode_audio(audio_inputs)
744
+ batch_size = audio_embeds.shape[0]
745
+ device = audio_embeds.device
746
+
747
+ if batch_size > 1:
748
+ raise ValueError("Streaming generation only supports batch_size=1")
749
+
750
+ if system_prompt is None:
751
+ system_prompt = self.system_prompt
752
+
753
+ if user_prompt is None:
754
+ user_prompt = (
755
+ self.TASK_PROMPTS.get(task, self.config.user_prompt or "Transcribe: <audio>")
756
+ or "Transcribe: <audio>"
757
+ )
758
+
759
+ messages = []
760
+ if system_prompt:
761
+ messages.append({"role": "system", "content": system_prompt})
762
+ messages.append({"role": "user", "content": user_prompt})
763
+
764
+ prompt_ids = self.tokenizer.apply_chat_template(
765
+ messages,
766
+ tokenize=True,
767
+ add_generation_prompt=True,
768
+ return_tensors="pt",
769
+ enable_thinking=False,
770
+ ).to(device)
771
+
772
+ if len(prompt_ids.shape) == 1:
773
+ prompt_ids = prompt_ids.unsqueeze(0)
774
+
775
+ if not (prompt_ids == self.audio_token_id).any():
776
+ raise ValueError("Audio token <audio> not found in prompt")
777
+
778
+ num_audio_tokens = audio_embeds.shape[1]
779
+ expanded_prompt_ids = self._expand_audio_tokens(prompt_ids, num_audio_tokens)
780
+ inputs_embeds = self._prepare_audio_inputs_embeds(expanded_prompt_ids, audio_embeds)
781
+ input_token_count = expanded_prompt_ids.shape[1]
782
+
783
+ attention_mask = torch.ones(
784
+ batch_size, input_token_count, dtype=torch.long, device=device
785
+ )
786
+
787
+ # Set up generation parameters
788
+ if max_new_tokens is None:
789
+ max_new_tokens = getattr(self.config, "max_new_tokens", 256)
790
+
791
+ generate_kwargs.setdefault("max_new_tokens", max_new_tokens)
792
+ generate_kwargs.setdefault("use_cache", True)
793
+ generate_kwargs.setdefault(
794
+ "eos_token_id", self.tokenizer.convert_tokens_to_ids("<|im_end|>")
795
+ )
796
+ generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
797
+
798
+ if temperature is not None:
799
+ generate_kwargs["temperature"] = temperature
800
+ generate_kwargs.setdefault("do_sample", True)
801
+
802
+ # Set up the streamer
803
+ streamer = TextIteratorStreamer(
804
+ self.tokenizer,
805
+ skip_prompt=True,
806
+ skip_special_tokens=True
807
+ )
808
+
809
+ # Generate in a separate thread
810
+ def generation_thread(future: futures.Future):
811
+ try:
812
+ result = self.decoder.generate(
813
+ input_ids=expanded_prompt_ids,
814
+ inputs_embeds=inputs_embeds,
815
+ attention_mask=attention_mask,
816
+ streamer=streamer,
817
+ **generate_kwargs,
818
+ )
819
+ future.set_result(result)
820
+ except Exception as e:
821
+ future.set_exception(e)
822
+
823
+ future: futures.Future[torch.Tensor] = futures.Future()
824
+ thread = threading.Thread(target=generation_thread, args=(future,))
825
+ thread.start()
826
+
827
+ # Stream the output
828
+ output_text = ""
829
+ output_token_count = 0
830
+
831
+ try:
832
+ for chunk in streamer:
833
+ if chunk:
834
+ output_text += chunk
835
+ output_token_count += 1
836
+ yield StreamChunk(chunk)
837
+ finally:
838
+ # Wait for generation to complete
839
+ thread.join()
840
+
841
+ # Check if there was an exception
842
+ if future.exception():
843
+ raise future.exception()
844
+
845
+ # Yield final statistics
846
+ yield StreamStats(input_token_count, output_token_count)
847
+
848
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
849
  import shutil
850
  from pathlib import Path as PathlibPath
asr_pipeline.py CHANGED
@@ -1,13 +1,13 @@
1
- from typing import Any, Dict
2
 
3
  import torch
4
  import transformers
5
  from truecase import get_true_case
6
 
7
  try:
8
- from .asr_modeling import ASRModel
9
  except ImportError:
10
- from asr_modeling import ASRModel # type: ignore[no-redef]
11
 
12
 
13
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
@@ -31,6 +31,11 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
31
  self.text_normalizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
32
 
33
  def __call__(self, inputs, **kwargs):
 
 
 
 
 
34
  generate_kwargs = {}
35
  for key in [
36
  "max_new_tokens",
@@ -292,3 +297,111 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
292
  text = get_true_case(text)
293
 
294
  return {"text": text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Generator, Union
2
 
3
  import torch
4
  import transformers
5
  from truecase import get_true_case
6
 
7
  try:
8
+ from .asr_modeling import ASRModel, StreamChunk, StreamStats
9
  except ImportError:
10
+ from asr_modeling import ASRModel, StreamChunk, StreamStats # type: ignore[no-redef]
11
 
12
 
13
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
 
31
  self.text_normalizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
32
 
33
  def __call__(self, inputs, **kwargs):
34
+ # Check if streaming is requested
35
+ stream = kwargs.pop("stream", False)
36
+ if stream:
37
+ return self._stream_inference(inputs, **kwargs)
38
+
39
  generate_kwargs = {}
40
  for key in [
41
  "max_new_tokens",
 
297
  text = get_true_case(text)
298
 
299
  return {"text": text}
300
+
301
+ def _stream_inference(
302
+ self, inputs, **kwargs
303
+ ) -> Generator[Union[Dict[str, str], Dict[str, int]], None, None]:
304
+ """
305
+ Perform streaming inference on audio input.
306
+
307
+ Args:
308
+ inputs: Audio input (same format as __call__)
309
+ **kwargs: Generation parameters
310
+
311
+ Yields:
312
+ Dict with "text" key containing text chunks as they're generated,
313
+ followed by a final dict with "input_tokens" and "output_tokens" statistics
314
+ """
315
+ # Extract generation kwargs
316
+ generate_kwargs = {}
317
+ for key in [
318
+ "max_new_tokens",
319
+ "temperature",
320
+ "do_sample",
321
+ "top_k",
322
+ "top_p",
323
+ "user_prompt",
324
+ "task",
325
+ "system_prompt",
326
+ ]:
327
+ if key in kwargs:
328
+ generate_kwargs[key] = kwargs.pop(key)
329
+
330
+ # Disable chunking for streaming - we want the whole audio at once
331
+ kwargs.pop("chunk_length_s", None)
332
+ kwargs.pop("stride_length_s", None)
333
+
334
+ # Preprocess audio to get model inputs
335
+ model_inputs = self.preprocess(inputs, chunk_length_s=0, **kwargs)
336
+
337
+ # Handle different input formats
338
+ audio_inputs = None
339
+ is_whisper = False
340
+
341
+ # Check if preprocess returned an iterator (shouldn't with chunk_length_s=0)
342
+ from collections.abc import Iterator
343
+ if isinstance(model_inputs, Iterator):
344
+ # Get the first (and should be only) chunk
345
+ try:
346
+ model_inputs = next(model_inputs)
347
+ except StopIteration:
348
+ raise ValueError("Preprocess returned empty iterator")
349
+
350
+ if isinstance(model_inputs, torch.Tensor):
351
+ audio_inputs = model_inputs
352
+ elif isinstance(model_inputs, dict):
353
+ # Remove metadata fields
354
+ model_inputs.pop("is_last", None)
355
+ model_inputs.pop("stride", None)
356
+
357
+ # Get audio input (Whisper uses input_features, others use input_values)
358
+ if "input_features" in model_inputs:
359
+ audio_inputs = model_inputs["input_features"]
360
+ is_whisper = True
361
+ else:
362
+ audio_inputs = model_inputs.get("input_values")
363
+
364
+ if audio_inputs is None:
365
+ # Debug info
366
+ import sys
367
+ print(f"DEBUG: model_inputs type: {type(model_inputs)}", file=sys.stderr)
368
+ if isinstance(model_inputs, dict):
369
+ print(f"DEBUG: model_inputs keys: {model_inputs.keys()}", file=sys.stderr)
370
+ raise ValueError(f"Could not extract audio inputs from preprocessing. Got type: {type(model_inputs)}")
371
+
372
+ if isinstance(audio_inputs, torch.Tensor):
373
+ audio_inputs = audio_inputs.to(self.model.device)
374
+ else:
375
+ raise ValueError(f"audio inputs must be a tensor, got {type(audio_inputs)}")
376
+
377
+ # Call the streaming generate method
378
+ if is_whisper:
379
+ stream_generator = self.model.generate_stream(
380
+ input_features=audio_inputs,
381
+ **generate_kwargs,
382
+ )
383
+ else:
384
+ stream_generator = self.model.generate_stream(
385
+ input_values=audio_inputs,
386
+ **generate_kwargs,
387
+ )
388
+
389
+ # Track full text for post-processing
390
+ full_text = ""
391
+
392
+ # Stream the chunks
393
+ for item in stream_generator:
394
+ if isinstance(item, StreamChunk):
395
+ full_text += item.text
396
+ yield {"text": item.text}
397
+ elif isinstance(item, StreamStats):
398
+ # Apply post-processing to the full text
399
+ processed_text = self.text_normalizer.normalize(full_text)
400
+ processed_text = get_true_case(processed_text)
401
+
402
+ # Yield final statistics with processed text
403
+ yield {
404
+ "input_tokens": item.input_tokens,
405
+ "output_tokens": item.output_tokens,
406
+ "full_text": processed_text,
407
+ }