|
|
import soundfile as sf |
|
|
import torch |
|
|
import tqdm |
|
|
from cached_path import cached_path |
|
|
from model import DiT, UNetT |
|
|
from model.utils import save_spectrogram |
|
|
from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav |
|
|
from model.utils import seed_everything |
|
|
import random |
|
|
import sys |
|
|
import requests |
|
|
import gdown |
|
|
import zipfile |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
class F5TTS: |
|
|
def __init__( |
|
|
self, |
|
|
model_type="F5-TTS", |
|
|
ckpt_file="", |
|
|
vocab_file="", |
|
|
ode_method="euler", |
|
|
use_ema=True, |
|
|
local_path=None, |
|
|
device=None, |
|
|
): |
|
|
|
|
|
self.final_wave = None |
|
|
self.target_sample_rate = 24000 |
|
|
self.n_mel_channels = 100 |
|
|
self.hop_length = 256 |
|
|
self.target_rms = 0.1 |
|
|
self.seed = -1 |
|
|
|
|
|
|
|
|
self.device = device or ( |
|
|
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
) |
|
|
|
|
|
|
|
|
self.load_vocoder_model(local_path) |
|
|
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema) |
|
|
|
|
|
def load_vocoder_model(self, local_path): |
|
|
self.vocos = load_vocoder(local_path is not None, local_path, self.device) |
|
|
|
|
|
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema): |
|
|
if model_type == "F5-TTS": |
|
|
if not ckpt_file: |
|
|
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")) |
|
|
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) |
|
|
model_cls = DiT |
|
|
elif model_type == "E2-TTS": |
|
|
if not ckpt_file: |
|
|
ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")) |
|
|
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) |
|
|
model_cls = UNetT |
|
|
else: |
|
|
raise ValueError(f"Unknown model type: {model_type}") |
|
|
|
|
|
self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device) |
|
|
|
|
|
def export_wav(self, wav, file_wave, remove_silence=False): |
|
|
sf.write(file_wave, wav, self.target_sample_rate) |
|
|
if remove_silence: |
|
|
remove_silence_for_generated_wav(file_wave) |
|
|
|
|
|
def export_spectrogram(self, spect, file_spect): |
|
|
save_spectrogram(spect, file_spect) |
|
|
|
|
|
def infer( |
|
|
self, |
|
|
ref_file, |
|
|
ref_text, |
|
|
gen_text, |
|
|
show_info=print, |
|
|
progress=tqdm, |
|
|
target_rms=0.1, |
|
|
cross_fade_duration=0.15, |
|
|
sway_sampling_coef=-1, |
|
|
cfg_strength=2, |
|
|
nfe_step=32, |
|
|
speed=1.0, |
|
|
fix_duration=None, |
|
|
remove_silence=False, |
|
|
file_wave=None, |
|
|
file_spect=None, |
|
|
seed=-1, |
|
|
): |
|
|
if seed == -1: |
|
|
seed = random.randint(0, sys.maxsize) |
|
|
seed_everything(seed) |
|
|
self.seed = seed |
|
|
|
|
|
wav, sr, spect = infer_process( |
|
|
ref_file, |
|
|
ref_text, |
|
|
gen_text, |
|
|
self.ema_model, |
|
|
show_info=show_info, |
|
|
progress=progress, |
|
|
target_rms=target_rms, |
|
|
cross_fade_duration=cross_fade_duration, |
|
|
nfe_step=nfe_step, |
|
|
cfg_strength=cfg_strength, |
|
|
sway_sampling_coef=sway_sampling_coef, |
|
|
speed=speed, |
|
|
fix_duration=fix_duration, |
|
|
device=self.device, |
|
|
) |
|
|
|
|
|
if file_wave is not None: |
|
|
self.export_wav(wav, file_wave, remove_silence) |
|
|
if file_spect is not None: |
|
|
self.export_spectrogram(spect, file_spect) |
|
|
|
|
|
return wav, sr, spect |
|
|
|
|
|
@staticmethod |
|
|
def download_from_huggingface(url, output_path): |
|
|
"""Download file from Hugging Face""" |
|
|
try: |
|
|
response = requests.get(url, stream=True) |
|
|
response.raise_for_status() |
|
|
total_size = int(response.headers.get('content-length', 0)) |
|
|
|
|
|
with open(output_path, 'wb') as f: |
|
|
for chunk in tqdm.tqdm(response.iter_content(chunk_size=8192), |
|
|
total=total_size//8192, |
|
|
unit='KB', |
|
|
desc="Downloading from Hugging Face"): |
|
|
if chunk: |
|
|
f.write(chunk) |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Error downloading from Hugging Face: {e}") |
|
|
return False |
|
|
|
|
|
@staticmethod |
|
|
def download_from_google_drive(url, output_path): |
|
|
"""Download file from Google Drive""" |
|
|
try: |
|
|
|
|
|
gdown.download(url=url, output=output_path, quiet=False, fuzzy=True) |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Error downloading from Google Drive: {e}") |
|
|
return False |
|
|
|
|
|
@staticmethod |
|
|
def extract_zip(zip_path, extract_path): |
|
|
"""Extract ZIP file""" |
|
|
try: |
|
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
|
|
zip_ref.extractall(extract_path) |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Error extracting ZIP file: {e}") |
|
|
return False |
|
|
|
|
|
@staticmethod |
|
|
def download_and_setup_voice(voice_url, voice_name, base_path="voices"): |
|
|
""" |
|
|
Download and setup a voice from URL (Hugging Face or Google Drive) |
|
|
|
|
|
Args: |
|
|
voice_url (str): URL to download the voice from |
|
|
voice_name (str): Name for the voice folder |
|
|
base_path (str): Base directory to store voices |
|
|
|
|
|
Returns: |
|
|
str: Path to the downloaded voice folder, or None if failed |
|
|
""" |
|
|
|
|
|
os.makedirs(base_path, exist_ok=True) |
|
|
|
|
|
|
|
|
is_huggingface = "huggingface.co" in voice_url |
|
|
is_google_drive = "drive.google.com" in voice_url |
|
|
|
|
|
if not (is_huggingface or is_google_drive): |
|
|
print("Unsupported URL. Only Hugging Face and Google Drive links are supported.") |
|
|
return None |
|
|
|
|
|
|
|
|
voice_dir = os.path.join(base_path, voice_name) |
|
|
os.makedirs(voice_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
zip_path = os.path.join(voice_dir, f"{voice_name}.zip") |
|
|
|
|
|
if is_huggingface: |
|
|
success = F5TTS.download_from_huggingface(voice_url, zip_path) |
|
|
elif is_google_drive: |
|
|
success = F5TTS.download_from_google_drive(voice_url, zip_path) |
|
|
|
|
|
if not success: |
|
|
print("Download failed") |
|
|
return None |
|
|
|
|
|
|
|
|
success = F5TTS.extract_zip(zip_path, voice_dir) |
|
|
if not success: |
|
|
print("Extraction failed") |
|
|
return None |
|
|
|
|
|
|
|
|
os.remove(zip_path) |
|
|
|
|
|
|
|
|
if not os.path.exists(voice_dir) or len(os.listdir(voice_dir)) == 0: |
|
|
print("Voice directory is empty after extraction") |
|
|
return None |
|
|
|
|
|
print(f"Voice '{voice_name}' successfully downloaded and setup at: {voice_dir}") |
|
|
return voice_dir |
|
|
|
|
|
if __name__ == "__main__": |
|
|
f5tts = F5TTS() |
|
|
wav, sr, spect = f5tts.infer( |
|
|
ref_file="tests/ref_audio/test_en_1_ref_short.wav", |
|
|
ref_text="some call me nature, others call me mother nature.", |
|
|
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequen |
|
|
</think> |
|
|
|
|
|
I notice that I was cut off while generating the modified `api.txt` file. Let me continue with the complete modifications for both files. |
|
|
|
|
|
## 📁 `api.txt` completely modified |
|
|
|
|
|
```python |
|
|
import soundfile as sf |
|
|
import torch |
|
|
import tqdm |
|
|
from cached_path import cached_path |
|
|
from model import DiT, UNetT |
|
|
from model.utils import save_spectrogram |
|
|
from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav |
|
|
from model.utils import seed_everything |
|
|
import random |
|
|
import sys |
|
|
import requests |
|
|
import gdown |
|
|
import zipfile |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
class F5TTS: |
|
|
def __init__( |
|
|
self, |
|
|
model_type="F5-TTS", |
|
|
ckpt_file="", |
|
|
vocab_file="", |
|
|
ode_method="euler", |
|
|
use_ema=True, |
|
|
local_path=None, |
|
|
device=None, |
|
|
): |
|
|
# Initialize parameters |
|
|
self.final_wave = None |
|
|
self.target_sample_rate = 24000 |
|
|
self.n_mel_channels = 100 |
|
|
self.hop_length = 256 |
|
|
self.target_rms = 0.1 |
|
|
self.seed = -1 |
|
|
|
|
|
# Set device |
|
|
self.device = device or ( |
|
|
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
) |
|
|
|
|
|
# Load models |
|
|
self.load_vocoder_model(local_path) |
|
|
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema) |
|
|
|
|
|
def load_vocoder_model(self, local_path): |
|
|
self.vocos = load_vocoder(local_path is not None, local_path, self.device) |
|
|
|
|
|
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema): |
|
|
if model_type == "F5-TTS": |
|
|
if not ckpt_file: |
|
|
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")) |
|
|
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) |
|
|
model_cls = DiT |
|
|
elif model_type == "E2-TTS": |
|
|
if not ckpt_file: |
|
|
ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")) |
|
|
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) |
|
|
model_cls = UNetT |
|
|
else: |
|
|
raise ValueError(f"Unknown model type: {model_type}") |
|
|
|
|
|
self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device) |
|
|
|
|
|
def export_wav(self, wav, file_wave, remove_silence=False): |
|
|
sf.write(file_wave, wav, self.target_sample_rate) |
|
|
if remove_silence: |
|
|
remove_silence_for_generated_wav(file_wave) |
|
|
|
|
|
def export_spectrogram(self, spect, file_spect): |
|
|
save_spectrogram(spect, file_spect) |
|
|
|
|
|
def infer( |
|
|
self, |
|
|
ref_file, |
|
|
ref_text, |
|
|
gen_text, |
|
|
show_info=print, |
|
|
progress=tqdm, |
|
|
target_rms=0.1, |
|
|
cross_fade_duration=0.15, |
|
|
sway_sampling_coef=-1, |
|
|
cfg_strength=2, |
|
|
nfe_step=32, |
|
|
speed=1.0, |
|
|
fix_duration=None, |
|
|
remove_silence=False, |
|
|
file_wave=None, |
|
|
file_spect=None, |
|
|
seed=-1, |
|
|
): |
|
|
if seed == -1: |
|
|
seed = random.randint(0, sys.maxsize) |
|
|
seed_everything(seed) |
|
|
self.seed = seed |
|
|
|
|
|
wav, sr, spect = infer_process( |
|
|
ref_file, |
|
|
ref_text, |
|
|
gen_text, |
|
|
self.ema_model, |
|
|
show_info=show_info, |
|
|
progress=progress, |
|
|
target_rms=target_rms, |
|
|
cross_fade_duration=cross_fade_duration, |
|
|
nfe_step=nfe_step, |
|
|
cfg_strength=cfg_strength, |
|
|
sway_sampling_coef=sway_sampling_coef, |
|
|
speed=speed, |
|
|
fix_duration=fix_duration, |
|
|
device=self.device, |
|
|
) |
|
|
|
|
|
if file_wave is not None: |
|
|
self.export_wav(wav, file_wave, remove_silence) |
|
|
if file_spect is not None: |
|
|
self.export_spectrogram(spect, file_spect) |
|
|
|
|
|
return wav, sr, spect |
|
|
|
|
|
@staticmethod |
|
|
def download_from_huggingface(url, output_path): |
|
|
"""Download file from Hugging Face""" |
|
|
try: |
|
|
response = requests.get(url, stream=True) |
|
|
response.raise_for_status() |
|
|
total_size = int(response.headers.get('content-length', 0)) |
|
|
|
|
|
with open(output_path, 'wb') as f: |
|
|
for chunk in tqdm.tqdm(response.iter_content(chunk_size=8192), |
|
|
total=total_size//8192, |
|
|
unit='KB', |
|
|
desc="Downloading from Hugging Face"): |
|
|
if chunk: |
|
|
f.write(chunk) |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Error downloading from Hugging Face: {e}") |
|
|
return False |
|
|
|
|
|
@staticmethod |
|
|
def download_from_google_drive(url, output_path): |
|
|
"""Download file from Google Drive""" |
|
|
try: |
|
|
# Use gdown for Google Drive downloads |
|
|
gdown.download(url=url, output=output_path, quiet=False, fuzzy=True) |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Error downloading from Google Drive: {e}") |
|
|
return False |
|
|
|
|
|
@staticmethod |
|
|
def extract_zip(zip_path, extract_path): |
|
|
"""Extract ZIP file""" |
|
|
try: |
|
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
|
|
zip_ref.extractall(extract_path) |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Error extracting ZIP file: {e}") |
|
|
return False |
|
|
|
|
|
@staticmethod |
|
|
def download_and_setup_voice(voice_url, voice_name, base_path="voices"): |
|
|
""" |
|
|
Download and setup a voice from URL (Hugging Face or Google Drive) |
|
|
|
|
|
Args: |
|
|
voice_url (str): URL to download the voice from |
|
|
voice_name (str): Name for the voice folder |
|
|
base_path (str): Base directory to store voices |
|
|
|
|
|
Returns: |
|
|
str: Path to the downloaded voice folder, or None if failed |
|
|
""" |
|
|
# Create base directory if it doesn't exist |
|
|
os.makedirs(base_path, exist_ok=True) |
|
|
|
|
|
# Determine download type |
|
|
is_huggingface = "huggingface.co" in voice_url |
|
|
is_google_drive = "drive.google.com" in voice_url |
|
|
|
|
|
if not (is_huggingface or is_google_drive): |
|
|
print("Unsupported URL. Only Hugging Face and Google Drive links are supported.") |
|
|
return None |
|
|
|
|
|
# Create voice directory |
|
|
voice_dir = os.path.join(base_path, voice_name) |
|
|
os.makedirs(voice_dir, exist_ok=True) |
|
|
|
|
|
# Download file |
|
|
zip_path = os.path.join(voice_dir, f"{voice_name}.zip") |
|
|
|
|
|
if is_huggingface: |
|
|
success = F5TTS.download_from_huggingface(voice_url, zip_path) |
|
|
elif is_google_drive: |
|
|
success = F5TTS.download_from_google_drive(voice_url, zip_path) |
|
|
|
|
|
if not success: |
|
|
print("Download failed") |
|
|
return None |
|
|
|
|
|
# Extract ZIP file |
|
|
success = F5TTS.extract_zip(zip_path, voice_dir) |
|
|
if not success: |
|
|
print("Extraction failed") |
|
|
return None |
|
|
|
|
|
# Remove ZIP file after extraction |
|
|
if os.path.exists(zip_path): |
|
|
os.remove(zip_path) |
|
|
|
|
|
# Check if the voice was properly extracted |
|
|
if not os.path.exists(voice_dir) or len(os.listdir(voice_dir)) == 0: |
|
|
print("Voice directory is empty after extraction") |
|
|
return None |
|
|
|
|
|
print(f"Voice '{voice_name}' successfully downloaded and setup at: {voice_dir}") |
|
|
return voice_dir |
|
|
|
|
|
@staticmethod |
|
|
def list_available_voices(base_path="voices"): |
|
|
"""List available downloaded voices""" |
|
|
if not os.path.exists(base_path): |
|
|
return [] |
|
|
|
|
|
voices = [] |
|
|
for item in os.listdir(base_path): |
|
|
item_path = os.path.join(base_path, item) |
|
|
if os.path.isdir(item_path): |
|
|
voices.append({ |
|
|
'name': item, |
|
|
'path': item_path, |
|
|
'files': os.listdir(item_path) |
|
|
}) |
|
|
return voices |
|
|
|
|
|
if __name__ == "__main__": |
|
|
# Example usage of voice download functionality |
|
|
print("=== F5TTS Voice Management Example ===") |
|
|
|
|
|
# Example URL (this is just an example, replace with actual voice URL) |
|
|
example_url = "https://huggingface.co/Chouio/Adam/resolve/main/AdamDefinitive.zip" |
|
|
voice_name = "Adam_Voice" |
|
|
|
|
|
print(f"Downloading voice from: {example_url}") |
|
|
voice_path = F5TTS.download_and_setup_voice(example_url, voice_name) |
|
|
|
|
|
if voice_path: |
|
|
print(f"Voice downloaded successfully to: {voice_path}") |
|
|
print("Available files in voice directory:") |
|
|
for file in os.listdir(voice_path): |
|
|
print(f" - {file}") |
|
|
|
|
|
# List available voices |
|
|
available_voices = F5TTS.list_available_voices() |
|
|
print(f"\nAvailable voices ({len(available_voices)}):") |
|
|
for voice in available_voices: |
|
|
print(f" - {voice['name']}") |
|
|
print(f" Path: {voice['path']}") |
|
|
print(f" Files: {', '.join(voice['files'])}") |
|
|
|
|
|
# Initialize F5TTS for inference |
|
|
f5tts = F5TTS() |
|
|
|
|
|
# Example inference (requires actual audio files) |
|
|
try: |
|
|
wav, sr, spect = f5tts.infer( |
|
|
ref_file="tests/ref_audio/test_en_1_ref_short.wav", |
|
|
ref_text="some call me nature, others call me mother nature.", |
|
|
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""", |
|
|
file_wave="tests/out.wav", |
|
|
file_spect="tests/out.png", |
|
|
seed=-1, # random seed = -1 |
|
|
) |
|
|
print("seed :", f5tts.seed) |
|
|
print("Inference completed successfully!") |
|
|
except Exception as e: |
|
|
print(f"Inference failed: {e}") |
|
|
print("Note: This example requires actual audio files in the specified paths.") |