nn-from-scratch / nn /layers.py
LaelaZ's picture
Upload folder using huggingface_hub
5041f39 verified
"""Layers with hand-written forward and backward passes.
Each layer caches what it needs on the forward pass and returns input-gradients
on the backward pass, plus parameter-gradients where it has parameters. There is
no automatic differentiation here on purpose: the point is to show the chain rule
applied by hand, then prove it correct with a numerical gradient check (see tests).
"""
from __future__ import annotations
import numpy as np
class Linear:
"""Fully-connected layer: y = x @ W + b."""
def __init__(self, n_in: int, n_out: int, rng: np.random.Generator):
# He initialisation keeps activations from vanishing/exploding through ReLU.
self.W = rng.standard_normal((n_in, n_out)) * np.sqrt(2.0 / n_in)
self.b = np.zeros(n_out)
self.dW = np.zeros_like(self.W)
self.db = np.zeros_like(self.b)
self._x = None
def forward(self, x: np.ndarray) -> np.ndarray:
self._x = x
return x @ self.W + self.b
def backward(self, dy: np.ndarray) -> np.ndarray:
# dL/dW = x^T · dy ; dL/db = sum over the batch ; dL/dx = dy · W^T
self.dW = self._x.T @ dy
self.db = dy.sum(axis=0)
return dy @ self.W.T
def params_and_grads(self):
yield self.W, self.dW
yield self.b, self.db
class ReLU:
"""Rectified linear unit: passes positives through, zeros the rest."""
def __init__(self):
self._mask = None
def forward(self, x: np.ndarray) -> np.ndarray:
self._mask = x > 0
return x * self._mask
def backward(self, dy: np.ndarray) -> np.ndarray:
# Gradient flows only where the input was positive.
return dy * self._mask
def params_and_grads(self):
return iter(())
def softmax_cross_entropy(logits: np.ndarray, labels: np.ndarray):
"""Combined softmax + cross-entropy.
Returns (loss, dlogits). Done together for numerical stability: subtracting the
row-max before exp avoids overflow, and the gradient simplifies to (p - y) / N.
`labels` is an integer class index per row.
"""
shifted = logits - logits.max(axis=1, keepdims=True)
exp = np.exp(shifted)
probs = exp / exp.sum(axis=1, keepdims=True)
n = logits.shape[0]
log_likelihood = -np.log(probs[np.arange(n), labels] + 1e-12)
loss = log_likelihood.mean()
dlogits = probs.copy()
dlogits[np.arange(n), labels] -= 1.0
dlogits /= n
return loss, dlogits