mana-tts / synthesis.py
abreza's picture
feat: improved number handling and audio processing
da2ee9a
raw
history blame
5.45 kB
import os
import sys
import re
import numpy as np
import torch
import soundfile as sf
import spaces
from config import models_path, results_path, sample_path, BASE_DIR
from sentence_splitter import PersianSentenceSplitter
from text_utils import convert_number_to_text
encoder = None
synthesizer = None
vocoder = None
sentence_splitter = None
def load_models():
global encoder, synthesizer, vocoder, sentence_splitter
try:
sys.path.append(os.path.join(BASE_DIR, 'pmt2'))
from encoder import inference as encoder_module
from synthesizer.inference import Synthesizer
from parallel_wavegan.utils import load_model as vocoder_hifigan
global encoder
encoder = encoder_module
print("Loading encoder model...")
encoder.load_model(os.path.join(models_path, 'encoder.pt'))
print("Loading synthesizer model...")
synthesizer = Synthesizer(os.path.join(models_path, 'synthesizer.pt'))
print("Loading HiFiGAN vocoder...")
vocoder = vocoder_hifigan(os.path.join(models_path, 'vocoder_HiFiGAN.pkl'))
vocoder.remove_weight_norm()
vocoder = vocoder.eval().to('cuda' if torch.cuda.is_available() else 'cpu')
sentence_splitter = PersianSentenceSplitter(max_chars=150, min_chars=30)
print("Models loaded successfully!")
return True
except Exception as e:
import traceback
print(f"Error loading models: {traceback.format_exc()}")
return False
def normalize_text_for_synthesis(text: str) -> str:
text = text.replace('ك', 'ک').replace('ي', 'ی')
text = text.replace('_', '\u200c')
text = re.sub(r'\s+', ' ', text)
text = text.strip()
number_pattern = r'[۰-۹0-9٠-٩]+(?:[,،٬][۰-۹0-9٠-٩]+)*'
def replace_number(match):
num_str = match.group(0)
try:
return convert_number_to_text(num_str)
except:
return num_str
text = re.sub(number_pattern, replace_number, text)
return text
def synthesize_segment(text_segment: str, embed: np.ndarray) -> np.ndarray:
try:
text_segment = normalize_text_for_synthesis(text_segment)
specs = synthesizer.synthesize_spectrograms([text_segment], [embed])
spec = specs[0]
x = torch.from_numpy(spec.T).to('cuda' if torch.cuda.is_available() else 'cpu')
with torch.no_grad():
wav = vocoder.inference(x)
wav = wav.cpu().numpy()
if wav.ndim > 1:
wav = wav.squeeze()
return wav
except Exception as e:
import traceback
print(f"Error synthesizing segment '{text_segment[:50]}...': {traceback.format_exc()}")
return None
def add_silence(duration_ms: int = 300) -> np.ndarray:
sample_rate = synthesizer.sample_rate
num_samples = int(sample_rate * duration_ms / 1000)
return np.zeros(num_samples, dtype=np.float32)
@spaces.GPU(duration=120)
def generate_speech(text, reference_audio=None, add_pauses: bool = True):
if not text or text.strip() == "":
return None
try:
if reference_audio is None:
ref_wav_path = sample_path
else:
ref_wav_path = os.path.join(results_path, "reference_audio.wav")
sf.write(ref_wav_path, reference_audio[1], reference_audio[0])
print(f"Using reference audio: {ref_wav_path}")
wav = synthesizer.load_preprocess_wav(ref_wav_path)
encoder_wav = encoder.preprocess_wav(wav)
embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
text_segments = sentence_splitter.split(text)
print(f"Split text into {len(text_segments)} segments:")
for i, segment in enumerate(text_segments, 1):
print(f" Segment {i}: {segment[:60]}{'...' if len(segment) > 60 else ''}")
audio_segments = []
silence = add_silence(300) if add_pauses else None # 300ms pause
for i, segment in enumerate(text_segments):
print(f"Processing segment {i+1}/{len(text_segments)}...")
segment_wav = synthesize_segment(segment, embed)
if segment_wav is not None:
segment_wav = segment_wav.flatten() if segment_wav.ndim > 1 else segment_wav
audio_segments.append(segment_wav)
if add_pauses and i < len(text_segments) - 1:
audio_segments.append(silence)
else:
print(f"Warning: Failed to synthesize segment {i+1}")
if not audio_segments:
print("Error: No audio segments were generated successfully")
return None
audio_segments = [seg.flatten() if seg.ndim > 1 else seg for seg in audio_segments]
final_wav = np.concatenate(audio_segments)
final_wav = final_wav / np.abs(final_wav).max() * 0.97
output_filename = f"generated_{abs(hash(text)) % 100000}.wav"
output_path = os.path.join(results_path, output_filename)
sf.write(output_path, final_wav, synthesizer.sample_rate)
print(f"✓ Successfully generated speech: {output_path}")
print(f" Total duration: {len(final_wav) / synthesizer.sample_rate:.2f} seconds")
return output_path
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error generating speech: {error_details}")
return None