LovecaSim / ai /training /train_bc.py
trioskosmos's picture
Upload ai/training/train_bc.py with huggingface_hub
e4accf1 verified
import argparse
import os
# Add project root to path
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sb3_contrib import MaskablePPO
from stable_baselines3.common.utils import get_device
from torch.utils.data import DataLoader, Dataset
sys.path.append(os.getcwd())
from ai.vec_env_adapter import VectorEnvAdapter
class BCDataset(Dataset):
def __init__(self, path):
print(f"Loading {path}...")
data = np.load(path)
self.obs = data["obs"]
self.actions = data["actions"]
print(f"Loaded {len(self.obs)} samples.")
def __len__(self):
return len(self.obs)
def __getitem__(self, idx):
return self.obs[idx], self.actions[idx].astype(np.int64)
def train_bc(data_path="data/bc_dataset.npz", save_path="models/bc_pretrained", epochs=10, batch_size=256, lr=1e-3):
device = get_device("auto")
print(f"Using device: {device}")
# 1. Initialize Dummy Env to get shapes/Policy
# We use COMPRESSED or STANDARD depending on env vars, defaults to STANDARD (2304)
# The dataset MUST match the observation space used here.
# generate_bc_data.py uses the default OBS_MODE from VectorEnv.
# Ensure they match!
env = VectorEnvAdapter(num_envs=1)
# 2. Create Model (MaskablePPO)
# We initialize a PPO model to get the policy network structure
model = MaskablePPO("MlpPolicy", env, verbose=1, device=device, learning_rate=lr)
policy = model.policy.to(device)
optimizer = optim.Adam(policy.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
# 3. Data Loader
dataset = BCDataset(data_path)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 4. Training Loop
print("Starting BC Training...")
for epoch in range(epochs):
total_loss = 0
total_acc = 0
batches = 0
policy.train()
for batch_obs, batch_acts in dataloader:
batch_obs = batch_obs.to(device).float()
batch_acts = batch_acts.to(device)
# Forward pass
# PPO Policy: get_distribution(obs) -> distribution
# We want logits.
# policy.action_net(policy.mlp_extractor.forward_actor(features))
features = policy.extract_features(batch_obs)
latent_pi, _ = policy.mlp_extractor(features)
logits = policy.action_net(latent_pi)
# Masking?
# In BC, we assume the heuristic action IS valid.
# We train the net to output it.
# We technically don't need masking for the loss, but the net should learn to not pick invalids.
# Ideally, we should apply masks to logits before Softmax to suppress invalids?
# But we don't have masks in the dataset (unless we generate them).
# The heuristic action is valid by definition.
# The network should learn to maximize logit for valid action.
loss = loss_fn(logits, batch_acts)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
# Accuracy
preds = torch.argmax(logits, dim=1)
acc = (preds == batch_acts).float().mean().item()
total_acc += acc
batches += 1
print(f"Epoch {epoch + 1}/{epochs} | Loss: {total_loss / batches:.4f} | Acc: {total_acc / batches:.4f}")
# 5. Save
print(f"Saving model to {save_path}...")
model.save(save_path)
print("Done.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, default="data/bc_dataset.npz")
parser.add_argument("--save", type=str, default="models/bc_pretrained")
parser.add_argument("--epochs", type=int, default=5)
args = parser.parse_args()
train_bc(data_path=args.data, save_path=args.save, epochs=args.epochs)