Wojtekb30's picture
Upload 11 files
316a030 verified
Raw
History Blame Contribute Delete
2.27 kB
import multiprocessing
from pathlib import Path
import torch
import torch.nn.functional as F
import torch.optim as optim
from safetensors.torch import save_file
from rvq_model import MotionRVQ_VAE
if __name__ == "__main__":
multiprocessing.freeze_support()
from rvq_humanml_dataset import DataLoader, HumanML3DDataset
base_dir = Path(__file__).resolve().parent
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on: {device}")
model = MotionRVQ_VAE().to(device)
dataset = HumanML3DDataset(data_dir=str(base_dir / "new_joint_vecs"), window_size=100)
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=4,
drop_last=True,
)
optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
num_epochs = 5
model.train()
for epoch in range(num_epochs):
epoch_loss = 0.0
for batch_idx, batch in enumerate(dataloader):
batch = batch.to(device)
optimizer.zero_grad()
reconstructed, _, commit_loss = model(batch)
pos_loss = F.mse_loss(reconstructed, batch)
vel_orig = batch[:, :, 1:] - batch[:, :, :-1]
vel_recon = reconstructed[:, :, 1:] - reconstructed[:, :, :-1]
vel_loss = F.mse_loss(vel_recon, vel_orig)
reconstruction_loss = pos_loss + (1.5 * vel_loss)
loss = reconstruction_loss + commit_loss
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if batch_idx % 50 == 0:
print(
f"Epoch [{epoch + 1}/{num_epochs}] Batch [{batch_idx}/{len(dataloader)}] "
f"MSE: {reconstruction_loss.item():.4f} | Commit: {commit_loss.item():.4f}"
)
print(f"--- End of epoch {epoch + 1} | Avg loss: {epoch_loss / len(dataloader):.4f} ---")
weights_path = base_dir / "motion_rvq_weights.safetensors"
state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}
save_file(state_dict, str(weights_path))
print(f"Training complete and model saved to: {weights_path}")