mazesmazes commited on
Commit
8d608ce
·
verified ·
1 Parent(s): 9fbfc75

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_pipeline.py +24 -2
asr_pipeline.py CHANGED
@@ -327,16 +327,33 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
327
  if key in kwargs:
328
  generate_kwargs[key] = kwargs.pop(key)
329
 
 
 
 
 
330
  # Preprocess audio to get model inputs
331
- model_inputs = self.preprocess(inputs, **kwargs)
332
 
333
  # Handle different input formats
334
  audio_inputs = None
335
  is_whisper = False
336
 
 
 
 
 
 
 
 
 
 
337
  if isinstance(model_inputs, torch.Tensor):
338
  audio_inputs = model_inputs
339
  elif isinstance(model_inputs, dict):
 
 
 
 
340
  # Get audio input (Whisper uses input_features, others use input_values)
341
  if "input_features" in model_inputs:
342
  audio_inputs = model_inputs["input_features"]
@@ -345,7 +362,12 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
345
  audio_inputs = model_inputs.get("input_values")
346
 
347
  if audio_inputs is None:
348
- raise ValueError("Could not extract audio inputs from preprocessing")
 
 
 
 
 
349
 
350
  if isinstance(audio_inputs, torch.Tensor):
351
  audio_inputs = audio_inputs.to(self.model.device)
 
327
  if key in kwargs:
328
  generate_kwargs[key] = kwargs.pop(key)
329
 
330
+ # Disable chunking for streaming - we want the whole audio at once
331
+ kwargs.pop("chunk_length_s", None)
332
+ kwargs.pop("stride_length_s", None)
333
+
334
  # Preprocess audio to get model inputs
335
+ model_inputs = self.preprocess(inputs, chunk_length_s=0, **kwargs)
336
 
337
  # Handle different input formats
338
  audio_inputs = None
339
  is_whisper = False
340
 
341
+ # Check if preprocess returned an iterator (shouldn't with chunk_length_s=0)
342
+ from collections.abc import Iterator
343
+ if isinstance(model_inputs, Iterator):
344
+ # Get the first (and should be only) chunk
345
+ try:
346
+ model_inputs = next(model_inputs)
347
+ except StopIteration:
348
+ raise ValueError("Preprocess returned empty iterator")
349
+
350
  if isinstance(model_inputs, torch.Tensor):
351
  audio_inputs = model_inputs
352
  elif isinstance(model_inputs, dict):
353
+ # Remove metadata fields
354
+ model_inputs.pop("is_last", None)
355
+ model_inputs.pop("stride", None)
356
+
357
  # Get audio input (Whisper uses input_features, others use input_values)
358
  if "input_features" in model_inputs:
359
  audio_inputs = model_inputs["input_features"]
 
362
  audio_inputs = model_inputs.get("input_values")
363
 
364
  if audio_inputs is None:
365
+ # Debug info
366
+ import sys
367
+ print(f"DEBUG: model_inputs type: {type(model_inputs)}", file=sys.stderr)
368
+ if isinstance(model_inputs, dict):
369
+ print(f"DEBUG: model_inputs keys: {model_inputs.keys()}", file=sys.stderr)
370
+ raise ValueError(f"Could not extract audio inputs from preprocessing. Got type: {type(model_inputs)}")
371
 
372
  if isinstance(audio_inputs, torch.Tensor):
373
  audio_inputs = audio_inputs.to(self.model.device)