XurmoTTS / continue_training.py
jahongirtech's picture
Create continue_training.py
5ec8604 verified
import os
import torch
import librosa
import soundfile as sf
from pathlib import Path
# ─────────────────────────────────────────
# AUDIO PREPROCESSING
# ─────────────────────────────────────────
def preprocess_audio(dataset_path, target_sr=22050):
wavs_dir = os.path.join(dataset_path, "wavs")
wav_files = list(Path(wavs_dir).glob("*.wav"))
already_done = os.path.join(dataset_path, ".preprocessed")
if os.path.exists(already_done):
print("βœ… Audio allaqachon tayyor.")
return
print(f"πŸ”„ {len(wav_files)} ta wav qayta ishlanmoqda...")
for wav_path in wav_files:
audio, sr = librosa.load(str(wav_path), sr=target_sr, mono=True)
sf.write(str(wav_path), audio, target_sr)
open(already_done, "w").close()
print("βœ… Barcha wav mono + 22050 Hz ga o'tkazildi.")
dataset_path = "/content/drive/MyDrive/tts/dataset_final"
output_dir = "/content/drive/MyDrive/tts/output"
preprocess_audio(dataset_path)
# ─────────────────────────────────────────
# IMPORT
# ─────────────────────────────────────────
from TTS.tts.configs.shared_configs import CharactersConfig, BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits
from TTS.utils.audio import AudioProcessor
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.datasets import formatters
from trainer import Trainer, TrainerArgs
# ─────────────────────────────────────────
# CONFIG
# ─────────────────────────────────────────
config = VitsConfig(
run_name="Xurmo_Media_20",
batch_size=16,
eval_batch_size=8,
num_loader_workers=2,
num_eval_loader_workers=2,
epochs=1000,
text_cleaner="multilingual_cleaners",
use_phonemes=False,
mixed_precision=True,
run_eval=True,
save_step=1000,
save_n_checkpoints=3,
print_step=50,
output_path=output_dir,
characters=CharactersConfig(
characters="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzOΚ»oΚ»GΚ»gΚ»ShshChch'0123456789",
punctuations="!,.? ",
pad="<PAD>",
eos="<EOS>",
bos="<BOS>",
blank="<BLNK>",
),
)
config.audio.sample_rate = 22050
config.audio.do_trim_silence = True
config.audio.resample = False
# ─────────────────────────────────────────
# FORMATTER
# ─────────────────────────────────────────
def uzbek_formatter(root_path, meta_file, **kwargs):
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
cols = line.split("|")
if len(cols) < 2:
continue
wav_file = os.path.join(root_path, "wavs", cols[0].strip() + ".wav")
text = cols[1].strip()
text = text.replace("\u2018", "'").replace("\u2019", "'")
text = text.replace("\u02bc", "'").replace("\u0060", "'")
if not os.path.exists(wav_file):
continue
items.append({
"text": text,
"audio_file": wav_file,
"root_path": root_path,
"speaker_name": "xurmo media",
"language": "uz",
})
print(f"βœ… {len(items)} ta sample yuklandi.")
return items
# ─────────────────────────────────────────
# DATASET
# ─────────────────────────────────────────
dataset_config = BaseDatasetConfig(
formatter="",
dataset_name="uzbek_tts",
path=dataset_path,
meta_file_train="metadata.csv",
meta_file_val="",
language="uz",
)
train_samples, eval_samples = load_tts_samples(
dataset_config, # ← ro'yxat emas, to'g'ridan
eval_split=True,
eval_split_size=0.1,
formatter=uzbek_formatter, # ← funksiya nomi, string emas
)
# ─────────────────────────────────────────
# MODEL
# ─────────────────────────────────────────
tokenizer, config = TTSTokenizer.init_from_config(config)
ap = AudioProcessor.init_from_config(config)
model = Vits(config, ap, tokenizer, speaker_manager=None)
# ─────────────────────────────────────────
# RESUME β€” oxirgi checkpoint ni topamiz
# ─────────────────────────────────────────
restore_path = "/content/drive/MyDrive/tts/output/Xurmo_Media_20-May-19-2026_11+39AM-0000000/checkpoint_4000.pth"
run_dirs = None #sorted(Path("/kaggle/working/output/Xurmo Media 20-April-25-2026_01+39PM-0000000/").glob("Xurmo_Media_20*"), key=os.path.getmtime)
if run_dirs:
checkpoints = sorted(run_dirs[-1].glob("*.pth"), key=os.path.getmtime)
if checkpoints:
restore_path = str(checkpoints[-1])
print(f"πŸ” Resume: {restore_path}")
else:
print("πŸ†• Yangi training boshlanadi.")
# ─────────────────────────────────────────a
# TRAINER β€” DDP YO'Q, oddiy single-GPU
# ─────────────────────────────────────────
trainer_args = TrainerArgs(
restore_path=restore_path,
# use_ddp=False ← default, shuning uchun yozish shart emas
)
trainer = Trainer(
trainer_args,
config,
output_path=output_dir,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
)
print(f"""
╔══════════════════════════════════════╗
β•‘ πŸš€ 1x GPU O'QITISH β•‘
β•‘ Batch size : 16 β•‘
β•‘ Epochs : 1000 β•‘
β•‘ Resume : {'HA ' if restore_path else "YO'Q"} β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
""")
trainer.fit()