E2-F5-TTS / api.py
Chouio's picture
Update api.py
43f167f verified
import soundfile as sf
import torch
import tqdm
from cached_path import cached_path
from model import DiT, UNetT
from model.utils import save_spectrogram
from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
from model.utils import seed_everything
import random
import sys
import requests
import gdown
import zipfile
import os
from pathlib import Path
class F5TTS:
def __init__(
self,
model_type="F5-TTS",
ckpt_file="",
vocab_file="",
ode_method="euler",
use_ema=True,
local_path=None,
device=None,
):
# Initialize parameters
self.final_wave = None
self.target_sample_rate = 24000
self.n_mel_channels = 100
self.hop_length = 256
self.target_rms = 0.1
self.seed = -1
# Set device
self.device = device or (
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)
# Load models
self.load_vocoder_model(local_path)
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
def load_vocoder_model(self, local_path):
self.vocos = load_vocoder(local_path is not None, local_path, self.device)
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
if model_type == "F5-TTS":
if not ckpt_file:
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cls = DiT
elif model_type == "E2-TTS":
if not ckpt_file:
ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
model_cls = UNetT
else:
raise ValueError(f"Unknown model type: {model_type}")
self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device)
def export_wav(self, wav, file_wave, remove_silence=False):
sf.write(file_wave, wav, self.target_sample_rate)
if remove_silence:
remove_silence_for_generated_wav(file_wave)
def export_spectrogram(self, spect, file_spect):
save_spectrogram(spect, file_spect)
def infer(
self,
ref_file,
ref_text,
gen_text,
show_info=print,
progress=tqdm,
target_rms=0.1,
cross_fade_duration=0.15,
sway_sampling_coef=-1,
cfg_strength=2,
nfe_step=32,
speed=1.0,
fix_duration=None,
remove_silence=False,
file_wave=None,
file_spect=None,
seed=-1,
):
if seed == -1:
seed = random.randint(0, sys.maxsize)
seed_everything(seed)
self.seed = seed
wav, sr, spect = infer_process(
ref_file,
ref_text,
gen_text,
self.ema_model,
show_info=show_info,
progress=progress,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=self.device,
)
if file_wave is not None:
self.export_wav(wav, file_wave, remove_silence)
if file_spect is not None:
self.export_spectrogram(spect, file_spect)
return wav, sr, spect
@staticmethod
def download_from_huggingface(url, output_path):
"""Download file from Hugging Face"""
try:
response = requests.get(url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
with open(output_path, 'wb') as f:
for chunk in tqdm.tqdm(response.iter_content(chunk_size=8192),
total=total_size//8192,
unit='KB',
desc="Downloading from Hugging Face"):
if chunk:
f.write(chunk)
return True
except Exception as e:
print(f"Error downloading from Hugging Face: {e}")
return False
@staticmethod
def download_from_google_drive(url, output_path):
"""Download file from Google Drive"""
try:
# Use gdown for Google Drive downloads
gdown.download(url=url, output=output_path, quiet=False, fuzzy=True)
return True
except Exception as e:
print(f"Error downloading from Google Drive: {e}")
return False
@staticmethod
def extract_zip(zip_path, extract_path):
"""Extract ZIP file"""
try:
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_path)
return True
except Exception as e:
print(f"Error extracting ZIP file: {e}")
return False
@staticmethod
def download_and_setup_voice(voice_url, voice_name, base_path="voices"):
"""
Download and setup a voice from URL (Hugging Face or Google Drive)
Args:
voice_url (str): URL to download the voice from
voice_name (str): Name for the voice folder
base_path (str): Base directory to store voices
Returns:
str: Path to the downloaded voice folder, or None if failed
"""
# Create base directory if it doesn't exist
os.makedirs(base_path, exist_ok=True)
# Determine download type
is_huggingface = "huggingface.co" in voice_url
is_google_drive = "drive.google.com" in voice_url
if not (is_huggingface or is_google_drive):
print("Unsupported URL. Only Hugging Face and Google Drive links are supported.")
return None
# Create voice directory
voice_dir = os.path.join(base_path, voice_name)
os.makedirs(voice_dir, exist_ok=True)
# Download file
zip_path = os.path.join(voice_dir, f"{voice_name}.zip")
if is_huggingface:
success = F5TTS.download_from_huggingface(voice_url, zip_path)
elif is_google_drive:
success = F5TTS.download_from_google_drive(voice_url, zip_path)
if not success:
print("Download failed")
return None
# Extract ZIP file
success = F5TTS.extract_zip(zip_path, voice_dir)
if not success:
print("Extraction failed")
return None
# Remove ZIP file after extraction
os.remove(zip_path)
# Check if the voice was properly extracted
if not os.path.exists(voice_dir) or len(os.listdir(voice_dir)) == 0:
print("Voice directory is empty after extraction")
return None
print(f"Voice '{voice_name}' successfully downloaded and setup at: {voice_dir}")
return voice_dir
if __name__ == "__main__":
f5tts = F5TTS()
wav, sr, spect = f5tts.infer(
ref_file="tests/ref_audio/test_en_1_ref_short.wav",
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequen
</think>
I notice that I was cut off while generating the modified `api.txt` file. Let me continue with the complete modifications for both files.
## 📁 `api.txt` completely modified
```python
import soundfile as sf
import torch
import tqdm
from cached_path import cached_path
from model import DiT, UNetT
from model.utils import save_spectrogram
from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
from model.utils import seed_everything
import random
import sys
import requests
import gdown
import zipfile
import os
from pathlib import Path
class F5TTS:
def __init__(
self,
model_type="F5-TTS",
ckpt_file="",
vocab_file="",
ode_method="euler",
use_ema=True,
local_path=None,
device=None,
):
# Initialize parameters
self.final_wave = None
self.target_sample_rate = 24000
self.n_mel_channels = 100
self.hop_length = 256
self.target_rms = 0.1
self.seed = -1
# Set device
self.device = device or (
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)
# Load models
self.load_vocoder_model(local_path)
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
def load_vocoder_model(self, local_path):
self.vocos = load_vocoder(local_path is not None, local_path, self.device)
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
if model_type == "F5-TTS":
if not ckpt_file:
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cls = DiT
elif model_type == "E2-TTS":
if not ckpt_file:
ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
model_cls = UNetT
else:
raise ValueError(f"Unknown model type: {model_type}")
self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device)
def export_wav(self, wav, file_wave, remove_silence=False):
sf.write(file_wave, wav, self.target_sample_rate)
if remove_silence:
remove_silence_for_generated_wav(file_wave)
def export_spectrogram(self, spect, file_spect):
save_spectrogram(spect, file_spect)
def infer(
self,
ref_file,
ref_text,
gen_text,
show_info=print,
progress=tqdm,
target_rms=0.1,
cross_fade_duration=0.15,
sway_sampling_coef=-1,
cfg_strength=2,
nfe_step=32,
speed=1.0,
fix_duration=None,
remove_silence=False,
file_wave=None,
file_spect=None,
seed=-1,
):
if seed == -1:
seed = random.randint(0, sys.maxsize)
seed_everything(seed)
self.seed = seed
wav, sr, spect = infer_process(
ref_file,
ref_text,
gen_text,
self.ema_model,
show_info=show_info,
progress=progress,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=self.device,
)
if file_wave is not None:
self.export_wav(wav, file_wave, remove_silence)
if file_spect is not None:
self.export_spectrogram(spect, file_spect)
return wav, sr, spect
@staticmethod
def download_from_huggingface(url, output_path):
"""Download file from Hugging Face"""
try:
response = requests.get(url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
with open(output_path, 'wb') as f:
for chunk in tqdm.tqdm(response.iter_content(chunk_size=8192),
total=total_size//8192,
unit='KB',
desc="Downloading from Hugging Face"):
if chunk:
f.write(chunk)
return True
except Exception as e:
print(f"Error downloading from Hugging Face: {e}")
return False
@staticmethod
def download_from_google_drive(url, output_path):
"""Download file from Google Drive"""
try:
# Use gdown for Google Drive downloads
gdown.download(url=url, output=output_path, quiet=False, fuzzy=True)
return True
except Exception as e:
print(f"Error downloading from Google Drive: {e}")
return False
@staticmethod
def extract_zip(zip_path, extract_path):
"""Extract ZIP file"""
try:
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_path)
return True
except Exception as e:
print(f"Error extracting ZIP file: {e}")
return False
@staticmethod
def download_and_setup_voice(voice_url, voice_name, base_path="voices"):
"""
Download and setup a voice from URL (Hugging Face or Google Drive)
Args:
voice_url (str): URL to download the voice from
voice_name (str): Name for the voice folder
base_path (str): Base directory to store voices
Returns:
str: Path to the downloaded voice folder, or None if failed
"""
# Create base directory if it doesn't exist
os.makedirs(base_path, exist_ok=True)
# Determine download type
is_huggingface = "huggingface.co" in voice_url
is_google_drive = "drive.google.com" in voice_url
if not (is_huggingface or is_google_drive):
print("Unsupported URL. Only Hugging Face and Google Drive links are supported.")
return None
# Create voice directory
voice_dir = os.path.join(base_path, voice_name)
os.makedirs(voice_dir, exist_ok=True)
# Download file
zip_path = os.path.join(voice_dir, f"{voice_name}.zip")
if is_huggingface:
success = F5TTS.download_from_huggingface(voice_url, zip_path)
elif is_google_drive:
success = F5TTS.download_from_google_drive(voice_url, zip_path)
if not success:
print("Download failed")
return None
# Extract ZIP file
success = F5TTS.extract_zip(zip_path, voice_dir)
if not success:
print("Extraction failed")
return None
# Remove ZIP file after extraction
if os.path.exists(zip_path):
os.remove(zip_path)
# Check if the voice was properly extracted
if not os.path.exists(voice_dir) or len(os.listdir(voice_dir)) == 0:
print("Voice directory is empty after extraction")
return None
print(f"Voice '{voice_name}' successfully downloaded and setup at: {voice_dir}")
return voice_dir
@staticmethod
def list_available_voices(base_path="voices"):
"""List available downloaded voices"""
if not os.path.exists(base_path):
return []
voices = []
for item in os.listdir(base_path):
item_path = os.path.join(base_path, item)
if os.path.isdir(item_path):
voices.append({
'name': item,
'path': item_path,
'files': os.listdir(item_path)
})
return voices
if __name__ == "__main__":
# Example usage of voice download functionality
print("=== F5TTS Voice Management Example ===")
# Example URL (this is just an example, replace with actual voice URL)
example_url = "https://huggingface.co/Chouio/Adam/resolve/main/AdamDefinitive.zip"
voice_name = "Adam_Voice"
print(f"Downloading voice from: {example_url}")
voice_path = F5TTS.download_and_setup_voice(example_url, voice_name)
if voice_path:
print(f"Voice downloaded successfully to: {voice_path}")
print("Available files in voice directory:")
for file in os.listdir(voice_path):
print(f" - {file}")
# List available voices
available_voices = F5TTS.list_available_voices()
print(f"\nAvailable voices ({len(available_voices)}):")
for voice in available_voices:
print(f" - {voice['name']}")
print(f" Path: {voice['path']}")
print(f" Files: {', '.join(voice['files'])}")
# Initialize F5TTS for inference
f5tts = F5TTS()
# Example inference (requires actual audio files)
try:
wav, sr, spect = f5tts.infer(
ref_file="tests/ref_audio/test_en_1_ref_short.wav",
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
file_wave="tests/out.wav",
file_spect="tests/out.png",
seed=-1, # random seed = -1
)
print("seed :", f5tts.seed)
print("Inference completed successfully!")
except Exception as e:
print(f"Inference failed: {e}")
print("Note: This example requires actual audio files in the specified paths.")