File size: 4,066 Bytes
e4accf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import os

# Add project root to path
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sb3_contrib import MaskablePPO
from stable_baselines3.common.utils import get_device
from torch.utils.data import DataLoader, Dataset

sys.path.append(os.getcwd())

from ai.vec_env_adapter import VectorEnvAdapter


class BCDataset(Dataset):
    def __init__(self, path):
        print(f"Loading {path}...")
        data = np.load(path)
        self.obs = data["obs"]
        self.actions = data["actions"]
        print(f"Loaded {len(self.obs)} samples.")

    def __len__(self):
        return len(self.obs)

    def __getitem__(self, idx):
        return self.obs[idx], self.actions[idx].astype(np.int64)


def train_bc(data_path="data/bc_dataset.npz", save_path="models/bc_pretrained", epochs=10, batch_size=256, lr=1e-3):
    device = get_device("auto")
    print(f"Using device: {device}")

    # 1. Initialize Dummy Env to get shapes/Policy
    # We use COMPRESSED or STANDARD depending on env vars, defaults to STANDARD (2304)
    # The dataset MUST match the observation space used here.
    # generate_bc_data.py uses the default OBS_MODE from VectorEnv.
    # Ensure they match!

    env = VectorEnvAdapter(num_envs=1)

    # 2. Create Model (MaskablePPO)
    # We initialize a PPO model to get the policy network structure
    model = MaskablePPO("MlpPolicy", env, verbose=1, device=device, learning_rate=lr)

    policy = model.policy.to(device)
    optimizer = optim.Adam(policy.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    # 3. Data Loader
    dataset = BCDataset(data_path)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # 4. Training Loop
    print("Starting BC Training...")

    for epoch in range(epochs):
        total_loss = 0
        total_acc = 0
        batches = 0

        policy.train()
        for batch_obs, batch_acts in dataloader:
            batch_obs = batch_obs.to(device).float()
            batch_acts = batch_acts.to(device)

            # Forward pass
            # PPO Policy: get_distribution(obs) -> distribution
            # We want logits.
            # policy.action_net(policy.mlp_extractor.forward_actor(features))

            features = policy.extract_features(batch_obs)
            latent_pi, _ = policy.mlp_extractor(features)
            logits = policy.action_net(latent_pi)

            # Masking?
            # In BC, we assume the heuristic action IS valid.
            # We train the net to output it.
            # We technically don't need masking for the loss, but the net should learn to not pick invalids.
            # Ideally, we should apply masks to logits before Softmax to suppress invalids?
            # But we don't have masks in the dataset (unless we generate them).
            # The heuristic action is valid by definition.
            # The network should learn to maximize logit for valid action.

            loss = loss_fn(logits, batch_acts)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Accuracy
            preds = torch.argmax(logits, dim=1)
            acc = (preds == batch_acts).float().mean().item()
            total_acc += acc
            batches += 1

        print(f"Epoch {epoch + 1}/{epochs} | Loss: {total_loss / batches:.4f} | Acc: {total_acc / batches:.4f}")

    # 5. Save
    print(f"Saving model to {save_path}...")
    model.save(save_path)
    print("Done.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data", type=str, default="data/bc_dataset.npz")
    parser.add_argument("--save", type=str, default="models/bc_pretrained")
    parser.add_argument("--epochs", type=int, default=5)
    args = parser.parse_args()

    train_bc(data_path=args.data, save_path=args.save, epochs=args.epochs)