mazesmazes commited on
Commit
7ff69ca
·
verified ·
1 Parent(s): 3b0e54b

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_pipeline.py +12 -5
asr_pipeline.py CHANGED
@@ -26,8 +26,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
26
  )
27
 
28
  def preprocess(self, inputs, **preprocess_params):
29
- preprocess_params.setdefault("chunk_length_s", 0)
30
-
31
  # Handle dict with "array" key (from datasets)
32
  if isinstance(inputs, dict) and "array" in inputs:
33
  inputs = {
@@ -35,10 +33,15 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
35
  "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
36
  }
37
 
38
- return super().preprocess(inputs, **preprocess_params)
 
 
 
39
 
40
  def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
41
- # Extract audio features
 
 
42
  if isinstance(model_inputs, dict):
43
  input_features = model_inputs.get("input_features")
44
  if input_features is not None:
@@ -51,9 +54,13 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
51
  **generate_kwargs,
52
  )
53
 
54
- return {"tokens": generated_ids}
55
 
56
  def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
 
 
 
 
57
  tokens = model_outputs.get("tokens")
58
  if tokens is None:
59
  return super().postprocess(model_outputs, **kwargs)
 
26
  )
27
 
28
  def preprocess(self, inputs, **preprocess_params):
 
 
29
  # Handle dict with "array" key (from datasets)
30
  if isinstance(inputs, dict) and "array" in inputs:
31
  inputs = {
 
33
  "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
34
  }
35
 
36
+ for item in super().preprocess(inputs, **preprocess_params):
37
+ if "is_last" not in item:
38
+ item["is_last"] = True
39
+ yield item
40
 
41
  def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
42
+ # Extract audio features and is_last flag
43
+ is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
44
+
45
  if isinstance(model_inputs, dict):
46
  input_features = model_inputs.get("input_features")
47
  if input_features is not None:
 
54
  **generate_kwargs,
55
  )
56
 
57
+ return {"tokens": generated_ids, "is_last": is_last}
58
 
59
  def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
60
+ # Handle list of outputs (from chunking)
61
+ if isinstance(model_outputs, list):
62
+ model_outputs = model_outputs[0] if model_outputs else {}
63
+
64
  tokens = model_outputs.get("tokens")
65
  if tokens is None:
66
  return super().postprocess(model_outputs, **kwargs)