mazesmazes commited on
Commit
ec54f1d
·
verified ·
1 Parent(s): 3a943c7

Update custom model files, README, and requirements

Browse files
Files changed (3) hide show
  1. asr_config.py +3 -4
  2. asr_modeling.py +5 -3
  3. handler.py +5 -32
asr_config.py CHANGED
@@ -21,10 +21,8 @@ class ASRConfig(transformers.PretrainedConfig):
21
  llm_dim: Optional[int] = None,
22
  audio_sample_rate: int = 16000,
23
  projector_init_std: float = 0.02,
24
- projector_pool_stride: int = 2,
25
- projector_hidden_dim: Optional[
26
- int
27
- ] = None,
28
  projector_dropout: float = 0.0, # Dropout rate for projector layers
29
  inference_diversity_penalty: float = 0.0,
30
  inference_warmup_tokens: int = 10,
@@ -120,4 +118,5 @@ class ASRConfig(transformers.PretrainedConfig):
120
  self.architectures = ["ASRModel"]
121
  self.pipeline_tag = "automatic-speech-recognition"
122
 
 
123
  transformers.AutoConfig.register("asr_model", ASRConfig)
 
21
  llm_dim: Optional[int] = None,
22
  audio_sample_rate: int = 16000,
23
  projector_init_std: float = 0.02,
24
+ projector_pool_stride: int = 2,
25
+ projector_hidden_dim: Optional[int] = None,
 
 
26
  projector_dropout: float = 0.0, # Dropout rate for projector layers
27
  inference_diversity_penalty: float = 0.0,
28
  inference_warmup_tokens: int = 10,
 
118
  self.architectures = ["ASRModel"]
119
  self.pipeline_tag = "automatic-speech-recognition"
120
 
121
+
122
  transformers.AutoConfig.register("asr_model", ASRConfig)
asr_modeling.py CHANGED
@@ -26,7 +26,6 @@ except ImportError:
26
 
27
 
28
  class SwiGLU(nn.Module):
29
-
30
  def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
31
  super().__init__()
32
  self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
@@ -44,7 +43,6 @@ class SwiGLU(nn.Module):
44
 
45
 
46
  class AudioProjector(nn.Module):
47
-
48
  def __init__(self, config):
49
  super().__init__()
50
  self.k = getattr(config, "projector_pool_stride", 2) # Downsampling rate
@@ -120,8 +118,12 @@ class ASRModel(PreTrainedModel):
120
  return WhisperFeatureExtractor.from_pretrained(
121
  audio_model_id,
122
  feature_size=num_mel_bins,
 
123
  )
124
- return Wav2Vec2FeatureExtractor.from_pretrained(audio_model_id)
 
 
 
125
 
126
  @classmethod
127
  def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
 
26
 
27
 
28
  class SwiGLU(nn.Module):
 
29
  def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
30
  super().__init__()
31
  self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
 
43
 
44
 
45
  class AudioProjector(nn.Module):
 
46
  def __init__(self, config):
47
  super().__init__()
48
  self.k = getattr(config, "projector_pool_stride", 2) # Downsampling rate
 
118
  return WhisperFeatureExtractor.from_pretrained(
119
  audio_model_id,
120
  feature_size=num_mel_bins,
121
+ do_normalize=True,
122
  )
123
+ return Wav2Vec2FeatureExtractor.from_pretrained(
124
+ audio_model_id,
125
+ do_normalize=True,
126
+ )
127
 
128
  @classmethod
129
  def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
handler.py CHANGED
@@ -97,48 +97,21 @@ class EndpointHandler:
97
  print(f"Warmup skipped due to: {e}")
98
 
99
  def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
100
- """Process audio transcription request.
101
-
102
- Supports both single and batch inputs for efficient concurrent processing.
103
- The endpoint infrastructure can batch multiple concurrent requests automatically.
104
- """
105
  inputs = data.get("inputs")
106
  if inputs is None:
107
  raise ValueError("Missing 'inputs' in request data")
108
 
109
- # Get generation parameters (matching SLAM-ASR paper defaults)
110
  params = data.get("parameters", {})
111
- max_new_tokens = params.get("max_new_tokens", 200) # Longer transcripts
112
-
113
- # Beam search for better quality (5 beams for higher quality)
114
- # Use num_beams=1 for faster inference at cost of ~2-3% WER increase
115
- num_beams = params.get("num_beams", 5)
116
-
117
  do_sample = params.get("do_sample", False)
118
-
119
- # Length penalty encourages appropriate transcript length
120
- # >1.0 = prefer longer outputs, <1.0 = prefer shorter
121
- # Slight positive bias helps avoid truncated transcripts
122
- length_penalty = params.get("length_penalty", 1.1)
123
-
124
- # Repetition penalty to prevent loops (1.1-1.2 is good for ASR)
125
- repetition_penalty = params.get("repetition_penalty", 1.15)
126
-
127
- # Alternative: use no_repeat_ngram_size to prevent exact n-gram repetition
128
- no_repeat_ngram_size = params.get("no_repeat_ngram_size", 3)
129
-
130
- # Early stopping for beam search: stop when all beams end
131
- # "never" = generate full max_new_tokens (more accurate but slower)
132
- # True = stop when all beams reach EOS (faster)
133
  early_stopping = params.get("early_stopping", True)
134
-
135
- # Diversity penalty encourages different beams (helps with rare words)
136
- # 0.0 = no diversity, 0.5-1.0 = good diversity
137
  default_diversity = self.pipe.model.config.inference_diversity_penalty
138
  diversity_penalty = params.get("diversity_penalty", default_diversity)
139
 
140
- # The pipeline's __call__ method handles both single and batch inputs
141
- # as well as automatic chunking for long audio files
142
  return self.pipe(
143
  inputs,
144
  max_new_tokens=max_new_tokens,
 
97
  print(f"Warmup skipped due to: {e}")
98
 
99
  def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
 
 
 
 
 
100
  inputs = data.get("inputs")
101
  if inputs is None:
102
  raise ValueError("Missing 'inputs' in request data")
103
 
 
104
  params = data.get("parameters", {})
105
+ max_new_tokens = params.get("max_new_tokens", 200)
106
+ num_beams = params.get("num_beams", 1)
 
 
 
 
107
  do_sample = params.get("do_sample", False)
108
+ length_penalty = params.get("length_penalty", 1.0)
109
+ repetition_penalty = params.get("repetition_penalty", 1.0)
110
+ no_repeat_ngram_size = params.get("no_repeat_ngram_size", 0)
 
 
 
 
 
 
 
 
 
 
 
 
111
  early_stopping = params.get("early_stopping", True)
 
 
 
112
  default_diversity = self.pipe.model.config.inference_diversity_penalty
113
  diversity_penalty = params.get("diversity_penalty", default_diversity)
114
 
 
 
115
  return self.pipe(
116
  inputs,
117
  max_new_tokens=max_new_tokens,