MoHamdyy commited on
Commit
403dd60
·
1 Parent(s): 574f683

Fix: Correctly pass GenerationConfig to Whisper model

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -403,10 +403,18 @@ def full_speech_translation_pipeline(audio_input_path: str):
403
 
404
  print("STT: Extracting features and transcribing...")
405
  inputs = stt_processor(audio_array_stt, sampling_rate=target_sr_stt, return_tensors="pt").input_features.to(DEVICE)
 
406
  forced_ids = stt_processor.get_decoder_prompt_ids(language="arabic", task="transcribe")
 
407
  with torch.no_grad():
408
- generated_ids = stt_model.generate(inputs, forced_decoder_ids=forced_ids, max_length=448)
409
- arabic_transcript = stt_processor.decode(generated_ids[0], skip_special_tokens=True).strip()
 
 
 
 
 
 
410
  print(f"STT Output: {arabic_transcript}")
411
  except Exception as e:
412
  print(f"STT Error: {e}")
@@ -439,14 +447,6 @@ def full_speech_translation_pipeline(audio_input_path: str):
439
  if generated_mel is not None and generated_mel.numel() > 0:
440
  mel_for_vocoder = generated_mel.detach().squeeze(0).transpose(0, 1)
441
  audio_tensor = inverse_mel_spec_to_wav(mel_for_vocoder)
442
- synthesized_audio_np = audio_tensor.cpu().numpy()
443
- print(f"TTS: Synthesized audio shape: {synthesized_audio_np.shape}")
444
- except Exception as e:
445
- print(f"TTS Error: {e}")
446
-
447
- print(f"--- PIPELINE END ---")
448
- return arabic_transcript, english_translation, (hp.sr, synthesized_audio_np)
449
-
450
 
451
  # --- Part 4: Gradio Interface Definition ---
452
  # (Same as before)
 
403
 
404
  print("STT: Extracting features and transcribing...")
405
  inputs = stt_processor(audio_array_stt, sampling_rate=target_sr_stt, return_tensors="pt").input_features.to(DEVICE)
406
+
407
  forced_ids = stt_processor.get_decoder_prompt_ids(language="arabic", task="transcribe")
408
+
409
  with torch.no_grad():
410
+ # Pass forced_decoder_ids via a GenerationConfig to avoid unused kwargs error
411
+ generation_config = stt_model.generation_config.copy()
412
+ generation_config.forced_decoder_ids = forced_ids
413
+
414
+ generated_ids = stt_model.generate(inputs, generation_config=generation_config, max_length=448)
415
+
416
+ # Use batch_decode for robustness
417
+ arabic_transcript = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
418
  print(f"STT Output: {arabic_transcript}")
419
  except Exception as e:
420
  print(f"STT Error: {e}")
 
447
  if generated_mel is not None and generated_mel.numel() > 0:
448
  mel_for_vocoder = generated_mel.detach().squeeze(0).transpose(0, 1)
449
  audio_tensor = inverse_mel_spec_to_wav(mel_for_vocoder)
 
 
 
 
 
 
 
 
450
 
451
  # --- Part 4: Gradio Interface Definition ---
452
  # (Same as before)