Spaces:
Sleeping
Sleeping
| import torch | |
| import torchaudio | |
| from model.network import MiniTTS | |
| from model.dataset import TextProcessor # We reuse the text logic we already wrote! | |
| class TTSInference: | |
| def __init__(self, checkpoint_path, device='cpu'): | |
| self.device = device | |
| self.model = self.load_model(checkpoint_path) | |
| print(f"Model loaded from {checkpoint_path}") | |
| def load_model(self, path): | |
| # 1. Initialize the same architecture as training | |
| model = MiniTTS(num_chars=40, num_mels=80) | |
| # 2. Load the weights | |
| # map_location ensures it loads on CPU even if trained on GPU | |
| state_dict = torch.load(path, map_location=self.device) | |
| model.load_state_dict(state_dict) | |
| return model.eval().to(self.device) | |
| def predict(self, text): | |
| # 1. Text Preprocessing | |
| text_tensor = TextProcessor.text_to_sequence(text).unsqueeze(0).to(self.device) | |
| # 2. Autoregressive Inference (The Loop) | |
| # We start with ONE silent frame. The model predicts the next, and we feed it back. | |
| with torch.no_grad(): | |
| # Start with [Batch, Time=1, Mels=80] of zeros | |
| decoder_input = torch.zeros(1, 1, 80).to(self.device) | |
| # Generate 150 frames (about 1.5 seconds of audio) | |
| # You can increase this range for longer sentences | |
| for _ in range(150): | |
| # Ask model to predict based on what we have so far | |
| prediction = self.model(text_tensor, decoder_input) | |
| # Take ONLY the newest frame it predicted (the last one) | |
| new_frame = prediction[:, -1:, :] | |
| # Add it to our growing list of frames | |
| decoder_input = torch.cat([decoder_input, new_frame], dim=1) | |
| # The result is our generated spectrogram | |
| # Shape: [1, 151, 80] -> [1, 80, 151] | |
| mel_spec = decoder_input.transpose(1, 2) | |
| # 3. Vocoder (Spectrogram -> Audio) | |
| # Inverse Mel Scale | |
| inverse_mel_scaler = torchaudio.transforms.InverseMelScale( | |
| n_stft=513, n_mels=80, sample_rate=22050 | |
| ).to(self.device) | |
| linear_spec = inverse_mel_scaler(mel_spec) | |
| # Griffin-Lim | |
| griffin_lim = torchaudio.transforms.GriffinLim(n_fft=1024, n_iter=32).to(self.device) | |
| audio = griffin_lim(linear_spec) | |
| return audio.squeeze(0).cpu().numpy(), 22050, mel_spec.squeeze(0).cpu().numpy() |