File size: 23,705 Bytes
77b2fc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
import os
import sys
import time

# Immediate feedback
print(" [Init] Python process started. Loading libraries...")

print(" [Init] Loading Pytorch...", end="", flush=True)
import torch
import torch as th
import torch.nn.functional as F

print(" Done.")

print(" [Init] Loading Gymnasium & SB3...", end="", flush=True)
import glob
import warnings

import numpy as np
from gymnasium import spaces
from sb3_contrib import MaskablePPO
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
from stable_baselines3.common.utils import explained_variance
from tqdm import tqdm

print(" Done.")

# Filter Numba warning
warnings.filterwarnings("ignore", category=RuntimeWarning, message="nopython is set for njit")

# Ensure project root is in path
sys.path.append(os.getcwd())

print(" [Init] Loading LovecaSim Vector Engine...", end="", flush=True)
from ai.environments.vec_env_adapter import VectorEnvAdapter
from ai.utils.loveca_features_extractor import LovecaFeaturesExtractor

print(" Done.")


class TimeCheckpointCallback(BaseCallback):
    """

    Save the model every N minutes.

    """

    def __init__(self, save_freq_minutes: float, save_path: str, name_prefix: str, verbose: int = 0):
        super().__init__(verbose)
        self.save_freq_seconds = save_freq_minutes * 60
        self.save_path = save_path
        self.name_prefix = name_prefix
        self.last_time_save = time.time()

    def _on_step(self) -> bool:
        if (time.time() - self.last_time_save) > self.save_freq_seconds:
            save_path = os.path.join(self.save_path, f"{self.name_prefix}_time_auto")
            self.model.save(save_path)
            if self.verbose > 0:
                print(f" [Save] Model auto-saved after 3 minutes to {save_path}")
            self.last_time_save = time.time()
        return True


class ModelSnapshotCallback(BaseCallback):
    """

    Saves a 'Model Snapshot' every X minutes:

    - model.zip

    - verified_card_pool.json (Context)

    - snapshot_meta.json (Architecture/Config)

    """

    def __init__(self, save_freq_minutes: float, save_path: str, verbose=0):
        super().__init__(verbose)
        self.save_freq_minutes = save_freq_minutes
        self.save_path = save_path
        self.last_save_time = time.time()
        # Ensure historiccheckpoints exists
        os.makedirs("historiccheckpoints", exist_ok=True)

    def _on_step(self) -> bool:
        if time.time() - self.last_save_time > self.save_freq_minutes * 60:
            self.last_save_time = time.time()
            self._save_snapshot()
        return True

    def _save_snapshot(self):
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        steps = self.num_timesteps
        snapshot_name = f"{timestamp}_{steps}_steps"
        snapshot_dir = os.path.join("historiccheckpoints", snapshot_name)

        if self.verbose > 0:
            print(f" [Snapshot] Saving to {snapshot_dir}...")

        os.makedirs(snapshot_dir, exist_ok=True)

        # 1. Save Model
        model_path = os.path.join(snapshot_dir, "model.zip")
        self.model.save(model_path)

        # 2. Save Card Pool (Context)
        try:
            import shutil

            shutil.copy("verified_card_pool.json", os.path.join(snapshot_dir, "verified_card_pool.json"))
        except Exception as e:
            print(f" [Snapshot] Warning: Could not copy card pool: {e}")

        # 3. Save Metadata (Architecture)
        meta = {
            "timestamp": timestamp,
            "timesteps": int(steps),
            "obs_dim": int(self.model.observation_space.shape[0]),
            "action_space_size": int(self.model.action_space.n),
            "features": ["GlobalVolumes", "LiveZone", "Traits", "TurnNumber"],
            "notes": "Generated by ModelSnapshotCallback",
        }
        try:
            import json

            with open(os.path.join(snapshot_dir, "snapshot_meta.json"), "w") as f:
                json.dump(meta, f, indent=2)
        except Exception as e:
            print(f" [Snapshot] Warning: Could not save meta: {e}")

        # 4. Limit to Last 5 Snapshots
        self._prune_snapshots()

    def _prune_snapshots(self):
        root = os.path.dirname(self.save_path)  # wait, save_path is "historiccheckpoints"?
        # save_path passed in init is "historiccheckpoints" relative to cwd? Yes.
        # But wait, self.save_path in init is used.

        # Let's verify self.save_path from init
        # It is "historiccheckpoints"

        search_dir = self.save_path
        if not os.path.exists(search_dir):
            return

        # Get list of directories
        try:
            subdirs = [
                os.path.join(search_dir, d)
                for d in os.listdir(search_dir)
                if os.path.isdir(os.path.join(search_dir, d))
            ]
            # Sort by creation time (oldest first)
            subdirs.sort(key=os.path.getctime)

            # Keep last 5
            max_keep = 5
            if len(subdirs) > max_keep:
                to_remove = subdirs[:-max_keep]
                import shutil

                for d in to_remove:
                    try:
                        shutil.rmtree(d)
                        if self.verbose > 0:
                            print(f" [Snapshot] Pruned old snapshot: {d}")
                    except Exception as e:
                        print(f" [Snapshot] Warning: Failed to prune {d}: {e}")
        except Exception as e:
            print(f" [Snapshot] Warning: Pruning failed: {e}")


class DetailedStatusCallback(BaseCallback):
    """

    Logs detailed phase information (Collection vs Optimization) and VRAM usage.

    """

    def __init__(self, verbose=0):
        super().__init__(verbose)
        self.collection_start_time = 0.0

    def _on_rollout_start(self) -> None:
        """

        A rollout is the collection of environment steps.

        """
        self.collection_start_time = time.time()
        print(f"\n [Phase] Starting Rollout Collection (Steps: {self.model.n_steps})...")

    def _on_rollout_end(self) -> None:
        """

        This event is triggered before updating the policy.

        """
        duration = time.time() - self.collection_start_time
        n_envs = self.model.n_envs
        n_steps = self.model.n_steps
        total_steps = n_envs * n_steps
        fps = total_steps / duration if duration > 0 else 0

        print(f" [Phase] Collection Complete. Duration: {duration:.2f}s ({fps:.0f} FPS)")

        # PPO optimization is about to start
        print(f" [Phase] Starting PPO Optimization (Epochs: {self.model.n_epochs}, Batch: {self.model.batch_size})...")

        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1024**3
            reserved = torch.cuda.memory_reserved() / 1024**3
            print(f" [VRAM] Allocated: {allocated:.2f} GB | Reserved: {reserved:.2f} GB")
            print(" [Info] Optimization may take time if batch size is large. Please wait...")

    def _on_step(self) -> bool:
        return True


class TrainingStatsCallback(BaseCallback):
    """

    Simple stats logging for Vectorized Training.

    """

    def __init__(self, verbose=0):
        super().__init__(verbose)

    def _on_step(self) -> bool:
        # Log win rate if available in infos
        infos = self.locals.get("infos")
        if infos:
            # VectorEnv doesn't emit 'win_rate' in infos by default unless we add it
            # But we can look for 'episode' keys
            episodes = [i.get("episode") for i in infos if "episode" in i]
            if episodes:
                rew = np.mean([ep["r"] for ep in episodes])
                length = np.mean([ep["l"] for ep in episodes])
                self.logger.record("rollout/ep_rew_mean", rew)
                self.logger.record("rollout/ep_len_mean", length)
        return True


class ProgressMaskablePPO(MaskablePPO):
    """

    MaskablePPO with a tqdm progress bar during the optimization phase.

    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.optimization_pbar = None

    def train(self) -> None:
        """

        Update policy using the currently gathered rollout buffer.

        """
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizer learning rate
        self._update_learning_rate(self.policy.optimizer)
        # Compute current clip range
        clip_range = self.clip_range(self._current_progress_remaining)  # type: ignore[operator]
        # Optional: clip range for the value function
        if self.clip_range_vf is not None:
            clip_range_vf = self.clip_range_vf(self._current_progress_remaining)  # type: ignore[operator]

        entropy_losses = []
        pg_losses, value_losses = [], []
        clip_fractions = []

        continue_training = True

        # train for n_epochs epochs
        # ADDED: Persistent TQDM Progress Bar
        total_steps = self.n_epochs * (self.rollout_buffer.buffer_size // self.batch_size)

        if self.optimization_pbar is None:
            self.optimization_pbar = tqdm(total=total_steps, desc="Optimization", unit="batch", leave=True)
        else:
            self.optimization_pbar.reset(total=total_steps)

        for epoch in range(self.n_epochs):
            approx_kl_divs = []
            # Do a complete pass on the rollout buffer
            for rollout_data in self.rollout_buffer.get(self.batch_size):
                actions = rollout_data.actions
                if isinstance(self.action_space, spaces.Discrete):
                    # Convert discrete action from float to long
                    actions = rollout_data.actions.long().flatten()

                with th.cuda.amp.autocast(enabled=th.cuda.is_available()):
                    values, log_prob, entropy = self.policy.evaluate_actions(
                        rollout_data.observations,
                        actions,
                        action_masks=rollout_data.action_masks,
                    )

                    values = values.flatten()
                    # Normalize advantage
                    advantages = rollout_data.advantages
                    if self.normalize_advantage:
                        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

                    # ratio between old and new policy, should be one at the first iteration
                    ratio = th.exp(log_prob - rollout_data.old_log_prob)

                    # clipped surrogate loss
                    policy_loss_1 = advantages * ratio
                    policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
                    policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()

                    # Logging
                    pg_losses.append(policy_loss.item())
                    clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
                    clip_fractions.append(clip_fraction)

                    if self.clip_range_vf is None:
                        # No clipping
                        values_pred = values
                    else:
                        # Clip the different between old and new value
                        # NOTE: this depends on the reward scaling
                        values_pred = rollout_data.old_values + th.clamp(
                            values - rollout_data.old_values, -clip_range_vf, clip_range_vf
                        )
                    # Value loss using the TD(gae_lambda) target
                    value_loss = F.mse_loss(rollout_data.returns, values_pred)
                    value_losses.append(value_loss.item())

                    # Entropy loss favor exploration
                    if entropy is None:
                        # Approximate entropy when no analytical form
                        entropy_loss = -th.mean(-log_prob)
                    else:
                        entropy_loss = -th.mean(entropy)

                    entropy_losses.append(entropy_loss.item())

                    loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

                # Calculate approximate form of reverse KL Divergence for early stopping
                # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
                # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
                # and Schulman blog: http://joschu.net/blog/kl-approx.html
                with th.no_grad():
                    log_ratio = log_prob - rollout_data.old_log_prob
                    approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
                    approx_kl_divs.append(approx_kl_div)

                if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
                    continue_training = False
                    if self.verbose >= 1:
                        print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
                    break

                # Optimization step
                self.policy.optimizer.zero_grad()

                # AMP: Automatic Mixed Precision
                # Check if scaler exists (backward compatibility)
                if not hasattr(self, "scaler"):
                    self.scaler = th.cuda.amp.GradScaler(enabled=th.cuda.is_available())

                # Backward pass
                self.scaler.scale(loss).backward()

                # Clip grad norm
                self.scaler.unscale_(self.policy.optimizer)
                th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)

                # Optimizer step
                self.scaler.step(self.policy.optimizer)
                self.scaler.update()

                # Update Progress Bar
                self.optimization_pbar.update(1)

            if not continue_training:
                break

        # Don't close, just leave it for the next reset

        self._n_updates += self.n_epochs
        explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())

        # Logs
        self.logger.record("train/entropy_loss", np.mean(entropy_losses))
        self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
        self.logger.record("train/value_loss", np.mean(value_losses))
        self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
        self.logger.record("train/clip_fraction", np.mean(clip_fractions))
        self.logger.record("train/loss", loss.item())
        self.logger.record("train/explained_variance", explained_var)
        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/clip_range", clip_range)
        if self.clip_range_vf is not None:
            self.logger.record("train/clip_range_vf", clip_range_vf)

    def _excluded_save_params(self) -> list[str]:
        """

        Returns the names of the parameters that should be excluded from being saved.

        """
        return super()._excluded_save_params() + ["optimization_pbar"]


def main():
    print("========================================================")
    print(" LovecaSim - STARTING VECTORIZED TRAINING (700k+ SPS)   ")
    print("========================================================")

    # Configuration from Environment Variables
    TOTAL_TIMESTEPS = int(os.getenv("TRAIN_STEPS", "100_000_000"))
    BATCH_SIZE = int(os.getenv("TRAIN_BATCH_SIZE", "8192"))
    NUM_ENVS = int(os.getenv("TRAIN_ENVS", "4096"))
    N_STEPS = int(os.getenv("TRAIN_N_STEPS", "256"))

    # Advanced Hyperparameters
    ENT_COEF = float(os.getenv("ENT_COEF", "0.01"))
    GAMMA = float(os.getenv("GAMMA", "0.99"))
    GAE_LAMBDA = float(os.getenv("GAE_LAMBDA", "0.95"))

    SAVE_PATH = "./checkpoints/vector/"
    os.makedirs(SAVE_PATH, exist_ok=True)

    # Log Hardware/Threading Config
    omp_threads = os.getenv("OMP_NUM_THREADS", "Unset (All Cores)")
    print(f" [Config] Batch Size: {BATCH_SIZE}")
    print(f" [Config] Num Envs:   {NUM_ENVS}")
    print(f" [Config] N Steps:    {N_STEPS}")
    print(f" [Config] CPU Cores:  {omp_threads}")

    # 1. Create Vector Environment (Numba)
    print(f" [Init] Creating {NUM_ENVS} parallel Numba environments...")
    env = VectorEnvAdapter(num_envs=NUM_ENVS)

    # --- WARMUP / COMPILATION ---
    print(" [Init] Compiling Numba functions (Reset)... This may take 30s+")
    env.reset()
    print(" [Init] Compiling Numba functions (Step)... This may take 60s+")
    # Perform a dummy step to force compilation of the massive step kernel
    dummy_actions = np.zeros(NUM_ENVS, dtype=np.int32)
    env.step(dummy_actions)
    print(" [Init] Compilation complete! Starting training...")
    # ----------------------------

    # 2. Setup or Load PPO Agent
    checkpoint_path = os.getenv("LOAD_CHECKPOINT", "")

    # Auto-resolve "LATEST" or "AUTO"
    force_restart = os.getenv("RESTART_TRAINING", "FALSE").upper() == "TRUE"
    if force_restart:
        print(" [Config] RESTART_TRAINING=TRUE. Ignoring checkpoints.")
        checkpoint_path = ""
    elif checkpoint_path.upper() in ["LATEST", "AUTO"]:
        list_of_files = glob.glob(os.path.join(SAVE_PATH, "*.zip"))
        if list_of_files:
            checkpoint_path = max(list_of_files, key=os.path.getctime)
            print(f" [Config] LOAD_CHECKPOINT='{os.getenv('LOAD_CHECKPOINT')}' -> Auto-resolved to: {checkpoint_path}")
        else:
            print(" [Config] LOAD_CHECKPOINT='LATEST' but no checkpoints found. Starting fresh.")
            checkpoint_path = ""

    model = None
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f" [Load] Scanning checkpoint: {checkpoint_path}")
        try:
            # Check dimensions before full load if possible, or load and check
            temp_model = ProgressMaskablePPO.load(checkpoint_path, device="cpu")
            model_obs_dim = temp_model.observation_space.shape[0]
            env_obs_dim = env.observation_space.shape[0]

            if model_obs_dim != env_obs_dim:
                print(f" [Load] Dimension Mismatch! Model: {model_obs_dim}, Env: {env_obs_dim}")
                print(f" [Load] Cannot resume training across eras. Starting FRESH {env_obs_dim}-dim model.")
                model = None
            else:
                print(f" [Load] Dimensions match ({model_obs_dim}). Resuming training...")
                model = ProgressMaskablePPO.load(
                    checkpoint_path,
                    env=env,
                    device="cuda" if torch.cuda.is_available() else "cpu",
                    custom_objects={
                        "learning_rate": float(os.getenv("LEARNING_RATE", "3e-4")),
                        "batch_size": BATCH_SIZE,
                        "n_epochs": int(os.getenv("NUM_EPOCHS", "4")),
                    },
                )
                reset_num_timesteps = False
                print(" [Load] Success.")
        except Exception as e:
            print(f" [Error] Failed to load checkpoint: {e}")
            print(" [Init] Falling back to fresh model...")
            model = None

    if model is None:
        if checkpoint_path and not os.path.exists(checkpoint_path):
            print(f" [Warning] Checkpoint file not found: {checkpoint_path}")

        print(" [Init] Creating fresh ProgressMaskablePPO model...")

        # Determine Policy Args
        obs_mode_env = os.getenv("OBS_MODE", "STANDARD")
        if obs_mode_env == "ATTENTION":
            print(" [Init] Using LovecaFeaturesExtractor (Attention)")
            policy_kwargs = dict(
                features_extractor_class=LovecaFeaturesExtractor,
                features_extractor_kwargs=dict(features_dim=256),
                net_arch=[],
            )
        else:
            policy_kwargs = dict(net_arch=[512, 512])

        model = ProgressMaskablePPO(
            "MlpPolicy",
            env,
            verbose=1,
            learning_rate=float(os.getenv("LEARNING_RATE", "3e-4")),
            n_steps=N_STEPS,
            batch_size=BATCH_SIZE,
            n_epochs=int(os.getenv("NUM_EPOCHS", "4")),
            gamma=GAMMA,
            gae_lambda=GAE_LAMBDA,
            ent_coef=ENT_COEF,
            tensorboard_log="./logs/vector_tensorboard/",
            policy_kwargs=policy_kwargs,
        )
        reset_num_timesteps = True

    print(f" [Init] PPO Model initialized. Device: {model.device}")

    # 3. Callbacks
    # Refactored: Callbacks moved to module level.

    # Standard Checkpoint (Keep for compatibility/safety)
    checkpoint_callback = CheckpointCallback(
        save_freq=max(1, 1000000 // NUM_ENVS), save_path=SAVE_PATH, name_prefix="numba_ppo"
    )

    save_freq = float(os.getenv("SAVE_FREQ_MINS", "15.0"))

    # Snapshot Callback (Replaces TimeCheckpointCallback)
    snapshot_callback = ModelSnapshotCallback(
        save_freq_minutes=save_freq,
        save_path="historiccheckpoints",
        verbose=1,
    )
    # Store OBS_MODE in snapshot meta
    # (We need to update ModelSnapshotCallback logic or just trust env stores it?
    #  Ideally pass it to callback or update meta generation.
    #  Let's keep it simple: Environment tracks it.)

    # 4. Train
    print(" [Train] Starting training loop...")
    print(f" [Train] Model Mode:   {os.getenv('OBS_MODE', 'STANDARD')}")
    print(f" [Train] Reset Timesteps: {reset_num_timesteps}")
    print(" [Note] Press Ctrl+C to stop and force-save.")

    # Generate a timestamped run name for TensorBoard
    run_name = f"ProgressPPO_{time.strftime('%m%d_%H%M%S')}"
    if not reset_num_timesteps:
        run_name += "_RESUME"

    try:
        model.learn(
            total_timesteps=TOTAL_TIMESTEPS,
            callback=[
                checkpoint_callback,
                snapshot_callback,
                TrainingStatsCallback(),
                DetailedStatusCallback(),
            ],  # Use Snapshot + DetailedStatus!
            progress_bar=True,
            reset_num_timesteps=reset_num_timesteps,
            tb_log_name=run_name,
        )
        print(" [Train] Training finished.")
        model.save(f"{SAVE_PATH}/final_model")
    except KeyboardInterrupt:
        print("\n [Train] Interrupted by user. Saving model...")
        model.save(f"{SAVE_PATH}/interrupted_model")
        # Trigger explicit snapshot on interrupt
        snapshot_callback._save_snapshot()
        print(" [Train] Model saved.")
    except Exception as e:
        print(f"\n [Error] Training crashed: {e}")
        import traceback

        traceback.print_exc()
        emergency_save = os.path.join(SAVE_PATH, "crash_emergency")
        model.save(emergency_save)
        print(f" [Save] Crash emergency checkpoint saved to: {emergency_save}")
    finally:
        print(" [Done] Exiting gracefully.")
        env.close()


if __name__ == "__main__":
    main()