File size: 4,649 Bytes
a157e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python3
"""
train.py
========
CLI entry point to train the ComplexityAwareRLAgent,
generate plots, and run baseline comparisons.
"""

import sys
import os

# Add parent src to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))

import torch
from profiler import TaskComplexityProfiler
from rl_env import ComplexityAwarePIMEnv
from rl_agent import ComplexityAwareRLAgent
from controller import ComplexityAwarePIMController
from training import train_complexity_aware_agent, compute_sample_efficiency
from plots import plot_training_behaviour, PolicyInterpreter
from baselines import BaselineEvaluator, AblationStudy
from benchmarks.mlperf_tiny import MLPerfTinyBenchmark


def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("=" * 65)
    print("  Complexity-Aware Physics-Based RL PIM Controller v2")
    print(f"  Device: {device}")
    print("=" * 65)

    # --- MLPerf Tiny benchmark pre-check ---
    print("\n--- MLPerf Tiny Benchmark ---")
    bench = MLPerfTinyBenchmark(device=device)

    def simple_router(model, input_shape, timesteps):
        profiler = TaskComplexityProfiler()
        p = profiler.profile(model, input_shape, timesteps)
        scores = profiler.compute_suitability_scores(p)
        return max(scores, key=scores.get)

    bench.run(simple_router, n_runs=10)
    bench.print_report()

    # --- Training ---
    print("\n--- Training Complexity-Aware Agent ---")
    agent, metrics = train_complexity_aware_agent(
        num_episodes=200, max_steps=300, device=device,
        save_dir="results", checkpoint_every=50)

    print("\n--- Generating Monitoring Plots ---")
    profiler = TaskComplexityProfiler()
    plot_training_behaviour(metrics, profiler, save_dir="results")

    print("\n--- Policy Interpretability ---")
    interp = PolicyInterpreter(agent)
    interp.plot_interpretability(save_path="results/fig5_policy_interpretability.png")

    # --- Sample Efficiency ---
    se = compute_sample_efficiency(metrics)
    print(f"\nSample Efficiency: {se}")

    # --- Baseline Comparison ---
    print("\n--- Baseline Comparison ---")
    evaluator = BaselineEvaluator(num_eval_episodes=20, max_steps=100)
    results = evaluator.evaluate_all(agent)
    evaluator.print_comparison_table(results)

    # --- Ablation Study (optional, slow) ---
    if input("\nRun ablation study? (y/n) ").lower().startswith("y"):
        print("\n--- Ablation Study ---")
        study = AblationStudy(num_episodes=50, max_steps=100, device=device)
        study.run()

    # --- Live Demo ---
    print("\n--- Live Routing Demo ---")
    ctrl = ComplexityAwarePIMController(device=device)
    ctrl.agent = agent

    # Demo SNN with correct dimensions (32x32 -> pool twice -> 8x8)
    try:
        import snntorch as snn
        from snntorch import surrogate
        HAS_SNNTORCH = True
    except ImportError:
        HAS_SNNTORCH = False

    if HAS_SNNTORCH:
        class DemoSNN(torch.nn.Module):
            def __init__(self):
                super().__init__()
                sg = surrogate.fast_sigmoid()
                self.conv1 = torch.nn.Conv2d(2, 16, 3, padding=1)
                self.lif1 = snn.Leaky(beta=0.9, spike_grad=sg, init_hidden=True)
                self.pool1 = torch.nn.MaxPool2d(2)
                self.conv2 = torch.nn.Conv2d(16, 32, 3, padding=1)
                self.lif2 = snn.Leaky(beta=0.9, spike_grad=sg, init_hidden=True)
                self.pool2 = torch.nn.MaxPool2d(2)
                self.fc = torch.nn.Linear(32 * 8 * 8, 10)
                self.lif3 = snn.Leaky(beta=0.9, spike_grad=sg, init_hidden=True, output=True)
            def forward(self, x):
                self.lif1.init_leaky(); self.lif2.init_leaky(); self.lif3.init_leaky()
                spk_rec = []
                T = x.size(1)
                for t in range(T):
                    frame = x[:, t]
                    c = self.conv1(frame)
                    s = self.lif1(c)
                    p = self.pool1(s)
                    c = self.conv2(p)
                    s = self.lif2(c)
                    p = self.pool2(s)
                    o = self.lif3(self.fc(p.flatten(1)))
                    spk_rec.append(o)
                return torch.stack(spk_rec)

        model = DemoSNN()
        target = ctrl.route_model(model, input_shape=(1, 2, 32, 32), timesteps=10)
        print(f"  DemoSNN -> Routed to: {target}")
        stats = ctrl.get_stats()
        print(f"  Stats: {stats}")
    else:
        print("  snntorch not installed — skipping SNN demo")

    print("=" * 65)


if __name__ == "__main__":
    main()