API_MC_AI / VietTTS /utils /file_utils.py
duyv's picture
Upload 86 files
a257816 verified
import os
import subprocess
import torchaudio
import soundfile
import numpy as np
from glob import glob
from loguru import logger
from huggingface_hub import snapshot_download
from VietTTS.utils.vad import get_speech
import torchaudio
import os
import subprocess
import tempfile
def convert_to_wav(input_filepath: str, target_sr: int) -> str:
"""
Convert an input audio file to WAV format with the desired sample rate using FFmpeg.
Args:
input_filepath (str): Path to the input audio file.
target_sr (int): Target sample rate.
Returns:
str: Path to the converted WAV file.
"""
temp_wav_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
temp_wav_filepath = temp_wav_file.name
temp_wav_file.close()
ffmpeg_command = [
"ffmpeg", "-y",
"-loglevel", "error",
"-i", input_filepath,
"-ar", str(target_sr),
"-ac", "1",
temp_wav_filepath
]
result = subprocess.run(ffmpeg_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if result.returncode != 0:
os.unlink(temp_wav_filepath)
raise RuntimeError(f"FFmpeg conversion failed: {result.stderr.decode()}")
return temp_wav_filepath
def load_wav(filepath: str, target_sr: int):
"""
Load an audio file in any supported format, convert it to WAV, and load as a tensor.
Args:
filepath (str): Path to the audio file in any format.
target_sr (int): Target sample rate.
Returns:
Tensor: Loaded audio tensor resampled to the target sample rate.
"""
# Check if the file is already in WAV format
if not filepath.lower().endswith(".wav"):
logger.info(f"Converting {filepath} to WAV format")
filepath = convert_to_wav(filepath, target_sr)
# Load the WAV file
speech, sample_rate = torchaudio.load(filepath)
speech = speech.mean(dim=0, keepdim=True) # Convert to mono if not already
if sample_rate != target_sr:
assert sample_rate > target_sr, f'WAV sample rate {sample_rate} must be greater than {target_sr}'
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
return speech
def save_wav(wav: np.ndarray, sr: int, filepath: str):
soundfile.write(filepath, wav, sr)
def load_prompt_speech_from_file(filepath: str, min_duration: float=3, max_duration: float=5, return_numpy: bool=False):
wav = load_wav(filepath, 16000)
if wav.abs().max() > 0.9:
wav = wav / wav.abs().max() * 0.9
wav = get_speech(
audio_input=wav.squeeze(0),
min_duration=min_duration,
max_duration=max_duration,
return_numpy=return_numpy
)
return wav
def load_voices(voice_dir: str):
files = glob(os.path.join(voice_dir, '*.wav')) + glob(os.path.join(voice_dir, '*.mp3'))
voice_name_map = {
os.path.basename(f).split('.')[0]: f
for f in files
}
return voice_name_map
def download_model(save_dir: str):
snapshot_download(repo_id="duyv/viet-tts", local_dir=save_dir)