Ani-Voice-API / tts_engine.py
beleata74's picture
Initial release of Ani-Voice-API (created by Ani-Antigravity)
695fb87 verified
Raw
History Blame Contribute Delete
6.58 kB
import os
import io
import re
import wave
import torch
import numpy as np
import tempfile
import sys
import supertonic
# Добавяме BgTTS към sys.path, за да може вътрешните му импорти да работят
sys.path.append(os.path.join(os.path.dirname(__file__), 'BgTTS'))
from inference import synthesize
from normalizer import normalize_text
class TTSEngine:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Зареждам TTS Engine на устройство: {self.device}")
# Supertonic (Референтно аудио)
from supertonic import TTS
self.engine = TTS(auto_download=True)
# BgTTS (Основен модел)
self.bgtts_checkpoint = os.path.join(os.path.dirname(__file__), "BgTTS", "checkpoint_inference.pt")
# BgTTS inference.synthesize зарежда модела всеки път, ако не му подадем модела.
# В текущия BgTTS/inference.py synthesize() вика load_for_inference(), ако се подаде път.
# За сега ще ползваме пътя, тъй като така е написан BgTTS.
# Ако искаме пълно кеширане, може да се наложи леко пренаписване на BgTTS/inference.py.
# Но засега ще ползваме оригиналната synthesize функция.
print("TTS Engine зареден успешно.")
def split_text_for_tts(self, text: str) -> list[str]:
text = text.strip()
if not text:
return []
raw = re.split(r'(?<=[\.\!\?…])\s+|\n+', text)
chunks = []
buf = ""
for part in raw:
part = part.strip()
if not part: continue
if not buf or len(buf) < 80 or len(buf) + len(part) + 1 <= 200:
buf = (buf + " " + part).strip()
else:
chunks.append(buf)
buf = part
if buf: chunks.append(buf)
return chunks
def generate_chunk(self, chunk_text: str, voice_style: str = "F5", speed: float = 1.6) -> bytes:
"""
Генерира аудио за едно изречение (chunk) и го връща като WAV байтове.
"""
clean_text = chunk_text.replace('"', '').replace('„', '').replace('“', '') \
.replace("’", "'").replace("–", "-").replace("—", "-") \
.replace("*", "")
if not clean_text.strip():
return b""
# 1. Генериране на референтно аудио
# Ако voice_style е стринг (напр. "F5"), взимаме съответния обект
if isinstance(voice_style, str):
v_style = self.engine.get_voice_style(voice_name=voice_style)
else:
v_style = voice_style
wav_array, _ = self.engine.synthesize(clean_text, voice_style=v_style, lang="bg", speed=speed)
wav_data = np.asarray(wav_array).flatten()
wav_max = np.max(np.abs(wav_data))
if wav_max > 0:
wav_data = wav_data / wav_max
pcm_data = (wav_data * 32767).astype(np.int16)
# Записваме временно референтното аудио (тъй като BgTTS изисква файл)
fd, ref_path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
with wave.open(ref_path, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(44100)
wf.writeframes(pcm_data.tobytes())
# 2. Генериране на крайното аудио
fd, final_path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
try:
synthesize(checkpoint=self.bgtts_checkpoint,
text=clean_text,
output=final_path,
speaker_wav=ref_path,
device=self.device)
# Прочитане на резултата
with open(final_path, "rb") as f:
audio_bytes = f.read()
return audio_bytes
finally:
try:
os.remove(ref_path)
os.remove(final_path)
except OSError:
pass
def synthesize_stream(self, text: str, voice_style: str = "F5", speed: float = 1.6):
"""
Генератор, който нормализира текста, цепи го на парчета и връща WAV байтове за всяко парче.
"""
normalized_text = normalize_text(text)
chunks = self.split_text_for_tts(normalized_text)
for chunk in chunks:
audio_bytes = self.generate_chunk(chunk, voice_style, speed)
if audio_bytes:
yield audio_bytes
def synthesize_full(self, text: str, voice_style: str = "F5", speed: float = 1.6) -> bytes:
"""
Нормализира текста, цепи го, генерира всички парчета и ги слепва в един общ WAV файл.
"""
normalized_text = normalize_text(text)
chunks = self.split_text_for_tts(normalized_text)
all_frames = b""
params = None
for chunk in chunks:
audio_bytes = self.generate_chunk(chunk, voice_style, speed)
if not audio_bytes:
continue
# Парсване на WAV данните, за да можем да ги слеем без да дублираме хедъри
with wave.open(io.BytesIO(audio_bytes), "rb") as wf:
if not params:
params = wf.getparams()
all_frames += wf.readframes(wf.getnframes())
if not params:
return b""
# Създаване на крайния WAV
out_io = io.BytesIO()
with wave.open(out_io, "wb") as wf:
wf.setparams(params)
wf.writeframes(all_frames)
return out_io.getvalue()
# Глобална инстанция за по-лесно преизползване
engine = TTSEngine()