juraganvoice / model.py
dzakybd's picture
-
39e3a57
import instld
from instld.errors import InstallingPackageError
is_f5_installed = False
try:
with instld('f5-tts', no_deps=True, catch_output=True):
is_f5_installed = True
print("f5-tts installed successfully")
except InstallingPackageError as e:
print(e.stdout)
print(e.stderr)
import pip
try:
import f5_tts
except Exception as e:
print(e)
e = str(e).split(' ')[-1].replace("'","")
pip.main(['install', e, '--no-deps'])
print(f'retrying install {e}')
try:
import os
os.system('pip install f5-tts --no-deps')
except Exception as e:
print(e)
from huggingface_hub import hf_hub_download
import os
from f5_tts.api import F5TTS
os.makedirs("models", exist_ok=True)
model_path = 'models/checkpoints/model_220000.pt'
if not os.path.exists(model_path):
print(f"Downloading model from HuggingFace...")
try:
downloaded_model = hf_hub_download(
repo_id="mesolitica/Malaysian-F5-TTS-v3",
filename="checkpoints/model_220000.pt",
local_dir="models",
local_dir_use_symlinks=False
)
print(f"Model downloaded successfully to {downloaded_model}!")
model_path = downloaded_model
except Exception as e:
print(f"Error downloading model: {e}")
raise
seed = 1
sampling_rate = 24000
CLONE_VOICES = {
# 'Prabowo': {
# 'path': 'prab.wav',
# 'transcript': 'pada saat sekarang ini dimana bangsa indonesia ditengah tantangan global'
# },
'Ono': {
'path': 'ono.wav',
'transcript': 'kalau saya percaya pekerjaan manusia itu harus lebih banyak lagi melakukan'
},
# 'Najwa': {
# 'path': 'najwa.wav',
# 'transcript': 'ada kekhawatiran dari masyarakat sipil proses pembentukan undang-undang kita'
# },
# 'Zilong': {
# 'path': 'zilong.wav',
# 'transcript': 'tidak ada yang menakutiku, bahkan kematian sekalipun'
# },
}
print('pre model loaded')
f5tts = F5TTS(ckpt_file=model_path, vocab_file='vocab.txt', device='cuda')
def generate_tts(gen_text, voice, speed):
wav, sampling_rate, _ = f5tts.infer(
ref_file=f'ref/{CLONE_VOICES[voice]["path"]}',
ref_text=CLONE_VOICES[voice]["transcript"],
gen_text=gen_text,
seed=seed,
speed=speed,
)
return wav, sampling_rate
print('post model loaded!')