# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Dehua Tao, # Xiao Chen) # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility from importlib.resources import files import matplotlib matplotlib.use("Agg") import numpy as np import torch import torchaudio import tqdm import logging # torch.set_printoptions(profile="full") # from f5_tts.model import CFM from f5_tts.model.utils import ( get_tokenizer, convert_char_to_pinyin, ) from f5_tts.model.modules import MelSpec device = ( "cuda" if torch.cuda.is_available() else ( "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) ) # ----------------------------------------- target_sample_rate = 24000 n_mel_channels = 100 hop_length = 256 win_length = 1024 n_fft = 1024 mel_spec_type = "vocos" target_rms = 0.1 cross_fade_duration = 0.15 ode_method = "euler" nfe_step = 32 # 16, 32 cfg_strength = 2.0 sway_sampling_coef = -1.0 speed = 1.0 fix_duration = None seed = 3214 # ----------------------------------------- def chunk_infer_batch_process( ref_audio, ref_text, gen_text_batches, model_obj, vocoder, mel_spec_type="vocos", progress=tqdm, target_rms=0.1, cross_fade_duration=0.15, nfe_step=32, cfg_strength=2.0, sway_sampling_coef=-1.0, speed=1.0, fix_duration=None, device=None, chunk_cond_proportion=0.5, chunk_look_ahead=0, max_ref_duration=4.5, ref_head_cut=False, ): audio, sr = ref_audio if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) rms = torch.sqrt(torch.mean(torch.square(audio))) if rms < target_rms: audio = audio * target_rms / rms if sr != target_sample_rate: resampler = torchaudio.transforms.Resample(sr, target_sample_rate) audio = resampler(audio) logging.info( "audio shape:" + str(audio.shape) + "; ref_text shape:" + str(len(ref_text)) ) ref_duration = audio.shape[1] / target_sample_rate if ref_duration > max_ref_duration: reserved_ref_audio_len = round(max_ref_duration * target_sample_rate) if ref_head_cut: logging.info(f"Using the first {max_ref_duration} seconds as ref audio") audio = audio[:, :reserved_ref_audio_len] ref_text = ref_text[ : round(max_ref_duration * len(ref_text) / ref_duration) ] else: logging.info(f"Using the last {max_ref_duration} seconds as ref audio") audio = audio[:, -reserved_ref_audio_len:] ref_text = ref_text[ -round(max_ref_duration * len(ref_text) / ref_duration) : ] logging.info( "audio shape:" + str(audio.shape) + "; ref_text shape:" + str(len(ref_text)) ) audio = audio.to(device) generated_waves = [] spectrograms = [] # fixed_ref_audio_len = audio.shape[-1] // hop_length mel_spec_module = MelSpec(mel_spec_type=mel_spec_type) fixed_ref_audio_mel_spec = mel_spec_module(audio) # The last dim should be num_channels fixed_ref_audio_mel_cond = fixed_ref_audio_mel_spec.permute(0, 2, 1) fixed_ref_audio_len = fixed_ref_audio_mel_cond.shape[1] assert isinstance(ref_text, list) is True fixed_ref_text = ref_text[:] fixed_ref_text_len = len(fixed_ref_text) mel_cond = fixed_ref_audio_mel_cond.clone() prev_chunk_audio_len = 0 for i, gen_text in enumerate(progress.tqdm(gen_text_batches)): # Prepare the text final_text_list = [ref_text + gen_text] logging.info(f"final_text_list: {final_text_list}") if fix_duration is not None: duration = int(fix_duration * target_sample_rate / hop_length) else: # Calculate duration assert isinstance(gen_text, list) is True gen_text_len = len(gen_text) duration = ( fixed_ref_audio_len + prev_chunk_audio_len + int(fixed_ref_audio_len / fixed_ref_text_len * gen_text_len / speed) ) logging.info(f"Duration: {duration}") # inference with torch.inference_mode(): logging.info(f"generate with nfe_step:{nfe_step}, cfg_strength:{cfg_strength}, sway_sampling_coef:{sway_sampling_coef}") # logging.info("mel_cond: " + str(mel_cond)) generated, _ = model_obj.sample( cond=mel_cond, text=final_text_list, duration=duration, steps=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, seed=seed, ) generated = generated.to(torch.float32) logging.info("gen mel shape: " + str(generated.shape)) # try to remove condition mel stripped_generated = generated[ :, (fixed_ref_audio_len + prev_chunk_audio_len) :, : ] # remove chunk_look_ahead from the tail of each generated mel look_ahead_mel_len = round( (duration - fixed_ref_audio_len - prev_chunk_audio_len) * chunk_look_ahead / len(gen_text) ) if look_ahead_mel_len > 0 and i < len(gen_text_batches) - 1: stripped_generated_without_look_ahead = stripped_generated[ :, :(-look_ahead_mel_len), :, ] # try to remove the chunk_look_ahead from the tail of gen_text gen_text = gen_text[:-chunk_look_ahead] else: stripped_generated_without_look_ahead = stripped_generated logging.info("gen mel shape: %s, gen text len: %d" % (str(stripped_generated_without_look_ahead.shape), len(gen_text))) # logging.info("generated mel: " + str(generated)) # prev chunk audio len is the length without fixed condition and chunk look ahead prev_chunk_audio_len = stripped_generated_without_look_ahead.shape[1] # prev_chunk_audio_len_with_look_ahead = stripped_generated.shape[1] # generate wav with look ahead generated_mel_spec = stripped_generated_without_look_ahead.permute(0, 2, 1) # generated_mel_spec = stripped_generated.permute(0, 2, 1) if mel_spec_type == "vocos": generated_wave = vocoder.decode(generated_mel_spec) elif mel_spec_type == "bigvgan": generated_wave = vocoder(generated_mel_spec) # strip look ahead wav from generated wav # if look_ahead_mel_len > 0 and i < len(gen_text_batches) - 1: # look_ahead_wav_len = round( # look_ahead_mel_len # * generated_wave.shape[1] # / prev_chunk_audio_len_with_look_ahead # ) # generated_wave = generated_wave[:, :-look_ahead_wav_len] if rms < target_rms: generated_wave = generated_wave * rms / target_rms logging.info("gen wav shape: " + str(generated_wave.shape)) # logging.info("generated wav: " + str(generated_wave)) # wav -> numpy generated_wave = generated_wave.squeeze().cpu().numpy() generated_waves.append(generated_wave) spectrograms.append(generated_mel_spec[0].cpu().numpy()) prev_chunk_cond_audio_len = round(chunk_cond_proportion * prev_chunk_audio_len) if prev_chunk_audio_len > prev_chunk_cond_audio_len: gen_text_cond = gen_text[-round(chunk_cond_proportion * len(gen_text)):] prev_chunk_audio_len = prev_chunk_cond_audio_len generated_cond = stripped_generated_without_look_ahead[:, (-prev_chunk_audio_len):, :] else: generated_cond = stripped_generated_without_look_ahead gen_text_cond = gen_text logging.info("gen text cond len: %d, gen mel cond len: %d" % (len(gen_text_cond), len(generated_cond))) ref_text = fixed_ref_text + gen_text_cond mel_cond = torch.cat([fixed_ref_audio_mel_cond, generated_cond], dim=1) # Combine all generated waves with cross-fading if cross_fade_duration <= 0: # Simply concatenate logging.info("simply concatenate") final_wave = np.concatenate(generated_waves) else: final_wave = generated_waves[0] for i in range(1, len(generated_waves)): prev_wave = final_wave next_wave = generated_waves[i] # Calculate cross-fade samples, ensuring it does not exceed wave lengths cross_fade_samples = int(cross_fade_duration * target_sample_rate) cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) if cross_fade_samples <= 0: # No overlap possible, concatenate final_wave = np.concatenate([prev_wave, next_wave]) continue # Overlapping parts prev_overlap = prev_wave[-cross_fade_samples:] next_overlap = next_wave[:cross_fade_samples] # Fade out and fade in # fade_out = np.linspace(1, 0, cross_fade_samples) # fade_in = np.linspace(0, 1, cross_fade_samples) wave_window = np.hamming(2 * cross_fade_samples) fade_out = wave_window[cross_fade_samples:] fade_in = wave_window[:cross_fade_samples] # Cross-faded overlap cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in # Combine new_wave = np.concatenate( [ prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:], ] ) final_wave = new_wave # Create a combined spectrogram combined_spectrogram = np.concatenate(spectrograms, axis=1) return final_wave, target_sample_rate, combined_spectrogram