File size: 899 Bytes
712dbf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
# Copyright © 2023 Apple Inc.
import time
import mlx.core as mx
num_features = 100
num_examples = 1_000
num_iters = 10_000
lr = 0.01
# True parameters
w_star = mx.random.normal((num_features,))
# Input examples (design matrix)
X = mx.random.normal((num_examples, num_features))
# Noisy labels
eps = 1e-2 * mx.random.normal((num_examples,))
y = X @ w_star + eps
# Initialize random parameters
w = 1e-2 * mx.random.normal((num_features,))
def loss_fn(w):
return 0.5 * mx.mean(mx.square(X @ w - y))
grad_fn = mx.grad(loss_fn)
tic = time.time()
for _ in range(num_iters):
grad = grad_fn(w)
w = w - lr * grad
mx.eval(w)
toc = time.time()
loss = loss_fn(w)
error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
throughput = num_iters / (toc - tic)
print(
f"Loss {loss.item():.5f}, L2 distance: |w-w*| = {error_norm:.5f}, "
f"Throughput {throughput:.5f} (it/s)"
)
|