Ani-Voice-API / tts_engine.py
ssasio's picture
Upload 12 files
caf9e2a verified
Raw
History Blame Contribute Delete
6.58 kB
import os
import io
import re
import wave
import torch
import numpy as np
import tempfile
import sys
import supertonic
# Добавяме BgTTS към sys.path, за да може вътрешните му импорти да работят
sys.path.append(os.path.join(os.path.dirname(__file__), 'BgTTS'))
from inference import synthesize
from normalizer import normalize_text
class TTSEngine:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Зареждам TTS Engine на устройство: {self.device}")
# Supertonic (Референтно аудио)
from supertonic import TTS
self.engine = TTS(auto_download=True)
# BgTTS (Основен модел)
self.bgtts_checkpoint = os.path.join(os.path.dirname(__file__), "BgTTS", "checkpoint_inference.pt")
# BgTTS inference.synthesize зарежда модела всеки път, ако не му подадем модела.
# В текущия BgTTS/inference.py synthesize() вика load_for_inference(), ако се подаде път.
# За сега ще ползваме пътя, тъй като така е написан BgTTS.
# Ако искаме пълно кеширане, може да се наложи леко пренаписване на BgTTS/inference.py.
# Но засега ще ползваме оригиналната synthesize функция.
print("TTS Engine зареден успешно.")
def split_text_for_tts(self, text: str) -> list[str]:
text = text.strip()
if not text:
return []
raw = re.split(r'(?<=[\.\!\?…])\s+|\n+', text)
chunks = []
buf = ""
for part in raw:
part = part.strip()
if not part: continue
if not buf or len(buf) < 80 or len(buf) + len(part) + 1 <= 200:
buf = (buf + " " + part).strip()
else:
chunks.append(buf)
buf = part
if buf: chunks.append(buf)
return chunks
def generate_chunk(self, chunk_text: str, voice_style: str = "F5", speed: float = 1.6) -> bytes:
"""
Генерира аудио за едно изречение (chunk) и го връща като WAV байтове.
"""
clean_text = chunk_text.replace('"', '').replace('„', '').replace('“', '') \
.replace("’", "'").replace("–", "-").replace("—", "-") \
.replace("*", "")
if not clean_text.strip():
return b""
# 1. Генериране на референтно аудио
# Ако voice_style е стринг (напр. "F5"), взимаме съответния обект
if isinstance(voice_style, str):
v_style = self.engine.get_voice_style(voice_name=voice_style)
else:
v_style = voice_style
wav_array, _ = self.engine.synthesize(clean_text, voice_style=v_style, lang="bg", speed=speed)
wav_data = np.asarray(wav_array).flatten()
wav_max = np.max(np.abs(wav_data))
if wav_max > 0:
wav_data = wav_data / wav_max
pcm_data = (wav_data * 32767).astype(np.int16)
# Записваме временно референтното аудио (тъй като BgTTS изисква файл)
fd, ref_path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
with wave.open(ref_path, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(44100)
wf.writeframes(pcm_data.tobytes())
# 2. Генериране на крайното аудио
fd, final_path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
try:
synthesize(checkpoint=self.bgtts_checkpoint,
text=clean_text,
output=final_path,
speaker_wav=ref_path,
device=self.device)
# Прочитане на резултата
with open(final_path, "rb") as f:
audio_bytes = f.read()
return audio_bytes
finally:
try:
os.remove(ref_path)
os.remove(final_path)
except OSError:
pass
def synthesize_stream(self, text: str, voice_style: str = "F5", speed: float = 1.6):
"""
Генератор, който нормализира текста, цепи го на парчета и връща WAV байтове за всяко парче.
"""
normalized_text = normalize_text(text)
chunks = self.split_text_for_tts(normalized_text)
for chunk in chunks:
audio_bytes = self.generate_chunk(chunk, voice_style, speed)
if audio_bytes:
yield audio_bytes
def synthesize_full(self, text: str, voice_style: str = "F5", speed: float = 1.6) -> bytes:
"""
Нормализира текста, цепи го, генерира всички парчета и ги слепва в един общ WAV файл.
"""
normalized_text = normalize_text(text)
chunks = self.split_text_for_tts(normalized_text)
all_frames = b""
params = None
for chunk in chunks:
audio_bytes = self.generate_chunk(chunk, voice_style, speed)
if not audio_bytes:
continue
# Парсване на WAV данните, за да можем да ги слеем без да дублираме хедъри
with wave.open(io.BytesIO(audio_bytes), "rb") as wf:
if not params:
params = wf.getparams()
all_frames += wf.readframes(wf.getnframes())
if not params:
return b""
# Създаване на крайния WAV
out_io = io.BytesIO()
with wave.open(out_io, "wb") as wf:
wf.setparams(params)
wf.writeframes(all_frames)
return out_io.getvalue()
# Глобална инстанция за по-лесно преизползване
engine = TTSEngine()