arabic-tts-server / utils /text_processing.py
shada-elewa's picture
Create utils/text_processing.py
dbee7f1 verified
import torch
from transformers import VitsModel, AutoTokenizer
class TTSManager:
def __init__(self, output_dir, use_cuda_if_available=True):
self.output_dir = output_dir
self.device = "cuda" if use_cuda_if_available and torch.cuda.is_available() else "cpu"
# Load a professional VITS model for Arabic
self.model_name = "facebook/mms-tts-ara"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = VitsModel.from_pretrained(self.model_name).to(self.device)
def tts(self, text, rate=1.0, denoise=0.01):
# 1. Tokenize the text
inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
# 2. Generate audio
with torch.no_grad():
output = self.model(**inputs).waveform
# 3. Save to a file
import scipy.io.wavfile as wavfile
output_path = f"{self.output_dir}/output.wav"
# Convert to numpy and save
audio_data = output.cpu().numpy().squeeze()
wavfile.write(output_path, self.model.config.sampling_rate, audio_data)
return {"audio_url": "/static/output.wav"}