audiobook-ru-tts / backends /espeech_backend.py
danilahs's picture
Upload folder using huggingface_hub
4f6648e verified
# backends/espeech_backend.py
# Полная интеграция ESpeech/ESpeech-TTS-1_RL-V2 (F5-TTS) для инференса.
# Основано на коде из model card: загрузка весов, препроцессинг референса,
# вызов infer_process и возврат (wave, sample_rate).
from __future__ import annotations
from typing import Tuple, Optional
import os
import gc
import numpy as np
import torch
import torchaudio
# Force CPU usage on macOS to avoid MPS issues
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
# Disable MPS to force CPU usage
torch.backends.mps.is_available = lambda: False
torch.backends.mps.is_built = lambda: False
from huggingface_hub import hf_hub_download, snapshot_download
# F5-TTS imports (как в карточке модели)
from f5_tts.infer.utils_infer import (
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
)
from f5_tts.model import DiT
# Конфиг модели из карточки
MODEL_CFG = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
class EspeechBackend:
def __init__(self, model_id: str = "ESpeech/ESpeech-TTS-1_RL-V2"):
self.model_id = model_id
self.model_file = "espeech_tts_rlv2.pt"
self.vocab_file = "vocab.txt"
# Force CPU on macOS to avoid MPS issues
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.model = None
self.vocoder = None
self._ensure_loaded()
def _download(self, repo: str, filename: str) -> str:
try:
return hf_hub_download(repo_id=repo, filename=filename)
except Exception:
# запасной путь: snapshot целиком
local_dir = f"cache_{repo.replace('/', '_')}"
snap_dir = snapshot_download(repo_id=repo, local_dir=local_dir)
path = os.path.join(snap_dir, filename)
if not os.path.exists(path):
raise FileNotFoundError(f"{filename} not found in snapshot {snap_dir}")
return path
def _ensure_loaded(self):
# загрузка весов модели и словаря
model_path = self._download(self.model_id, self.model_file)
vocab_path = self._download(self.model_id, self.vocab_file)
# инициализация модели и вокодера
self.model = load_model(DiT, MODEL_CFG, model_path, vocab_file=vocab_path)
self.vocoder = load_vocoder()
# перенос на устройство
try:
self.model.to(self.device)
self.vocoder.to(self.device)
except Exception as e:
# Fallback to CPU if device transfer fails
print(f"Warning: Failed to move model to {self.device}, falling back to CPU: {e}")
self.device = torch.device("cpu")
self.model.to(self.device)
self.vocoder.to(self.device)
def synthesize(
self,
text: str,
ref_audio_path: Optional[str],
ref_text: str,
speed: float = 1.0,
nfe_steps: int = 48,
seed: Optional[int] = None,
cross_fade_sec: float = 0.15,
target_rms: float = 0.1,
cfg_strength: float = 2.0,
sway_sampling_coef: float = -1.0,
) -> Tuple[np.ndarray, int]:
"""
Возвращает (audio_float32_mono, sample_rate).
Требования: float32 [-1..1], моно.
"""
if not text or not text.strip():
raise ValueError("Пустой текст для синтеза.")
if not ref_audio_path or not os.path.exists(ref_audio_path):
raise FileNotFoundError("Укажите путь к reference audio (6–12 с).")
if not ref_text or not ref_text.strip():
raise ValueError("Укажите reference text (транскрипт того же reference audio).")
if seed is not None:
torch.manual_seed(int(seed))
# Подготовка референса (функция сама сделает ресэмплинг/моно)
ref_audio_proc, ref_text_proc = preprocess_ref_audio_text(ref_audio_path, ref_text)
# Основной вызов инференса с дополнительными параметрами для улучшения качества голоса
final_wave, final_sample_rate, _ = infer_process(
ref_audio_proc,
ref_text_proc,
text,
self.model,
self.vocoder,
cross_fade_duration=float(cross_fade_sec),
nfe_step=int(nfe_steps),
speed=float(speed),
target_rms=float(target_rms),
cfg_strength=float(cfg_strength),
sway_sampling_coef=float(sway_sampling_coef),
)
# На всякий случай приводим тип/диапазон
wav = np.asarray(final_wave, dtype=np.float32)
wav = np.clip(wav, -1.0, 1.0)
sr = int(final_sample_rate)
# Освобождение памяти на CUDA (длинные книги)
if self.device.type == "cuda":
try:
torch.cuda.empty_cache()
gc.collect()
except Exception:
pass
return wav, sr