#!/usr/bin/env python3 """ train.py ======== CLI entry point to train the ComplexityAwareRLAgent, generate plots, and run baseline comparisons. """ import sys import os # Add parent src to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) import torch from profiler import TaskComplexityProfiler from rl_env import ComplexityAwarePIMEnv from rl_agent import ComplexityAwareRLAgent from controller import ComplexityAwarePIMController from training import train_complexity_aware_agent, compute_sample_efficiency from plots import plot_training_behaviour, PolicyInterpreter from baselines import BaselineEvaluator, AblationStudy from benchmarks.mlperf_tiny import MLPerfTinyBenchmark def main(): device = "cuda" if torch.cuda.is_available() else "cpu" print("=" * 65) print(" Complexity-Aware Physics-Based RL PIM Controller v2") print(f" Device: {device}") print("=" * 65) # --- MLPerf Tiny benchmark pre-check --- print("\n--- MLPerf Tiny Benchmark ---") bench = MLPerfTinyBenchmark(device=device) def simple_router(model, input_shape, timesteps): profiler = TaskComplexityProfiler() p = profiler.profile(model, input_shape, timesteps) scores = profiler.compute_suitability_scores(p) return max(scores, key=scores.get) bench.run(simple_router, n_runs=10) bench.print_report() # --- Training --- print("\n--- Training Complexity-Aware Agent ---") agent, metrics = train_complexity_aware_agent( num_episodes=200, max_steps=300, device=device, save_dir="results", checkpoint_every=50) print("\n--- Generating Monitoring Plots ---") profiler = TaskComplexityProfiler() plot_training_behaviour(metrics, profiler, save_dir="results") print("\n--- Policy Interpretability ---") interp = PolicyInterpreter(agent) interp.plot_interpretability(save_path="results/fig5_policy_interpretability.png") # --- Sample Efficiency --- se = compute_sample_efficiency(metrics) print(f"\nSample Efficiency: {se}") # --- Baseline Comparison --- print("\n--- Baseline Comparison ---") evaluator = BaselineEvaluator(num_eval_episodes=20, max_steps=100) results = evaluator.evaluate_all(agent) evaluator.print_comparison_table(results) # --- Ablation Study (optional, slow) --- if input("\nRun ablation study? (y/n) ").lower().startswith("y"): print("\n--- Ablation Study ---") study = AblationStudy(num_episodes=50, max_steps=100, device=device) study.run() # --- Live Demo --- print("\n--- Live Routing Demo ---") ctrl = ComplexityAwarePIMController(device=device) ctrl.agent = agent # Demo SNN with correct dimensions (32x32 -> pool twice -> 8x8) try: import snntorch as snn from snntorch import surrogate HAS_SNNTORCH = True except ImportError: HAS_SNNTORCH = False if HAS_SNNTORCH: class DemoSNN(torch.nn.Module): def __init__(self): super().__init__() sg = surrogate.fast_sigmoid() self.conv1 = torch.nn.Conv2d(2, 16, 3, padding=1) self.lif1 = snn.Leaky(beta=0.9, spike_grad=sg, init_hidden=True) self.pool1 = torch.nn.MaxPool2d(2) self.conv2 = torch.nn.Conv2d(16, 32, 3, padding=1) self.lif2 = snn.Leaky(beta=0.9, spike_grad=sg, init_hidden=True) self.pool2 = torch.nn.MaxPool2d(2) self.fc = torch.nn.Linear(32 * 8 * 8, 10) self.lif3 = snn.Leaky(beta=0.9, spike_grad=sg, init_hidden=True, output=True) def forward(self, x): self.lif1.init_leaky(); self.lif2.init_leaky(); self.lif3.init_leaky() spk_rec = [] T = x.size(1) for t in range(T): frame = x[:, t] c = self.conv1(frame) s = self.lif1(c) p = self.pool1(s) c = self.conv2(p) s = self.lif2(c) p = self.pool2(s) o = self.lif3(self.fc(p.flatten(1))) spk_rec.append(o) return torch.stack(spk_rec) model = DemoSNN() target = ctrl.route_model(model, input_shape=(1, 2, 32, 32), timesteps=10) print(f" DemoSNN -> Routed to: {target}") stats = ctrl.get_stats() print(f" Stats: {stats}") else: print(" snntorch not installed — skipping SNN demo") print("=" * 65) if __name__ == "__main__": main()