| import torch
|
| import torch.nn as nn
|
| import torch.optim as optim
|
| from torch.utils.data import DataLoader
|
| import os
|
| import sys
|
|
|
|
|
| sys.path.append(os.path.join(os.path.dirname(__file__), 'core'))
|
|
|
| from predictor import OmegaJEPA_Predictor
|
| from omega_metrics import OmegaMetrics
|
| from mock_data import get_dataloader
|
|
|
| def train():
|
|
|
| embedding_dim = 256
|
| action_dim = 64
|
| latent_dim = 64
|
| batch_size = 64
|
| epochs = 10
|
| lr = 1e-4
|
| alpha = 0.1
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| print(f"Training on: {device}")
|
|
|
|
|
| model = OmegaJEPA_Predictor(
|
| embedding_dim=embedding_dim,
|
| action_dim=action_dim,
|
| latent_dim=latent_dim
|
| ).to(device)
|
|
|
| metrics_auditor = OmegaMetrics()
|
|
|
| dataloader = get_dataloader(
|
| batch_size=batch_size,
|
| embedding_dim=embedding_dim,
|
| action_dim=action_dim
|
| )
|
|
|
| optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
|
|
|
|
|
| model.train()
|
| for epoch in range(epochs):
|
| epoch_losses = []
|
| epoch_tras = []
|
|
|
| for batch_idx, (s_t, a_t, s_t1) in enumerate(dataloader):
|
| s_t, a_t, s_t1 = s_t.to(device), a_t.to(device), s_t1.to(device)
|
|
|
| optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
| pred_s_t1 = model(s_t, a_t)
|
|
|
|
|
|
|
| mse_loss = torch.mean((pred_s_t1 - s_t1)**2)
|
|
|
|
|
|
|
| tra_loss = metrics_auditor.compute_tra(s_t, pred_s_t1).mean()
|
|
|
|
|
| total_loss = mse_loss + (alpha * tra_loss)
|
|
|
|
|
| total_loss.backward()
|
| optimizer.step()
|
|
|
| epoch_losses.append(mse_loss.item())
|
| epoch_tras.append(tra_loss.item())
|
|
|
| if batch_idx % 50 == 0:
|
| print(f"Epoch {epoch} [{batch_idx}/{len(dataloader)}] "
|
| f"MSE: {mse_loss.item():.6f} | TRA: {tra_loss.item():.6f}")
|
|
|
| avg_mse = sum(epoch_losses) / len(epoch_losses)
|
| avg_tra = sum(epoch_tras) / len(epoch_tras)
|
| print(f"==> Epoch {epoch} Complete | Avg MSE: {avg_mse:.6f} | Avg TRA: {avg_tra:.6f}")
|
|
|
|
|
| checkpoint_dir = os.path.join(os.path.dirname(__file__), "checkpoints")
|
| os.makedirs(checkpoint_dir, exist_ok=True)
|
| save_path = os.path.join(checkpoint_dir, "omega_jepa_latest.pt")
|
| torch.save(model.state_dict(), save_path)
|
| print(f"Training finished and model saved to {save_path}.")
|
|
|
| if __name__ == "__main__":
|
| train()
|
|
|