| | import os |
| | import h5py |
| | import torch |
| | import random |
| | import yaml |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset, DataLoader |
| | from tqdm import tqdm |
| | from diffusion import create_diffusion |
| | from models import DiT |
| | import torch.optim as optim |
| | from torch.utils.tensorboard import SummaryWriter |
| |
|
| | |
| | with open('config/train.yaml', 'r') as file: |
| | config = yaml.safe_load(file) |
| |
|
| | |
| | writer = SummaryWriter() |
| |
|
| | class MelMetaDataset(Dataset): |
| | def __init__(self, h5_file, mel_frames): |
| | self.h5_file = h5_file |
| | self.mel_frames = mel_frames |
| | with h5py.File(h5_file, 'r') as f: |
| | self.keys = list(f.keys()) |
| |
|
| | def __len__(self): |
| | return len(self.keys) |
| | |
| | def pad_mel(self, mel_segment, total_frames): |
| | if total_frames < self.mel_frames: |
| | padding_frames = self.mel_frames - total_frames |
| | mel_segment = F.pad(mel_segment, (0, padding_frames), mode='constant', value=0) |
| | return mel_segment |
| |
|
| | def __getitem__(self, idx): |
| | key = self.keys[idx] |
| | with h5py.File(self.h5_file, 'r') as f: |
| | mel = torch.FloatTensor(f[key]['mel'][:]) |
| | meta_latent = torch.FloatTensor(f[key]['meta'][:]) |
| | |
| | total_frames = mel.shape[2] |
| | if total_frames > self.mel_frames: |
| | start_frame = random.randint(0, total_frames - self.mel_frames) |
| | mel_segment = mel[:, :, start_frame:start_frame + self.mel_frames] |
| | else: |
| | mel_segment = self.pad_mel(mel, total_frames) |
| | mel_segment = (mel_segment + 10) / 20 |
| | return mel_segment, meta_latent |
| |
|
| | |
| | dataset = MelMetaDataset(config['h5_file_path'], mel_frames=config['mel_frames']) |
| | dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True) |
| |
|
| | |
| | device = config['device'] if torch.cuda.is_available() else "cpu" |
| | model = DiT( |
| | input_size=tuple(config['input_size']), |
| | patch_size=config['patch_size'], |
| | in_channels=config['in_channels'], |
| | hidden_size=config['hidden_size'], |
| | depth=config['depth'], |
| | num_heads=config['num_heads'], |
| | ) |
| | model.to(device) |
| |
|
| | |
| | diffusion = create_diffusion(timestep_respacing="") |
| |
|
| | |
| | optimizer = optim.AdamW(model.parameters(), lr=config['lr']) |
| |
|
| | |
| | os.makedirs(config['checkpoint_dir'], exist_ok=True) |
| |
|
| | |
| | def train_model(model, dataloader, optimizer, diffusion, num_epochs, sample_interval): |
| | model.train() |
| | for epoch in range(num_epochs): |
| | total_loss = 0.0 |
| | for step, (mel_segment, meta_latent) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")): |
| | mel_segment = mel_segment.to(device) |
| | meta_latent = meta_latent.to(device) |
| | t = torch.randint(0, diffusion.num_timesteps, (mel_segment.shape[0],), device=device) |
| | model_kwargs = dict(y=meta_latent) |
| | loss_dict = diffusion.training_losses(model, mel_segment, t, model_kwargs) |
| | loss = loss_dict["loss"].mean() |
| |
|
| | optimizer.zero_grad() |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | total_loss += loss.item() |
| |
|
| | avg_loss = total_loss / len(dataloader) |
| | print(f"Epoch {epoch + 1}/{num_epochs}: Average Loss: {avg_loss:.4f}") |
| | writer.add_scalar('Loss/epoch', avg_loss, epoch + 1) |
| |
|
| | if (epoch + 1) % sample_interval == 0: |
| | checkpoint = { |
| | 'epoch': epoch + 1, |
| | 'model_state_dict': model.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | } |
| | checkpoint_path = f"{config['checkpoint_dir']}/model_epoch_{epoch + 1}.pt" |
| | torch.save(checkpoint, checkpoint_path) |
| | print(f"Model checkpoint saved at epoch {epoch + 1}") |
| |
|
| | |
| | train_model(model, dataloader, optimizer, diffusion, num_epochs=config['num_epochs'], sample_interval=config['sample_interval']) |
| |
|
| | |
| | writer.close() |
| |
|