mazesmazes commited on
Commit
4623ffa
·
verified ·
1 Parent(s): 6718653

Update custom model files, README, and requirements

Browse files
Files changed (2) hide show
  1. asr_modeling.py +4 -0
  2. asr_pipeline.py +1 -18
asr_modeling.py CHANGED
@@ -100,6 +100,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
100
  self.generation_config.max_new_tokens = config.max_new_tokens
101
  self.generation_config.num_beams = config.num_beams
102
  self.generation_config.do_sample = False
 
 
 
 
103
  self.generation_config.use_cache = config.use_cache
104
  self.generation_config.length_penalty = config.length_penalty
105
  self.generation_config.repetition_penalty = config.repetition_penalty
 
100
  self.generation_config.max_new_tokens = config.max_new_tokens
101
  self.generation_config.num_beams = config.num_beams
102
  self.generation_config.do_sample = False
103
+ # Clear sampling params (inherited from LLM) since we use greedy decoding
104
+ self.generation_config.temperature = None
105
+ self.generation_config.top_p = None
106
+ self.generation_config.top_k = None
107
  self.generation_config.use_cache = config.use_cache
108
  self.generation_config.length_penalty = config.length_penalty
109
  self.generation_config.repetition_penalty = config.repetition_penalty
asr_pipeline.py CHANGED
@@ -1,6 +1,5 @@
1
  from typing import Any
2
 
3
- import numpy as np
4
  import torch
5
  import transformers
6
 
@@ -10,14 +9,6 @@ except ImportError:
10
  from asr_modeling import ASRModel # type: ignore[no-redef]
11
 
12
 
13
- def normalize_audio(audio: np.ndarray, target_peak: float = 0.95) -> np.ndarray:
14
- """Normalize audio to target peak amplitude for consistent input levels."""
15
- max_val = np.abs(audio).max()
16
- if max_val > 0:
17
- return audio / max_val * target_peak
18
- return audio
19
-
20
-
21
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
22
  """ASR Pipeline for audio-to-text transcription."""
23
 
@@ -37,18 +28,10 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
37
  def preprocess(self, inputs, **preprocess_params):
38
  # Handle dict with "array" key (from datasets)
39
  if isinstance(inputs, dict) and "array" in inputs:
40
- audio = inputs["array"]
41
- if isinstance(audio, np.ndarray):
42
- audio = normalize_audio(audio)
43
  inputs = {
44
- "raw": audio,
45
  "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
46
  }
47
- # Handle dict with "raw" key
48
- elif isinstance(inputs, dict) and "raw" in inputs:
49
- audio = inputs["raw"]
50
- if isinstance(audio, np.ndarray):
51
- inputs["raw"] = normalize_audio(audio)
52
 
53
  for item in super().preprocess(inputs, **preprocess_params):
54
  if "is_last" not in item:
 
1
  from typing import Any
2
 
 
3
  import torch
4
  import transformers
5
 
 
9
  from asr_modeling import ASRModel # type: ignore[no-redef]
10
 
11
 
 
 
 
 
 
 
 
 
12
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
13
  """ASR Pipeline for audio-to-text transcription."""
14
 
 
28
  def preprocess(self, inputs, **preprocess_params):
29
  # Handle dict with "array" key (from datasets)
30
  if isinstance(inputs, dict) and "array" in inputs:
 
 
 
31
  inputs = {
32
+ "raw": inputs["array"],
33
  "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
34
  }
 
 
 
 
 
35
 
36
  for item in super().preprocess(inputs, **preprocess_params):
37
  if "is_last" not in item: