File size: 5,139 Bytes
3f2dde4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | from __future__ import annotations
from dataclasses import dataclass
from math import exp
from time import perf_counter
@dataclass(slots=True)
class StepTelemetry:
epoch: int
steps: int
wall_time_sec: float
memory_rss_mb: float
child_processes: int
thread_count: int
predictability_score: float
final_accuracy: float
final_loss: float
learned_gate_mean: float
learned_gate_std: float
@dataclass(slots=True)
class GateDemoResult:
initial_accuracy: float
final_accuracy: float
final_loss: float
reached_target: bool
trained_steps: int
target_accuracy: float
learned_gates: list[float]
learned_gate_sample: list[float]
telemetry: list[StepTelemetry]
def _process_snapshot() -> tuple[float, int, int]:
try:
import psutil
process = psutil.Process()
memory_rss_mb = process.memory_info().rss / (1024 * 1024)
child_processes = len(process.children(recursive=True))
thread_count = process.num_threads()
return memory_rss_mb, child_processes, thread_count
except Exception:
return 0.0, 0, 0
def run_tinygrad_gate_demo(
steps: int = 80,
batch_size: int = 64,
seed: int = 0,
target_accuracy: float = 0.99,
) -> GateDemoResult:
try:
from tinygrad import Tensor, nn
from tinygrad.nn.state import get_parameters
except ImportError as exc: # pragma: no cover - dependency gate
raise RuntimeError("tinygrad demo requires tinygrad to be installed") from exc
Tensor.manual_seed(seed)
input_dim = 12
classes = 2
samples = 128
features = Tensor.randn(samples, input_dim)
class GatedProbe:
def __init__(self) -> None:
self.base_weights = Tensor.linspace(0.5, 1.5, input_dim).is_param_(False)
self.log_gates = Tensor.zeros(input_dim)
def __call__(self, x: Tensor) -> Tensor:
score = (x * self.base_weights * self.log_gates.exp()).sum(axis=1)
return Tensor.stack(-score, score, dim=1)
teacher = GatedProbe()
teacher.log_gates = Tensor.linspace(-0.25, 0.75, input_dim).is_param_(False)
labels = teacher(features).argmax(-1)
student = GatedProbe()
optimizer = nn.optim.SGD(get_parameters(student), lr=0.8)
def accuracy(model: GatedProbe) -> float:
logits = model(features)
pred = logits.argmax(-1)
return float((pred == labels).sum().item()) / samples
initial_accuracy = accuracy(student)
telemetry: list[StepTelemetry] = []
start_time = perf_counter()
Tensor.training = True
reached_target = False
trained_steps = 0
for epoch in range(1, steps + 1):
batch_x = features
batch_y = labels
optimizer.zero_grad()
loss = student(batch_x).sparse_categorical_crossentropy(batch_y).backward()
optimizer.step()
trained_steps = epoch
if epoch == steps or epoch % max(1, steps // 8) == 0:
current_logits = student(features)
current_loss = float(current_logits.sparse_categorical_crossentropy(labels).item())
current_accuracy = accuracy(student)
memory_rss_mb, child_processes, thread_count = _process_snapshot()
learned_gates = [float(x) for x in student.log_gates.exp().tolist()]
telemetry.append(
StepTelemetry(
epoch=epoch,
steps=epoch,
wall_time_sec=perf_counter() - start_time,
memory_rss_mb=memory_rss_mb,
child_processes=child_processes,
thread_count=thread_count,
predictability_score=float(exp(-current_loss) * 100.0),
final_accuracy=current_accuracy,
final_loss=current_loss,
learned_gate_mean=sum(learned_gates) / max(len(learned_gates), 1),
learned_gate_std=(
(sum((x - (sum(learned_gates) / max(len(learned_gates), 1))) ** 2 for x in learned_gates) / max(len(learned_gates), 1))
** 0.5
),
)
)
if current_accuracy >= target_accuracy:
reached_target = True
break
Tensor.training = False
final_logits = student(features)
final_loss = float(final_logits.sparse_categorical_crossentropy(labels).item())
final_accuracy = accuracy(student)
learned_gates = [float(x) for x in student.log_gates.exp().tolist()]
gate_sample = learned_gates[:8]
return GateDemoResult(
initial_accuracy=initial_accuracy,
final_accuracy=final_accuracy,
final_loss=final_loss,
reached_target=reached_target,
trained_steps=trained_steps,
target_accuracy=target_accuracy,
learned_gates=learned_gates,
learned_gate_sample=gate_sample,
telemetry=telemetry,
)
|