mazesmazes commited on
Commit
e696bea
·
verified ·
1 Parent(s): 29a8ec6

Training in progress - step 12000

Browse files
Files changed (1) hide show
  1. asr_pipeline.py +5 -12
asr_pipeline.py CHANGED
@@ -26,6 +26,8 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
26
  )
27
 
28
  def preprocess(self, inputs, **preprocess_params):
 
 
29
  # Handle dict with "array" key (from datasets)
30
  if isinstance(inputs, dict) and "array" in inputs:
31
  inputs = {
@@ -33,15 +35,10 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
33
  "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
34
  }
35
 
36
- for item in super().preprocess(inputs, **preprocess_params):
37
- if "is_last" not in item:
38
- item["is_last"] = True
39
- yield item
40
 
41
  def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
42
- # Extract audio features and is_last flag
43
- is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
44
-
45
  if isinstance(model_inputs, dict):
46
  input_features = model_inputs.get("input_features")
47
  if input_features is not None:
@@ -54,13 +51,9 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
54
  **generate_kwargs,
55
  )
56
 
57
- return {"tokens": generated_ids, "is_last": is_last}
58
 
59
  def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
60
- # Handle list of outputs (from chunking)
61
- if isinstance(model_outputs, list):
62
- model_outputs = model_outputs[0] if model_outputs else {}
63
-
64
  tokens = model_outputs.get("tokens")
65
  if tokens is None:
66
  return super().postprocess(model_outputs, **kwargs)
 
26
  )
27
 
28
  def preprocess(self, inputs, **preprocess_params):
29
+ preprocess_params.setdefault("chunk_length_s", 0)
30
+
31
  # Handle dict with "array" key (from datasets)
32
  if isinstance(inputs, dict) and "array" in inputs:
33
  inputs = {
 
35
  "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
36
  }
37
 
38
+ return super().preprocess(inputs, **preprocess_params)
 
 
 
39
 
40
  def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
41
+ # Extract audio features
 
 
42
  if isinstance(model_inputs, dict):
43
  input_features = model_inputs.get("input_features")
44
  if input_features is not None:
 
51
  **generate_kwargs,
52
  )
53
 
54
+ return {"tokens": generated_ids}
55
 
56
  def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
 
 
 
 
57
  tokens = model_outputs.get("tokens")
58
  if tokens is None:
59
  return super().postprocess(model_outputs, **kwargs)