trioskosmos commited on
Commit
e4accf1
·
verified ·
1 Parent(s): f6a813d

Upload ai/training/train_bc.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/training/train_bc.py +119 -0
ai/training/train_bc.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ # Add project root to path
5
+ import sys
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from sb3_contrib import MaskablePPO
12
+ from stable_baselines3.common.utils import get_device
13
+ from torch.utils.data import DataLoader, Dataset
14
+
15
+ sys.path.append(os.getcwd())
16
+
17
+ from ai.vec_env_adapter import VectorEnvAdapter
18
+
19
+
20
+ class BCDataset(Dataset):
21
+ def __init__(self, path):
22
+ print(f"Loading {path}...")
23
+ data = np.load(path)
24
+ self.obs = data["obs"]
25
+ self.actions = data["actions"]
26
+ print(f"Loaded {len(self.obs)} samples.")
27
+
28
+ def __len__(self):
29
+ return len(self.obs)
30
+
31
+ def __getitem__(self, idx):
32
+ return self.obs[idx], self.actions[idx].astype(np.int64)
33
+
34
+
35
+ def train_bc(data_path="data/bc_dataset.npz", save_path="models/bc_pretrained", epochs=10, batch_size=256, lr=1e-3):
36
+ device = get_device("auto")
37
+ print(f"Using device: {device}")
38
+
39
+ # 1. Initialize Dummy Env to get shapes/Policy
40
+ # We use COMPRESSED or STANDARD depending on env vars, defaults to STANDARD (2304)
41
+ # The dataset MUST match the observation space used here.
42
+ # generate_bc_data.py uses the default OBS_MODE from VectorEnv.
43
+ # Ensure they match!
44
+
45
+ env = VectorEnvAdapter(num_envs=1)
46
+
47
+ # 2. Create Model (MaskablePPO)
48
+ # We initialize a PPO model to get the policy network structure
49
+ model = MaskablePPO("MlpPolicy", env, verbose=1, device=device, learning_rate=lr)
50
+
51
+ policy = model.policy.to(device)
52
+ optimizer = optim.Adam(policy.parameters(), lr=lr)
53
+ loss_fn = nn.CrossEntropyLoss()
54
+
55
+ # 3. Data Loader
56
+ dataset = BCDataset(data_path)
57
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
58
+
59
+ # 4. Training Loop
60
+ print("Starting BC Training...")
61
+
62
+ for epoch in range(epochs):
63
+ total_loss = 0
64
+ total_acc = 0
65
+ batches = 0
66
+
67
+ policy.train()
68
+ for batch_obs, batch_acts in dataloader:
69
+ batch_obs = batch_obs.to(device).float()
70
+ batch_acts = batch_acts.to(device)
71
+
72
+ # Forward pass
73
+ # PPO Policy: get_distribution(obs) -> distribution
74
+ # We want logits.
75
+ # policy.action_net(policy.mlp_extractor.forward_actor(features))
76
+
77
+ features = policy.extract_features(batch_obs)
78
+ latent_pi, _ = policy.mlp_extractor(features)
79
+ logits = policy.action_net(latent_pi)
80
+
81
+ # Masking?
82
+ # In BC, we assume the heuristic action IS valid.
83
+ # We train the net to output it.
84
+ # We technically don't need masking for the loss, but the net should learn to not pick invalids.
85
+ # Ideally, we should apply masks to logits before Softmax to suppress invalids?
86
+ # But we don't have masks in the dataset (unless we generate them).
87
+ # The heuristic action is valid by definition.
88
+ # The network should learn to maximize logit for valid action.
89
+
90
+ loss = loss_fn(logits, batch_acts)
91
+
92
+ optimizer.zero_grad()
93
+ loss.backward()
94
+ optimizer.step()
95
+
96
+ total_loss += loss.item()
97
+
98
+ # Accuracy
99
+ preds = torch.argmax(logits, dim=1)
100
+ acc = (preds == batch_acts).float().mean().item()
101
+ total_acc += acc
102
+ batches += 1
103
+
104
+ print(f"Epoch {epoch + 1}/{epochs} | Loss: {total_loss / batches:.4f} | Acc: {total_acc / batches:.4f}")
105
+
106
+ # 5. Save
107
+ print(f"Saving model to {save_path}...")
108
+ model.save(save_path)
109
+ print("Done.")
110
+
111
+
112
+ if __name__ == "__main__":
113
+ parser = argparse.ArgumentParser()
114
+ parser.add_argument("--data", type=str, default="data/bc_dataset.npz")
115
+ parser.add_argument("--save", type=str, default="models/bc_pretrained")
116
+ parser.add_argument("--epochs", type=int, default=5)
117
+ args = parser.parse_args()
118
+
119
+ train_bc(data_path=args.data, save_path=args.save, epochs=args.epochs)