chatterbox-bangla / chatterbox-finetuning /src /inference_callback.py
arijitx's picture
convert submodule to normal directory
38a17ab
import os
import torch
import soundfile as sf
from transformers import TrainerCallback
from safetensors.torch import load_file
from src.chatterbox_.tts import ChatterboxTTS
from src.chatterbox_.tts_turbo import ChatterboxTurboTTS
from src.chatterbox_.models.t3.t3 import T3
from src.utils import setup_logger, trim_silence_with_vad
logger = setup_logger("InferenceCallback")
class InferenceCallback(TrainerCallback):
def __init__(self, config):
self.config = config
self.inference_dir = os.path.join(config.output_dir, "inference_samples")
os.makedirs(self.inference_dir, exist_ok=True)
if not hasattr(config, 'inference_prompt_path') or not config.inference_prompt_path:
logger.warning("The inference prompt path is not specified; sampling will be skipped.")
self.skip_inference = True
elif not hasattr(config, 'inference_test_text') or not config.inference_test_text:
logger.warning("The inference test text is not specified; the sample will be skipped.")
self.skip_inference = True
else:
self.skip_inference = False
logger.info(f"Inference Callback is ready. Examples will be saved here: {self.inference_dir}")
def on_save(self, args, state, control, **kwargs):
if self.skip_inference:
return
step = state.global_step
checkpoint_dir = os.path.join(args.output_dir, f"checkpoint-{step}")
weights_path = os.path.join(checkpoint_dir, "model.safetensors")
if not os.path.exists(weights_path):
weights_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
if not os.path.exists(weights_path):
logger.warning(f"Checkpoint weights could not be found: {checkpoint_dir}")
return
logger.info(f"Initializing inference for checkpoint-{step}...")
try:
output_path = os.path.join(self.inference_dir, f"checkpoint-{step}.wav")
self._generate_sample(weights_path, output_path)
except Exception as e:
logger.error(f"An error occurred during the inference (Step: {step}): {e}", exc_info=True)
def _generate_sample(self, checkpoint_path: str, output_path: str):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
is_turbo = getattr(self.config, "is_turbo", False)
EngineClass = ChatterboxTurboTTS if is_turbo else ChatterboxTTS
tts_engine = EngineClass.from_local(self.config.model_dir, device="cpu")
t3_config = tts_engine.t3.hp
if hasattr(self.config, 'new_vocab_size'):
t3_config.text_tokens_dict_size = self.config.new_vocab_size
new_t3 = T3(hp=t3_config)
if is_turbo:
if hasattr(new_t3.tfmr, "wte"):
del new_t3.tfmr.wte
if checkpoint_path.endswith(".safetensors"):
state_dict = load_file(checkpoint_path)
else:
state_dict = torch.load(checkpoint_path, map_location="cpu")
clean_state_dict = {}
for k, v in state_dict.items():
k_clean = k.replace("module.", "")
if k_clean.startswith("t3."):
clean_state_dict[k_clean.replace("t3.", "")] = v
elif not any(x in k_clean for x in ["s3gen", "ve.", "tokenizer"]):
clean_state_dict[k_clean] = v
missing_keys, unexpected_keys = new_t3.load_state_dict(clean_state_dict, strict=False)
critical_missing = [k for k in missing_keys if "tfmr.layers" in k]
if len(critical_missing) > 0:
logger.error("[CRITICAL ERROR] Model weights COULD NOT BE LOADED!")
logger.error(f"Number of missing keys: {len(missing_keys)}")
logger.error(f"Examples of missing information: {critical_missing[:3]}")
logger.error("The sound produced will be 100% NOISE (Static Noise). Check your checkpoint recording method.")
elif len(missing_keys) > 0:
non_wte_missing = [k for k in missing_keys if "wte" not in k]
if len(non_wte_missing) > 0:
logger.warning(f"Some weights are missing ({len(non_wte_missing)} pieces): {non_wte_missing[:3]}...")
else:
logger.info("The weights were successfully loaded (except for the WTE - normal for the Turbo).")
else:
logger.info("All the weights were loaded completely and successfully.")
tts_engine.t3 = new_t3
tts_engine.t3.to(device).eval()
tts_engine.s3gen.to(device).eval()
tts_engine.ve.to(device).eval()
tts_engine.device = device
params = {
"temperature": 0.8,
"repetition_penalty": 1.2,
}
if not is_turbo:
params["cfg_weight"] = 0.2
params["exaggeration"]= 1.2,
with torch.no_grad():
wav = tts_engine.generate(
text=self.config.inference_test_text,
audio_prompt_path=self.config.inference_prompt_path,
**params
)
wav_np = wav.squeeze().cpu().numpy()
trimmed_wav = trim_silence_with_vad(wav_np, tts_engine.sr)
sf.write(output_path, trimmed_wav, tts_engine.sr)
logger.info(f"Example saved: {output_path}")
del tts_engine
del new_t3
del state_dict
del clean_state_dict
torch.cuda.empty_cache()