Granis87's picture
Initial upload of MnemoCore
dbb04e4 verified
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import os
import sys
# Add core to path to import components
sys.path.append(os.path.join(os.path.dirname(__file__), 'core'))
from predictor import OmegaJEPA_Predictor
from omega_metrics import OmegaMetrics
from mock_data import get_dataloader
def train():
# 1. Hyperparameters
embedding_dim = 256
action_dim = 64
latent_dim = 64
batch_size = 64
epochs = 10
lr = 1e-4
alpha = 0.1 # Weight for auxiliary anomaly loss (TRA)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on: {device}")
# 2. Initialize Model, Metrics, and Data
model = OmegaJEPA_Predictor(
embedding_dim=embedding_dim,
action_dim=action_dim,
latent_dim=latent_dim
).to(device)
metrics_auditor = OmegaMetrics()
dataloader = get_dataloader(
batch_size=batch_size,
embedding_dim=embedding_dim,
action_dim=action_dim
)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
# 3. Training Loop
model.train()
for epoch in range(epochs):
epoch_losses = []
epoch_tras = []
for batch_idx, (s_t, a_t, s_t1) in enumerate(dataloader):
s_t, a_t, s_t1 = s_t.to(device), a_t.to(device), s_t1.to(device)
optimizer.zero_grad()
# Forward Pass
# In training, we can either sample z or use deterministic (zero)
# For simplicity, we use z=None (deterministic) to learn the mean transition
pred_s_t1 = model(s_t, a_t)
# Loss Components
# a) L2 Distance (Primary Prediction Loss)
mse_loss = torch.mean((pred_s_t1 - s_t1)**2)
# b) Auxiliary Loss: Omega Anomaly Score (TRA)
# We want to minimize TRA to encourage organic transitions that respect energy flow
tra_loss = metrics_auditor.compute_tra(s_t, pred_s_t1).mean()
# Total Loss
total_loss = mse_loss + (alpha * tra_loss)
# Backward Pass
total_loss.backward()
optimizer.step()
epoch_losses.append(mse_loss.item())
epoch_tras.append(tra_loss.item())
if batch_idx % 50 == 0:
print(f"Epoch {epoch} [{batch_idx}/{len(dataloader)}] "
f"MSE: {mse_loss.item():.6f} | TRA: {tra_loss.item():.6f}")
avg_mse = sum(epoch_losses) / len(epoch_losses)
avg_tra = sum(epoch_tras) / len(epoch_tras)
print(f"==> Epoch {epoch} Complete | Avg MSE: {avg_mse:.6f} | Avg TRA: {avg_tra:.6f}")
# 4. Save Model
checkpoint_dir = os.path.join(os.path.dirname(__file__), "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)
save_path = os.path.join(checkpoint_dir, "omega_jepa_latest.pt")
torch.save(model.state_dict(), save_path)
print(f"Training finished and model saved to {save_path}.")
if __name__ == "__main__":
train()