mazesmazes commited on
Commit
12f0409
·
verified ·
1 Parent(s): b13f900

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_pipeline.py +4 -2
asr_pipeline.py CHANGED
@@ -42,7 +42,9 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
42
  yield item
43
 
44
  def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
45
- # Extract audio features
 
 
46
  if isinstance(model_inputs, dict):
47
  input_features = model_inputs.get("input_features")
48
  if input_features is not None:
@@ -55,7 +57,7 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
55
  **generate_kwargs,
56
  )
57
 
58
- return {"tokens": generated_ids}
59
 
60
  def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
61
  tokens = model_outputs.get("tokens")
 
42
  yield item
43
 
44
  def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
45
+ # Extract audio features and is_last flag
46
+ is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
47
+
48
  if isinstance(model_inputs, dict):
49
  input_features = model_inputs.get("input_features")
50
  if input_features is not None:
 
57
  **generate_kwargs,
58
  )
59
 
60
+ return {"tokens": generated_ids, "is_last": is_last}
61
 
62
  def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
63
  tokens = model_outputs.get("tokens")