Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.utils.data import DataLoader | |
| from torch.optim import AdamW | |
| from accelerate import Accelerator | |
| from tqdm import tqdm | |
| import os | |
| from .model import build_model | |
| from .config import HexaConfig | |
| from .dataset import HexaDataset, collate_fn | |
| def train(): | |
| """ | |
| Massive Scale Training Loop. | |
| """ | |
| # 1. Setup | |
| config = HexaConfig() | |
| # Gradient Accumulation is CRITICAL for large models on small GPUs | |
| accelerator = Accelerator(gradient_accumulation_steps=16) | |
| print(f"Initializing 5B Parameter Model... (This takes memory!)") | |
| try: | |
| model = build_model() | |
| except RuntimeError as e: | |
| print(f"Error initializing full model: {e}") | |
| print("Fallback: Your GPU memory is too small for 5B. Please try reducing config.dim in config.py") | |
| return | |
| # 2. Data | |
| data_root = "d:\\hexatts\\data" | |
| if not os.path.exists(data_root): | |
| print("Data not found. Run 'python get_data.py' first.") | |
| return | |
| dataset = HexaDataset(data_root, config) | |
| dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn) | |
| # 3. Optimize | |
| optimizer = AdamW(model.parameters(), lr=1e-4) # Standard LR | |
| model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) | |
| print("Starting Training...") | |
| model.train() | |
| # 4. Loop | |
| global_step = 0 | |
| epochs = 5 # arbitrary for demo | |
| for epoch in range(epochs): | |
| progress_bar = tqdm(total=len(dataloader), desc=f"Epoch {epoch+1}") | |
| for batch in dataloader: | |
| with accelerator.accumulate(model): | |
| text, speakers, langs, emotions, target_mels = batch | |
| # Check shapes | |
| # Output: [Batch, Time, Channels] | |
| # Target: [Batch, Time, Channels] | |
| output_mels = model(text, speakers, langs, emotions) | |
| # Align lengths (Simple truncation to min length for loss) | |
| min_len = min(output_mels.shape[1], target_mels.shape[1]) | |
| output_sliced = output_mels[:, :min_len, :] | |
| target_sliced = target_mels[:, :min_len, :] | |
| loss = torch.nn.functional.mse_loss(output_sliced, target_sliced) | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| progress_bar.set_postfix(loss=loss.item()) | |
| progress_bar.update(1) | |
| global_step += 1 | |
| # Save Checkpoint | |
| save_path = os.path.join("checkpoints", f"checkpoint_epoch_{epoch}") | |
| os.makedirs(save_path, exist_ok=True) | |
| accelerator.save_state(save_path) | |
| print(f"Saved checkpoint to {save_path}") | |
| if __name__ == "__main__": | |
| train() | |