import os import glob import math import csv import torch from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence from tqdm import tqdm from torch.amp import autocast from config import (PAD_TOKEN_ID, START_OF_SPEECH_TOKEN_ID, END_OF_SPEECH_TOKEN_ID, AUDIO_OFFSET) from model import create_model, save_checkpoint from tokenizer import TTSTokenizer # ── Хиперпараметри ─────────────────────────────────────────────── PEAK_LR = 7e-5 START_LR = 0 MIN_LR = 5e-6 WEIGHT_DECAY = 0.01 EPOCHS = 20 BATCH_SIZE = 64 ACCUM_STEPS = 1 # Без accumulation GRAD_CLIP = 1.0 CKPT_EVERY = 1000 # Checkpoint на всеки N optimizer стъпки LOG_FILE = "train_log.csv" # ── Dataset ────────────────────────────────────────────────────── class ShardedTTSDataset(Dataset): def __init__(self, data_dir): self.shard_files = sorted(glob.glob(os.path.join(data_dir, "*.pt"))) self.samples = [] print(f"Зареждане на {len(self.shard_files)} шарда...") for sf in self.shard_files: self.samples.extend(torch.load(sf, weights_only=False)) print(f"Общо записи: {len(self.samples):,}") def __len__(self): return len(self.samples) def __getitem__(self, idx): item = self.samples[idx] return { 'text_ids': item['text_ids'].clone().detach().long(), 'audio_codes': item['audio_codes'].clone().detach().long(), 'speaker_emb': item['speaker_emb'].clone().detach().float(), } def collate_fn(batch): enc_ids_list, dec_ids_list, labels_list, speaker_embs = [], [], [], [] for item in batch: enc_ids_list.append(item['text_ids']) audio_codes = item['audio_codes'] + AUDIO_OFFSET # GPT-style: model.py вътрешно shift-ва logits[:, :-1] vs labels[:, 1:] # Затова dec_ids и labels трябва да са подравнени, а model-ът сам измества. dec_ids_list.append(torch.cat([torch.tensor([START_OF_SPEECH_TOKEN_ID]), audio_codes, torch.tensor([END_OF_SPEECH_TOKEN_ID])])) labels_list.append(torch.cat([torch.tensor([-100]), audio_codes, torch.tensor([END_OF_SPEECH_TOKEN_ID])])) speaker_embs.append(item['speaker_emb']) enc_ids = pad_sequence(enc_ids_list, batch_first=True, padding_value=PAD_TOKEN_ID) dec_ids = pad_sequence(dec_ids_list, batch_first=True, padding_value=PAD_TOKEN_ID) labels = pad_sequence(labels_list, batch_first=True, padding_value=-100) enc_mask = (enc_ids != PAD_TOKEN_ID).long() speaker_emb = torch.stack(speaker_embs) return enc_ids, dec_ids, enc_mask, labels, speaker_emb # ── LR Scheduler: Warmup + Cosine Decay ───────────────────────── def get_lr(step: int, warmup_steps: int, total_steps: int) -> float: if step < warmup_steps: return START_LR + (PEAK_LR - START_LR) * (step / max(1, warmup_steps)) else: progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) return MIN_LR + (PEAK_LR - MIN_LR) * cosine # ── Основен тренировъчен цикъл ─────────────────────────────────── def train(): device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Устройство: {device}") processed_dir = os.path.abspath("../data/processed") if not os.path.exists(processed_dir): print(f"[ГРЕШКА] {processed_dir} не съществува!"); return dataset = ShardedTTSDataset(processed_dir) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=4, pin_memory=True) steps_per_epoch = len(dataloader) // ACCUM_STEPS # optimizer стъпки на епоха warmup_steps = steps_per_epoch * 2 # Warmup = 2 епохи total_steps = steps_per_epoch * EPOCHS print(f"Батчове/епоха: {len(dataloader):,} | Optimizer стъпки/епоха: {steps_per_epoch:,} | Accum: {ACCUM_STEPS}") print(f"Warmup: {warmup_steps:,} стъпки (2 епохи) | Общо: {total_steps:,}") print(f"Peak LR: {PEAK_LR}, Min LR: {MIN_LR}, Weight Decay: {WEIGHT_DECAY}, Epochs: {EPOCHS}") print(f"Ефективен batch size: {BATCH_SIZE * ACCUM_STEPS}") model = create_model(device=device) model.train() optimizer = torch.optim.AdamW(model.parameters(), lr=PEAK_LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999), eps=1e-8) # BF16 — без GradScaler (не е нужен при bfloat16) os.makedirs("checkpoints", exist_ok=True) # CSV лог за реално наблюдение log_path = LOG_FILE log_f = open(log_path, "w", newline="") writer = csv.writer(log_f) writer.writerow(["step", "batch_loss", "avg_loss", "lr"]) log_f.flush() print(f"Loss лог: {log_path} (следи с: tail -f {log_path})\n") step = 0 running_loss = 0.0 running_count = 0 for epoch in range(EPOCHS): loop = tqdm(total=steps_per_epoch, desc=f"Епоха {epoch+1}/{EPOCHS}") epoch_loss_sum, valid_batches = 0.0, 0 optimizer.zero_grad(set_to_none=True) for i, (enc_ids, dec_ids, enc_mask, labels, spk_emb) in enumerate(dataloader): enc_ids = enc_ids.to(device) dec_ids = dec_ids.to(device) enc_mask = enc_mask.to(device) labels = labels.to(device) spk_emb = spk_emb.to(device) with autocast('cuda', dtype=torch.bfloat16): out = model(enc_ids=enc_ids, dec_ids=dec_ids, enc_mask=enc_mask, dec_labels=labels, speaker_emb=spk_emb) loss = out['loss'] / ACCUM_STEPS loss.backward() batch_loss = loss.item() * ACCUM_STEPS # реалният loss epoch_loss_sum += batch_loss valid_batches += 1 if (i + 1) % ACCUM_STEPS == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) optimizer.step() optimizer.zero_grad(set_to_none=True) step += 1 current_lr = get_lr(step, warmup_steps, total_steps) for pg in optimizer.param_groups: pg['lr'] = current_lr running_loss += batch_loss running_count += 1 avg_loss = running_loss / running_count writer.writerow([step, f"{batch_loss:.4f}", f"{avg_loss:.4f}", f"{current_lr:.2e}"]) log_f.flush() loop.update(1) loop.set_postfix(step=step, loss=f"{batch_loss:.4f}", avg=f"{avg_loss:.4f}", lr=f"{current_lr:.2e}") if step % CKPT_EVERY == 0: ckpt_dir = f"checkpoints/step_{step:06d}" save_checkpoint(model, optimizer, None, step, avg_loss, ckpt_dir, best_val_loss=None) tqdm.write(f" ✓ Checkpoint запазен: {ckpt_dir} | step={step} | avg_loss={avg_loss:.4f}") loop.close() epoch_avg = epoch_loss_sum / max(1, valid_batches) ckpt_dir = f"checkpoints/epoch_{epoch+1}_final" save_checkpoint(model, optimizer, None, step, epoch_avg, ckpt_dir, best_val_loss=None) print(f"\n✓ Епоха {epoch+1} завърши. Средна загуба: {epoch_avg:.4f}") print(f" Checkpoint: {ckpt_dir}") log_f.close() print("\n[КРАЙ] Обучението приключи успешно!") if __name__ == "__main__": train()