hadadrjt's picture
[3/?] Pocket TTS: Handle multiple format extensions for voice cloning.
dae9fa5
#
# SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
# SPDX-License-Identifier: Apache-2.0
#
import gradio as gr
from config import VOICE_MODE_CLONE
from ..core.state import (
generation_state_lock,
get_stop_generation_requested,
set_stop_generation_requested
)
from ..core.authentication import get_huggingface_token
from ..core.memory import (
has_temporary_files_pending_cleanup,
cleanup_expired_temporary_files,
perform_memory_cleanup,
memory_cleanup,
trigger_background_cleanup_check
)
from ..tts.manager import text_to_speech_manager
from ..validation.text import validate_text_input
from ..audio.validator import (
perform_comprehensive_audio_validation,
get_format_display_name
)
from ..audio.converter import prepare_audio_file_for_voice_cloning
def check_if_generating():
from ..core.state import is_currently_generating
with generation_state_lock:
return is_currently_generating
def request_generation_stop():
set_stop_generation_requested(True)
return gr.update(interactive=False)
def validate_and_prepare_voice_clone_audio(voice_clone_audio_file):
if not voice_clone_audio_file:
return None, "Please upload an audio file for voice cloning.", None, None
is_valid, is_wav_format, detected_format, validation_error = perform_comprehensive_audio_validation(voice_clone_audio_file)
if not is_valid:
format_display_name = get_format_display_name(detected_format) if detected_format else "Unknown"
if validation_error:
if "too short" in validation_error.lower():
return None, f"The uploaded audio file is too short. Please upload a longer audio sample for better voice cloning results.", None, detected_format
if "too long" in validation_error.lower():
return None, f"The uploaded audio file is too long. Please upload a shorter audio sample (maximum 1 hour).", None, detected_format
if "empty" in validation_error.lower() or "0 bytes" in validation_error.lower():
return None, "The uploaded audio file is empty. Please upload a valid audio file.", None, detected_format
if "corrupted" in validation_error.lower() or "truncated" in validation_error.lower():
return None, f"The uploaded {format_display_name} file appears to be corrupted or incomplete. Please upload a valid audio file.", None, detected_format
if "unsupported" in validation_error.lower():
return None, validation_error, None, detected_format
return None, f"Invalid audio file: {validation_error}", None, detected_format
return None, "The uploaded file could not be validated as a valid audio file.", None, detected_format
format_display_name = get_format_display_name(detected_format)
if is_wav_format:
prepared_path, preparation_error, was_converted, final_format = prepare_audio_file_for_voice_cloning(voice_clone_audio_file)
if prepared_path is None:
return None, f"Failed to process WAV file: {preparation_error}", None, 'wav'
return prepared_path, None, False, 'wav'
else:
prepared_path, preparation_error, was_converted, final_format = prepare_audio_file_for_voice_cloning(voice_clone_audio_file)
if prepared_path is None:
if "no audio conversion library" in preparation_error.lower():
return None, f"Cannot convert {format_display_name} format. Please upload a WAV file directly.", None, detected_format
return None, f"Failed to convert {format_display_name} to WAV format: {preparation_error}", None, detected_format
return prepared_path, None, True, detected_format
def perform_speech_generation(
text_input,
voice_mode_selection,
voice_preset_selection,
voice_clone_audio_file,
model_variant,
lsd_decode_steps,
temperature,
noise_clamp,
eos_threshold,
frames_after_eos,
enable_custom_frames
):
from ..core import state as global_state
if has_temporary_files_pending_cleanup():
cleanup_expired_temporary_files()
perform_memory_cleanup()
is_valid, validation_result = validate_text_input(text_input)
if not is_valid:
if validation_result:
raise gr.Error(validation_result)
raise gr.Error("Please enter valid text to generate speech.")
prepared_audio_path = None
was_audio_converted = False
original_audio_format = None
if voice_mode_selection == VOICE_MODE_CLONE:
if not voice_clone_audio_file:
raise gr.Error("Please upload an audio file for voice cloning.")
if not get_huggingface_token():
raise gr.Error("Voice cloning is not configured properly at the moment. Please try again later.")
prepared_audio_path, audio_error, was_audio_converted, original_audio_format = validate_and_prepare_voice_clone_audio(voice_clone_audio_file)
if prepared_audio_path is None:
raise gr.Error(audio_error)
if was_audio_converted:
format_display_name = get_format_display_name(original_audio_format)
gr.Warning(f"Audio converted from {format_display_name} to WAV format for voice cloning.")
with generation_state_lock:
if global_state.is_currently_generating:
raise gr.Error("A generation is already in progress. Please wait.")
global_state.is_currently_generating = True
global_state.stop_generation_requested = False
generated_audio_tensor = None
cloned_voice_state_tensor = None
try:
text_to_speech_manager.load_or_get_model(
model_variant,
temperature,
lsd_decode_steps,
noise_clamp,
eos_threshold
)
with generation_state_lock:
if global_state.stop_generation_requested:
return None
if voice_mode_selection == VOICE_MODE_CLONE:
cloned_voice_state_tensor = text_to_speech_manager.get_voice_state_for_clone(
voice_clone_audio_file,
prepared_audio_path=prepared_audio_path
)
voice_state = cloned_voice_state_tensor
else:
voice_state = text_to_speech_manager.get_voice_state_for_preset(voice_preset_selection)
with generation_state_lock:
if global_state.stop_generation_requested:
return None
generated_audio_tensor = text_to_speech_manager.generate_audio(
validation_result,
voice_state,
frames_after_eos,
enable_custom_frames
)
with generation_state_lock:
if global_state.stop_generation_requested:
return None
output_file_path = text_to_speech_manager.save_audio_to_file(generated_audio_tensor)
return output_file_path
except gr.Error:
raise
except RuntimeError as runtime_error:
raise gr.Error(str(runtime_error))
except Exception as generation_error:
error_message = str(generation_error)
if "file does not start with RIFF id" in error_message:
raise gr.Error("The audio file format is not supported. Please upload a valid WAV file or a common audio format (MP3, FLAC, OGG, M4A).")
if "unknown format" in error_message.lower():
raise gr.Error("The audio file uses an unsupported encoding format. Please convert it to a standard format and try again.")
raise gr.Error(f"Speech generation failed: {error_message}")
finally:
with generation_state_lock:
global_state.is_currently_generating = False
global_state.stop_generation_requested = False
if generated_audio_tensor is not None:
del generated_audio_tensor
generated_audio_tensor = None
if cloned_voice_state_tensor is not None:
del cloned_voice_state_tensor
cloned_voice_state_tensor = None
memory_cleanup()
trigger_background_cleanup_check()