|
|
import torch
|
|
|
from torch.utils.data import DataLoader
|
|
|
from torch.optim import AdamW
|
|
|
from accelerate import Accelerator
|
|
|
from tqdm import tqdm
|
|
|
import os
|
|
|
from .model import build_model
|
|
|
from .config import HexaConfig
|
|
|
from .dataset import HexaDataset, collate_fn
|
|
|
|
|
|
def train():
|
|
|
"""
|
|
|
Massive Scale Training Loop.
|
|
|
"""
|
|
|
|
|
|
config = HexaConfig()
|
|
|
|
|
|
|
|
|
accelerator = Accelerator(gradient_accumulation_steps=16)
|
|
|
|
|
|
print(f"Initializing 5B Parameter Model... (This takes memory!)")
|
|
|
try:
|
|
|
model = build_model()
|
|
|
except RuntimeError as e:
|
|
|
print(f"Error initializing full model: {e}")
|
|
|
print("Fallback: Your GPU memory is too small for 5B. Please try reducing config.dim in config.py")
|
|
|
return
|
|
|
|
|
|
|
|
|
data_root = "d:\\hexatts\\data"
|
|
|
if not os.path.exists(data_root):
|
|
|
print("Data not found. Run 'python get_data.py' first.")
|
|
|
return
|
|
|
|
|
|
dataset = HexaDataset(data_root, config)
|
|
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
|
|
|
|
|
|
|
|
|
optimizer = AdamW(model.parameters(), lr=1e-4)
|
|
|
|
|
|
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
|
|
|
|
|
print("Starting Training...")
|
|
|
model.train()
|
|
|
|
|
|
|
|
|
global_step = 0
|
|
|
epochs = 5
|
|
|
|
|
|
for epoch in range(epochs):
|
|
|
progress_bar = tqdm(total=len(dataloader), desc=f"Epoch {epoch+1}")
|
|
|
|
|
|
for batch in dataloader:
|
|
|
with accelerator.accumulate(model):
|
|
|
text, speakers, langs, emotions, target_mels = batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_mels = model(text, speakers, langs, emotions)
|
|
|
|
|
|
|
|
|
min_len = min(output_mels.shape[1], target_mels.shape[1])
|
|
|
output_sliced = output_mels[:, :min_len, :]
|
|
|
target_sliced = target_mels[:, :min_len, :]
|
|
|
|
|
|
loss = torch.nn.functional.mse_loss(output_sliced, target_sliced)
|
|
|
|
|
|
accelerator.backward(loss)
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
progress_bar.set_postfix(loss=loss.item())
|
|
|
progress_bar.update(1)
|
|
|
global_step += 1
|
|
|
|
|
|
|
|
|
save_path = os.path.join("checkpoints", f"checkpoint_epoch_{epoch}")
|
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
accelerator.save_state(save_path)
|
|
|
print(f"Saved checkpoint to {save_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
train()
|
|
|
|