|
|
import os |
|
|
import torch |
|
|
import soundfile as sf |
|
|
from src.chatterbox_.tts import ChatterboxTTS |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
|
|
|
|
MODEL_DIR = "./pretrained_models" |
|
|
|
|
|
|
|
|
|
|
|
FINETUNED_WEIGHTS = "./models/best_finnish_multilingual_cp986.safetensors" |
|
|
|
|
|
|
|
|
TEXT = "Tervetuloa kokeilemaan hienoviritettyä suomenkielistä Chatterbox-puhesynteesiä." |
|
|
|
|
|
|
|
|
REFERENCE_AUDIO = "./samples/reference_finnish.wav" |
|
|
|
|
|
|
|
|
OUTPUT_FILE = "output_finnish.wav" |
|
|
|
|
|
|
|
|
def main(): |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
print(f"Loading base model from {MODEL_DIR}...") |
|
|
engine = ChatterboxTTS.from_local(MODEL_DIR, device=device) |
|
|
|
|
|
|
|
|
if os.path.exists(FINETUNED_WEIGHTS): |
|
|
print(f"Loading finetuned weights from {FINETUNED_WEIGHTS}...") |
|
|
checkpoint_state = load_file(FINETUNED_WEIGHTS) |
|
|
|
|
|
|
|
|
t3_state_dict = {k[3:] if k.startswith("t3.") else k: v for k, v in checkpoint_state.items()} |
|
|
|
|
|
|
|
|
engine.t3.load_state_dict(t3_state_dict, strict=False) |
|
|
else: |
|
|
print(f"Warning: Finetuned weights not found at {FINETUNED_WEIGHTS}. Using base weights.") |
|
|
|
|
|
|
|
|
print(f"Generating audio for: '{TEXT}'") |
|
|
|
|
|
wav_tensor = engine.generate( |
|
|
text=TEXT, |
|
|
audio_prompt_path=REFERENCE_AUDIO, |
|
|
repetition_penalty=1.2, |
|
|
temperature=0.8, |
|
|
exaggeration=0.6 |
|
|
) |
|
|
|
|
|
|
|
|
wav_np = wav_tensor.squeeze().cpu().numpy() |
|
|
sf.write(OUTPUT_FILE, wav_np, engine.sr) |
|
|
print(f"Successfully saved audio to {OUTPUT_FILE}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|