File size: 4,649 Bytes
a157e36 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | #!/usr/bin/env python3
"""
train.py
========
CLI entry point to train the ComplexityAwareRLAgent,
generate plots, and run baseline comparisons.
"""
import sys
import os
# Add parent src to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
import torch
from profiler import TaskComplexityProfiler
from rl_env import ComplexityAwarePIMEnv
from rl_agent import ComplexityAwareRLAgent
from controller import ComplexityAwarePIMController
from training import train_complexity_aware_agent, compute_sample_efficiency
from plots import plot_training_behaviour, PolicyInterpreter
from baselines import BaselineEvaluator, AblationStudy
from benchmarks.mlperf_tiny import MLPerfTinyBenchmark
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print("=" * 65)
print(" Complexity-Aware Physics-Based RL PIM Controller v2")
print(f" Device: {device}")
print("=" * 65)
# --- MLPerf Tiny benchmark pre-check ---
print("\n--- MLPerf Tiny Benchmark ---")
bench = MLPerfTinyBenchmark(device=device)
def simple_router(model, input_shape, timesteps):
profiler = TaskComplexityProfiler()
p = profiler.profile(model, input_shape, timesteps)
scores = profiler.compute_suitability_scores(p)
return max(scores, key=scores.get)
bench.run(simple_router, n_runs=10)
bench.print_report()
# --- Training ---
print("\n--- Training Complexity-Aware Agent ---")
agent, metrics = train_complexity_aware_agent(
num_episodes=200, max_steps=300, device=device,
save_dir="results", checkpoint_every=50)
print("\n--- Generating Monitoring Plots ---")
profiler = TaskComplexityProfiler()
plot_training_behaviour(metrics, profiler, save_dir="results")
print("\n--- Policy Interpretability ---")
interp = PolicyInterpreter(agent)
interp.plot_interpretability(save_path="results/fig5_policy_interpretability.png")
# --- Sample Efficiency ---
se = compute_sample_efficiency(metrics)
print(f"\nSample Efficiency: {se}")
# --- Baseline Comparison ---
print("\n--- Baseline Comparison ---")
evaluator = BaselineEvaluator(num_eval_episodes=20, max_steps=100)
results = evaluator.evaluate_all(agent)
evaluator.print_comparison_table(results)
# --- Ablation Study (optional, slow) ---
if input("\nRun ablation study? (y/n) ").lower().startswith("y"):
print("\n--- Ablation Study ---")
study = AblationStudy(num_episodes=50, max_steps=100, device=device)
study.run()
# --- Live Demo ---
print("\n--- Live Routing Demo ---")
ctrl = ComplexityAwarePIMController(device=device)
ctrl.agent = agent
# Demo SNN with correct dimensions (32x32 -> pool twice -> 8x8)
try:
import snntorch as snn
from snntorch import surrogate
HAS_SNNTORCH = True
except ImportError:
HAS_SNNTORCH = False
if HAS_SNNTORCH:
class DemoSNN(torch.nn.Module):
def __init__(self):
super().__init__()
sg = surrogate.fast_sigmoid()
self.conv1 = torch.nn.Conv2d(2, 16, 3, padding=1)
self.lif1 = snn.Leaky(beta=0.9, spike_grad=sg, init_hidden=True)
self.pool1 = torch.nn.MaxPool2d(2)
self.conv2 = torch.nn.Conv2d(16, 32, 3, padding=1)
self.lif2 = snn.Leaky(beta=0.9, spike_grad=sg, init_hidden=True)
self.pool2 = torch.nn.MaxPool2d(2)
self.fc = torch.nn.Linear(32 * 8 * 8, 10)
self.lif3 = snn.Leaky(beta=0.9, spike_grad=sg, init_hidden=True, output=True)
def forward(self, x):
self.lif1.init_leaky(); self.lif2.init_leaky(); self.lif3.init_leaky()
spk_rec = []
T = x.size(1)
for t in range(T):
frame = x[:, t]
c = self.conv1(frame)
s = self.lif1(c)
p = self.pool1(s)
c = self.conv2(p)
s = self.lif2(c)
p = self.pool2(s)
o = self.lif3(self.fc(p.flatten(1)))
spk_rec.append(o)
return torch.stack(spk_rec)
model = DemoSNN()
target = ctrl.route_model(model, input_shape=(1, 2, 32, 32), timesteps=10)
print(f" DemoSNN -> Routed to: {target}")
stats = ctrl.get_stats()
print(f" Stats: {stats}")
else:
print(" snntorch not installed — skipping SNN demo")
print("=" * 65)
if __name__ == "__main__":
main()
|