Spaces:
Runtime error
Runtime error
File size: 5,769 Bytes
38a17ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | import os
import torch
import soundfile as sf
from transformers import TrainerCallback
from safetensors.torch import load_file
from src.chatterbox_.tts import ChatterboxTTS
from src.chatterbox_.tts_turbo import ChatterboxTurboTTS
from src.chatterbox_.models.t3.t3 import T3
from src.utils import setup_logger, trim_silence_with_vad
logger = setup_logger("InferenceCallback")
class InferenceCallback(TrainerCallback):
def __init__(self, config):
self.config = config
self.inference_dir = os.path.join(config.output_dir, "inference_samples")
os.makedirs(self.inference_dir, exist_ok=True)
if not hasattr(config, 'inference_prompt_path') or not config.inference_prompt_path:
logger.warning("The inference prompt path is not specified; sampling will be skipped.")
self.skip_inference = True
elif not hasattr(config, 'inference_test_text') or not config.inference_test_text:
logger.warning("The inference test text is not specified; the sample will be skipped.")
self.skip_inference = True
else:
self.skip_inference = False
logger.info(f"Inference Callback is ready. Examples will be saved here: {self.inference_dir}")
def on_save(self, args, state, control, **kwargs):
if self.skip_inference:
return
step = state.global_step
checkpoint_dir = os.path.join(args.output_dir, f"checkpoint-{step}")
weights_path = os.path.join(checkpoint_dir, "model.safetensors")
if not os.path.exists(weights_path):
weights_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
if not os.path.exists(weights_path):
logger.warning(f"Checkpoint weights could not be found: {checkpoint_dir}")
return
logger.info(f"Initializing inference for checkpoint-{step}...")
try:
output_path = os.path.join(self.inference_dir, f"checkpoint-{step}.wav")
self._generate_sample(weights_path, output_path)
except Exception as e:
logger.error(f"An error occurred during the inference (Step: {step}): {e}", exc_info=True)
def _generate_sample(self, checkpoint_path: str, output_path: str):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
is_turbo = getattr(self.config, "is_turbo", False)
EngineClass = ChatterboxTurboTTS if is_turbo else ChatterboxTTS
tts_engine = EngineClass.from_local(self.config.model_dir, device="cpu")
t3_config = tts_engine.t3.hp
if hasattr(self.config, 'new_vocab_size'):
t3_config.text_tokens_dict_size = self.config.new_vocab_size
new_t3 = T3(hp=t3_config)
if is_turbo:
if hasattr(new_t3.tfmr, "wte"):
del new_t3.tfmr.wte
if checkpoint_path.endswith(".safetensors"):
state_dict = load_file(checkpoint_path)
else:
state_dict = torch.load(checkpoint_path, map_location="cpu")
clean_state_dict = {}
for k, v in state_dict.items():
k_clean = k.replace("module.", "")
if k_clean.startswith("t3."):
clean_state_dict[k_clean.replace("t3.", "")] = v
elif not any(x in k_clean for x in ["s3gen", "ve.", "tokenizer"]):
clean_state_dict[k_clean] = v
missing_keys, unexpected_keys = new_t3.load_state_dict(clean_state_dict, strict=False)
critical_missing = [k for k in missing_keys if "tfmr.layers" in k]
if len(critical_missing) > 0:
logger.error("[CRITICAL ERROR] Model weights COULD NOT BE LOADED!")
logger.error(f"Number of missing keys: {len(missing_keys)}")
logger.error(f"Examples of missing information: {critical_missing[:3]}")
logger.error("The sound produced will be 100% NOISE (Static Noise). Check your checkpoint recording method.")
elif len(missing_keys) > 0:
non_wte_missing = [k for k in missing_keys if "wte" not in k]
if len(non_wte_missing) > 0:
logger.warning(f"Some weights are missing ({len(non_wte_missing)} pieces): {non_wte_missing[:3]}...")
else:
logger.info("The weights were successfully loaded (except for the WTE - normal for the Turbo).")
else:
logger.info("All the weights were loaded completely and successfully.")
tts_engine.t3 = new_t3
tts_engine.t3.to(device).eval()
tts_engine.s3gen.to(device).eval()
tts_engine.ve.to(device).eval()
tts_engine.device = device
params = {
"temperature": 0.8,
"repetition_penalty": 1.2,
}
if not is_turbo:
params["cfg_weight"] = 0.2
params["exaggeration"]= 1.2,
with torch.no_grad():
wav = tts_engine.generate(
text=self.config.inference_test_text,
audio_prompt_path=self.config.inference_prompt_path,
**params
)
wav_np = wav.squeeze().cpu().numpy()
trimmed_wav = trim_silence_with_vad(wav_np, tts_engine.sr)
sf.write(output_path, trimmed_wav, tts_engine.sr)
logger.info(f"Example saved: {output_path}")
del tts_engine
del new_t3
del state_dict
del clean_state_dict
torch.cuda.empty_cache() |