Gen-HVAC commited on
Commit
1641a08
·
verified ·
1 Parent(s): 831718a

Upload 4 files

Browse files
training/data_loader.py ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import re
4
+ import hashlib
5
+ from typing import Dict, List, Optional, Any, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from tqdm import tqdm
11
+ import json
12
+
13
+
14
+ # CONFIG & REGISTRY
15
+ DROP_OBS_KEYS = []
16
+ DATA_DIR = "TrajectoryData_from_docker"
17
+ INDEX_CACHE_PATH = os.path.join(DATA_DIR, "episode_index_cache_topk.json")
18
+ NORM_CACHE_PATH = os.path.join(DATA_DIR, "norm_stats_v_topk.npz")
19
+
20
+ PAD_ID = 0
21
+ UNK_ID = 1
22
+ SENSOR_START_ID = 2
23
+ ACTION_START_ID = 300
24
+ VOCAB_SIZE = 512
25
+
26
+ CONTEXT_LEN = 48
27
+ MAX_TOKENS_PER_STEP = 64
28
+ MAX_ZONES = 32
29
+ PHYSICS_HORIZON = 16
30
+ SEED = 42
31
+
32
+ USE_TOPK = True
33
+ TOPK_FRAC = 0.8
34
+ TOPK_MODE = "filter"
35
+ TOPK_ON = "energy"
36
+ TOPK_BOOST = 3.0
37
+
38
+ # --- Action Discretization ---
39
+ NUM_ACTION_BINS = 64
40
+ HTG_LOW, HTG_HIGH = 15.0, 30.0
41
+ CLG_LOW, CLG_HIGH = 15.0, 30.0
42
+
43
+ # --- Normalization & Scaling ---
44
+ USE_NORMALIZATION = True
45
+ ACTION_VALUE_INPUT_MODE = "prev"
46
+ ACTION_VALUE_MASK_CONST = 0.0
47
+ COMFORT_SCALE = 1.0
48
+
49
+ # --- Preference conditioning ---
50
+ PREF_MODE = "sample"
51
+ PREF_FIXED_LAMBDA = 0.5
52
+ PREF_BETA_A = 5.0
53
+ PREF_BETA_B = 2.0
54
+ ZONE_SRC_REGEX = 1
55
+ ZONE_SRC_PAREN = 2
56
+ ZONE_SRC_CORE_PERIM = 3
57
+ ZONE_SRC_HASH = 4
58
+
59
+ HVAC_KEYWORD_MAP = {
60
+ # Sensors (2..299)
61
+ "temp": 10, "t_in": 10, "temperature": 10,
62
+ "humidity": 11, "rh": 11,
63
+ "co2": 12, "ppm": 12,
64
+ "power": 13, "energy": 13, "kw": 13,
65
+ "occupancy": 14, "occ": 14, "people": 14,
66
+ "solar": 15, "rad": 15, "radiation": 15,
67
+ "outdoor": 16, "site": 16, "environment": 16,
68
+ "pressure": 17, "flow": 18, "fan": 19, "speed": 19,
69
+ # Actions (offset from ACTION_START_ID)
70
+ "setpoint": 10, "stpt": 10,
71
+ "damper": 11, "position": 11, "valve": 12,
72
+ }
73
+
74
+ # ============================================================
75
+ # HELPER
76
+ # ============================================================
77
+ def compute_comfort_indices_from_state_keys(state_keys: List[str]) -> List[int]:
78
+ kl = [str(k).lower() for k in state_keys]
79
+
80
+ any_idx = [i for i, k in enumerate(kl)
81
+ if ("ash55" in k and "notcomfortable" in k and "any" in k)]
82
+ if len(any_idx) > 0:
83
+ return any_idx
84
+
85
+ return [i for i, k in enumerate(kl)
86
+ if ("ash55" in k and "notcomfortable" in k)]
87
+
88
+
89
+ def extract_zone_id_with_source(name_lower: str) -> Tuple[int, int]:
90
+ m = re.search(r'(?:\bzone\b|\bz\b|\bzn\b)[_\s\-]*?(\d+)\b', name_lower)
91
+ if m:
92
+ zid = int(m.group(1))
93
+ zid = min(max(zid, 0), MAX_ZONES - 1)
94
+ return zid, ZONE_SRC_REGEX
95
+ parens = re.findall(r'\(([^)]+)\)', name_lower)
96
+ for chunk in parens:
97
+ m2 = re.search(r'(?:\bzone\b|\bz\b|\bzn\b)[_\s\-]*?(\d+)\b', chunk)
98
+ if m2: return min(max(int(m2.group(1)), 0), MAX_ZONES - 1), ZONE_SRC_PAREN
99
+ m4 = re.search(r'(?:perimeter|perim|core)[_\s\-]*?(?:zn[_\s\-]*)?(\d+)\b', name_lower)
100
+ if m4:
101
+ return min(max(int(m4.group(1)), 0), MAX_ZONES - 1), ZONE_SRC_CORE_PERIM
102
+ h = int(hashlib.md5(name_lower.encode()).hexdigest(), 16)
103
+ return 1 + (h % max(1, (MAX_ZONES - 1))), ZONE_SRC_HASH
104
+
105
+ def parse_feature_identity(name: str, is_action: bool = False) -> Tuple[int, int, int]:
106
+ name_lower = str(name).lower()
107
+ zone_id, zone_src = extract_zone_id_with_source(name_lower)
108
+ found_id = UNK_ID
109
+ for key, val in HVAC_KEYWORD_MAP.items():
110
+ if key in name_lower:
111
+ found_id = val
112
+ break
113
+ if found_id == UNK_ID:
114
+ hash_val = int(hashlib.md5(name_lower.encode()).hexdigest(), 16)
115
+ found_id = 50 + (hash_val % 50)
116
+ final_id = (ACTION_START_ID if is_action else SENSOR_START_ID) + found_id
117
+ if final_id >= VOCAB_SIZE: final_id = UNK_ID
118
+ return final_id, zone_id, zone_src
119
+
120
+ def discretize_actions_to_bins(actions: np.ndarray, action_keys: List[str]) -> np.ndarray:
121
+ out = np.zeros_like(actions, dtype=np.int64)
122
+ for j, k in enumerate(action_keys):
123
+ kl = k.lower()
124
+ if "clg" in kl or "cool" in kl: lo, hi = CLG_LOW, CLG_HIGH
125
+ else: lo, hi = HTG_LOW, HTG_HIGH
126
+ a = np.clip(actions[:, j], lo, hi)
127
+ x = (a - lo) / (hi - lo + 1e-12)
128
+ bins = np.rint(x * (NUM_ACTION_BINS - 1)).astype(np.int64)
129
+ out[:, j] = np.clip(bins, 0, NUM_ACTION_BINS - 1)
130
+ return out
131
+
132
+ def discounted_cumsum(x: np.ndarray, gamma: float = 1.0) -> np.ndarray:
133
+ y = np.zeros_like(x, dtype=np.float32)
134
+ running = 0.0
135
+ for t in range(len(x)-1, -1, -1):
136
+ running = x[t] + gamma * running
137
+ y[t] = running
138
+ return y
139
+
140
+ def _mix_u64(x: int) -> int:
141
+ x &= 0xFFFFFFFFFFFFFFFF
142
+ x ^= (x >> 33)
143
+ x = (x * 0xff51afd7ed558ccd) & 0xFFFFFFFFFFFFFFFF
144
+ x ^= (x >> 33)
145
+ x = (x * 0xc4ceb9fe1a85ec53) & 0xFFFFFFFFFFFFFFFF
146
+ x ^= (x >> 33)
147
+ return x & 0xFFFFFFFFFFFFFFFF
148
+
149
+ def dataset_signature(npz_paths: List[str]) -> str:
150
+ parts = []
151
+ for p in npz_paths:
152
+ try:
153
+ st = os.stat(p)
154
+ parts.append(f"{p}|{st.st_size}|{int(st.st_mtime)}")
155
+ except FileNotFoundError:
156
+ parts.append(f"{p}|missing")
157
+ raw = "\n".join(parts).encode("utf-8")
158
+ return hashlib.md5(raw).hexdigest()
159
+
160
+ def compute_occupancy_indices_from_state_keys(state_keys: List[str]) -> List[int]:
161
+ kl = [str(k).lower() for k in state_keys]
162
+ return [i for i, k in enumerate(kl) if ("occ" in k and "count" in k)]
163
+
164
+ # ============================================================
165
+ # 1) EPISODE INDEX
166
+ # ============================================================
167
+
168
+ class EpisodeIndex:
169
+ def __init__(self, npz_paths: List[str]):
170
+ self.paths = list(npz_paths)
171
+ self.T: List[int] = []
172
+
173
+ self.returns_energy: List[float] = []
174
+ self.returns_comfort: List[float] = []
175
+
176
+ self.s_meta: List[List[Tuple[int,int,int]]] = []
177
+ self.a_meta: List[List[Tuple[int,int,int]]] = []
178
+ self.state_keys: List[List[str]] = []
179
+ self.action_keys: List[List[str]] = []
180
+ self.keep_indices_map: List[List[int]] = []
181
+ self.comfort_idx: List[List[int]] = []
182
+
183
+ sig = dataset_signature(self.paths)
184
+
185
+ if os.path.exists(INDEX_CACHE_PATH):
186
+ try:
187
+ with open(INDEX_CACHE_PATH, "r") as f:
188
+ cache = json.load(f)
189
+ if cache.get("signature") == sig and "returns_energy" in cache:
190
+ print(f"[DataLoader] Loading cached index: {INDEX_CACHE_PATH}")
191
+ self.T = cache["T"]
192
+ self.returns_energy = cache["returns_energy"]
193
+ self.returns_comfort = cache["returns_comfort"]
194
+ self.state_keys = cache["state_keys"]
195
+
196
+ self.action_keys = cache["action_keys"]
197
+ self.keep_indices_map = cache.get("keep_indices_map", [])
198
+ self.s_meta = [[parse_feature_identity(k, is_action=False) for k in ks] for ks in self.state_keys]
199
+ self.a_meta = [[parse_feature_identity(k, is_action=True) for k in ks] for ks in self.action_keys]
200
+ if "comfort_idx" in cache:
201
+ self.comfort_idx = cache["comfort_idx"]
202
+ else:
203
+ print("[DataLoader] Cache missing comfort_idx. Rebuilding.")
204
+ raise ValueError("Outdated Cache")
205
+
206
+ print(f"[DataLoader] Cache loaded. Episodes indexed: {len(self.T)}")
207
+ return
208
+ else:
209
+ print("[DataLoader] Cache signature mismatch")
210
+ except Exception as e:
211
+ print(f"[DataLoader] Failed load cache: {e}")
212
+ for p in tqdm(self.paths, desc="Indexing"):
213
+ try:
214
+ with np.load(p, allow_pickle=True) as d:
215
+ obs = d["observations"]
216
+ if "rewards_energy" in d:
217
+ r_e = d["rewards_energy"]
218
+ r_c = d["rewards_comfort"]
219
+ else:
220
+ r_e = d["rewards"]
221
+ r_c = np.zeros_like(r_e)
222
+
223
+ ret_e = float(np.sum(r_e))
224
+ ret_c = float(np.sum(r_c))
225
+
226
+ T = int(obs.shape[0])
227
+
228
+ # Get RAW keys
229
+ raw_s_keys = d["state_keys"].astype(object).tolist() if "state_keys" in d else []
230
+ a_keys = d["action_keys"].astype(object).tolist() if "action_keys" in d else []
231
+ raw_s_keys = list(map(str, raw_s_keys))
232
+ a_keys = list(map(str, a_keys))
233
+ c_idx = compute_comfort_indices_from_state_keys(raw_s_keys)
234
+ keep_idxs = [i for i, k in enumerate(raw_s_keys) if k not in DROP_OBS_KEYS]
235
+ s_keys = [raw_s_keys[i] for i in keep_idxs]
236
+
237
+ s_meta = [parse_feature_identity(k, is_action=False) for k in s_keys]
238
+ a_meta = [parse_feature_identity(k, is_action=True) for k in a_keys]
239
+
240
+ self.T.append(T)
241
+ self.returns_energy.append(ret_e)
242
+ self.returns_comfort.append(ret_c)
243
+ self.state_keys.append(s_keys)
244
+ self.action_keys.append(a_keys)
245
+ self.comfort_idx.append(c_idx) # Save indices relative to RAW array
246
+
247
+ self.s_meta.append(s_meta)
248
+ self.a_meta.append(a_meta)
249
+ self.keep_indices_map.append(keep_idxs)
250
+
251
+ except Exception as e:
252
+ print(f"[IndexError] {p}: {e}")
253
+
254
+ # Save Cache
255
+ try:
256
+ cache = {
257
+ "signature": sig,
258
+ "T": self.T,
259
+ "returns_energy": self.returns_energy,
260
+ "returns_comfort": self.returns_comfort,
261
+ "state_keys": self.state_keys,
262
+ "action_keys": self.action_keys,
263
+ "keep_indices_map": self.keep_indices_map,
264
+ "comfort_idx": self.comfort_idx, # Added
265
+ }
266
+ with open(INDEX_CACHE_PATH, "w") as f:
267
+ json.dump(cache, f)
268
+ print(f"[DataLoader] Saved index cache: {INDEX_CACHE_PATH}")
269
+ except Exception as e:
270
+ print(f"[DataLoader] Warning: failed to save cache: {e}")
271
+
272
+ def __len__(self):
273
+ return len(self.T)
274
+
275
+ # ============================================================
276
+ # 2) NORMALIZATION
277
+ # ============================================================
278
+
279
+ def compute_and_save_norm_stats(npz_paths: List[str], index: "EpisodeIndex", max_episodes: int = 1000, stride: int = 4):
280
+ rng = np.random.default_rng(SEED)
281
+ n = len(index)
282
+ if n == 0:
283
+ raise RuntimeError("EpisodeIndex is empty (no valid episodes).")
284
+
285
+ k = min(max_episodes, n)
286
+ eps_idx = rng.choice(np.arange(n), size=k, replace=False)
287
+
288
+ obs_sum, obs_sumsq = None, None
289
+ act_sum, act_sumsq = None, None
290
+ count = 0
291
+
292
+ for ei in tqdm(eps_idx, desc="Computing norm stats"):
293
+ p = index.paths[int(ei)]
294
+ with np.load(p, allow_pickle=True) as d:
295
+ obs = d["observations"].astype(np.float32)
296
+ act = d["actions"].astype(np.float32)
297
+
298
+ keep_idxs = index.keep_indices_map[int(ei)]
299
+ obs = obs[:, keep_idxs]
300
+
301
+
302
+ obs = obs[::stride]
303
+ act = act[::stride]
304
+
305
+ if obs_sum is None:
306
+ obs_sum = np.zeros(obs.shape[1], dtype=np.float64)
307
+ obs_sumsq = np.zeros(obs.shape[1], dtype=np.float64)
308
+ act_sum = np.zeros(act.shape[1], dtype=np.float64)
309
+ act_sumsq = np.zeros(act.shape[1], dtype=np.float64)
310
+
311
+ obs_sum += obs.sum(axis=0)
312
+ obs_sumsq += (obs**2).sum(axis=0)
313
+ act_sum += act.sum(axis=0)
314
+ act_sumsq += (act**2).sum(axis=0)
315
+ count += obs.shape[0]
316
+
317
+ if obs_sum is None or obs_sumsq is None or act_sum is None or act_sumsq is None:
318
+ raise ValueError("obs_sum, obs_sumsq, act_sum, or act_sumsq is not initialized properly.")
319
+
320
+ obs_mean = (obs_sum / max(count, 1)).astype(np.float32)
321
+ obs_std = np.sqrt(np.maximum((obs_sumsq / max(count, 1)) - obs_mean**2, 1e-6)).astype(np.float32)
322
+ act_mean = (act_sum / max(count, 1)).astype(np.float32)
323
+ act_std = np.sqrt(np.maximum((act_sumsq / max(count, 1)) - act_mean**2, 1e-6)).astype(np.float32)
324
+ all_re = np.abs(np.array(index.returns_energy))
325
+ all_rc = np.abs(np.array(index.returns_comfort))
326
+
327
+ scale_energy = float(np.percentile(all_re, 95)) if len(all_re) > 0 else 1.0
328
+ scale_comfort = float(np.percentile(all_rc, 95)) if len(all_rc) > 0 else 1.0
329
+
330
+ scale_energy = max(scale_energy, 1.0)
331
+ scale_comfort = max(scale_comfort, 1.0)
332
+
333
+ np.savez_compressed(
334
+ NORM_CACHE_PATH,
335
+ obs_mean=obs_mean, obs_std=obs_std,
336
+ act_mean=act_mean, act_std=act_std,
337
+ scale_energy=np.array([scale_energy], dtype=np.float32),
338
+ scale_comfort=np.array([scale_comfort], dtype=np.float32),
339
+ )
340
+
341
+
342
+
343
+ class GeneralistDataset(Dataset):
344
+ def __init__(
345
+ self,
346
+ npz_paths: List[str],
347
+ max_tokens: int = MAX_TOKENS_PER_STEP,
348
+ seed: int = SEED,
349
+ virtual_len: int = 60_000,
350
+ gamma_rtg: float = 1.0,
351
+ topk_frac: Optional[float] = None,
352
+ topk_mode: Optional[str] = None,
353
+ topk_on: Optional[str] = None,
354
+ ):
355
+
356
+ self.index = EpisodeIndex(npz_paths)
357
+ self.max_tokens = int(max_tokens)
358
+ self.seed = int(seed)
359
+ self.virtual_len = int(virtual_len)
360
+ self.epoch = 0
361
+ self.gamma_rtg = float(gamma_rtg)
362
+ self.is_train = True
363
+
364
+ self.all_eps = np.arange(len(self.index), dtype=np.int64)
365
+
366
+ # ---------------- Top-K selection ----------------
367
+ self.use_topk = bool(USE_TOPK) if topk_frac is None else True
368
+ self.topk_frac = float(TOPK_FRAC) if topk_frac is None else float(topk_frac)
369
+ self.topk_mode = str(TOPK_MODE) if topk_mode is None else str(topk_mode)
370
+ self.topk_on = str(TOPK_ON) if topk_on is None else str(topk_on)
371
+ rets_e = np.asarray(self.index.returns_energy, dtype=np.float32)
372
+ rets_c = np.asarray(self.index.returns_comfort, dtype=np.float32)
373
+
374
+
375
+
376
+
377
+
378
+ self.sel_eps = self.all_eps
379
+ self.weights = None
380
+
381
+ if self.use_topk and len(self.all_eps) > 0:
382
+ total_k = max(1, int(round(self.topk_frac * len(self.all_eps))))
383
+
384
+ # === STRATEGY 1: PARETO UNION (Energy + Comfort + Mixed) ===
385
+ if self.topk_on == "pareto":
386
+ print("[Top-K] Strategy: Energy + Comfort + Mixed")
387
+ k_part = max(1, total_k // 3)
388
+
389
+ # 1. Best Energy
390
+ idx_energy = np.argsort(rets_e)[::-1][:k_part]
391
+ # 2. Best Comfort
392
+ idx_comfort = np.argsort(rets_c)[::-1][:k_part]
393
+ # 3. Best Mixed (Balanced)
394
+ norm_e = (rets_e - rets_e.mean()) / (rets_e.std() + 1e-6)
395
+ norm_c = (rets_c - rets_c.mean()) / (rets_c.std() + 1e-6)
396
+ idx_mixed = np.argsort(norm_e + norm_c)[::-1][:k_part]
397
+
398
+ # Combine unique indices
399
+ top_eps = np.unique(np.concatenate([idx_energy, idx_comfort, idx_mixed]))
400
+
401
+ else:
402
+ if self.topk_on == "energy": rank_signal = rets_e
403
+ elif self.topk_on == "comfort": rank_signal = rets_c
404
+ elif self.topk_on == "mixed": rank_signal = rets_e + rets_c
405
+ else: rank_signal = rets_e # Fallback
406
+
407
+ order = np.argsort(rank_signal)[::-1]
408
+ top_eps = order[:total_k]
409
+ # === APPLY FILTER ===
410
+ if self.topk_mode == "filter":
411
+ self.sel_eps = top_eps
412
+ self.weights = None
413
+ elif self.topk_mode == "weighted":
414
+ self.sel_eps = top_eps
415
+ self.weights = None
416
+
417
+
418
+ # Load Norm Stats
419
+ if USE_NORMALIZATION:
420
+ if not os.path.exists(NORM_CACHE_PATH):
421
+ print("[DataLoader] Computing Norm Stats...")
422
+ compute_and_save_norm_stats(npz_paths, self.index)
423
+
424
+ z = np.load(NORM_CACHE_PATH)
425
+ self.obs_mean = z["obs_mean"].astype(np.float32)
426
+ self.obs_std = z["obs_std"].astype(np.float32)
427
+ self.act_mean = z["act_mean"].astype(np.float32)
428
+ self.act_std = z["act_std"].astype(np.float32)
429
+
430
+ self.scale_energy = float(z["scale_energy"][0])
431
+ self.scale_comfort = float(z["scale_comfort"][0])
432
+ else:
433
+ self.obs_mean = None
434
+ self.scale_energy = 1.0
435
+ self.scale_comfort = 1.0
436
+
437
+
438
+ def set_epoch(self, e: int):
439
+ self.epoch = int(e)
440
+
441
+ def __len__(self):
442
+ return self.virtual_len
443
+
444
+
445
+ def __getitem__(self, i: int) -> Dict[str, Any]:
446
+ x = _mix_u64(self.seed ^ (self.epoch * 0x9E3779B97F4A7C15) ^ (int(i) * 0xD1B54A32D192ED03))
447
+
448
+ # Preference sampling
449
+ if PREF_MODE == "fixed":
450
+ lam = float(PREF_FIXED_LAMBDA)
451
+ else:
452
+ rng = np.random.default_rng(int(x & 0xFFFFFFFF))
453
+ lam = float(rng.beta(PREF_BETA_A, PREF_BETA_B))
454
+
455
+ if self.weights is None:
456
+ ep_i = int(self.sel_eps[x % len(self.sel_eps)])
457
+ else:
458
+ u = ((x & 0xFFFFFFFF) / 2**32)
459
+ #Clip index to avoid out-of-bounds
460
+ cdf = np.cumsum(self.weights)
461
+ idx = int(np.searchsorted(cdf, u, side="right"))
462
+ idx = min(idx, len(self.weights) - 1)
463
+ ep_i = int(self.sel_eps[idx])
464
+
465
+ p = self.index.paths[ep_i]
466
+ T_total = int(self.index.T[ep_i])
467
+ L = CONTEXT_LEN
468
+ # 1. Load Data
469
+ with np.load(p, allow_pickle=True) as d:
470
+ raw_obs = d["observations"].astype(np.float32)
471
+ at = d["actions"].astype(np.float32)
472
+
473
+ if "rewards_energy" in d:
474
+ re = d["rewards_energy"].astype(np.float32)
475
+ rc = d["rewards_comfort"].astype(np.float32)
476
+ else:
477
+ re = d["rewards"].astype(np.float32)
478
+ rc = np.zeros_like(re)
479
+
480
+ if T_total >= L:
481
+ total_r = re + rc
482
+ num_candidates = 20
483
+ candidates = np.random.randint(0, T_total - L, size=num_candidates)
484
+ scores = np.array([total_r[c : c + L].sum() for c in candidates])
485
+
486
+ scores_stab = (scores - np.max(scores)) / (np.std(scores) + 1e-6)
487
+ probs = np.exp(scores_stab)
488
+ probs /= probs.sum()
489
+ s0 = np.random.choice(candidates, p=probs)
490
+ else:
491
+ s0 = 0
492
+ cidx = self.index.comfort_idx[ep_i]
493
+ if len(cidx) > 0:
494
+ ash55_raw_slice = raw_obs[:, cidx]
495
+ else:
496
+ ash55_raw_slice = np.zeros((T_total, 1), dtype=np.float32)
497
+ keep_idxs = self.index.keep_indices_map[ep_i]
498
+ st = raw_obs[:, keep_idxs]
499
+ s_keys_ep = self.index.state_keys[ep_i]
500
+ def find_idx(substring):
501
+ for idx, k in enumerate(s_keys_ep):
502
+ if substring in k.lower(): return idx
503
+ return -1
504
+
505
+ idx_out = find_idx("outdoor_temp")
506
+ idx_dew = find_idx("dewpoint")
507
+ idx_hr = find_idx("hour")
508
+ idx_mth = find_idx("month")
509
+ idx_occ = compute_occupancy_indices_from_state_keys(s_keys_ep)
510
+
511
+ def get_window(arr, pad_val=0.0):
512
+ if T_total >= L:
513
+ return arr[s0:s0+L]
514
+ else:
515
+ out = np.full((L, *arr.shape[1:]), pad_val, dtype=np.float32)
516
+ out[:T_total] = arr
517
+ return out
518
+
519
+ st_win = get_window(st)
520
+ at_win = get_window(at)
521
+ at_win_raw = at_win.copy()
522
+
523
+ re_win = get_window(re)
524
+ rc_win = get_window(rc)
525
+
526
+ ash55_win = get_window(ash55_raw_slice)
527
+ ash55_any = ash55_win.mean(axis=1).astype(np.float32)
528
+
529
+ tm_win = np.zeros((L,), dtype=np.float32)
530
+ valid_len = min(T_total, L)
531
+ tm_win[:valid_len] = 1.0
532
+
533
+ valid_mask = (tm_win > 0.5)
534
+
535
+ FORECAST_STEPS = 48
536
+ future_start = s0 + L
537
+ future_end = min(T_total, future_start + FORECAST_STEPS)
538
+
539
+ forecast_temp = 0.0
540
+ if idx_out != -1:
541
+ current_vals = st_win[valid_mask, idx_out]
542
+ if len(current_vals) > 0:
543
+ forecast_temp = current_vals.mean()
544
+ if future_end > future_start:
545
+ future_vals = st[future_start:future_end, idx_out]
546
+ if len(future_vals) > 0:
547
+ forecast_temp = future_vals.mean()
548
+
549
+ # 3. Context Vector
550
+ t_mean, t_std = 0.0, 0.0
551
+ if idx_out != -1 and valid_mask.sum() > 0:
552
+ vals = st_win[valid_mask, idx_out]
553
+ t_mean, t_std = vals.mean(), vals.std()
554
+
555
+ d_mean = 0.0
556
+ if idx_dew != -1 and valid_mask.sum() > 0:
557
+ d_mean = st_win[valid_mask, idx_dew].mean()
558
+
559
+ occ_frac = 0.0
560
+ if len(idx_occ) > 0 and valid_mask.sum() > 0:
561
+ occ_sum = st_win[valid_mask][:, idx_occ].sum(axis=1)
562
+ occ_frac = (occ_sum > 0.5).mean()
563
+
564
+ # Cyclical Time
565
+ hr_sin, hr_cos = 0.0, 0.0
566
+ if idx_hr != -1 and valid_mask.sum() > 0:
567
+ hr_val = st_win[valid_mask, idx_hr][0]
568
+ hr_sin = np.sin(2 * np.pi * hr_val / 24.0)
569
+ hr_cos = np.cos(2 * np.pi * hr_val / 24.0)
570
+
571
+ mth_sin, mth_cos = 0.0, 0.0
572
+ if idx_mth != -1 and valid_mask.sum() > 0:
573
+ mth_val = st_win[valid_mask, idx_mth][0]
574
+ mth_sin = np.sin(2 * np.pi * mth_val / 12.0)
575
+ mth_cos = np.cos(2 * np.pi * mth_val / 12.0)
576
+ ctx_vec = np.array([
577
+ t_mean, t_std, d_mean, occ_frac,
578
+ hr_sin, hr_cos, mth_sin, mth_cos,
579
+ forecast_temp,
580
+ 0.0
581
+ ], dtype=np.float32)
582
+
583
+ next_st_win = np.zeros_like(st_win)
584
+ future_4h_st_win = np.zeros_like(st_win)
585
+
586
+ if T_total >= L:
587
+ end_idx = min(s0 + L + 1, T_total)
588
+ actual_len = end_idx - (s0 + 1)
589
+ if actual_len > 0:
590
+ next_st_win[:actual_len] = st[s0+1 : end_idx]
591
+ f_end_idx = min(s0 + L + PHYSICS_HORIZON, T_total)
592
+ f_actual_len = f_end_idx - (s0 + PHYSICS_HORIZON)
593
+ if f_actual_len > 0:
594
+ future_4h_st_win[:f_actual_len] = st[s0 + PHYSICS_HORIZON : f_end_idx]
595
+ else:
596
+ if T_total > 1:
597
+ next_st_win[:T_total-1] = st[1:T_total]
598
+ if USE_NORMALIZATION and (self.obs_mean is not None):
599
+ st_win = (st_win - self.obs_mean) / self.obs_std
600
+ next_st_win = (next_st_win - self.obs_mean) / self.obs_std
601
+ future_4h_st_win = (future_4h_st_win - self.obs_mean) / self.obs_std
602
+ at_win = (at_win - self.act_mean) / self.act_std
603
+ delta_4h_win = future_4h_st_win - st_win
604
+ full_rtg_e = discounted_cumsum(re, gamma=self.gamma_rtg)
605
+ full_rtg_c = discounted_cumsum(rc, gamma=self.gamma_rtg)
606
+
607
+ rtg_e_win = get_window(full_rtg_e)
608
+ rtg_c_win = get_window(full_rtg_c)
609
+
610
+ rtg_e_norm = rtg_e_win / self.scale_energy
611
+ rtg_c_norm = rtg_c_win / self.scale_comfort
612
+
613
+ rtg_combined = np.stack([rtg_e_norm, rtg_c_norm], axis=-1)
614
+
615
+ if getattr(self, "is_train", True):
616
+ rtg_combined += np.random.normal(0, 0.005, rtg_combined.shape).astype(np.float32)
617
+ feat_ids = np.full((L, self.max_tokens), PAD_ID, dtype=np.int64)
618
+ feat_vals = np.zeros((L, self.max_tokens), dtype=np.float32)
619
+ zone_ids = np.zeros((L, self.max_tokens), dtype=np.int64)
620
+ attn_mask = np.zeros((L, self.max_tokens), dtype=np.int64)
621
+
622
+ target_toks = np.full((L, self.max_tokens), -100, dtype=np.int64)
623
+ target_mask = np.zeros((L, self.max_tokens), dtype=np.float32)
624
+
625
+ s_meta = self.index.s_meta[ep_i]
626
+ a_meta = self.index.a_meta[ep_i]
627
+
628
+ S_dim = min(len(s_meta), st_win.shape[1])
629
+ A_dim = min(len(a_meta), at_win.shape[1])
630
+
631
+ num_act_toks = min(A_dim, self.max_tokens)
632
+ num_state_toks = min(S_dim, self.max_tokens - num_act_toks)
633
+ if num_state_toks > 0:
634
+ feat_ids[:, :num_state_toks] = [m[0] for m in s_meta[:num_state_toks]]
635
+ zone_ids[:, :num_state_toks] = [m[1] for m in s_meta[:num_state_toks]]
636
+ feat_vals[:, :num_state_toks] = st_win[:, :num_state_toks]
637
+ attn_mask[:, :num_state_toks] = 1
638
+ if num_act_toks > 0:
639
+ start = num_state_toks
640
+ end = start + num_act_toks
641
+ feat_ids[:, start:end] = [m[0] for m in a_meta[:num_act_toks]]
642
+ zone_ids[:, start:end] = [m[1] for m in a_meta[:num_act_toks]]
643
+ attn_mask[:, start:end] = 1
644
+
645
+ a_in = np.zeros((L, num_act_toks), dtype=np.float32)
646
+ if L > 1:
647
+ a_in[1:] = at_win[:-1, :num_act_toks]
648
+ feat_vals[:, start:end] = a_in
649
+
650
+ a_keys = self.index.action_keys[ep_i]
651
+ at_discrete = discretize_actions_to_bins(at_win_raw, a_keys)
652
+
653
+ target_toks[:, start:end] = at_discrete[:, :num_act_toks]
654
+ target_mask[:, start:end] = 1.0
655
+
656
+ valid_t = (tm_win > 0.5)[:, None]
657
+ attn_mask *= valid_t.astype(np.int64)
658
+ target_mask *= valid_t
659
+
660
+ return {
661
+ "feature_ids": feat_ids,
662
+ "feature_values": feat_vals,
663
+ "zone_ids": zone_ids,
664
+ "attention_mask": attn_mask,
665
+ "target_action_tokens": target_toks,
666
+ "target_mask": target_mask,
667
+ "rtg": rtg_combined,
668
+ "rtg_energy": rtg_e_norm,
669
+ "rtg_comfort": rtg_c_norm,
670
+ "rewards_energy": re_win,
671
+ "rewards_comfort": rc_win,
672
+ "pref_lambda": np.float32(lam),
673
+ "ash55_any": ash55_any,
674
+ "next_obs": next_st_win,
675
+ "target_4h_delta": delta_4h_win,
676
+ "time_mask": tm_win,
677
+ "context": ctx_vec,
678
+ }
679
+
680
+ def generalist_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
681
+ def stack(k):
682
+ return np.stack([b[k] for b in batch])
683
+
684
+ return {
685
+ "feature_ids": torch.from_numpy(stack("feature_ids")).long(),
686
+ "feature_values": torch.from_numpy(stack("feature_values")).float(),
687
+ "zone_ids": torch.from_numpy(stack("zone_ids")).long(),
688
+ "attention_mask": torch.from_numpy(stack("attention_mask")).long(),
689
+ "target_action_tokens": torch.from_numpy(stack("target_action_tokens")).long(),
690
+ "target_mask": torch.from_numpy(stack("target_mask")).float(),
691
+
692
+ "rtg": torch.from_numpy(stack("rtg")).float(),
693
+ "rtg_energy": torch.from_numpy(stack("rtg_energy")).float(),
694
+ "rtg_comfort": torch.from_numpy(stack("rtg_comfort")).float(),
695
+
696
+ "rewards_energy": torch.from_numpy(stack("rewards_energy")).float(),
697
+ "rewards_comfort": torch.from_numpy(stack("rewards_comfort")).float(),
698
+
699
+ "pref_lambda": torch.from_numpy(stack("pref_lambda")).float(),
700
+ "ash55_any": torch.from_numpy(stack("ash55_any")).float(),
701
+
702
+ "next_obs": torch.from_numpy(stack("next_obs")).float(),
703
+ "target_4h_delta": torch.from_numpy(stack("target_4h_delta")).float(),
704
+ "time_mask": torch.from_numpy(stack("time_mask")).float(),
705
+ "context": torch.from_numpy(stack("context")).float(),
706
+ }
707
+
708
+ # ============================================================
709
+ # 4) DEBUG MAIN
710
+ # ============================================================
711
+
712
+ def main():
713
+ npz_paths = sorted(glob.glob(os.path.join(DATA_DIR, "TrajectoryData_officesmall", "**", "traj_ep*_seed*.npz"), recursive=True))
714
+ npz_paths = [p for p in npz_paths if os.path.basename(p) not in ("norm_stats.npz",)]
715
+
716
+ if not npz_paths:
717
+ print(f"No data found in {DATA_DIR}")
718
+ return
719
+ ds = GeneralistDataset(npz_paths, max_tokens=64)
720
+ loader = DataLoader(ds, batch_size=4, collate_fn=generalist_collate_fn, num_workers=0)
721
+
722
+ batch = next(iter(loader))
723
+
724
+
725
+ if __name__ == "__main__":
726
+ main()
training/embeddings.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #embeddings.py
2
+
3
+ from __future__ import annotations
4
+ from typing import Dict, List, Optional, Tuple
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ # ============================================================
10
+ # 1.MLP HEAD
11
+ # ============================================================
12
+ class MLPHead(nn.Module):
13
+ def __init__(self, in_dim: int, out_dim: int, hidden_dim: int = 512):
14
+ super().__init__()
15
+ self.net = nn.Sequential(
16
+ nn.Linear(in_dim, hidden_dim),
17
+ nn.GELU(),
18
+ nn.LayerNorm(hidden_dim),
19
+ nn.Linear(hidden_dim, hidden_dim // 2),
20
+ nn.GELU(),
21
+ nn.Linear(hidden_dim // 2, out_dim)
22
+ )
23
+
24
+ def forward(self, x):
25
+ return self.net(x)
26
+
27
+ # ============================================================
28
+ # 2. DECISION TRANSFORMER
29
+ # ============================================================
30
+
31
+ class GeneralistComfortDT(nn.Module):
32
+ def __init__(self, config: dict):
33
+ super().__init__()
34
+ self.config = config
35
+
36
+ d_model = config["D_MODEL"]
37
+ vocab_size = config["VOCAB_SIZE"]
38
+ max_zones = config["MAX_ZONES"]
39
+ context_dim = config.get("CONTEXT_DIM", 10)
40
+ rtg_dim = config.get("RTG_DIM", 2)
41
+ self.feat_embed = nn.Embedding(vocab_size, d_model)
42
+ self.zone_embed = nn.Embedding(max_zones, d_model)
43
+ self.val_proj = nn.Linear(1, d_model)
44
+ self.val_gamma = nn.Embedding(vocab_size, d_model)
45
+ self.val_beta = nn.Embedding(vocab_size, d_model)
46
+ self.ctx_proj = nn.Linear(context_dim, d_model)
47
+ self.rtg_embed = nn.Linear(rtg_dim, d_model)
48
+ self.pos_embed = nn.Parameter(torch.zeros(1, config["CONTEXT_LEN"], d_model))
49
+
50
+ enc_layer = nn.TransformerEncoderLayer(
51
+ d_model=d_model,
52
+ nhead=config["N_HEADS"],
53
+ dim_feedforward=4 * d_model,
54
+ dropout=config["DROPOUT"],
55
+ batch_first=True,
56
+ activation="gelu",
57
+ norm_first=True,
58
+ )
59
+ self.backbone = nn.TransformerEncoder(enc_layer, num_layers=config["N_LAYERS"])
60
+ self.ln_out = nn.LayerNorm(d_model)
61
+ self.action_head = MLPHead(d_model, config["NUM_ACTION_BINS"])
62
+ self.state_head = nn.Linear(d_model, 1)
63
+ self.state_head_4h = nn.Linear(d_model, 1)
64
+ self.return_head = MLPHead(d_model, rtg_dim, hidden_dim=256)
65
+
66
+ self._init_weights()
67
+
68
+ def _init_weights(self):
69
+ for m in self.modules():
70
+ if isinstance(m, nn.Linear):
71
+ nn.init.xavier_uniform_(m.weight)
72
+ if m.bias is not None: nn.init.zeros_(m.bias)
73
+ elif isinstance(m, nn.Embedding):
74
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
75
+ elif isinstance(m, nn.LayerNorm):
76
+ nn.init.ones_(m.weight)
77
+ nn.init.zeros_(m.bias)
78
+
79
+ nn.init.normal_(self.pos_embed, std=0.02)
80
+ nn.init.ones_(self.val_gamma.weight)
81
+ nn.init.zeros_(self.val_beta.weight)
82
+
83
+ @staticmethod
84
+ def _build_time_causal_mask(T: int, K: int, device: torch.device) -> torch.Tensor:
85
+ L = T * K
86
+ ti = torch.arange(L, device=device) // K
87
+ return (ti[None, :] > ti[:, None])
88
+
89
+
90
+
91
+ def forward(
92
+ self,
93
+ feature_ids: torch.Tensor,
94
+ feature_vals: torch.Tensor,
95
+ zone_ids: torch.Tensor,
96
+ attn_mask: torch.Tensor,
97
+ rtg: Optional[torch.Tensor] = None,
98
+ context: Optional[torch.Tensor] = None,
99
+ rtg_dropout_prob: float = 0.0
100
+ ) -> Dict[str, torch.Tensor]:
101
+
102
+ B, T, K = feature_ids.shape
103
+ d_model = self.config["D_MODEL"]
104
+ flat_fids = feature_ids.reshape(B, -1)
105
+ flat_vals = feature_vals.reshape(B, -1, 1)
106
+ flat_zids = zone_ids.reshape(B, -1)
107
+ val_emb = self.val_proj(flat_vals)
108
+ val_emb = self.val_gamma(flat_fids) * val_emb + self.val_beta(flat_fids)
109
+
110
+ x_base = (
111
+ self.feat_embed(flat_fids)
112
+ + self.zone_embed(flat_zids)
113
+ + val_emb
114
+ )
115
+ pos = self.pos_embed[:, :T, :].unsqueeze(2).expand(-1, -1, K, -1).reshape(1, -1, d_model)
116
+ x_base = x_base + pos
117
+
118
+ if context is not None:
119
+ ctx_emb = self.ctx_proj(context).unsqueeze(1)
120
+ x_base = x_base + ctx_emb
121
+ rtg_emb = torch.zeros_like(x_base)
122
+ if rtg is not None:
123
+ flat_rtg = rtg.unsqueeze(2).expand(-1, -1, K, -1).reshape(B, -1, 2)
124
+ if self.training:
125
+ flat_rtg = flat_rtg + torch.randn_like(flat_rtg) * 0.005 # Noise
126
+
127
+ rtg_emb = self.rtg_embed(flat_rtg)
128
+
129
+ if self.training:
130
+ rtg_emb = F.dropout(rtg_emb, p=0.1)
131
+ if rtg_dropout_prob > 0.0:
132
+ mask = torch.bernoulli(torch.full((B, 1, 1), 1.0 - rtg_dropout_prob, device=x_base.device))
133
+ rtg_emb = rtg_emb * mask
134
+ x = x_base + rtg_emb
135
+
136
+
137
+ flat_mask = attn_mask.reshape(B, -1)
138
+ key_padding_mask = (flat_mask == 0)
139
+ attn_mask_2d = self._build_time_causal_mask(T, K, device=x.device)
140
+ x_latent = self.backbone(x, mask=attn_mask_2d, src_key_padding_mask=key_padding_mask)
141
+ x_latent = self.ln_out(x_latent)
142
+ action_logits = self.action_head(x_latent).reshape(B, T, K, -1)
143
+ x_phys = x_latent - rtg_emb
144
+ state_preds = self.state_head(x_phys).reshape(B, T, K)
145
+ state_preds_4h = self.state_head_4h(x_phys).reshape(B, T, K)
146
+ return_preds_raw = self.return_head(x_phys).reshape(B, T, K, -1)
147
+ return_preds = return_preds_raw.mean(dim=2)
148
+
149
+
150
+ if self.training and rtg_dropout_prob > 0.0:
151
+ mask = torch.bernoulli(torch.full((B, 1, 1), 1.0 - rtg_dropout_prob, device=x_base.device))
152
+ rtg_emb = rtg_emb * mask
153
+
154
+ return {
155
+ "action_logits": action_logits,
156
+ "state_preds": state_preds,
157
+ "state_preds_4h": state_preds_4h,
158
+ "return_preds": return_preds,
159
+ "building_latent": x_latent.mean(dim=1)
160
+ }
training/losses.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ losses.py
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Dict, Tuple
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+
13
+ # ============================================================
14
+ # 1) CONFIG
15
+ # ============================================================
16
+
17
+ @dataclass
18
+ class GeneralistLossConfig:
19
+ w_action: float = 1.0
20
+ w_physics: float = 20.0
21
+ w_value: float = 100.0
22
+ label_smoothing: float = 0.0
23
+ use_rtg_weighting: bool = True
24
+ rtg_weight_mode: str = "exp"
25
+ rtg_weight_beta: float = 2.0
26
+ min_token_weight: float = 0.05
27
+
28
+
29
+ # ============================================================
30
+ # 2) HELPERS
31
+ # ============================================================
32
+
33
+ def _expand_rtg_to_tokens(rtg_bt: torch.Tensor, K: int) -> torch.Tensor:
34
+ return rtg_bt.unsqueeze(-1).expand(-1, -1, K)
35
+
36
+
37
+ def _rtg_to_weights(rtg_input: torch.Tensor, mode: str, beta: float) -> torch.Tensor:
38
+ if mode == "none":
39
+ return torch.ones(rtg_input.shape[:2], device=rtg_input.device)
40
+ if rtg_input.dim() == 3:
41
+ mu = rtg_input.mean(dim=1, keepdim=True)
42
+ sig = rtg_input.std(dim=1, keepdim=True, unbiased=False).clamp_min(1e-5)
43
+ rtg_norm = (rtg_input - mu) / sig
44
+ scalar_rtg = rtg_norm.sum(dim=-1)
45
+ else:
46
+ scalar_rtg = rtg_input
47
+ mu_s = scalar_rtg.mean(dim=1, keepdim=True)
48
+ sig_s = scalar_rtg.std(dim=1, keepdim=True, unbiased=False).clamp_min(1e-5)
49
+
50
+ z = (scalar_rtg - mu_s) / sig_s
51
+ z = torch.clamp(z, -5.0, 5.0)
52
+ if mode == "clamp01":
53
+ w = torch.sigmoid(beta * z)
54
+ elif mode == "softplus":
55
+ w = F.softplus(beta * z)
56
+ elif mode == "exp":
57
+ w = torch.exp(beta * z)
58
+ else:
59
+ raise ValueError(f"Unknown rtg_weight_mode={mode}")
60
+ w = torch.clamp(w, min=0.01, max=50.0)
61
+ return w
62
+
63
+ # return total, metrics
64
+ def compute_generalist_loss(
65
+ model_out: Dict[str, torch.Tensor],
66
+ batch: Dict[str, torch.Tensor],
67
+ config: GeneralistLossConfig
68
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
69
+ """
70
+ Computes Physics loss and Rescaled Value loss.
71
+ """
72
+ action_logits = model_out["action_logits"] # [B, T, K, n_bins]
73
+ state_preds = model_out["state_preds"] # [B, T, K]
74
+ state_preds_4h = model_out["state_preds_4h"] # [B, T, K]
75
+ return_preds = model_out["return_preds"] # [B, T, 2]
76
+
77
+ target_tokens = batch["target_action_tokens"]
78
+ target_mask = batch["target_mask"].float()
79
+ attn_mask = batch["attention_mask"].float()
80
+ target_rtg = batch["rtg"].float()
81
+ time_mask = batch.get("time_mask", torch.ones(target_rtg.shape[:2], device=target_rtg.device)).float()
82
+
83
+ B, T, K, n_bins = action_logits.shape
84
+ is_state = (1.0 - target_mask)
85
+ valid_phys = attn_mask * is_state
86
+
87
+ # 1) Stitching
88
+ if config.use_rtg_weighting:
89
+ w_bt = _rtg_to_weights(target_rtg, config.rtg_weight_mode, config.rtg_weight_beta)
90
+ w_btk = _expand_rtg_to_tokens(w_bt, K)
91
+ norm_factor = (target_mask * attn_mask).sum().clamp_min(1e-6) / (w_btk * target_mask * attn_mask).sum().clamp_min(1e-6)
92
+ token_importance = w_btk * norm_factor
93
+ else:
94
+ w_bt = torch.ones((B, T), device=action_logits.device)
95
+ token_importance = torch.ones((B, T, K), device=action_logits.device)
96
+
97
+ # 2) ACTION LOSS (CE)
98
+ flat_logits = action_logits.reshape(-1, n_bins)
99
+ flat_targets = target_tokens.reshape(-1)
100
+ flat_mask = (target_mask * attn_mask).reshape(-1)
101
+ flat_importance = token_importance.reshape(-1)
102
+
103
+ with torch.no_grad():
104
+ valid_t = flat_targets[flat_mask > 0.5]
105
+ if valid_t.numel() > 0:
106
+ counts = torch.bincount(valid_t, minlength=n_bins).float()
107
+ class_weights = (1.0 / (counts + 10.0)) / (1.0 / (counts + 10.0)).mean()
108
+ else:
109
+ class_weights = torch.ones(n_bins, device=flat_logits.device)
110
+
111
+ ce_per_token = F.cross_entropy(flat_logits, flat_targets, weight=class_weights, reduction="none", ignore_index=-100)
112
+ loss_action = (ce_per_token * flat_mask * flat_importance).sum() / flat_mask.sum().clamp_min(1e-6)
113
+
114
+ # ============================================================
115
+ # 3) PHYSICS LOSS (The Delta Fix)
116
+ # ============================================================
117
+ # Ground Truth from Dataloader
118
+ # next_obs is [B, T, 21]
119
+ # feature_values is [B, T, 64] (Padded tokens)
120
+ true_next = batch["next_obs"].float()
121
+ target_delta_4h = batch["target_4h_delta"].float()
122
+ K_limit = true_next.shape[2]
123
+ true_vals_sliced = batch["feature_values"].float().narrow(2, 0, K_limit)
124
+ s_pred_valid = state_preds.narrow(2, 0, K_limit)
125
+ s_pred_4h_valid = state_preds_4h.narrow(2, 0, K_limit)
126
+ v_phys_mask = valid_phys.narrow(2, 0, K_limit)
127
+ target_delta_1s = true_next - true_vals_sliced
128
+ mse_1s = (s_pred_valid - target_delta_1s) ** 2
129
+ mse_4h = (s_pred_4h_valid - target_delta_4h) ** 2
130
+ with torch.no_grad():
131
+ act_diff = torch.zeros((B, T), device=true_next.device)
132
+ if T > 1:
133
+ act_diff[:, 1:] = torch.abs(true_vals_sliced[:, 1:] - true_vals_sliced[:, :-1]).sum(dim=-1)
134
+ excitation = (1.0 + 5.0 * act_diff).unsqueeze(-1)
135
+ denom = (v_phys_mask * excitation).sum().clamp_min(1e-6)
136
+
137
+ loss_phys_1s = (mse_1s * v_phys_mask * excitation).sum() / denom
138
+ loss_phys_4h = (mse_4h * v_phys_mask * excitation).sum() / denom
139
+
140
+ loss_physics = loss_phys_1s + 0.5 * loss_phys_4h
141
+ val_mse = ((return_preds - target_rtg) ** 2).sum(dim=-1)
142
+ loss_value = (val_mse * w_bt * time_mask).sum() / time_mask.sum().clamp_min(1e-6)
143
+ loss_value = loss_value * 500.0
144
+ total = (config.w_action * loss_action) + \
145
+ (config.w_physics * loss_physics) + \
146
+ (config.w_value * loss_value)
147
+ with torch.no_grad():
148
+ acc = ((torch.argmax(flat_logits, -1) == flat_targets).float() * flat_mask).sum() / flat_mask.sum().clamp_min(1e-6)
149
+ if torch.rand(1) < 0.001:
150
+ print(f"[Loss Debug] Action: {loss_action.item():.3f} | Phys: {loss_physics.item():.3f} | Val: {loss_value.item():.3f}")
151
+
152
+ metrics = {
153
+ "loss_action": loss_action.detach(),
154
+ "loss_physics": loss_physics.detach(),
155
+ "loss_value": loss_value.detach(),
156
+ "accuracy": acc.detach(),
157
+ "total_loss": total.detach(),
158
+ }
159
+ return total, metrics
training/training.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #train.py
2
+
3
+ import os
4
+ import time
5
+ import math
6
+ import glob
7
+ import json
8
+ import numpy as np
9
+ import torch
10
+ from torch.utils.data import DataLoader
11
+ from tqdm import tqdm
12
+ import traceback
13
+ import matplotlib.pyplot as plt
14
+ from collections import Counter
15
+
16
+ # --- New Modules ---
17
+ import dataloader as dl
18
+ from embeddings import GeneralistComfortDT
19
+ from losses import compute_generalist_loss, GeneralistLossConfig
20
+ import plots
21
+
22
+ # ============================================================
23
+ # CONFIGURATION
24
+ # ============================================================
25
+ DATA_DIR = "TrajectoryData_from_docker"
26
+ RUNS_DIR = "training-runs"
27
+
28
+ # Architecture
29
+ VOCAB_SIZE = 512
30
+ D_MODEL = 256
31
+ N_LAYERS = 6
32
+ N_HEADS = 8
33
+ DROPOUT = 0.1
34
+ MAX_ZONES = 32
35
+
36
+ # Training
37
+ BATCH_SIZE = 16
38
+ EPOCHS = 50
39
+ LR = 3e-4
40
+ WARMUP_STEPS = 1000
41
+ WEIGHT_DECAY = 1e-2
42
+ GRAD_CLIP = 1.0
43
+
44
+ MAX_TOKENS_PER_STEP = 64
45
+ CONTEXT_LEN = 48
46
+ CONTEXT_DIM = 10
47
+ RTG_DIM = 2 # Energy + Comfort
48
+
49
+ # Loss Weights
50
+ W_ACTION = 1.0
51
+ W_PHYSICS = 1.0
52
+ W_VALUE = 1.0
53
+
54
+ # Generalist Stitching Config
55
+ USE_TOPK = True
56
+ TOPK_FRACTION = 1.0
57
+ TOPK_MODE = "filter"
58
+ TOPK_ON = "pareto"
59
+ RTG_SCALE = 1.0
60
+
61
+ # Robustness
62
+ RTG_DROPOUT_PROB = 0.2
63
+
64
+ SEED = 42
65
+ NUM_WORKERS = 12
66
+
67
+ # ============================================================
68
+ # UTILITIES
69
+ # ============================================================
70
+ def set_seed(s):
71
+ torch.manual_seed(s)
72
+ torch.cuda.manual_seed_all(s)
73
+ np.random.seed(s)
74
+
75
+ def list_episode_npzs(data_dir: str):
76
+ paths = sorted(glob.glob(os.path.join(DATA_DIR, "TrajectoryData_officesmall", "**", "traj_ep*_seed*.npz"), recursive=True))
77
+ paths = [p for p in paths if "norm_stats" not in p and "cache" not in p]
78
+ return paths
79
+
80
+ def load_checkpoint_if_available(run_dir, model, opt, scaler, device):
81
+ last_path = os.path.join(run_dir, "last.pt")
82
+ if not os.path.exists(last_path):
83
+ return 1, 0
84
+ ckpt = torch.load(last_path, map_location=device)
85
+ model.load_state_dict(ckpt["model"])
86
+ opt.load_state_dict(ckpt["opt"])
87
+ scaler.load_state_dict(ckpt["scaler"])
88
+ start_epoch = int(ckpt.get("epoch", 0)) + 1
89
+ global_step = int(ckpt.get("global_step", 0))
90
+ print(f"[Resume] Loaded {last_path} | start_epoch={start_epoch} global_step={global_step}")
91
+ return start_epoch, global_step
92
+
93
+ def save_checkpoint(run_dir, model, opt, scaler, epoch, global_step, name):
94
+ ckpt = {
95
+ "epoch": epoch,
96
+ "global_step": global_step,
97
+ "model": model.state_dict(),
98
+ "opt": opt.state_dict(),
99
+ "scaler": scaler.state_dict(),
100
+ }
101
+ torch.save(ckpt, os.path.join(run_dir, name))
102
+
103
+ def get_run_dir():
104
+ os.makedirs(RUNS_DIR, exist_ok=True)
105
+ existing = len(glob.glob(os.path.join(RUNS_DIR, "run_*")))
106
+ path = os.path.join(RUNS_DIR, f"run_{existing+1:03d}")
107
+ os.makedirs(path, exist_ok=True)
108
+ os.makedirs(os.path.join(path, "plots"), exist_ok=True)
109
+ return path
110
+
111
+ def _atomic_write_json(path, obj):
112
+ tmp = path + ".tmp"
113
+ with open(tmp, "w") as f:
114
+ json.dump(obj, f, indent=2)
115
+ os.replace(tmp, path)
116
+
117
+ # ============================================================
118
+ # MAIN LOOP
119
+ # ============================================================
120
+ def main():
121
+ set_seed(SEED)
122
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
123
+ torch.backends.cuda.matmul.allow_tf32 = True
124
+ torch.backends.cudnn.allow_tf32 = True
125
+ torch.set_float32_matmul_precision("high")
126
+
127
+ run_dir = get_run_dir()
128
+ os.makedirs(os.path.join(run_dir, "plot_data"), exist_ok=True)
129
+
130
+ report_path = os.path.join(run_dir, "report.json")
131
+ metrics_csv = os.path.join(run_dir, "metrics.csv")
132
+
133
+ hist = {"step": [], "loss": [], "acc": [], "phy": [], "val": [], "lr": [], "grad_norm": [], "loss_action": []}
134
+ epoch_hist = {"epoch": [], "loss_mean": [], "acc_mean": [], "phy_mean": [], "val_mean": []}
135
+
136
+ report = {
137
+ "run_dir": run_dir,
138
+ "started_at": time.strftime("%Y-%m-%d %H:%M:%S"),
139
+ "config": {
140
+ "DATA_DIR": DATA_DIR, "MAX_TOKENS": MAX_TOKENS_PER_STEP,
141
+ "BATCH_SIZE": BATCH_SIZE, "LR": LR, "SEED": SEED
142
+ },
143
+ "status": "running",
144
+ "progress": {"epoch": 0, "global_step": 0},
145
+ }
146
+ _atomic_write_json(report_path, report)
147
+
148
+ try:
149
+ print(f"Loading data from {DATA_DIR}...")
150
+ all_paths = list_episode_npzs(DATA_DIR)
151
+ if not all_paths: raise RuntimeError(f"No valid npz files found in {DATA_DIR}")
152
+
153
+ train_ds = dl.GeneralistDataset(
154
+ all_paths, seed=SEED,
155
+ max_tokens=MAX_TOKENS_PER_STEP,
156
+ topk_frac=TOPK_FRACTION,
157
+ topk_mode=TOPK_MODE,
158
+ topk_on=TOPK_ON
159
+ )
160
+ train_ds.is_train = True
161
+ train_loader = DataLoader(
162
+ train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS,
163
+ pin_memory=True, pin_memory_device="cuda", persistent_workers=True,
164
+ prefetch_factor=4, collate_fn=dl.generalist_collate_fn, drop_last=True
165
+ )
166
+
167
+ model_config = {
168
+ "VOCAB_SIZE": VOCAB_SIZE, "D_MODEL": D_MODEL,
169
+ "N_LAYERS": N_LAYERS, "N_HEADS": N_HEADS,
170
+ "DROPOUT": DROPOUT, "MAX_ZONES": MAX_ZONES,
171
+ "CONTEXT_LEN": CONTEXT_LEN,
172
+ "NUM_ACTION_BINS": dl.NUM_ACTION_BINS,
173
+ "CONTEXT_DIM": CONTEXT_DIM,
174
+ "RTG_DIM": RTG_DIM
175
+ }
176
+
177
+ model = GeneralistComfortDT(model_config).to(device)
178
+
179
+ total_params = sum(p.numel() for p in model.parameters())
180
+ print(f"\n{'='*40}\nModel Params: {total_params:,}\n{'='*40}\n")
181
+
182
+ opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
183
+ scaler = torch.amp.GradScaler("cuda")
184
+ start_epoch, global_step = load_checkpoint_if_available(run_dir, model, opt, scaler, device)
185
+
186
+ loss_cfg = GeneralistLossConfig(
187
+ w_action=W_ACTION,
188
+ w_physics=W_PHYSICS,
189
+ w_value=W_VALUE,
190
+ use_rtg_weighting=True,
191
+ rtg_weight_mode="exp",
192
+ rtg_weight_beta=2.0
193
+ )
194
+
195
+ _atomic_write_json(os.path.join(run_dir, "model_config.json"), model_config)
196
+
197
+ total_steps = len(train_loader) * EPOCHS
198
+ print(f"Starting Training | Steps: {total_steps}")
199
+
200
+ csv_header = ["timestamp", "epoch", "step", "loss", "loss_action", "accuracy", "loss_physics", "loss_value", "lr", "grad_norm"]
201
+ csv_buffer = []
202
+
203
+ def flush_csv():
204
+ nonlocal csv_buffer
205
+ if not csv_buffer: return
206
+ write_header = not os.path.exists(metrics_csv)
207
+ with open(metrics_csv, "a") as f:
208
+ if write_header: f.write(",".join(csv_header) + "\n")
209
+ for row in csv_buffer:
210
+ f.write(",".join(str(row.get(k, "")) for k in csv_header) + "\n")
211
+ csv_buffer = []
212
+
213
+ for epoch in range(start_epoch, EPOCHS + 1):
214
+ model.train()
215
+ train_ds.set_epoch(epoch)
216
+ pbar = tqdm(train_loader, desc=f"Ep {epoch}", dynamic_ncols=True)
217
+ stats = {"loss": [], "acc": [], "phy": [], "val": []}
218
+
219
+ for batch in pbar:
220
+ # 1. LR Schedule
221
+ MIN_LR = 5e-5
222
+ curr_lr = MIN_LR + 0.5 * (LR - MIN_LR) * (1 + math.cos(math.pi * global_step / total_steps))
223
+
224
+ # Warmup check stays the same
225
+ if global_step < WARMUP_STEPS:
226
+ curr_lr = LR * (global_step / WARMUP_STEPS)
227
+
228
+ for pg in opt.param_groups:
229
+ pg['lr'] = curr_lr
230
+
231
+ b_gpu = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
232
+
233
+ # 2. RTG Prep
234
+ # rtg is [B, T, 2] (Energy, Comfort)
235
+ rtg_input = b_gpu["rtg"] * RTG_SCALE
236
+
237
+ with torch.amp.autocast("cuda"):
238
+ out = model(
239
+ feature_ids=b_gpu["feature_ids"],
240
+ feature_vals=b_gpu["feature_values"],
241
+ zone_ids=b_gpu["zone_ids"],
242
+ attn_mask=b_gpu["attention_mask"],
243
+ rtg=rtg_input,
244
+ context=b_gpu["context"],
245
+ rtg_dropout_prob=RTG_DROPOUT_PROB
246
+ )
247
+
248
+ # 3. Loss Calculation
249
+ loss, metrics = compute_generalist_loss(out, b_gpu, loss_cfg)
250
+
251
+ opt.zero_grad(set_to_none=True)
252
+ scaler.scale(loss).backward()
253
+ scaler.unscale_(opt)
254
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
255
+ scaler.step(opt)
256
+ if global_step % 500 == 0:
257
+ print(f"DEBUG: Step {global_step} | Grad Norm: {grad_norm:.4f} | LR: {curr_lr:.2e}")
258
+ scaler.update()
259
+
260
+ global_step += 1
261
+
262
+ # 5. Logging
263
+ for k in ["loss_action", "loss_physics", "loss_value", "accuracy", "total_loss"]:
264
+ val = metrics.get(k, 0.0)
265
+ if torch.is_tensor(val): val = val.item()
266
+
267
+ if k == "total_loss": stats["loss"].append(val)
268
+ elif k == "accuracy": stats["acc"].append(val)
269
+ elif k == "loss_physics": stats["phy"].append(val)
270
+ elif k == "loss_value": stats["val"].append(val)
271
+ elif k == "loss_action":
272
+ hist["loss_action"].append(val)
273
+
274
+
275
+ hist["step"].append(global_step)
276
+ hist["loss"].append(stats["loss"][-1])
277
+ hist["acc"].append(stats["acc"][-1])
278
+ hist["phy"].append(stats["phy"][-1])
279
+ hist["val"].append(stats["val"][-1])
280
+ hist["lr"].append(curr_lr)
281
+ hist["grad_norm"].append(float(grad_norm.item()) if torch.is_tensor(grad_norm) else grad_norm)
282
+
283
+ csv_buffer.append({
284
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "epoch": epoch, "step": global_step,
285
+ "loss": stats["loss"][-1],
286
+ "loss_action": metrics.get("loss_action", 0.0).item() if torch.is_tensor(metrics.get("loss_action", 0.0)) else metrics.get("loss_action", 0.0), # <--- ADDED
287
+ "accuracy": stats["acc"][-1],
288
+ "loss_physics": stats["phy"][-1], "loss_value": stats["val"][-1],
289
+ "lr": float(curr_lr), "grad_norm": hist["grad_norm"][-1]
290
+ })
291
+
292
+ if global_step % 50 == 0: flush_csv()
293
+ if global_step % 20 == 0:
294
+ pbar.set_postfix(
295
+ act=f"{metrics.get('loss_action', 0):.2f}", # Action CE
296
+ phy=f"{np.mean(stats['phy'][-20:]):.4f}", # Physics Delta MSE
297
+ val=f"{np.mean(stats['val'][-20:]):.2f}", # Rescaled Value MSE
298
+ acc=f"{np.mean(stats['acc'][-20:]):.2f}"
299
+ )
300
+ model.eval()
301
+ with torch.no_grad():
302
+ try:
303
+ debug_batch = next(iter(train_loader))
304
+ except StopIteration:
305
+ debug_batch = next(iter(train_loader))
306
+
307
+ b_debug = {k: v.to(device) for k, v in debug_batch.items()}
308
+ rtg_input_debug = b_debug["rtg"] * RTG_SCALE
309
+
310
+ # 3. Forward Pass
311
+ out_debug = model(
312
+ feature_ids=b_debug["feature_ids"],
313
+ feature_vals=b_debug["feature_values"],
314
+ zone_ids=b_debug["zone_ids"],
315
+ attn_mask=b_debug["attention_mask"],
316
+ rtg=rtg_input_debug,
317
+ context=b_debug["context"],
318
+ rtg_dropout_prob=0.0
319
+ )
320
+
321
+ # 4. Process Data
322
+ logits = out_debug["action_logits"]
323
+ pred_bins = torch.argmax(logits, dim=-1).cpu().numpy()
324
+ target_bins = b_debug["target_action_tokens"].cpu().numpy()
325
+
326
+ # Create masks
327
+ # [B, T, K] -> [B, T]
328
+ t_mask = b_debug["time_mask"].cpu().numpy().astype(bool) # [B, T]
329
+ # [B, T, K] for actions
330
+ a_mask = b_debug["target_mask"].cpu().numpy().astype(bool) # [B, T, K]
331
+ valid_preds = pred_bins[a_mask]
332
+ valid_targets = target_bins[a_mask]
333
+ target_rtg_raw = b_debug["rtg"].cpu().numpy()
334
+ pred_rtg_raw = out_debug["return_preds"].cpu().numpy()
335
+
336
+ valid_target_rtg = target_rtg_raw[t_mask]
337
+ valid_pred_rtg = pred_rtg_raw[t_mask]
338
+
339
+
340
+ np.savez_compressed(
341
+ os.path.join(run_dir, "plot_data", "distributions.npz"),
342
+ target_actions=valid_targets,
343
+ pred_actions=valid_preds,
344
+ target_rtg=valid_target_rtg,
345
+ pred_rtg=valid_pred_rtg
346
+ )
347
+ # ====================================
348
+
349
+ flush_csv()
350
+ save_checkpoint(run_dir, model, opt, scaler, epoch, global_step, "last.pt")
351
+ if epoch % 5 == 0:
352
+ save_checkpoint(run_dir, model, opt, scaler, epoch, global_step, f"ckpt_{epoch}.pt")
353
+
354
+ epoch_hist["epoch"].append(epoch)
355
+ epoch_hist["loss_mean"].append(np.mean(stats["loss"]))
356
+ epoch_hist["acc_mean"].append(np.mean(stats["acc"]))
357
+ epoch_hist["phy_mean"].append(np.mean(stats["phy"]))
358
+ epoch_hist["val_mean"].append(np.mean(stats["val"]))
359
+
360
+ try:
361
+ plots.save_plot_arrays(run_dir, hist, epoch_hist)
362
+ plots.make_plots(run_dir)
363
+ except Exception as e:
364
+ print(f"Plotting failed: {e}")
365
+
366
+ report["status"] = "complete"
367
+ _atomic_write_json(report_path, report)
368
+ print("Training Complete.")
369
+
370
+ except Exception as e:
371
+ _atomic_write_json(os.path.join(run_dir, "crash.json"), {"error": str(e), "traceback": traceback.format_exc()})
372
+ raise
373
+
374
+ if __name__ == "__main__":
375
+ main()