File size: 5,769 Bytes
38a17ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
import torch
import soundfile as sf
from transformers import TrainerCallback
from safetensors.torch import load_file

from src.chatterbox_.tts import ChatterboxTTS
from src.chatterbox_.tts_turbo import ChatterboxTurboTTS
from src.chatterbox_.models.t3.t3 import T3
from src.utils import setup_logger, trim_silence_with_vad


logger = setup_logger("InferenceCallback")


class InferenceCallback(TrainerCallback):
    
    def __init__(self, config):
        
        self.config = config
        self.inference_dir = os.path.join(config.output_dir, "inference_samples")
        os.makedirs(self.inference_dir, exist_ok=True)
        

        if not hasattr(config, 'inference_prompt_path') or not config.inference_prompt_path:
            logger.warning("The inference prompt path is not specified; sampling will be skipped.")
            self.skip_inference = True
            
        elif not hasattr(config, 'inference_test_text') or not config.inference_test_text:
            logger.warning("The inference test text is not specified; the sample will be skipped.")
            self.skip_inference = True
            
        else:
            self.skip_inference = False
            logger.info(f"Inference Callback is ready. Examples will be saved here: {self.inference_dir}")


    def on_save(self, args, state, control, **kwargs):

        if self.skip_inference:
            return
        
        step = state.global_step
        checkpoint_dir = os.path.join(args.output_dir, f"checkpoint-{step}")
        
        weights_path = os.path.join(checkpoint_dir, "model.safetensors")
        if not os.path.exists(weights_path):
            weights_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
        
        if not os.path.exists(weights_path):
            logger.warning(f"Checkpoint weights could not be found: {checkpoint_dir}")
            return

        logger.info(f"Initializing inference for checkpoint-{step}...")


        try:
            
            output_path = os.path.join(self.inference_dir, f"checkpoint-{step}.wav")
            self._generate_sample(weights_path, output_path)
            
        except Exception as e:
            logger.error(f"An error occurred during the inference (Step: {step}): {e}", exc_info=True)



    def _generate_sample(self, checkpoint_path: str, output_path: str):

        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        is_turbo = getattr(self.config, "is_turbo", False)
        
        EngineClass = ChatterboxTurboTTS if is_turbo else ChatterboxTTS
        
        tts_engine = EngineClass.from_local(self.config.model_dir, device="cpu")
        
        t3_config = tts_engine.t3.hp
        if hasattr(self.config, 'new_vocab_size'):
            t3_config.text_tokens_dict_size = self.config.new_vocab_size
        
        new_t3 = T3(hp=t3_config)
        

        if is_turbo:
            if hasattr(new_t3.tfmr, "wte"):
                del new_t3.tfmr.wte


        if checkpoint_path.endswith(".safetensors"):
            state_dict = load_file(checkpoint_path)
        else:
            state_dict = torch.load(checkpoint_path, map_location="cpu")


        clean_state_dict = {}
        for k, v in state_dict.items():
            k_clean = k.replace("module.", "")
            
            if k_clean.startswith("t3."):
                clean_state_dict[k_clean.replace("t3.", "")] = v
            
            elif not any(x in k_clean for x in ["s3gen", "ve.", "tokenizer"]):
                clean_state_dict[k_clean] = v


        missing_keys, unexpected_keys = new_t3.load_state_dict(clean_state_dict, strict=False)
        
        
        critical_missing = [k for k in missing_keys if "tfmr.layers" in k]
        
        
        if len(critical_missing) > 0:
            logger.error("[CRITICAL ERROR] Model weights COULD NOT BE LOADED!")
            logger.error(f"Number of missing keys: {len(missing_keys)}")
            logger.error(f"Examples of missing information: {critical_missing[:3]}")
            logger.error("The sound produced will be 100% NOISE (Static Noise). Check your checkpoint recording method.")
        
        elif len(missing_keys) > 0:
            
            non_wte_missing = [k for k in missing_keys if "wte" not in k]
            if len(non_wte_missing) > 0:
                logger.warning(f"Some weights are missing ({len(non_wte_missing)} pieces): {non_wte_missing[:3]}...")
            
            else:
                logger.info("The weights were successfully loaded (except for the WTE - normal for the Turbo).")
        
        else:
            logger.info("All the weights were loaded completely and successfully.")


        tts_engine.t3 = new_t3
        
        tts_engine.t3.to(device).eval()
        tts_engine.s3gen.to(device).eval()
        tts_engine.ve.to(device).eval()

        tts_engine.device = device
        

        params = {
            "temperature": 0.8,
            "repetition_penalty": 1.2,
        }


        if not is_turbo:
            params["cfg_weight"] = 0.2
            params["exaggeration"]= 1.2,


        with torch.no_grad():
            wav = tts_engine.generate(
                text=self.config.inference_test_text,
                audio_prompt_path=self.config.inference_prompt_path,
                **params
            )
        
        
        wav_np = wav.squeeze().cpu().numpy()
        trimmed_wav = trim_silence_with_vad(wav_np, tts_engine.sr)
        
        
        sf.write(output_path, trimmed_wav, tts_engine.sr)
        logger.info(f"Example saved: {output_path}")
        
        
        del tts_engine
        del new_t3
        del state_dict
        del clean_state_dict
        torch.cuda.empty_cache()