import torch from torch.utils.data import DataLoader from torch.optim import AdamW from accelerate import Accelerator from tqdm import tqdm import os from .model import build_model from .config import HexaConfig from .dataset import HexaDataset, collate_fn def train(): """ Massive Scale Training Loop. """ # 1. Setup config = HexaConfig() # Gradient Accumulation is CRITICAL for large models on small GPUs accelerator = Accelerator(gradient_accumulation_steps=16) print(f"Initializing 5B Parameter Model... (This takes memory!)") try: model = build_model() except RuntimeError as e: print(f"Error initializing full model: {e}") print("Fallback: Your GPU memory is too small for 5B. Please try reducing config.dim in config.py") return # 2. Data data_root = "d:\\hexatts\\data" if not os.path.exists(data_root): print("Data not found. Run 'python get_data.py' first.") return dataset = HexaDataset(data_root, config) dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn) # 3. Optimize optimizer = AdamW(model.parameters(), lr=1e-4) # Standard LR model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) print("Starting Training...") model.train() # 4. Loop global_step = 0 epochs = 5 # arbitrary for demo for epoch in range(epochs): progress_bar = tqdm(total=len(dataloader), desc=f"Epoch {epoch+1}") for batch in dataloader: with accelerator.accumulate(model): text, speakers, langs, emotions, target_mels = batch # Check shapes # Output: [Batch, Time, Channels] # Target: [Batch, Time, Channels] output_mels = model(text, speakers, langs, emotions) # Align lengths (Simple truncation to min length for loss) min_len = min(output_mels.shape[1], target_mels.shape[1]) output_sliced = output_mels[:, :min_len, :] target_sliced = target_mels[:, :min_len, :] loss = torch.nn.functional.mse_loss(output_sliced, target_sliced) accelerator.backward(loss) optimizer.step() optimizer.zero_grad() progress_bar.set_postfix(loss=loss.item()) progress_bar.update(1) global_step += 1 # Save Checkpoint save_path = os.path.join("checkpoints", f"checkpoint_epoch_{epoch}") os.makedirs(save_path, exist_ok=True) accelerator.save_state(save_path) print(f"Saved checkpoint to {save_path}") if __name__ == "__main__": train()