XurmoTTS / model_trainer.py
jahongirtech's picture
Create model_trainer.py
f0fc7d1 verified
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import librosa
import soundfile as sf
from pathlib import Path
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
# ─────────────────────────────────────────
# AUDIO PREPROCESSING
# ─────────────────────────────────────────
def preprocess_audio(dataset_path, target_sr=22050):
wavs_dir = os.path.join(dataset_path, "wavs")
wav_files = list(Path(wavs_dir).glob("*.wav"))
already_done = os.path.join(dataset_path, ".preprocessed")
if os.path.exists(already_done):
print("βœ… Audio allaqachon tayyor.")
return
print(f"πŸ”„ {len(wav_files)} ta wav qayta ishlanmoqda...")
for wav_path in wav_files:
audio, sr = librosa.load(str(wav_path), sr=target_sr, mono=True)
sf.write(str(wav_path), audio, target_sr)
open(already_done, "w").close()
print("βœ… Barcha wav mono + 22050 Hz ga o'tkazildi.")
dataset_path = "/content/drive/MyDrive/tts/dataset_final"
preprocess_audio(dataset_path)
# ─────────────────────────────────────────
# TRAIN FUNKSIYASI β€” har bir GPU uchun alohida ishga tushadi
# ─────────────────────────────────────────
def train(rank, world_size):
"""rank=0 β†’ GPU0, rank=1 β†’ GPU1"""
# DDP ni ishga tushirish
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
print(f"βœ… GPU {rank}/{world_size} ishga tushdi: {torch.cuda.get_device_name(rank)}")
from TTS.tts.configs.shared_configs import CharactersConfig, BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits
from TTS.utils.audio import AudioProcessor
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from trainer import Trainer, TrainerArgs
# ── CONFIG ──
config = VitsConfig(
run_name="Xurmo Media 20",
batch_size=16, # Har bir GPU uchun 16 β†’ jami 32
eval_batch_size=8,
num_loader_workers=2,
num_eval_loader_workers=2,
epochs=1000,
text_cleaner="multilingual_cleaners",
use_phonemes=False,
mixed_precision=True, # FP16 β€” T4 da 2x tezlik
run_eval=True,
save_step=1000,
save_n_checkpoints=3,
print_step=50,
output_path="/content/drive/MyDrive/tts/output",
characters=CharactersConfig(
characters="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzO'o'G'g'ShshChch'0123456789",
punctuations="!,.? ",
pad="<PAD>",
eos="<EOS>",
bos="<BOS>",
blank="<BLNK>",
),
)
config.audio.sample_rate = 22050
config.audio.do_trim_silence = True
config.audio.resample = False
# ── FORMATTER ──
def formatter(root_path, meta_file, **kwargs):
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
cols = line.split("|")
if len(cols) < 2:
continue
wav_file = os.path.join(root_path, "wavs", cols[0].strip() + ".wav")
text = cols[1].strip()
# Typographic apostrof β†’ oddiy apostrof
text = text.replace("\u2018", "'").replace("\u2019", "'")
text = text.replace("\u02bc", "'").replace("\u0060", "'")
if not os.path.exists(wav_file):
continue
items.append({
"text": text,
"audio_file": wav_file,
"root_path": root_path,
"speaker_name": "xurmo media",
"language": "uz",
})
if rank == 0:
print(f"βœ… {len(items)} ta sample yuklandi.")
return items
# ── DATASET ──
dataset_config = BaseDatasetConfig(
dataset_name="uzbek_tts",
path=dataset_path,
meta_file_train="metadata.csv",
meta_file_val="",
language="uz",
)
train_samples, eval_samples = load_tts_samples(
[dataset_config],
eval_split=True,
eval_split_size=0.1,
formatter=formatter,
)
# ── MODEL ──
tokenizer, config = TTSTokenizer.init_from_config(config)
ap = AudioProcessor.init_from_config(config)
model = Vits(config, ap, tokenizer, speaker_manager=None)
# ── TRAINER β€” rank va world_size ni uzatamiz ──
trainer_args = TrainerArgs(
rank=rank,
group_id=f"group_{rank}",
use_ddp=True,
grad_accum_steps=1, # VITS GAN uchun majburiy =1
)
trainer = Trainer(
trainer_args,
config,
output_path="/kaggle/working/output",
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
)
if rank == 0:
print(f"""
╔══════════════════════════════════════╗
β•‘ πŸš€ Colab T4 O'QITISH β•‘
β•‘ Har GPU batch : 16 β•‘
β•‘ Effective batch: 32 β•‘
β•‘ Epochs : 1000
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
""")
trainer.fit()
dist.destroy_process_group()
# ─────────────────────────────────────────
# ISHGA TUSHIRISH
# ─────────────────────────────────────────
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"πŸ–₯️ Topilgan GPU: {world_size} ta")
if world_size < 2:
print("⚠️ Faqat 1 GPU topildi! Kaggle Settings β†’ Accelerator β†’ GPU T4 x2 tanlang.")
# Baribir 1 GPU bilan ishlaydi
train(0, 1)
else:
# Ikkala GPU ni parallel ishga tushirish
mp.spawn(
train,
args=(world_size,),
nprocs=world_size,
join=True
)