tts / src /audio /validator.py
hadadrjt's picture
[3/?] Pocket TTS: Handle multiple format extensions for voice cloning.
dae9fa5
#
# SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
# SPDX-License-Identifier: Apache-2.0
#
import os
import wave
from config import (
SUPPORTED_AUDIO_EXTENSIONS,
AUDIO_FORMAT_DISPLAY_NAME_OVERRIDES
)
def build_format_display_names_from_supported_extensions():
format_display_names = {}
for extension in SUPPORTED_AUDIO_EXTENSIONS:
format_code = extension.lstrip(".")
if format_code in AUDIO_FORMAT_DISPLAY_NAME_OVERRIDES:
format_display_names[format_code] = AUDIO_FORMAT_DISPLAY_NAME_OVERRIDES[format_code]
else:
format_display_names[format_code] = format_code.upper()
format_display_names["unknown"] = "Unknown"
return format_display_names
FORMAT_DISPLAY_NAMES = build_format_display_names_from_supported_extensions()
def get_audio_file_extension(file_path):
if not file_path:
return None
_, extension = os.path.splitext(file_path)
return extension.lower()
def is_supported_audio_extension(file_path):
extension = get_audio_file_extension(file_path)
if extension is None:
return False
return extension in SUPPORTED_AUDIO_EXTENSIONS
def validate_file_exists_and_readable(file_path):
if not file_path:
return False, "No audio file provided."
if not os.path.exists(file_path):
return False, "Audio file does not exist."
if not os.path.isfile(file_path):
return False, "The provided path is not a valid file."
try:
file_size = os.path.getsize(file_path)
except OSError as size_error:
return False, f"Cannot read file size: {str(size_error)}"
if file_size == 0:
return False, "Audio file is empty (0 bytes)."
if file_size < 44:
return False, "Audio file is too small to be a valid audio file."
try:
with open(file_path, "rb") as test_file:
test_file.read(1)
except IOError as read_error:
return False, f"Audio file is not readable: {str(read_error)}"
return True, None
def detect_audio_format_from_header(file_path):
try:
with open(file_path, "rb") as audio_file:
header_bytes = audio_file.read(32)
if len(header_bytes) < 4:
return None, "File is too small to determine audio format."
if len(header_bytes) >= 12:
if header_bytes[:4] == b"RIFF" and header_bytes[8:12] == b"WAVE":
return "wav", None
if header_bytes[:3] == b"ID3":
return "mp3", None
if len(header_bytes) >= 2:
first_two_bytes = header_bytes[:2]
mp3_sync_bytes = [
b"\xff\xfb",
b"\xff\xfa",
b"\xff\xf3",
b"\xff\xf2",
b"\xff\xe0",
b"\xff\xe2",
b"\xff\xe3"
]
if first_two_bytes in mp3_sync_bytes:
return "mp3", None
if header_bytes[:4] == b"fLaC":
return "flac", None
if header_bytes[:4] == b"OggS":
return "ogg", None
if len(header_bytes) >= 12:
if header_bytes[:4] == b"FORM" and header_bytes[8:12] in [b"AIFF", b"AIFC"]:
return "aiff", None
if len(header_bytes) >= 8:
if header_bytes[4:8] == b"ftyp":
return "m4a", None
if len(header_bytes) >= 4:
if header_bytes[:4] == b"\x1aE\xdf\xa3":
return "webm", None
if len(header_bytes) >= 8:
if header_bytes[4:8] in [b"mdat", b"moov", b"free", b"skip", b"wide"]:
return "m4a", None
file_extension = get_audio_file_extension(file_path)
if file_extension and file_extension in SUPPORTED_AUDIO_EXTENSIONS:
return file_extension.lstrip("."), None
return "unknown", "Could not determine audio format from file header. The file may be corrupted or in an unsupported format."
except IOError as io_error:
return None, f"Error reading file header: {str(io_error)}"
except Exception as detection_error:
return None, f"Unexpected error detecting audio format: {str(detection_error)}"
def validate_wav_file_structure(file_path):
try:
with wave.open(file_path, "rb") as wav_file:
number_of_channels = wav_file.getnchannels()
sample_width_bytes = wav_file.getsampwidth()
sample_rate = wav_file.getframerate()
number_of_frames = wav_file.getnframes()
if number_of_channels < 1:
return False, "WAV file has no audio channels."
if number_of_channels > 16:
return False, f"WAV file has too many channels ({number_of_channels}). Maximum supported is 16."
if sample_width_bytes < 1:
return False, "WAV file has invalid sample width (less than 1 byte)."
if sample_width_bytes > 4:
return False, f"WAV file has unsupported sample width ({sample_width_bytes} bytes). Maximum supported is 4 bytes (32-bit)."
if sample_rate < 100:
return False, f"WAV file has invalid sample rate ({sample_rate} Hz). Minimum supported is 100 Hz."
if sample_rate > 384000:
return False, f"WAV file has unsupported sample rate ({sample_rate} Hz). Maximum supported is 384000 Hz."
if number_of_frames < 1:
return False, "WAV file contains no audio frames."
audio_duration_seconds = number_of_frames / sample_rate
if audio_duration_seconds < 0.1:
return False, f"Audio is too short ({audio_duration_seconds:.2f} seconds). Minimum duration is 0.1 seconds."
if audio_duration_seconds > 3600:
return False, f"Audio is too long ({audio_duration_seconds:.0f} seconds). Maximum duration is 1 hour."
return True, None
except wave.Error as wav_error:
error_message = str(wav_error)
if "file does not start with RIFF id" in error_message:
return False, "File has .wav extension but is not a valid WAV file. It may be a different audio format renamed to .wav."
if "unknown format" in error_message.lower():
return False, "WAV file uses an unsupported audio encoding format."
return False, f"Invalid WAV file structure: {error_message}"
except EOFError:
return False, "WAV file is truncated or corrupted (unexpected end of file)."
except Exception as validation_error:
return False, f"Error validating WAV file: {str(validation_error)}"
def perform_comprehensive_audio_validation(file_path):
file_exists_valid, file_exists_error = validate_file_exists_and_readable(file_path)
if not file_exists_valid:
return False, False, None, file_exists_error
file_extension = get_audio_file_extension(file_path)
if not is_supported_audio_extension(file_path):
supported_formats_list = ", ".join(SUPPORTED_AUDIO_EXTENSIONS)
return False, False, None, f"Unsupported file format '{file_extension}'. Supported formats are: {supported_formats_list}"
detected_format, detection_error = detect_audio_format_from_header(file_path)
if detected_format is None:
return False, False, None, detection_error
is_wav_format = (detected_format == "wav")
if is_wav_format:
wav_structure_valid, wav_structure_error = validate_wav_file_structure(file_path)
if not wav_structure_valid:
return False, True, "wav", wav_structure_error
return True, is_wav_format, detected_format, None
def get_format_display_name(format_code):
if format_code is None:
return "Unknown"
if format_code in FORMAT_DISPLAY_NAMES:
return FORMAT_DISPLAY_NAMES[format_code]
return format_code.upper()