| import multiprocessing | |
| from pathlib import Path | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from safetensors.torch import save_file | |
| from rvq_model import MotionRVQ_VAE | |
| if __name__ == "__main__": | |
| multiprocessing.freeze_support() | |
| from rvq_humanml_dataset import DataLoader, HumanML3DDataset | |
| base_dir = Path(__file__).resolve().parent | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Training on: {device}") | |
| model = MotionRVQ_VAE().to(device) | |
| dataset = HumanML3DDataset(data_dir=str(base_dir / "new_joint_vecs"), window_size=100) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=32, | |
| shuffle=True, | |
| num_workers=4, | |
| drop_last=True, | |
| ) | |
| optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4) | |
| num_epochs = 5 | |
| model.train() | |
| for epoch in range(num_epochs): | |
| epoch_loss = 0.0 | |
| for batch_idx, batch in enumerate(dataloader): | |
| batch = batch.to(device) | |
| optimizer.zero_grad() | |
| reconstructed, _, commit_loss = model(batch) | |
| pos_loss = F.mse_loss(reconstructed, batch) | |
| vel_orig = batch[:, :, 1:] - batch[:, :, :-1] | |
| vel_recon = reconstructed[:, :, 1:] - reconstructed[:, :, :-1] | |
| vel_loss = F.mse_loss(vel_recon, vel_orig) | |
| reconstruction_loss = pos_loss + (1.5 * vel_loss) | |
| loss = reconstruction_loss + commit_loss | |
| loss.backward() | |
| optimizer.step() | |
| epoch_loss += loss.item() | |
| if batch_idx % 50 == 0: | |
| print( | |
| f"Epoch [{epoch + 1}/{num_epochs}] Batch [{batch_idx}/{len(dataloader)}] " | |
| f"MSE: {reconstruction_loss.item():.4f} | Commit: {commit_loss.item():.4f}" | |
| ) | |
| print(f"--- End of epoch {epoch + 1} | Avg loss: {epoch_loss / len(dataloader):.4f} ---") | |
| weights_path = base_dir / "motion_rvq_weights.safetensors" | |
| state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()} | |
| save_file(state_dict, str(weights_path)) | |
| print(f"Training complete and model saved to: {weights_path}") | |