File size: 8,253 Bytes
7c72478 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | import os
import glob
import math
import csv
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from torch.amp import autocast
from config import (PAD_TOKEN_ID, START_OF_SPEECH_TOKEN_ID,
END_OF_SPEECH_TOKEN_ID, AUDIO_OFFSET)
from model import create_model, save_checkpoint
from tokenizer import TTSTokenizer
# ── Хиперпараметри ───────────────────────────────────────────────
PEAK_LR = 7e-5
START_LR = 0
MIN_LR = 5e-6
WEIGHT_DECAY = 0.01
EPOCHS = 20
BATCH_SIZE = 64
ACCUM_STEPS = 1 # Без accumulation
GRAD_CLIP = 1.0
CKPT_EVERY = 1000 # Checkpoint на всеки N optimizer стъпки
LOG_FILE = "train_log.csv"
# ── Dataset ──────────────────────────────────────────────────────
class ShardedTTSDataset(Dataset):
def __init__(self, data_dir):
self.shard_files = sorted(glob.glob(os.path.join(data_dir, "*.pt")))
self.samples = []
print(f"Зареждане на {len(self.shard_files)} шарда...")
for sf in self.shard_files:
self.samples.extend(torch.load(sf, weights_only=False))
print(f"Общо записи: {len(self.samples):,}")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
item = self.samples[idx]
return {
'text_ids': item['text_ids'].clone().detach().long(),
'audio_codes': item['audio_codes'].clone().detach().long(),
'speaker_emb': item['speaker_emb'].clone().detach().float(),
}
def collate_fn(batch):
enc_ids_list, dec_ids_list, labels_list, speaker_embs = [], [], [], []
for item in batch:
enc_ids_list.append(item['text_ids'])
audio_codes = item['audio_codes'] + AUDIO_OFFSET
# GPT-style: model.py вътрешно shift-ва logits[:, :-1] vs labels[:, 1:]
# Затова dec_ids и labels трябва да са подравнени, а model-ът сам измества.
dec_ids_list.append(torch.cat([torch.tensor([START_OF_SPEECH_TOKEN_ID]), audio_codes, torch.tensor([END_OF_SPEECH_TOKEN_ID])]))
labels_list.append(torch.cat([torch.tensor([-100]), audio_codes, torch.tensor([END_OF_SPEECH_TOKEN_ID])]))
speaker_embs.append(item['speaker_emb'])
enc_ids = pad_sequence(enc_ids_list, batch_first=True, padding_value=PAD_TOKEN_ID)
dec_ids = pad_sequence(dec_ids_list, batch_first=True, padding_value=PAD_TOKEN_ID)
labels = pad_sequence(labels_list, batch_first=True, padding_value=-100)
enc_mask = (enc_ids != PAD_TOKEN_ID).long()
speaker_emb = torch.stack(speaker_embs)
return enc_ids, dec_ids, enc_mask, labels, speaker_emb
# ── LR Scheduler: Warmup + Cosine Decay ─────────────────────────
def get_lr(step: int, warmup_steps: int, total_steps: int) -> float:
if step < warmup_steps:
return START_LR + (PEAK_LR - START_LR) * (step / max(1, warmup_steps))
else:
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
return MIN_LR + (PEAK_LR - MIN_LR) * cosine
# ── Основен тренировъчен цикъл ───────────────────────────────────
def train():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Устройство: {device}")
processed_dir = os.path.abspath("../data/processed")
if not os.path.exists(processed_dir):
print(f"[ГРЕШКА] {processed_dir} не съществува!"); return
dataset = ShardedTTSDataset(processed_dir)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True,
collate_fn=collate_fn, num_workers=4, pin_memory=True)
steps_per_epoch = len(dataloader) // ACCUM_STEPS # optimizer стъпки на епоха
warmup_steps = steps_per_epoch * 2 # Warmup = 2 епохи
total_steps = steps_per_epoch * EPOCHS
print(f"Батчове/епоха: {len(dataloader):,} | Optimizer стъпки/епоха: {steps_per_epoch:,} | Accum: {ACCUM_STEPS}")
print(f"Warmup: {warmup_steps:,} стъпки (2 епохи) | Общо: {total_steps:,}")
print(f"Peak LR: {PEAK_LR}, Min LR: {MIN_LR}, Weight Decay: {WEIGHT_DECAY}, Epochs: {EPOCHS}")
print(f"Ефективен batch size: {BATCH_SIZE * ACCUM_STEPS}")
model = create_model(device=device)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=PEAK_LR, weight_decay=WEIGHT_DECAY,
betas=(0.9, 0.999), eps=1e-8)
# BF16 — без GradScaler (не е нужен при bfloat16)
os.makedirs("checkpoints", exist_ok=True)
# CSV лог за реално наблюдение
log_path = LOG_FILE
log_f = open(log_path, "w", newline="")
writer = csv.writer(log_f)
writer.writerow(["step", "batch_loss", "avg_loss", "lr"])
log_f.flush()
print(f"Loss лог: {log_path} (следи с: tail -f {log_path})\n")
step = 0
running_loss = 0.0
running_count = 0
for epoch in range(EPOCHS):
loop = tqdm(total=steps_per_epoch, desc=f"Епоха {epoch+1}/{EPOCHS}")
epoch_loss_sum, valid_batches = 0.0, 0
optimizer.zero_grad(set_to_none=True)
for i, (enc_ids, dec_ids, enc_mask, labels, spk_emb) in enumerate(dataloader):
enc_ids = enc_ids.to(device)
dec_ids = dec_ids.to(device)
enc_mask = enc_mask.to(device)
labels = labels.to(device)
spk_emb = spk_emb.to(device)
with autocast('cuda', dtype=torch.bfloat16):
out = model(enc_ids=enc_ids, dec_ids=dec_ids,
enc_mask=enc_mask, dec_labels=labels,
speaker_emb=spk_emb)
loss = out['loss'] / ACCUM_STEPS
loss.backward()
batch_loss = loss.item() * ACCUM_STEPS # реалният loss
epoch_loss_sum += batch_loss
valid_batches += 1
if (i + 1) % ACCUM_STEPS == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
step += 1
current_lr = get_lr(step, warmup_steps, total_steps)
for pg in optimizer.param_groups:
pg['lr'] = current_lr
running_loss += batch_loss
running_count += 1
avg_loss = running_loss / running_count
writer.writerow([step, f"{batch_loss:.4f}", f"{avg_loss:.4f}", f"{current_lr:.2e}"])
log_f.flush()
loop.update(1)
loop.set_postfix(step=step, loss=f"{batch_loss:.4f}",
avg=f"{avg_loss:.4f}", lr=f"{current_lr:.2e}")
if step % CKPT_EVERY == 0:
ckpt_dir = f"checkpoints/step_{step:06d}"
save_checkpoint(model, optimizer, None, step,
avg_loss, ckpt_dir, best_val_loss=None)
tqdm.write(f" ✓ Checkpoint запазен: {ckpt_dir} | step={step} | avg_loss={avg_loss:.4f}")
loop.close()
epoch_avg = epoch_loss_sum / max(1, valid_batches)
ckpt_dir = f"checkpoints/epoch_{epoch+1}_final"
save_checkpoint(model, optimizer, None, step, epoch_avg, ckpt_dir, best_val_loss=None)
print(f"\n✓ Епоха {epoch+1} завърши. Средна загуба: {epoch_avg:.4f}")
print(f" Checkpoint: {ckpt_dir}")
log_f.close()
print("\n[КРАЙ] Обучението приключи успешно!")
if __name__ == "__main__":
train()
|