| import os |
| import glob |
| import math |
| import csv |
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from torch.nn.utils.rnn import pad_sequence |
| from tqdm import tqdm |
| from torch.amp import autocast |
|
|
| from config import (PAD_TOKEN_ID, START_OF_SPEECH_TOKEN_ID, |
| END_OF_SPEECH_TOKEN_ID, AUDIO_OFFSET) |
| from model import create_model, save_checkpoint |
| from tokenizer import TTSTokenizer |
|
|
| |
| PEAK_LR = 7e-5 |
| START_LR = 0 |
| MIN_LR = 5e-6 |
| WEIGHT_DECAY = 0.01 |
| EPOCHS = 20 |
| BATCH_SIZE = 64 |
| ACCUM_STEPS = 1 |
| GRAD_CLIP = 1.0 |
| CKPT_EVERY = 1000 |
| LOG_FILE = "train_log.csv" |
|
|
| |
| class ShardedTTSDataset(Dataset): |
| def __init__(self, data_dir): |
| self.shard_files = sorted(glob.glob(os.path.join(data_dir, "*.pt"))) |
| self.samples = [] |
| print(f"Зареждане на {len(self.shard_files)} шарда...") |
| for sf in self.shard_files: |
| self.samples.extend(torch.load(sf, weights_only=False)) |
| print(f"Общо записи: {len(self.samples):,}") |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| item = self.samples[idx] |
| return { |
| 'text_ids': item['text_ids'].clone().detach().long(), |
| 'audio_codes': item['audio_codes'].clone().detach().long(), |
| 'speaker_emb': item['speaker_emb'].clone().detach().float(), |
| } |
|
|
| def collate_fn(batch): |
| enc_ids_list, dec_ids_list, labels_list, speaker_embs = [], [], [], [] |
| for item in batch: |
| enc_ids_list.append(item['text_ids']) |
| audio_codes = item['audio_codes'] + AUDIO_OFFSET |
| |
| |
| dec_ids_list.append(torch.cat([torch.tensor([START_OF_SPEECH_TOKEN_ID]), audio_codes, torch.tensor([END_OF_SPEECH_TOKEN_ID])])) |
| labels_list.append(torch.cat([torch.tensor([-100]), audio_codes, torch.tensor([END_OF_SPEECH_TOKEN_ID])])) |
| speaker_embs.append(item['speaker_emb']) |
|
|
| enc_ids = pad_sequence(enc_ids_list, batch_first=True, padding_value=PAD_TOKEN_ID) |
| dec_ids = pad_sequence(dec_ids_list, batch_first=True, padding_value=PAD_TOKEN_ID) |
| labels = pad_sequence(labels_list, batch_first=True, padding_value=-100) |
| enc_mask = (enc_ids != PAD_TOKEN_ID).long() |
| speaker_emb = torch.stack(speaker_embs) |
| return enc_ids, dec_ids, enc_mask, labels, speaker_emb |
|
|
| |
| def get_lr(step: int, warmup_steps: int, total_steps: int) -> float: |
| if step < warmup_steps: |
| return START_LR + (PEAK_LR - START_LR) * (step / max(1, warmup_steps)) |
| else: |
| progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) |
| cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) |
| return MIN_LR + (PEAK_LR - MIN_LR) * cosine |
|
|
| |
| def train(): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Устройство: {device}") |
|
|
| processed_dir = os.path.abspath("../data/processed") |
| if not os.path.exists(processed_dir): |
| print(f"[ГРЕШКА] {processed_dir} не съществува!"); return |
|
|
| dataset = ShardedTTSDataset(processed_dir) |
| dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, |
| collate_fn=collate_fn, num_workers=4, pin_memory=True) |
|
|
| steps_per_epoch = len(dataloader) // ACCUM_STEPS |
| warmup_steps = steps_per_epoch * 2 |
| total_steps = steps_per_epoch * EPOCHS |
| print(f"Батчове/епоха: {len(dataloader):,} | Optimizer стъпки/епоха: {steps_per_epoch:,} | Accum: {ACCUM_STEPS}") |
| print(f"Warmup: {warmup_steps:,} стъпки (2 епохи) | Общо: {total_steps:,}") |
| print(f"Peak LR: {PEAK_LR}, Min LR: {MIN_LR}, Weight Decay: {WEIGHT_DECAY}, Epochs: {EPOCHS}") |
| print(f"Ефективен batch size: {BATCH_SIZE * ACCUM_STEPS}") |
|
|
| model = create_model(device=device) |
| model.train() |
| optimizer = torch.optim.AdamW(model.parameters(), lr=PEAK_LR, weight_decay=WEIGHT_DECAY, |
| betas=(0.9, 0.999), eps=1e-8) |
| |
|
|
| os.makedirs("checkpoints", exist_ok=True) |
|
|
| |
| log_path = LOG_FILE |
| log_f = open(log_path, "w", newline="") |
| writer = csv.writer(log_f) |
| writer.writerow(["step", "batch_loss", "avg_loss", "lr"]) |
| log_f.flush() |
| print(f"Loss лог: {log_path} (следи с: tail -f {log_path})\n") |
|
|
| step = 0 |
| running_loss = 0.0 |
| running_count = 0 |
|
|
| for epoch in range(EPOCHS): |
| loop = tqdm(total=steps_per_epoch, desc=f"Епоха {epoch+1}/{EPOCHS}") |
| epoch_loss_sum, valid_batches = 0.0, 0 |
|
|
| optimizer.zero_grad(set_to_none=True) |
| for i, (enc_ids, dec_ids, enc_mask, labels, spk_emb) in enumerate(dataloader): |
| enc_ids = enc_ids.to(device) |
| dec_ids = dec_ids.to(device) |
| enc_mask = enc_mask.to(device) |
| labels = labels.to(device) |
| spk_emb = spk_emb.to(device) |
|
|
| with autocast('cuda', dtype=torch.bfloat16): |
| out = model(enc_ids=enc_ids, dec_ids=dec_ids, |
| enc_mask=enc_mask, dec_labels=labels, |
| speaker_emb=spk_emb) |
| loss = out['loss'] / ACCUM_STEPS |
|
|
| loss.backward() |
|
|
| batch_loss = loss.item() * ACCUM_STEPS |
| epoch_loss_sum += batch_loss |
| valid_batches += 1 |
|
|
| if (i + 1) % ACCUM_STEPS == 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) |
| optimizer.step() |
| optimizer.zero_grad(set_to_none=True) |
| step += 1 |
|
|
| current_lr = get_lr(step, warmup_steps, total_steps) |
| for pg in optimizer.param_groups: |
| pg['lr'] = current_lr |
|
|
| running_loss += batch_loss |
| running_count += 1 |
| avg_loss = running_loss / running_count |
|
|
| writer.writerow([step, f"{batch_loss:.4f}", f"{avg_loss:.4f}", f"{current_lr:.2e}"]) |
| log_f.flush() |
|
|
| loop.update(1) |
| loop.set_postfix(step=step, loss=f"{batch_loss:.4f}", |
| avg=f"{avg_loss:.4f}", lr=f"{current_lr:.2e}") |
|
|
| if step % CKPT_EVERY == 0: |
| ckpt_dir = f"checkpoints/step_{step:06d}" |
| save_checkpoint(model, optimizer, None, step, |
| avg_loss, ckpt_dir, best_val_loss=None) |
| tqdm.write(f" ✓ Checkpoint запазен: {ckpt_dir} | step={step} | avg_loss={avg_loss:.4f}") |
|
|
| loop.close() |
| epoch_avg = epoch_loss_sum / max(1, valid_batches) |
| ckpt_dir = f"checkpoints/epoch_{epoch+1}_final" |
| save_checkpoint(model, optimizer, None, step, epoch_avg, ckpt_dir, best_val_loss=None) |
| print(f"\n✓ Епоха {epoch+1} завърши. Средна загуба: {epoch_avg:.4f}") |
| print(f" Checkpoint: {ckpt_dir}") |
|
|
| log_f.close() |
| print("\n[КРАЙ] Обучението приключи успешно!") |
|
|
| if __name__ == "__main__": |
| train() |
|
|