Training in progress - step 12000
Browse files- asr_pipeline.py +5 -12
asr_pipeline.py
CHANGED
|
@@ -26,6 +26,8 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 26 |
)
|
| 27 |
|
| 28 |
def preprocess(self, inputs, **preprocess_params):
|
|
|
|
|
|
|
| 29 |
# Handle dict with "array" key (from datasets)
|
| 30 |
if isinstance(inputs, dict) and "array" in inputs:
|
| 31 |
inputs = {
|
|
@@ -33,15 +35,10 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 33 |
"sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
|
| 34 |
}
|
| 35 |
|
| 36 |
-
|
| 37 |
-
if "is_last" not in item:
|
| 38 |
-
item["is_last"] = True
|
| 39 |
-
yield item
|
| 40 |
|
| 41 |
def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
|
| 42 |
-
# Extract audio features
|
| 43 |
-
is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
|
| 44 |
-
|
| 45 |
if isinstance(model_inputs, dict):
|
| 46 |
input_features = model_inputs.get("input_features")
|
| 47 |
if input_features is not None:
|
|
@@ -54,13 +51,9 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 54 |
**generate_kwargs,
|
| 55 |
)
|
| 56 |
|
| 57 |
-
return {"tokens": generated_ids
|
| 58 |
|
| 59 |
def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
|
| 60 |
-
# Handle list of outputs (from chunking)
|
| 61 |
-
if isinstance(model_outputs, list):
|
| 62 |
-
model_outputs = model_outputs[0] if model_outputs else {}
|
| 63 |
-
|
| 64 |
tokens = model_outputs.get("tokens")
|
| 65 |
if tokens is None:
|
| 66 |
return super().postprocess(model_outputs, **kwargs)
|
|
|
|
| 26 |
)
|
| 27 |
|
| 28 |
def preprocess(self, inputs, **preprocess_params):
|
| 29 |
+
preprocess_params.setdefault("chunk_length_s", 0)
|
| 30 |
+
|
| 31 |
# Handle dict with "array" key (from datasets)
|
| 32 |
if isinstance(inputs, dict) and "array" in inputs:
|
| 33 |
inputs = {
|
|
|
|
| 35 |
"sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
|
| 36 |
}
|
| 37 |
|
| 38 |
+
return super().preprocess(inputs, **preprocess_params)
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
|
| 41 |
+
# Extract audio features
|
|
|
|
|
|
|
| 42 |
if isinstance(model_inputs, dict):
|
| 43 |
input_features = model_inputs.get("input_features")
|
| 44 |
if input_features is not None:
|
|
|
|
| 51 |
**generate_kwargs,
|
| 52 |
)
|
| 53 |
|
| 54 |
+
return {"tokens": generated_ids}
|
| 55 |
|
| 56 |
def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
tokens = model_outputs.get("tokens")
|
| 58 |
if tokens is None:
|
| 59 |
return super().postprocess(model_outputs, **kwargs)
|