Spaces:
Configuration error
Configuration error
Update app.py
Browse files
app.py
CHANGED
|
@@ -28,11 +28,21 @@ device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
|
|
| 28 |
|
| 29 |
whisper_model = whisper.load_model("turbo")
|
| 30 |
|
| 31 |
-
def detect_speech_language(
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def detect_text_language(text):
|
| 35 |
-
langid.classify(text)[0]
|
| 36 |
|
| 37 |
@torch.no_grad()
|
| 38 |
def get_prompt_text(speech_16k, language):
|
|
@@ -41,7 +51,6 @@ def get_prompt_text(speech_16k, language):
|
|
| 41 |
short_prompt_end_ts = 0.0
|
| 42 |
|
| 43 |
asr_result = whisper_model.transcribe(speech_16k, language=language)
|
| 44 |
-
print("asr_result:", asr_result)
|
| 45 |
full_prompt_text = asr_result["text"] # whisper asr result
|
| 46 |
#text = asr_result["segments"][0]["text"] # whisperx asr result
|
| 47 |
shot_prompt_text = ""
|
|
@@ -51,8 +60,6 @@ def get_prompt_text(speech_16k, language):
|
|
| 51 |
short_prompt_end_ts = segment['end']
|
| 52 |
if short_prompt_end_ts >= 4:
|
| 53 |
break
|
| 54 |
-
print("full prompt text:", full_prompt_text, " shot_prompt_text:", shot_prompt_text,
|
| 55 |
-
"short_prompt_end_ts:", short_prompt_end_ts)
|
| 56 |
return full_prompt_text, shot_prompt_text, short_prompt_end_ts
|
| 57 |
|
| 58 |
|
|
@@ -310,7 +317,7 @@ def maskgct_inference(
|
|
| 310 |
speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]
|
| 311 |
speech = librosa.load(prompt_speech_path, sr=24000)[0]
|
| 312 |
|
| 313 |
-
prompt_language = detect_speech_language(
|
| 314 |
full_prompt_text, short_prompt_text, shot_prompt_end_ts = get_prompt_text(prompt_speech_path,
|
| 315 |
prompt_language)
|
| 316 |
# use the first 4+ seconds wav as the prompt in case the prompt wav is too long
|
|
@@ -321,7 +328,7 @@ def maskgct_inference(
|
|
| 321 |
device,
|
| 322 |
speech_16k,
|
| 323 |
short_prompt_text,
|
| 324 |
-
|
| 325 |
target_text,
|
| 326 |
target_language,
|
| 327 |
target_len,
|
|
@@ -393,9 +400,17 @@ iface = gr.Interface(
|
|
| 393 |
outputs=gr.Audio(label="Generated Audio"),
|
| 394 |
title="MaskGCT TTS Demo",
|
| 395 |
description="""
|
| 396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
"""
|
| 398 |
)
|
| 399 |
|
| 400 |
# Launch the interface
|
| 401 |
-
iface.launch(allowed_paths=["./output"])
|
|
|
|
| 28 |
|
| 29 |
whisper_model = whisper.load_model("turbo")
|
| 30 |
|
| 31 |
+
def detect_speech_language(speech_file):
|
| 32 |
+
# load audio and pad/trim it to fit 30 seconds
|
| 33 |
+
audio = whisper.load_audio(speech_file)
|
| 34 |
+
audio = whisper.pad_or_trim(audio)
|
| 35 |
+
|
| 36 |
+
# make log-Mel spectrogram and move to the same device as the model
|
| 37 |
+
mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(whisper_model.device)
|
| 38 |
+
|
| 39 |
+
# detect the spoken language
|
| 40 |
+
_, probs = whisper_model.detect_language(mel)
|
| 41 |
+
return max(probs, key=probs.get)
|
| 42 |
+
|
| 43 |
|
| 44 |
def detect_text_language(text):
|
| 45 |
+
return langid.classify(text)[0]
|
| 46 |
|
| 47 |
@torch.no_grad()
|
| 48 |
def get_prompt_text(speech_16k, language):
|
|
|
|
| 51 |
short_prompt_end_ts = 0.0
|
| 52 |
|
| 53 |
asr_result = whisper_model.transcribe(speech_16k, language=language)
|
|
|
|
| 54 |
full_prompt_text = asr_result["text"] # whisper asr result
|
| 55 |
#text = asr_result["segments"][0]["text"] # whisperx asr result
|
| 56 |
shot_prompt_text = ""
|
|
|
|
| 60 |
short_prompt_end_ts = segment['end']
|
| 61 |
if short_prompt_end_ts >= 4:
|
| 62 |
break
|
|
|
|
|
|
|
| 63 |
return full_prompt_text, shot_prompt_text, short_prompt_end_ts
|
| 64 |
|
| 65 |
|
|
|
|
| 317 |
speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]
|
| 318 |
speech = librosa.load(prompt_speech_path, sr=24000)[0]
|
| 319 |
|
| 320 |
+
prompt_language = detect_speech_language(prompt_speech_path)
|
| 321 |
full_prompt_text, short_prompt_text, shot_prompt_end_ts = get_prompt_text(prompt_speech_path,
|
| 322 |
prompt_language)
|
| 323 |
# use the first 4+ seconds wav as the prompt in case the prompt wav is too long
|
|
|
|
| 328 |
device,
|
| 329 |
speech_16k,
|
| 330 |
short_prompt_text,
|
| 331 |
+
prompt_language,
|
| 332 |
target_text,
|
| 333 |
target_language,
|
| 334 |
target_len,
|
|
|
|
| 400 |
outputs=gr.Audio(label="Generated Audio"),
|
| 401 |
title="MaskGCT TTS Demo",
|
| 402 |
description="""
|
| 403 |
+
## MaskGCT: Zero-Shot Text-to-Speech with Masked Generative Codec Transformer
|
| 404 |
+
|
| 405 |
+
[](https://arxiv.org/abs/2409.00750)
|
| 406 |
+
|
| 407 |
+
[](https://huggingface.co/amphion/maskgct)
|
| 408 |
+
|
| 409 |
+
[](https://huggingface.co/spaces/amphion/maskgct)
|
| 410 |
+
|
| 411 |
+
[](https://github.com/open-mmlab/Amphion/tree/main/models/tts/maskgct)
|
| 412 |
"""
|
| 413 |
)
|
| 414 |
|
| 415 |
# Launch the interface
|
| 416 |
+
iface.launch(allowed_paths=["./output"])
|