| import os
|
| import glob
|
| import math
|
| import csv
|
| import torch
|
| from torch.utils.data import Dataset, DataLoader
|
| from torch.nn.utils.rnn import pad_sequence
|
| from tqdm import tqdm
|
| from torch.amp import autocast
|
|
|
| from config import (PAD_TOKEN_ID, START_OF_SPEECH_TOKEN_ID,
|
| END_OF_SPEECH_TOKEN_ID, AUDIO_OFFSET)
|
| from model import create_model, save_checkpoint
|
| from tokenizer import TTSTokenizer
|
|
|
|
|
| PEAK_LR = 7e-5
|
| START_LR = 0
|
| MIN_LR = 5e-6
|
| WEIGHT_DECAY = 0.01
|
| EPOCHS = 20
|
| BATCH_SIZE = 64
|
| ACCUM_STEPS = 1
|
| GRAD_CLIP = 1.0
|
| CKPT_EVERY = 1000
|
| LOG_FILE = "train_log.csv"
|
|
|
|
|
| class ShardedTTSDataset(Dataset):
|
| def __init__(self, data_dir):
|
| self.shard_files = sorted(glob.glob(os.path.join(data_dir, "*.pt")))
|
| self.samples = []
|
| print(f"Зареждане на {len(self.shard_files)} шарда...")
|
| for sf in self.shard_files:
|
| self.samples.extend(torch.load(sf, weights_only=False))
|
| print(f"Общо записи: {len(self.samples):,}")
|
|
|
| def __len__(self):
|
| return len(self.samples)
|
|
|
| def __getitem__(self, idx):
|
| item = self.samples[idx]
|
| return {
|
| 'text_ids': item['text_ids'].clone().detach().long(),
|
| 'audio_codes': item['audio_codes'].clone().detach().long(),
|
| 'speaker_emb': item['speaker_emb'].clone().detach().float(),
|
| }
|
|
|
| def collate_fn(batch):
|
| enc_ids_list, dec_ids_list, labels_list, speaker_embs = [], [], [], []
|
| for item in batch:
|
| enc_ids_list.append(item['text_ids'])
|
| audio_codes = item['audio_codes'] + AUDIO_OFFSET
|
|
|
|
|
| dec_ids_list.append(torch.cat([torch.tensor([START_OF_SPEECH_TOKEN_ID]), audio_codes, torch.tensor([END_OF_SPEECH_TOKEN_ID])]))
|
| labels_list.append(torch.cat([torch.tensor([-100]), audio_codes, torch.tensor([END_OF_SPEECH_TOKEN_ID])]))
|
| speaker_embs.append(item['speaker_emb'])
|
|
|
| enc_ids = pad_sequence(enc_ids_list, batch_first=True, padding_value=PAD_TOKEN_ID)
|
| dec_ids = pad_sequence(dec_ids_list, batch_first=True, padding_value=PAD_TOKEN_ID)
|
| labels = pad_sequence(labels_list, batch_first=True, padding_value=-100)
|
| enc_mask = (enc_ids != PAD_TOKEN_ID).long()
|
| speaker_emb = torch.stack(speaker_embs)
|
| return enc_ids, dec_ids, enc_mask, labels, speaker_emb
|
|
|
|
|
| def get_lr(step: int, warmup_steps: int, total_steps: int) -> float:
|
| if step < warmup_steps:
|
| return START_LR + (PEAK_LR - START_LR) * (step / max(1, warmup_steps))
|
| else:
|
| progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
|
| cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
|
| return MIN_LR + (PEAK_LR - MIN_LR) * cosine
|
|
|
|
|
| def train():
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
| print(f"Устройство: {device}")
|
|
|
| processed_dir = os.path.abspath("../data/processed")
|
| if not os.path.exists(processed_dir):
|
| print(f"[ГРЕШКА] {processed_dir} не съществува!"); return
|
|
|
| dataset = ShardedTTSDataset(processed_dir)
|
| dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True,
|
| collate_fn=collate_fn, num_workers=4, pin_memory=True)
|
|
|
| steps_per_epoch = len(dataloader) // ACCUM_STEPS
|
| warmup_steps = steps_per_epoch * 2
|
| total_steps = steps_per_epoch * EPOCHS
|
| print(f"Батчове/епоха: {len(dataloader):,} | Optimizer стъпки/епоха: {steps_per_epoch:,} | Accum: {ACCUM_STEPS}")
|
| print(f"Warmup: {warmup_steps:,} стъпки (2 епохи) | Общо: {total_steps:,}")
|
| print(f"Peak LR: {PEAK_LR}, Min LR: {MIN_LR}, Weight Decay: {WEIGHT_DECAY}, Epochs: {EPOCHS}")
|
| print(f"Ефективен batch size: {BATCH_SIZE * ACCUM_STEPS}")
|
|
|
| model = create_model(device=device)
|
| model.train()
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=PEAK_LR, weight_decay=WEIGHT_DECAY,
|
| betas=(0.9, 0.999), eps=1e-8)
|
|
|
|
|
| os.makedirs("checkpoints", exist_ok=True)
|
|
|
|
|
| log_path = LOG_FILE
|
| log_f = open(log_path, "w", newline="")
|
| writer = csv.writer(log_f)
|
| writer.writerow(["step", "batch_loss", "avg_loss", "lr"])
|
| log_f.flush()
|
| print(f"Loss лог: {log_path} (следи с: tail -f {log_path})\n")
|
|
|
| step = 0
|
| running_loss = 0.0
|
| running_count = 0
|
|
|
| for epoch in range(EPOCHS):
|
| loop = tqdm(total=steps_per_epoch, desc=f"Епоха {epoch+1}/{EPOCHS}")
|
| epoch_loss_sum, valid_batches = 0.0, 0
|
|
|
| optimizer.zero_grad(set_to_none=True)
|
| for i, (enc_ids, dec_ids, enc_mask, labels, spk_emb) in enumerate(dataloader):
|
| enc_ids = enc_ids.to(device)
|
| dec_ids = dec_ids.to(device)
|
| enc_mask = enc_mask.to(device)
|
| labels = labels.to(device)
|
| spk_emb = spk_emb.to(device)
|
|
|
| with autocast('cuda', dtype=torch.bfloat16):
|
| out = model(enc_ids=enc_ids, dec_ids=dec_ids,
|
| enc_mask=enc_mask, dec_labels=labels,
|
| speaker_emb=spk_emb)
|
| loss = out['loss'] / ACCUM_STEPS
|
|
|
| loss.backward()
|
|
|
| batch_loss = loss.item() * ACCUM_STEPS
|
| epoch_loss_sum += batch_loss
|
| valid_batches += 1
|
|
|
| if (i + 1) % ACCUM_STEPS == 0:
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
|
| optimizer.step()
|
| optimizer.zero_grad(set_to_none=True)
|
| step += 1
|
|
|
| current_lr = get_lr(step, warmup_steps, total_steps)
|
| for pg in optimizer.param_groups:
|
| pg['lr'] = current_lr
|
|
|
| running_loss += batch_loss
|
| running_count += 1
|
| avg_loss = running_loss / running_count
|
|
|
| writer.writerow([step, f"{batch_loss:.4f}", f"{avg_loss:.4f}", f"{current_lr:.2e}"])
|
| log_f.flush()
|
|
|
| loop.update(1)
|
| loop.set_postfix(step=step, loss=f"{batch_loss:.4f}",
|
| avg=f"{avg_loss:.4f}", lr=f"{current_lr:.2e}")
|
|
|
| if step % CKPT_EVERY == 0:
|
| ckpt_dir = f"checkpoints/step_{step:06d}"
|
| save_checkpoint(model, optimizer, None, step,
|
| avg_loss, ckpt_dir, best_val_loss=None)
|
| tqdm.write(f" ✓ Checkpoint запазен: {ckpt_dir} | step={step} | avg_loss={avg_loss:.4f}")
|
|
|
| loop.close()
|
| epoch_avg = epoch_loss_sum / max(1, valid_batches)
|
| ckpt_dir = f"checkpoints/epoch_{epoch+1}_final"
|
| save_checkpoint(model, optimizer, None, step, epoch_avg, ckpt_dir, best_val_loss=None)
|
| print(f"\n✓ Епоха {epoch+1} завърши. Средна загуба: {epoch_avg:.4f}")
|
| print(f" Checkpoint: {ckpt_dir}")
|
|
|
| log_f.close()
|
| print("\n[КРАЙ] Обучението приключи успешно!")
|
|
|
| if __name__ == "__main__":
|
| train()
|
|
|