hexa-tts-5b / src /train.py
Hexa09's picture
Upload folder using huggingface_hub
e729286 verified
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from accelerate import Accelerator
from tqdm import tqdm
import os
from .model import build_model
from .config import HexaConfig
from .dataset import HexaDataset, collate_fn
def train():
"""
Massive Scale Training Loop.
"""
# 1. Setup
config = HexaConfig()
# Gradient Accumulation is CRITICAL for large models on small GPUs
accelerator = Accelerator(gradient_accumulation_steps=16)
print(f"Initializing 5B Parameter Model... (This takes memory!)")
try:
model = build_model()
except RuntimeError as e:
print(f"Error initializing full model: {e}")
print("Fallback: Your GPU memory is too small for 5B. Please try reducing config.dim in config.py")
return
# 2. Data
data_root = "d:\\hexatts\\data"
if not os.path.exists(data_root):
print("Data not found. Run 'python get_data.py' first.")
return
dataset = HexaDataset(data_root, config)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
# 3. Optimize
optimizer = AdamW(model.parameters(), lr=1e-4) # Standard LR
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
print("Starting Training...")
model.train()
# 4. Loop
global_step = 0
epochs = 5 # arbitrary for demo
for epoch in range(epochs):
progress_bar = tqdm(total=len(dataloader), desc=f"Epoch {epoch+1}")
for batch in dataloader:
with accelerator.accumulate(model):
text, speakers, langs, emotions, target_mels = batch
# Check shapes
# Output: [Batch, Time, Channels]
# Target: [Batch, Time, Channels]
output_mels = model(text, speakers, langs, emotions)
# Align lengths (Simple truncation to min length for loss)
min_len = min(output_mels.shape[1], target_mels.shape[1])
output_sliced = output_mels[:, :min_len, :]
target_sliced = target_mels[:, :min_len, :]
loss = torch.nn.functional.mse_loss(output_sliced, target_sliced)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
progress_bar.set_postfix(loss=loss.item())
progress_bar.update(1)
global_step += 1
# Save Checkpoint
save_path = os.path.join("checkpoints", f"checkpoint_epoch_{epoch}")
os.makedirs(save_path, exist_ok=True)
accelerator.save_state(save_path)
print(f"Saved checkpoint to {save_path}")
if __name__ == "__main__":
train()