Spaces:
Sleeping
Sleeping
File size: 2,598 Bytes
be29b5b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | import torch
import torchaudio
from model.network import MiniTTS
from model.dataset import TextProcessor # We reuse the text logic we already wrote!
class TTSInference:
def __init__(self, checkpoint_path, device='cpu'):
self.device = device
self.model = self.load_model(checkpoint_path)
print(f"Model loaded from {checkpoint_path}")
def load_model(self, path):
# 1. Initialize the same architecture as training
model = MiniTTS(num_chars=40, num_mels=80)
# 2. Load the weights
# map_location ensures it loads on CPU even if trained on GPU
state_dict = torch.load(path, map_location=self.device)
model.load_state_dict(state_dict)
return model.eval().to(self.device)
def predict(self, text):
# 1. Text Preprocessing
text_tensor = TextProcessor.text_to_sequence(text).unsqueeze(0).to(self.device)
# 2. Autoregressive Inference (The Loop)
# We start with ONE silent frame. The model predicts the next, and we feed it back.
with torch.no_grad():
# Start with [Batch, Time=1, Mels=80] of zeros
decoder_input = torch.zeros(1, 1, 80).to(self.device)
# Generate 150 frames (about 1.5 seconds of audio)
# You can increase this range for longer sentences
for _ in range(150):
# Ask model to predict based on what we have so far
prediction = self.model(text_tensor, decoder_input)
# Take ONLY the newest frame it predicted (the last one)
new_frame = prediction[:, -1:, :]
# Add it to our growing list of frames
decoder_input = torch.cat([decoder_input, new_frame], dim=1)
# The result is our generated spectrogram
# Shape: [1, 151, 80] -> [1, 80, 151]
mel_spec = decoder_input.transpose(1, 2)
# 3. Vocoder (Spectrogram -> Audio)
# Inverse Mel Scale
inverse_mel_scaler = torchaudio.transforms.InverseMelScale(
n_stft=513, n_mels=80, sample_rate=22050
).to(self.device)
linear_spec = inverse_mel_scaler(mel_spec)
# Griffin-Lim
griffin_lim = torchaudio.transforms.GriffinLim(n_fft=1024, n_iter=32).to(self.device)
audio = griffin_lim(linear_spec)
return audio.squeeze(0).cpu().numpy(), 22050, mel_spec.squeeze(0).cpu().numpy() |