| |
| """ |
| train.py |
| ======== |
| CLI entry point to train the ComplexityAwareRLAgent, |
| generate plots, and run baseline comparisons. |
| """ |
|
|
| import sys |
| import os |
|
|
| |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) |
|
|
| import torch |
| from profiler import TaskComplexityProfiler |
| from rl_env import ComplexityAwarePIMEnv |
| from rl_agent import ComplexityAwareRLAgent |
| from controller import ComplexityAwarePIMController |
| from training import train_complexity_aware_agent, compute_sample_efficiency |
| from plots import plot_training_behaviour, PolicyInterpreter |
| from baselines import BaselineEvaluator, AblationStudy |
| from benchmarks.mlperf_tiny import MLPerfTinyBenchmark |
|
|
|
|
| def main(): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print("=" * 65) |
| print(" Complexity-Aware Physics-Based RL PIM Controller v2") |
| print(f" Device: {device}") |
| print("=" * 65) |
|
|
| |
| print("\n--- MLPerf Tiny Benchmark ---") |
| bench = MLPerfTinyBenchmark(device=device) |
|
|
| def simple_router(model, input_shape, timesteps): |
| profiler = TaskComplexityProfiler() |
| p = profiler.profile(model, input_shape, timesteps) |
| scores = profiler.compute_suitability_scores(p) |
| return max(scores, key=scores.get) |
|
|
| bench.run(simple_router, n_runs=10) |
| bench.print_report() |
|
|
| |
| print("\n--- Training Complexity-Aware Agent ---") |
| agent, metrics = train_complexity_aware_agent( |
| num_episodes=200, max_steps=300, device=device, |
| save_dir="results", checkpoint_every=50) |
|
|
| print("\n--- Generating Monitoring Plots ---") |
| profiler = TaskComplexityProfiler() |
| plot_training_behaviour(metrics, profiler, save_dir="results") |
|
|
| print("\n--- Policy Interpretability ---") |
| interp = PolicyInterpreter(agent) |
| interp.plot_interpretability(save_path="results/fig5_policy_interpretability.png") |
|
|
| |
| se = compute_sample_efficiency(metrics) |
| print(f"\nSample Efficiency: {se}") |
|
|
| |
| print("\n--- Baseline Comparison ---") |
| evaluator = BaselineEvaluator(num_eval_episodes=20, max_steps=100) |
| results = evaluator.evaluate_all(agent) |
| evaluator.print_comparison_table(results) |
|
|
| |
| if input("\nRun ablation study? (y/n) ").lower().startswith("y"): |
| print("\n--- Ablation Study ---") |
| study = AblationStudy(num_episodes=50, max_steps=100, device=device) |
| study.run() |
|
|
| |
| print("\n--- Live Routing Demo ---") |
| ctrl = ComplexityAwarePIMController(device=device) |
| ctrl.agent = agent |
|
|
| |
| try: |
| import snntorch as snn |
| from snntorch import surrogate |
| HAS_SNNTORCH = True |
| except ImportError: |
| HAS_SNNTORCH = False |
|
|
| if HAS_SNNTORCH: |
| class DemoSNN(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| sg = surrogate.fast_sigmoid() |
| self.conv1 = torch.nn.Conv2d(2, 16, 3, padding=1) |
| self.lif1 = snn.Leaky(beta=0.9, spike_grad=sg, init_hidden=True) |
| self.pool1 = torch.nn.MaxPool2d(2) |
| self.conv2 = torch.nn.Conv2d(16, 32, 3, padding=1) |
| self.lif2 = snn.Leaky(beta=0.9, spike_grad=sg, init_hidden=True) |
| self.pool2 = torch.nn.MaxPool2d(2) |
| self.fc = torch.nn.Linear(32 * 8 * 8, 10) |
| self.lif3 = snn.Leaky(beta=0.9, spike_grad=sg, init_hidden=True, output=True) |
| def forward(self, x): |
| self.lif1.init_leaky(); self.lif2.init_leaky(); self.lif3.init_leaky() |
| spk_rec = [] |
| T = x.size(1) |
| for t in range(T): |
| frame = x[:, t] |
| c = self.conv1(frame) |
| s = self.lif1(c) |
| p = self.pool1(s) |
| c = self.conv2(p) |
| s = self.lif2(c) |
| p = self.pool2(s) |
| o = self.lif3(self.fc(p.flatten(1))) |
| spk_rec.append(o) |
| return torch.stack(spk_rec) |
|
|
| model = DemoSNN() |
| target = ctrl.route_model(model, input_shape=(1, 2, 32, 32), timesteps=10) |
| print(f" DemoSNN -> Routed to: {target}") |
| stats = ctrl.get_stats() |
| print(f" Stats: {stats}") |
| else: |
| print(" snntorch not installed — skipping SNN demo") |
|
|
| print("=" * 65) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|