| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import os |
| | import sys |
| |
|
| | os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" |
| | from importlib.resources import files |
| | import matplotlib |
| |
|
| | matplotlib.use("Agg") |
| |
|
| |
|
| | import numpy as np |
| | import torch |
| | import torchaudio |
| | import tqdm |
| | import logging |
| | |
| | |
| | from f5_tts.model.utils import ( |
| | get_tokenizer, |
| | convert_char_to_pinyin, |
| | ) |
| | from f5_tts.model.modules import MelSpec |
| |
|
| |
|
| | device = ( |
| | "cuda" |
| | if torch.cuda.is_available() |
| | else ( |
| | "xpu" |
| | if torch.xpu.is_available() |
| | else "mps" if torch.backends.mps.is_available() else "cpu" |
| | ) |
| | ) |
| |
|
| | |
| |
|
| | target_sample_rate = 24000 |
| | n_mel_channels = 100 |
| | hop_length = 256 |
| | win_length = 1024 |
| | n_fft = 1024 |
| | mel_spec_type = "vocos" |
| | target_rms = 0.1 |
| | cross_fade_duration = 0.15 |
| | ode_method = "euler" |
| | nfe_step = 32 |
| | cfg_strength = 2.0 |
| | sway_sampling_coef = -1.0 |
| | speed = 1.0 |
| | fix_duration = None |
| | seed = 3214 |
| |
|
| | |
| | def chunk_infer_batch_process( |
| | ref_audio, |
| | ref_text, |
| | gen_text_batches, |
| | model_obj, |
| | vocoder, |
| | mel_spec_type="vocos", |
| | progress=tqdm, |
| | target_rms=0.1, |
| | cross_fade_duration=0.15, |
| | nfe_step=32, |
| | cfg_strength=2.0, |
| | sway_sampling_coef=-1.0, |
| | speed=1.0, |
| | fix_duration=None, |
| | device=None, |
| | chunk_cond_proportion=0.5, |
| | chunk_look_ahead=0, |
| | max_ref_duration=4.5, |
| | ref_head_cut=False, |
| | ): |
| | audio, sr = ref_audio |
| | if audio.shape[0] > 1: |
| | audio = torch.mean(audio, dim=0, keepdim=True) |
| |
|
| | rms = torch.sqrt(torch.mean(torch.square(audio))) |
| | if rms < target_rms: |
| | audio = audio * target_rms / rms |
| | if sr != target_sample_rate: |
| | resampler = torchaudio.transforms.Resample(sr, target_sample_rate) |
| | audio = resampler(audio) |
| |
|
| | logging.info( |
| | "audio shape:" + str(audio.shape) + "; ref_text shape:" + str(len(ref_text)) |
| | ) |
| | ref_duration = audio.shape[1] / target_sample_rate |
| | if ref_duration > max_ref_duration: |
| | reserved_ref_audio_len = round(max_ref_duration * target_sample_rate) |
| | if ref_head_cut: |
| | logging.info(f"Using the first {max_ref_duration} seconds as ref audio") |
| | audio = audio[:, :reserved_ref_audio_len] |
| | ref_text = ref_text[ |
| | : round(max_ref_duration * len(ref_text) / ref_duration) |
| | ] |
| | else: |
| | logging.info(f"Using the last {max_ref_duration} seconds as ref audio") |
| | audio = audio[:, -reserved_ref_audio_len:] |
| | ref_text = ref_text[ |
| | -round(max_ref_duration * len(ref_text) / ref_duration) : |
| | ] |
| | logging.info( |
| | "audio shape:" + str(audio.shape) + "; ref_text shape:" + str(len(ref_text)) |
| | ) |
| | audio = audio.to(device) |
| |
|
| | generated_waves = [] |
| | spectrograms = [] |
| |
|
| | |
| | mel_spec_module = MelSpec(mel_spec_type=mel_spec_type) |
| | fixed_ref_audio_mel_spec = mel_spec_module(audio) |
| | |
| | fixed_ref_audio_mel_cond = fixed_ref_audio_mel_spec.permute(0, 2, 1) |
| | fixed_ref_audio_len = fixed_ref_audio_mel_cond.shape[1] |
| |
|
| | assert isinstance(ref_text, list) is True |
| | fixed_ref_text = ref_text[:] |
| | fixed_ref_text_len = len(fixed_ref_text) |
| |
|
| | mel_cond = fixed_ref_audio_mel_cond.clone() |
| |
|
| | prev_chunk_audio_len = 0 |
| |
|
| | for i, gen_text in enumerate(progress.tqdm(gen_text_batches)): |
| | |
| | final_text_list = [ref_text + gen_text] |
| | logging.info(f"final_text_list: {final_text_list}") |
| |
|
| | if fix_duration is not None: |
| | duration = int(fix_duration * target_sample_rate / hop_length) |
| | else: |
| | |
| | assert isinstance(gen_text, list) is True |
| | gen_text_len = len(gen_text) |
| | duration = ( |
| | fixed_ref_audio_len |
| | + prev_chunk_audio_len |
| | + int(fixed_ref_audio_len / fixed_ref_text_len * gen_text_len / speed) |
| | ) |
| | logging.info(f"Duration: {duration}") |
| |
|
| | |
| | with torch.inference_mode(): |
| | logging.info(f"generate with nfe_step:{nfe_step}, cfg_strength:{cfg_strength}, sway_sampling_coef:{sway_sampling_coef}") |
| | |
| | generated, _ = model_obj.sample( |
| | cond=mel_cond, |
| | text=final_text_list, |
| | duration=duration, |
| | steps=nfe_step, |
| | cfg_strength=cfg_strength, |
| | sway_sampling_coef=sway_sampling_coef, |
| | seed=seed, |
| | ) |
| | generated = generated.to(torch.float32) |
| | logging.info("gen mel shape: " + str(generated.shape)) |
| |
|
| | |
| | stripped_generated = generated[ |
| | :, (fixed_ref_audio_len + prev_chunk_audio_len) :, : |
| | ] |
| |
|
| | |
| | look_ahead_mel_len = round( |
| | (duration - fixed_ref_audio_len - prev_chunk_audio_len) |
| | * chunk_look_ahead |
| | / len(gen_text) |
| | ) |
| | if look_ahead_mel_len > 0 and i < len(gen_text_batches) - 1: |
| | stripped_generated_without_look_ahead = stripped_generated[ |
| | :, |
| | :(-look_ahead_mel_len), |
| | :, |
| | ] |
| | |
| | gen_text = gen_text[:-chunk_look_ahead] |
| | else: |
| | stripped_generated_without_look_ahead = stripped_generated |
| | logging.info("gen mel shape: %s, gen text len: %d" % (str(stripped_generated_without_look_ahead.shape), len(gen_text))) |
| | |
| |
|
| | |
| | prev_chunk_audio_len = stripped_generated_without_look_ahead.shape[1] |
| | |
| |
|
| | |
| | generated_mel_spec = stripped_generated_without_look_ahead.permute(0, 2, 1) |
| | |
| |
|
| | if mel_spec_type == "vocos": |
| | generated_wave = vocoder.decode(generated_mel_spec) |
| | elif mel_spec_type == "bigvgan": |
| | generated_wave = vocoder(generated_mel_spec) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if rms < target_rms: |
| | generated_wave = generated_wave * rms / target_rms |
| |
|
| | logging.info("gen wav shape: " + str(generated_wave.shape)) |
| | |
| |
|
| | |
| | generated_wave = generated_wave.squeeze().cpu().numpy() |
| |
|
| | generated_waves.append(generated_wave) |
| | spectrograms.append(generated_mel_spec[0].cpu().numpy()) |
| |
|
| | prev_chunk_cond_audio_len = round(chunk_cond_proportion * prev_chunk_audio_len) |
| | if prev_chunk_audio_len > prev_chunk_cond_audio_len: |
| | gen_text_cond = gen_text[-round(chunk_cond_proportion * len(gen_text)):] |
| | prev_chunk_audio_len = prev_chunk_cond_audio_len |
| | generated_cond = stripped_generated_without_look_ahead[:, (-prev_chunk_audio_len):, :] |
| | else: |
| | generated_cond = stripped_generated_without_look_ahead |
| | gen_text_cond = gen_text |
| |
|
| | logging.info("gen text cond len: %d, gen mel cond len: %d" % (len(gen_text_cond), len(generated_cond))) |
| |
|
| | ref_text = fixed_ref_text + gen_text_cond |
| | mel_cond = torch.cat([fixed_ref_audio_mel_cond, generated_cond], dim=1) |
| |
|
| | |
| | if cross_fade_duration <= 0: |
| | |
| | logging.info("simply concatenate") |
| | final_wave = np.concatenate(generated_waves) |
| | else: |
| | final_wave = generated_waves[0] |
| | for i in range(1, len(generated_waves)): |
| | prev_wave = final_wave |
| | next_wave = generated_waves[i] |
| |
|
| | |
| | cross_fade_samples = int(cross_fade_duration * target_sample_rate) |
| | cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) |
| |
|
| | if cross_fade_samples <= 0: |
| | |
| | final_wave = np.concatenate([prev_wave, next_wave]) |
| | continue |
| |
|
| | |
| | prev_overlap = prev_wave[-cross_fade_samples:] |
| | next_overlap = next_wave[:cross_fade_samples] |
| |
|
| | |
| | |
| | |
| |
|
| | wave_window = np.hamming(2 * cross_fade_samples) |
| | fade_out = wave_window[cross_fade_samples:] |
| | fade_in = wave_window[:cross_fade_samples] |
| |
|
| | |
| | cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in |
| |
|
| | |
| | new_wave = np.concatenate( |
| | [ |
| | prev_wave[:-cross_fade_samples], |
| | cross_faded_overlap, |
| | next_wave[cross_fade_samples:], |
| | ] |
| | ) |
| |
|
| | final_wave = new_wave |
| |
|
| | |
| | combined_spectrogram = np.concatenate(spectrograms, axis=1) |
| |
|
| | return final_wave, target_sample_rate, combined_spectrogram |
| |
|