File size: 5,139 Bytes
3f2dde4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from __future__ import annotations

from dataclasses import dataclass
from math import exp
from time import perf_counter


@dataclass(slots=True)
class StepTelemetry:
    epoch: int
    steps: int
    wall_time_sec: float
    memory_rss_mb: float
    child_processes: int
    thread_count: int
    predictability_score: float
    final_accuracy: float
    final_loss: float
    learned_gate_mean: float
    learned_gate_std: float


@dataclass(slots=True)
class GateDemoResult:
    initial_accuracy: float
    final_accuracy: float
    final_loss: float
    reached_target: bool
    trained_steps: int
    target_accuracy: float
    learned_gates: list[float]
    learned_gate_sample: list[float]
    telemetry: list[StepTelemetry]


def _process_snapshot() -> tuple[float, int, int]:
    try:
        import psutil

        process = psutil.Process()
        memory_rss_mb = process.memory_info().rss / (1024 * 1024)
        child_processes = len(process.children(recursive=True))
        thread_count = process.num_threads()
        return memory_rss_mb, child_processes, thread_count
    except Exception:
        return 0.0, 0, 0


def run_tinygrad_gate_demo(

    steps: int = 80,

    batch_size: int = 64,

    seed: int = 0,

    target_accuracy: float = 0.99,

) -> GateDemoResult:
    try:
        from tinygrad import Tensor, nn
        from tinygrad.nn.state import get_parameters
    except ImportError as exc:  # pragma: no cover - dependency gate
        raise RuntimeError("tinygrad demo requires tinygrad to be installed") from exc

    Tensor.manual_seed(seed)

    input_dim = 12
    classes = 2
    samples = 128

    features = Tensor.randn(samples, input_dim)

    class GatedProbe:
        def __init__(self) -> None:
            self.base_weights = Tensor.linspace(0.5, 1.5, input_dim).is_param_(False)
            self.log_gates = Tensor.zeros(input_dim)

        def __call__(self, x: Tensor) -> Tensor:
            score = (x * self.base_weights * self.log_gates.exp()).sum(axis=1)
            return Tensor.stack(-score, score, dim=1)

    teacher = GatedProbe()
    teacher.log_gates = Tensor.linspace(-0.25, 0.75, input_dim).is_param_(False)
    labels = teacher(features).argmax(-1)

    student = GatedProbe()
    optimizer = nn.optim.SGD(get_parameters(student), lr=0.8)

    def accuracy(model: GatedProbe) -> float:
        logits = model(features)
        pred = logits.argmax(-1)
        return float((pred == labels).sum().item()) / samples

    initial_accuracy = accuracy(student)

    telemetry: list[StepTelemetry] = []
    start_time = perf_counter()
    Tensor.training = True
    reached_target = False
    trained_steps = 0
    for epoch in range(1, steps + 1):
        batch_x = features
        batch_y = labels
        optimizer.zero_grad()
        loss = student(batch_x).sparse_categorical_crossentropy(batch_y).backward()
        optimizer.step()
        trained_steps = epoch

        if epoch == steps or epoch % max(1, steps // 8) == 0:
            current_logits = student(features)
            current_loss = float(current_logits.sparse_categorical_crossentropy(labels).item())
            current_accuracy = accuracy(student)
            memory_rss_mb, child_processes, thread_count = _process_snapshot()
            learned_gates = [float(x) for x in student.log_gates.exp().tolist()]
            telemetry.append(
                StepTelemetry(
                    epoch=epoch,
                    steps=epoch,
                    wall_time_sec=perf_counter() - start_time,
                    memory_rss_mb=memory_rss_mb,
                    child_processes=child_processes,
                    thread_count=thread_count,
                    predictability_score=float(exp(-current_loss) * 100.0),
                    final_accuracy=current_accuracy,
                    final_loss=current_loss,
                    learned_gate_mean=sum(learned_gates) / max(len(learned_gates), 1),
                    learned_gate_std=(
                        (sum((x - (sum(learned_gates) / max(len(learned_gates), 1))) ** 2 for x in learned_gates) / max(len(learned_gates), 1))
                        ** 0.5
                    ),
                )
            )
            if current_accuracy >= target_accuracy:
                reached_target = True
                break

    Tensor.training = False
    final_logits = student(features)
    final_loss = float(final_logits.sparse_categorical_crossentropy(labels).item())
    final_accuracy = accuracy(student)
    learned_gates = [float(x) for x in student.log_gates.exp().tolist()]
    gate_sample = learned_gates[:8]

    return GateDemoResult(
        initial_accuracy=initial_accuracy,
        final_accuracy=final_accuracy,
        final_loss=final_loss,
        reached_target=reached_target,
        trained_steps=trained_steps,
        target_accuracy=target_accuracy,
        learned_gates=learned_gates,
        learned_gate_sample=gate_sample,
        telemetry=telemetry,
    )