tempo-snn-v2 / scripts /train.py
KD099's picture
Upload folder using huggingface_hub
a157e36 verified
#!/usr/bin/env python3
"""
train.py
========
CLI entry point to train the ComplexityAwareRLAgent,
generate plots, and run baseline comparisons.
"""
import sys
import os
# Add parent src to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
import torch
from profiler import TaskComplexityProfiler
from rl_env import ComplexityAwarePIMEnv
from rl_agent import ComplexityAwareRLAgent
from controller import ComplexityAwarePIMController
from training import train_complexity_aware_agent, compute_sample_efficiency
from plots import plot_training_behaviour, PolicyInterpreter
from baselines import BaselineEvaluator, AblationStudy
from benchmarks.mlperf_tiny import MLPerfTinyBenchmark
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print("=" * 65)
print(" Complexity-Aware Physics-Based RL PIM Controller v2")
print(f" Device: {device}")
print("=" * 65)
# --- MLPerf Tiny benchmark pre-check ---
print("\n--- MLPerf Tiny Benchmark ---")
bench = MLPerfTinyBenchmark(device=device)
def simple_router(model, input_shape, timesteps):
profiler = TaskComplexityProfiler()
p = profiler.profile(model, input_shape, timesteps)
scores = profiler.compute_suitability_scores(p)
return max(scores, key=scores.get)
bench.run(simple_router, n_runs=10)
bench.print_report()
# --- Training ---
print("\n--- Training Complexity-Aware Agent ---")
agent, metrics = train_complexity_aware_agent(
num_episodes=200, max_steps=300, device=device,
save_dir="results", checkpoint_every=50)
print("\n--- Generating Monitoring Plots ---")
profiler = TaskComplexityProfiler()
plot_training_behaviour(metrics, profiler, save_dir="results")
print("\n--- Policy Interpretability ---")
interp = PolicyInterpreter(agent)
interp.plot_interpretability(save_path="results/fig5_policy_interpretability.png")
# --- Sample Efficiency ---
se = compute_sample_efficiency(metrics)
print(f"\nSample Efficiency: {se}")
# --- Baseline Comparison ---
print("\n--- Baseline Comparison ---")
evaluator = BaselineEvaluator(num_eval_episodes=20, max_steps=100)
results = evaluator.evaluate_all(agent)
evaluator.print_comparison_table(results)
# --- Ablation Study (optional, slow) ---
if input("\nRun ablation study? (y/n) ").lower().startswith("y"):
print("\n--- Ablation Study ---")
study = AblationStudy(num_episodes=50, max_steps=100, device=device)
study.run()
# --- Live Demo ---
print("\n--- Live Routing Demo ---")
ctrl = ComplexityAwarePIMController(device=device)
ctrl.agent = agent
# Demo SNN with correct dimensions (32x32 -> pool twice -> 8x8)
try:
import snntorch as snn
from snntorch import surrogate
HAS_SNNTORCH = True
except ImportError:
HAS_SNNTORCH = False
if HAS_SNNTORCH:
class DemoSNN(torch.nn.Module):
def __init__(self):
super().__init__()
sg = surrogate.fast_sigmoid()
self.conv1 = torch.nn.Conv2d(2, 16, 3, padding=1)
self.lif1 = snn.Leaky(beta=0.9, spike_grad=sg, init_hidden=True)
self.pool1 = torch.nn.MaxPool2d(2)
self.conv2 = torch.nn.Conv2d(16, 32, 3, padding=1)
self.lif2 = snn.Leaky(beta=0.9, spike_grad=sg, init_hidden=True)
self.pool2 = torch.nn.MaxPool2d(2)
self.fc = torch.nn.Linear(32 * 8 * 8, 10)
self.lif3 = snn.Leaky(beta=0.9, spike_grad=sg, init_hidden=True, output=True)
def forward(self, x):
self.lif1.init_leaky(); self.lif2.init_leaky(); self.lif3.init_leaky()
spk_rec = []
T = x.size(1)
for t in range(T):
frame = x[:, t]
c = self.conv1(frame)
s = self.lif1(c)
p = self.pool1(s)
c = self.conv2(p)
s = self.lif2(c)
p = self.pool2(s)
o = self.lif3(self.fc(p.flatten(1)))
spk_rec.append(o)
return torch.stack(spk_rec)
model = DemoSNN()
target = ctrl.route_model(model, input_shape=(1, 2, 32, 32), timesteps=10)
print(f" DemoSNN -> Routed to: {target}")
stats = ctrl.get_stats()
print(f" Stats: {stats}")
else:
print(" snntorch not installed — skipping SNN demo")
print("=" * 65)
if __name__ == "__main__":
main()