Text-to-Speech
Safetensors
inf5
custom_code
File size: 11,304 Bytes
a3cbac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import sys
import os

current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)

from transformers import PreTrainedModel, PretrainedConfig, AutoConfig
import torch
import numpy as np
from f5_tts.infer.utils_infer import (
    infer_process,
    load_model,
    load_vocoder,
    preprocess_ref_audio_text,
)
from f5_tts.model import DiT
import soundfile as sf
import io
from pydub import AudioSegment, silence
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import os

class INF5Config(PretrainedConfig):
    model_type = "inf5"

    def __init__(self, ckpt_path: str = "checkpoints/model_best.pt", vocab_path: str = "checkpoints/vocab.txt",
                                                                  speed: float = 1.0, remove_sil: bool = True, **kwargs):
        super().__init__(**kwargs)
        self.ckpt_path = ckpt_path
        self.vocab_path = vocab_path
        self.speed = speed
        self.remove_sil = remove_sil

class INF5Model(PreTrainedModel):
    config_class = INF5Config
    _tied_weights_keys = []  # Fix for transformers 5.0.0 compatibility
    
    @property
    def all_tied_weights_keys(self):
        """Compatibility property for transformers 5.0.0"""
        return {}


    def __init__(self, config):
        super().__init__(config)
        # Determine target device for inference (GPU if available)
        self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Disable torch.compile graph tracing to prevent ODE solver issues
        torch._dynamo.config.suppress_errors = True
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        # Load vocoder - force on actual device to avoid meta tensor issues in transformers 5.0+
        with torch.device('cpu'):
            # Use eager backend to keep _orig_mod structure without actual compilation
            self.vocoder = torch.compile(load_vocoder(vocoder_name="vocos", is_local=False, device='cpu'), backend="eager")
                                     
        # Download and load model weights (load on CPU first for safe init,
        # model will be moved to target device in forward())
        safetensors_path = hf_hub_download(config.name_or_path, filename="model.safetensors")
        print(f"Loading model weights from {safetensors_path} (safetensors)...")
        state_dict = load_file(safetensors_path, device='cpu')

        # Download vocab.txt from HF Hub
        vocab_path = hf_hub_download(config.name_or_path, filename="checkpoints/vocab.txt")
                                                                     
        # Force model loading on CPU to avoid meta tensor issues
        with torch.device('cpu'):
            self.ema_model = load_model(
                    DiT,
                    dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
                                                                                     mel_spec_type="vocos",
                    vocab_file=vocab_path,
                    device='cpu'
                )

        # Load state dict into model BEFORE compiling
        # Separate ema_model and vocoder weights, strip _orig_mod. prefix
        ema_state_dict = {}
        vocoder_state_dict = {}
        
        for key, value in state_dict.items():
            # Process ema_model weights
            if key.startswith("ema_model._orig_mod."):
                new_key = key.replace("ema_model._orig_mod.", "")
                ema_state_dict[new_key] = value
            elif key.startswith("ema_model."):
                new_key = key.replace("ema_model.", "")
                ema_state_dict[new_key] = value
            # Process vocoder weights
            elif key.startswith("vocoder._orig_mod."):
                new_key = key.replace("vocoder._orig_mod.", "")
                vocoder_state_dict[new_key] = value
            elif key.startswith("vocoder."):
                new_key = key.replace("vocoder.", "")
                vocoder_state_dict[new_key] = value
        
        # Load ema_model weights
        missing_keys, unexpected_keys = self.ema_model.load_state_dict(
            ema_state_dict, strict=False)
        
        # Load vocoder weights if any (vocoder is already compiled, so use _orig_mod if needed)
        if vocoder_state_dict:
            try:
                # Try loading directly first
                self.vocoder.load_state_dict(vocoder_state_dict, strict=False)
            except:
                # If vocoder is compiled, access the underlying model
                if hasattr(self.vocoder, '_orig_mod'):
                    self.vocoder._orig_mod.load_state_dict(vocoder_state_dict, strict=False)
                                                                            
        # Use eager backend - disables actual compilation while keeping _orig_mod
        # structure for weight serialization. Full torch.compile with inductor
        # breaks the ODE solver in CFM.sample() causing jumbled/partial text output.
        self.ema_model = torch.compile(self.ema_model, backend="eager")
        print(f"Weight loading - Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}")
        if missing_keys:
            print(f"Missing keys sample: {missing_keys[:5]}")
        if unexpected_keys:
            print(f"Unexpected keys sample: {unexpected_keys[:5]}")

        # Flag for lazy buffer recomputation (see _recompute_buffers).
        # We cannot recompute here because transformers 5.0 materializes
        # meta tensors AFTER __init__ returns, overwriting our values.
        self._buffers_need_recompute = True

    def _recompute_buffers(self):
        """Recompute non-persistent buffers that were corrupted by
        transformers 5.0's meta device initialization.
        
        transformers 5.0 wraps __init__ in torch.device('meta') context,
        then materializes meta tensors with uninitialized (garbage) values.
        Non-persistent buffers (not in safetensors) never get correct values.
        This method must be called AFTER from_pretrained completes."""
        from f5_tts.model.modules import precompute_freqs_cis
        
        # Get the underlying model (unwrap torch.compile if needed)
        ema = self.ema_model._orig_mod if hasattr(self.ema_model, '_orig_mod') else self.ema_model
                                                                      
        # Determine current device of the buffers
        buf_device = ema.transformer.text_embed.freqs_cis.device if (
            hasattr(ema, 'transformer') and hasattr(ema.transformer, 'text_embed')
            and hasattr(ema.transformer.text_embed, 'freqs_cis')
        ) else torch.device('cpu')

        # Recompute text_embed.freqs_cis (positional embeddings for text)
        if hasattr(ema, 'transformer') and hasattr(ema.transformer, 'text_embed'):
            text_embed = ema.transformer.text_embed
            if hasattr(text_embed, 'extra_modeling') and text_embed.extra_modeling:
                text_dim = text_embed.text_embed.embedding_dim
                max_pos = text_embed.precompute_max_pos
                freqs_cis = precompute_freqs_cis(text_dim, max_pos).to(buf_device)
                # Check if recomputation needed (first value should be cos(0) = 1.0)
                if text_embed.freqs_cis.is_meta or abs(text_embed.freqs_cis[0, 0].item() - 1.0) > 0.01:
                    text_embed.freqs_cis.data.copy_(freqs_cis)
                    print(f"Recomputed freqs_cis: shape={freqs_cis.shape}, first_val={freqs_cis[0,0].item():.4f}")
                                                                    
        # Recompute mel_spec.dummy buffer
        if hasattr(ema, 'mel_spec') and hasattr(ema.mel_spec, 'dummy'):
            if ema.mel_spec.dummy.is_meta or ema.mel_spec.dummy.item() != 0:
                ema.mel_spec.dummy.data.fill_(0)
                print("Recomputed mel_spec.dummy to 0")
            
        # Recompute rotary_embed.inv_freq if needed
        if hasattr(ema, 'transformer') and hasattr(ema.transformer, 'rotary_embed'):
            rot = ema.transformer.rotary_embed
            if hasattr(rot, 'inv_freq'):
                dim = rot.inv_freq.shape[0] * 2
                if rot.inv_freq.is_meta or rot.inv_freq[0].abs() > 10:
                    theta = 10000.0
                    inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float().to(buf_device) / dim))
                    rot.inv_freq.data.copy_(inv_freq)
                    print(f"Recomputed rotary inv_freq: shape={inv_freq.shape}")
        
        self._buffers_need_recompute = False

    
    @property
    def device(self):
        """Get the target device of the model (GPU if available, else CPU)"""
        return getattr(self, '_target_device', torch.device('cpu'))
                                
    def forward(self, text: str, ref_audio_path: str, ref_text: str):
        """
        Generate speech given a reference audio & text input.
        
        Args:
            text (str): The text to be synthesized.
            ref_audio_path (str): Path to the reference audio file.
            ref_text (str): The reference text.
        Returns:
            np.array: Generated waveform.
        """

        # Lazy recomputation of non-persistent buffers corrupted by transformers 5.0
        if getattr(self, "_buffers_need_recompute", False):
            self._recompute_buffers()

        if not os.path.exists(ref_audio_path):
            raise FileNotFoundError(f"Reference audio file {ref_audio_path} not found.")
                                                                        
        # Load reference audio & text
        ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_path, ref_text)
        
        # Move models to target device (GPU if available) - only actually
        # transfers on first call; subsequent calls are no-ops
        self.ema_model.to(self.device)
        self.vocoder.to(self.device)
        
        # Perform inference
        audio, final_sample_rate, _ = infer_process(
            ref_audio,
            ref_text,
            text,
            self.ema_model,
            self.vocoder,
            mel_spec_type="vocos",
            speed=self.config.speed,
            device=self.device,
        )

        # Convert to pydub format and remove silence if needed
        buffer = io.BytesIO()
        sf.write(buffer, audio, samplerate=24000, format="WAV")
        buffer.seek(0)
        audio_segment = AudioSegment.from_file(buffer, format="wav")

        if self.config.remove_sil:
            non_silent_segs = silence.split_on_silence(
                audio_segment,
                min_silence_len=1000,
                silence_thresh=-50,
                keep_silence=500,
                seek_step=10,
            )
            non_silent_wave = sum(non_silent_segs, AudioSegment.silent(duration=0))
            audio_segment = non_silent_wave

        # Normalize loudness
        target_dBFS = -20.0
        change_in_dBFS = target_dBFS - audio_segment.dBFS
        audio_segment = audio_segment.apply_gain(change_in_dBFS)

        return np.array(audio_segment.get_array_of_samples())