Text-to-Speech
Safetensors
inf5
custom_code
IndicF5 / model.py
Aditya02's picture
Upload folder using huggingface_hub
c465f9f verified
import sys
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig
import torch
import numpy as np
from f5_tts.infer.utils_infer import (
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
)
from f5_tts.model import DiT
import soundfile as sf
import io
from pydub import AudioSegment, silence
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import os
class INF5Config(PretrainedConfig):
model_type = "inf5"
def __init__(self, ckpt_path: str = "checkpoints/model_best.pt", vocab_path: str = "checkpoints/vocab.txt",
speed: float = 1.0, remove_sil: bool = True, **kwargs):
super().__init__(**kwargs)
self.ckpt_path = ckpt_path
self.vocab_path = vocab_path
self.speed = speed
self.remove_sil = remove_sil
class INF5Model(PreTrainedModel):
config_class = INF5Config
_tied_weights_keys = [] # Fix for transformers 5.0.0 compatibility
@property
def all_tied_weights_keys(self):
"""Compatibility property for transformers 5.0.0"""
return {}
def __init__(self, config):
super().__init__(config)
# Determine target device for inference (GPU if available)
self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Disable torch.compile graph tracing to prevent ODE solver issues
torch._dynamo.config.suppress_errors = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Load vocoder - force on actual device to avoid meta tensor issues in transformers 5.0+
with torch.device('cpu'):
# Use eager backend to keep _orig_mod structure without actual compilation
self.vocoder = torch.compile(load_vocoder(vocoder_name="vocos", is_local=False, device='cpu'), backend="eager")
# Download and load model weights (load on CPU first for safe init,
# model will be moved to target device in forward())
safetensors_path = hf_hub_download(config.name_or_path, filename="model.safetensors")
print(f"Loading model weights from {safetensors_path} (safetensors)...")
state_dict = load_file(safetensors_path, device='cpu')
# Download vocab.txt from HF Hub
vocab_path = hf_hub_download(config.name_or_path, filename="checkpoints/vocab.txt")
# Force model loading on CPU to avoid meta tensor issues
with torch.device('cpu'):
self.ema_model = load_model(
DiT,
dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
mel_spec_type="vocos",
vocab_file=vocab_path,
device='cpu'
)
# Load state dict into model BEFORE compiling
# Separate ema_model and vocoder weights, strip _orig_mod. prefix
ema_state_dict = {}
vocoder_state_dict = {}
for key, value in state_dict.items():
# Process ema_model weights
if key.startswith("ema_model._orig_mod."):
new_key = key.replace("ema_model._orig_mod.", "")
ema_state_dict[new_key] = value
elif key.startswith("ema_model."):
new_key = key.replace("ema_model.", "")
ema_state_dict[new_key] = value
# Process vocoder weights
elif key.startswith("vocoder._orig_mod."):
new_key = key.replace("vocoder._orig_mod.", "")
vocoder_state_dict[new_key] = value
elif key.startswith("vocoder."):
new_key = key.replace("vocoder.", "")
vocoder_state_dict[new_key] = value
# Load ema_model weights
missing_keys, unexpected_keys = self.ema_model.load_state_dict(
ema_state_dict, strict=False)
# Load vocoder weights if any (vocoder is already compiled, so use _orig_mod if needed)
if vocoder_state_dict:
try:
# Try loading directly first
self.vocoder.load_state_dict(vocoder_state_dict, strict=False)
except:
# If vocoder is compiled, access the underlying model
if hasattr(self.vocoder, '_orig_mod'):
self.vocoder._orig_mod.load_state_dict(vocoder_state_dict, strict=False)
# Use eager backend - disables actual compilation while keeping _orig_mod
# structure for weight serialization. Full torch.compile with inductor
# breaks the ODE solver in CFM.sample() causing jumbled/partial text output.
self.ema_model = torch.compile(self.ema_model, backend="eager")
print(f"Weight loading - Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}")
if missing_keys:
print(f"Missing keys sample: {missing_keys[:5]}")
if unexpected_keys:
print(f"Unexpected keys sample: {unexpected_keys[:5]}")
# Flag for lazy buffer recomputation (see _recompute_buffers).
# We cannot recompute here because transformers 5.0 materializes
# meta tensors AFTER __init__ returns, overwriting our values.
self._buffers_need_recompute = True
def _recompute_buffers(self):
"""Recompute non-persistent buffers that were corrupted by
transformers 5.0's meta device initialization.
transformers 5.0 wraps __init__ in torch.device('meta') context,
then materializes meta tensors with uninitialized (garbage) values.
Non-persistent buffers (not in safetensors) never get correct values.
This method must be called AFTER from_pretrained completes."""
from f5_tts.model.modules import precompute_freqs_cis
# Get the underlying model (unwrap torch.compile if needed)
ema = self.ema_model._orig_mod if hasattr(self.ema_model, '_orig_mod') else self.ema_model
# Determine current device of the buffers
buf_device = ema.transformer.text_embed.freqs_cis.device if (
hasattr(ema, 'transformer') and hasattr(ema.transformer, 'text_embed')
and hasattr(ema.transformer.text_embed, 'freqs_cis')
) else torch.device('cpu')
# Recompute text_embed.freqs_cis (positional embeddings for text)
if hasattr(ema, 'transformer') and hasattr(ema.transformer, 'text_embed'):
text_embed = ema.transformer.text_embed
if hasattr(text_embed, 'extra_modeling') and text_embed.extra_modeling:
text_dim = text_embed.text_embed.embedding_dim
max_pos = text_embed.precompute_max_pos
freqs_cis = precompute_freqs_cis(text_dim, max_pos).to(buf_device)
# Check if recomputation needed (first value should be cos(0) = 1.0)
if text_embed.freqs_cis.is_meta or abs(text_embed.freqs_cis[0, 0].item() - 1.0) > 0.01:
text_embed.freqs_cis.data.copy_(freqs_cis)
print(f"Recomputed freqs_cis: shape={freqs_cis.shape}, first_val={freqs_cis[0,0].item():.4f}")
# Recompute mel_spec.dummy buffer
if hasattr(ema, 'mel_spec') and hasattr(ema.mel_spec, 'dummy'):
if ema.mel_spec.dummy.is_meta or ema.mel_spec.dummy.item() != 0:
ema.mel_spec.dummy.data.fill_(0)
print("Recomputed mel_spec.dummy to 0")
# Recompute rotary_embed.inv_freq if needed
if hasattr(ema, 'transformer') and hasattr(ema.transformer, 'rotary_embed'):
rot = ema.transformer.rotary_embed
if hasattr(rot, 'inv_freq'):
dim = rot.inv_freq.shape[0] * 2
if rot.inv_freq.is_meta or rot.inv_freq[0].abs() > 10:
theta = 10000.0
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float().to(buf_device) / dim))
rot.inv_freq.data.copy_(inv_freq)
print(f"Recomputed rotary inv_freq: shape={inv_freq.shape}")
self._buffers_need_recompute = False
@property
def device(self):
"""Get the target device of the model (GPU if available, else CPU)"""
return getattr(self, '_target_device', torch.device('cpu'))
def forward(self, text: str, ref_audio_path: str, ref_text: str):
"""
Generate speech given a reference audio & text input.
Args:
text (str): The text to be synthesized.
ref_audio_path (str): Path to the reference audio file.
ref_text (str): The reference text.
Returns:
np.array: Generated waveform.
"""
# Lazy recomputation of non-persistent buffers corrupted by transformers 5.0
if getattr(self, "_buffers_need_recompute", False):
self._recompute_buffers()
if not os.path.exists(ref_audio_path):
raise FileNotFoundError(f"Reference audio file {ref_audio_path} not found.")
# Load reference audio & text
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_path, ref_text)
# Move models to target device (GPU if available) - only actually
# transfers on first call; subsequent calls are no-ops
self.ema_model.to(self.device)
self.vocoder.to(self.device)
# Perform inference
audio, final_sample_rate, _ = infer_process(
ref_audio,
ref_text,
text,
self.ema_model,
self.vocoder,
mel_spec_type="vocos",
speed=self.config.speed,
device=self.device,
)
# Convert to pydub format and remove silence if needed
buffer = io.BytesIO()
sf.write(buffer, audio, samplerate=24000, format="WAV")
buffer.seek(0)
audio_segment = AudioSegment.from_file(buffer, format="wav")
if self.config.remove_sil:
non_silent_segs = silence.split_on_silence(
audio_segment,
min_silence_len=1000,
silence_thresh=-50,
keep_silence=500,
seek_step=10,
)
non_silent_wave = sum(non_silent_segs, AudioSegment.silent(duration=0))
audio_segment = non_silent_wave
# Normalize loudness
target_dBFS = -20.0
change_in_dBFS = target_dBFS - audio_segment.dBFS
audio_segment = audio_segment.apply_gain(change_in_dBFS)
return np.array(audio_segment.get_array_of_samples())