import torch import torchaudio from model.network import MiniTTS from model.dataset import TextProcessor # We reuse the text logic we already wrote! class TTSInference: def __init__(self, checkpoint_path, device='cpu'): self.device = device self.model = self.load_model(checkpoint_path) print(f"Model loaded from {checkpoint_path}") def load_model(self, path): # 1. Initialize the same architecture as training model = MiniTTS(num_chars=40, num_mels=80) # 2. Load the weights # map_location ensures it loads on CPU even if trained on GPU state_dict = torch.load(path, map_location=self.device) model.load_state_dict(state_dict) return model.eval().to(self.device) def predict(self, text): # 1. Text Preprocessing text_tensor = TextProcessor.text_to_sequence(text).unsqueeze(0).to(self.device) # 2. Autoregressive Inference (The Loop) # We start with ONE silent frame. The model predicts the next, and we feed it back. with torch.no_grad(): # Start with [Batch, Time=1, Mels=80] of zeros decoder_input = torch.zeros(1, 1, 80).to(self.device) # Generate 150 frames (about 1.5 seconds of audio) # You can increase this range for longer sentences for _ in range(150): # Ask model to predict based on what we have so far prediction = self.model(text_tensor, decoder_input) # Take ONLY the newest frame it predicted (the last one) new_frame = prediction[:, -1:, :] # Add it to our growing list of frames decoder_input = torch.cat([decoder_input, new_frame], dim=1) # The result is our generated spectrogram # Shape: [1, 151, 80] -> [1, 80, 151] mel_spec = decoder_input.transpose(1, 2) # 3. Vocoder (Spectrogram -> Audio) # Inverse Mel Scale inverse_mel_scaler = torchaudio.transforms.InverseMelScale( n_stft=513, n_mels=80, sample_rate=22050 ).to(self.device) linear_spec = inverse_mel_scaler(mel_spec) # Griffin-Lim griffin_lim = torchaudio.transforms.GriffinLim(n_fft=1024, n_iter=32).to(self.device) audio = griffin_lim(linear_spec) return audio.squeeze(0).cpu().numpy(), 22050, mel_spec.squeeze(0).cpu().numpy()