Update custom model files, README, and requirements
Browse files- asr_config.py +3 -4
- asr_modeling.py +5 -3
- handler.py +5 -32
asr_config.py
CHANGED
|
@@ -21,10 +21,8 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 21 |
llm_dim: Optional[int] = None,
|
| 22 |
audio_sample_rate: int = 16000,
|
| 23 |
projector_init_std: float = 0.02,
|
| 24 |
-
projector_pool_stride: int = 2,
|
| 25 |
-
projector_hidden_dim: Optional[
|
| 26 |
-
int
|
| 27 |
-
] = None,
|
| 28 |
projector_dropout: float = 0.0, # Dropout rate for projector layers
|
| 29 |
inference_diversity_penalty: float = 0.0,
|
| 30 |
inference_warmup_tokens: int = 10,
|
|
@@ -120,4 +118,5 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 120 |
self.architectures = ["ASRModel"]
|
| 121 |
self.pipeline_tag = "automatic-speech-recognition"
|
| 122 |
|
|
|
|
| 123 |
transformers.AutoConfig.register("asr_model", ASRConfig)
|
|
|
|
| 21 |
llm_dim: Optional[int] = None,
|
| 22 |
audio_sample_rate: int = 16000,
|
| 23 |
projector_init_std: float = 0.02,
|
| 24 |
+
projector_pool_stride: int = 2,
|
| 25 |
+
projector_hidden_dim: Optional[int] = None,
|
|
|
|
|
|
|
| 26 |
projector_dropout: float = 0.0, # Dropout rate for projector layers
|
| 27 |
inference_diversity_penalty: float = 0.0,
|
| 28 |
inference_warmup_tokens: int = 10,
|
|
|
|
| 118 |
self.architectures = ["ASRModel"]
|
| 119 |
self.pipeline_tag = "automatic-speech-recognition"
|
| 120 |
|
| 121 |
+
|
| 122 |
transformers.AutoConfig.register("asr_model", ASRConfig)
|
asr_modeling.py
CHANGED
|
@@ -26,7 +26,6 @@ except ImportError:
|
|
| 26 |
|
| 27 |
|
| 28 |
class SwiGLU(nn.Module):
|
| 29 |
-
|
| 30 |
def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
|
| 31 |
super().__init__()
|
| 32 |
self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
|
|
@@ -44,7 +43,6 @@ class SwiGLU(nn.Module):
|
|
| 44 |
|
| 45 |
|
| 46 |
class AudioProjector(nn.Module):
|
| 47 |
-
|
| 48 |
def __init__(self, config):
|
| 49 |
super().__init__()
|
| 50 |
self.k = getattr(config, "projector_pool_stride", 2) # Downsampling rate
|
|
@@ -120,8 +118,12 @@ class ASRModel(PreTrainedModel):
|
|
| 120 |
return WhisperFeatureExtractor.from_pretrained(
|
| 121 |
audio_model_id,
|
| 122 |
feature_size=num_mel_bins,
|
|
|
|
| 123 |
)
|
| 124 |
-
return Wav2Vec2FeatureExtractor.from_pretrained(
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
@classmethod
|
| 127 |
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
class SwiGLU(nn.Module):
|
|
|
|
| 29 |
def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
|
| 30 |
super().__init__()
|
| 31 |
self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
class AudioProjector(nn.Module):
|
|
|
|
| 46 |
def __init__(self, config):
|
| 47 |
super().__init__()
|
| 48 |
self.k = getattr(config, "projector_pool_stride", 2) # Downsampling rate
|
|
|
|
| 118 |
return WhisperFeatureExtractor.from_pretrained(
|
| 119 |
audio_model_id,
|
| 120 |
feature_size=num_mel_bins,
|
| 121 |
+
do_normalize=True,
|
| 122 |
)
|
| 123 |
+
return Wav2Vec2FeatureExtractor.from_pretrained(
|
| 124 |
+
audio_model_id,
|
| 125 |
+
do_normalize=True,
|
| 126 |
+
)
|
| 127 |
|
| 128 |
@classmethod
|
| 129 |
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
handler.py
CHANGED
|
@@ -97,48 +97,21 @@ class EndpointHandler:
|
|
| 97 |
print(f"Warmup skipped due to: {e}")
|
| 98 |
|
| 99 |
def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
|
| 100 |
-
"""Process audio transcription request.
|
| 101 |
-
|
| 102 |
-
Supports both single and batch inputs for efficient concurrent processing.
|
| 103 |
-
The endpoint infrastructure can batch multiple concurrent requests automatically.
|
| 104 |
-
"""
|
| 105 |
inputs = data.get("inputs")
|
| 106 |
if inputs is None:
|
| 107 |
raise ValueError("Missing 'inputs' in request data")
|
| 108 |
|
| 109 |
-
# Get generation parameters (matching SLAM-ASR paper defaults)
|
| 110 |
params = data.get("parameters", {})
|
| 111 |
-
max_new_tokens = params.get("max_new_tokens", 200)
|
| 112 |
-
|
| 113 |
-
# Beam search for better quality (5 beams for higher quality)
|
| 114 |
-
# Use num_beams=1 for faster inference at cost of ~2-3% WER increase
|
| 115 |
-
num_beams = params.get("num_beams", 5)
|
| 116 |
-
|
| 117 |
do_sample = params.get("do_sample", False)
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
# Slight positive bias helps avoid truncated transcripts
|
| 122 |
-
length_penalty = params.get("length_penalty", 1.1)
|
| 123 |
-
|
| 124 |
-
# Repetition penalty to prevent loops (1.1-1.2 is good for ASR)
|
| 125 |
-
repetition_penalty = params.get("repetition_penalty", 1.15)
|
| 126 |
-
|
| 127 |
-
# Alternative: use no_repeat_ngram_size to prevent exact n-gram repetition
|
| 128 |
-
no_repeat_ngram_size = params.get("no_repeat_ngram_size", 3)
|
| 129 |
-
|
| 130 |
-
# Early stopping for beam search: stop when all beams end
|
| 131 |
-
# "never" = generate full max_new_tokens (more accurate but slower)
|
| 132 |
-
# True = stop when all beams reach EOS (faster)
|
| 133 |
early_stopping = params.get("early_stopping", True)
|
| 134 |
-
|
| 135 |
-
# Diversity penalty encourages different beams (helps with rare words)
|
| 136 |
-
# 0.0 = no diversity, 0.5-1.0 = good diversity
|
| 137 |
default_diversity = self.pipe.model.config.inference_diversity_penalty
|
| 138 |
diversity_penalty = params.get("diversity_penalty", default_diversity)
|
| 139 |
|
| 140 |
-
# The pipeline's __call__ method handles both single and batch inputs
|
| 141 |
-
# as well as automatic chunking for long audio files
|
| 142 |
return self.pipe(
|
| 143 |
inputs,
|
| 144 |
max_new_tokens=max_new_tokens,
|
|
|
|
| 97 |
print(f"Warmup skipped due to: {e}")
|
| 98 |
|
| 99 |
def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
inputs = data.get("inputs")
|
| 101 |
if inputs is None:
|
| 102 |
raise ValueError("Missing 'inputs' in request data")
|
| 103 |
|
|
|
|
| 104 |
params = data.get("parameters", {})
|
| 105 |
+
max_new_tokens = params.get("max_new_tokens", 200)
|
| 106 |
+
num_beams = params.get("num_beams", 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
do_sample = params.get("do_sample", False)
|
| 108 |
+
length_penalty = params.get("length_penalty", 1.0)
|
| 109 |
+
repetition_penalty = params.get("repetition_penalty", 1.0)
|
| 110 |
+
no_repeat_ngram_size = params.get("no_repeat_ngram_size", 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
early_stopping = params.get("early_stopping", True)
|
|
|
|
|
|
|
|
|
|
| 112 |
default_diversity = self.pipe.model.config.inference_diversity_penalty
|
| 113 |
diversity_penalty = params.get("diversity_penalty", default_diversity)
|
| 114 |
|
|
|
|
|
|
|
| 115 |
return self.pipe(
|
| 116 |
inputs,
|
| 117 |
max_new_tokens=max_new_tokens,
|