File size: 2,598 Bytes
be29b5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torchaudio
from model.network import MiniTTS
from model.dataset import TextProcessor  # We reuse the text logic we already wrote!

class TTSInference:
    def __init__(self, checkpoint_path, device='cpu'):
        self.device = device
        self.model = self.load_model(checkpoint_path)
        print(f"Model loaded from {checkpoint_path}")

    def load_model(self, path):
        # 1. Initialize the same architecture as training
        model = MiniTTS(num_chars=40, num_mels=80) 
        
        # 2. Load the weights
        # map_location ensures it loads on CPU even if trained on GPU
        state_dict = torch.load(path, map_location=self.device)
        model.load_state_dict(state_dict)
        
        return model.eval().to(self.device)

    def predict(self, text):
        # 1. Text Preprocessing
        text_tensor = TextProcessor.text_to_sequence(text).unsqueeze(0).to(self.device)
        
        # 2. Autoregressive Inference (The Loop)
        # We start with ONE silent frame. The model predicts the next, and we feed it back.
        with torch.no_grad():
            # Start with [Batch, Time=1, Mels=80] of zeros
            decoder_input = torch.zeros(1, 1, 80).to(self.device)
            
            # Generate 150 frames (about 1.5 seconds of audio)
            # You can increase this range for longer sentences
            for _ in range(150):
                # Ask model to predict based on what we have so far
                prediction = self.model(text_tensor, decoder_input)
                
                # Take ONLY the newest frame it predicted (the last one)
                new_frame = prediction[:, -1:, :]
                
                # Add it to our growing list of frames
                decoder_input = torch.cat([decoder_input, new_frame], dim=1)
            
            # The result is our generated spectrogram
            # Shape: [1, 151, 80] -> [1, 80, 151]
            mel_spec = decoder_input.transpose(1, 2)

        # 3. Vocoder (Spectrogram -> Audio)
        # Inverse Mel Scale
        inverse_mel_scaler = torchaudio.transforms.InverseMelScale(
            n_stft=513, n_mels=80, sample_rate=22050
        ).to(self.device)
        
        linear_spec = inverse_mel_scaler(mel_spec)
        
        # Griffin-Lim
        griffin_lim = torchaudio.transforms.GriffinLim(n_fft=1024, n_iter=32).to(self.device)
        audio = griffin_lim(linear_spec)
        
        return audio.squeeze(0).cpu().numpy(), 22050, mel_spec.squeeze(0).cpu().numpy()