import sys import os current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(current_dir) from transformers import PreTrainedModel, PretrainedConfig, AutoConfig import torch import numpy as np from f5_tts.infer.utils_infer import ( infer_process, load_model, load_vocoder, preprocess_ref_audio_text, ) from f5_tts.model import DiT import soundfile as sf import io from pydub import AudioSegment, silence from huggingface_hub import hf_hub_download from safetensors.torch import load_file import os class INF5Config(PretrainedConfig): model_type = "inf5" def __init__(self, ckpt_path: str = "checkpoints/model_best.pt", vocab_path: str = "checkpoints/vocab.txt", speed: float = 1.0, remove_sil: bool = True, **kwargs): super().__init__(**kwargs) self.ckpt_path = ckpt_path self.vocab_path = vocab_path self.speed = speed self.remove_sil = remove_sil class INF5Model(PreTrainedModel): config_class = INF5Config _tied_weights_keys = [] # Fix for transformers 5.0.0 compatibility @property def all_tied_weights_keys(self): """Compatibility property for transformers 5.0.0""" return {} def __init__(self, config): super().__init__(config) # Determine target device for inference (GPU if available) self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Disable torch.compile graph tracing to prevent ODE solver issues torch._dynamo.config.suppress_errors = True torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Load vocoder - force on actual device to avoid meta tensor issues in transformers 5.0+ with torch.device('cpu'): # Use eager backend to keep _orig_mod structure without actual compilation self.vocoder = torch.compile(load_vocoder(vocoder_name="vocos", is_local=False, device='cpu'), backend="eager") # Download and load model weights (load on CPU first for safe init, # model will be moved to target device in forward()) safetensors_path = hf_hub_download(config.name_or_path, filename="model.safetensors") print(f"Loading model weights from {safetensors_path} (safetensors)...") state_dict = load_file(safetensors_path, device='cpu') # Download vocab.txt from HF Hub vocab_path = hf_hub_download(config.name_or_path, filename="checkpoints/vocab.txt") # Force model loading on CPU to avoid meta tensor issues with torch.device('cpu'): self.ema_model = load_model( DiT, dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4), mel_spec_type="vocos", vocab_file=vocab_path, device='cpu' ) # Load state dict into model BEFORE compiling # Separate ema_model and vocoder weights, strip _orig_mod. prefix ema_state_dict = {} vocoder_state_dict = {} for key, value in state_dict.items(): # Process ema_model weights if key.startswith("ema_model._orig_mod."): new_key = key.replace("ema_model._orig_mod.", "") ema_state_dict[new_key] = value elif key.startswith("ema_model."): new_key = key.replace("ema_model.", "") ema_state_dict[new_key] = value # Process vocoder weights elif key.startswith("vocoder._orig_mod."): new_key = key.replace("vocoder._orig_mod.", "") vocoder_state_dict[new_key] = value elif key.startswith("vocoder."): new_key = key.replace("vocoder.", "") vocoder_state_dict[new_key] = value # Load ema_model weights missing_keys, unexpected_keys = self.ema_model.load_state_dict( ema_state_dict, strict=False) # Load vocoder weights if any (vocoder is already compiled, so use _orig_mod if needed) if vocoder_state_dict: try: # Try loading directly first self.vocoder.load_state_dict(vocoder_state_dict, strict=False) except: # If vocoder is compiled, access the underlying model if hasattr(self.vocoder, '_orig_mod'): self.vocoder._orig_mod.load_state_dict(vocoder_state_dict, strict=False) # Use eager backend - disables actual compilation while keeping _orig_mod # structure for weight serialization. Full torch.compile with inductor # breaks the ODE solver in CFM.sample() causing jumbled/partial text output. self.ema_model = torch.compile(self.ema_model, backend="eager") print(f"Weight loading - Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}") if missing_keys: print(f"Missing keys sample: {missing_keys[:5]}") if unexpected_keys: print(f"Unexpected keys sample: {unexpected_keys[:5]}") # Flag for lazy buffer recomputation (see _recompute_buffers). # We cannot recompute here because transformers 5.0 materializes # meta tensors AFTER __init__ returns, overwriting our values. self._buffers_need_recompute = True def _recompute_buffers(self): """Recompute non-persistent buffers that were corrupted by transformers 5.0's meta device initialization. transformers 5.0 wraps __init__ in torch.device('meta') context, then materializes meta tensors with uninitialized (garbage) values. Non-persistent buffers (not in safetensors) never get correct values. This method must be called AFTER from_pretrained completes.""" from f5_tts.model.modules import precompute_freqs_cis # Get the underlying model (unwrap torch.compile if needed) ema = self.ema_model._orig_mod if hasattr(self.ema_model, '_orig_mod') else self.ema_model # Determine current device of the buffers buf_device = ema.transformer.text_embed.freqs_cis.device if ( hasattr(ema, 'transformer') and hasattr(ema.transformer, 'text_embed') and hasattr(ema.transformer.text_embed, 'freqs_cis') ) else torch.device('cpu') # Recompute text_embed.freqs_cis (positional embeddings for text) if hasattr(ema, 'transformer') and hasattr(ema.transformer, 'text_embed'): text_embed = ema.transformer.text_embed if hasattr(text_embed, 'extra_modeling') and text_embed.extra_modeling: text_dim = text_embed.text_embed.embedding_dim max_pos = text_embed.precompute_max_pos freqs_cis = precompute_freqs_cis(text_dim, max_pos).to(buf_device) # Check if recomputation needed (first value should be cos(0) = 1.0) if text_embed.freqs_cis.is_meta or abs(text_embed.freqs_cis[0, 0].item() - 1.0) > 0.01: text_embed.freqs_cis.data.copy_(freqs_cis) print(f"Recomputed freqs_cis: shape={freqs_cis.shape}, first_val={freqs_cis[0,0].item():.4f}") # Recompute mel_spec.dummy buffer if hasattr(ema, 'mel_spec') and hasattr(ema.mel_spec, 'dummy'): if ema.mel_spec.dummy.is_meta or ema.mel_spec.dummy.item() != 0: ema.mel_spec.dummy.data.fill_(0) print("Recomputed mel_spec.dummy to 0") # Recompute rotary_embed.inv_freq if needed if hasattr(ema, 'transformer') and hasattr(ema.transformer, 'rotary_embed'): rot = ema.transformer.rotary_embed if hasattr(rot, 'inv_freq'): dim = rot.inv_freq.shape[0] * 2 if rot.inv_freq.is_meta or rot.inv_freq[0].abs() > 10: theta = 10000.0 inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float().to(buf_device) / dim)) rot.inv_freq.data.copy_(inv_freq) print(f"Recomputed rotary inv_freq: shape={inv_freq.shape}") self._buffers_need_recompute = False @property def device(self): """Get the target device of the model (GPU if available, else CPU)""" return getattr(self, '_target_device', torch.device('cpu')) def forward(self, text: str, ref_audio_path: str, ref_text: str): """ Generate speech given a reference audio & text input. Args: text (str): The text to be synthesized. ref_audio_path (str): Path to the reference audio file. ref_text (str): The reference text. Returns: np.array: Generated waveform. """ # Lazy recomputation of non-persistent buffers corrupted by transformers 5.0 if getattr(self, "_buffers_need_recompute", False): self._recompute_buffers() if not os.path.exists(ref_audio_path): raise FileNotFoundError(f"Reference audio file {ref_audio_path} not found.") # Load reference audio & text ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_path, ref_text) # Move models to target device (GPU if available) - only actually # transfers on first call; subsequent calls are no-ops self.ema_model.to(self.device) self.vocoder.to(self.device) # Perform inference audio, final_sample_rate, _ = infer_process( ref_audio, ref_text, text, self.ema_model, self.vocoder, mel_spec_type="vocos", speed=self.config.speed, device=self.device, ) # Convert to pydub format and remove silence if needed buffer = io.BytesIO() sf.write(buffer, audio, samplerate=24000, format="WAV") buffer.seek(0) audio_segment = AudioSegment.from_file(buffer, format="wav") if self.config.remove_sil: non_silent_segs = silence.split_on_silence( audio_segment, min_silence_len=1000, silence_thresh=-50, keep_silence=500, seek_step=10, ) non_silent_wave = sum(non_silent_segs, AudioSegment.silent(duration=0)) audio_segment = non_silent_wave # Normalize loudness target_dBFS = -20.0 change_in_dBFS = target_dBFS - audio_segment.dBFS audio_segment = audio_segment.apply_gain(change_in_dBFS) return np.array(audio_segment.get_array_of_samples())