Spaces:
Sleeping
Sleeping
Fix: Correctly pass GenerationConfig to Whisper model
Browse files
app.py
CHANGED
|
@@ -403,10 +403,18 @@ def full_speech_translation_pipeline(audio_input_path: str):
|
|
| 403 |
|
| 404 |
print("STT: Extracting features and transcribing...")
|
| 405 |
inputs = stt_processor(audio_array_stt, sampling_rate=target_sr_stt, return_tensors="pt").input_features.to(DEVICE)
|
|
|
|
| 406 |
forced_ids = stt_processor.get_decoder_prompt_ids(language="arabic", task="transcribe")
|
|
|
|
| 407 |
with torch.no_grad():
|
| 408 |
-
|
| 409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
print(f"STT Output: {arabic_transcript}")
|
| 411 |
except Exception as e:
|
| 412 |
print(f"STT Error: {e}")
|
|
@@ -439,14 +447,6 @@ def full_speech_translation_pipeline(audio_input_path: str):
|
|
| 439 |
if generated_mel is not None and generated_mel.numel() > 0:
|
| 440 |
mel_for_vocoder = generated_mel.detach().squeeze(0).transpose(0, 1)
|
| 441 |
audio_tensor = inverse_mel_spec_to_wav(mel_for_vocoder)
|
| 442 |
-
synthesized_audio_np = audio_tensor.cpu().numpy()
|
| 443 |
-
print(f"TTS: Synthesized audio shape: {synthesized_audio_np.shape}")
|
| 444 |
-
except Exception as e:
|
| 445 |
-
print(f"TTS Error: {e}")
|
| 446 |
-
|
| 447 |
-
print(f"--- PIPELINE END ---")
|
| 448 |
-
return arabic_transcript, english_translation, (hp.sr, synthesized_audio_np)
|
| 449 |
-
|
| 450 |
|
| 451 |
# --- Part 4: Gradio Interface Definition ---
|
| 452 |
# (Same as before)
|
|
|
|
| 403 |
|
| 404 |
print("STT: Extracting features and transcribing...")
|
| 405 |
inputs = stt_processor(audio_array_stt, sampling_rate=target_sr_stt, return_tensors="pt").input_features.to(DEVICE)
|
| 406 |
+
|
| 407 |
forced_ids = stt_processor.get_decoder_prompt_ids(language="arabic", task="transcribe")
|
| 408 |
+
|
| 409 |
with torch.no_grad():
|
| 410 |
+
# Pass forced_decoder_ids via a GenerationConfig to avoid unused kwargs error
|
| 411 |
+
generation_config = stt_model.generation_config.copy()
|
| 412 |
+
generation_config.forced_decoder_ids = forced_ids
|
| 413 |
+
|
| 414 |
+
generated_ids = stt_model.generate(inputs, generation_config=generation_config, max_length=448)
|
| 415 |
+
|
| 416 |
+
# Use batch_decode for robustness
|
| 417 |
+
arabic_transcript = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
| 418 |
print(f"STT Output: {arabic_transcript}")
|
| 419 |
except Exception as e:
|
| 420 |
print(f"STT Error: {e}")
|
|
|
|
| 447 |
if generated_mel is not None and generated_mel.numel() > 0:
|
| 448 |
mel_for_vocoder = generated_mel.detach().squeeze(0).transpose(0, 1)
|
| 449 |
audio_tensor = inverse_mel_spec_to_wav(mel_for_vocoder)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
|
| 451 |
# --- Part 4: Gradio Interface Definition ---
|
| 452 |
# (Same as before)
|