File size: 2,986 Bytes
e729286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from accelerate import Accelerator
from tqdm import tqdm
import os
from .model import build_model
from .config import HexaConfig
from .dataset import HexaDataset, collate_fn

def train():
    """

    Massive Scale Training Loop.

    """
    # 1. Setup
    config = HexaConfig()
    
    # Gradient Accumulation is CRITICAL for large models on small GPUs
    accelerator = Accelerator(gradient_accumulation_steps=16) 
    
    print(f"Initializing 5B Parameter Model... (This takes memory!)")
    try:
        model = build_model()
    except RuntimeError as e:
        print(f"Error initializing full model: {e}")
        print("Fallback: Your GPU memory is too small for 5B. Please try reducing config.dim in config.py")
        return

    # 2. Data
    data_root = "d:\\hexatts\\data"
    if not os.path.exists(data_root):
        print("Data not found. Run 'python get_data.py' first.")
        return
        
    dataset = HexaDataset(data_root, config)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
    
    # 3. Optimize
    optimizer = AdamW(model.parameters(), lr=1e-4) # Standard LR
    
    model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
    
    print("Starting Training...")
    model.train()
    
    # 4. Loop
    global_step = 0
    epochs = 5 # arbitrary for demo
    
    for epoch in range(epochs):
        progress_bar = tqdm(total=len(dataloader), desc=f"Epoch {epoch+1}")
        
        for batch in dataloader:
            with accelerator.accumulate(model):
                text, speakers, langs, emotions, target_mels = batch
                
                # Check shapes
                # Output: [Batch, Time, Channels]
                # Target: [Batch, Time, Channels]
                
                output_mels = model(text, speakers, langs, emotions)
                
                # Align lengths (Simple truncation to min length for loss)
                min_len = min(output_mels.shape[1], target_mels.shape[1])
                output_sliced = output_mels[:, :min_len, :]
                target_sliced = target_mels[:, :min_len, :]
                
                loss = torch.nn.functional.mse_loss(output_sliced, target_sliced)
                
                accelerator.backward(loss)
                optimizer.step()
                optimizer.zero_grad()
                
                progress_bar.set_postfix(loss=loss.item())
                progress_bar.update(1)
                global_step += 1
        
        # Save Checkpoint
        save_path = os.path.join("checkpoints", f"checkpoint_epoch_{epoch}")
        os.makedirs(save_path, exist_ok=True)
        accelerator.save_state(save_path)
        print(f"Saved checkpoint to {save_path}")

if __name__ == "__main__":
    train()